From 7a793fa5f7c9dbe0913aadce4250ac7f6d6f9dc5 Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Fri, 15 Sep 2023 13:41:52 -0700 Subject: [PATCH] Add MutableAst PiperOrigin-RevId: 565767218 --- .../main/java/dev/cel/common/ast/CelExpr.java | 125 ++++++- .../common/ast/CelExprIdGeneratorFactory.java | 11 +- .../java/dev/cel/common/ast/CelExprTest.java | 23 ++ optimizer/BUILD.bazel | 38 ++ .../main/java/dev/cel/optimizer/BUILD.bazel | 96 +++++ .../dev/cel/optimizer/CelAstOptimizer.java | 49 +++ .../optimizer/CelOptimizationException.java | 23 ++ .../java/dev/cel/optimizer/CelOptimizer.java | 36 ++ .../cel/optimizer/CelOptimizerBuilder.java | 34 ++ .../cel/optimizer/CelOptimizerFactory.java | 47 +++ .../dev/cel/optimizer/CelOptimizerImpl.java | 91 +++++ .../java/dev/cel/optimizer/MutableAst.java | 164 +++++++++ .../test/java/dev/cel/optimizer/BUILD.bazel | 39 +++ .../optimizer/CelOptimizerFactoryTest.java | 61 ++++ .../cel/optimizer/CelOptimizerImplTest.java | 131 +++++++ .../dev/cel/optimizer/MutableAstTest.java | 331 ++++++++++++++++++ .../src/main/java/dev/cel/parser/BUILD.bazel | 1 + .../dev/cel/parser/CelUnparserFactory.java | 24 ++ .../validators/DurationLiteralValidator.java | 2 +- .../HomogeneousLiteralValidator.java | 2 +- .../validators/RegexLiteralValidator.java | 2 +- .../validators/TimestampLiteralValidator.java | 2 +- .../cel/validator/CelValidatorImplTest.java | 2 - 23 files changed, 1320 insertions(+), 14 deletions(-) create mode 100644 optimizer/BUILD.bazel create mode 100644 optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel create mode 100644 optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java create mode 100644 optimizer/src/main/java/dev/cel/optimizer/CelOptimizationException.java create mode 100644 optimizer/src/main/java/dev/cel/optimizer/CelOptimizer.java create mode 100644 optimizer/src/main/java/dev/cel/optimizer/CelOptimizerBuilder.java create mode 100644 optimizer/src/main/java/dev/cel/optimizer/CelOptimizerFactory.java create mode 100644 optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java create mode 100644 optimizer/src/main/java/dev/cel/optimizer/MutableAst.java create mode 100644 optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel create mode 100644 optimizer/src/test/java/dev/cel/optimizer/CelOptimizerFactoryTest.java create mode 100644 optimizer/src/test/java/dev/cel/optimizer/CelOptimizerImplTest.java create mode 100644 optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java create mode 100644 parser/src/main/java/dev/cel/parser/CelUnparserFactory.java diff --git a/common/src/main/java/dev/cel/common/ast/CelExpr.java b/common/src/main/java/dev/cel/common/ast/CelExpr.java index 3e9f0bc17..82d3e42eb 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExpr.java +++ b/common/src/main/java/dev/cel/common/ast/CelExpr.java @@ -221,6 +221,80 @@ public abstract static class Builder { public abstract Builder setExprKind(ExprKind value); + public abstract ExprKind exprKind(); + + /** + * Gets the underlying constant expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#CONSTANT}. + */ + public CelConstant constant() { + return exprKind().constant(); + } + + /** + * Gets the underlying identifier expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#IDENT}. + */ + public CelIdent ident() { + return exprKind().ident(); + } + + /** + * Gets the underlying select expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#SELECT}. + */ + public CelSelect select() { + return exprKind().select(); + } + + /** + * Gets the underlying call expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#CALL}. + */ + public CelCall call() { + return exprKind().call(); + } + + /** + * Gets the underlying createList expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#CREATE_LIST}. + */ + public CelCreateList createList() { + return exprKind().createList(); + } + + /** + * Gets the underlying createStruct expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#CREATE_STRUCT}. + */ + public CelCreateStruct createStruct() { + return exprKind().createStruct(); + } + + /** + * Gets the underlying createMap expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#createMap}. + */ + public CelCreateMap createMap() { + return exprKind().createMap(); + } + + /** + * Gets the underlying comprehension expression. + * + * @throws UnsupportedOperationException if expression is not {@link Kind#COMPREHENSION}. + */ + public CelComprehension comprehension() { + return exprKind().comprehension(); + } + public Builder setConstant(CelConstant constant) { return setExprKind(AutoOneOf_CelExpr_ExprKind.constant(constant)); } @@ -373,6 +447,11 @@ public abstract static class CelSelect { /** Builder for CelSelect. */ @AutoValue.Builder public abstract static class Builder { + public abstract CelExpr operand(); + + public abstract String field(); + + public abstract boolean testOnly(); public abstract Builder setOperand(CelExpr value); @@ -418,9 +497,9 @@ public abstract static class CelCall { /** Builder for CelCall. */ @AutoValue.Builder public abstract static class Builder { - List mutableArgs = new ArrayList<>(); + private List mutableArgs = new ArrayList<>(); - abstract ImmutableList args(); + public abstract ImmutableList args(); public abstract Builder setTarget(CelExpr value); @@ -428,6 +507,8 @@ public abstract static class Builder { public abstract Builder setFunction(String value); + public abstract Optional target(); + // Not public. This only exists to make AutoValue.Builder work. abstract Builder setArgs(ImmutableList value); @@ -501,16 +582,23 @@ public abstract static class CelCreateList { /** Builder for CelCreateList. */ @AutoValue.Builder public abstract static class Builder { - List mutableElements = new ArrayList<>(); + private List mutableElements = new ArrayList<>(); + // Not public. This only exists to make AutoValue.Builder work. abstract ImmutableList elements(); + // Not public. This only exists to make AutoValue.Builder work. abstract ImmutableList.Builder optionalIndicesBuilder(); // Not public. This only exists to make AutoValue.Builder work. @CanIgnoreReturnValue abstract Builder setElements(ImmutableList elements); + /** Returns an immutable copy of the current mutable elements present in the builder. */ + public ImmutableList getElements() { + return ImmutableList.copyOf(mutableElements); + } + @CanIgnoreReturnValue public Builder setElement(int index, CelExpr element) { checkNotNull(element); @@ -586,8 +674,9 @@ public abstract static class CelCreateStruct { /** Builder for CelCreateStruct. */ @AutoValue.Builder public abstract static class Builder { - List mutableEntries = new ArrayList<>(); + private List mutableEntries = new ArrayList<>(); + // Not public. This only exists to make AutoValue.Builder work. abstract ImmutableList entries(); @CanIgnoreReturnValue @@ -597,6 +686,11 @@ public abstract static class Builder { @CanIgnoreReturnValue abstract Builder setEntries(ImmutableList entries); + /** Returns an immutable copy of the current mutable entries present in the builder. */ + public ImmutableList getEntries() { + return ImmutableList.copyOf(mutableEntries); + } + @CanIgnoreReturnValue public Builder setEntry(int index, CelCreateStruct.Entry entry) { checkNotNull(entry); @@ -669,6 +763,8 @@ public abstract static class Entry { @AutoValue.Builder public abstract static class Builder { + public abstract CelExpr value(); + public abstract Builder setId(long value); public abstract Builder setFieldKey(String value); @@ -704,14 +800,20 @@ public abstract static class CelCreateMap { @AutoValue.Builder public abstract static class Builder { - List mutableEntries = new ArrayList<>(); + private List mutableEntries = new ArrayList<>(); + // Not public. This only exists to make AutoValue.Builder work. abstract ImmutableList entries(); // Not public. This only exists to make AutoValue.Builder work. @CanIgnoreReturnValue abstract Builder setEntries(ImmutableList entries); + /** Returns an immutable copy of the current mutable entries present in the builder. */ + public ImmutableList getEntries() { + return ImmutableList.copyOf(mutableEntries); + } + @CanIgnoreReturnValue public Builder setEntry(int index, CelCreateMap.Entry entry) { checkNotNull(entry); @@ -784,6 +886,10 @@ public abstract static class Entry { @AutoValue.Builder public abstract static class Builder { + public abstract CelExpr key(); + + public abstract CelExpr value(); + public abstract CelCreateMap.Entry.Builder setId(long value); public abstract CelCreateMap.Entry.Builder setKey(CelExpr value); @@ -868,6 +974,15 @@ public abstract static class CelComprehension { /** Builder for Comprehension. */ @AutoValue.Builder public abstract static class Builder { + public abstract CelExpr iterRange(); + + public abstract CelExpr accuInit(); + + public abstract CelExpr loopCondition(); + + public abstract CelExpr loopStep(); + + public abstract CelExpr result(); public abstract Builder setIterVar(String value); diff --git a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java index d7f312d27..493bdc2de 100644 --- a/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java +++ b/common/src/main/java/dev/cel/common/ast/CelExprIdGeneratorFactory.java @@ -18,10 +18,15 @@ import java.util.HashMap; /** Factory for populating expression IDs */ -final class CelExprIdGeneratorFactory { +public final class CelExprIdGeneratorFactory { - /** MonotonicIdGenerator increments expression IDs from an initial seed value. */ - static CelExprIdGenerator newMonotonicIdGenerator(long exprId) { + /** + * MonotonicIdGenerator increments expression IDs from an initial seed value. + * + * @param exprId Seed value. Must be non-negative. For example, if 1 is provided {@link + * CelExprIdGenerator#nextExprId} will return 2. + */ + public static CelExprIdGenerator newMonotonicIdGenerator(long exprId) { return new MonotonicIdGenerator(exprId); } diff --git a/common/src/test/java/dev/cel/common/ast/CelExprTest.java b/common/src/test/java/dev/cel/common/ast/CelExprTest.java index 42a40d7a8..12c9e972e 100644 --- a/common/src/test/java/dev/cel/common/ast/CelExprTest.java +++ b/common/src/test/java/dev/cel/common/ast/CelExprTest.java @@ -111,6 +111,7 @@ public void celExprBuilder_setConstant() { CelExpr celExpr = CelExpr.newBuilder().setConstant(celConstant).build(); assertThat(celExpr.constant()).isEqualTo(celConstant); + assertThat(celExpr.toBuilder().constant()).isEqualTo(celConstant); } @Test @@ -119,6 +120,7 @@ public void celExprBuilder_setIdent() { CelExpr celExpr = CelExpr.newBuilder().setIdent(celIdent).build(); assertThat(celExpr.ident()).isEqualTo(celIdent); + assertThat(celExpr.toBuilder().ident()).isEqualTo(celIdent); } @Test @@ -131,6 +133,7 @@ public void celExprBuilder_setCall() { CelExpr celExpr = CelExpr.newBuilder().setCall(celCall).build(); assertThat(celExpr.call()).isEqualTo(celCall); + assertThat(celExpr.toBuilder().call()).isEqualTo(celCall); } @Test @@ -144,6 +147,8 @@ public void celExprBuilder_setCall_clearTarget() { CelExpr.newBuilder().setCall(celCall.toBuilder().clearTarget().build()).build(); assertThat(celExpr.call()).isEqualTo(CelCall.newBuilder().setFunction("function").build()); + assertThat(celExpr.toBuilder().call()) + .isEqualTo(CelCall.newBuilder().setFunction("function").build()); } @Test @@ -182,6 +187,7 @@ public void celExprBuilder_setSelect() { assertThat(celExpr.select().testOnly()).isFalse(); assertThat(celExpr.select()).isEqualTo(celSelect); + assertThat(celExpr.toBuilder().select()).isEqualTo(celSelect); } @Test @@ -193,6 +199,7 @@ public void celExprBuilder_setCreateList() { CelExpr celExpr = CelExpr.newBuilder().setCreateList(celCreateList).build(); assertThat(celExpr.createList()).isEqualTo(celCreateList); + assertThat(celExpr.toBuilder().createList()).isEqualTo(celCreateList); } @Test @@ -236,6 +243,7 @@ public void celExprBuilder_setCreateStruct() { assertThat(celExpr.createStruct().entries().get(0).optionalEntry()).isFalse(); assertThat(celExpr.createStruct()).isEqualTo(celCreateStruct); + assertThat(celExpr.toBuilder().createStruct()).isEqualTo(celCreateStruct); } @Test @@ -309,6 +317,7 @@ public void celExprBuilder_setComprehension() { CelExpr celExpr = CelExpr.newBuilder().setComprehension(celComprehension).build(); assertThat(celExpr.comprehension()).isEqualTo(celComprehension); + assertThat(celExpr.toBuilder().comprehension()).isEqualTo(celComprehension); } @Test @@ -316,30 +325,44 @@ public void getUnderlyingExpression_unmatchedKind_throws( @TestParameter BuilderExprKindTestCase testCase) { if (!testCase.expectedExprKind.equals(Kind.NOT_SET)) { assertThrows(UnsupportedOperationException.class, () -> testCase.expr.exprKind().notSet()); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().exprKind().notSet()); } if (!testCase.expectedExprKind.equals(Kind.CONSTANT)) { assertThrows(UnsupportedOperationException.class, testCase.expr::constant); + assertThrows(UnsupportedOperationException.class, () -> testCase.expr.toBuilder().constant()); } if (!testCase.expectedExprKind.equals(Kind.IDENT)) { assertThrows(UnsupportedOperationException.class, testCase.expr::ident); + assertThrows(UnsupportedOperationException.class, () -> testCase.expr.toBuilder().ident()); } if (!testCase.expectedExprKind.equals(Kind.SELECT)) { assertThrows(UnsupportedOperationException.class, testCase.expr::select); + assertThrows(UnsupportedOperationException.class, () -> testCase.expr.toBuilder().select()); } if (!testCase.expectedExprKind.equals(Kind.CALL)) { assertThrows(UnsupportedOperationException.class, testCase.expr::call); + assertThrows(UnsupportedOperationException.class, () -> testCase.expr.toBuilder().call()); } if (!testCase.expectedExprKind.equals(Kind.CREATE_LIST)) { assertThrows(UnsupportedOperationException.class, testCase.expr::createList); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().createList()); } if (!testCase.expectedExprKind.equals(Kind.CREATE_STRUCT)) { assertThrows(UnsupportedOperationException.class, testCase.expr::createStruct); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().createStruct()); } if (!testCase.expectedExprKind.equals(Kind.CREATE_MAP)) { assertThrows(UnsupportedOperationException.class, testCase.expr::createMap); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().createMap()); } if (!testCase.expectedExprKind.equals(Kind.COMPREHENSION)) { assertThrows(UnsupportedOperationException.class, testCase.expr::comprehension); + assertThrows( + UnsupportedOperationException.class, () -> testCase.expr.toBuilder().comprehension()); } } diff --git a/optimizer/BUILD.bazel b/optimizer/BUILD.bazel new file mode 100644 index 000000000..49cfa4bad --- /dev/null +++ b/optimizer/BUILD.bazel @@ -0,0 +1,38 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = ["//visibility:public"], # TODO: Expose to public +) + +java_library( + name = "optimizer", + exports = ["//optimizer/src/main/java/dev/cel/optimizer"], +) + +java_library( + name = "optimizer_builder", + exports = ["//optimizer/src/main/java/dev/cel/optimizer:optimizer_builder"], +) + +java_library( + name = "ast_optimizer", + exports = ["//optimizer/src/main/java/dev/cel/optimizer:ast_optimizer"], +) + +java_library( + name = "optimization_exception", + exports = ["//optimizer/src/main/java/dev/cel/optimizer:optimization_exception"], +) + +java_library( + name = "mutable_ast", + testonly = 1, + visibility = ["//optimizer/src/test/java/dev/cel/optimizer:__pkg__"], + exports = ["//optimizer/src/main/java/dev/cel/optimizer:mutable_ast"], +) + +java_library( + name = "optimizer_impl", + testonly = 1, + visibility = ["//optimizer/src/test/java/dev/cel/optimizer:__pkg__"], + exports = ["//optimizer/src/main/java/dev/cel/optimizer:optimizer_impl"], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel new file mode 100644 index 000000000..88999315c --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/BUILD.bazel @@ -0,0 +1,96 @@ +package( + default_applicable_licenses = ["//:license"], + default_visibility = [ + "//visibility:public", + ], +) + +java_library( + name = "optimizer", + srcs = [ + "CelOptimizerFactory.java", + ], + tags = [ + ], + deps = [ + ":optimizer_impl", + "//bundle:cel", + "//checker:checker_builder", + "//compiler", + "//compiler:compiler_builder", + "//optimizer:optimizer_builder", + "//parser:parser_builder", + "//runtime", + ], +) + +java_library( + name = "optimizer_builder", + srcs = [ + "CelOptimizer.java", + "CelOptimizerBuilder.java", + ], + tags = [ + ], + deps = [ + ":ast_optimizer", + ":optimization_exception", + "//common", + "@maven//:com_google_errorprone_error_prone_annotations", + ], +) + +java_library( + name = "optimizer_impl", + srcs = [ + "CelOptimizerImpl.java", + ], + tags = [ + ], + deps = [ + ":ast_optimizer", + ":optimization_exception", + ":optimizer_builder", + "//bundle:cel", + "//common", + "//common:compiler_common", + "//common/navigation", + "@maven//:com_google_guava_guava", + ], +) + +java_library( + name = "ast_optimizer", + srcs = ["CelAstOptimizer.java"], + tags = [ + ], + deps = [ + ":mutable_ast", + "//bundle:cel", + "//checker:checker_builder", + "//common", + "//common:compiler_common", + "//common/ast", + "//common/navigation", + ], +) + +java_library( + name = "mutable_ast", + srcs = ["MutableAst.java"], + tags = [ + ], + deps = [ + "//common/annotations", + "//common/ast", + "//common/ast:expr_factory", + "@maven//:com_google_guava_guava", + ], +) + +java_library( + name = "optimization_exception", + srcs = ["CelOptimizationException.java"], + tags = [ + ], +) diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java new file mode 100644 index 000000000..21c095f40 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/CelAstOptimizer.java @@ -0,0 +1,49 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.navigation.CelNavigableAst; + +/** Public interface for performing a single, custom optimization on an AST. */ +public interface CelAstOptimizer { + + /** Optimizes a single AST. */ + CelAbstractSyntaxTree optimize(CelNavigableAst navigableAst, Cel cel); + + /** + * Replaces a subtree in the given CelExpr. This operation is intended for AST optimization + * purposes. + * + *

This is a very dangerous operation. Callers should re-typecheck the mutated AST and + * additionally verify that the resulting AST is semantically valid. + * + *

All expression IDs will be renumbered in a stable manner to ensure there's no ID collision + * between the nodes. The renumbering occurs even if the subtree was not replaced. + * + * @param ast Original ast to mutate. + * @param newExpr New CelExpr to replace the subtree with. + * @param exprIdToReplace Expression id of the subtree that is getting replaced. + */ + default CelAbstractSyntaxTree replaceSubtree( + CelAbstractSyntaxTree ast, CelExpr newExpr, long exprIdToReplace) { + CelExpr newRoot = MutableAst.replaceSubtree(ast.getExpr(), newExpr, exprIdToReplace); + + return CelAbstractSyntaxTree.newCheckedAst( + newRoot, ast.getSource(), ast.getReferenceMap(), ast.getTypeMap()); + } +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizationException.java b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizationException.java new file mode 100644 index 000000000..baa79e618 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizationException.java @@ -0,0 +1,23 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +/** Checked exception thrown by CelOptimizer during AST optimization. */ +public final class CelOptimizationException extends Exception { + + public CelOptimizationException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizer.java new file mode 100644 index 000000000..a47b83e98 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizer.java @@ -0,0 +1,36 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import dev.cel.common.CelAbstractSyntaxTree; + +/** Public interface for optimizing an AST. */ +public interface CelOptimizer { + + /** + * Performs custom optimization of the provided AST. + * + *

This invokes all the AST optimizers present in this CelOptimizer instance via {@link + * CelOptimizerBuilder#addAstOptimizers} in their added order. Any exceptions thrown within the + * AST optimizer will be propagated to the caller and will abort the optimization process. + * + *

Note that the produced expression string from unparsing an optimized AST will likely not be + * equal to the original expression. + * + * @param ast A type-checked AST. + * @throws CelOptimizationException If any failures occur during any of the AST optimization pass. + */ + CelAbstractSyntaxTree optimize(CelAbstractSyntaxTree ast) throws CelOptimizationException; +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerBuilder.java b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerBuilder.java new file mode 100644 index 000000000..abfed8f38 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerBuilder.java @@ -0,0 +1,34 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import com.google.errorprone.annotations.CheckReturnValue; + +/** Interface for building an instance of CelOptimizer. */ +public interface CelOptimizerBuilder { + + /** Adds one or more optimizer to perform custom AST optimizations. */ + @CanIgnoreReturnValue + CelOptimizerBuilder addAstOptimizers(CelAstOptimizer... astOptimizers); + + /** Adds one or more optimizer to perform custom AST optimizations. */ + @CanIgnoreReturnValue + CelOptimizerBuilder addAstOptimizers(Iterable astOptimizers); + + /** Build a new instance of the {@link CelOptimizer}. */ + @CheckReturnValue + CelOptimizer build(); +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerFactory.java b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerFactory.java new file mode 100644 index 000000000..1ebfd293e --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerFactory.java @@ -0,0 +1,47 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.checker.CelChecker; +import dev.cel.compiler.CelCompiler; +import dev.cel.compiler.CelCompilerFactory; +import dev.cel.parser.CelParser; +import dev.cel.runtime.CelRuntime; + +/** Factory class for constructing an {@link CelOptimizer} instance. */ +public final class CelOptimizerFactory { + + /** Create a new builder for constructing a {@link CelOptimizer} instance. */ + public static CelOptimizerBuilder standardCelOptimizerBuilder(Cel cel) { + return CelOptimizerImpl.newBuilder(cel); + } + + /** Create a new builder for constructing a {@link CelOptimizer} instance. */ + public static CelOptimizerBuilder standardCelOptimizerBuilder( + CelCompiler celCompiler, CelRuntime celRuntime) { + return standardCelOptimizerBuilder(CelFactory.combine(celCompiler, celRuntime)); + } + + /** Create a new builder for constructing a {@link CelOptimizer} instance. */ + public static CelOptimizerBuilder standardCelOptimizerBuilder( + CelParser celParser, CelChecker celChecker, CelRuntime celRuntime) { + return standardCelOptimizerBuilder( + CelCompilerFactory.combine(celParser, celChecker), celRuntime); + } + + private CelOptimizerFactory() {} +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java new file mode 100644 index 000000000..8ce74e1a3 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/CelOptimizerImpl.java @@ -0,0 +1,91 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.collect.ImmutableSet; +import dev.cel.bundle.Cel; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelValidationException; +import dev.cel.common.navigation.CelNavigableAst; +import java.util.Arrays; + +final class CelOptimizerImpl implements CelOptimizer { + private final Cel cel; + private final ImmutableSet astOptimizers; + + CelOptimizerImpl(Cel cel, ImmutableSet astOptimizers) { + this.cel = cel; + this.astOptimizers = astOptimizers; + } + + @Override + public CelAbstractSyntaxTree optimize(CelAbstractSyntaxTree ast) throws CelOptimizationException { + if (!ast.isChecked()) { + throw new IllegalArgumentException("AST must be type-checked."); + } + + CelAbstractSyntaxTree optimizedAst = ast; + try { + for (CelAstOptimizer optimizer : astOptimizers) { + CelNavigableAst navigableAst = CelNavigableAst.fromAst(ast); + optimizedAst = optimizer.optimize(navigableAst, cel); + optimizedAst = cel.check(optimizedAst).getAst(); + } + } catch (CelValidationException e) { + throw new CelOptimizationException( + "Optimized AST failed to type-check: " + e.getMessage(), e); + } catch (RuntimeException e) { + throw new CelOptimizationException("Optimization failure: " + e.getMessage(), e); + } + + return optimizedAst; + } + + /** Create a new builder for constructing a {@link CelOptimizer} instance. */ + static CelOptimizerImpl.Builder newBuilder(Cel cel) { + return new CelOptimizerImpl.Builder(cel); + } + + /** Builder class for {@link CelOptimizerImpl}. */ + static final class Builder implements CelOptimizerBuilder { + private final Cel cel; + private final ImmutableSet.Builder astOptimizers; + + private Builder(Cel cel) { + this.cel = cel; + this.astOptimizers = ImmutableSet.builder(); + } + + @Override + public CelOptimizerBuilder addAstOptimizers(CelAstOptimizer... astOptimizers) { + checkNotNull(astOptimizers); + return addAstOptimizers(Arrays.asList(astOptimizers)); + } + + @Override + public CelOptimizerBuilder addAstOptimizers(Iterable astOptimizers) { + checkNotNull(astOptimizers); + this.astOptimizers.addAll(astOptimizers); + return this; + } + + @Override + public CelOptimizer build() { + return new CelOptimizerImpl(cel, astOptimizers.build()); + } + } +} diff --git a/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java new file mode 100644 index 000000000..03eb4f762 --- /dev/null +++ b/optimizer/src/main/java/dev/cel/optimizer/MutableAst.java @@ -0,0 +1,164 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import com.google.common.collect.ImmutableList; +import dev.cel.common.annotations.Internal; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; +import dev.cel.common.ast.CelExpr.CelCreateList; +import dev.cel.common.ast.CelExpr.CelCreateMap; +import dev.cel.common.ast.CelExpr.CelCreateStruct; +import dev.cel.common.ast.CelExpr.CelSelect; +import dev.cel.common.ast.CelExprIdGenerator; +import dev.cel.common.ast.CelExprIdGeneratorFactory; + +/** MutableAst contains logic for mutating a {@link CelExpr}. */ +@Internal +final class MutableAst { + private static final int MAX_ITERATION_COUNT = 500; + private final CelExpr.Builder newExpr; + private final long exprIdToReplace; + private final CelExprIdGenerator celExprIdGenerator; + private int iterationCount; + + private MutableAst(CelExprIdGenerator celExprIdGenerator, CelExpr.Builder newExpr, long exprId) { + this.celExprIdGenerator = celExprIdGenerator; + this.newExpr = newExpr; + this.exprIdToReplace = exprId; + } + + /** + * Replaces a subtree in the given CelExpr. This is a very dangerous operation. Callers should + * re-typecheck the mutated AST and additionally verify that the resulting AST is semantically + * valid. + * + *

This method should remain package-private. + */ + static CelExpr replaceSubtree(CelExpr root, CelExpr newExpr, long exprIdToReplace) { + // Zero out the expr IDs in the new expression tree first. This ensures that no ID collision + // occurs while attempting to replace the subtree, potentially leading to infinite loop + CelExpr.Builder newExprBuilder = newExpr.toBuilder(); + MutableAst mutableAst = new MutableAst(() -> 0, CelExpr.newBuilder(), -1); + newExprBuilder = mutableAst.visit(newExprBuilder); + + // Replace the subtree + mutableAst = + new MutableAst( + CelExprIdGeneratorFactory.newMonotonicIdGenerator(0), newExprBuilder, exprIdToReplace); + + // TODO: Normalize IDs for macro calls + + return mutableAst.visit(root.toBuilder()).build(); + } + + private CelExpr.Builder visit(CelExpr.Builder expr) { + if (++iterationCount > MAX_ITERATION_COUNT) { + throw new IllegalStateException("Max iteration count reached."); + } + + if (expr.id() == exprIdToReplace) { + return visit(newExpr); + } + + switch (expr.exprKind().getKind()) { + case SELECT: + return visit(expr, expr.select().toBuilder()); + case CALL: + return visit(expr, expr.call().toBuilder()); + case CREATE_LIST: + return visit(expr, expr.createList().toBuilder()); + case CREATE_STRUCT: + return visit(expr, expr.createStruct().toBuilder()); + case CREATE_MAP: + return visit(expr, expr.createMap().toBuilder()); + case COMPREHENSION: + throw new UnsupportedOperationException("Augmenting comprehensions are not supported"); + case CONSTANT: + case IDENT: + return expr.setId(celExprIdGenerator.nextExprId()); + default: + throw new IllegalArgumentException("unexpected expr kind: " + expr.exprKind().getKind()); + } + } + + private CelExpr.Builder visit(CelExpr.Builder celExpr, CelSelect.Builder selectExpr) { + CelExpr.Builder visitedOperand = visit(selectExpr.operand().toBuilder()); + selectExpr = selectExpr.setOperand(visitedOperand.build()); + + return celExpr.setSelect(selectExpr.build()).setId(celExprIdGenerator.nextExprId()); + } + + private CelExpr.Builder visit(CelExpr.Builder celExpr, CelCall.Builder callExpr) { + if (callExpr.target().isPresent()) { + CelExpr.Builder visitedTargetExpr = visit(callExpr.target().get().toBuilder()); + callExpr = callExpr.setTarget(visitedTargetExpr.build()); + + celExpr.setCall(callExpr.build()); + } + + ImmutableList args = callExpr.args(); + for (int i = 0; i < args.size(); i++) { + CelExpr arg = args.get(i); + CelExpr.Builder visitedArg = visit(arg.toBuilder()); + callExpr.setArg(i, visitedArg.build()); + } + + return celExpr.setCall(callExpr.build()).setId(celExprIdGenerator.nextExprId()); + } + + private CelExpr.Builder visit(CelExpr.Builder celExpr, CelCreateList.Builder createListBuilder) { + ImmutableList elements = createListBuilder.getElements(); + for (int i = 0; i < elements.size(); i++) { + CelExpr.Builder visitedElement = visit(elements.get(i).toBuilder()); + createListBuilder.setElement(i, visitedElement.build()); + } + + return celExpr.setCreateList(createListBuilder.build()).setId(celExprIdGenerator.nextExprId()); + } + + private CelExpr.Builder visit( + CelExpr.Builder celExpr, CelCreateStruct.Builder createStructBuilder) { + ImmutableList entries = createStructBuilder.getEntries(); + for (int i = 0; i < entries.size(); i++) { + CelCreateStruct.Entry.Builder entryBuilder = + entries.get(i).toBuilder().setId(celExprIdGenerator.nextExprId()); + CelExpr.Builder visitedValue = visit(entryBuilder.value().toBuilder()); + entryBuilder.setValue(visitedValue.build()); + + createStructBuilder.setEntry(i, entryBuilder.build()); + } + + return celExpr + .setCreateStruct(createStructBuilder.build()) + .setId(celExprIdGenerator.nextExprId()); + } + + private CelExpr.Builder visit(CelExpr.Builder celExpr, CelCreateMap.Builder createMapBuilder) { + ImmutableList entries = createMapBuilder.getEntries(); + for (int i = 0; i < entries.size(); i++) { + CelCreateMap.Entry.Builder entryBuilder = + entries.get(i).toBuilder().setId(celExprIdGenerator.nextExprId()); + CelExpr.Builder visitedKey = visit(entryBuilder.key().toBuilder()); + entryBuilder.setKey(visitedKey.build()); + CelExpr.Builder visitedValue = visit(entryBuilder.value().toBuilder()); + entryBuilder.setValue(visitedValue.build()); + + createMapBuilder.setEntry(i, entryBuilder.build()); + } + + return celExpr.setCreateMap(createMapBuilder.build()).setId(celExprIdGenerator.nextExprId()); + } +} diff --git a/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel b/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel new file mode 100644 index 000000000..075aaad2d --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/BUILD.bazel @@ -0,0 +1,39 @@ +load("//:testing.bzl", "junit4_test_suites") + +package(default_applicable_licenses = ["//:license"]) + +java_library( + name = "tests", + testonly = 1, + srcs = glob(["*.java"]), + deps = [ + "//:java_truth", + "//bundle:cel", + "//common", + "//common:compiler_common", + "//common/ast", + "//common/resources/testdata/proto3:test_all_types_java_proto", + "//common/types", + "//compiler", + "//optimizer", + "//optimizer:optimization_exception", + "//optimizer:optimizer_builder", + "//optimizer:optimizer_impl", + "//optimizer/src/main/java/dev/cel/optimizer:mutable_ast", + "//parser", + "//parser:macro", + "//parser:unparser", + "//runtime", + "@maven//:com_google_testparameterinjector_test_parameter_injector", + "@maven//:junit_junit", + ], +) + +junit4_test_suites( + name = "test_suites", + sizes = [ + "small", + ], + src_dir = "src/test/java", + deps = [":tests"], +) diff --git a/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerFactoryTest.java b/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerFactoryTest.java new file mode 100644 index 000000000..41c7ecd74 --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerFactoryTest.java @@ -0,0 +1,61 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import static com.google.common.truth.Truth.assertThat; + +import dev.cel.bundle.CelFactory; +import dev.cel.compiler.CelCompilerFactory; +import dev.cel.parser.CelParserFactory; +import dev.cel.runtime.CelRuntimeFactory; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CelOptimizerFactoryTest { + + @Test + public void standardCelOptimizerBuilder_withParserCheckerAndRuntime() { + CelOptimizerBuilder builder = + CelOptimizerFactory.standardCelOptimizerBuilder( + CelParserFactory.standardCelParserBuilder().build(), + CelCompilerFactory.standardCelCheckerBuilder().build(), + CelRuntimeFactory.standardCelRuntimeBuilder().build()); + + assertThat(builder).isNotNull(); + assertThat(builder.build()).isNotNull(); + } + + @Test + public void standardCelOptimizerBuilder_withCompilerAndRuntime() { + CelOptimizerBuilder builder = + CelOptimizerFactory.standardCelOptimizerBuilder( + CelCompilerFactory.standardCelCompilerBuilder().build(), + CelRuntimeFactory.standardCelRuntimeBuilder().build()); + + assertThat(builder).isNotNull(); + assertThat(builder.build()).isNotNull(); + } + + @Test + public void standardCelOptimizerBuilder_withCel() { + CelOptimizerBuilder builder = + CelOptimizerFactory.standardCelOptimizerBuilder(CelFactory.standardCelBuilder().build()); + + assertThat(builder).isNotNull(); + assertThat(builder.build()).isNotNull(); + } +} diff --git a/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerImplTest.java b/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerImplTest.java new file mode 100644 index 000000000..fdb6cb968 --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/CelOptimizerImplTest.java @@ -0,0 +1,131 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelSource; +import dev.cel.common.CelValidationException; +import dev.cel.common.ast.CelExpr; +import java.util.ArrayList; +import java.util.List; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class CelOptimizerImplTest { + + private static final Cel CEL = CelFactory.standardCelBuilder().build(); + + @Test + public void constructCelOptimizer_success() { + CelOptimizer celOptimizer = + CelOptimizerImpl.newBuilder(CEL) + .addAstOptimizers( + (navigableAst, cel) -> + // no-op + navigableAst.getAst()) + .build(); + + assertThat(celOptimizer).isNotNull(); + assertThat(celOptimizer).isInstanceOf(CelOptimizerImpl.class); + } + + @Test + public void astOptimizers_invokedInOrder() throws Exception { + List list = new ArrayList<>(); + + CelOptimizer celOptimizer = + CelOptimizerImpl.newBuilder(CEL) + .addAstOptimizers( + (navigableAst, cel) -> { + list.add(1); + return navigableAst.getAst(); + }) + .addAstOptimizers( + (navigableAst, cel) -> { + list.add(2); + return navigableAst.getAst(); + }) + .addAstOptimizers( + (navigableAst, cel) -> { + list.add(3); + return navigableAst.getAst(); + }) + .build(); + + CelAbstractSyntaxTree ast = celOptimizer.optimize(CEL.compile("'hello world'").getAst()); + + assertThat(ast).isNotNull(); + assertThat(list).containsExactly(1, 2, 3).inOrder(); + } + + @Test + public void optimizer_whenAstOptimizerThrows_throwsException() { + CelOptimizer celOptimizer = + CelOptimizerImpl.newBuilder(CEL) + .addAstOptimizers( + (navigableAst, cel) -> { + throw new IllegalArgumentException("Test exception"); + }) + .build(); + + CelOptimizationException e = + assertThrows( + CelOptimizationException.class, + () -> celOptimizer.optimize(CEL.compile("'hello world'").getAst())); + assertThat(e).hasMessageThat().isEqualTo("Optimization failure: Test exception"); + assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void parsedAst_throwsException() { + CelOptimizer celOptimizer = CelOptimizerImpl.newBuilder(CEL).build(); + + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, + () -> celOptimizer.optimize(CEL.parse("'test'").getAst())); + assertThat(e).hasMessageThat().contains("AST must be type-checked."); + } + + @Test + public void optimizedAst_failsToTypeCheck_throwsException() { + CelOptimizer celOptimizer = + CelOptimizerImpl.newBuilder(CEL) + .addAstOptimizers( + (navigableAst, cel) -> + CelAbstractSyntaxTree.newParsedAst( + CelExpr.ofIdentExpr(1, "undeclared_ident"), CelSource.newBuilder().build())) + .build(); + + CelOptimizationException e = + assertThrows( + CelOptimizationException.class, + () -> celOptimizer.optimize(CEL.compile("'hello world'").getAst())); + + assertThat(e) + .hasMessageThat() + .contains( + "Optimized AST failed to type-check: ERROR: :1:1: undeclared reference to" + + " 'undeclared_ident' (in container '')"); + assertThat(e).hasCauseThat().isInstanceOf(CelValidationException.class); + } +} diff --git a/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java new file mode 100644 index 000000000..5cff92d09 --- /dev/null +++ b/optimizer/src/test/java/dev/cel/optimizer/MutableAstTest.java @@ -0,0 +1,331 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.optimizer; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.testing.junit.testparameterinjector.TestParameterInjector; +import dev.cel.bundle.Cel; +import dev.cel.bundle.CelFactory; +import dev.cel.common.CelAbstractSyntaxTree; +import dev.cel.common.CelFunctionDecl; +import dev.cel.common.CelOverloadDecl; +import dev.cel.common.CelSource; +import dev.cel.common.ast.CelConstant; +import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelIdent; +import dev.cel.common.ast.CelExpr.CelSelect; +import dev.cel.common.types.SimpleType; +import dev.cel.common.types.StructTypeReference; +import dev.cel.parser.CelStandardMacro; +import dev.cel.parser.CelUnparser; +import dev.cel.parser.CelUnparserFactory; +import dev.cel.testing.testdata.proto3.TestAllTypesProto.TestAllTypes; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(TestParameterInjector.class) +public class MutableAstTest { + private static final Cel CEL = + CelFactory.standardCelBuilder() + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addMessageTypes(TestAllTypes.getDescriptor()) + .setContainer("dev.cel.testing.testdata.proto3") + .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) + .addVar("x", SimpleType.INT) + .build(); + + private static final CelUnparser CEL_UNPARSER = CelUnparserFactory.newUnparser(); + + @Test + public void constExpr() throws Exception { + CelExpr root = CEL.compile("10").getAst().getExpr(); + + CelExpr replacedExpr = + MutableAst.replaceSubtree( + root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(true)).build(), 1); + + assertThat(replacedExpr).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(true))); + } + + @Test + public void globalCallExpr_replaceRoot() throws Exception { + // Tree shape (brackets are expr IDs): + // + [4] + // + [2] x [5] + // 1 [1] 2 [3] + CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 4); + + assertThat(replacedRoot).isEqualTo(CelExpr.ofConstantExpr(1, CelConstant.ofValue(10))); + } + + @Test + public void globalCallExpr_replaceLeaf() throws Exception { + // Tree shape (brackets are expr IDs): + // + [4] + // + [2] x [5] + // 1 [1] 2 [3] + CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 1); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10 + 2 + x"); + } + + @Test + public void globalCallExpr_replaceMiddleBranch() throws Exception { + // Tree shape (brackets are expr IDs): + // + [4] + // + [2] x [5] + // 1 [1] 2 [3] + CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + root, CelExpr.newBuilder().setConstant(CelConstant.ofValue(10)).build(), 2); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10 + x"); + } + + @Test + public void globalCallExpr_replaceMiddleBranch_withCallExpr() throws Exception { + // Tree shape (brackets are expr IDs): + // + [4] + // + [2] x [5] + // 1 [1] 2 [3] + CelExpr root = CEL.compile("1 + 2 + x").getAst().getExpr(); + CelExpr root2 = CEL.compile("4 + 5 + 6").getAst().getExpr(); + + CelExpr replacedRoot = MutableAst.replaceSubtree(root, root2, 2); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("4 + 5 + 6 + x"); + } + + @Test + public void memberCallExpr_replaceLeafTarget() throws Exception { + // Tree shape (brackets are expr IDs): + // func [2] + // 10 [1] func [4] + // 4 [3] 5 [5] + Cel cel = + CelFactory.standardCelBuilder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "func", + CelOverloadDecl.newMemberOverload( + "func_overload", SimpleType.INT, SimpleType.INT, SimpleType.INT))) + .build(); + CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 3); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10.func(20.func(5))"); + } + + @Test + public void memberCallExpr_replaceLeafArgument() throws Exception { + // Tree shape (brackets are expr IDs): + // func [2] + // 10 [1] func [4] + // 4 [3] 5 [5] + Cel cel = + CelFactory.standardCelBuilder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "func", + CelOverloadDecl.newMemberOverload( + "func_overload", SimpleType.INT, SimpleType.INT, SimpleType.INT))) + .build(); + CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 5); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10.func(4.func(20))"); + } + + @Test + public void memberCallExpr_replaceMiddleBranchTarget() throws Exception { + // Tree shape (brackets are expr IDs): + // func [2] + // 10 [1] func [4] + // 4 [3] 5 [5] + Cel cel = + CelFactory.standardCelBuilder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "func", + CelOverloadDecl.newMemberOverload( + "func_overload", SimpleType.INT, SimpleType.INT, SimpleType.INT))) + .build(); + CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 1); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("20.func(4.func(5))"); + } + + @Test + public void memberCallExpr_replaceMiddleBranchArgument() throws Exception { + // Tree shape (brackets are expr IDs): + // func [2] + // 10 [1] func [4] + // 4 [3] 5 [5] + Cel cel = + CelFactory.standardCelBuilder() + .addFunctionDeclarations( + CelFunctionDecl.newFunctionDeclaration( + "func", + CelOverloadDecl.newMemberOverload( + "func_overload", SimpleType.INT, SimpleType.INT, SimpleType.INT))) + .build(); + CelAbstractSyntaxTree ast = cel.compile("10.func(4.func(5))").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(20)).build(), 4); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("10.func(20)"); + } + + @Test + public void select_replaceField() throws Exception { + // Tree shape (brackets are expr IDs): + // + [2] + // 5 [1] select [4] + // msg [3] + CelAbstractSyntaxTree ast = CEL.compile("5 + msg.single_int64").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), + CelExpr.newBuilder() + .setSelect( + CelSelect.newBuilder() + .setField("single_sint32") + .setOperand( + CelExpr.newBuilder() + .setIdent(CelIdent.newBuilder().setName("test").build()) + .build()) + .build()) + .build(), + 4); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("5 + test.single_sint32"); + } + + @Test + public void select_replaceOperand() throws Exception { + // Tree shape (brackets are expr IDs): + // + [2] + // 5 [1] select [4] + // msg [3] + CelAbstractSyntaxTree ast = CEL.compile("5 + msg.single_int64").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), + CelExpr.newBuilder().setIdent(CelIdent.newBuilder().setName("test").build()).build(), + 3); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("5 + test.single_int64"); + } + + @Test + public void list_replaceElement() throws Exception { + // Tree shape (brackets are expr IDs): + // list [1] + // 2 [2] 3 [3] 4 [4] + CelAbstractSyntaxTree ast = CEL.compile("[2, 3, 4]").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("[2, 3, 5]"); + } + + @Test + public void createStruct_replaceValue() throws Exception { + // Tree shape (brackets are expr IDs): + // TestAllTypes [1] + // single_int64 [2] + // 2 [3] + CelAbstractSyntaxTree ast = CEL.compile("TestAllTypes{single_int64: 2}").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("TestAllTypes{single_int64: 5}"); + } + + @Test + public void createMap_replaceKey() throws Exception { + // Tree shape (brackets are expr IDs): + // map [1] + // map_entry [2] + // 'a' [3] : 1 [4] + CelAbstractSyntaxTree ast = CEL.compile("{'a': 1}").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 3); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("{5: 1}"); + } + + @Test + public void createMap_replaceValue() throws Exception { + // Tree shape (brackets are expr IDs): + // map [1] + // map_entry [2] + // 'a' [3] : 1 [4] + CelAbstractSyntaxTree ast = CEL.compile("{'a': 1}").getAst(); + + CelExpr replacedRoot = + MutableAst.replaceSubtree( + ast.getExpr(), CelExpr.newBuilder().setConstant(CelConstant.ofValue(5)).build(), 4); + + assertThat(getUnparsedExpression(replacedRoot)).isEqualTo("{\"a\": 5}"); + } + + @Test + public void invalidCelExprKind_throwsException() { + assertThrows( + IllegalArgumentException.class, + () -> + MutableAst.replaceSubtree( + CelExpr.ofConstantExpr(1, CelConstant.ofValue("test")), CelExpr.ofNotSet(1), 1)); + } + + private static String getUnparsedExpression(CelExpr expr) { + CelAbstractSyntaxTree ast = + CelAbstractSyntaxTree.newParsedAst(expr, CelSource.newBuilder().build()); + return CEL_UNPARSER.unparse(ast); + } +} diff --git a/parser/src/main/java/dev/cel/parser/BUILD.bazel b/parser/src/main/java/dev/cel/parser/BUILD.bazel index d9d826c49..3def9c472 100644 --- a/parser/src/main/java/dev/cel/parser/BUILD.bazel +++ b/parser/src/main/java/dev/cel/parser/BUILD.bazel @@ -34,6 +34,7 @@ MACRO_SOURCES = [ # keep sorted UNPARSER_SOURCES = [ "CelUnparser.java", + "CelUnparserFactory.java", "CelUnparserImpl.java", ] diff --git a/parser/src/main/java/dev/cel/parser/CelUnparserFactory.java b/parser/src/main/java/dev/cel/parser/CelUnparserFactory.java new file mode 100644 index 000000000..e7a14bb8b --- /dev/null +++ b/parser/src/main/java/dev/cel/parser/CelUnparserFactory.java @@ -0,0 +1,24 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.parser; + +/** Factory class for producing {@link CelUnparser} instances and builders. */ +public final class CelUnparserFactory { + public static CelUnparser newUnparser() { + return new CelUnparserImpl(); + } + + private CelUnparserFactory() {} +} diff --git a/validator/src/main/java/dev/cel/validator/validators/DurationLiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/DurationLiteralValidator.java index 622bc2e1d..349374ca2 100644 --- a/validator/src/main/java/dev/cel/validator/validators/DurationLiteralValidator.java +++ b/validator/src/main/java/dev/cel/validator/validators/DurationLiteralValidator.java @@ -17,7 +17,7 @@ import com.google.protobuf.Duration; /** DurationLiteralValidator ensures that duration literal arguments are valid. */ -public class DurationLiteralValidator extends LiteralValidator { +public final class DurationLiteralValidator extends LiteralValidator { public static final DurationLiteralValidator INSTANCE = new DurationLiteralValidator("duration", Duration.class); diff --git a/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java index db787d89f..a6d641311 100644 --- a/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java +++ b/validator/src/main/java/dev/cel/validator/validators/HomogeneousLiteralValidator.java @@ -34,7 +34,7 @@ * HomogeneousLiteralValidator checks that all list and map literals entries have the same types, * i.e. no mixed list element types or mixed map key or map value types. */ -public class HomogeneousLiteralValidator implements CelAstValidator { +public final class HomogeneousLiteralValidator implements CelAstValidator { private final ImmutableSet exemptFunctions; /** diff --git a/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java index 6c1ac1437..fc79083c5 100644 --- a/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java +++ b/validator/src/main/java/dev/cel/validator/validators/RegexLiteralValidator.java @@ -24,7 +24,7 @@ import java.util.regex.PatternSyntaxException; /** RegexLiteralValidator ensures that regex patterns are valid. */ -public class RegexLiteralValidator implements CelAstValidator { +public final class RegexLiteralValidator implements CelAstValidator { public static final RegexLiteralValidator INSTANCE = new RegexLiteralValidator(); @Override diff --git a/validator/src/main/java/dev/cel/validator/validators/TimestampLiteralValidator.java b/validator/src/main/java/dev/cel/validator/validators/TimestampLiteralValidator.java index bc179c97f..4f6a5209e 100644 --- a/validator/src/main/java/dev/cel/validator/validators/TimestampLiteralValidator.java +++ b/validator/src/main/java/dev/cel/validator/validators/TimestampLiteralValidator.java @@ -17,7 +17,7 @@ import com.google.protobuf.Timestamp; /** TimestampLiteralValidator ensures that timestamp literal arguments are valid. */ -public class TimestampLiteralValidator extends LiteralValidator { +public final class TimestampLiteralValidator extends LiteralValidator { public static final TimestampLiteralValidator INSTANCE = new TimestampLiteralValidator("timestamp", Timestamp.class); diff --git a/validator/src/test/java/dev/cel/validator/CelValidatorImplTest.java b/validator/src/test/java/dev/cel/validator/CelValidatorImplTest.java index 495b10e9e..521faa9d4 100644 --- a/validator/src/test/java/dev/cel/validator/CelValidatorImplTest.java +++ b/validator/src/test/java/dev/cel/validator/CelValidatorImplTest.java @@ -31,8 +31,6 @@ public class CelValidatorImplTest { private static final Cel CEL = CelFactory.standardCelBuilder().build(); - // private static final CelValidatorImpl CEL_VALIDATOR = new CelValidatorImpl(CEL, - @Test public void constructCelValidator_success() { CelValidator celValidator =