Skip to content

Commit 0330d92

Browse files
committed
WIP
1 parent 6f3c60f commit 0330d92

File tree

3 files changed

+216
-89
lines changed

3 files changed

+216
-89
lines changed

prost-derive/src/field/scalar.rs

+7-14
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ impl Field {
281281

282282
if let Ty::Enumeration(ref ty) = self.ty {
283283
let set = Ident::new(&format!("set_{}", ident_str), Span::call_site());
284-
let set_doc = format!("Sets `{}` to the provided enum value.", ident_str);
285284
Some(match self.kind {
286285
Kind::Plain(ref default) | Kind::Required(ref default) => {
287286
let get_doc = format!(
@@ -292,12 +291,11 @@ impl Field {
292291
quote! {
293292
#[doc=#get_doc]
294293
pub fn #ident(&self) -> #ty {
295-
#ty::from_i32(self.#ident).unwrap_or(#default)
296-
}
297-
298-
#[doc=#set_doc]
299-
pub fn #set(&mut self, value: #ty) {
300-
self.#ident = value as i32;
294+
if self.#ident.is_valid() {
295+
self.#ident
296+
} else {
297+
#default
298+
}
301299
}
302300
}
303301
}
@@ -310,12 +308,7 @@ impl Field {
310308
quote! {
311309
#[doc=#get_doc]
312310
pub fn #ident(&self) -> #ty {
313-
self.#ident.and_then(#ty::from_i32).unwrap_or(#default)
314-
}
315-
316-
#[doc=#set_doc]
317-
pub fn #set(&mut self, value: #ty) {
318-
self.#ident = ::std::option::Option::Some(value as i32);
311+
self.#ident.map(#ty::from).filter(#ty::is_valid).unwrap_or(#default)
319312
}
320313
}
321314
}
@@ -527,7 +520,7 @@ impl Ty {
527520

528521
pub fn module(&self) -> Ident {
529522
match *self {
530-
Ty::Enumeration(..) => Ident::new("int32", Span::call_site()),
523+
Ty::Enumeration(..) => Ident::new("enumeration", Span::call_site()),
531524
_ => Ident::new(self.as_str(), Span::call_site()),
532525
}
533526
}

prost-derive/src/lib.rs

+49-74
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ use itertools::Itertools;
1212
use proc_macro::TokenStream;
1313
use proc_macro2::Span;
1414
use syn::{
15-
punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
16-
FieldsUnnamed, Ident, Variant,
15+
punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Fields, FieldsNamed,
16+
FieldsUnnamed, Ident, ImplItem, ItemImpl, Variant,
1717
};
1818

1919
mod field;
@@ -230,104 +230,79 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
230230
Ok(expanded.into())
231231
}
232232

233-
#[proc_macro_derive(Message, attributes(prost))]
234-
pub fn message(input: TokenStream) -> TokenStream {
235-
try_message(input).unwrap()
236-
}
237-
238-
fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
239-
let input: DeriveInput = syn::parse(input)?;
240-
let ident = input.ident;
233+
fn try_enumeration(_attr: TokenStream, input: TokenStream) -> Result<TokenStream, Error> {
234+
let mut impl_: ItemImpl = syn::parse(input)?;
241235

242-
if !input.generics.params.is_empty() || input.generics.where_clause.is_some() {
243-
bail!("Message may not be derived for generic type");
236+
if !impl_.generics.params.is_empty() || impl_.generics.where_clause.is_some() {
237+
bail!("enumeration may not be applied to generic types");
244238
}
245239

246-
let punctuated_variants = match input.data {
247-
Data::Enum(DataEnum { variants, .. }) => variants,
248-
Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
249-
Data::Union(..) => bail!("Enumeration can not be derived for a union"),
250-
};
251-
252-
// Map the variants into 'fields'.
253-
let mut variants: Vec<(Ident, Expr)> = Vec::new();
254-
for Variant {
255-
ident,
256-
fields,
257-
discriminant,
258-
..
259-
} in punctuated_variants
260-
{
261-
match fields {
262-
Fields::Unit => (),
263-
Fields::Named(_) | Fields::Unnamed(_) => {
264-
bail!("Enumeration variants may not have fields")
265-
}
266-
}
240+
if impl_.trait_.is_some() {
241+
bail!("enumeration may not be applied to trait impls");
242+
}
267243

268-
match discriminant {
269-
Some((_, expr)) => variants.push((ident, expr)),
270-
None => bail!("Enumeration variants must have a disriminant"),
271-
}
244+
let mut variants = Vec::new();
245+
for item in &impl_.items {
246+
let const_ = match item {
247+
ImplItem::Const(const_) => const_.ident.clone(),
248+
_ => bail!("enumeration may only be applied to impls with only consts"),
249+
};
250+
variants.push(const_);
272251
}
273252

274253
if variants.is_empty() {
275-
panic!("Enumeration must have at least one variant");
254+
bail!("enumeration must be applied to impls with consts");
276255
}
277256

278-
let default = variants[0].0.clone();
279-
280-
let is_valid = variants
281-
.iter()
282-
.map(|&(_, ref value)| quote!(#value => true));
283-
let from = variants.iter().map(
284-
|&(ref variant, ref value)| quote!(#value => ::std::option::Option::Some(#ident::#variant)),
285-
);
257+
let ty = &impl_.self_ty;
258+
let is_valid = quote! {
259+
/// Returns true if the enum's value corresponds to a known variant.
260+
#[inline]
261+
pub fn is_valid(&self) -> bool {
262+
match self {
263+
#(#ty::#variants)|* => true,
264+
_ => false,
265+
}
266+
}
267+
};
268+
impl_
269+
.items
270+
.push(ImplItem::Method(syn::parse(is_valid.into()).unwrap()));
286271

287-
let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
288-
let from_i32_doc = format!(
289-
"Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
290-
ident
291-
);
272+
let default = &variants[0];
273+
let ty = &impl_.self_ty;
292274

293275
let expanded = quote! {
294-
impl #ident {
295-
#[doc=#is_valid_doc]
296-
pub fn is_valid(value: i32) -> bool {
297-
match value {
298-
#(#is_valid,)*
299-
_ => false,
300-
}
301-
}
276+
#impl_
302277

303-
#[doc=#from_i32_doc]
304-
pub fn from_i32(value: i32) -> ::std::option::Option<#ident> {
305-
match value {
306-
#(#from,)*
307-
_ => ::std::option::Option::None,
308-
}
278+
impl ::std::default::Default for #ty {
279+
#[inline]
280+
fn default() -> #ty {
281+
#ty::#default
309282
}
310283
}
311284

312-
impl ::std::default::Default for #ident {
313-
fn default() -> #ident {
314-
#ident::#default
285+
impl ::std::convert::From<i32> for #ty {
286+
#[inline]
287+
fn from(value: i32) -> #ty {
288+
#ty(value)
315289
}
316290
}
317291

318-
impl ::std::convert::From<#ident> for i32 {
319-
fn from(value: #ident) -> i32 {
320-
value as i32
292+
impl ::std::convert::From<#ty> for i32 {
293+
#[inline]
294+
fn from(value: #ty) -> i32 {
295+
value.0
321296
}
322297
}
323298
};
324299

325300
Ok(expanded.into())
326301
}
327302

328-
#[proc_macro_derive(Enumeration, attributes(prost))]
329-
pub fn enumeration(input: TokenStream) -> TokenStream {
330-
try_enumeration(input).unwrap()
303+
#[proc_macro_attribute]
304+
pub fn enumeration(attr: TokenStream, input: TokenStream) -> TokenStream {
305+
try_enumeration(attr, input).unwrap()
331306
}
332307

333308
fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {

src/encoding.rs

+160-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,12 @@ where
381381
Ok(())
382382
}
383383

384-
pub fn skip_field<B>(wire_type: WireType, tag: u32, buf: &mut B, ctx: DecodeContext) -> Result<(), DecodeError>
384+
pub fn skip_field<B>(
385+
wire_type: WireType,
386+
tag: u32,
387+
buf: &mut B,
388+
ctx: DecodeContext,
389+
) -> Result<(), DecodeError>
385390
where
386391
B: Buf,
387392
{
@@ -744,6 +749,160 @@ fixed_width!(
744749
get_i64_le
745750
);
746751

752+
pub mod enumeration {
753+
use super::*;
754+
755+
pub fn encode<T, B>(tag: u32, value: &T, buf: &mut B)
756+
where
757+
T: Copy,
758+
i32: From<T>,
759+
B: BufMut,
760+
{
761+
encode_key(tag, WireType::Varint, buf);
762+
encode_varint(i32::from(*value) as u64, buf);
763+
}
764+
765+
pub fn merge<T, B>(
766+
wire_type: WireType,
767+
value: &mut T,
768+
buf: &mut B,
769+
_: DecodeContext,
770+
) -> Result<(), DecodeError>
771+
where
772+
T: Copy + From<i32>,
773+
i32: From<T>,
774+
B: Buf,
775+
{
776+
check_wire_type(WireType::Varint, wire_type)?;
777+
let new_value = decode_varint(buf)? as i32;
778+
*value = T::from(new_value);
779+
Ok(())
780+
}
781+
782+
pub fn encode_repeated<T, B>(tag: u32, values: &[T], buf: &mut B)
783+
where
784+
T: Copy,
785+
i32: From<T>,
786+
B: BufMut,
787+
{
788+
for value in values {
789+
encode(tag, value, buf);
790+
}
791+
}
792+
793+
pub fn encode_packed<T, B>(tag: u32, values: &[T], buf: &mut B)
794+
where
795+
T: Copy,
796+
i32: From<T>,
797+
B: BufMut,
798+
{
799+
if values.is_empty() {
800+
return;
801+
}
802+
803+
encode_key(tag, WireType::LengthDelimited, buf);
804+
let len: usize = values
805+
.iter()
806+
.map(|v| encoded_len_varint(i32::from(*v) as u64))
807+
.sum();
808+
encode_varint(len as u64, buf);
809+
810+
for value in values {
811+
encode_varint(i32::from(*value) as u64, buf);
812+
}
813+
}
814+
815+
pub fn merge_repeated<T, B>(
816+
wire_type: WireType,
817+
values: &mut Vec<T>,
818+
buf: &mut B,
819+
ctx: DecodeContext,
820+
) -> Result<(), DecodeError>
821+
where
822+
T: Default + Copy + From<i32>,
823+
i32: From<T>,
824+
B: Buf,
825+
{
826+
if wire_type == WireType::LengthDelimited {
827+
// Packed.
828+
merge_loop(values, buf, ctx, |values, buf, ctx| {
829+
let mut value = Default::default();
830+
merge(WireType::Varint, &mut value, buf, ctx)?;
831+
values.push(value);
832+
Ok(())
833+
})
834+
} else {
835+
// Unpacked.
836+
check_wire_type(WireType::Varint, wire_type)?;
837+
let mut value = Default::default();
838+
merge(wire_type, &mut value, buf, ctx)?;
839+
values.push(value);
840+
Ok(())
841+
}
842+
}
843+
844+
pub fn encoded_len<T>(tag: u32, value: &T) -> usize
845+
where
846+
T: Copy,
847+
i32: From<T>,
848+
{
849+
key_len(tag) + encoded_len_varint(i32::from(*value) as u64)
850+
}
851+
852+
pub fn encoded_len_repeated<T>(tag: u32, values: &[T]) -> usize
853+
where
854+
T: Copy,
855+
i32: From<T>,
856+
{
857+
key_len(tag) * values.len()
858+
+ values
859+
.iter()
860+
.map(|value| encoded_len_varint(i32::from(*value) as u64))
861+
.sum::<usize>()
862+
}
863+
864+
pub fn encoded_len_packed<T>(tag: u32, values: &[T]) -> usize
865+
where
866+
T: Copy,
867+
i32: From<T>,
868+
{
869+
if values.is_empty() {
870+
0
871+
} else {
872+
let len = values
873+
.iter()
874+
.map(|value| encoded_len_varint(i32::from(*value) as u64))
875+
.sum::<usize>();
876+
key_len(tag) + encoded_len_varint(len as u64) + len
877+
}
878+
}
879+
880+
#[cfg(test)]
881+
mod test {
882+
use quickcheck::{quickcheck, TestResult};
883+
884+
use super::super::test::{check_collection_type, check_type};
885+
use super::*;
886+
887+
quickcheck! {
888+
fn check(value: i32, tag: u32) -> TestResult {
889+
check_type(value, tag, WireType::Varint,
890+
encode, merge, encoded_len)
891+
}
892+
fn check_repeated(value: Vec<i32>, tag: u32) -> TestResult {
893+
check_collection_type(value, tag, WireType::Varint,
894+
encode_repeated, merge_repeated,
895+
encoded_len_repeated)
896+
}
897+
fn check_packed(value: Vec<i32>, tag: u32) -> TestResult {
898+
check_type(value, tag, WireType::LengthDelimited,
899+
encode_packed, merge_repeated,
900+
encoded_len_packed)
901+
}
902+
}
903+
}
904+
}
905+
747906
/// Macro which emits encoding functions for a length-delimited type.
748907
macro_rules! length_delimited {
749908
($ty:ty) => {

0 commit comments

Comments
 (0)