1 package org.kite9.framework.classloading;
2
3 import java.io.ByteArrayOutputStream;
4 import java.io.IOException;
5 import java.net.URL;
6 import java.net.URLConnection;
7 import java.util.HashMap;
8 import java.util.HashSet;
9 import java.util.Map;
10 import java.util.Set;
11 import java.util.zip.ZipEntry;
12 import java.util.zip.ZipInputStream;
13
14 import org.kite9.framework.common.RepositoryHelp;
15
16 /***
17 * Handles loading of classes from a nested jar files. Handles a special
18 * 'override class' which will be loaded only using this class loader.
19 *
20 *
21 * @author robmoffat
22 *
23 */
24 public class OverrideJarClassLoader extends ClassLoader {
25
26 String overrideClass;
27 boolean logging;
28
29 public Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
30 Class<?> clazz = findLoadedClass(name);
31 if (clazz != null) {
32 return clazz;
33 }
34
35 if (!overrideClass.equals(name)) {
36 try {
37 clazz = getSystemClassLoader().loadClass(name);
38 if (clazz != null) {
39 if (resolve)
40 resolveClass(clazz);
41 log("Returning (System ClassLoader) " + clazz + " with class loader: "
42 + clazz.getClassLoader());
43 return (clazz);
44 }
45 } catch (ClassNotFoundException e) {
46 }
47
48 ClassLoader loader = getParent();
49 if (loader != null) {
50 try {
51 clazz = loader.loadClass(name);
52 if (clazz != null) {
53 if (resolve)
54 resolveClass(clazz);
55 log("Returning (Parent ClassLoader) " + clazz + " with class loader: "
56 + clazz.getClassLoader());
57 return (clazz);
58 }
59 } catch (ClassNotFoundException e) {
60 }
61 }
62 }
63
64 String file = convertToFile(name);
65 if (memoryMappedFiles.containsKey(file)) {
66 clazz = findClass(name);
67 if (clazz != null) {
68 if (resolve)
69 resolveClass(clazz);
70 log("Returning (OverrideJarClassLoader) " + clazz + " with class loader: "
71 + clazz.getClassLoader());
72 return clazz;
73 }
74
75 }
76
77 throw new ClassNotFoundException(name);
78 }
79
80 URL[] urls;
81
82 Map<String, byte[]> memoryMappedFiles = new HashMap<String, byte[]>();
83 Map<String, Integer> startPosition = new HashMap<String, Integer>();
84 Map<String, Integer> lengths = new HashMap<String, Integer>();
85
86 public OverrideJarClassLoader(URL[] urls, ClassLoader parent, String overrideClass, boolean log) throws IOException {
87 super(parent);
88 this.urls = urls;
89 init();
90 this.overrideClass = overrideClass;
91 this.logging = log;
92 }
93
94 protected void init() throws IOException {
95 for (int i = 0; i < urls.length; i++) {
96 Set<String> files = new HashSet<String>(500);
97
98 URLConnection c = urls[i].openConnection();
99 ByteArrayOutputStream baos = new ByteArrayOutputStream(20000);
100
101 ZipInputStream zis = new ZipInputStream(c.getInputStream());
102 ZipEntry ze = zis.getNextEntry();
103 while (ze != null) {
104 int start = baos.size();
105 startPosition.put(ze.getName(), start);
106 RepositoryHelp.streamCopy(zis, baos, false);
107 int end = baos.size();
108 lengths.put(ze.getName(), end - start);
109 files.add(ze.getName());
110 ze = zis.getNextEntry();
111 }
112
113 zis.close();
114 baos.close();
115
116 byte[] theFile = baos.toByteArray();
117
118 for (String name : files) {
119 memoryMappedFiles.put(name, theFile);
120 }
121
122 log("mapped file: " + urls[i]);
123 }
124 }
125
126 protected Class<?> findClass(String name) throws ClassNotFoundException {
127 try {
128 log("Finding: "+name);
129 String path = convertToFile(name);
130 byte[] theFile = memoryMappedFiles.get(path);
131 Integer start = startPosition.get(path);
132 Integer length = lengths.get(path);
133
134 if (theFile == null) {
135 throw new ClassNotFoundException("Could not load class, not found in any jar: " + path);
136 }
137
138 Class<?> out = defineClass(name, theFile, start, length);
139 return out;
140
141 } catch (ClassFormatError e) {
142 throw new ClassNotFoundException("Could not load class: ", e);
143 }
144
145 }
146
147 private String convertToFile(String name) {
148 return name.replace('.', '/').concat(".class");
149 }
150
151 public boolean isLogging() {
152 return logging;
153 }
154
155 public void setLogging(boolean logging) {
156 this.logging = logging;
157 }
158
159 public void log(String l) {
160 if (logging) {
161 System.err.println("OJCL: "+l);
162 }
163 }
164 }