Skip to content

Commit

Permalink
Fix unpacking any messages containing extension fields
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572685714
  • Loading branch information
l46kok authored and copybara-github committed Oct 17, 2023
1 parent 3c3aa79 commit ce2ae17
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 6 deletions.
8 changes: 8 additions & 0 deletions bundle/src/main/java/dev/cel/bundle/CelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import dev.cel.checker.ProtoTypeMask;
import dev.cel.checker.TypeProvider;
Expand Down Expand Up @@ -282,6 +283,13 @@ public interface CelBuilder {
@CanIgnoreReturnValue
CelBuilder addRuntimeLibraries(Iterable<CelRuntimeLibrary> libraries);

/**
* Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2
extension field.
*/
@CanIgnoreReturnValue
CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry);

/** Construct a new {@code Cel} instance from the provided configuration. */
Cel build();
}
8 changes: 8 additions & 0 deletions bundle/src/main/java/dev/cel/bundle/CelImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import dev.cel.checker.CelCheckerBuilder;
import dev.cel.checker.ProtoTypeMask;
Expand Down Expand Up @@ -339,6 +340,13 @@ public Builder addRuntimeLibraries(Iterable<CelRuntimeLibrary> libraries) {
return this;
}

@Override
public CelBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry) {
checkNotNull(extensionRegistry);
runtimeBuilder.setExtensionRegistry(extensionRegistry);
return this;
}

@Override
public Cel build() {
return new CelImpl(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ public final class DefaultDescriptorPool implements CelDescriptorPool {

/** A DefaultDescriptorPool instance with just well known types loaded. */
public static final DefaultDescriptorPool INSTANCE =
new DefaultDescriptorPool(WELL_KNOWN_TYPE_DESCRIPTORS, ImmutableMultimap.of());
new DefaultDescriptorPool(
WELL_KNOWN_TYPE_DESCRIPTORS,
ImmutableMultimap.of(),
ExtensionRegistry.getEmptyRegistry());

// K: Fully qualified message type name, V: Message descriptor
private final ImmutableMap<String, Descriptor> descriptorMap;
Expand All @@ -55,7 +58,15 @@ public final class DefaultDescriptorPool implements CelDescriptorPool {
// V: Field descriptor for the extension message
private final ImmutableMultimap<String, FieldDescriptor> extensionDescriptorMap;

@SuppressWarnings("Immutable") // ExtensionRegistry is immutable, just not marked as such.
private final ExtensionRegistry extensionRegistry;

public static DefaultDescriptorPool create(CelDescriptors celDescriptors) {
return create(celDescriptors, ExtensionRegistry.getEmptyRegistry());
}

public static DefaultDescriptorPool create(
CelDescriptors celDescriptors, ExtensionRegistry extensionRegistry) {
Map<String, Descriptor> descriptorMap = new HashMap<>(); // Using a hashmap to allow deduping
stream(WellKnownProto.values()).forEach(d -> descriptorMap.put(d.typeName(), d.descriptor()));

Expand All @@ -64,7 +75,9 @@ public static DefaultDescriptorPool create(CelDescriptors celDescriptors) {
}

return new DefaultDescriptorPool(
ImmutableMap.copyOf(descriptorMap), celDescriptors.extensionDescriptors());
ImmutableMap.copyOf(descriptorMap),
celDescriptors.extensionDescriptors(),
extensionRegistry);
}

@Override
Expand All @@ -83,14 +96,15 @@ public Optional<FieldDescriptor> findExtensionDescriptor(

@Override
public ExtensionRegistry getExtensionRegistry() {
// TODO: Populate one from runtime builder.
return ExtensionRegistry.getEmptyRegistry();
return extensionRegistry;
}

private DefaultDescriptorPool(
ImmutableMap<String, Descriptor> descriptorMap,
ImmutableMultimap<String, FieldDescriptor> extensionDescriptorMap) {
ImmutableMultimap<String, FieldDescriptor> extensionDescriptorMap,
ExtensionRegistry extensionRegistry) {
this.descriptorMap = checkNotNull(descriptorMap);
this.extensionDescriptorMap = checkNotNull(extensionDescriptorMap);
this.extensionRegistry = checkNotNull(extensionRegistry);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.protobuf.Any;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.testing.junit.testparameterinjector.TestParameter;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;
import com.google.testing.junit.testparameterinjector.TestParameters;
import dev.cel.bundle.Cel;
import dev.cel.bundle.CelFactory;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelFunctionDecl;
import dev.cel.common.CelOverloadDecl;
Expand Down Expand Up @@ -277,6 +281,29 @@ public void getExt_nonProtoNamespace_success(String expr) throws Exception {
assertThat(result).isTrue();
}

@Test
public void getExt_onAnyPackedExtensionField_success() throws Exception {
ExtensionRegistry extensionRegistry = ExtensionRegistry.newInstance();
MessagesProto2Extensions.registerAllExtensions(extensionRegistry);
Cel cel =
CelFactory.standardCelBuilder()
.addCompilerLibraries(CelExtensions.protos())
.addFileTypes(MessagesProto2Extensions.getDescriptor())
.setExtensionRegistry(extensionRegistry)
.addVar(
"msg", StructTypeReference.create("dev.cel.testing.testdata.proto2.Proto2Message"))
.build();
CelAbstractSyntaxTree ast =
cel.compile("proto.getExt(msg, dev.cel.testing.testdata.proto2.int32_ext)").getAst();
Any anyMsg =
Any.pack(
Proto2Message.newBuilder().setExtension(MessagesProto2Extensions.int32Ext, 1).build());

Long result = (Long) cel.createProgram(ast).eval(ImmutableMap.of("msg", anyMsg));

assertThat(result).isEqualTo(1);
}

private enum ParseErrorTestCase {
FIELD_NOT_FULLY_QUALIFIED(
"proto.getExt(Proto2ExtensionScopedMessage{}, int64_ext)",
Expand Down
8 changes: 8 additions & 0 deletions runtime/src/main/java/dev/cel/runtime/CelRuntimeBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import dev.cel.common.CelOptions;
import java.util.function.Function;
Expand Down Expand Up @@ -149,6 +150,13 @@ public interface CelRuntimeBuilder {
@CanIgnoreReturnValue
CelRuntimeBuilder addLibraries(Iterable<? extends CelRuntimeLibrary> libraries);

/**
* Sets a proto ExtensionRegistry to assist with unpacking Any messages containing a proto2
extension field.
*/
@CanIgnoreReturnValue
CelRuntimeBuilder setExtensionRegistry(ExtensionRegistry extensionRegistry);

/** Build a new instance of the {@code CelRuntime}. */
@CheckReturnValue
CelRuntime build();
Expand Down
15 changes: 14 additions & 1 deletion runtime/src/main/java/dev/cel/runtime/CelRuntimeLegacyImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import dev.cel.common.CelAbstractSyntaxTree;
import dev.cel.common.CelDescriptorUtil;
Expand Down Expand Up @@ -79,6 +80,7 @@ public static final class Builder implements CelRuntimeBuilder {

private boolean standardEnvironmentEnabled;
private Function<String, Message.Builder> customTypeFactory;
private ExtensionRegistry extensionRegistry;

@Override
@CanIgnoreReturnValue
Expand Down Expand Up @@ -161,6 +163,14 @@ public Builder addLibraries(Iterable<? extends CelRuntimeLibrary> libraries) {
return this;
}

@Override
@CanIgnoreReturnValue
public Builder setExtensionRegistry(ExtensionRegistry extensionRegistry) {
checkNotNull(extensionRegistry);
this.extensionRegistry = extensionRegistry.getUnmodifiable();
return this;
}

/** Build a new {@code CelRuntimeLegacyImpl} instance from the builder config. */
@Override
@CanIgnoreReturnValue
Expand All @@ -171,6 +181,7 @@ public CelRuntimeLegacyImpl build() {
CelDescriptorPool celDescriptorPool =
newDescriptorPool(
fileTypes.build(),
extensionRegistry,
options);

@SuppressWarnings("Immutable")
Expand Down Expand Up @@ -214,14 +225,15 @@ public CelRuntimeLegacyImpl build() {

private static CelDescriptorPool newDescriptorPool(
ImmutableSet<FileDescriptor> fileTypeSet,
ExtensionRegistry extensionRegistry,
CelOptions celOptions) {
CelDescriptors celDescriptors =
CelDescriptorUtil.getAllDescriptorsFromFileDescriptor(
fileTypeSet, celOptions.resolveTypeDependencies());

ImmutableList.Builder<CelDescriptorPool> descriptorPools = new ImmutableList.Builder<>();

descriptorPools.add(DefaultDescriptorPool.create(celDescriptors));
descriptorPools.add(DefaultDescriptorPool.create(celDescriptors, extensionRegistry));

return CombinedDescriptorPool.create(descriptorPools.build());
}
Expand All @@ -241,6 +253,7 @@ private Builder() {
this.fileTypes = ImmutableSet.builder();
this.functionBindings = ImmutableMap.builder();
this.celRuntimeLibraries = ImmutableSet.builder();
this.extensionRegistry = ExtensionRegistry.getEmptyRegistry();
this.customTypeFactory = null;
}
}
Expand Down

0 comments on commit ce2ae17

Please sign in to comment.