Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unknown fields #317

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,20 @@ impl<'a> CodeGenerator<'a> {
}
self.path.pop();

if self
.config
.unknown_fields_messages
.iter()
// Skip the leading '.' when matching
.any(|m| m == &fq_message_name[1..])
{
self.push_indent();
self.buf.push_str("#[prost(unknown_fields)]\n");
self.push_indent();
self.buf
.push_str("pub protobuf_unknown_fields: Vec<::prost::UnknownField>,\n");
}

self.depth -= 1;
self.push_indent();
self.buf.push_str("}\n");
Expand Down
11 changes: 11 additions & 0 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ pub struct Config {
strip_enum_prefix: bool,
out_dir: Option<PathBuf>,
extern_paths: Vec<(String, String)>,
unknown_fields_messages: Vec<String>,
}

impl Config {
Expand Down Expand Up @@ -484,6 +485,15 @@ impl Config {
self
}

pub fn unknown_fields_message<S>(&mut self, message: S) -> &mut Self
where
S: AsRef<str>,
{
self.unknown_fields_messages
.push(message.as_ref().to_owned());
self
}

/// Compile `.proto` files into Rust files during a Cargo build with additional code generator
/// configuration options.
///
Expand Down Expand Up @@ -626,6 +636,7 @@ impl default::Default for Config {
strip_enum_prefix: true,
out_dir: None,
extern_paths: Vec::new(),
unknown_fields_messages: Vec::new(),
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions prost-derive/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod map;
mod message;
mod oneof;
mod scalar;
mod unknown_fields;

use std::fmt;
use std::slice;
Expand All @@ -24,6 +25,8 @@ pub enum Field {
Oneof(oneof::Field),
/// A group field.
Group(group::Field),
/// Unknown fields
UnknownFields(unknown_fields::Field),
}

impl Field {
Expand All @@ -46,6 +49,8 @@ impl Field {
Field::Oneof(field)
} else if let Some(field) = group::Field::new(&attrs, inferred_tag)? {
Field::Group(field)
} else if let Some(field) = unknown_fields::Field::new(&attrs)? {
Field::UnknownFields(field)
} else {
bail!("no type attribute");
};
Expand Down Expand Up @@ -84,6 +89,7 @@ impl Field {
Field::Map(ref map) => vec![map.tag],
Field::Oneof(ref oneof) => oneof.tags.clone(),
Field::Group(ref group) => vec![group.tag],
Field::UnknownFields(_) => Vec::new(),
}
}

Expand All @@ -95,6 +101,7 @@ impl Field {
Field::Map(ref map) => map.encode(ident),
Field::Oneof(ref oneof) => oneof.encode(ident),
Field::Group(ref group) => group.encode(ident),
Field::UnknownFields(ref unknown_fields) => unknown_fields.encode(ident),
}
}

Expand All @@ -107,6 +114,7 @@ impl Field {
Field::Map(ref map) => map.merge(ident),
Field::Oneof(ref oneof) => oneof.merge(ident),
Field::Group(ref group) => group.merge(ident),
Field::UnknownFields(ref unknown_fields) => unknown_fields.merge(ident),
}
}

Expand All @@ -118,6 +126,7 @@ impl Field {
Field::Message(ref msg) => msg.encoded_len(ident),
Field::Oneof(ref oneof) => oneof.encoded_len(ident),
Field::Group(ref group) => group.encoded_len(ident),
Field::UnknownFields(ref unknown_fields) => unknown_fields.encoded_len(ident),
}
}

Expand All @@ -129,6 +138,7 @@ impl Field {
Field::Map(ref map) => map.clear(ident),
Field::Oneof(ref oneof) => oneof.clear(ident),
Field::Group(ref group) => group.clear(ident),
Field::UnknownFields(ref unknown_fields) => unknown_fields.clear(ident),
}
}

Expand Down
47 changes: 47 additions & 0 deletions prost-derive/src/field/unknown_fields.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use anyhow::{bail, Error};
use proc_macro2::TokenStream;
use quote::quote;
use syn::Meta;

#[derive(Clone)]
pub struct Field;

impl Field {
pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> {
if let Some(attr) = attrs.iter().next() {
if attr.path().is_ident("unknown_fields") {
if attrs.len() > 1 {
bail!("invalid unknown_fields attribute(s): {:?}", &attrs[1..0]);
}
return Ok(Some(Field));
}
}
Ok(None)
}

pub fn encode(&self, ident: TokenStream) -> TokenStream {
quote! {
for field in #ident.iter() {
::prost::encoding::bytes::encode(field.tag, &field.value, buf);
}
}
}

pub fn clear(&self, ident: TokenStream) -> TokenStream {
quote!(#ident.clear())
}

pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident
.iter()
.map(|field| ::prost::encoding::bytes::encoded_len(field.tag, &field.value))
.sum::<usize>()
}
}

pub fn merge(&self, _ident: TokenStream) -> TokenStream {
// Adding unknown fields is handled separately by the decoding logic
quote!(Result::<(), ::prost::DecodeError>::Ok(()))
}
}
51 changes: 35 additions & 16 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
})
.collect::<Result<Vec<_>, _>>()?;

let has_unknown_fields_accessor = fields.iter().any(|(_, f)| match f {
Field::UnknownFields(_) => true,
_ => false,
});

// We want Debug to be in declaration order
let unsorted_fields = fields.clone();

// Sort the fields by tag number so that fields will be encoded in tag order.
// TODO: This encodes oneof fields in the position of their lowest tag,
// regardless of the currently occupied variant, is that consequential?
// See: https://developers.google.com/protocol-buffers/docs/encoding#order
fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap_or(std::u32::MAX));
let fields = fields;

let mut tags = fields
Expand All @@ -103,20 +108,24 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
.map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));

let merge = fields.iter().map(|&(ref field_ident, ref field)| {
let merge = field.merge(quote!(value));
let tags = field
.tags()
.into_iter()
.map(|tag| quote!(#tag))
.intersperse(quote!(|));
quote! {
#(#tags)* => {
let mut value = &mut self.#field_ident;
#merge.map_err(|mut error| {
error.push(STRUCT_NAME, stringify!(#field_ident));
error
})
},
if let Field::UnknownFields(_) = field {
quote! {}
} else {
let merge = field.merge(quote!(value));
let tags = field
.tags()
.into_iter()
.map(|tag| quote!(#tag))
.intersperse(quote!(|));
quote! {
#(#tags)* => {
let mut value = &mut self.#field_ident;
#merge.map_err(|mut error| {
error.push(STRUCT_NAME, stringify!(#field_ident));
error
})
},
}
}
});

Expand Down Expand Up @@ -175,6 +184,16 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
quote!(f.debug_tuple(stringify!(#ident)))
};

let unknown_field_handler = if has_unknown_fields_accessor {
quote! {
::prost::encoding::unknown_field(wire_type, tag, buf, &mut self.protobuf_unknown_fields, ctx)
}
} else {
quote! {
::prost::encoding::skip_field(wire_type, tag, buf, ctx)
}
};

let expanded = quote! {
impl ::prost::Message for #ident {
#[allow(unused_variables)]
Expand All @@ -194,7 +213,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
#struct_name
match tag {
#(#merge)*
_ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
_ => #unknown_field_handler,
}
}

Expand Down
57 changes: 55 additions & 2 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use ::bytes::{buf::ext::BufExt, Buf, BufMut};

use crate::DecodeError;
use crate::Message;
use crate::UnknownField;

/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
/// The buffer must have enough remaining space (maximum 10 bytes).
Expand Down Expand Up @@ -391,6 +392,52 @@ pub fn skip_field<B>(
) -> Result<(), DecodeError>
where
B: Buf,
{
handle_unknown_field(
wire_type,
tag,
buf,
&mut |_, len, buf| {
buf.advance(len);
},
ctx,
)
}

pub fn unknown_field<B>(
wire_type: WireType,
tag: u32,
buf: &mut B,
unknown_fields: &mut Vec<UnknownField>,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
{
handle_unknown_field(
wire_type,
tag,
buf,
&mut |tag, len, buf| {
let mut value = Vec::new();
value.resize(len as usize, 0);
buf.copy_to_slice(value.as_mut_slice());
unknown_fields.push(UnknownField { tag, value });
},
ctx,
)
}

fn handle_unknown_field<B, F>(
wire_type: WireType,
tag: u32,
buf: &mut B,
handler: &mut F,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
F: FnMut(u32, usize, &mut B),
{
ctx.limit_reached()?;
let len = match wire_type {
Expand All @@ -407,7 +454,13 @@ where
}
break 0;
}
_ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
_ => handle_unknown_field(
inner_wire_type,
inner_tag,
buf,
handler,
ctx.enter_recursion(),
)?,
}
},
WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
Expand All @@ -417,7 +470,7 @@ where
return Err(DecodeError::new("buffer underflow"));
}

buf.advance(len as usize);
handler(tag, len as usize, buf);
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod types;
pub mod encoding;

pub use crate::error::{DecodeError, EncodeError};
pub use crate::message::Message;
pub use crate::message::{Message, UnknownField};

use bytes::{Buf, BufMut};

Expand Down
13 changes: 13 additions & 0 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,16 @@ where
(**self).clear()
}
}

/// An unknown field on a protobuf message.
///
/// Structs generated by `prost-build` with the `unknown_fields` flag enabled have a field
/// `protobuf_unknown_fields: Vec<UnknownField>` that gives access to any fields that were not
/// recognized.
/// Any values added to this field will also be present in serialized form.
#[derive(Debug, Default, Clone, PartialEq)]
pub struct UnknownField {
pub tag: u32,
/// The raw field value.
pub value: Vec<u8>,
}
5 changes: 5 additions & 0 deletions tests/src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ fn main() {
.compile_protos(&[src.join("well_known_types.proto")], includes)
.unwrap();

config
.unknown_fields_message("unknown_fields.MessageWithUnknownFields")
.compile_protos(&[src.join("unknown_fields.proto")], includes)
.unwrap();

config
.compile_protos(
&[src.join("packages/widget_factory.proto")],
Expand Down
2 changes: 2 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ mod message_encoding;
#[cfg(test)]
mod no_unused_results;
#[cfg(test)]
mod unknown_fields;
#[cfg(test)]
mod well_known_types;

pub mod foo {
Expand Down
Loading