Skip to content

Commit

Permalink
Implement support for Unknown Fields.
Browse files Browse the repository at this point in the history
  • Loading branch information
wildarch committed May 7, 2020
1 parent 2de785a commit 9b3c001
Show file tree
Hide file tree
Showing 12 changed files with 286 additions and 19 deletions.
14 changes: 14 additions & 0 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,20 @@ impl<'a> CodeGenerator<'a> {
}
self.path.pop();

if self
.config
.unknown_fields_messages
.iter()
// Skip the leading '.' when matching
.any(|m| m == &fq_message_name[1..])
{
self.push_indent();
self.buf.push_str("#[prost(unknown_fields)]\n");
self.push_indent();
self.buf
.push_str("pub protobuf_unknown_fields: Vec<::prost::UnknownField>,\n");
}

self.depth -= 1;
self.push_indent();
self.buf.push_str("}\n");
Expand Down
11 changes: 11 additions & 0 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ pub struct Config {
strip_enum_prefix: bool,
out_dir: Option<PathBuf>,
extern_paths: Vec<(String, String)>,
unknown_fields_messages: Vec<String>,
}

impl Config {
Expand Down Expand Up @@ -484,6 +485,15 @@ impl Config {
self
}

pub fn unknown_fields_message<S>(&mut self, message: S) -> &mut Self
where
S: AsRef<str>,
{
self.unknown_fields_messages
.push(message.as_ref().to_owned());
self
}

/// Compile `.proto` files into Rust files during a Cargo build with additional code generator
/// configuration options.
///
Expand Down Expand Up @@ -626,6 +636,7 @@ impl default::Default for Config {
strip_enum_prefix: true,
out_dir: None,
extern_paths: Vec::new(),
unknown_fields_messages: Vec::new(),
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions prost-derive/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod map;
mod message;
mod oneof;
mod scalar;
mod unknown_fields;

use std::fmt;
use std::slice;
Expand All @@ -24,6 +25,8 @@ pub enum Field {
Oneof(oneof::Field),
/// A group field.
Group(group::Field),
/// Unknown fields
UnknownFields(unknown_fields::Field),
}

impl Field {
Expand All @@ -46,6 +49,8 @@ impl Field {
Field::Oneof(field)
} else if let Some(field) = group::Field::new(&attrs, inferred_tag)? {
Field::Group(field)
} else if let Some(field) = unknown_fields::Field::new(&attrs)? {
Field::UnknownFields(field)
} else {
bail!("no type attribute");
};
Expand Down Expand Up @@ -84,6 +89,7 @@ impl Field {
Field::Map(ref map) => vec![map.tag],
Field::Oneof(ref oneof) => oneof.tags.clone(),
Field::Group(ref group) => vec![group.tag],
Field::UnknownFields(_) => Vec::new(),
}
}

Expand All @@ -95,6 +101,7 @@ impl Field {
Field::Map(ref map) => map.encode(ident),
Field::Oneof(ref oneof) => oneof.encode(ident),
Field::Group(ref group) => group.encode(ident),
Field::UnknownFields(ref unknown_fields) => unknown_fields.encode(ident),
}
}

Expand All @@ -107,6 +114,7 @@ impl Field {
Field::Map(ref map) => map.merge(ident),
Field::Oneof(ref oneof) => oneof.merge(ident),
Field::Group(ref group) => group.merge(ident),
Field::UnknownFields(ref unknown_fields) => unknown_fields.merge(ident),
}
}

Expand All @@ -118,6 +126,7 @@ impl Field {
Field::Message(ref msg) => msg.encoded_len(ident),
Field::Oneof(ref oneof) => oneof.encoded_len(ident),
Field::Group(ref group) => group.encoded_len(ident),
Field::UnknownFields(ref unknown_fields) => unknown_fields.encoded_len(ident),
}
}

Expand All @@ -129,6 +138,7 @@ impl Field {
Field::Map(ref map) => map.clear(ident),
Field::Oneof(ref oneof) => oneof.clear(ident),
Field::Group(ref group) => group.clear(ident),
Field::UnknownFields(ref unknown_fields) => unknown_fields.clear(ident),
}
}

Expand Down
47 changes: 47 additions & 0 deletions prost-derive/src/field/unknown_fields.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use anyhow::{bail, Error};
use proc_macro2::TokenStream;
use quote::quote;
use syn::Meta;

#[derive(Clone)]
pub struct Field;

impl Field {
pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> {
if let Some(attr) = attrs.iter().next() {
if attr.path().is_ident("unknown_fields") {
if attrs.len() > 1 {
bail!("invalid unknown_fields attribute(s): {:?}", &attrs[1..0]);
}
return Ok(Some(Field));
}
}
Ok(None)
}

pub fn encode(&self, ident: TokenStream) -> TokenStream {
quote! {
for field in #ident.iter() {
::prost::encoding::bytes::encode(field.tag, &field.value, buf);
}
}
}

pub fn clear(&self, ident: TokenStream) -> TokenStream {
quote!(#ident.clear())
}

pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident
.iter()
.map(|field| ::prost::encoding::bytes::encoded_len(field.tag, &field.value))
.sum::<usize>()
}
}

pub fn merge(&self, _ident: TokenStream) -> TokenStream {
// Adding unknown fields is handled separately by the decoding logic
quote!(Result::<(), ::prost::DecodeError>::Ok(()))
}
}
51 changes: 35 additions & 16 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
})
.collect::<Result<Vec<_>, _>>()?;

let has_unknown_fields_accessor = fields.iter().any(|(_, f)| match f {
Field::UnknownFields(_) => true,
_ => false,
});

// We want Debug to be in declaration order
let unsorted_fields = fields.clone();

// Sort the fields by tag number so that fields will be encoded in tag order.
// TODO: This encodes oneof fields in the position of their lowest tag,
// regardless of the currently occupied variant, is that consequential?
// See: https://developers.google.com/protocol-buffers/docs/encoding#order
fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap_or(std::u32::MAX));
let fields = fields;

let mut tags = fields
Expand All @@ -103,20 +108,24 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
.map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));

let merge = fields.iter().map(|&(ref field_ident, ref field)| {
let merge = field.merge(quote!(value));
let tags = field
.tags()
.into_iter()
.map(|tag| quote!(#tag))
.intersperse(quote!(|));
quote! {
#(#tags)* => {
let mut value = &mut self.#field_ident;
#merge.map_err(|mut error| {
error.push(STRUCT_NAME, stringify!(#field_ident));
error
})
},
if let Field::UnknownFields(_) = field {
quote! {}
} else {
let merge = field.merge(quote!(value));
let tags = field
.tags()
.into_iter()
.map(|tag| quote!(#tag))
.intersperse(quote!(|));
quote! {
#(#tags)* => {
let mut value = &mut self.#field_ident;
#merge.map_err(|mut error| {
error.push(STRUCT_NAME, stringify!(#field_ident));
error
})
},
}
}
});

Expand Down Expand Up @@ -175,6 +184,16 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
quote!(f.debug_tuple(stringify!(#ident)))
};

let unknown_field_handler = if has_unknown_fields_accessor {
quote! {
::prost::encoding::unknown_field(wire_type, tag, buf, &mut self.protobuf_unknown_fields, ctx)
}
} else {
quote! {
::prost::encoding::skip_field(wire_type, tag, buf, ctx)
}
};

let expanded = quote! {
impl ::prost::Message for #ident {
#[allow(unused_variables)]
Expand All @@ -194,7 +213,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
#struct_name
match tag {
#(#merge)*
_ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
_ => #unknown_field_handler,
}
}

Expand Down
57 changes: 55 additions & 2 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use ::bytes::{buf::ext::BufExt, Buf, BufMut};

use crate::DecodeError;
use crate::Message;
use crate::UnknownField;

/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
/// The buffer must have enough remaining space (maximum 10 bytes).
Expand Down Expand Up @@ -391,6 +392,52 @@ pub fn skip_field<B>(
) -> Result<(), DecodeError>
where
B: Buf,
{
handle_unknown_field(
wire_type,
tag,
buf,
&mut |_, len, buf| {
buf.advance(len);
},
ctx,
)
}

pub fn unknown_field<B>(
wire_type: WireType,
tag: u32,
buf: &mut B,
unknown_fields: &mut Vec<UnknownField>,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
{
handle_unknown_field(
wire_type,
tag,
buf,
&mut |tag, len, buf| {
let mut value = Vec::new();
value.resize(len as usize, 0);
buf.copy_to_slice(value.as_mut_slice());
unknown_fields.push(UnknownField { tag, value });
},
ctx,
)
}

fn handle_unknown_field<B, F>(
wire_type: WireType,
tag: u32,
buf: &mut B,
handler: &mut F,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
F: FnMut(u32, usize, &mut B),
{
ctx.limit_reached()?;
let len = match wire_type {
Expand All @@ -407,7 +454,13 @@ where
}
break 0;
}
_ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
_ => handle_unknown_field(
inner_wire_type,
inner_tag,
buf,
handler,
ctx.enter_recursion(),
)?,
}
},
WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
Expand All @@ -417,7 +470,7 @@ where
return Err(DecodeError::new("buffer underflow"));
}

buf.advance(len as usize);
handler(tag, len as usize, buf);
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod types;
pub mod encoding;

pub use crate::error::{DecodeError, EncodeError};
pub use crate::message::Message;
pub use crate::message::{Message, UnknownField};

use bytes::{Buf, BufMut};

Expand Down
13 changes: 13 additions & 0 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,16 @@ where
(**self).clear()
}
}

/// An unknown field on a protobuf message.
///
/// Structs generated by `prost-build` with the `unknown_fields` flag enabled have a field
/// `protobuf_unknown_fields: Vec<UnknownField>` that gives access to any fields that were not
/// recognized.
/// Any values added to this field will also be present in serialized form.
#[derive(Debug, Default, Clone, PartialEq)]
pub struct UnknownField {
pub tag: u32,
/// The raw field value.
pub value: Vec<u8>,
}
5 changes: 5 additions & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ fn main() {
.compile_protos(&[src.join("well_known_types.proto")], includes)
.unwrap();

config
.unknown_fields_message("unknown_fields.MessageWithUnknownFields")
.compile_protos(&[src.join("unknown_fields.proto")], includes)
.unwrap();

config
.compile_protos(
&[src.join("packages/widget_factory.proto")],
Expand Down
2 changes: 2 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ mod message_encoding;
#[cfg(test)]
mod no_unused_results;
#[cfg(test)]
mod unknown_fields;
#[cfg(test)]
mod well_known_types;

pub mod foo {
Expand Down
Loading

0 comments on commit 9b3c001

Please sign in to comment.