Skip to content

Commit

Permalink
reformat: some manual newline removals & improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Janmm14 committed May 28, 2022
1 parent 4167796 commit 2687779
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 314 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,17 @@ public class InvokedynamicTransformer extends Transformer<TransformerConfig> {
@Override
public boolean transform() {
System.out.println("[AntiReleak] [InvokedynamicTransformer] Starting");
System.out.println(
"[AntiReleak] [InvokedynamicTransformer] Finding invokedynamic instructions");
System.out.println("[AntiReleak] [InvokedynamicTransformer] Finding invokedynamic instructions");
int amount = findInvokeDynamic(classNodes());
System.out.println("[AntiReleak] [InvokedynamicTransformer] Found "
+ amount + " invokedynamic instructions");
System.out.println("[AntiReleak] [InvokedynamicTransformer] Found " + amount + " invokedynamic instructions");
if (amount > 0) {
System.out.println(
"[AntiReleak] [InvokedynamicTransformer] Inlining invokedynamic");
System.out.println("[AntiReleak] [InvokedynamicTransformer] Inlining invokedynamic");
long start = System.currentTimeMillis();
int inlined = inlineInvokeDynamic(amount);
long end = System.currentTimeMillis();
System.out.println("[AntiReleak] [InvokedynamicTransformer] Removed "
+ inlined + " invokedynamic instructions, took "
System.out.println("[AntiReleak] [InvokedynamicTransformer] Removed " + inlined + " invokedynamic instructions, took "
+ TimeUnit.MILLISECONDS.toSeconds(end - start) + "s");
System.out.println(
"[AntiReleak] [InvokedynamicTransformer] Cleaning up bootstrap methods");
System.out.println("[AntiReleak] [InvokedynamicTransformer] Cleaning up bootstrap methods");
cleanup();
}
System.out.println("[AntiReleak] [InvokedynamicTransformer] Done");
Expand All @@ -82,21 +77,16 @@ public static int findInvokeDynamic(Collection<ClassNode> classNodes) {
classNodes.forEach(classNode -> {
classNode.methods.forEach(methodNode -> {
for (int i = 0; i < methodNode.instructions.size(); i++) {
AbstractInsnNode abstractInsnNode =
methodNode.instructions.get(i);
AbstractInsnNode abstractInsnNode = methodNode.instructions.get(i);
if (abstractInsnNode instanceof InvokeDynamicInsnNode) {
InvokeDynamicInsnNode dyn =
(InvokeDynamicInsnNode) abstractInsnNode;
if (dyn.bsmArgs.length == 3 &&
(dyn.bsmArgs[2].equals(182) || dyn.bsmArgs[2].equals(184)))
InvokeDynamicInsnNode dyn = (InvokeDynamicInsnNode) abstractInsnNode;
if (dyn.bsmArgs.length == 3 && (dyn.bsmArgs[2].equals(182) || dyn.bsmArgs[2].equals(184)))
total.incrementAndGet();
else if (dyn.bsmArgs.length == 8 && dyn.bsmArgs[0] instanceof Integer)
total.incrementAndGet();
else if (dyn.bsmArgs.length == 9 && dyn.bsmArgs[0] instanceof String
&& isInteger((String) dyn.bsmArgs[0]))
else if (dyn.bsmArgs.length == 9 && dyn.bsmArgs[0] instanceof String && isInteger((String) dyn.bsmArgs[0]))
total.incrementAndGet();
else if (dyn.bsmArgs.length == 7 && dyn.bsmArgs[0].equals("1")
&& dyn.bsmArgs[1].equals("8"))
else if (dyn.bsmArgs.length == 7 && dyn.bsmArgs[0].equals("1") && dyn.bsmArgs[1].equals("8"))
total.incrementAndGet();
}
}
Expand All @@ -116,8 +106,7 @@ private int inlineInvokeDynamic(int expected) {
provider.register(new MappedMethodProvider(classes));
provider.register(new ComparisonProvider() {
@Override
public boolean instanceOf(JavaValue target, Type type,
Context context) {
public boolean instanceOf(JavaValue target, Type type, Context context) {
if (type.getDescriptor().equals("Ljava/lang/String;"))
if (!(target.value() instanceof String))
return false;
Expand All @@ -128,35 +117,30 @@ public boolean instanceOf(JavaValue target, Type type,
}

@Override
public boolean checkcast(JavaValue target, Type type,
Context context) {
public boolean checkcast(JavaValue target, Type type, Context context) {
if (type.getDescriptor().equals("[C"))
if (!(target.value() instanceof char[]))
return false;
return true;
}

@Override
public boolean checkEquality(JavaValue first, JavaValue second,
Context context) {
public boolean checkEquality(JavaValue first, JavaValue second, Context context) {
return true;
}

@Override
public boolean canCheckInstanceOf(JavaValue target, Type type,
Context context) {
public boolean canCheckInstanceOf(JavaValue target, Type type, Context context) {
return true;
}

@Override
public boolean canCheckcast(JavaValue target, Type type,
Context context) {
public boolean canCheckcast(JavaValue target, Type type, Context context) {
return true;
}

@Override
public boolean canCheckEquality(JavaValue first, JavaValue second,
Context context) {
public boolean canCheckEquality(JavaValue first, JavaValue second, Context context) {
return false;
}
});
Expand All @@ -176,15 +160,10 @@ public boolean canCheckEquality(JavaValue first, JavaValue second,
|| (dyn.bsmArgs.length == 7 && dyn.bsmArgs[0].equals("1") && dyn.bsmArgs[1].equals("8")))
{
Handle bootstrap = dyn.bsm;
ClassNode bootstrapClassNode =
classes.get(bootstrap.getOwner());
MethodNode bootstrapMethodNode =
bootstrapClassNode.methods.stream()
.filter(mn -> mn.name
.equals(bootstrap.getName())
&& mn.desc
.equals(bootstrap.getDesc()))
.findFirst().orElse(null);
ClassNode bootstrapClassNode = classes.get(bootstrap.getOwner());
MethodNode bootstrapMethodNode = bootstrapClassNode.methods.stream()
.filter(mn -> mn.name.equals(bootstrap.getName()) && mn.desc.equals(bootstrap.getDesc()))
.findFirst().orElse(null);
List<JavaValue> args = new ArrayList<>();
args.add(new JavaObject(null, "java/lang/invoke/MethodHandles$Lookup")); // Lookup
args.add(JavaValue.valueOf(dyn.name)); // dyn
Expand All @@ -200,21 +179,15 @@ public boolean canCheckEquality(JavaValue first, JavaValue second,
Object o = dyn.bsmArgs[i1];
if (o.getClass() == Type.class) {
Type type = (Type) o;
args.add(JavaValue.valueOf(new JavaClass(
type.getInternalName().replace('/', '.'), context)));
} else if (o.getClass() == Integer.class && Type.getArgumentTypes(bootstrapMethodNode.desc)[i1 + 3]
.getSort() == Type.INT)
args.add(JavaValue.valueOf(new JavaClass(type.getInternalName().replace('/', '.'), context)));
} else if (o.getClass() == Integer.class && Type.getArgumentTypes(bootstrapMethodNode.desc)[i1 + 3].getSort() == Type.INT)
args.add(new JavaInteger((int) o));
else
args.add(JavaValue.valueOf(o));
}
try {
JavaMethodHandle result =
MethodExecutor.execute(bootstrapClassNode,
bootstrapMethodNode, args, null,
context);
String clazz =
result.clazz.replace('.', '/');
JavaMethodHandle result = MethodExecutor.execute(bootstrapClassNode, bootstrapMethodNode, args, null, context);
String clazz = result.clazz.replace('.', '/');
MethodInsnNode replacement = null;
switch (result.type) {
case "virtual":
Expand All @@ -236,20 +209,18 @@ public boolean canCheckEquality(JavaValue first, JavaValue second,
if (replacement.getNext() != null
&& replacement.getNext().getOpcode() == Opcodes.CHECKCAST
&& Type.getReturnType(replacement.desc).getDescriptor().equals(((TypeInsnNode) replacement.getNext()).desc))
{
methodNode.instructions.remove(replacement.getNext());
}
total.incrementAndGet();
int x = (int) (total.get() * 1.0d / expected
* 100);
int x = (int) (total.get() * 1.0d / expected * 100);
if (x != 0 && x % 10 == 0 && !alerted[x - 1]) {
System.out.println(
"[AntiReleak] [InvokedynamicTransformer] Done "
+ x + "%");
System.out.println("[AntiReleak] [InvokedynamicTransformer] Done " + x + "%");
alerted[x - 1] = true;
}
} catch (ExecutionException ex) {
if (ex.getCause() != null)
ex.getCause()
.printStackTrace(System.out);
ex.getCause().printStackTrace(System.out);
throw ex;
} catch (Throwable t) {
System.out.println(classNode.name);
Expand Down Expand Up @@ -278,14 +249,13 @@ private void patchIndyTransformer(ClassNode classNode) {
InsnList list = Utils.cloneInsnList(node.instructions);
boolean patched = false;
for (int i = 0; i < node.instructions.size(); i++) {
AbstractInsnNode abstractInsnNode =
node.instructions.get(i);
AbstractInsnNode abstractInsnNode = node.instructions.get(i);
if (abstractInsnNode.getType() == AbstractInsnNode.LABEL) {
LabelNode labelNode = (LabelNode) abstractInsnNode;
AbstractInsnNode after3 = labelNode.getNext().getNext().getNext();
if (labelNode.getNext() != null && labelNode.getNext().getOpcode() == Opcodes.ASTORE
&& labelNode.getNext().getNext() != null && labelNode.getNext().getNext().getOpcode() ==
Opcodes.ALOAD && after3 != null && after3.getOpcode() == Opcodes.INVOKESTATIC)
&& labelNode.getNext().getNext() != null && labelNode.getNext().getNext().getOpcode() == Opcodes.ALOAD
&& after3 != null && after3.getOpcode() == Opcodes.INVOKESTATIC)
{
//Start at label 80 by removing everything before it
while (labelNode.getPrevious() != null) {
Expand All @@ -295,8 +265,7 @@ private void patchIndyTransformer(ClassNode classNode) {
node.instructions.remove(labelNode.getNext());
//5 nexts from after3 if the beginning of label 101
AbstractInsnNode after8 = after3.getNext().getNext().getNext().getNext().getNext();
if (after8 != null
&& after8.getType() == AbstractInsnNode.LABEL)
if (after8 != null && after8.getType() == AbstractInsnNode.LABEL)
if (after8.getNext() != null
&& after8.getNext().getOpcode() == Opcodes.ILOAD
&& after8.getNext().getNext() != null
Expand All @@ -319,8 +288,7 @@ private void patchIndyTransformer(ClassNode classNode) {
InsnList list = Utils.cloneInsnList(node.instructions);
boolean patched = false;
for (int i = 0; i < node.instructions.size(); i++) {
AbstractInsnNode ain =
node.instructions.get(i);
AbstractInsnNode ain = node.instructions.get(i);
if (ain.getOpcode() == Opcodes.INVOKESTATIC) {
MethodInsnNode methodInsn = (MethodInsnNode) ain;
if (methodInsn.desc.equals("(Ljava/lang/String;)Ljava/lang/String;")
Expand All @@ -335,8 +303,7 @@ private void patchIndyTransformer(ClassNode classNode) {
}
if (patched) {
for (int i = 0; i < node.instructions.size(); i++) {
AbstractInsnNode ain =
node.instructions.get(i);
AbstractInsnNode ain = node.instructions.get(i);
if (ain.getOpcode() == Opcodes.INVOKEVIRTUAL
&& ((MethodInsnNode) ain).owner.equals("java/lang/Object")
&& ((MethodInsnNode) ain).name.equals("equals"))
Expand Down Expand Up @@ -529,8 +496,7 @@ else if (ain.getOpcode() == Opcodes.LDC
InsnList list = Utils.cloneInsnList(node.instructions);
boolean patched = false;
for (int i = 0; i < node.instructions.size(); i++) {
AbstractInsnNode ain =
node.instructions.get(i);
AbstractInsnNode ain = node.instructions.get(i);
if (ain.getOpcode() == Opcodes.CHECKCAST
&& ((TypeInsnNode) ain).desc.equals("java/lang/invoke/MethodHandles$Lookup")
&& Utils.getPrevious(ain) != null
Expand All @@ -556,8 +522,7 @@ else if (ain.getOpcode() == Opcodes.LDC
InsnList list = Utils.cloneInsnList(node.instructions);
boolean patched = false;
for (int i = 0; i < node.instructions.size(); i++) {
AbstractInsnNode ain =
node.instructions.get(i);
AbstractInsnNode ain = node.instructions.get(i);
if (ain.getOpcode() == Opcodes.NEW
&& ((TypeInsnNode) ain).desc.equals("java/util/zip/ZipFile"))
{
Expand Down Expand Up @@ -658,11 +623,10 @@ private void cleanup() {
}
}
for (MethodInsnNode insnNode : bootstrapReferences) {
MethodNode method = classNode.methods.stream().filter(
m -> m.name.equals(insnNode.name) && m.desc.equals(insnNode.desc)).findFirst().orElse(null);
if (method != null && method.desc
.equals("(Ljava/lang/String;)Ljava/lang/String;"))
{
MethodNode method = classNode.methods.stream()
.filter(m -> m.name.equals(insnNode.name) && m.desc.equals(insnNode.desc))
.findFirst().orElse(null);
if (method != null && method.desc.equals("(Ljava/lang/String;)Ljava/lang/String;")) {
classNode.methods.remove(method);
hasIndyNode = true;
}
Expand All @@ -681,8 +645,9 @@ private void cleanup() {
}
}
for (FieldInsnNode fieldInsn : bootstrapFieldReferences) {
FieldNode field = classNode.fields.stream().filter(f -> f.name.equals(fieldInsn.name)
&& f.desc.equals(fieldInsn.desc)).findFirst().orElse(null);
FieldNode field = classNode.fields.stream()
.filter(f -> f.name.equals(fieldInsn.name) && f.desc.equals(fieldInsn.desc))
.findFirst().orElse(null);
if (field != null) {
if (classNode.fields.indexOf(field) < classNode.fields.size() - 1) {
FieldNode next = classNode.fields.get(classNode.fields.indexOf(field) + 1);
Expand Down
Loading

0 comments on commit 2687779

Please sign in to comment.