Skip to content

Commit

Permalink
Take superclasses into account when looking up main method
Browse files Browse the repository at this point in the history
  • Loading branch information
geoand committed Sep 29, 2023
1 parent af5b725 commit ca49710
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -518,80 +518,142 @@ public MainMethodTransformer(IndexView index) {

@Override
public ClassVisitor apply(String mainClassName, ClassVisitor outputClassVisitor) {
ClassInfo classByName = index.getClassByName(mainClassName);
if (classByName == null) {
ClassInfo mainClassInfo = index.getClassByName(mainClassName);
if (mainClassInfo == null) {
throw new IllegalStateException("mainClassName should have a corresponding ClassInfo at this point");
}
ClassTransformer transformer = new ClassTransformer(mainClassName);
ForStringArgsResult forStringArgsResult = applyForStringArgs(mainClassName, outputClassVisitor, transformer,
mainClassInfo, true);
if (forStringArgsResult.classVisitor != null) {
return forStringArgsResult.classVisitor;
}
return applyForNoArgs(mainClassName, outputClassVisitor, transformer, mainClassInfo, forStringArgsResult.state,
true);
}

MethodInfo withStringArgs = classByName.method("main", STRING_ARRAY);
StringArgsPassState state = StringArgsPassState.NO_METHOD;
private ForStringArgsResult applyForStringArgs(String originalMainClassName, ClassVisitor originalVisitor,
ClassTransformer transformer,
ClassInfo currentClassInfo, boolean isTopLevel) {
MethodInfo withStringArgs = currentClassInfo.method("main", STRING_ARRAY);
StringArgsPassState withArgsCheckState = StringArgsPassState.NO_METHOD;

if (withStringArgs != null) {
short modifiers = withStringArgs.flags();
if (Modifier.isStatic(modifiers)) {
if (Modifier.isPublic(modifiers)) {
// nothing to do
state = StringArgsPassState.EXIT;
} else {
// this is the simplest case where we just make the method public
transformer.modifyMethod(MethodDescriptor.of(withStringArgs)).removeModifiers(Modifier.PROTECTED)
.addModifiers(Modifier.PUBLIC);
state = StringArgsPassState.NO_MORE_ACTIONS_NEEDED;
if (isTopLevel) {
if (Modifier.isPublic(modifiers)) {
// nothing to do
withArgsCheckState = StringArgsPassState.EXIT;
} else {
// this is the simplest case where we just make the method public
transformer.modifyMethod(MethodDescriptor.of(withStringArgs)).removeModifiers(Modifier.PROTECTED)
.addModifiers(Modifier.PUBLIC);
withArgsCheckState = StringArgsPassState.NO_MORE_ACTIONS_NEEDED;
}
}
} else {
if (Modifier.isPrivate(modifiers)) {
state = StringArgsPassState.HAS_PRIVATE_MAIN;
withArgsCheckState = StringArgsPassState.HAS_PRIVATE_MAIN;
} else {
// here we need to construct an instance and call the instance method with the args parameter
MethodCreator standardMain = createStandardMain(transformer);
ResultHandle instanceHandle = standardMain.newInstance(ofConstructor(mainClassName));
standardMain.invokeVirtualMethod(ofMethod(mainClassName, "$$main$$", void.class, String[].class),
instanceHandle, standardMain.getMethodParam(0));
ResultHandle instanceHandle = standardMain.newInstance(ofConstructor(originalMainClassName));
ResultHandle argsParamHandle = standardMain.getMethodParam(0);
if (isTopLevel) {
// we need to rename the method in order to avoid having two main methods with the same name
standardMain.invokeVirtualMethod(
ofMethod(originalMainClassName, "$$main$$", void.class, String[].class),
instanceHandle, argsParamHandle);

transformer.modifyMethod(MethodDescriptor.of(withStringArgs)).rename("$$main$$");
} else {
// Invoke super
standardMain.invokeSpecialMethod(withStringArgs, instanceHandle, argsParamHandle);
}
withArgsCheckState = StringArgsPassState.NO_MORE_ACTIONS_NEEDED;
standardMain.returnValue(null);
transformer.modifyMethod(MethodDescriptor.of(withStringArgs)).rename("$$main$$");
state = StringArgsPassState.NO_MORE_ACTIONS_NEEDED;
}
}
}

if (state == StringArgsPassState.EXIT) {
return outputClassVisitor;
} else if (state == StringArgsPassState.NO_MORE_ACTIONS_NEEDED) {
return transformer.applyTo(outputClassVisitor);
if (withArgsCheckState == StringArgsPassState.EXIT) {
// no transformations were necessary, so just make the result a pass-through
return new ForStringArgsResult(originalVisitor, withArgsCheckState);
} else if (withArgsCheckState == StringArgsPassState.NO_MORE_ACTIONS_NEEDED) {
// no more transformations are needed, so just set the result
return new ForStringArgsResult(transformer.applyTo(originalVisitor), withArgsCheckState);
} else {
DotName superName = currentClassInfo.superName();
if (superName.equals(OBJECT)) {
return new ForStringArgsResult(null, withArgsCheckState);
}
ClassInfo superClassInfo = getSuperClassInfo(originalMainClassName, currentClassInfo);
return applyForStringArgs(originalMainClassName, originalVisitor, transformer, superClassInfo, false);
}
}

private static MethodCreator createStandardMain(ClassTransformer transformer) {
return transformer.addMethod("main", void.class, String[].class)
.setModifiers(Modifier.PUBLIC | Modifier.STATIC);
}

private ClassVisitor applyForNoArgs(String originalMainClassName, ClassVisitor originalVisitor,
ClassTransformer transformer,
ClassInfo currentClassInfo,
StringArgsPassState withArgsCheckState, boolean allowStatic) {

MethodInfo withoutArgs = classByName.method("main");
boolean hasValidNoArgsMethod = true;
MethodInfo withoutArgs = currentClassInfo.method("main");
if (withoutArgs == null) {
if (state == StringArgsPassState.HAS_PRIVATE_MAIN) {
throw new IllegalStateException("Main method on class '" + mainClassName + "' cannot be private");
if (withArgsCheckState == StringArgsPassState.HAS_PRIVATE_MAIN) {
throw new IllegalStateException(
"Main method on class '" + originalMainClassName + "' cannot be private");
} else {
throw new IllegalStateException("Unable to find main method on class '" + mainClassName + "'");
hasValidNoArgsMethod = false;
}
} else {
short modifiers = withoutArgs.flags();
if (Modifier.isPrivate(modifiers)) {
throw new IllegalStateException("Main method on class '" + mainClassName + "' cannot be private");
throw new IllegalStateException(
"Main method on class '" + originalMainClassName + "' cannot be private");
} else {
MethodCreator standardMain = createStandardMain(transformer);
if (Modifier.isStatic(modifiers)) {
// call the static main without any parameters
standardMain.invokeStaticMethod(MethodDescriptor.of(withoutArgs));
if (allowStatic) {
// call the static main without any parameters
standardMain.invokeStaticMethod(MethodDescriptor.of(withoutArgs));
}
} else {
// here we need to construct an instance and call the instance method without any parameters
ResultHandle instanceHandle = standardMain.newInstance(ofConstructor(mainClassName));
ResultHandle instanceHandle = standardMain.newInstance(ofConstructor(originalMainClassName));
standardMain.invokeVirtualMethod(MethodDescriptor.of(withoutArgs), instanceHandle);
}
standardMain.returnValue(null);
}
}

return transformer.applyTo(outputClassVisitor);
if (hasValidNoArgsMethod) {
return transformer.applyTo(originalVisitor);
} else {
ClassInfo superClassInfo = getSuperClassInfo(originalMainClassName, currentClassInfo);
return applyForNoArgs(originalMainClassName, originalVisitor, transformer, superClassInfo, withArgsCheckState,
false);
}
}

private static MethodCreator createStandardMain(ClassTransformer transformer) {
return transformer.addMethod("main", void.class, String[].class)
.setModifiers(Modifier.PUBLIC | Modifier.STATIC);
private ClassInfo getSuperClassInfo(String originalMainClassName, ClassInfo currentClassInfo) {
DotName superName = currentClassInfo.superName();
if (superName.equals(OBJECT)) {
// no valid main method was found, so we need to fail
throw new IllegalStateException("Unable to find main method on class '" + originalMainClassName + "'");
}
ClassInfo superClassInfo = index.getClassByName(superName);
if (superClassInfo == null) {
throw new IllegalStateException("Unable to find main method on class '" + originalMainClassName
+ "' while it was also not possible to traverse the class hierarchy");
}
return superClassInfo;
}

enum StringArgsPassState {
Expand All @@ -601,6 +663,16 @@ enum StringArgsPassState {
HAS_PRIVATE_MAIN,
REQUIRES_RENAME
}

private static class ForStringArgsResult {
private final ClassVisitor classVisitor;
private final StringArgsPassState state;

public ForStringArgsResult(ClassVisitor classVisitor, StringArgsPassState state) {
this.classVisitor = classVisitor;
this.state = state;
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.quarkus.commandmode;

import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.runtime.annotations.QuarkusMain;
import io.quarkus.test.QuarkusProdModeTest;

public class InstanceMainInSuperClassCommandModeTestCase {
@RegisterExtension
static final QuarkusProdModeTest config = new QuarkusProdModeTest()
.withApplicationRoot((jar) -> jar
.addClasses(HelloWorldSuperSuper.class, HelloWorldSuper.class, HelloWorldMain.class))
.setApplicationName("run-exit")
.setApplicationVersion("0.1-SNAPSHOT")
.setExpectExit(true)
.setRun(true);

@Test
public void testRun() {
Assertions.assertThat(config.getStartupConsoleOutput()).contains("Hello World");
Assertions.assertThat(config.getExitCode()).isEqualTo(0);
}

@QuarkusMain
public static class HelloWorldMain extends HelloWorldSuper {

}

public static class HelloWorldSuperSuper {

protected void main() {
System.out.println("Hello World");
}
}

public static class HelloWorldSuper extends HelloWorldSuperSuper {

protected void main2() {
System.out.println("Hello");
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package io.quarkus.commandmode;

import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.runtime.annotations.QuarkusMain;
import io.quarkus.test.QuarkusProdModeTest;

public class InstanceMainInSuperClassNoArgsCommandModeTestCase {
@RegisterExtension
static final QuarkusProdModeTest config = new QuarkusProdModeTest()
.withApplicationRoot((jar) -> jar
.addClasses(HelloWorldSuper.class, HelloWorldMain.class))
.setApplicationName("run-exit")
.setApplicationVersion("0.1-SNAPSHOT")
.setExpectExit(true)
.setRun(true);

@Test
public void testRun() {
Assertions.assertThat(config.getStartupConsoleOutput()).contains("Hello World");
Assertions.assertThat(config.getExitCode()).isEqualTo(0);
}

@QuarkusMain
public static class HelloWorldMain extends HelloWorldSuper {

}

public static class HelloWorldSuper {

void main() {
System.out.println("Hello World");
}

void main2() {
System.out.println("Hello");
}

void main3(String[] args) {
System.out.println("Hello");
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package io.quarkus.commandmode;

import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.runtime.annotations.QuarkusMain;
import io.quarkus.test.QuarkusProdModeTest;

public class MultipleInstanceMainInSuperClassCommandModeTestCase {
@RegisterExtension
static final QuarkusProdModeTest config = new QuarkusProdModeTest()
.withApplicationRoot((jar) -> jar
.addClasses(HelloWorldSuperSuper.class, HelloWorldSuper.class, HelloWorldMain.class))
.setApplicationName("run-exit")
.setApplicationVersion("0.1-SNAPSHOT")
.setExpectExit(true)
.setRun(true);

@Test
public void testRun() {
Assertions.assertThat(config.getStartupConsoleOutput()).contains("Hi World");
Assertions.assertThat(config.getExitCode()).isEqualTo(0);
}

@QuarkusMain
public static class HelloWorldMain extends HelloWorldSuper {

}

public static class HelloWorldSuperSuper {

protected void main(String[] args) {
System.out.println("Hi World");
}

protected void main() {
System.out.println("Hello World");
}
}

public static class HelloWorldSuper extends HelloWorldSuperSuper {

protected void main2() {
System.out.println("Hello");
}
}

}

0 comments on commit ca49710

Please sign in to comment.