Skip to content

Commit

Permalink
prost-build: consolidate message field data
Browse files Browse the repository at this point in the history
When massaging field data in CodeGenerator::append_message,
move it into lists of Field and OneofField structs so that later
generation passes can operate on the data with less code duplication.

Subsidiary append_* methods are changed to take references to these
structs rather than moved data, as generation of lexical tokens
does not actually consume any owned data, and we will need more
passes over the same field lists for the upcoming builder code.
  • Loading branch information
mzabaluev committed Mar 29, 2024
1 parent 50bab4f commit 486cb60
Showing 1 changed file with 101 additions and 67 deletions.
168 changes: 101 additions & 67 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,45 @@ fn push_indent(buf: &mut String, depth: u8) {
buf.push_str(" ");
}
}

struct Field {
rust_name: String,
descriptor: FieldDescriptorProto,
path_index: i32,
}

impl Field {
fn new(descriptor: FieldDescriptorProto, path_index: i32) -> Self {
Self {
rust_name: to_snake(descriptor.name()),
descriptor,
path_index,
}
}
}

struct OneofField {
rust_name: String,
descriptor: OneofDescriptorProto,
fields: Vec<(FieldDescriptorProto, i32)>,
path_index: i32,
}

impl OneofField {
fn new(
descriptor: OneofDescriptorProto,
fields: Vec<(FieldDescriptorProto, i32)>,
path_index: i32,
) -> Self {
Self {
rust_name: to_snake(descriptor.name()),
descriptor,
fields,
path_index,
}
}
}

impl<'a> CodeGenerator<'a> {
pub fn generate(
config: &mut Config,
Expand Down Expand Up @@ -159,21 +198,33 @@ impl<'a> CodeGenerator<'a> {

// Split the fields into a vector of the normal fields, and oneof fields.
// Path indexes are preserved so that comments can be retrieved.
type Fields = Vec<(FieldDescriptorProto, usize)>;
type OneofFields = MultiMap<i32, (FieldDescriptorProto, usize)>;
let (fields, mut oneof_fields): (Fields, OneofFields) = message
type OneofFieldsByIndex = MultiMap<i32, (FieldDescriptorProto, i32)>;
let (fields, mut oneof_map): (Vec<Field>, OneofFieldsByIndex) = message
.field
.into_iter()
.enumerate()
.partition_map(|(idx, field)| {
if field.proto3_optional.unwrap_or(false) {
Either::Left((field, idx))
} else if let Some(oneof_index) = field.oneof_index {
Either::Right((oneof_index, (field, idx)))
.partition_map(|(idx, proto)| {
let idx = idx as i32;
if proto.proto3_optional.unwrap_or(false) {
Either::Left(Field::new(proto, idx))
} else if let Some(oneof_index) = proto.oneof_index {
Either::Right((oneof_index, (proto, idx)))
} else {
Either::Left((field, idx))
Either::Left(Field::new(proto, idx))
}
});
// Optional fields create a synthetic oneof that we want to skip
let oneof_fields: Vec<OneofField> = message
.oneof_decl
.into_iter()
.enumerate()
.filter_map(move |(idx, proto)| {
let idx = idx as i32;
oneof_map
.remove(&idx)
.map(|fields| OneofField::new(proto, fields, idx))
})
.collect();

self.append_doc(&fq_message_name, None);
self.append_type_attributes(&fq_message_name);
Expand All @@ -193,33 +244,25 @@ impl<'a> CodeGenerator<'a> {

self.depth += 1;
self.path.push(2);
for (field, idx) in fields {
self.path.push(idx as i32);
for field in &fields {
self.path.push(field.path_index);
match field
.descriptor
.type_name
.as_ref()
.and_then(|type_name| map_types.get(type_name))
{
Some(&(ref key, ref value)) => {
self.append_map_field(&fq_message_name, field, key, value)
}
Some((key, value)) => self.append_map_field(&fq_message_name, field, key, value),
None => self.append_field(&fq_message_name, field),
}
self.path.pop();
}
self.path.pop();

self.path.push(8);
for (idx, oneof) in message.oneof_decl.iter().enumerate() {
let idx = idx as i32;

let fields = match oneof_fields.get_vec(&idx) {
Some(fields) => fields,
None => continue,
};

self.path.push(idx);
self.append_oneof_field(&message_name, &fq_message_name, oneof, fields);
for oneof in &oneof_fields {
self.path.push(oneof.path_index);
self.append_oneof_field(&message_name, &fq_message_name, oneof);
self.path.pop();
}
self.path.pop();
Expand All @@ -246,14 +289,8 @@ impl<'a> CodeGenerator<'a> {
}
self.path.pop();

for (idx, oneof) in message.oneof_decl.into_iter().enumerate() {
let idx = idx as i32;
// optional fields create a synthetic oneof that we want to skip
let fields = match oneof_fields.remove(&idx) {
Some(fields) => fields,
None => continue,
};
self.append_oneof(&fq_message_name, oneof, idx, fields);
for oneof in &oneof_fields {
self.append_oneof(&fq_message_name, oneof);
}

self.pop_mod();
Expand Down Expand Up @@ -362,12 +399,14 @@ impl<'a> CodeGenerator<'a> {
}
}

fn append_field(&mut self, fq_message_name: &str, field: FieldDescriptorProto) {
fn append_field(&mut self, fq_message_name: &str, field: &Field) {
let rust_name = &field.rust_name;
let field = &field.descriptor;
let type_ = field.r#type();
let repeated = field.label == Some(Label::Repeated as i32);
let deprecated = self.deprecated(&field);
let optional = self.optional(&field);
let ty = self.resolve_type(&field, fq_message_name);
let deprecated = self.deprecated(field);
let optional = self.optional(field);
let ty = self.resolve_type(field, fq_message_name);

let boxed = !repeated
&& ((type_ == Type::Message || type_ == Type::Group)
Expand Down Expand Up @@ -396,7 +435,7 @@ impl<'a> CodeGenerator<'a> {

self.push_indent();
self.buf.push_str("#[prost(");
let type_tag = self.field_type_tag(&field);
let type_tag = self.field_type_tag(field);
self.buf.push_str(&type_tag);

if type_ == Type::Bytes {
Expand All @@ -419,7 +458,7 @@ impl<'a> CodeGenerator<'a> {
Label::Required => self.buf.push_str(", required"),
Label::Repeated => {
self.buf.push_str(", repeated");
if can_pack(&field)
if can_pack(field)
&& !field
.options
.as_ref()
Expand Down Expand Up @@ -470,7 +509,7 @@ impl<'a> CodeGenerator<'a> {
self.append_field_attributes(fq_message_name, field.name());
self.push_indent();
self.buf.push_str("pub ");
self.buf.push_str(&to_snake(field.name()));
self.buf.push_str(rust_name);
self.buf.push_str(": ");

let prost_path = self.config.prost_path.as_deref().unwrap_or("::prost");
Expand Down Expand Up @@ -498,10 +537,12 @@ impl<'a> CodeGenerator<'a> {
fn append_map_field(
&mut self,
fq_message_name: &str,
field: FieldDescriptorProto,
field: &Field,
key: &FieldDescriptorProto,
value: &FieldDescriptorProto,
) {
let rust_name = &field.rust_name;
let field = &field.descriptor;
let key_ty = self.resolve_type(key, fq_message_name);
let value_ty = self.resolve_type(value, fq_message_name);

Expand Down Expand Up @@ -535,7 +576,7 @@ impl<'a> CodeGenerator<'a> {
self.push_indent();
self.buf.push_str(&format!(
"pub {}: {}<{}, {}>,\n",
to_snake(field.name()),
rust_name,
map_type.rust_type(),
key_ty,
value_ty
Expand All @@ -546,47 +587,40 @@ impl<'a> CodeGenerator<'a> {
&mut self,
message_name: &str,
fq_message_name: &str,
oneof: &OneofDescriptorProto,
fields: &[(FieldDescriptorProto, usize)],
oneof: &OneofField,
) {
let name = format!(
let type_name = format!(
"{}::{}",
to_snake(message_name),
to_upper_camel(oneof.name())
to_upper_camel(oneof.descriptor.name())
);
let field_tags = oneof
.fields
.iter()
.map(|(field, _)| field.number())
.join(", ");
self.append_doc(fq_message_name, None);
self.push_indent();
self.buf.push_str(&format!(
"#[prost(oneof=\"{}\", tags=\"{}\")]\n",
name,
fields
.iter()
.map(|&(ref field, _)| field.number())
.join(", ")
type_name, field_tags,
));
self.append_field_attributes(fq_message_name, oneof.name());
self.append_field_attributes(fq_message_name, oneof.descriptor.name());
self.push_indent();
self.buf.push_str(&format!(
"pub {}: ::core::option::Option<{}>,\n",
to_snake(oneof.name()),
name
oneof.rust_name, type_name
));
}

fn append_oneof(
&mut self,
fq_message_name: &str,
oneof: OneofDescriptorProto,
idx: i32,
fields: Vec<(FieldDescriptorProto, usize)>,
) {
fn append_oneof(&mut self, fq_message_name: &str, oneof: &OneofField) {
self.path.push(8);
self.path.push(idx);
self.path.push(oneof.path_index);
self.append_doc(fq_message_name, None);
self.path.pop();
self.path.pop();

let oneof_name = format!("{}.{}", fq_message_name, oneof.name());
let oneof_name = format!("{}.{}", fq_message_name, oneof.descriptor.name());
self.append_type_attributes(&oneof_name);
self.append_enum_attributes(&oneof_name);
self.push_indent();
Expand All @@ -599,20 +633,20 @@ impl<'a> CodeGenerator<'a> {
self.append_skip_debug(&fq_message_name);
self.push_indent();
self.buf.push_str("pub enum ");
self.buf.push_str(&to_upper_camel(oneof.name()));
self.buf.push_str(&to_upper_camel(oneof.descriptor.name()));
self.buf.push_str(" {\n");

self.path.push(2);
self.depth += 1;
for (field, idx) in fields {
for (field, idx) in &oneof.fields {
let type_ = field.r#type();

self.path.push(idx as i32);
self.path.push(*idx);
self.append_doc(fq_message_name, Some(field.name()));
self.path.pop();

self.push_indent();
let ty_tag = self.field_type_tag(&field);
let ty_tag = self.field_type_tag(field);
self.buf.push_str(&format!(
"#[prost({}, tag=\"{}\")]\n",
ty_tag,
Expand All @@ -621,7 +655,7 @@ impl<'a> CodeGenerator<'a> {
self.append_field_attributes(&oneof_name, field.name());

self.push_indent();
let ty = self.resolve_type(&field, fq_message_name);
let ty = self.resolve_type(field, fq_message_name);

let boxed = ((type_ == Type::Message || type_ == Type::Group)
&& self
Expand Down

0 comments on commit 486cb60

Please sign in to comment.