Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add opt-in support for unknown fields #574

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
3 changes: 0 additions & 3 deletions conformance/failing_tests.txt
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
# TODO(tokio-rs/prost#2): prost doesn't preserve unknown fields.
Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput
Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput
16 changes: 16 additions & 0 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,14 @@ impl<'a> CodeGenerator<'a> {
}
self.path.pop();
}
if let Some(field_name) = self
.config
.include_unknown_fields
.get_first(&fq_message_name)
.cloned()
{
self.append_unknown_field_set(&fq_message_name, &field_name);
}
self.path.pop();

self.path.push(8);
Expand Down Expand Up @@ -451,6 +459,14 @@ impl<'a> CodeGenerator<'a> {
));
}

fn append_unknown_field_set(&mut self, fq_message_name: &str, field_name: &str) {
self.buf.push_str("#[prost(unknown)]\n");
self.append_field_attributes(fq_message_name, field_name);
self.push_indent();
self.buf
.push_str(&format!("pub {}: ::prost::UnknownFieldSet,\n", field_name,));
}

fn append_oneof_field(
&mut self,
message_name: &str,
Expand Down
29 changes: 29 additions & 0 deletions prost-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ pub struct Config {
bytes_type: PathMap<BytesType>,
type_attributes: PathMap<String>,
field_attributes: PathMap<String>,
include_unknown_fields: PathMap<String>,
prost_types: bool,
strip_enum_prefix: bool,
out_dir: Option<PathBuf>,
Expand Down Expand Up @@ -459,6 +460,32 @@ impl Config {
self
}

/// Preserve unknown fields for the message type.
///
/// # Arguments
///
/// **`paths`** - paths to specific messages, or packages which should preserve unknown
/// fields during deserialization.
///
/// **`field_name`** - the name of the field to place unknown fields in. A field with this
/// name and type `prost::UnknownFieldSet` will be added to the generated struct
///
/// # Examples
///
/// ```rust
/// # let mut config = prost_build::Config::new();
/// config.include_unknown_fields(".my_messages.MyMessageType", "unknown_fields");
/// ```
pub fn include_unknown_fields<P, A>(&mut self, path: P, field_name: A) -> &mut Self
where
P: AsRef<str>,
A: AsRef<str>,
{
self.include_unknown_fields
.insert(path.as_ref().to_string(), field_name.as_ref().to_string());
self
}

/// Configures the code generator to use the provided service generator.
pub fn service_generator(&mut self, service_generator: Box<dyn ServiceGenerator>) -> &mut Self {
self.service_generator = Some(service_generator);
Expand Down Expand Up @@ -1046,6 +1073,7 @@ impl default::Default for Config {
bytes_type: PathMap::default(),
type_attributes: PathMap::default(),
field_attributes: PathMap::default(),
include_unknown_fields: PathMap::default(),
prost_types: true,
strip_enum_prefix: true,
out_dir: None,
Expand All @@ -1068,6 +1096,7 @@ impl fmt::Debug for Config {
.field("bytes_type", &self.bytes_type)
.field("type_attributes", &self.type_attributes)
.field("field_attributes", &self.field_attributes)
.field("include_unknown_fields", &self.include_unknown_fields)
.field("prost_types", &self.prost_types)
.field("strip_enum_prefix", &self.strip_enum_prefix)
.field("out_dir", &self.out_dir)
Expand Down
14 changes: 14 additions & 0 deletions prost-derive/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod map;
mod message;
mod oneof;
mod scalar;
mod unknown;

use std::fmt;
use std::slice;
Expand All @@ -24,6 +25,8 @@ pub enum Field {
Oneof(oneof::Field),
/// A group field.
Group(group::Field),
/// A set of unknown message fields.
Unknown(unknown::Field),
}

impl Field {
Expand All @@ -46,6 +49,8 @@ impl Field {
Field::Oneof(field)
} else if let Some(field) = group::Field::new(&attrs, inferred_tag)? {
Field::Group(field)
} else if let Some(field) = unknown::Field::new(&attrs)? {
Field::Unknown(field)
} else {
bail!("no type attribute");
};
Expand Down Expand Up @@ -84,6 +89,7 @@ impl Field {
Field::Map(ref map) => vec![map.tag],
Field::Oneof(ref oneof) => oneof.tags.clone(),
Field::Group(ref group) => vec![group.tag],
Field::Unknown(_) => vec![],
}
}

Expand All @@ -95,6 +101,7 @@ impl Field {
Field::Map(ref map) => map.encode(ident),
Field::Oneof(ref oneof) => oneof.encode(ident),
Field::Group(ref group) => group.encode(ident),
Field::Unknown(ref unknown) => unknown.encode(ident),
}
}

Expand All @@ -107,6 +114,7 @@ impl Field {
Field::Map(ref map) => map.merge(ident),
Field::Oneof(ref oneof) => oneof.merge(ident),
Field::Group(ref group) => group.merge(ident),
Field::Unknown(ref unknown) => unknown.merge(ident),
}
}

Expand All @@ -118,6 +126,7 @@ impl Field {
Field::Message(ref msg) => msg.encoded_len(ident),
Field::Oneof(ref oneof) => oneof.encoded_len(ident),
Field::Group(ref group) => group.encoded_len(ident),
Field::Unknown(ref unknown) => unknown.encoded_len(ident),
}
}

Expand All @@ -129,6 +138,7 @@ impl Field {
Field::Map(ref map) => map.clear(ident),
Field::Oneof(ref oneof) => oneof.clear(ident),
Field::Group(ref group) => group.clear(ident),
Field::Unknown(ref unknown) => unknown.clear(ident),
}
}

Expand Down Expand Up @@ -171,6 +181,10 @@ impl Field {
_ => None,
}
}

pub fn is_unknown(&self) -> bool {
matches!(self, Field::Unknown(_))
}
}

#[derive(Clone, Copy, PartialEq, Eq)]
Expand Down
66 changes: 66 additions & 0 deletions prost-derive/src/field/unknown.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use anyhow::{bail, Error};
use proc_macro2::TokenStream;
use quote::quote;
use syn::Meta;

use crate::field::{set_bool, word_attr};

#[derive(Clone)]
pub struct Field {}

impl Field {
pub fn new(attrs: &[Meta]) -> Result<Option<Field>, Error> {
let mut unknown = false;
let mut unknown_attrs = Vec::new();

for attr in attrs {
if word_attr("unknown", attr) {
set_bool(&mut unknown, "duplicate message attribute")?;
} else {
unknown_attrs.push(attr);
}
}

if !unknown {
return Ok(None);
}

match unknown_attrs.len() {
0 => (),
1 => bail!(
"unknown attribute for unknown field set: {:?}",
unknown_attrs[0]
),
_ => bail!(
"unknown attributes for unknown field set: {:?}",
unknown_attrs
),
}

Ok(Some(Field {}))
}

pub fn encode(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident.encode_raw(buf)
}
}

pub fn merge(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident.merge_field(tag, wire_type, buf, ctx)
}
}

pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident.encoded_len()
}
}

pub fn clear(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident.clear()
}
}
}
33 changes: 30 additions & 3 deletions prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,17 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
// We want Debug to be in declaration order
let unsorted_fields = fields.clone();

// Sort the fields by tag number so that fields will be encoded in tag order.
// Sort the fields by tag number so that fields will be encoded in tag order,
// and unknown fields are encoded last.
// TODO: This encodes oneof fields in the position of their lowest tag,
// regardless of the currently occupied variant, is that consequential?
// See: https://developers.google.com/protocol-buffers/docs/encoding#order
fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
fields.sort_by_key(|&(_, ref field)| {
(
field.is_unknown(),
field.tags().into_iter().min().unwrap_or(0),
)
});
let fields = fields;

let mut tags = fields
Expand All @@ -101,6 +107,10 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
.map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));

let merge = fields.iter().map(|&(ref field_ident, ref field)| {
if field.is_unknown() {
return quote!();
}

let merge = field.merge(quote!(value));
let tags = field.tags().into_iter().map(|tag| quote!(#tag));
let tags = Itertools::intersperse(tags, quote!(|));
Expand All @@ -115,6 +125,23 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
},
}
});
let merge_fallback = match fields.iter().find(|&(_, f)| f.is_unknown()) {
Some((field_ident, field)) => {
let merge = field.merge(quote!(value));
quote! {
_ => {
let mut value = &mut self.#field_ident;
#merge.map_err(|mut error| {
error.push(STRUCT_NAME, stringify!(#field_ident));
error
})
},
}
}
None => quote! {
_ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
},
};

let struct_name = if fields.is_empty() {
quote!()
Expand Down Expand Up @@ -190,7 +217,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
#struct_name
match tag {
#(#merge)*
_ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
#merge_fallback
}
}

Expand Down
1 change: 1 addition & 0 deletions protobuf/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ fn main() -> Result<()> {
// values.
prost_build::Config::new()
.btree_map(&["."])
.include_unknown_fields(".", "unknown_fields")
.compile_protos(
&[
test_includes.join("test_messages_proto2.proto"),
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ pub use bytes;
mod error;
mod message;
mod types;
mod unknown;

#[doc(hidden)]
pub mod encoding;

pub use crate::error::{DecodeError, EncodeError};
pub use crate::message::Message;
pub use crate::unknown::{UnknownField, UnknownFieldIter, UnknownFieldSet};

use bytes::{Buf, BufMut};

Expand Down
Loading