Skip to content

Commit

Permalink
Use a Vec<UnknownField> instead of Map.
Browse files Browse the repository at this point in the history
This allows accessing all elements in a repeated field..
  • Loading branch information
wildarch committed May 7, 2020
1 parent 753a510 commit e564c56
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 48 deletions.
5 changes: 2 additions & 3 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,8 @@ impl<'a> CodeGenerator<'a> {
self.push_indent();
self.buf.push_str("#[prost(unknown_fields)]\n");
self.push_indent();
self.buf.push_str(
"pub protobuf_unknown_fields: std::collections::HashMap<u32, Vec<u8>>,\n",
);
self.buf
.push_str("pub protobuf_unknown_fields: Vec<::prost::UnknownField>,\n");
}

self.depth -= 1;
Expand Down
9 changes: 6 additions & 3 deletions prost-derive/src/field/unknown_fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ impl Field {

pub fn encode(&self, ident: TokenStream) -> TokenStream {
quote! {
for (tag, bytes) in #ident.iter() {
::prost::encoding::bytes::encode(*tag, bytes, buf);
for field in #ident.iter() {
::prost::encoding::bytes::encode(field.tag, &field.value, buf);
}
}
}
Expand All @@ -33,7 +33,10 @@ impl Field {

pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
quote! {
#ident.iter().map(|(tag, value)| ::prost::encoding::bytes::encoded_len(*tag, value)).sum::<usize>()
#ident
.iter()
.map(|field| ::prost::encoding::bytes::encoded_len(field.tag, &field.value))
.sum::<usize>()
}
}

Expand Down
72 changes: 38 additions & 34 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#![allow(clippy::implicit_hasher, clippy::ptr_arg)]

use std::cmp::min;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::mem;
use std::str;
Expand All @@ -16,6 +15,7 @@ use ::bytes::{buf::ext::BufExt, Buf, BufMut};

use crate::DecodeError;
use crate::Message;
use crate::UnknownField;

/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
/// The buffer must have enough remaining space (maximum 10 bytes).
Expand Down Expand Up @@ -393,44 +393,51 @@ pub fn skip_field<B>(
where
B: Buf,
{
ctx.limit_reached()?;
let len = match wire_type {
WireType::Varint => decode_varint(buf).map(|_| 0)?,
WireType::ThirtyTwoBit => 4,
WireType::SixtyFourBit => 8,
WireType::LengthDelimited => decode_varint(buf)?,
WireType::StartGroup => loop {
let (inner_tag, inner_wire_type) = decode_key(buf)?;
match inner_wire_type {
WireType::EndGroup => {
if inner_tag != tag {
return Err(DecodeError::new("unexpected end group tag"));
}
break 0;
}
_ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
}
handle_unknown_field(
wire_type,
tag,
buf,
&mut |_, len, buf| {
buf.advance(len);
},
WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
};

if len > buf.remaining() as u64 {
return Err(DecodeError::new("buffer underflow"));
}

buf.advance(len as usize);
Ok(())
ctx,
)
}

pub fn unknown_field<B>(
wire_type: WireType,
tag: u32,
buf: &mut B,
unknown_fields_map: &mut HashMap<u32, Vec<u8>>,
unknown_fields: &mut Vec<UnknownField>,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
{
handle_unknown_field(
wire_type,
tag,
buf,
&mut |tag, len, buf| {
let mut value = Vec::new();
value.resize(len as usize, 0);
buf.copy_to_slice(value.as_mut_slice());
unknown_fields.push(UnknownField { tag, value });
},
ctx,
)
}

fn handle_unknown_field<B, F>(
wire_type: WireType,
tag: u32,
buf: &mut B,
handler: &mut F,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
F: FnMut(u32, usize, &mut B),
{
ctx.limit_reached()?;
let len = match wire_type {
Expand All @@ -447,11 +454,11 @@ where
}
break 0;
}
_ => unknown_field(
_ => handle_unknown_field(
inner_wire_type,
inner_tag,
buf,
unknown_fields_map,
handler,
ctx.enter_recursion(),
)?,
}
Expand All @@ -463,10 +470,7 @@ where
return Err(DecodeError::new("buffer underflow"));
}

let mut field_value = Vec::new();
field_value.resize(len as usize, 0);
buf.copy_to_slice(field_value.as_mut_slice());
let _ = unknown_fields_map.insert(tag, field_value);
handler(tag, len as usize, buf);
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod types;
pub mod encoding;

pub use crate::error::{DecodeError, EncodeError};
pub use crate::message::Message;
pub use crate::message::{Message, UnknownField};

use bytes::{Buf, BufMut};

Expand Down
6 changes: 6 additions & 0 deletions src/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,9 @@ where
(**self).clear()
}
}

#[derive(Debug, Default, Clone, PartialEq)]
pub struct UnknownField {
pub tag: u32,
pub value: Vec<u8>,
}
5 changes: 5 additions & 0 deletions tests/src/unknown_fields.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,8 @@ message MessageWithExtraFields {
string normal_field = 1;
string extra_field = 2;
}

message MessageWithRepeatedExtraFields {
string normal_field = 1;
repeated string extra_field = 2;
}
65 changes: 58 additions & 7 deletions tests/src/unknown_fields.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use prost::Message;
use prost::{Message, UnknownField};

mod unknown_fields {
include!(concat!(env!("OUT_DIR"), "/unknown_fields.rs"));
Expand All @@ -15,11 +15,62 @@ fn test_access_unknown_field() {
let message = unknown_fields::MessageWithUnknownFields::decode(&encoded[..])
.expect("Could not decode as MessageWithUnknownFields");

let extra_field = message
.protobuf_unknown_fields
.get(&2)
.expect("extra_field not in unknown_fields");
assert_eq!(
message.protobuf_unknown_fields,
vec![UnknownField {
tag: 2,
value: b"extra".to_vec()
}]
);
}

#[test]
fn test_serialize_unknown_field() {
let message = unknown_fields::MessageWithUnknownFields {
normal_field: "normal".to_string(),
protobuf_unknown_fields: vec![UnknownField {
tag: 2,
value: b"extra".to_vec(),
}],
};
let mut encoded = Vec::new();
message.encode(&mut encoded).unwrap();
let message = unknown_fields::MessageWithExtraFields::decode(&encoded[..])
.expect("Could not decode as MessageWithExtraFields");

assert_eq!(message.extra_field, "extra");
}

#[test]
fn test_access_repeated_unknown_field() {
let message = unknown_fields::MessageWithRepeatedExtraFields {
normal_field: "normal".to_string(),
extra_field: vec![
"repeated".to_string(),
"extra".to_string(),
"repeated".to_string(),
],
};
let mut encoded = Vec::new();
message.encode(&mut encoded).unwrap();
let message = unknown_fields::MessageWithUnknownFields::decode(&encoded[..])
.expect("Could not decode as MessageWithUnknownFields");

assert_eq!(extra_field, b"extra");
assert_eq!(message.protobuf_unknown_fields.len(), 1);
assert_eq!(
message.protobuf_unknown_fields,
vec![
UnknownField {
tag: 2,
value: b"repeated".to_vec(),
},
UnknownField {
tag: 2,
value: b"extra".to_vec(),
},
UnknownField {
tag: 2,
value: b"repeated".to_vec(),
},
]
);
}

0 comments on commit e564c56

Please sign in to comment.