Skip to content

Commit

Permalink
ProstBuild: CodeGen: Add support for adding Cow types
Browse files Browse the repository at this point in the history
Signed-off-by: Jon Doron <jond@wiz.io>
  • Loading branch information
arilou committed Dec 11, 2024
1 parent a40f358 commit 4556734
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 36 deletions.
109 changes: 78 additions & 31 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ impl CodeGenerator<'_> {
self.push_indent();
self.buf.push_str("pub struct ");
self.buf.push_str(&to_upper_camel(&message_name));
if self.message_graph.message_has_lifetime(&fq_message_name) {
self.buf.push_str("<'a>");
}
self.buf.push_str(" {\n");

self.depth += 1;
Expand Down Expand Up @@ -406,13 +409,15 @@ impl CodeGenerator<'_> {
let deprecated = self.deprecated(&field.descriptor);
let optional = self.optional(&field.descriptor);
let boxed = self.boxed(&field.descriptor, fq_message_name, None);
let ty = self.resolve_type(&field.descriptor, fq_message_name);
let cowed = self.cowed(&field.descriptor, fq_message_name, None);
let ty = self.resolve_type(&field.descriptor, fq_message_name, cowed);

debug!(
" field: {:?}, type: {:?}, boxed: {}",
" field: {:?}, type: {:?}, boxed: {} cowed: {}",
field.descriptor.name(),
ty,
boxed
boxed,
cowed
);

self.append_doc(fq_message_name, Some(field.descriptor.name()));
Expand All @@ -424,10 +429,10 @@ impl CodeGenerator<'_> {

self.push_indent();
self.buf.push_str("#[prost(");
let type_tag = self.field_type_tag(&field.descriptor);
let type_tag = self.field_type_tag(&field.descriptor, cowed);
self.buf.push_str(&type_tag);

if type_ == Type::Bytes {
if !cowed && type_ == Type::Bytes {
let bytes_type = self
.config
.bytes_type
Expand Down Expand Up @@ -532,8 +537,10 @@ impl CodeGenerator<'_> {
key: &FieldDescriptorProto,
value: &FieldDescriptorProto,
) {
let key_ty = self.resolve_type(key, fq_message_name);
let value_ty = self.resolve_type(value, fq_message_name);
let key_cowed = self.cowed(key, fq_message_name, None);
let key_ty = self.resolve_type(key, fq_message_name, key_cowed);
let value_cowed = self.cowed(value, fq_message_name, None);
let value_ty = self.resolve_type(value, fq_message_name, value_cowed);

debug!(
" map field: {:?}, key type: {:?}, value type: {:?}",
Expand All @@ -551,8 +558,8 @@ impl CodeGenerator<'_> {
.get_first_field(fq_message_name, field.descriptor.name())
.copied()
.unwrap_or_default();
let key_tag = self.field_type_tag(key);
let value_tag = self.map_value_type_tag(value);
let key_tag = self.field_type_tag(key, key_cowed);
let value_tag = self.map_value_type_tag(value, value_cowed);

self.buf.push_str(&format!(
"#[prost({}=\"{}, {}\", tag=\"{}\")]\n",
Expand Down Expand Up @@ -597,9 +604,11 @@ impl CodeGenerator<'_> {
self.append_field_attributes(fq_message_name, oneof.descriptor.name());
self.push_indent();
self.buf.push_str(&format!(
"pub {}: ::core::option::Option<{}>,\n",
"pub {}: ::core::option::Option<{}{}>,\n",
oneof.rust_name(),
type_name
type_name,
if self.message_graph
.message_has_lifetime(fq_message_name) { "<'a>" } else { "" },
));
}

Expand Down Expand Up @@ -628,6 +637,9 @@ impl CodeGenerator<'_> {
self.push_indent();
self.buf.push_str("pub enum ");
self.buf.push_str(&to_upper_camel(oneof.descriptor.name()));
if self.message_graph.message_has_lifetime(fq_message_name) {
self.buf.push_str("<'a>");
}
self.buf.push_str(" {\n");

self.path.push(2);
Expand All @@ -637,8 +649,14 @@ impl CodeGenerator<'_> {
self.append_doc(fq_message_name, Some(field.descriptor.name()));
self.path.pop();

let cowed = self.cowed(
&field.descriptor,
fq_message_name,
Some(oneof.descriptor.name()),
);

self.push_indent();
let ty_tag = self.field_type_tag(&field.descriptor);
let ty_tag = self.field_type_tag(&field.descriptor, cowed);
self.buf.push_str(&format!(
"#[prost({}, tag=\"{}\")]\n",
ty_tag,
Expand All @@ -647,7 +665,7 @@ impl CodeGenerator<'_> {
self.append_field_attributes(&oneof_name, field.descriptor.name());

self.push_indent();
let ty = self.resolve_type(&field.descriptor, fq_message_name);
let ty = self.resolve_type(&field.descriptor, fq_message_name, cowed);

let boxed = self.boxed(
&field.descriptor,
Expand All @@ -656,10 +674,11 @@ impl CodeGenerator<'_> {
);

debug!(
" oneof: {:?}, type: {:?}, boxed: {}",
" oneof: {:?}, type: {:?}, boxed: {} cowed: {}",
field.descriptor.name(),
ty,
boxed
boxed,
cowed,
);

if boxed {
Expand Down Expand Up @@ -883,8 +902,8 @@ impl CodeGenerator<'_> {
let name = method.name.take().unwrap();
let input_proto_type = method.input_type.take().unwrap();
let output_proto_type = method.output_type.take().unwrap();
let input_type = self.resolve_ident(&input_proto_type);
let output_type = self.resolve_ident(&output_proto_type);
let input_type = self.resolve_ident(&input_proto_type).0;
let output_type = self.resolve_ident(&output_proto_type).0;
let client_streaming = method.client_streaming();
let server_streaming = method.server_streaming();

Expand Down Expand Up @@ -947,7 +966,12 @@ impl CodeGenerator<'_> {
self.buf.push_str("}\n");
}

fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
fn resolve_type(
&self,
field: &FieldDescriptorProto,
fq_message_name: &str,
cowed: bool,
) -> String {
match field.r#type() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Expand All @@ -956,7 +980,13 @@ impl CodeGenerator<'_> {
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String if cowed => {
format!("{}::alloc::borrow::Cow<'a, str>", prost_path(self.config))
}
Type::String => format!("{}::alloc::string::String", prost_path(self.config)),
Type::Bytes if cowed => {
format!("{}::alloc::borrow::Cow<'a, [u8]>", prost_path(self.config))
}
Type::Bytes => self
.config
.bytes_type
Expand All @@ -965,16 +995,28 @@ impl CodeGenerator<'_> {
.unwrap_or_default()
.rust_type()
.to_owned(),
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
Type::Group | Type::Message => {
let (mut s, is_extern) = self.resolve_ident(field.type_name());
if !is_extern
&& cowed
&& self
.message_graph
.field_has_lifetime(fq_message_name, field)
{
s.push_str("<'a>");
}
s
}
}
}

fn resolve_ident(&self, pb_ident: &str) -> String {
/// Returns the identifier and a bool indicating if its an extern
fn resolve_ident(&self, pb_ident: &str) -> (String, bool) {
// protoc should always give fully qualified identifiers.
assert_eq!(".", &pb_ident[..1]);

if let Some(proto_ident) = self.extern_paths.resolve_ident(pb_ident) {
return proto_ident;
return (proto_ident, true);
}

let mut local_path = self
Expand All @@ -1000,14 +1042,15 @@ impl CodeGenerator<'_> {
ident_path.next();
}

local_path
let s = local_path
.map(|_| "super".to_string())
.chain(ident_path.map(to_snake))
.chain(iter::once(to_upper_camel(ident_type)))
.join("::")
.join("::");
(s, false)
}

fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
fn field_type_tag(&self, field: &FieldDescriptorProto, cowed: bool) -> Cow<'static, str> {
match field.r#type() {
Type::Float => Cow::Borrowed("float"),
Type::Double => Cow::Borrowed("double"),
Expand All @@ -1022,24 +1065,26 @@ impl CodeGenerator<'_> {
Type::Sfixed32 => Cow::Borrowed("sfixed32"),
Type::Sfixed64 => Cow::Borrowed("sfixed64"),
Type::Bool => Cow::Borrowed("bool"),
Type::String if cowed => Cow::Borrowed("cow_str"),
Type::String => Cow::Borrowed("string"),
Type::Bytes if cowed => Cow::Borrowed("cow_bytes"),
Type::Bytes => Cow::Borrowed("bytes"),
Type::Group => Cow::Borrowed("group"),
Type::Message => Cow::Borrowed("message"),
Type::Enum => Cow::Owned(format!(
"enumeration={:?}",
self.resolve_ident(field.type_name())
self.resolve_ident(field.type_name()).0
)),
}
}

fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
fn map_value_type_tag(&self, field: &FieldDescriptorProto, cowed: bool) -> Cow<'static, str> {
match field.r#type() {
Type::Enum => Cow::Owned(format!(
"enumeration({})",
self.resolve_ident(field.type_name())
self.resolve_ident(field.type_name()).0
)),
_ => self.field_type_tag(field),
_ => self.field_type_tag(field, cowed),
}
}

Expand Down Expand Up @@ -1111,16 +1156,18 @@ impl CodeGenerator<'_> {
let fd_type = field.r#type();

// We only support Cow for Bytes and String
if !matches!(fd_type, Type::Bytes | Type::String) {
if !matches!(
fd_type,
Type::Message | Type::Group | Type::Bytes | Type::String
) {
return false;
}

let config_path = match oneof {
None => Cow::Borrowed(fq_message_name),
Some(ooname) => Cow::Owned(format!("{fq_message_name}.{ooname}")),
};
self
.config
self.config
.cowed
.get_first_field(&config_path, field.name())
.is_some()
Expand Down
33 changes: 28 additions & 5 deletions prost-build/src/message_graph.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use petgraph::algo::has_path_connecting;
use petgraph::graph::NodeIndex;
Expand Down Expand Up @@ -157,20 +157,38 @@ impl MessageGraph {
}
}

pub fn message_has_lifetime(&self, fq_message_name: &str) -> bool {
fn message_has_lifetime_internal(
&self,
fq_message_name: &str,
visited: &mut HashSet<String>,
) -> bool {
visited.insert(fq_message_name.to_string());
assert_eq!(".", &fq_message_name[..1]);
self.get_message(fq_message_name)
.unwrap()
.field
.iter()
.any(|field| self.field_has_lifetime(fq_message_name, field))
.any(|field| self.field_has_lifetime_internal(fq_message_name, field, visited))
}

pub fn field_has_lifetime(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool {
pub fn message_has_lifetime(&self, fq_message_name: &str) -> bool {
let mut visited = Default::default();
self.message_has_lifetime_internal(fq_message_name, &mut visited)
}

fn field_has_lifetime_internal(
&self,
fq_message_name: &str,
field: &FieldDescriptorProto,
visited: &mut HashSet<String>,
) -> bool {
assert_eq!(".", &fq_message_name[..1]);

if field.r#type() == Type::Message {
self.message_has_lifetime(field.type_name())
if visited.contains(field.type_name()) {
return false;
}
self.message_has_lifetime_internal(field.type_name(), visited)
} else {
matches!(field.r#type(), Type::Bytes | Type::String)
&& self
Expand All @@ -179,4 +197,9 @@ impl MessageGraph {
.is_some()
}
}

pub fn field_has_lifetime(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool {
let mut visited = Default::default();
self.field_has_lifetime_internal(fq_message_name, field, &mut visited)
}
}

0 comments on commit 4556734

Please sign in to comment.