Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use args map instead of method args in generated java code #9779

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.kafka.connect.data.Schema;
import org.codehaus.commons.compiler.CompileException;
import org.codehaus.commons.compiler.CompilerFactoryFactory;
Expand Down Expand Up @@ -164,8 +163,7 @@ public CompiledExpression buildCodeGenFromParseTree(

final Class<?> expressionType = SQL_TO_JAVA_TYPE_CONVERTER.toJavaType(returnType);

final IExpressionEvaluator ee =
cook(javaCode, expressionType, spec.argumentNames(), spec.argumentTypes());
final IExpressionEvaluator ee = cook(javaCode, expressionType);

return new CompiledExpression(ee, spec, returnType, expression);
} catch (KsqlException | CompileException e) {
Expand All @@ -185,17 +183,15 @@ public CompiledExpression buildCodeGenFromParseTree(
@VisibleForTesting
public static IExpressionEvaluator cook(
final String javaCode,
final Class<?> expressionType,
final String[] argNames,
final Class<?>[] argTypes
final Class<?> expressionType
) throws Exception {
final IExpressionEvaluator ee = CompilerFactoryFactory.getDefaultCompilerFactory()
.newExpressionEvaluator();

ee.setDefaultImports(SqlToJavaVisitor.JAVA_IMPORTS.toArray(new String[0]));
ee.setParameters(
ArrayUtils.addAll(argNames, "defaultValue", "logger", "row"),
ArrayUtils.addAll(argTypes, Object.class, ProcessingLogger.class, GenericRow.class)
new String[]{"arguments", "defaultValue", "logger", "row"},
new Class[]{Map.class, Object.class, ProcessingLogger.class, GenericRow.class}
);
ee.setExpressionType(expressionType);
ee.cook(javaCode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,6 @@ private CodeGenSpec(
this.structToCodeName = structToCodeName;
}

public String[] argumentNames() {
return arguments.stream().map(ArgumentSpec::name).toArray(String[]::new);
}

public Class<?>[] argumentTypes() {
return arguments.stream().map(ArgumentSpec::type).toArray(Class[]::new);
}

@SuppressFBWarnings(value = "EI_EXPOSE_REP", justification = "arguments is ImmutableList")
public List<ArgumentSpec> arguments() {
return arguments;
Expand All @@ -81,10 +73,14 @@ public String getUniqueNameForFunction(final FunctionName functionName, final in
return names.get(index);
}

public void resolve(final GenericRow row, final Object[] parameters) {
public Map<String, Object> resolveArguments(final GenericRow row) {
final Map<String, Object> resolvedArguments = new HashMap<>(arguments.size());
for (int paramIdx = 0; paramIdx < arguments.size(); paramIdx++) {
parameters[paramIdx] = arguments.get(paramIdx).resolve(row);
final String name = arguments.get(paramIdx).name();
final Object value = arguments.get(paramIdx).resolve(row);
resolvedArguments.put(name, value);
}
return resolvedArguments;
}

public String getStructSchemaName(final CreateStructExpression createStructExpression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package io.confluent.ksql.execution.codegen;

import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.ksql.SchemaConverters;
import io.confluent.ksql.schema.ksql.types.SqlType;

public final class CodeGenUtil {

Expand All @@ -37,4 +39,15 @@ public static String functionName(final FunctionName fun, final int index) {
return fun.text() + "_" + index;
}

public static String argumentAccessor(final String name,
final SqlType type) {
final Class<?> javaType = SchemaConverters.sqlToJavaConverter().toJavaType(type);
return argumentAccessor(name, javaType);
}

public static String argumentAccessor(final String name,
final Class<?> type) {
return String.format("((%s) arguments.get(\"%s\"))", type.getCanonicalName(), name);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
import org.apache.commons.lang3.ArrayUtils;
import org.codehaus.commons.compiler.IExpressionEvaluator;

@Immutable
Expand All @@ -37,7 +36,6 @@ public class CompiledExpression implements ExpressionEvaluator {
@EffectivelyImmutable
private final IExpressionEvaluator expressionEvaluator;
private final SqlType expressionType;
private final ThreadLocal<Object[]> threadLocalParameters;
private final Expression expression;
private final CodeGenSpec spec;

Expand All @@ -51,7 +49,6 @@ public CompiledExpression(
this.expressionType = Objects.requireNonNull(expressionType, "expressionType");
this.expression = Objects.requireNonNull(expression, "expression");
this.spec = Objects.requireNonNull(spec, "spec");
this.threadLocalParameters = ThreadLocal.withInitial(() -> new Object[spec.arguments().size()]);
}

public List<ArgumentSpec> arguments() {
Expand Down Expand Up @@ -85,8 +82,12 @@ public Object evaluate(
final Supplier<String> errorMsg
) {
try {
return expressionEvaluator.evaluate(
ArrayUtils.addAll(getParameters(row), defaultValue, logger, row));
return expressionEvaluator.evaluate(new Object[]{
spec.resolveArguments(row),
defaultValue,
logger,
row
});
} catch (final Exception e) {
final Throwable cause = e instanceof InvocationTargetException
? e.getCause()
Expand All @@ -96,10 +97,4 @@ public Object evaluate(
return defaultValue;
}
}

private Object[] getParameters(final GenericRow row) {
final Object[] parameters = threadLocalParameters.get();
spec.resolve(row, parameters);
return parameters;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
import io.confluent.ksql.function.types.ArrayType;
import io.confluent.ksql.function.types.ParamType;
import io.confluent.ksql.function.types.ParamTypes;
import io.confluent.ksql.function.udf.Kudf;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.name.FunctionName;
import io.confluent.ksql.schema.Operator;
Expand Down Expand Up @@ -124,6 +125,7 @@
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringEscapeUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.kafka.connect.data.Schema;
import org.apache.kafka.connect.data.SchemaBuilder;
import org.apache.kafka.connect.data.Struct;

Expand Down Expand Up @@ -465,7 +467,9 @@ public Pair<String, SqlType> visitUnqualifiedColumnReference(
.orElseThrow(() ->
new KsqlException("Field not found: " + node.getColumnName()));

return new Pair<>(colRefToCodeName.apply(fieldName), schemaColumn.type());
final String codeName = colRefToCodeName.apply(fieldName);
final String paramAccessor = CodeGenUtil.argumentAccessor(codeName, schemaColumn.type());
return new Pair<>(paramAccessor, schemaColumn.type());
}

@Override
Expand Down Expand Up @@ -515,6 +519,7 @@ public Pair<String, SqlType> visitFunctionCall(
) {
final FunctionName functionName = node.getName();
final String instanceName = funNameToCodeName.apply(functionName);
final String functionAccessor = CodeGenUtil.argumentAccessor(instanceName, Kudf.class);
final UdfFactory udfFactory = functionRegistry.getUdfFactory(node.getName());
final FunctionTypeInfo argumentsAndContext = FunctionArgumentsUtil
.getFunctionTypeInfo(
Expand Down Expand Up @@ -561,7 +566,7 @@ public Pair<String, SqlType> visitFunctionCall(
}

final String argumentsString = joiner.toString();
final String codeString = "((" + javaReturnType + ") " + instanceName
final String codeString = "((" + javaReturnType + ") " + functionAccessor
+ ".evaluate(" + argumentsString + "))";
return new Pair<>(codeString, returnType);
}
Expand Down Expand Up @@ -1165,7 +1170,10 @@ public Pair<String, SqlType> visitStructExpression(
final Context context
) {
final String schemaName = structToCodeName.apply(node);
final StringBuilder struct = new StringBuilder("new Struct(").append(schemaName).append(")");
final String schemaAccessor = CodeGenUtil.argumentAccessor(schemaName, Schema.class);
final StringBuilder struct = new StringBuilder("new Struct(")
.append(schemaAccessor)
.append(")");
for (final Field field : node.getFields()) {
struct.append(".put(")
.append('"')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,9 @@

import static java.util.Objects.requireNonNull;

import com.google.common.collect.ImmutableList;
import io.confluent.ksql.GenericRow;
import io.confluent.ksql.logging.processing.ProcessingLogger;
import java.lang.reflect.InvocationTargetException;
import java.util.Collections;
import java.util.List;
import org.apache.commons.lang3.ArrayUtils;
import java.util.Map;
import org.codehaus.commons.compiler.IExpressionEvaluator;

public final class CodeGenTestUtil {
Expand All @@ -22,77 +18,27 @@ public static Object cookAndEval(
return cookAndEval(
javaCode,
resultType,
ImmutableList.of(),
ImmutableList.of(),
ImmutableList.of()
Collections.emptyMap()
);
}

public static Object cookAndEval(
final String javaCode,
final Class<?> resultType,
final String argName,
final Class<?> argType,
final Object arg
final Map<String, Object> args
) {
return cookAndEval(
javaCode,
resultType,
ImmutableList.of(argName),
ImmutableList.of(argType),
Collections.singletonList(arg)
);
}

public static Object cookAndEval(
final String javaCode,
final Class<?> resultType,
final List<String> argNames,
final List<Class<?>> argTypes,
final List<Object> args
) {
final Evaluator evaluator = CodeGenTestUtil.cookCode(javaCode, resultType, argNames, argTypes);
final Evaluator evaluator = CodeGenTestUtil.cookCode(javaCode, resultType);
return evaluator.evaluate(args);
}

public static Evaluator cookCode(
final String javaCode,
final Class<?> resultType
) {
return cookCode(
javaCode,
resultType,
ImmutableList.of(),
ImmutableList.of()
);
}

public static Evaluator cookCode(
final String javaCode,
final Class<?> resultType,
final String argName,
final Class<?> argType
) {
return cookCode(
javaCode,
resultType,
ImmutableList.of(argName),
ImmutableList.of(argType)
);
}

public static Evaluator cookCode(
final String javaCode,
final Class<?> resultType,
final List<String> argNames,
final List<Class<?>> argTypes
) {
try {
final IExpressionEvaluator ee = CodeGenRunner.cook(
javaCode,
resultType,
argNames.toArray(new String[0]),
argTypes.toArray(new Class<?>[0])
resultType
);

return new Evaluator(ee, javaCode);
Expand All @@ -116,11 +62,15 @@ public Evaluator(final IExpressionEvaluator ee, final String javaCode) {
this.javaCode = requireNonNull(javaCode, "javaCode");
}

public Object evaluate(final Object arg) {
return evaluate(Collections.singletonList(arg));
public Object evaluate() {
return evaluate(Collections.emptyMap());
}

public Object evaluate(final String argName, final Object argValue) {
return evaluate(Collections.singletonMap(argName, argValue));
}

public Object evaluate(final List<?> args) {
public Object evaluate(final Map<String, Object> args) {
try {
return rawEvaluate(args);
} catch (final Exception e) {
Expand All @@ -133,13 +83,13 @@ public Object evaluate(final List<?> args) {
}
}

public Object rawEvaluate(final Object arg) throws Exception {
return rawEvaluate(Collections.singletonList(arg));
public Object rawEvaluate(final String argName, final Object argValue) throws Exception {
return rawEvaluate(Collections.singletonMap(argName, argValue));
}

public Object rawEvaluate(final List<?> args) throws Exception {
public Object rawEvaluate(final Map<String, Object> args) throws Exception {
try {
return ee.evaluate(ArrayUtils.addAll(args == null ? new Object[]{null} : args.toArray(), null, null, null));
return ee.evaluate(new Object[]{args, null, null, null});
} catch (final InvocationTargetException e) {
throw e.getTargetException() instanceof Exception
? (Exception) e.getTargetException()
Expand Down
Loading