Skip to content

Commit

Permalink
[api] Fixes loading BlockFactory bug (#1547)
Browse files Browse the repository at this point in the history
Change-Id: I9fe68ae6af4be85bcd8a626cfa28534f8d141798
  • Loading branch information
frankfliu authored Mar 30, 2022
1 parent ff35da9 commit dbdabb9
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
5 changes: 3 additions & 2 deletions api/src/main/java/ai/djl/repository/zoo/BaseModelLoader.java
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ protected Model createModel(
Model model = Model.newInstance(name, device, engine);
if (block == null) {
String className = (String) arguments.get("blockFactory");
BlockFactory factory = ClassLoaderUtils.findImplementation(modelPath, className);
BlockFactory factory =
ClassLoaderUtils.findImplementation(modelPath, BlockFactory.class, className);
if (factory != null) {
block = factory.newBlock(model, modelPath, arguments);
}
Expand Down Expand Up @@ -229,7 +230,7 @@ protected TranslatorFactory getTranslatorFactory(
String factoryClass = (String) arguments.get("translatorFactory");
if (factoryClass != null) {
ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
factory = ClassLoaderUtils.initClass(cl, factoryClass);
factory = ClassLoaderUtils.initClass(cl, TranslatorFactory.class, factoryClass);
}
return factory;
}
Expand Down
24 changes: 13 additions & 11 deletions api/src/main/java/ai/djl/util/ClassLoaderUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ private ClassLoaderUtils() {}
* <p>For .class file, this function expects them in classes/your/package/ClassName.class
*
* @param path the path to scan from
* @param type the type of the class
* @param className the name of the classes, pass null if name is unknown
* @param <T> the Template T for the output Class
* @return the Class implementation
*/
public static <T> T findImplementation(Path path, String className) {
public static <T> T findImplementation(Path path, Class<T> type, String className) {
try {
Path classesDir = path.resolve("classes");
// we only consider .class files and skip .java files
Expand All @@ -75,16 +76,16 @@ public static <T> T findImplementation(Path path, String className) {
(PrivilegedAction<ClassLoader>)
() -> new URLClassLoader(urls, contextCl));
if (className != null && !className.isEmpty()) {
return initClass(cl, className);
return initClass(cl, type, className);
}

T implemented = scanDirectory(cl, classesDir);
T implemented = scanDirectory(cl, type, classesDir);
if (implemented != null) {
return implemented;
}

for (Path p : jarFiles) {
implemented = scanJarFile(cl, p);
implemented = scanJarFile(cl, type, p);
if (implemented != null) {
return implemented;
}
Expand All @@ -95,7 +96,7 @@ public static <T> T findImplementation(Path path, String className) {
return null;
}

private static <T> T scanDirectory(ClassLoader cl, Path dir) throws IOException {
private static <T> T scanDirectory(ClassLoader cl, Class<T> type, Path dir) throws IOException {
if (!Files.isDirectory(dir)) {
logger.trace("Directory not exists: {}", dir);
return null;
Expand All @@ -109,15 +110,15 @@ private static <T> T scanDirectory(ClassLoader cl, Path dir) throws IOException
String className = p.toString();
className = className.substring(0, className.lastIndexOf('.'));
className = className.replace(File.separatorChar, '.');
T implemented = initClass(cl, className);
T implemented = initClass(cl, type, className);
if (implemented != null) {
return implemented;
}
}
return null;
}

private static <T> T scanJarFile(ClassLoader cl, Path path) throws IOException {
private static <T> T scanJarFile(ClassLoader cl, Class<T> type, Path path) throws IOException {
try (JarFile jarFile = new JarFile(path.toFile())) {
Enumeration<JarEntry> en = jarFile.entries();
while (en.hasMoreElements()) {
Expand All @@ -126,7 +127,7 @@ private static <T> T scanJarFile(ClassLoader cl, Path path) throws IOException {
if (fileName.endsWith(".class")) {
fileName = fileName.substring(0, fileName.lastIndexOf('.'));
fileName = fileName.replace('/', '.');
T implemented = initClass(cl, fileName);
T implemented = initClass(cl, type, fileName);
if (implemented != null) {
return implemented;
}
Expand All @@ -140,15 +141,16 @@ private static <T> T scanJarFile(ClassLoader cl, Path path) throws IOException {
* Loads the specified class and constructs an instance.
*
* @param cl the {@code ClassLoader} to use
* @param type the type of the class
* @param className the class to be loaded
* @param <T> the type of the class
* @return an instance of the class, null if the class not found
*/
@SuppressWarnings("unchecked")
public static <T> T initClass(ClassLoader cl, String className) {
public static <T> T initClass(ClassLoader cl, Class<T> type, String className) {
try {
Class<?> clazz = Class.forName(className, true, cl);
Constructor<T> constructor = (Constructor<T>) clazz.getConstructor();
Class<? extends T> sub = clazz.asSubclass(type);
Constructor<? extends T> constructor = sub.getConstructor();
return constructor.newInstance();
} catch (Throwable e) {
logger.trace("Not able to load Object", e);
Expand Down

0 comments on commit dbdabb9

Please sign in to comment.