From a513330f5f6b1070a98aa1e2e70f3f827f8ca8c2 Mon Sep 17 00:00:00 2001 From: piaoliu <441594700@qq.com> Date: Mon, 8 Apr 2019 17:04:09 +0800 Subject: [PATCH] feat: Allow users to customize the version select procedure --- src/builder.rs | 14 ++++++++++++++ src/protocol_select/mod.rs | 35 ++++++++++++++++++++++++----------- src/service/config.rs | 3 ++- src/session.rs | 3 ++- 4 files changed, 42 insertions(+), 13 deletions(-) diff --git a/src/builder.rs b/src/builder.rs index 50ee14f6..8eed2904 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -4,6 +4,7 @@ use std::time::Duration; use tokio::codec::LengthDelimitedCodec; use crate::{ + protocol_select::SelectFn, secio::SecioKeyPair, service::{ config::{Meta, ServiceConfig}, @@ -109,6 +110,7 @@ pub(crate) type NameFn = Box String + Send + Sync>; pub(crate) type CodecFn = Box Box + Send + Sync>; pub(crate) type SessionHandleFn = Box ProtocolHandle> + Send>; +pub(crate) type SelectVersionFn = Box Option> + Send + Sync + 'static>; /// Builder for protocol meta pub struct MetaBuilder { @@ -118,6 +120,7 @@ pub struct MetaBuilder { codec: CodecFn, service_handle: ProtocolHandle>, session_handle: SessionHandleFn, + select_version: SelectVersionFn, } impl MetaBuilder { @@ -191,6 +194,15 @@ impl MetaBuilder { self } + /// Protocol version selection rule, default is [select_version](../protocol_select/fn.select_version.html) + pub fn select_version(mut self, f: T) -> Self + where + T: Fn() -> Option> + Send + Sync + 'static, + { + self.select_version = Box::new(f); + self + } + /// Combine the configuration of this builder to create a ProtocolMeta pub fn build(self) -> ProtocolMeta { let meta = Meta { @@ -198,6 +210,7 @@ impl MetaBuilder { name: self.name, support_versions: self.support_versions, codec: self.codec, + select_version: self.select_version, }; ProtocolMeta { inner: Arc::new(meta), @@ -216,6 +229,7 @@ impl Default for MetaBuilder { codec: Box::new(|| Box::new(LengthDelimitedCodec::new())), service_handle: ProtocolHandle::Neither, session_handle: Box::new(|_| ProtocolHandle::Neither), + select_version: Box::new(|| None), } } } diff --git a/src/protocol_select/mod.rs b/src/protocol_select/mod.rs index 5db142c2..272fadde 100644 --- a/src/protocol_select/mod.rs +++ b/src/protocol_select/mod.rs @@ -22,6 +22,9 @@ mod protocol_select_generated; #[allow(dead_code)] mod protocol_select_generated_verifier; +/// Function for protocol version select +pub type SelectFn = Box Option + Send + 'static>; + /// Protocol Info #[derive(Clone, Debug, Default, Eq, PartialEq)] pub struct ProtocolInfo { @@ -140,7 +143,7 @@ pub(crate) fn client_select( /// plus the protocol name, plus the version option. pub(crate) fn server_select( handle: T, - proto_infos: HashMap, + proto_infos: HashMap>)>, ) -> impl Future, String, Option), Error = io::Error> { let socket = Framed::new(handle, LengthDelimitedCodec::new()); @@ -167,14 +170,24 @@ pub(crate) fn server_select( return Err(err); } }; - let version = proto_infos - .remove(&remote_info.name) - .and_then(|local_info| { - select_version( - &local_info.support_versions, - &remote_info.support_versions, - ) - }); + let version = + proto_infos + .remove(&remote_info.name) + .and_then(|(local_info, select)| { + select + .map(|f| { + f( + &local_info.support_versions, + &remote_info.support_versions, + ) + }) + .unwrap_or_else(|| { + select_version( + &local_info.support_versions, + &remote_info.support_versions, + ) + }) + }); Ok((socket, remote_info.name, version)) }) }) @@ -195,7 +208,7 @@ pub(crate) fn server_select( /// Choose the highest version of the two sides, assume that slices are sorted #[inline] -fn select_version(local: &[T], remote: &[T]) -> Option { +pub fn select_version(local: &[T], remote: &[T]) -> Option { let (mut local_iter, mut remote_iter) = (local.iter().rev(), remote.iter().rev()); let (mut local, mut remote) = (local_iter.next(), remote_iter.next()); while let (Some(l), Some(r)) = (local, remote) { @@ -263,7 +276,7 @@ mod tests { message.name = "test".to_owned(); message.support_versions = server; let mut messages = HashMap::new(); - messages.insert("test".to_owned(), message); + messages.insert("test".to_owned(), (message, None)); let task = server_select(connect.unwrap(), messages) .map(|(_, _, a)| { diff --git a/src/service/config.rs b/src/service/config.rs index 29fe6a5a..6b90ef03 100644 --- a/src/service/config.rs +++ b/src/service/config.rs @@ -1,5 +1,5 @@ use crate::{ - builder::{CodecFn, NameFn, SessionHandleFn}, + builder::{CodecFn, NameFn, SelectVersionFn, SessionHandleFn}, traits::{Codec, ServiceProtocol, SessionProtocol}, yamux::config::Config as YamuxConfig, ProtocolId, SessionId, @@ -120,6 +120,7 @@ pub(crate) struct Meta { pub(crate) name: NameFn, pub(crate) support_versions: Vec, pub(crate) codec: CodecFn, + pub(crate) select_version: SelectVersionFn, } /// Protocol handle diff --git a/src/session.rs b/src/session.rs index 98157a19..f0e44056 100644 --- a/src/session.rs +++ b/src/session.rs @@ -329,7 +329,8 @@ where .map(|proto_meta| { let name = (proto_meta.name)(proto_meta.id); let proto_info = ProtocolInfo::new(&name, proto_meta.support_versions.clone()); - (name, proto_info) + let select_fn = (proto_meta.select_version)(); + (name, (proto_info, select_fn)) }) .collect();