diff --git a/src/protocol/resp/src/request/badd.rs b/src/protocol/resp/src/request/badd.rs new file mode 100644 index 000000000..ff68c5564 --- /dev/null +++ b/src/protocol/resp/src/request/badd.rs @@ -0,0 +1,145 @@ +// Copyright 2022 Twitter, Inc. +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use super::*; +use std::io::{Error, ErrorKind}; +use std::sync::Arc; + +type ArcByteSlice = Arc>; +type ArcKeyValuePair = (ArcByteSlice, ArcByteSlice); + +/// Represents the btree add command which was added to Twitter's internal +/// version of redis32. +/// format is: badd outer_key (inner_key value)+ +#[derive(Debug, PartialEq, Eq)] +pub struct BAddRequest { + outer_key: Arc>, + inner_key_value_pairs: Arc>, +} + +impl BAddRequest { + pub fn outer_key(&self) -> &[u8] { + &self.outer_key + } + + pub fn inner_key_value_pairs(&self) -> Box<[(&[u8], &[u8])]> { + self.inner_key_value_pairs + .iter() + .map(|(k, v)| (&***k, &***v)) + .collect::>() + .into_boxed_slice() + } +} + +impl TryFrom for BAddRequest { + type Error = Error; + + fn try_from(other: Message) -> Result { + if let Message::Array(array) = other { + if array.inner.is_none() { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + let mut array = array.inner.unwrap(); + + if array.len() < 4 { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + if array.len() % 2 == 1 { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + let outer_key = take_bulk_string(&mut array)?; + if outer_key.is_empty() { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + //loop as long as we have at least 2 arguments after the command + let mut pairs = Vec::with_capacity(array.len() / 2); + while array.len() >= 3 { + let inner_key = take_bulk_string(&mut array)?; + if inner_key.is_empty() { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + let value = take_bulk_string(&mut array)?; + if value.is_empty() { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + + pairs.push((inner_key, value)); + } + + Ok(Self { + outer_key, + inner_key_value_pairs: Arc::new(Box::<[ArcKeyValuePair]>::from(pairs)), + }) + } else { + Err(Error::new(ErrorKind::Other, "malformed command")) + } + } +} + +impl From<&BAddRequest> for Message { + fn from(other: &BAddRequest) -> Message { + let mut v = vec![ + Message::bulk_string(b"BADD"), + Message::BulkString(BulkString::from(other.outer_key.clone())), + ]; + for kv in (*other.inner_key_value_pairs).iter() { + v.push(Message::BulkString(BulkString::from(kv.0.clone()))); + v.push(Message::BulkString(BulkString::from(kv.1.clone()))); + } + + Message::Array(Array { inner: Some(v) }) + } +} + +impl Compose for BAddRequest { + fn compose(&self, buf: &mut dyn BufMut) -> usize { + let message = Message::from(self); + message.compose(buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parser() { + let parser = RequestParser::new(); + + //1 key value pair + if let Request::BAdd(request) = parser + .parse(b"badd outer inner 42\r\n") + .unwrap() + .into_inner() + { + assert_eq!(request.outer_key(), b"outer"); + assert_eq!(request.inner_key_value_pairs.len(), 1); + assert_eq!(request.inner_key_value_pairs()[0].0, b"inner"); + assert_eq!(request.inner_key_value_pairs()[0].1, b"42"); + } else { + panic!("invalid parse result"); + } + + //> 1 key value pairs + if let Request::BAdd(request) = parser + .parse(b"badd outer inner 42 inner2 7*6\r\n") + .unwrap() + .into_inner() + { + assert_eq!(request.outer_key(), b"outer"); + assert_eq!(request.inner_key_value_pairs.len(), 2); + assert_eq!(request.inner_key_value_pairs()[0].0, b"inner"); + assert_eq!(request.inner_key_value_pairs()[0].1, b"42"); + assert_eq!(request.inner_key_value_pairs()[1].0, b"inner2"); + assert_eq!(request.inner_key_value_pairs()[1].1, b"7*6"); + } else { + panic!("invalid parse result"); + } + } +} diff --git a/src/protocol/resp/src/request/mod.rs b/src/protocol/resp/src/request/mod.rs index 4bae15391..534ccebc8 100644 --- a/src/protocol/resp/src/request/mod.rs +++ b/src/protocol/resp/src/request/mod.rs @@ -10,9 +10,11 @@ use protocol_common::ParseOk; use std::io::{Error, ErrorKind}; use std::sync::Arc; +mod badd; mod get; mod set; +pub use badd::BAddRequest; pub use get::GetRequest; pub use set::SetRequest; @@ -87,6 +89,9 @@ impl Parse for RequestParser { match &array[0] { Message::BulkString(c) => match c.inner.as_ref().map(|v| v.as_ref().as_ref()) { + Some(b"badd") | Some(b"BADD") => { + BAddRequest::try_from(message).map(Request::from) + } Some(b"get") | Some(b"GET") => { GetRequest::try_from(message).map(Request::from) } @@ -113,6 +118,7 @@ impl Parse for RequestParser { impl Compose for Request { fn compose(&self, buf: &mut dyn BufMut) -> usize { match self { + Self::BAdd(r) => r.compose(buf), Self::Get(r) => r.compose(buf), Self::Set(r) => r.compose(buf), } @@ -121,10 +127,17 @@ impl Compose for Request { #[derive(Debug, PartialEq, Eq)] pub enum Request { + BAdd(BAddRequest), Get(GetRequest), Set(SetRequest), } +impl From for Request { + fn from(other: BAddRequest) -> Self { + Self::BAdd(other) + } +} + impl From for Request { fn from(other: GetRequest) -> Self { Self::Get(other) @@ -139,6 +152,7 @@ impl From for Request { #[derive(Debug, PartialEq, Eq)] pub enum Command { + BAdd, Get, Set, } @@ -148,6 +162,7 @@ impl TryFrom<&[u8]> for Command { fn try_from(other: &[u8]) -> Result { match other { + b"badd" | b"BADD" => Ok(Command::BAdd), b"get" | b"GET" => Ok(Command::Get), b"set" | b"SET" => Ok(Command::Set), _ => Err(()), diff --git a/src/proxy/momento/src/frontend.rs b/src/proxy/momento/src/frontend.rs index 8cbb70fd4..55a7729c2 100644 --- a/src/proxy/momento/src/frontend.rs +++ b/src/proxy/momento/src/frontend.rs @@ -102,6 +102,11 @@ pub(crate) async fn handle_resp_client( break; } } + _ => { + println!("bad request"); + let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; + break; + } } buf.advance(consumed); }