Skip to content

Commit

Permalink
Optional support for bytes::Bytes type backing bytes field (tokio-rs#341
Browse files Browse the repository at this point in the history
)

* Optional support for Bytes type backing bytes field.

* Use actual Bytes type during build.

* Run rustfmt.

* Address clippy lints.

* Clean up code generation logic.

* Address fmt.

* Update types.

* Reinstate missing reserve().

* Rework trait to be private and clarify current behaviour.

* Use trait bounds where possible.

* Add link to related issue from bytes crate.

* Add tests for bytes.

* minor comment cleanup

* fix default bytes values

* clippy

* fmt

Co-authored-by: Rolf Timmermans <rolftimmermans@voormedia.com>
  • Loading branch information
danburkert and rolftimmermans committed Jul 18, 2020
1 parent fa71719 commit a1cccbc
Show file tree
Hide file tree
Showing 7 changed files with 366 additions and 67 deletions.
50 changes: 44 additions & 6 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ enum Syntax {
Proto3,
}

#[derive(PartialEq)]
enum BytesTy {
Vec,
Bytes,
}

impl BytesTy {
fn as_str(&self) -> &'static str {
match self {
BytesTy::Vec => "\"vec\"",
BytesTy::Bytes => "\"bytes\"",
}
}
}

pub struct CodeGenerator<'a> {
config: &'a mut Config,
package: String,
Expand Down Expand Up @@ -277,7 +292,7 @@ impl<'a> CodeGenerator<'a> {
let repeated = field.label == Some(Label::Repeated as i32);
let deprecated = self.deprecated(&field);
let optional = self.optional(&field);
let ty = self.resolve_type(&field);
let ty = self.resolve_type(&field, msg_name);

let boxed = !repeated
&& (type_ == Type::Message || type_ == Type::Group)
Expand All @@ -302,6 +317,12 @@ impl<'a> CodeGenerator<'a> {
let type_tag = self.field_type_tag(&field);
self.buf.push_str(&type_tag);

if type_ == Type::Bytes {
self.buf.push_str("=");
self.buf
.push_str(self.bytes_backing_type(&field, msg_name).as_str());
}

match field.label() {
Label::Optional => {
if optional {
Expand Down Expand Up @@ -394,8 +415,8 @@ impl<'a> CodeGenerator<'a> {
key: &FieldDescriptorProto,
value: &FieldDescriptorProto,
) {
let key_ty = self.resolve_type(key);
let value_ty = self.resolve_type(value);
let key_ty = self.resolve_type(key, msg_name);
let value_ty = self.resolve_type(value, msg_name);

debug!(
" map field: {:?}, key type: {:?}, value type: {:?}",
Expand All @@ -420,6 +441,7 @@ impl<'a> CodeGenerator<'a> {

let key_tag = self.field_type_tag(key);
let value_tag = self.map_value_type_tag(value);

self.buf.push_str(&format!(
"#[prost({}=\"{}, {}\", tag=\"{}\")]\n",
annotation_ty,
Expand Down Expand Up @@ -512,7 +534,7 @@ impl<'a> CodeGenerator<'a> {
self.append_field_attributes(&oneof_name, field.name());

self.push_indent();
let ty = self.resolve_type(&field);
let ty = self.resolve_type(&field, msg_name);

let boxed = (type_ == Type::Message || type_ == Type::Group)
&& self.message_graph.is_nested(field.type_name(), msg_name);
Expand Down Expand Up @@ -715,7 +737,7 @@ impl<'a> CodeGenerator<'a> {
self.buf.push_str("}\n");
}

fn resolve_type(&self, field: &FieldDescriptorProto) -> String {
fn resolve_type(&self, field: &FieldDescriptorProto, msg_name: &str) -> String {
match field.r#type() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Expand All @@ -725,7 +747,10 @@ impl<'a> CodeGenerator<'a> {
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String => String::from("::prost::alloc::string::String"),
Type::Bytes => String::from("::prost::alloc::vec::Vec<u8>"),
Type::Bytes => match self.bytes_backing_type(field, msg_name) {
BytesTy::Bytes => String::from("::prost::bytes::Bytes"),
BytesTy::Vec => String::from("::prost::alloc::vec::Vec<u8>"),
},
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
}
}
Expand Down Expand Up @@ -804,6 +829,19 @@ impl<'a> CodeGenerator<'a> {
}
}

fn bytes_backing_type(&self, field: &FieldDescriptorProto, msg_name: &str) -> BytesTy {
let bytes = self
.config
.bytes
.iter()
.any(|matcher| match_ident(matcher, msg_name, Some(field.name())));
if bytes {
BytesTy::Bytes
} else {
BytesTy::Vec
}
}

/// Returns `true` if the field options includes the `deprecated` option.
fn deprecated(&self, field: &FieldDescriptorProto) -> bool {
field
Expand Down
59 changes: 59 additions & 0 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ pub trait ServiceGenerator {
pub struct Config {
service_generator: Option<Box<dyn ServiceGenerator>>,
btree_map: Vec<String>,
bytes: Vec<String>,
type_attributes: Vec<(String, String)>,
field_attributes: Vec<(String, String)>,
prost_types: bool,
Expand Down Expand Up @@ -255,6 +256,63 @@ impl Config {
self
}

/// Configure the code generator to generate Rust [`bytes::Bytes`][1] fields for Protobuf
/// [`bytes`][2] type fields.
///
/// # Arguments
///
/// **`paths`** - paths to specific fields, messages, or packages which should use a Rust
/// `Bytes` for Protobuf `bytes` fields. Paths are specified in terms of the Protobuf type
/// name (not the generated Rust type name). Paths with a leading `.` are treated as fully
/// qualified names. Paths without a leading `.` are treated as relative, and are suffix
/// matched on the fully qualified field name. If a Protobuf map field matches any of the
/// paths, a Rust `Bytes` field is generated instead of the default [`Vec<u8>`][3].
///
/// The matching is done on the Protobuf names, before converting to Rust-friendly casing
/// standards.
///
/// # Examples
///
/// ```rust
/// # let mut config = prost_build::Config::new();
/// // Match a specific field in a message type.
/// config.bytes(&[".my_messages.MyMessageType.my_bytes_field"]);
///
/// // Match all bytes fields in a message type.
/// config.bytes(&[".my_messages.MyMessageType"]);
///
/// // Match all bytes fields in a package.
/// config.bytes(&[".my_messages"]);
///
/// // Match all bytes fields. Expecially useful in `no_std` contexts.
/// config.bytes(&["."]);
///
/// // Match all bytes fields in a nested message.
/// config.bytes(&[".my_messages.MyMessageType.MyNestedMessageType"]);
///
/// // Match all fields named 'my_bytes_field'.
/// config.bytes(&["my_bytes_field"]);
///
/// // Match all fields named 'my_bytes_field' in messages named 'MyMessageType', regardless of
/// // package or nesting.
/// config.bytes(&["MyMessageType.my_bytes_field"]);
///
/// // Match all fields named 'my_bytes_field', and all fields in the 'foo.bar' package.
/// config.bytes(&["my_bytes_field", ".foo.bar"]);
/// ```
///
/// [1]: https://docs.rs/bytes/latest/bytes/struct.Bytes.html
/// [2]: https://developers.google.com/protocol-buffers/docs/proto3#scalar
/// [3]: https://doc.rust-lang.org/std/vec/struct.Vec.html
pub fn bytes<I, S>(&mut self, paths: I) -> &mut Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.bytes = paths.into_iter().map(|s| s.as_ref().to_string()).collect();
self
}

/// Add additional attribute to matched fields.
///
/// # Arguments
Expand Down Expand Up @@ -626,6 +684,7 @@ impl default::Default for Config {
Config {
service_generator: None,
btree_map: Vec::new(),
bytes: Vec::new(),
type_attributes: Vec::new(),
field_attributes: Vec::new(),
prost_types: true,
Expand Down
72 changes: 53 additions & 19 deletions prost-derive/src/field/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt;

use anyhow::{anyhow, bail, Error};
use proc_macro2::{Span, TokenStream};
use quote::{quote, ToTokens};
use quote::{quote, ToTokens, TokenStreamExt};
use syn::{parse_str, Ident, Lit, LitByteStr, Meta, MetaList, MetaNameValue, NestedMeta, Path};

use crate::field::{bool_attr, set_option, tag_attr, Label};
Expand Down Expand Up @@ -194,7 +194,7 @@ impl Field {
Kind::Plain(ref default) | Kind::Required(ref default) => {
let default = default.typed();
match self.ty {
Ty::String | Ty::Bytes => quote!(#ident.clear()),
Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
_ => quote!(#ident = #default),
}
}
Expand Down Expand Up @@ -381,10 +381,33 @@ pub enum Ty {
Sfixed64,
Bool,
String,
Bytes,
Bytes(BytesTy),
Enumeration(Path),
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum BytesTy {
Vec,
Bytes,
}

impl BytesTy {
fn try_from_str(s: &str) -> Result<Self, Error> {
match s {
"vec" => Ok(BytesTy::Vec),
"bytes" => Ok(BytesTy::Bytes),
_ => bail!("Invalid bytes type: {}", s),
}
}

fn rust_type(&self) -> TokenStream {
match self {
BytesTy::Vec => quote! { ::prost::alloc::vec::Vec<u8> },
BytesTy::Bytes => quote! { ::prost::bytes::Bytes },
}
}
}

impl Ty {
pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> {
let ty = match *attr {
Expand All @@ -402,7 +425,12 @@ impl Ty {
Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
Meta::Path(ref name) if name.is_ident("string") => Ty::String,
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes,
Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
Meta::NameValue(MetaNameValue {
ref path,
lit: Lit::Str(ref l),
..
}) if path.is_ident("bytes") => Ty::Bytes(BytesTy::try_from_str(&l.value())?),
Meta::NameValue(MetaNameValue {
ref path,
lit: Lit::Str(ref l),
Expand Down Expand Up @@ -447,7 +475,7 @@ impl Ty {
"sfixed64" => Ty::Sfixed64,
"bool" => Ty::Bool,
"string" => Ty::String,
"bytes" => Ty::Bytes,
"bytes" => Ty::Bytes(BytesTy::Vec),
s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
let s = &s[enumeration_len..].trim();
match s.chars().next() {
Expand Down Expand Up @@ -483,16 +511,16 @@ impl Ty {
Ty::Sfixed64 => "sfixed64",
Ty::Bool => "bool",
Ty::String => "string",
Ty::Bytes => "bytes",
Ty::Bytes(..) => "bytes",
Ty::Enumeration(..) => "enum",
}
}

// TODO: rename to 'owned_type'.
pub fn rust_type(&self) -> TokenStream {
match *self {
match self {
Ty::String => quote!(::prost::alloc::string::String),
Ty::Bytes => quote!(::prost::alloc::vec::Vec<u8>),
Ty::Bytes(ty) => ty.rust_type(),
_ => self.rust_ref_type(),
}
}
Expand All @@ -514,7 +542,7 @@ impl Ty {
Ty::Sfixed64 => quote!(i64),
Ty::Bool => quote!(bool),
Ty::String => quote!(&str),
Ty::Bytes => quote!(&[u8]),
Ty::Bytes(..) => quote!(&[u8]),
Ty::Enumeration(..) => quote!(i32),
}
}
Expand All @@ -526,9 +554,12 @@ impl Ty {
}
}

/// Returns true if the scalar type is length delimited (i.e., `string` or `bytes`).
/// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
pub fn is_numeric(&self) -> bool {
*self != Ty::String && *self != Ty::Bytes
match self {
Ty::String | Ty::Bytes(..) => false,
_ => true,
}
}
}

Expand Down Expand Up @@ -621,7 +652,11 @@ impl DefaultValue {

Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value),
Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()),
Lit::ByteStr(ref lit) if *ty == Ty::Bytes => DefaultValue::Bytes(lit.value()),
Lit::ByteStr(ref lit)
if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) =>
{
DefaultValue::Bytes(lit.value())
}

Lit::Str(ref lit) => {
let value = lit.value();
Expand Down Expand Up @@ -734,7 +769,7 @@ impl DefaultValue {

Ty::Bool => DefaultValue::Bool(false),
Ty::String => DefaultValue::String(String::new()),
Ty::Bytes => DefaultValue::Bytes(Vec::new()),
Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
}
}
Expand All @@ -744,13 +779,11 @@ impl DefaultValue {
DefaultValue::String(ref value) if value.is_empty() => {
quote!(::prost::alloc::string::String::new())
}
DefaultValue::String(ref value) => quote!(#value.to_owned()),
DefaultValue::Bytes(ref value) if value.is_empty() => {
quote!(::prost::alloc::vec::Vec::new())
}
DefaultValue::String(ref value) => quote!(#value.into()),
DefaultValue::Bytes(ref value) if value.is_empty() => quote!(Default::default()),
DefaultValue::Bytes(ref value) => {
let lit = LitByteStr::new(value, Span::call_site());
quote!(#lit.to_owned())
quote!(#lit.as_ref().into())
}

ref other => other.typed(),
Expand Down Expand Up @@ -778,7 +811,8 @@ impl ToTokens for DefaultValue {
DefaultValue::Bool(value) => value.to_tokens(tokens),
DefaultValue::String(ref value) => value.to_tokens(tokens),
DefaultValue::Bytes(ref value) => {
LitByteStr::new(value, Span::call_site()).to_tokens(tokens)
let byte_str = LitByteStr::new(value, Span::call_site());
tokens.append_all(quote!(#byte_str as &[u8]));
}
DefaultValue::Enumeration(ref value) => value.to_tokens(tokens),
DefaultValue::Path(ref value) => value.to_tokens(tokens),
Expand Down
4 changes: 2 additions & 2 deletions prost-types/src/protobuf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ pub struct UninterpretedOption {
pub negative_int_value: ::core::option::Option<i64>,
#[prost(double, optional, tag="6")]
pub double_value: ::core::option::Option<f64>,
#[prost(bytes, optional, tag="7")]
#[prost(bytes="vec", optional, tag="7")]
pub string_value: ::core::option::Option<::prost::alloc::vec::Vec<u8>>,
#[prost(string, optional, tag="8")]
pub aggregate_value: ::core::option::Option<::prost::alloc::string::String>,
Expand Down Expand Up @@ -1003,7 +1003,7 @@ pub struct Any {
#[prost(string, tag="1")]
pub type_url: ::prost::alloc::string::String,
/// Must be a valid serialized protocol buffer of the above specified type.
#[prost(bytes, tag="2")]
#[prost(bytes="vec", tag="2")]
pub value: ::prost::alloc::vec::Vec<u8>,
}
/// `SourceContext` represents information about the source of a
Expand Down
Loading

0 comments on commit a1cccbc

Please sign in to comment.