Skip to content

Commit

Permalink
Add support for extensions in CRuby, JRuby, and FFI Ruby (#14703)
Browse files Browse the repository at this point in the history
Follow up to #14594, which added support for custom options, this PR implements extensions support, which should fully resolve #1198.

Closes #14703

COPYBARA_INTEGRATE_REVIEW=#14703 from protocolbuffers:add-support-for-extensions-in-ruby 601aca4
PiperOrigin-RevId: 582460674
  • Loading branch information
JasonLunn authored and zhangskz committed Nov 14, 2023
1 parent 2495d4f commit 98fa596
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 55 deletions.
10 changes: 8 additions & 2 deletions ruby/ext/google/protobuf_c/defs.c
Original file line number Diff line number Diff line change
Expand Up @@ -144,20 +144,26 @@ VALUE DescriptorPool_add_serialized_file(VALUE _self,
* call-seq:
* DescriptorPool.lookup(name) => descriptor
*
* Finds a Descriptor or EnumDescriptor by name and returns it, or nil if none
* exists with the given name.
* Finds a Descriptor, EnumDescriptor or FieldDescriptor by name and returns it,
* or nil if none exists with the given name.
*/
static VALUE DescriptorPool_lookup(VALUE _self, VALUE name) {
DescriptorPool* self = ruby_to_DescriptorPool(_self);
const char* name_str = get_str(name);
const upb_MessageDef* msgdef;
const upb_EnumDef* enumdef;
const upb_FieldDef* fielddef;

msgdef = upb_DefPool_FindMessageByName(self->symtab, name_str);
if (msgdef) {
return get_msgdef_obj(_self, msgdef);
}

fielddef = upb_DefPool_FindExtensionByName(self->symtab, name_str);
if (fielddef) {
return get_fielddef_obj(_self, fielddef);
}

enumdef = upb_DefPool_FindEnumByName(self->symtab, name_str);
if (enumdef) {
return get_enumdef_obj(_self, enumdef);
Expand Down
10 changes: 8 additions & 2 deletions ruby/ext/google/protobuf_c/message.c
Original file line number Diff line number Diff line change
Expand Up @@ -977,9 +977,12 @@ VALUE Message_decode_bytes(int size, const char* bytes, int options,
VALUE msg_rb = initialize_rb_class_with_no_args(klass);
Message* msg = ruby_to_Message(msg_rb);

const upb_FileDef* file = upb_MessageDef_File(msg->msgdef);
const upb_ExtensionRegistry* extreg =
upb_DefPool_ExtensionRegistry(upb_FileDef_Pool(file));
upb_DecodeStatus status = upb_Decode(bytes, size, (upb_Message*)msg->msg,
upb_MessageDef_MiniTable(msg->msgdef),
NULL, options, Arena_get(msg->arena));
extreg, options, Arena_get(msg->arena));
if (status != kUpb_DecodeStatus_Ok) {
rb_raise(cParseError, "Error occurred during parsing");
}
Expand Down Expand Up @@ -1303,9 +1306,12 @@ upb_Message* Message_deep_copy(const upb_Message* msg, const upb_MessageDef* m,
upb_Message* new_msg = upb_Message_New(layout, arena);
char* data;

const upb_FileDef* file = upb_MessageDef_File(m);
const upb_ExtensionRegistry* extreg =
upb_DefPool_ExtensionRegistry(upb_FileDef_Pool(file));
if (upb_Encode(msg, layout, 0, tmp_arena, &data, &size) !=
kUpb_EncodeStatus_Ok ||
upb_Decode(data, size, new_msg, layout, NULL, 0, arena) !=
upb_Decode(data, size, new_msg, layout, extreg, 0, arena) !=
kUpb_DecodeStatus_Ok) {
upb_Arena_Free(tmp_arena);
rb_raise(cParseError, "Error occurred copying proto");
Expand Down
20 changes: 12 additions & 8 deletions ruby/lib/google/protobuf/ffi/descriptor_pool.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@ module Google
module Protobuf
class FFI
# DefPool
attach_function :add_serialized_file, :upb_DefPool_AddFile, [:DefPool, :FileDescriptorProto, Status.by_ref], :FileDef
attach_function :free_descriptor_pool, :upb_DefPool_Free, [:DefPool], :void
attach_function :create_descriptor_pool,:upb_DefPool_New, [], :DefPool
attach_function :lookup_enum, :upb_DefPool_FindEnumByName, [:DefPool, :string], EnumDescriptor
attach_function :lookup_msg, :upb_DefPool_FindMessageByName, [:DefPool, :string], Descriptor
# FileDescriptorProto
attach_function :parse, :FileDescriptorProto_parse, [:binary_string, :size_t, Internal::Arena], :FileDescriptorProto
attach_function :add_serialized_file, :upb_DefPool_AddFile, [:DefPool, :FileDescriptorProto, Status.by_ref], :FileDef
attach_function :free_descriptor_pool, :upb_DefPool_Free, [:DefPool], :void
attach_function :create_descriptor_pool,:upb_DefPool_New, [], :DefPool
attach_function :get_extension_registry,:upb_DefPool_ExtensionRegistry, [:DefPool], :ExtensionRegistry
attach_function :lookup_enum, :upb_DefPool_FindEnumByName, [:DefPool, :string], EnumDescriptor
attach_function :lookup_extension, :upb_DefPool_FindExtensionByName,[:DefPool, :string], FieldDescriptor
attach_function :lookup_msg, :upb_DefPool_FindMessageByName, [:DefPool, :string], Descriptor

# FileDescriptorProto
attach_function :parse, :FileDescriptorProto_parse, [:binary_string, :size_t, Internal::Arena], :FileDescriptorProto
end
class DescriptorPool
attr :descriptor_pool
Expand Down Expand Up @@ -50,7 +53,8 @@ def add_serialized_file(file_contents)

def lookup name
Google::Protobuf::FFI.lookup_msg(@descriptor_pool, name) ||
Google::Protobuf::FFI.lookup_enum(@descriptor_pool, name)
Google::Protobuf::FFI.lookup_enum(@descriptor_pool, name) ||
Google::Protobuf::FFI.lookup_extension(@descriptor_pool, name)
end

def self.generated_pool
Expand Down
10 changes: 9 additions & 1 deletion ruby/lib/google/protobuf/ffi/message.rb
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,15 @@ def self.decode(data, options = {})

message = new
mini_table_ptr = Google::Protobuf::FFI.get_mini_table(message.class.descriptor)
status = Google::Protobuf::FFI.decode_message(data, data.bytesize, message.instance_variable_get(:@msg), mini_table_ptr, nil, decoding_options, message.instance_variable_get(:@arena))
status = Google::Protobuf::FFI.decode_message(
data,
data.bytesize,
message.instance_variable_get(:@msg),
mini_table_ptr,
Google::Protobuf::FFI.get_extension_registry(message.class.descriptor.send(:pool).descriptor_pool),
decoding_options,
message.instance_variable_get(:@arena)
)
raise ParseError.new "Error occurred during parsing" unless status == :Ok
message
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.DescriptorValidationException;
import com.google.protobuf.Descriptors.EnumDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -70,6 +72,7 @@ public IRubyObject allocate(Ruby runtime, RubyClass klazz) {
cDescriptorPool.newInstance(runtime.getCurrentContext(), Block.NULL_BLOCK);
cDescriptor = (RubyClass) runtime.getClassFromPath("Google::Protobuf::Descriptor");
cEnumDescriptor = (RubyClass) runtime.getClassFromPath("Google::Protobuf::EnumDescriptor");
cFieldDescriptor = (RubyClass) runtime.getClassFromPath("Google::Protobuf::FieldDescriptor");
}

public RubyDescriptorPool(Ruby runtime, RubyClass klazz) {
Expand All @@ -92,7 +95,7 @@ public IRubyObject build(ThreadContext context, Block block) {
* call-seq:
* DescriptorPool.lookup(name) => descriptor
*
* Finds a Descriptor or EnumDescriptor by name and returns it, or nil if none
* Finds a Descriptor, EnumDescriptor or FieldDescriptor by name and returns it, or nil if none
* exists with the given name.
*
* This currently lazy loads the ruby descriptor objects as they are requested.
Expand Down Expand Up @@ -121,7 +124,8 @@ public static IRubyObject generatedPool(ThreadContext context, IRubyObject recv)
public IRubyObject add_serialized_file(ThreadContext context, IRubyObject data) {
byte[] bin = data.convertToString().getBytes();
try {
FileDescriptorProto.Builder builder = FileDescriptorProto.newBuilder().mergeFrom(bin);
FileDescriptorProto.Builder builder =
FileDescriptorProto.newBuilder().mergeFrom(bin, registry);
registerFileDescriptor(context, builder);
} catch (InvalidProtocolBufferException e) {
throw RaiseException.from(
Expand Down Expand Up @@ -150,6 +154,8 @@ protected void registerFileDescriptor(
for (EnumDescriptor ed : fd.getEnumTypes()) registerEnumDescriptor(context, ed, packageName);
for (Descriptor message : fd.getMessageTypes())
registerDescriptor(context, message, packageName);
for (FieldDescriptor fieldDescriptor : fd.getExtensions())
registerExtension(context, fieldDescriptor, packageName);

// Mark this as a loaded file
fileDescriptors.add(fd);
Expand All @@ -170,6 +176,24 @@ private void registerDescriptor(ThreadContext context, Descriptor descriptor, St
registerEnumDescriptor(context, ed, fullPath);
for (Descriptor message : descriptor.getNestedTypes())
registerDescriptor(context, message, fullPath);
for (FieldDescriptor fieldDescriptor : descriptor.getExtensions())
registerExtension(context, fieldDescriptor, fullPath);
}

private void registerExtension(
ThreadContext context, FieldDescriptor descriptor, String parentPath) {
if (descriptor.getJavaType() == FieldDescriptor.JavaType.MESSAGE) {
registry.add(descriptor, descriptor.toProto());
} else {
registry.add(descriptor);
}
RubyString name = context.runtime.newString(parentPath + descriptor.getName());
RubyFieldDescriptor des =
(RubyFieldDescriptor) cFieldDescriptor.newInstance(context, Block.NULL_BLOCK);
des.setName(name);
des.setDescriptor(context, descriptor, this);
// For MessageSet extensions, there is the possibility of a name conflict. Prefer the Message.
symtab.putIfAbsent(name, des);
}

private void registerEnumDescriptor(
Expand All @@ -188,8 +212,10 @@ private FileDescriptor[] existingFileDescriptors() {

private static RubyClass cDescriptor;
private static RubyClass cEnumDescriptor;
private static RubyClass cFieldDescriptor;
private static RubyDescriptorPool descriptorPool;

private List<FileDescriptor> fileDescriptors;
private Map<IRubyObject, IRubyObject> symtab;
protected static final ExtensionRegistry registry = ExtensionRegistry.newInstance();
}
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ public IRubyObject getName(ThreadContext context) {
return this.name;
}

protected void setName(IRubyObject name) {
this.name = name;
}

/*
* call-seq:
* FieldDescriptor.subtype => message_or_enum_descriptor
Expand Down Expand Up @@ -229,7 +233,7 @@ public IRubyObject has(ThreadContext context, IRubyObject message) {
*/
@JRubyMethod(name = "set")
public IRubyObject setValue(ThreadContext context, IRubyObject message, IRubyObject value) {
((RubyMessage) message).setField(context, descriptor, value);
((RubyMessage) message).setField(context, this, value);
return context.nil;
}

Expand Down Expand Up @@ -263,6 +267,10 @@ protected void setDescriptor(
this.pool = pool;
}

protected FieldDescriptor getDescriptor() {
return descriptor;
}

private void calculateLabel(ThreadContext context) {
if (descriptor.isRepeated()) {
this.label = context.runtime.newSymbol("repeated");
Expand Down
24 changes: 21 additions & 3 deletions ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ public static IRubyObject decode(ThreadContext context, IRubyObject recv, IRubyO
public static IRubyObject decodeBytes(
ThreadContext context, RubyMessage ret, CodedInputStream input, boolean freeze) {
try {
ret.builder.mergeFrom(input);
ret.builder.mergeFrom(input, RubyDescriptorPool.registry);
} catch (Exception e) {
throw RaiseException.from(
context.runtime,
Expand Down Expand Up @@ -965,6 +965,12 @@ protected IRubyObject setField(
return setFieldInternal(context, fieldDescriptor, value);
}

protected IRubyObject setField(
ThreadContext context, RubyFieldDescriptor fieldDescriptor, IRubyObject value) {
validateMessageType(context, fieldDescriptor.getDescriptor(), "set");
return setFieldInternal(context, fieldDescriptor.getDescriptor(), fieldDescriptor, value);
}

private RubyRepeatedField getRepeatedField(
ThreadContext context, FieldDescriptor fieldDescriptor) {
if (fields.containsKey(fieldDescriptor)) {
Expand Down Expand Up @@ -1275,6 +1281,14 @@ private IRubyObject getFieldInternal(

private IRubyObject setFieldInternal(
ThreadContext context, FieldDescriptor fieldDescriptor, IRubyObject value) {
return setFieldInternal(context, fieldDescriptor, null, value);
}

private IRubyObject setFieldInternal(
ThreadContext context,
FieldDescriptor fieldDescriptor,
RubyFieldDescriptor rubyFieldDescriptor,
IRubyObject value) {
testFrozen("can't modify frozen " + getMetaClass());

if (fieldDescriptor.isMapField()) {
Expand All @@ -1299,8 +1313,12 @@ private IRubyObject setFieldInternal(
// Determine the typeclass, if any
IRubyObject typeClass = context.runtime.getObject();
if (fieldType == FieldDescriptor.Type.MESSAGE) {
typeClass =
((RubyDescriptor) getDescriptorForField(context, fieldDescriptor)).msgclass(context);
if (rubyFieldDescriptor != null) {
typeClass = ((RubyDescriptor) rubyFieldDescriptor.getSubtype(context)).msgclass(context);
} else {
typeClass =
((RubyDescriptor) getDescriptorForField(context, fieldDescriptor)).msgclass(context);
}
if (value.isNil()) {
addValue = false;
}
Expand Down
32 changes: 32 additions & 0 deletions ruby/tests/basic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,19 @@ def test_oneof_descriptor_options
oneof_descriptor = descriptor.lookup_oneof("test_deprecated_message_oneof")

assert_instance_of Google::Protobuf::OneofOptions, oneof_descriptor.options
test_top_level_option = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test.test_top_level_option'
assert_instance_of Google::Protobuf::FieldDescriptor, test_top_level_option
assert_equal "Custom option value", test_top_level_option.get(oneof_descriptor.options)
end

def test_nested_extension
descriptor = TestDeprecatedMessage.descriptor
oneof_descriptor = descriptor.lookup_oneof("test_deprecated_message_oneof")

assert_instance_of Google::Protobuf::OneofOptions, oneof_descriptor.options
test_nested_option = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test.TestDeprecatedMessage.test_nested_option'
assert_instance_of Google::Protobuf::FieldDescriptor, test_nested_option
assert_equal "Another custom option value", test_nested_option.get(oneof_descriptor.options)
end

def test_options_deep_freeze
Expand All @@ -739,6 +752,25 @@ def test_options_deep_freeze
Google::Protobuf::UninterpretedOption.new
end
end

def test_message_deep_freeze
message = TestDeprecatedMessage.new
omit(":internal_deep_freeze only exists under FFI") unless message.respond_to? :internal_deep_freeze, true
nested_message_2 = TestMessage2.new

message.map_string_msg["message"] = TestMessage2.new
message.repeated_msg.push(TestMessage2.new)

message.send(:internal_deep_freeze)

assert_raise FrozenError do
message.map_string_msg["message"].foo = "bar"
end

assert_raise FrozenError do
message.repeated_msg[0].foo = "bar"
end
end
end

def test_oneof_fields_respond_to? # regression test for issue 9202
Expand Down
52 changes: 52 additions & 0 deletions ruby/tests/basic_proto2.rb
Original file line number Diff line number Diff line change
Expand Up @@ -269,5 +269,57 @@ def test_oneof_fields_respond_to? # regression test for issue 9202
assert msg.respond_to? :has_d?
refute msg.has_d?
end

def test_extension
message = TestExtensions.new
extension = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.optional_int32_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, extension
assert_equal 0, extension.get(message)
extension.set message, 42
assert_equal 42, extension.get(message)
end

def test_nested_extension
message = TestExtensions.new
extension = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestNestedExtension.test'
assert_instance_of Google::Protobuf::FieldDescriptor, extension
assert_equal 'test', extension.get(message)
extension.set message, 'another test'
assert_equal 'another test', extension.get(message)
end

def test_message_set_extension_json_roundtrip
omit "Java Protobuf JsonFormat does not handle Proto2 extensions" if defined? JRUBY_VERSION and :NATIVE == Google::Protobuf::IMPLEMENTATION
message = TestMessageSet.new
ext1 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension1.message_set_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, ext1
ext2 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension2.message_set_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, ext2
ext3 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.message_set_extension3'
assert_instance_of Google::Protobuf::FieldDescriptor, ext3
ext1.set(message, ext1.subtype.msgclass.new(i: 42))
ext2.set(message, ext2.subtype.msgclass.new(str: 'foo'))
ext3.set(message, ext3.subtype.msgclass.new(text: 'bar'))
message_text = message.to_json
parsed_message = TestMessageSet.decode_json message_text
assert_equal message, parsed_message
end


def test_message_set_extension_roundtrip
message = TestMessageSet.new
ext1 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension1.message_set_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, ext1
ext2 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.TestMessageSetExtension2.message_set_extension'
assert_instance_of Google::Protobuf::FieldDescriptor, ext2
ext3 = Google::Protobuf::DescriptorPool.generated_pool.lookup 'basic_test_proto2.message_set_extension3'
assert_instance_of Google::Protobuf::FieldDescriptor, ext3
ext1.set(message, ext1.subtype.msgclass.new(i: 42))
ext2.set(message, ext2.subtype.msgclass.new(str: 'foo'))
ext3.set(message, ext3.subtype.msgclass.new(text: 'bar'))
encoded_message = TestMessageSet.encode message
decoded_message = TestMessageSet.decode encoded_message
assert_equal message, decoded_message
end
end
end
Loading

0 comments on commit 98fa596

Please sign in to comment.