From 58d29a6f7c477e4c5940b816674b3e9b466a3b4e Mon Sep 17 00:00:00 2001 From: upupnoah Date: Tue, 16 Jul 2024 00:31:54 +0700 Subject: [PATCH] feat(resp-decode): support resp decode for simple redis --- Cargo.lock | 21 +++ Cargo.toml | 1 + examples/bytes_mut.rs | 15 ++ examples/enum_dispatch.rs | 39 ++++ examples/thiserorr_anyhow.rs | 57 ++++++ src/resp.rs | 152 +++++++++++++-- src/resp/decode.rs | 353 +++++++++++++++++++++++++++++++++++ 7 files changed, 625 insertions(+), 13 deletions(-) create mode 100644 examples/bytes_mut.rs create mode 100644 examples/enum_dispatch.rs create mode 100644 examples/thiserorr_anyhow.rs create mode 100644 src/resp/decode.rs diff --git a/Cargo.lock b/Cargo.lock index 6730f5a..3a52224 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,6 +57,7 @@ dependencies = [ "anyhow", "bytes", "enum_dispatch", + "thiserror", ] [[package]] @@ -70,6 +71,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2675633b1499176c2dff06b0856a27976a8f9d436737b4cf4f312d4d91d8bbb" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d20468752b09f49e909e55a5d338caa8bedf615594e9d80bc4c565d30faf798c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/Cargo.toml b/Cargo.toml index 3bdd43a..4ad1dbb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,3 +10,4 @@ authors = ["Noah "] anyhow = "^1.0" bytes = "^1.6.1" enum_dispatch = "^0.3.13" +thiserror = "^1.0.62" diff --git a/examples/bytes_mut.rs b/examples/bytes_mut.rs new file mode 100644 index 0000000..fc0c6e5 --- /dev/null +++ b/examples/bytes_mut.rs @@ -0,0 +1,15 @@ +use anyhow::Result; +use bytes::BytesMut; +fn main() -> Result<()> { + let a = "hello"; // a.as_bytes() = b"hello" + + // bytes_mut + let mut bytes_mut = BytesMut::new(); + bytes_mut.extend_from_slice(a.as_bytes()); + println!("bytes_mut: {:?}", bytes_mut); + + let b = bytes_mut.split_to(3); + println!("b: {:?}", b); + println!("after split_to(3) -> bytes_mut: {:?}", bytes_mut); + Ok(()) +} diff --git a/examples/enum_dispatch.rs b/examples/enum_dispatch.rs new file mode 100644 index 0000000..da21222 --- /dev/null +++ b/examples/enum_dispatch.rs @@ -0,0 +1,39 @@ +use enum_dispatch::enum_dispatch; + +#[enum_dispatch] +trait DoSomething { + fn do_something(&self); +} + +#[enum_dispatch(DoSomething)] +enum Types { + Apple(A), + Banana(B), +} + +struct A; +struct B; + +impl DoSomething for A { + fn do_something(&self) { + println!("A"); + } +} + +impl DoSomething for B { + fn do_something(&self) { + println!("B"); + } +} +fn main() { + // test enum_dispatch + let apple = Types::Apple(A); + let banana = Types::Banana(B); + + let type_apple = apple; + let type_banana = banana; + + // 都是 types 类型的, 但是结果不同, enum_dispatch 相当于是主动帮我 match 了 + type_apple.do_something(); + type_banana.do_something(); +} diff --git a/examples/thiserorr_anyhow.rs b/examples/thiserorr_anyhow.rs new file mode 100644 index 0000000..8439727 --- /dev/null +++ b/examples/thiserorr_anyhow.rs @@ -0,0 +1,57 @@ +use anyhow::{Context, Result}; +use std::io::Error as IoError; +use std::num::ParseIntError; +use thiserror::Error; + +// 定义自定义错误类型 +#[derive(Error, Debug)] +enum MyError { + #[error("An IO error occurred: {0}")] + Io(#[from] IoError), + + #[error("A parsing error occurred: {0}")] + Parse(#[from] ParseIntError), + + #[error("Custom error: {0}")] + Custom(String), + + #[error("Anyhow error: {0}")] + Anyhow(#[from] anyhow::Error), +} + +// 一个可能返回错误的函数 +fn parse_number(input: &str) -> Result { + let trimmed = input.trim(); + if trimmed.is_empty() { + return Err(MyError::Custom("Input is empty".into())); + } + + let number: i32 = trimmed + .parse() + // .map_err(|e| MyError::Parse(e)) + .map_err(MyError::Parse) // 更好的写法 + .context("Failed to parse number")?; + Ok(number) +} + +fn main() -> Result<(), MyError> { + // 示例一: 正确的输入 + match parse_number("42") { + Ok(number) => println!("Parsed number: {}", number), + Err(e) => eprintln!("Error: {}", e), + } + + // 示例二: 空输入 + match parse_number("") { + Ok(number) => println!("Parsed number: {}", number), + Err(e) => eprintln!("Error: {}", e), + } + + // 示例三: 无效输入 + match parse_number("abc") { + Ok(number) => println!("Parsed number: {}", number), + Err(e) => eprintln!("Error: {}", e), + } + + Ok(()) +} diff --git a/src/resp.rs b/src/resp.rs index 6de9f66..ed9d3b2 100644 --- a/src/resp.rs +++ b/src/resp.rs @@ -1,18 +1,60 @@ -use std::collections::BTreeMap; - -use bytes::BytesMut; +use bytes::{Buf, BytesMut}; use enum_dispatch::enum_dispatch; +use std::collections::BTreeMap; +use thiserror::Error; +mod decode; mod encode; +const CRLF: &[u8] = b"\r\n"; +const CRLF_LEN: usize = CRLF.len(); + #[enum_dispatch] pub trait RespEncode { fn encode(self) -> Vec; } -pub trait RespDecode { - fn decode(buf: Self) -> Result; +// Sized 表示这个 trait 只能被 [大小确定的类型] 实现 +// 因为 decode 方法的返回值是一个 Self, 因此必须将这个 trait 标记为 Sized +pub trait RespDecode: Sized { + const PREFIX: &'static str; + fn decode(buf: &mut BytesMut) -> Result; + fn expect_length(buf: &[u8]) -> Result; +} + +#[derive(Error, Debug, PartialEq, Eq)] +pub enum RespError { + // region: --- thiserror format usage + + // #[error("{var}")] ⟶ write!("{}", self.var) + // #[error("{0}")] ⟶ write!("{}", self.0) + // #[error("{var:?}")] ⟶ write!("{:?}", self.var) + // #[error("{0:?}")] ⟶ write!("{:?}", self.0) + + // endregion: --- thiserror format usage + #[error("Invalid frame: {0}")] // 这里的 0 表示 self.0。 会转化为 write! + InvalidFrame(String), + #[error("Invalid frame type: {0}")] + InvalidFrameType(String), + #[error("Invalid frame length: {0}")] + InvalidFrameLength(isize), + #[error("Frame is not complete")] + NotComplete, + + #[error("Parse error: {0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("Utf8 error: {0}")] + Utf8Error(#[from] std::string::FromUtf8Error), + #[error("Parse float error: {0}")] + ParseFloatError(#[from] std::num::ParseFloatError), } + +// pub trait RespDecode: Sized { +// const PREFIX: &'static str; +// fn decode(buf: &mut BytesMut) -> Result; +// fn expect_length(buf: &[u8]) -> Result; +// } + // 之所以要定义一些新的结构体, 是因为要在实现 trait 的时候, 要区分开这些类型 #[enum_dispatch(RespEncode)] pub enum RespFrame { @@ -35,12 +77,19 @@ pub enum RespFrame { // 2. 新类型模式:这是 Rust 中常用的一种模式,用于在类型系统层面区分不同用途的相同底层类型。比如,你可能想区分普通的字符串和特定格式的字符串。 // 3. 添加方法:你可以为 SimpleString 实现方法,这些方法特定于这种类型的字符串。 // 4. 语义清晰:在复杂的数据结构中(如你展示的 RespFrame 枚举),使用 SimpleString 而不是直接使用 String 可以使代码的意图更加明确。 +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] pub struct SimpleString(String); // Simple String, 用于存储简单字符串 +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] pub struct SimpleError(String); +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] pub struct BulkString(Vec); // 单个二进制字符串, 用于存储二进制数据(最大512MB) +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] pub struct RespNullBulkString; + pub struct RespArray(Vec); +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] pub struct RespNullArray; +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)] pub struct RespNull; #[derive(Default)] pub struct RespMap(BTreeMap); // 改为 BTreeMap, 用于有序的 key-value 数据 @@ -84,14 +133,91 @@ impl RespSet { } } -impl RespDecode for BytesMut { - fn decode(_buf: Self) -> Result { - todo!() +// utility functions +fn extract_fixed_data( + buf: &mut BytesMut, + expect: &str, + expect_type: &str, +) -> Result<(), RespError> { + if buf.len() < expect.len() { + return Err(RespError::NotComplete); } + + if !buf.starts_with(expect.as_bytes()) { + return Err(RespError::InvalidFrameType(format!( + "expect: {}, got: {:?}", + expect_type, buf + ))); + } + + buf.advance(expect.len()); + Ok(()) } -// impl RespEncode for RespFrame { -// fn encode(self) -> Vec { -// todo!() -// } -// } +fn extract_simple_frame_data(buf: &[u8], prefix: &str) -> Result { + if buf.len() < 3 { + return Err(RespError::NotComplete); + } + + if !buf.starts_with(prefix.as_bytes()) { + return Err(RespError::InvalidFrameType(format!( + "expect: SimpleString({}), got: {:?}", + prefix, buf + ))); + } + + let end = find_crlf(buf, 1).ok_or(RespError::NotComplete)?; + + Ok(end) +} + +// find nth CRLF in the buffer +fn find_crlf(buf: &[u8], nth: usize) -> Option { + let mut count = 0; + for i in 1..buf.len() - 1 { + if buf[i] == b'\r' && buf[i + 1] == b'\n' { + count += 1; + if count == nth { + return Some(i); + } + } + } + None +} + +fn parse_length(buf: &[u8], prefix: &str) -> Result<(usize, usize), RespError> { + let end = extract_simple_frame_data(buf, prefix)?; + let s = String::from_utf8_lossy(&buf[prefix.len()..end]); + Ok((end, s.parse()?)) +} + +fn calc_total_length(buf: &[u8], end: usize, len: usize, prefix: &str) -> Result { + let mut total = end + CRLF_LEN; + let mut data = &buf[total..]; + match prefix { + "*" | "~" => { + // find nth CRLF in the buffer, for array and set, we need to find 1 CRLF for each element + for _ in 0..len { + let len = RespFrame::expect_length(data)?; + data = &data[len..]; + total += len; + } + Ok(total) + } + "%" => { + // find nth CRLF in the buffer. For map, we need to find 2 CRLF for each key-value pair + for _ in 0..len { + let len = SimpleString::expect_length(data)?; + + data = &data[len..]; + total += len; + + let len = RespFrame::expect_length(data)?; + data = &data[len..]; + total += len; + } + Ok(total) + } + _ => Ok(len + CRLF_LEN), + } +} diff --git a/src/resp/decode.rs b/src/resp/decode.rs new file mode 100644 index 0000000..a511d42 --- /dev/null +++ b/src/resp/decode.rs @@ -0,0 +1,353 @@ +use bytes::{Buf, BytesMut}; + +use crate::{ + BulkString, RespArray, RespDecode, RespError, RespFrame, RespMap, RespNull, RespNullArray, + RespNullBulkString, RespSet, SimpleError, SimpleString, +}; + +use super::{ + calc_total_length, extract_fixed_data, extract_simple_frame_data, parse_length, CRLF_LEN, +}; + +impl RespDecode for RespFrame { + const PREFIX: &'static str = ""; + fn decode(buf: &mut BytesMut) -> Result { + let mut iter = buf.iter().peekable(); + match iter.peek() { + Some(b'+') => { + let frame = SimpleString::decode(buf)?; + Ok(frame.into()) + } + Some(b'-') => { + let frame = SimpleError::decode(buf)?; + Ok(frame.into()) + } + Some(b':') => { + let frame = i64::decode(buf)?; + Ok(frame.into()) + } + Some(b'$') => { + // try null bulk string first + match RespNullBulkString::decode(buf) { + Ok(frame) => Ok(frame.into()), + Err(RespError::NotComplete) => Err(RespError::NotComplete), + Err(_) => { + let frame = BulkString::decode(buf)?; + Ok(frame.into()) + } + } + } + Some(b'*') => { + // try null array first + match RespNullArray::decode(buf) { + Ok(frame) => Ok(frame.into()), + Err(RespError::NotComplete) => Err(RespError::NotComplete), + Err(_) => { + let frame = RespArray::decode(buf)?; + Ok(frame.into()) + } + } + } + Some(b'_') => { + let frame = RespNull::decode(buf)?; + Ok(frame.into()) + } + Some(b'#') => { + let frame = bool::decode(buf)?; + Ok(frame.into()) + } + Some(b',') => { + let frame = f64::decode(buf)?; + Ok(frame.into()) + } + Some(b'%') => { + let frame = RespMap::decode(buf)?; + Ok(frame.into()) + } + Some(b'~') => { + let frame = RespSet::decode(buf)?; + Ok(frame.into()) + } + None => Err(RespError::NotComplete), + _ => Err(RespError::InvalidFrameType(format!( + "expect_length: unknown frame type: {:?}", + buf + ))), + } + } + + fn expect_length(buf: &[u8]) -> Result { + let mut iter = buf.iter().peekable(); + match iter.peek() { + Some(b'*') => RespArray::expect_length(buf), + Some(b'~') => RespSet::expect_length(buf), + Some(b'%') => RespMap::expect_length(buf), + Some(b'$') => BulkString::expect_length(buf), + Some(b':') => i64::expect_length(buf), + Some(b'+') => SimpleString::expect_length(buf), + Some(b'-') => SimpleError::expect_length(buf), + Some(b'#') => bool::expect_length(buf), + Some(b',') => f64::expect_length(buf), + Some(b'_') => RespNull::expect_length(buf), + _ => Err(RespError::NotComplete), + } + } +} + +// - simple string: "+\r\n" +impl RespDecode for SimpleString { + const PREFIX: &'static str = "+"; + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + // split the buffer + let data = buf.split_to(end + CRLF_LEN); // 把 buf 截断, 返回截止到 at 位置的数据 + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(SimpleString::new(s.to_string())) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +// - error: "-Error message\r\n" +impl RespDecode for SimpleError { + const PREFIX: &'static str = "-"; + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + // split the buffer + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(SimpleError::new(s.to_string())) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +// - integer: ":[<+|->]\r\n" +impl RespDecode for i64 { + const PREFIX: &'static str = ":"; + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + // split the buffer + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(s.parse()?) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +// - bulk string: "$\r\n\r\n" +impl RespDecode for BulkString { + const PREFIX: &'static str = "$"; + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + let remained = &buf[end + CRLF_LEN..]; + if remained.len() < len + CRLF_LEN { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let data = buf.split_to(len + CRLF_LEN); + Ok(BulkString::new(data[..len].to_vec())) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN + len + CRLF_LEN) + } +} + +// - null bulk string: "$-1\r\n" +impl RespDecode for RespNullBulkString { + const PREFIX: &'static str = "$"; + fn decode(buf: &mut BytesMut) -> Result { + extract_fixed_data(buf, "$-1\r\n", "NullBulkString")?; + Ok(RespNullBulkString) + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(5) + } +} + +// - array: "*\r\n..." +impl RespDecode for RespArray { + const PREFIX: &'static str = "*"; + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; + + if buf.len() < total_len { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let mut frames = Vec::with_capacity(len); + for _ in 0..len { + frames.push(RespFrame::decode(buf)?); + } + + Ok(RespArray::new(frames)) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + calc_total_length(buf, end, len, Self::PREFIX) + } +} + +// - null array: "*-1\r\n" +impl RespDecode for RespNullArray { + const PREFIX: &'static str = "*"; + fn decode(buf: &mut BytesMut) -> Result { + extract_fixed_data(buf, "*-1\r\n", "NullArray")?; + Ok(RespNullArray) + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(4) + } +} + +// - nulls: "_\r\n" +impl RespDecode for RespNull { + const PREFIX: &'static str = "_"; + fn decode(buf: &mut BytesMut) -> Result { + extract_fixed_data(buf, "_\r\n", "Null")?; + Ok(RespNull) + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(3) + } +} + +// - boolean: "#\r\n" +impl RespDecode for bool { + const PREFIX: &'static str = "#"; + fn decode(buf: &mut BytesMut) -> Result { + match extract_fixed_data(buf, "#t\r\n", "Bool") { + Ok(_) => Ok(true), + Err(RespError::NotComplete) => Err(RespError::NotComplete), + Err(_) => match extract_fixed_data(buf, "#f\r\n", "Bool") { + Ok(_) => Ok(false), + Err(e) => Err(e), + }, + } + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(4) + } +} + +// - double: ",[<+|->][.][[sign]]\r\n" +impl RespDecode for f64 { + const PREFIX: &'static str = ","; + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(s.parse()?) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +// - map: "%\r\n..." +impl RespDecode for RespMap { + const PREFIX: &'static str = "%"; + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; + + if buf.len() < total_len { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let mut frames = RespMap::new(); + for _ in 0..len { + let key = SimpleString::decode(buf)?; + let value = RespFrame::decode(buf)?; + frames.insert(key.0, value); + } + + Ok(frames) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + calc_total_length(buf, end, len, Self::PREFIX) + } +} + +// - set: "~\r\n..." +impl RespDecode for RespSet { + const PREFIX: &'static str = "~"; + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; + + if buf.len() < total_len { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let mut frames = Vec::new(); + for _ in 0..len { + frames.push(RespFrame::decode(buf)?); + } + + Ok(RespSet::new(frames)) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + calc_total_length(buf, end, len, Self::PREFIX) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use anyhow::Result; + use bytes::BufMut; + + #[test] + fn test_simple_string_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"+OK\r\n"); + + let frame = SimpleString::decode(&mut buf)?; + assert_eq!(frame, SimpleString::new("OK".to_string())); + + buf.extend_from_slice(b"+hello\r"); + + let ret = SimpleString::decode(&mut buf); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + buf.put_u8(b'\n'); + let frame = SimpleString::decode(&mut buf)?; + assert_eq!(frame, SimpleString::new("hello".to_string())); + + Ok(()) + } +}