Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Fix optional enums #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions actix-prost-macros/src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ use crate::prost::parse_attrs;

pub fn process_field(f: &syn::Field) -> (Option<syn::Attribute>, bool) {
let metas = parse_attrs(f.attrs.clone());
for m in metas {
if let (Some(attr), need_serde_as) = parse_meta(&m) {
return (Some(attr), need_serde_as);
}
}

if let syn::Type::Path(ty) = &f.ty {
for m in metas {
if let (Some(attr), need_serde_as) = parse_meta(&ty.path, &m) {
return (Some(attr), need_serde_as);
}
}

if let (Some(attr), need_serde_as) = parse_path(&ty.path) {
return (Some(attr), need_serde_as);
}
Expand Down Expand Up @@ -48,7 +49,7 @@ fn parse_path(name: &syn::Path) -> (Option<syn::Attribute>, bool) {
}
}

fn parse_meta(meta: &syn::Meta) -> (Option<syn::Attribute>, bool) {
fn parse_meta(name: &syn::Path, meta: &syn::Meta) -> (Option<syn::Attribute>, bool) {
let meta = match meta {
syn::Meta::NameValue(value) => value,
_ => return (None, false),
Expand All @@ -58,8 +59,22 @@ fn parse_meta(meta: &syn::Meta) -> (Option<syn::Attribute>, bool) {
syn::Lit::Str(name) => name,
_ => return (None, false),
};
let as_value = format!("serde_with::TryFromInto<{}>", enum_name.value());
(Some(syn::parse_quote!(#[serde_as(as = #as_value)])), true)
let option: syn::Path = syn::parse_quote!(::core::option::Option);
if name.segments.len() == option.segments.len()
&& name
.segments
.iter()
.zip(option.segments.iter())
.all(|(a, b)| a.ident == b.ident)
{
(
Some(syn::parse_quote!(#[serde(default,with="actix_prost::serde::option_enum")])),
false,
)
} else {
let as_value = format!("serde_with::TryFromInto<{}>", enum_name.value());
(Some(syn::parse_quote!(#[serde_as(as = #as_value)])), true)
}
} else if meta.path == syn::parse_quote!(oneof) {
(Some(syn::parse_quote!(#[serde(flatten)])), false)
} else {
Expand Down
47 changes: 47 additions & 0 deletions actix-prost/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,50 @@ pub mod option_bytes {
Ok(helper.map(|Helper(external)| external))
}
}

pub mod option_enum {
use super::*;
use serde::de::DeserializeOwned;
use serde_with::TryFromInto;

#[serde_as]
#[derive(Serialize, Deserialize)]
struct Helper<T: Serialize + DeserializeOwned + TryFrom<i32, Error = String> + Into<i32>>(
#[serde_as(as = "TryFromInto<T>")] i32,
std::marker::PhantomData<T>,
);

pub fn serialize<
'a,
T: Serialize + DeserializeOwned + TryFrom<i32, Error = String> + Into<i32> + Copy,
S,
>(
value: &'a Option<T>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
value
.as_ref()
.map(|x| Helper((*x).into(), std::marker::PhantomData::<T>))
.serialize(serializer)
}

pub fn deserialize<
'de,
T: Serialize + DeserializeOwned + TryFrom<i32, Error = String> + Into<i32> + Copy,
D,
>(
deserializer: D,
) -> Result<Option<T>, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct Helper<T>(T);

let helper = Option::deserialize(deserializer)?;
Ok(helper.map(|Helper(external)| external))
}
}
3 changes: 3 additions & 0 deletions tests/proto/http_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ http:
- selector: "types.TypesRPC.EnumsRPC"
post: /types/enums
body: "*"
- selector: "types.TypesRPC.OptionalEnumsRPC"
post: /types/optional_enums
body: "*"
- selector: "types.TypesRPC.RepeatedRPC"
post: /types/repeated
body: "*"
Expand Down
3 changes: 3 additions & 0 deletions tests/proto/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ enum Values {

message Enums { Values values = 1; }

message OptionalEnums { optional Values values = 1; }

message Repeated { repeated string foo = 1; }

message Maps { map<string, int32> foo = 1; }
Expand Down Expand Up @@ -59,6 +61,7 @@ service TypesRPC {
rpc ScalarsRPC(Scalars) returns (Scalars);
rpc OptionalScalarsRPC(OptionalScalars) returns (OptionalScalars);
rpc EnumsRPC(Enums) returns (Enums);
rpc OptionalEnumsRPC(OptionalEnums) returns (OptionalEnums);
rpc RepeatedRPC(Repeated) returns (Repeated);
rpc MapsRPC(Maps) returns (Maps);
rpc OneOfsRPC(OneOfs) returns (OneOfs);
Expand Down
25 changes: 25 additions & 0 deletions tests/proto/types.swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ paths:
$ref: '#/definitions/typesOneOfs'
tags:
- TypesRPC
/types/optional_enums:
post:
operationId: TypesRPC_OptionalEnumsRPC
responses:
"200":
description: A successful response.
schema:
$ref: '#/definitions/typesOptionalEnums'
default:
description: An unexpected error response.
schema:
$ref: '#/definitions/rpcStatus'
parameters:
- name: body
in: body
required: true
schema:
$ref: '#/definitions/typesOptionalEnums'
tags:
- TypesRPC
/types/optional_scalars:
post:
operationId: TypesRPC_OptionalScalarsRPC
Expand Down Expand Up @@ -317,6 +337,11 @@ definitions:
format: int64
foo:
type: string
typesOptionalEnums:
type: object
properties:
values:
$ref: '#/definitions/typesValues'
typesOptionalScalars:
type: object
properties:
Expand Down
102 changes: 102 additions & 0 deletions tests/src/proto/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ pub struct Enums {
#[actix_prost_macros::serde]
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct OptionalEnums {
#[prost(enumeration = "Values", optional, tag = "1")]
pub values: ::core::option::Option<i32>,
}
#[actix_prost_macros::serde]
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Repeated {
#[prost(string, repeated, tag = "1")]
pub foo: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
Expand Down Expand Up @@ -158,6 +165,13 @@ pub mod types_rpc_actix {
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[actix_prost_macros::serde]
pub struct OptionalEnumsRPCJson {
#[prost(enumeration = "Values", optional, tag = "1")]
pub values: ::core::option::Option<i32>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
#[actix_prost_macros::serde]
pub struct RepeatedRPCJson {
#[prost(string, repeated, tag = "1")]
pub foo: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
Expand Down Expand Up @@ -267,6 +281,28 @@ pub mod types_rpc_actix {
let response = response.into_inner();
Ok(::actix_web::web::Json(response))
}
async fn call_optional_enums_rpc(
service: ::actix_web::web::Data<dyn TypesRpc + Sync + Send + 'static>,
http_request: ::actix_web::HttpRequest,
payload: ::actix_web::web::Payload,
) -> Result<::actix_web::web::Json<OptionalEnums>, ::actix_web::Error> {
let mut payload = payload.into_inner();
let json = <::actix_web::web::Json::<
OptionalEnumsRPCJson,
> as ::actix_web::FromRequest>::from_request(&http_request, &mut payload)
.await?
.into_inner();
let request = OptionalEnums {
values: json.values,
};
let request = ::actix_prost::new_request(request, &http_request);
let response = service
.optional_enums_rpc(request)
.await
.map_err(::actix_prost::map_tonic_error)?;
let response = response.into_inner();
Ok(::actix_web::web::Json(response))
}
async fn call_repeated_rpc(
service: ::actix_web::web::Data<dyn TypesRpc + Sync + Send + 'static>,
http_request: ::actix_web::HttpRequest,
Expand Down Expand Up @@ -369,6 +405,11 @@ pub mod types_rpc_actix {
::actix_web::web::post().to(call_optional_scalars_rpc),
);
config.route("/types/enums", ::actix_web::web::post().to(call_enums_rpc));
config
.route(
"/types/optional_enums",
::actix_web::web::post().to(call_optional_enums_rpc),
);
config.route("/types/repeated", ::actix_web::web::post().to(call_repeated_rpc));
config.route("/types/maps", ::actix_web::web::post().to(call_maps_rpc));
config.route("/types/oneofs", ::actix_web::web::post().to(call_one_ofs_rpc));
Expand Down Expand Up @@ -499,6 +540,25 @@ pub mod types_rpc_client {
let path = http::uri::PathAndQuery::from_static("/types.TypesRPC/EnumsRPC");
self.inner.unary(request.into_request(), path, codec).await
}
pub async fn optional_enums_rpc(
&mut self,
request: impl tonic::IntoRequest<super::OptionalEnums>,
) -> Result<tonic::Response<super::OptionalEnums>, tonic::Status> {
self.inner
.ready()
.await
.map_err(|e| {
tonic::Status::new(
tonic::Code::Unknown,
format!("Service was not ready: {}", e.into()),
)
})?;
let codec = tonic::codec::ProstCodec::default();
let path = http::uri::PathAndQuery::from_static(
"/types.TypesRPC/OptionalEnumsRPC",
);
self.inner.unary(request.into_request(), path, codec).await
}
pub async fn repeated_rpc(
&mut self,
request: impl tonic::IntoRequest<super::Repeated>,
Expand Down Expand Up @@ -593,6 +653,10 @@ pub mod types_rpc_server {
&self,
request: tonic::Request<super::Enums>,
) -> Result<tonic::Response<super::Enums>, tonic::Status>;
async fn optional_enums_rpc(
&self,
request: tonic::Request<super::OptionalEnums>,
) -> Result<tonic::Response<super::OptionalEnums>, tonic::Status>;
async fn repeated_rpc(
&self,
request: tonic::Request<super::Repeated>,
Expand Down Expand Up @@ -780,6 +844,44 @@ pub mod types_rpc_server {
};
Box::pin(fut)
}
"/types.TypesRPC/OptionalEnumsRPC" => {
#[allow(non_camel_case_types)]
struct OptionalEnumsRPCSvc<T: TypesRpc>(pub Arc<T>);
impl<T: TypesRpc> tonic::server::UnaryService<super::OptionalEnums>
for OptionalEnumsRPCSvc<T> {
type Response = super::OptionalEnums;
type Future = BoxFuture<
tonic::Response<Self::Response>,
tonic::Status,
>;
fn call(
&mut self,
request: tonic::Request<super::OptionalEnums>,
) -> Self::Future {
let inner = self.0.clone();
let fut = async move {
(*inner).optional_enums_rpc(request).await
};
Box::pin(fut)
}
}
let accept_compression_encodings = self.accept_compression_encodings;
let send_compression_encodings = self.send_compression_encodings;
let inner = self.inner.clone();
let fut = async move {
let inner = inner.0;
let method = OptionalEnumsRPCSvc(inner);
let codec = tonic::codec::ProstCodec::default();
let mut grpc = tonic::server::Grpc::new(codec)
.apply_compression_config(
accept_compression_encodings,
send_compression_encodings,
);
let res = grpc.unary(method, req).await;
Ok(res)
};
Box::pin(fut)
}
"/types.TypesRPC/RepeatedRPC" => {
#[allow(non_camel_case_types)]
struct RepeatedRPCSvc<T: TypesRpc>(pub Arc<T>);
Expand Down
11 changes: 10 additions & 1 deletion tests/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
proto::types::{
types_rpc_actix::route_types_rpc, types_rpc_server::TypesRpc, Complex, Enums, Maps, OneOfs,
OptionalScalars, Repeated, Scalars,
OptionalEnums, OptionalScalars, Repeated, Scalars,
},
test,
};
Expand All @@ -27,6 +27,12 @@ impl TypesRpc for TypesServer {
async fn enums_rpc(&self, request: Request<Enums>) -> Result<Response<Enums>, Status> {
Ok(Response::new(request.into_inner()))
}
async fn optional_enums_rpc(
&self,
request: Request<OptionalEnums>,
) -> Result<Response<OptionalEnums>, Status> {
Ok(Response::new(request.into_inner()))
}
async fn repeated_rpc(&self, request: Request<Repeated>) -> Result<Response<Repeated>, Status> {
Ok(Response::new(request.into_inner()))
}
Expand Down Expand Up @@ -90,6 +96,9 @@ async fn ping() {
)
.await;
assert_ping(&addr, "/types/enums", r#"{"values":"BAR"}"#.into()).await;
assert_ping(&addr, "/types/optional_enums", r#"{"values":"BAR"}"#.into()).await;
assert_ping(&addr, "/types/optional_enums", r#"{"values":null}"#.into()).await;
assert_ping(&addr, "/types/optional_enums", r#"{}"#.into()).await;
assert_ping(
&addr,
"/types/repeated",
Expand Down