diff --git a/assets/mini-redis/.github/workflows/ci.yml b/assets/mini-redis/.github/workflows/ci.yml new file mode 100644 index 0000000..513a5ca --- /dev/null +++ b/assets/mini-redis/.github/workflows/ci.yml @@ -0,0 +1,24 @@ +name: CI + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --verbose + - name: rustfmt + uses: actions-rs/cargo@v1 + with: + command: fmt + args: --all -- --check diff --git a/assets/mini-redis/.gitignore b/assets/mini-redis/.gitignore new file mode 100644 index 0000000..53eaa21 --- /dev/null +++ b/assets/mini-redis/.gitignore @@ -0,0 +1,2 @@ +/target +**/*.rs.bk diff --git a/assets/mini-redis/Cargo.lock b/assets/mini-redis/Cargo.lock new file mode 100644 index 0000000..629806e --- /dev/null +++ b/assets/mini-redis/Cargo.lock @@ -0,0 +1,695 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "ansi_term" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee49baf6cb617b853aa8d93bf420db2383fab46d314482ca2803b40d5fde979b" +dependencies = [ + "winapi", +] + +[[package]] +name = "ansi_term" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d52a9bb7ec0cf484c551830a7ce27bd20d67eac647e1befb56b0be4ee39a55d2" +dependencies = [ + "winapi", +] + +[[package]] +name = "async-stream" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "171374e7e3b2504e0e5236e3b59260560f9fe94bfe9ac39ba5e4e929c5590625" +dependencies = [ + "async-stream-impl", + "futures-core", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "648ed8c8d2ce5409ccd57453d9d1b214b342a0d69376a6feda1fd6cae3299308" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "atoi" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c897df197d57c25b37df9d8fa2f93ddbfeee9ebd2264350ac79c8ec4b795885" +dependencies = [ + "num-traits", +] + +[[package]] +name = "atty" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" +dependencies = [ + "hermit-abi", + "libc", + "winapi", +] + +[[package]] +name = "autocfg" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a" + +[[package]] +name = "bitflags" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf1de2fe8c75bc145a2f577add951f8134889b4795d47466a54a5c846d691693" + +[[package]] +name = "bytes" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b700ce4376041dcd0a327fd0097c41095743c4c8af8887265942faf1100bd040" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "chrono" +version = "0.4.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670ad68c9088c2a963aaa298cb369688cf3f9465ce5e2d4ca10e6e0098a1ce73" +dependencies = [ + "libc", + "num-integer", + "num-traits", + "winapi", +] + +[[package]] +name = "clap" +version = "2.33.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37e58ac78573c40708d45522f0d80fa2f01cc4f9b4e2bf749807255454312002" +dependencies = [ + "ansi_term 0.11.0", + "atty", + "bitflags", + "strsim", + "textwrap", + "unicode-width", + "vec_map", +] + +[[package]] +name = "futures-core" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" + +[[package]] +name = "heck" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d621efb26863f0e9924c6ac577e8275e5e6b77455db64ffa6c65c904e9e132c" +dependencies = [ + "unicode-segmentation", +] + +[[package]] +name = "hermit-abi" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" +dependencies = [ + "libc", +] + +[[package]] +name = "instant" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee0328b1209d157ef001c94dd85b4f8f64139adb0eac2659f4b08382b2f474d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "itoa" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd25036021b0de88a0aff6b850051563c6516d0bf53f8638938edbb9de732736" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.98" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320cfe77175da3a483efed4bc0adc1968ca050b098ce4f2f1c13a56626128790" + +[[package]] +name = "lock_api" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0382880606dff6d15c9476c416d18690b72742aa7b605bb6dd6ec9030fbf07eb" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51b9bbe6c47d51fc3e1a9b945965946b4c44142ab8792c50835a980d362c2710" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "matchers" +version = "0.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f099785f7595cc4b4553a174ce30dd7589ef93391ff414dbb67f62392b9e0ce1" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "memchr" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b16bd47d9e329435e309c58469fe0791c2d0d1ba96ec0954152a5ae2b04387dc" + +[[package]] +name = "mini-redis" +version = "0.4.1" +dependencies = [ + "async-stream", + "atoi", + "bytes", + "structopt", + "tokio", + "tokio-stream", + "tracing", + "tracing-futures", + "tracing-subscriber", +] + +[[package]] +name = "mio" +version = "0.7.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c2bdb6314ec10835cd3293dd268473a835c02b7b352e788be788b3c6ca6bb16" +dependencies = [ + "libc", + "log", + "miow", + "ntapi", + "winapi", +] + +[[package]] +name = "miow" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9f1c5b025cda876f66ef43a113f91ebc9f4ccef34843000e0adf6ebbab84e21" +dependencies = [ + "winapi", +] + +[[package]] +name = "ntapi" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6bb902e437b6d86e03cce10a7e2af662292c5dfef23b65899ea3ac9354ad44" +dependencies = [ + "winapi", +] + +[[package]] +name = "num-integer" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05499f3756671c15885fee9034446956fff3f243d6077b91e5767df161f766b3" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "once_cell" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56" + +[[package]] +name = "parking_lot" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d7744ac029df22dca6284efe4e898991d28e3085c706c972bcd7da4a27a15eb" +dependencies = [ + "instant", + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7a782938e745763fe6907fc6ba86946d72f49fe7e21de074e08128a99fb018" +dependencies = [ + "cfg-if", + "instant", + "libc", + "redox_syscall", + "smallvec", + "winapi", +] + +[[package]] +name = "pin-project" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7509cc106041c40a4518d2af7a61530e1eed0e6285296a3d8c5472806ccc4a4" +dependencies = [ + "pin-project-internal", +] + +[[package]] +name = "pin-project-internal" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c950132583b500556b1efd71d45b319029f2b71518d979fcc208e16b42426f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443" + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + +[[package]] +name = "proc-macro2" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038" +dependencies = [ + "unicode-xid", +] + +[[package]] +name = "quote" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d0b9745dc2debf507c8422de05d7226cc1f0644216dfdfead988f9b1ab32a7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "redox_syscall" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ab49abadf3f9e1c4bc499e8845e152ad87d2ad2d30371841171169e9d75feee" +dependencies = [ + "bitflags", +] + +[[package]] +name = "regex" +version = "1.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461" +dependencies = [ + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.6.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" + +[[package]] +name = "ryu" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71d301d4193d031abdd79ff7e3dd721168a9572ef3fe51a1517aba235bd8f86e" + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + +[[package]] +name = "serde" +version = "1.0.126" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec7505abeacaec74ae4778d9d9328fe5a5d04253220a85c4ee022239fc996d03" + +[[package]] +name = "serde_json" +version = "1.0.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "799e97dc9fdae36a5c8b8f2cae9ce2ee9fdce2058c57a93e6099d919fd982f79" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sharded-slab" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79c719719ee05df97490f80a45acfc99e5a30ce98a1e4fb67aee422745ae14e3" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51e73328dc4ac0c7ccbda3a494dfa03df1de2f46018127f60c693f2648455b0" +dependencies = [ + "libc", +] + +[[package]] +name = "smallvec" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe0f37c9e8f3c5a4a66ad655a93c74daac4ad00c441533bf5c6e7990bb42604e" + +[[package]] +name = "strsim" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a" + +[[package]] +name = "structopt" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69b041cdcb67226aca307e6e7be44c8806423d83e018bd662360a93dabce4d71" +dependencies = [ + "clap", + "lazy_static", + "structopt-derive", +] + +[[package]] +name = "structopt-derive" +version = "0.4.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7813934aecf5f51a54775e00068c237de98489463968231a51746bbbc03f9c10" +dependencies = [ + "heck", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "syn" +version = "1.0.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f71489ff30030d2ae598524f61326b902466f72a0fb1a8564c001cc63425bcc7" +dependencies = [ + "proc-macro2", + "quote", + "unicode-xid", +] + +[[package]] +name = "textwrap" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060" +dependencies = [ + "unicode-width", +] + +[[package]] +name = "thread_local" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8018d24e04c95ac8790716a5987d0fec4f8b27249ffa0f7d33f1369bdfb88cbd" +dependencies = [ + "once_cell", +] + +[[package]] +name = "tokio" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98c8b05dc14c75ea83d63dd391100353789f5f24b8b3866542a5e85c8be8e985" +dependencies = [ + "autocfg", + "bytes", + "libc", + "memchr", + "mio", + "num_cpus", + "once_cell", + "parking_lot", + "pin-project-lite", + "signal-hook-registry", + "tokio-macros", + "winapi", +] + +[[package]] +name = "tokio-macros" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54473be61f4ebe4efd09cec9bd5d16fa51d70ea0192213d754d2d500457db110" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-stream" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2f3f698253f03119ac0102beaa64f67a67e08074d03a22d18784104543727f" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tracing" +version = "0.1.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09adeb8c97449311ccd28a427f96fb563e7fd31aabf994189879d9da2394b89d" +dependencies = [ + "cfg-if", + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c42e6fa53307c8a17e4ccd4dc81cf5ec38db9209f59b222210375b54ee40d1e2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9ff14f98b1a4b289c6248a023c1c2fa1491062964e9fed67ab29c4e4da4a052" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "tracing-futures" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97d095ae15e245a057c8e8451bab9b3ee1e1f68e9ba2b4fbc18d0ac5237835f2" +dependencies = [ + "pin-project", + "tracing", +] + +[[package]] +name = "tracing-log" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6923477a48e41c1951f1999ef8bb5a3023eb723ceadafe78ffb65dc366761e3" +dependencies = [ + "lazy_static", + "log", + "tracing-core", +] + +[[package]] +name = "tracing-serde" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb65ea441fbb84f9f6748fd496cf7f63ec9af5bca94dd86456978d055e8eb28b" +dependencies = [ + "serde", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab69019741fca4d98be3c62d2b75254528b5432233fd8a4d2739fec20278de48" +dependencies = [ + "ansi_term 0.12.1", + "chrono", + "lazy_static", + "matchers", + "regex", + "serde", + "serde_json", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", + "tracing-serde", +] + +[[package]] +name = "unicode-segmentation" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8895849a949e7845e06bd6dc1aa51731a103c42707010a5b591c0038fb73385b" + +[[package]] +name = "unicode-width" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9337591893a19b88d8d87f2cec1e73fad5cdfd10e5a6f349f498ad6ea2ffb1e3" + +[[package]] +name = "unicode-xid" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3" + +[[package]] +name = "vec_map" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bddf1187be692e79c5ffeab891132dfb0f236ed36a43c7ed39f1165ee20191" + +[[package]] +name = "version_check" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fecdca9a5291cc2b8dcf7dc02453fee791a280f3743cb0905f8822ae463b3fe" + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" diff --git a/assets/mini-redis/Cargo.toml b/assets/mini-redis/Cargo.toml new file mode 100644 index 0000000..7ed7c81 --- /dev/null +++ b/assets/mini-redis/Cargo.toml @@ -0,0 +1,36 @@ +[package] +authors = ["Carl Lerche "] +edition = "2018" +name = "mini-redis" +version = "0.4.1" +license = "MIT" +readme = "README.md" +documentation = "https://docs.rs/mini-redis/0.4.0/mini-redis/" +repository = "https://github.com/tokio-rs/mini-redis" +description = """ +An incomplete implementation of a Rust client and server. Used as a +larger example of an idiomatic Tokio application. +""" + +[[bin]] +name = "mini-redis-cli" +path = "src/bin/cli.rs" + +[[bin]] +name = "mini-redis-server" +path = "src/bin/server.rs" + +[dependencies] +async-stream = "0.3.0" +atoi = "0.3.2" +bytes = "1" +structopt = "0.3.14" +tokio = { version = "1", features = ["full"] } +tokio-stream = "0.1" +tracing = "0.1.13" +tracing-futures = { version = "0.2.3" } +tracing-subscriber = "0.2.2" + +[dev-dependencies] +# Enable test-utilities in dev mode only. This is mostly for tests. +tokio = { version = "1", features = ["test-util"] } diff --git a/assets/mini-redis/LICENSE b/assets/mini-redis/LICENSE new file mode 100644 index 0000000..243fcd6 --- /dev/null +++ b/assets/mini-redis/LICENSE @@ -0,0 +1,25 @@ +Copyright (c) 2020 Tokio Contributors + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/assets/mini-redis/README.md b/assets/mini-redis/README.md new file mode 100644 index 0000000..cf83b51 --- /dev/null +++ b/assets/mini-redis/README.md @@ -0,0 +1,169 @@ +# mini-redis + +本项目从[tokio/mini-redis](https://github.com/tokio-rs/mini-redis)fork而来,作为rust course的练习项目之一,**文档和注释还未进行翻译**,欢迎大家贡献。 + + +`mini-redis` is an incomplete, idiomatic implementation of a +[Redis](https://redis.io) client and server built with +[Tokio](https://tokio.rs). + +The intent of this project is to provide a larger example of writing a Tokio +application. + +**Disclaimer** Please don't use mini-redis in production. This project is +intended to be a learning resource, and omits various parts of the Redis +protocol because implementing them would not introduce any new concepts. We will +not add new features because you need them in your project — use one of the +fully featured alternatives instead. + +## Why Redis + +The primary goal of this project is teaching Tokio. Doing this requires a +project with a wide range of features with a focus on implementation simplicity. +Redis, an in-memory database, provides a wide range of features and uses a +simple wire protocol. The wide range of features allows demonstrating many Tokio +patterns in a "real world" context. + +The Redis wire protocol documentation can be found [here](https://redis.io/topics/protocol). + +The set of commands Redis provides can be found +[here](https://redis.io/commands). + + +## Running + +The repository provides a server, client library, and some client executables +for interacting with the server. + +Start the server: + +``` +RUST_LOG=debug cargo run --bin mini-redis-server +``` + +The [`tracing`](https://github.com/tokio-rs/tracing) crate is used to provide structured logs. +You can substitute `debug` with the desired [log level][level]. + +[level]: https://docs.rs/tracing-subscriber/latest/tracing_subscriber/filter/struct.EnvFilter.html#directives + +Then, in a different terminal window, the various client [examples](examples) +can be executed. For example: + +``` +cargo run --example hello_world +``` + +Additionally, a CLI client is provided to run arbitrary commands from the +terminal. With the server running, the following works: + +``` +cargo run --bin mini-redis-cli set foo bar + +cargo run --bin mini-redis-cli get foo +``` + +## Supported commands + +`mini-redis` currently supports the following commands. + +* [GET](https://redis.io/commands/get) +* [SET](https://redis.io/commands/set) +* [PUBLISH](https://redis.io/commands/publish) +* [SUBSCRIBE](https://redis.io/commands/subscribe) + +The Redis wire protocol specification can be found +[here](https://redis.io/topics/protocol). + +There is no support for persistence yet. + +## Tokio patterns + +The project demonstrates a number of useful patterns, including: + +### TCP server + +[`server.rs`](src/server.rs) starts a TCP server that accepts connections, +and spawns a new task per connection. It gracefully handles `accept` errors. + +### Client library + +[`client.rs`](src/client.rs) shows how to model an asynchronous client. The +various capabilities are exposed as `async` methods. + +### State shared across sockets + +The server maintains a [`Db`] instance that is accessible from all connected +connections. The [`Db`] instance manages the key-value state as well as pub/sub +capabilities. + +[`Db`]: src/db.rs + +### Framing + +[`connection.rs`](src/connection.rs) and [`frame.rs`](src/frame.rs) show how to +idiomatically implement a wire protocol. The protocol is modeled using an +intermediate representation, the `Frame` structure. `Connection` takes a +`TcpStream` and exposes an API that sends and receives `Frame` values. + +### Graceful shutdown + +The server implements graceful shutdown. [`tokio::signal`] is used to listen for +a SIGINT. Once the signal is received, shutdown begins. The server stops +accepting new connections. Existing connections are notified to shutdown +gracefully. In-flight work is completed, and the connection is closed. + +[`tokio::signal`]: https://docs.rs/tokio/*/tokio/signal/ + +### Concurrent connection limiting + +The server uses a [`Semaphore`] limits the maximum number of concurrent +connections. Once the limit is reached, the server stops accepting new +connections until an existing one terminates. + +[`Semaphore`]: https://docs.rs/tokio/*/tokio/sync/struct.Semaphore.html + +### Pub/Sub + +The server implements non-trivial pub/sub capability. The client may subscribe +to multiple channels and update its subscription at any time. The server +implements this using one [broadcast channel][broadcast] per channel and a +[`StreamMap`] per connection. Clients are able to send subscription commands to +the server to update the active subscriptions. + +[broadcast]: https://docs.rs/tokio/*/tokio/sync/broadcast/index.html +[`StreamMap`]: https://docs.rs/tokio/*/tokio/stream/struct.StreamMap.html + +### Using a `std::sync::Mutex` in an async application + +The server uses a `std::sync::Mutex` and **not** a Tokio mutex to synchronize +access to shared state. See [`db.rs`](src/db.rs) for more details. + +### Testing asynchronous code that relies on time + +In [`tests/server.rs`](tests/server.rs), there are tests for key expiration. +These tests depend on time passing. In order to make the tests deterministic, +time is mocked out using Tokio's testing utilities. + +## Contributing + +Contributions to `mini-redis` are welcome. Keep in mind, the goal of the project +is **not** to reach feature parity with real Redis, but to demonstrate +asynchronous Rust patterns with Tokio. + +Commands or other features should only be added if doing so is useful to +demonstrate a new pattern. + +Contributions should come with extensive comments targetted to new Tokio users. + +Contributions that only focus on clarifying and improving comments are very +welcome. + +## License + +This project is licensed under the [MIT license](LICENSE). + +### Contribution + +Unless you explicitly state otherwise, any contribution intentionally submitted +for inclusion in `mini-redis` by you, shall be licensed as MIT, without any +additional terms or conditions. diff --git a/assets/mini-redis/examples/chat.rs b/assets/mini-redis/examples/chat.rs new file mode 100644 index 0000000..607aa27 --- /dev/null +++ b/assets/mini-redis/examples/chat.rs @@ -0,0 +1,4 @@ +#[tokio::main] +async fn main() { + unimplemented!(); +} diff --git a/assets/mini-redis/examples/hello_world.rs b/assets/mini-redis/examples/hello_world.rs new file mode 100644 index 0000000..34d2ae8 --- /dev/null +++ b/assets/mini-redis/examples/hello_world.rs @@ -0,0 +1,32 @@ +//! Hello world server. +//! +//! A simple client that connects to a mini-redis server, sets key "hello" with value "world", +//! and gets it from the server after +//! +//! You can test this out by running: +//! +//! cargo run --bin mini-redis-server +//! +//! And then in another terminal run: +//! +//! cargo run --example hello_world + +#![warn(rust_2018_idioms)] + +use mini_redis::{client, Result}; + +#[tokio::main] +pub async fn main() -> Result<()> { + // Open a connection to the mini-redis address. + let mut client = client::connect("127.0.0.1:6379").await?; + + // Set the key "hello" with value "world" + client.set("hello", "world".into()).await?; + + // Get key "hello" + let result = client.get("hello").await?; + + println!("got value from the server; success={:?}", result.is_some()); + + Ok(()) +} diff --git a/assets/mini-redis/examples/pub.rs b/assets/mini-redis/examples/pub.rs new file mode 100644 index 0000000..bdae6dd --- /dev/null +++ b/assets/mini-redis/examples/pub.rs @@ -0,0 +1,31 @@ +//! Publish to a redis channel example. +//! +//! A simple client that connects to a mini-redis server, and +//! publishes a message on `foo` channel +//! +//! You can test this out by running: +//! +//! cargo run --bin mini-redis-server +//! +//! Then in another terminal run: +//! +//! cargo run --example sub +//! +//! And then in another terminal run: +//! +//! cargo run --example pub + +#![warn(rust_2018_idioms)] + +use mini_redis::{client, Result}; + +#[tokio::main] +async fn main() -> Result<()> { + // Open a connection to the mini-redis address. + let mut client = client::connect("127.0.0.1:6379").await?; + + // publish message `bar` on channel foo + client.publish("foo", "bar".into()).await?; + + Ok(()) +} diff --git a/assets/mini-redis/examples/sub.rs b/assets/mini-redis/examples/sub.rs new file mode 100644 index 0000000..69f179d --- /dev/null +++ b/assets/mini-redis/examples/sub.rs @@ -0,0 +1,37 @@ +//! Subscribe to a redis channel example. +//! +//! A simple client that connects to a mini-redis server, subscribes to "foo" and "bar" channels +//! and awaits messages published on those channels +//! +//! You can test this out by running: +//! +//! cargo run --bin mini-redis-server +//! +//! Then in another terminal run: +//! +//! cargo run --example sub +//! +//! And then in another terminal run: +//! +//! cargo run --example pub + + + +use mini_redis::{client, Result}; +use tokio_stream::StreamExt; +#[tokio::main] +pub async fn main() -> Result<()> { + // Open a connection to the mini-redis address. + let client = client::connect("127.0.0.1:6379").await?; + + // subscribe to channel foo + let mut subscriber = client.subscribe(vec!["foo".into()]).await?; + let messages = subscriber.into_stream(); + tokio::pin!(messages); + // await messages on channel foo + while let Some(msg) = messages.next().await { + println!("got = {:?}", msg); + } + + Ok(()) +} diff --git a/assets/mini-redis/src/bin/cli.rs b/assets/mini-redis/src/bin/cli.rs new file mode 100644 index 0000000..dfd539f --- /dev/null +++ b/assets/mini-redis/src/bin/cli.rs @@ -0,0 +1,108 @@ +use mini_redis::{client, DEFAULT_PORT}; + +use bytes::Bytes; +use std::num::ParseIntError; +use std::str; +use std::time::Duration; +use structopt::StructOpt; + +#[derive(StructOpt, Debug)] +#[structopt(name = "mini-redis-cli", author = env!("CARGO_PKG_AUTHORS"), about = "Issue Redis commands")] +struct Cli { + #[structopt(subcommand)] + command: Command, + + #[structopt(name = "hostname", long = "--host", default_value = "127.0.0.1")] + host: String, + + #[structopt(name = "port", long = "--port", default_value = DEFAULT_PORT)] + port: String, +} + +#[derive(StructOpt, Debug)] +enum Command { + /// Get the value of key. + Get { + /// Name of key to get + key: String, + }, + /// Set key to hold the string value. + Set { + /// Name of key to set + key: String, + + /// Value to set. + #[structopt(parse(from_str = bytes_from_str))] + value: Bytes, + + /// Expire the value after specified amount of time + #[structopt(parse(try_from_str = duration_from_ms_str))] + expires: Option, + }, +} + +/// Entry point for CLI tool. +/// +/// The `[tokio::main]` annotation signals that the Tokio runtime should be +/// started when the function is called. The body of the function is executed +/// within the newly spawned runtime. +/// +/// `flavor = "current_thread"` is used here to avoid spawning background +/// threads. The CLI tool use case benefits more by being lighter instead of +/// multi-threaded. +#[tokio::main(flavor = "current_thread")] +async fn main() -> mini_redis::Result<()> { + // Enable logging + tracing_subscriber::fmt::try_init()?; + + // Parse command line arguments + let cli = Cli::from_args(); + + // Get the remote address to connect to + let addr = format!("{}:{}", cli.host, cli.port); + + // Establish a connection + let mut client = client::connect(&addr).await?; + + // Process the requested command + match cli.command { + Command::Get { key } => { + if let Some(value) = client.get(&key).await? { + if let Ok(string) = str::from_utf8(&value) { + println!("\"{}\"", string); + } else { + println!("{:?}", value); + } + } else { + println!("(nil)"); + } + } + Command::Set { + key, + value, + expires: None, + } => { + client.set(&key, value).await?; + println!("OK"); + } + Command::Set { + key, + value, + expires: Some(expires), + } => { + client.set_expires(&key, value, expires).await?; + println!("OK"); + } + } + + Ok(()) +} + +fn duration_from_ms_str(src: &str) -> Result { + let ms = src.parse::()?; + Ok(Duration::from_millis(ms)) +} + +fn bytes_from_str(src: &str) -> Bytes { + Bytes::from(src.to_string()) +} diff --git a/assets/mini-redis/src/bin/server.rs b/assets/mini-redis/src/bin/server.rs new file mode 100644 index 0000000..2f76ad7 --- /dev/null +++ b/assets/mini-redis/src/bin/server.rs @@ -0,0 +1,37 @@ +//! mini-redis server. +//! +//! This file is the entry point for the server implemented in the library. It +//! performs command line parsing and passes the arguments on to +//! `mini_redis::server`. +//! +//! The `clap` crate is used for parsing arguments. + +use mini_redis::{server, DEFAULT_PORT}; + +use structopt::StructOpt; +use tokio::net::TcpListener; +use tokio::signal; + +#[tokio::main] +pub async fn main() -> mini_redis::Result<()> { + // enable logging + // see https://docs.rs/tracing for more info + tracing_subscriber::fmt::try_init()?; + + let cli = Cli::from_args(); + let port = cli.port.as_deref().unwrap_or(DEFAULT_PORT); + + // Bind a TCP listener + let listener = TcpListener::bind(&format!("127.0.0.1:{}", port)).await?; + + server::run(listener, signal::ctrl_c()).await; + + Ok(()) +} + +#[derive(StructOpt, Debug)] +#[structopt(name = "mini-redis-server", version = env!("CARGO_PKG_VERSION"), author = env!("CARGO_PKG_AUTHORS"), about = "A Redis server")] +struct Cli { + #[structopt(name = "port", long = "--port")] + port: Option, +} diff --git a/assets/mini-redis/src/blocking_client.rs b/assets/mini-redis/src/blocking_client.rs new file mode 100644 index 0000000..962a1e9 --- /dev/null +++ b/assets/mini-redis/src/blocking_client.rs @@ -0,0 +1,264 @@ +//! Minimal blocking Redis client implementation +//! +//! Provides a blocking connect and methods for issuing the supported commands. + +use bytes::Bytes; +use std::time::Duration; +use tokio::net::ToSocketAddrs; +use tokio::runtime::Runtime; + +pub use crate::client::Message; + +/// Established connection with a Redis server. +/// +/// Backed by a single `TcpStream`, `BlockingClient` provides basic network +/// client functionality (no pooling, retrying, ...). Connections are +/// established using the [`connect`](fn@connect) function. +/// +/// Requests are issued using the various methods of `Client`. +pub struct BlockingClient { + /// The asynchronous `Client`. + inner: crate::client::Client, + + /// A `current_thread` runtime for executing operations on the asynchronous + /// client in a blocking manner. + rt: Runtime, +} + +/// A client that has entered pub/sub mode. +/// +/// Once clients subscribe to a channel, they may only perform pub/sub related +/// commands. The `BlockingClient` type is transitioned to a +/// `BlockingSubscriber` type in order to prevent non-pub/sub methods from being +/// called. +pub struct BlockingSubscriber { + /// The asynchronous `Subscriber`. + inner: crate::client::Subscriber, + + /// A `current_thread` runtime for executing operations on the asynchronous + /// `Subscriber` in a blocking manner. + rt: Runtime, +} + +/// The iterator returned by `Subscriber::into_iter`. +struct SubscriberIterator { + /// The asynchronous `Subscriber`. + inner: crate::client::Subscriber, + + /// A `current_thread` runtime for executing operations on the asynchronous + /// `Subscriber` in a blocking manner. + rt: Runtime, +} + +/// Establish a connection with the Redis server located at `addr`. +/// +/// `addr` may be any type that can be asynchronously converted to a +/// `SocketAddr`. This includes `SocketAddr` and strings. The `ToSocketAddrs` +/// trait is the Tokio version and not the `std` version. +/// +/// # Examples +/// +/// ```no_run +/// use mini_redis::blocking_client; +/// +/// fn main() { +/// let client = match blocking_client::connect("localhost:6379") { +/// Ok(client) => client, +/// Err(_) => panic!("failed to establish connection"), +/// }; +/// # drop(client); +/// } +/// ``` +pub fn connect(addr: T) -> crate::Result { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + + let inner = rt.block_on(crate::client::connect(addr))?; + + Ok(BlockingClient { inner, rt }) +} + +impl BlockingClient { + /// Get the value of key. + /// + /// If the key does not exist the special value `None` is returned. + /// + /// # Examples + /// + /// Demonstrates basic usage. + /// + /// ```no_run + /// use mini_redis::blocking_client; + /// + /// fn main() { + /// let mut client = blocking_client::connect("localhost:6379").unwrap(); + /// + /// let val = client.get("foo").unwrap(); + /// println!("Got = {:?}", val); + /// } + /// ``` + pub fn get(&mut self, key: &str) -> crate::Result> { + self.rt.block_on(self.inner.get(key)) + } + + /// Set `key` to hold the given `value`. + /// + /// The `value` is associated with `key` until it is overwritten by the next + /// call to `set` or it is removed. + /// + /// If key already holds a value, it is overwritten. Any previous time to + /// live associated with the key is discarded on successful SET operation. + /// + /// # Examples + /// + /// Demonstrates basic usage. + /// + /// ```no_run + /// use mini_redis::blocking_client; + /// + /// fn main() { + /// let mut client = blocking_client::connect("localhost:6379").unwrap(); + /// + /// client.set("foo", "bar".into()).unwrap(); + /// + /// // Getting the value immediately works + /// let val = client.get("foo").unwrap().unwrap(); + /// assert_eq!(val, "bar"); + /// } + /// ``` + pub fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> { + self.rt.block_on(self.inner.set(key, value)) + } + + /// Set `key` to hold the given `value`. The value expires after `expiration` + /// + /// The `value` is associated with `key` until one of the following: + /// - it expires. + /// - it is overwritten by the next call to `set`. + /// - it is removed. + /// + /// If key already holds a value, it is overwritten. Any previous time to + /// live associated with the key is discarded on a successful SET operation. + /// + /// # Examples + /// + /// Demonstrates basic usage. This example is not **guaranteed** to always + /// work as it relies on time based logic and assumes the client and server + /// stay relatively synchronized in time. The real world tends to not be so + /// favorable. + /// + /// ```no_run + /// use mini_redis::blocking_client; + /// use std::thread; + /// use std::time::Duration; + /// + /// fn main() { + /// let ttl = Duration::from_millis(500); + /// let mut client = blocking_client::connect("localhost:6379").unwrap(); + /// + /// client.set_expires("foo", "bar".into(), ttl).unwrap(); + /// + /// // Getting the value immediately works + /// let val = client.get("foo").unwrap().unwrap(); + /// assert_eq!(val, "bar"); + /// + /// // Wait for the TTL to expire + /// thread::sleep(ttl); + /// + /// let val = client.get("foo").unwrap(); + /// assert!(val.is_some()); + /// } + /// ``` + pub fn set_expires( + &mut self, + key: &str, + value: Bytes, + expiration: Duration, + ) -> crate::Result<()> { + self.rt + .block_on(self.inner.set_expires(key, value, expiration)) + } + + /// Posts `message` to the given `channel`. + /// + /// Returns the number of subscribers currently listening on the channel. + /// There is no guarantee that these subscribers receive the message as they + /// may disconnect at any time. + /// + /// # Examples + /// + /// Demonstrates basic usage. + /// + /// ```no_run + /// use mini_redis::blocking_client; + /// + /// fn main() { + /// let mut client = blocking_client::connect("localhost:6379").unwrap(); + /// + /// let val = client.publish("foo", "bar".into()).unwrap(); + /// println!("Got = {:?}", val); + /// } + /// ``` + pub fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result { + self.rt.block_on(self.inner.publish(channel, message)) + } + + /// Subscribes the client to the specified channels. + /// + /// Once a client issues a subscribe command, it may no longer issue any + /// non-pub/sub commands. The function consumes `self` and returns a + /// `BlockingSubscriber`. + /// + /// The `BlockingSubscriber` value is used to receive messages as well as + /// manage the list of channels the client is subscribed to. + pub fn subscribe(self, channels: Vec) -> crate::Result { + let subscriber = self.rt.block_on(self.inner.subscribe(channels))?; + Ok(BlockingSubscriber { + inner: subscriber, + rt: self.rt, + }) + } +} + +impl BlockingSubscriber { + /// Returns the set of channels currently subscribed to. + pub fn get_subscribed(&self) -> &[String] { + self.inner.get_subscribed() + } + + /// Receive the next message published on a subscribed channel, waiting if + /// necessary. + /// + /// `None` indicates the subscription has been terminated. + pub fn next_message(&mut self) -> crate::Result> { + self.rt.block_on(self.inner.next_message()) + } + + /// Convert the subscriber into an `Iterator` yielding new messages published + /// on subscribed channels. + pub fn into_iter(self) -> impl Iterator> { + SubscriberIterator { + inner: self.inner, + rt: self.rt, + } + } + + /// Subscribe to a list of new channels + pub fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> { + self.rt.block_on(self.inner.subscribe(channels)) + } + + /// Unsubscribe to a list of new channels + pub fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> { + self.rt.block_on(self.inner.unsubscribe(channels)) + } +} + +impl Iterator for SubscriberIterator { + type Item = crate::Result; + + fn next(&mut self) -> Option> { + self.rt.block_on(self.inner.next_message()).transpose() + } +} diff --git a/assets/mini-redis/src/buffer.rs b/assets/mini-redis/src/buffer.rs new file mode 100644 index 0000000..be7b0ee --- /dev/null +++ b/assets/mini-redis/src/buffer.rs @@ -0,0 +1,120 @@ +use crate::client::Client; +use crate::Result; + +use bytes::Bytes; +use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::sync::oneshot; + +/// Create a new client request buffer +/// +/// The `Client` performs Redis commands directly on the TCP connection. Only a +/// single request may be in-flight at a given time and operations require +/// mutable access to the `Client` handle. This prevents using a single Redis +/// connection from multiple Tokio tasks. +/// +/// The strategy for dealing with this class of problem is to spawn a dedicated +/// Tokio task to manage the Redis connection and using "message passing" to +/// operate on the connection. Commands are pushed into a channel. The +/// connection task pops commands off of the channel and applies them to the +/// Redis connection. When the response is received, it is forwarded to the +/// original requester. +/// +/// The returned `Buffer` handle may be cloned before passing the new handle to +/// separate tasks. +pub fn buffer(client: Client) -> Buffer { + // Setting the message limit to a hard coded value of 32. in a real-app, the + // buffer size should be configurable, but we don't need to do that here. + let (tx, rx) = channel(32); + + // Spawn a task to process requests for the connection. + tokio::spawn(async move { run(client, rx).await }); + + // Return the `Buffer` handle. + Buffer { tx } +} + +// Enum used to message pass the requested command from the `Buffer` handle +#[derive(Debug)] +enum Command { + Get(String), + Set(String, Bytes), +} + +// Message type sent over the channel to the connection task. +// +// `Command` is the command to forward to the connection. +// +// `oneshot::Sender` is a channel type that sends a **single** value. It is used +// here to send the response received from the connection back to the original +// requester. +type Message = (Command, oneshot::Sender>>); + +/// Receive commands sent through the channel and forward them to client. The +/// response is returned back to the caller via a `oneshot`. +async fn run(mut client: Client, mut rx: Receiver) { + // Repeatedly pop messages from the channel. A return value of `None` + // indicates that all `Buffer` handles have dropped and there will never be + // another message sent on the channel. + while let Some((cmd, tx)) = rx.recv().await { + // The command is forwarded to the connection + let response = match cmd { + Command::Get(key) => client.get(&key).await, + Command::Set(key, value) => client.set(&key, value).await.map(|_| None), + }; + + // Send the response back to the caller. + // + // Failing to send the message indicates the `rx` half dropped + // before receiving the message. This is a normal runtime event. + let _ = tx.send(response); + } +} + +#[derive(Clone)] +pub struct Buffer { + tx: Sender, +} + +impl Buffer { + /// Get the value of a key. + /// + /// Same as `Client::get` but requests are **buffered** until the associated + /// connection has the ability to send the request. + pub async fn get(&mut self, key: &str) -> Result> { + // Initialize a new `Get` command to send via the channel. + let get = Command::Get(key.into()); + + // Initialize a new oneshot to be used to receive the response back from the connection. + let (tx, rx) = oneshot::channel(); + + // Send the request + self.tx.send((get, tx)).await?; + + // Await the response + match rx.await { + Ok(res) => res, + Err(err) => Err(err.into()), + } + } + + /// Set `key` to hold the given `value`. + /// + /// Same as `Client::set` but requests are **buffered** until the associated + /// connection has the ability to send the request + pub async fn set(&mut self, key: &str, value: Bytes) -> Result<()> { + // Initialize a new `Set` command to send via the channel. + let set = Command::Set(key.into(), value); + + // Initialize a new oneshot to be used to receive the response back from the connection. + let (tx, rx) = oneshot::channel(); + + // Send the request + self.tx.send((set, tx)).await?; + + // Await the response + match rx.await { + Ok(res) => res.map(|_| ()), + Err(err) => Err(err.into()), + } + } +} diff --git a/assets/mini-redis/src/client.rs b/assets/mini-redis/src/client.rs new file mode 100644 index 0000000..08223b2 --- /dev/null +++ b/assets/mini-redis/src/client.rs @@ -0,0 +1,473 @@ +//! Minimal Redis client implementation +//! +//! Provides an async connect and methods for issuing the supported commands. + +use crate::cmd::{Get, Publish, Set, Subscribe, Unsubscribe}; +use crate::{Connection, Frame}; + +use async_stream::try_stream; +use bytes::Bytes; +use std::io::{Error, ErrorKind}; +use std::time::Duration; +use tokio::net::{TcpStream, ToSocketAddrs}; +use tokio_stream::Stream; +use tracing::{debug, instrument}; + +/// Established connection with a Redis server. +/// +/// Backed by a single `TcpStream`, `Client` provides basic network client +/// functionality (no pooling, retrying, ...). Connections are established using +/// the [`connect`](fn@connect) function. +/// +/// Requests are issued using the various methods of `Client`. +pub struct Client { + /// The TCP connection decorated with the redis protocol encoder / decoder + /// implemented using a buffered `TcpStream`. + /// + /// When `Listener` receives an inbound connection, the `TcpStream` is + /// passed to `Connection::new`, which initializes the associated buffers. + /// `Connection` allows the handler to operate at the "frame" level and keep + /// the byte level protocol parsing details encapsulated in `Connection`. + connection: Connection, +} + +/// A client that has entered pub/sub mode. +/// +/// Once clients subscribe to a channel, they may only perform pub/sub related +/// commands. The `Client` type is transitioned to a `Subscriber` type in order +/// to prevent non-pub/sub methods from being called. +pub struct Subscriber { + /// The subscribed client. + client: Client, + + /// The set of channels to which the `Subscriber` is currently subscribed. + subscribed_channels: Vec, +} + +/// A message received on a subscribed channel. +#[derive(Debug, Clone)] +pub struct Message { + pub channel: String, + pub content: Bytes, +} + +/// Establish a connection with the Redis server located at `addr`. +/// +/// `addr` may be any type that can be asynchronously converted to a +/// `SocketAddr`. This includes `SocketAddr` and strings. The `ToSocketAddrs` +/// trait is the Tokio version and not the `std` version. +/// +/// # Examples +/// +/// ```no_run +/// use mini_redis::client; +/// +/// #[tokio::main] +/// async fn main() { +/// let client = match client::connect("localhost:6379").await { +/// Ok(client) => client, +/// Err(_) => panic!("failed to establish connection"), +/// }; +/// # drop(client); +/// } +/// ``` +/// +pub async fn connect(addr: T) -> crate::Result { + // The `addr` argument is passed directly to `TcpStream::connect`. This + // performs any asynchronous DNS lookup and attempts to establish the TCP + // connection. An error at either step returns an error, which is then + // bubbled up to the caller of `mini_redis` connect. + let socket = TcpStream::connect(addr).await?; + + // Initialize the connection state. This allocates read/write buffers to + // perform redis protocol frame parsing. + let connection = Connection::new(socket); + + Ok(Client { connection }) +} + +impl Client { + /// Get the value of key. + /// + /// If the key does not exist the special value `None` is returned. + /// + /// # Examples + /// + /// Demonstrates basic usage. + /// + /// ```no_run + /// use mini_redis::client; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut client = client::connect("localhost:6379").await.unwrap(); + /// + /// let val = client.get("foo").await.unwrap(); + /// println!("Got = {:?}", val); + /// } + /// ``` + #[instrument(skip(self))] + pub async fn get(&mut self, key: &str) -> crate::Result> { + // Create a `Get` command for the `key` and convert it to a frame. + let frame = Get::new(key).into_frame(); + + debug!(request = ?frame); + + // Write the frame to the socket. This writes the full frame to the + // socket, waiting if necessary. + self.connection.write_frame(&frame).await?; + + // Wait for the response from the server + // + // Both `Simple` and `Bulk` frames are accepted. `Null` represents the + // key not being present and `None` is returned. + match self.read_response().await? { + Frame::Simple(value) => Ok(Some(value.into())), + Frame::Bulk(value) => Ok(Some(value)), + Frame::Null => Ok(None), + frame => Err(frame.to_error()), + } + } + + /// Set `key` to hold the given `value`. + /// + /// The `value` is associated with `key` until it is overwritten by the next + /// call to `set` or it is removed. + /// + /// If key already holds a value, it is overwritten. Any previous time to + /// live associated with the key is discarded on successful SET operation. + /// + /// # Examples + /// + /// Demonstrates basic usage. + /// + /// ```no_run + /// use mini_redis::client; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut client = client::connect("localhost:6379").await.unwrap(); + /// + /// client.set("foo", "bar".into()).await.unwrap(); + /// + /// // Getting the value immediately works + /// let val = client.get("foo").await.unwrap().unwrap(); + /// assert_eq!(val, "bar"); + /// } + /// ``` + #[instrument(skip(self))] + pub async fn set(&mut self, key: &str, value: Bytes) -> crate::Result<()> { + // Create a `Set` command and pass it to `set_cmd`. A separate method is + // used to set a value with an expiration. The common parts of both + // functions are implemented by `set_cmd`. + self.set_cmd(Set::new(key, value, None)).await + } + + /// Set `key` to hold the given `value`. The value expires after `expiration` + /// + /// The `value` is associated with `key` until one of the following: + /// - it expires. + /// - it is overwritten by the next call to `set`. + /// - it is removed. + /// + /// If key already holds a value, it is overwritten. Any previous time to + /// live associated with the key is discarded on a successful SET operation. + /// + /// # Examples + /// + /// Demonstrates basic usage. This example is not **guaranteed** to always + /// work as it relies on time based logic and assumes the client and server + /// stay relatively synchronized in time. The real world tends to not be so + /// favorable. + /// + /// ```no_run + /// use mini_redis::client; + /// use tokio::time; + /// use std::time::Duration; + /// + /// #[tokio::main] + /// async fn main() { + /// let ttl = Duration::from_millis(500); + /// let mut client = client::connect("localhost:6379").await.unwrap(); + /// + /// client.set_expires("foo", "bar".into(), ttl).await.unwrap(); + /// + /// // Getting the value immediately works + /// let val = client.get("foo").await.unwrap().unwrap(); + /// assert_eq!(val, "bar"); + /// + /// // Wait for the TTL to expire + /// time::sleep(ttl).await; + /// + /// let val = client.get("foo").await.unwrap(); + /// assert!(val.is_some()); + /// } + /// ``` + #[instrument(skip(self))] + pub async fn set_expires( + &mut self, + key: &str, + value: Bytes, + expiration: Duration, + ) -> crate::Result<()> { + // Create a `Set` command and pass it to `set_cmd`. A separate method is + // used to set a value with an expiration. The common parts of both + // functions are implemented by `set_cmd`. + self.set_cmd(Set::new(key, value, Some(expiration))).await + } + + /// The core `SET` logic, used by both `set` and `set_expires. + async fn set_cmd(&mut self, cmd: Set) -> crate::Result<()> { + // Convert the `Set` command into a frame + let frame = cmd.into_frame(); + + debug!(request = ?frame); + + // Write the frame to the socket. This writes the full frame to the + // socket, waiting if necessary. + self.connection.write_frame(&frame).await?; + + // Wait for the response from the server. On success, the server + // responds simply with `OK`. Any other response indicates an error. + match self.read_response().await? { + Frame::Simple(response) if response == "OK" => Ok(()), + frame => Err(frame.to_error()), + } + } + + /// Posts `message` to the given `channel`. + /// + /// Returns the number of subscribers currently listening on the channel. + /// There is no guarantee that these subscribers receive the message as they + /// may disconnect at any time. + /// + /// # Examples + /// + /// Demonstrates basic usage. + /// + /// ```no_run + /// use mini_redis::client; + /// + /// #[tokio::main] + /// async fn main() { + /// let mut client = client::connect("localhost:6379").await.unwrap(); + /// + /// let val = client.publish("foo", "bar".into()).await.unwrap(); + /// println!("Got = {:?}", val); + /// } + /// ``` + #[instrument(skip(self))] + pub async fn publish(&mut self, channel: &str, message: Bytes) -> crate::Result { + // Convert the `Publish` command into a frame + let frame = Publish::new(channel, message).into_frame(); + + debug!(request = ?frame); + + // Write the frame to the socket + self.connection.write_frame(&frame).await?; + + // Read the response + match self.read_response().await? { + Frame::Integer(response) => Ok(response), + frame => Err(frame.to_error()), + } + } + + /// Subscribes the client to the specified channels. + /// + /// Once a client issues a subscribe command, it may no longer issue any + /// non-pub/sub commands. The function consumes `self` and returns a `Subscriber`. + /// + /// The `Subscriber` value is used to receive messages as well as manage the + /// list of channels the client is subscribed to. + #[instrument(skip(self))] + pub async fn subscribe(mut self, channels: Vec) -> crate::Result { + // Issue the subscribe command to the server and wait for confirmation. + // The client will then have been transitioned into the "subscriber" + // state and may only issue pub/sub commands from that point on. + self.subscribe_cmd(&channels).await?; + + // Return the `Subscriber` type + Ok(Subscriber { + client: self, + subscribed_channels: channels, + }) + } + + /// The core `SUBSCRIBE` logic, used by misc subscribe fns + async fn subscribe_cmd(&mut self, channels: &[String]) -> crate::Result<()> { + // Convert the `Subscribe` command into a frame + let frame = Subscribe::new(&channels).into_frame(); + + debug!(request = ?frame); + + // Write the frame to the socket + self.connection.write_frame(&frame).await?; + + // For each channel being subscribed to, the server responds with a + // message confirming subscription to that channel. + for channel in channels { + // Read the response + let response = self.read_response().await?; + + // Verify it is confirmation of subscription. + match response { + Frame::Array(ref frame) => match frame.as_slice() { + // The server responds with an array frame in the form of: + // + // ``` + // [ "subscribe", channel, num-subscribed ] + // ``` + // + // where channel is the name of the channel and + // num-subscribed is the number of channels that the client + // is currently subscribed to. + [subscribe, schannel, ..] + if *subscribe == "subscribe" && *schannel == channel => {} + _ => return Err(response.to_error()), + }, + frame => return Err(frame.to_error()), + }; + } + + Ok(()) + } + + /// Reads a response frame from the socket. + /// + /// If an `Error` frame is received, it is converted to `Err`. + async fn read_response(&mut self) -> crate::Result { + let response = self.connection.read_frame().await?; + + debug!(?response); + + match response { + // Error frames are converted to `Err` + Some(Frame::Error(msg)) => Err(msg.into()), + Some(frame) => Ok(frame), + None => { + // Receiving `None` here indicates the server has closed the + // connection without sending a frame. This is unexpected and is + // represented as a "connection reset by peer" error. + let err = Error::new(ErrorKind::ConnectionReset, "connection reset by server"); + + Err(err.into()) + } + } + } +} + +impl Subscriber { + /// Returns the set of channels currently subscribed to. + pub fn get_subscribed(&self) -> &[String] { + &self.subscribed_channels + } + + /// Receive the next message published on a subscribed channel, waiting if + /// necessary. + /// + /// `None` indicates the subscription has been terminated. + pub async fn next_message(&mut self) -> crate::Result> { + match self.client.connection.read_frame().await? { + Some(mframe) => { + debug!(?mframe); + + match mframe { + Frame::Array(ref frame) => match frame.as_slice() { + [message, channel, content] if *message == "message" => Ok(Some(Message { + channel: channel.to_string(), + content: Bytes::from(content.to_string()), + })), + _ => Err(mframe.to_error()), + }, + frame => Err(frame.to_error()), + } + } + None => Ok(None), + } + } + + /// Convert the subscriber into a `Stream` yielding new messages published + /// on subscribed channels. + /// + /// `Subscriber` does not implement stream itself as doing so with safe code + /// is non trivial. The usage of async/await would require a manual Stream + /// implementation to use `unsafe` code. Instead, a conversion function is + /// provided and the returned stream is implemented with the help of the + /// `async-stream` crate. + pub fn into_stream(mut self) -> impl Stream> { + // Uses the `try_stream` macro from the `async-stream` crate. Generators + // are not stable in Rust. The crate uses a macro to simulate generators + // on top of async/await. There are limitations, so read the + // documentation there. + try_stream! { + while let Some(message) = self.next_message().await? { + yield message; + } + } + } + + /// Subscribe to a list of new channels + #[instrument(skip(self))] + pub async fn subscribe(&mut self, channels: &[String]) -> crate::Result<()> { + // Issue the subscribe command + self.client.subscribe_cmd(channels).await?; + + // Update the set of subscribed channels. + self.subscribed_channels + .extend(channels.iter().map(Clone::clone)); + + Ok(()) + } + + /// Unsubscribe to a list of new channels + #[instrument(skip(self))] + pub async fn unsubscribe(&mut self, channels: &[String]) -> crate::Result<()> { + let frame = Unsubscribe::new(&channels).into_frame(); + + debug!(request = ?frame); + + // Write the frame to the socket + self.client.connection.write_frame(&frame).await?; + + // if the input channel list is empty, server acknowledges as unsubscribing + // from all subscribed channels, so we assert that the unsubscribe list received + // matches the client subscribed one + let num = if channels.is_empty() { + self.subscribed_channels.len() + } else { + channels.len() + }; + + // Read the response + for _ in 0..num { + let response = self.client.read_response().await?; + + match response { + Frame::Array(ref frame) => match frame.as_slice() { + [unsubscribe, channel, ..] if *unsubscribe == "unsubscribe" => { + let len = self.subscribed_channels.len(); + + if len == 0 { + // There must be at least one channel + return Err(response.to_error()); + } + + // unsubscribed channel should exist in the subscribed list at this point + self.subscribed_channels.retain(|c| *channel != &c[..]); + + // Only a single channel should be removed from the + // list of subscribed channels. + if self.subscribed_channels.len() != len - 1 { + return Err(response.to_error()); + } + } + _ => return Err(response.to_error()), + }, + frame => return Err(frame.to_error()), + }; + } + + Ok(()) + } +} diff --git a/assets/mini-redis/src/cmd/get.rs b/assets/mini-redis/src/cmd/get.rs new file mode 100644 index 0000000..81964a8 --- /dev/null +++ b/assets/mini-redis/src/cmd/get.rs @@ -0,0 +1,93 @@ +use crate::{Connection, Db, Frame, Parse}; + +use bytes::Bytes; +use tracing::{debug, instrument}; + +/// Get the value of key. +/// +/// If the key does not exist the special value nil is returned. An error is +/// returned if the value stored at key is not a string, because GET only +/// handles string values. +#[derive(Debug)] +pub struct Get { + /// Name of the key to get + key: String, +} + +impl Get { + /// Create a new `Get` command which fetches `key`. + pub fn new(key: impl ToString) -> Get { + Get { + key: key.to_string(), + } + } + + /// Get the key + pub fn key(&self) -> &str { + &self.key + } + + /// Parse a `Get` instance from a received frame. + /// + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from + /// the socket. + /// + /// The `GET` string has already been consumed. + /// + /// # Returns + /// + /// Returns the `Get` value on success. If the frame is malformed, `Err` is + /// returned. + /// + /// # Format + /// + /// Expects an array frame containing two entries. + /// + /// ```text + /// GET key + /// ``` + pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { + // The `GET` string has already been consumed. The next value is the + // name of the key to get. If the next value is not a string or the + // input is fully consumed, then an error is returned. + let key = parse.next_string()?; + + Ok(Get { key }) + } + + /// Apply the `Get` command to the specified `Db` instance. + /// + /// The response is written to `dst`. This is called by the server in order + /// to execute a received command. + #[instrument(skip(self, db, dst))] + pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { + // Get the value from the shared database state + let response = if let Some(value) = db.get(&self.key) { + // If a value is present, it is written to the client in "bulk" + // format. + Frame::Bulk(value) + } else { + // If there is no value, `Null` is written. + Frame::Null + }; + + debug!(?response); + + // Write the response back to the client + dst.write_frame(&response).await?; + + Ok(()) + } + + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding a `Get` command to send to + /// the server. + pub(crate) fn into_frame(self) -> Frame { + let mut frame = Frame::array(); + frame.push_bulk(Bytes::from("get".as_bytes())); + frame.push_bulk(Bytes::from(self.key.into_bytes())); + frame + } +} diff --git a/assets/mini-redis/src/cmd/mod.rs b/assets/mini-redis/src/cmd/mod.rs new file mode 100644 index 0000000..2da5ad0 --- /dev/null +++ b/assets/mini-redis/src/cmd/mod.rs @@ -0,0 +1,116 @@ +mod get; +pub use get::Get; + +mod publish; +pub use publish::Publish; + +mod set; +pub use set::Set; + +mod subscribe; +pub use subscribe::{Subscribe, Unsubscribe}; + +mod unknown; +pub use unknown::Unknown; + +use crate::{Connection, Db, Frame, Parse, ParseError, Shutdown}; + +/// Enumeration of supported Redis commands. +/// +/// Methods called on `Command` are delegated to the command implementation. +#[derive(Debug)] +pub enum Command { + Get(Get), + Publish(Publish), + Set(Set), + Subscribe(Subscribe), + Unsubscribe(Unsubscribe), + Unknown(Unknown), +} + +impl Command { + /// Parse a command from a received frame. + /// + /// The `Frame` must represent a Redis command supported by `mini-redis` and + /// be the array variant. + /// + /// # Returns + /// + /// On success, the command value is returned, otherwise, `Err` is returned. + pub fn from_frame(frame: Frame) -> crate::Result { + // The frame value is decorated with `Parse`. `Parse` provides a + // "cursor" like API which makes parsing the command easier. + // + // The frame value must be an array variant. Any other frame variants + // result in an error being returned. + let mut parse = Parse::new(frame)?; + + // All redis commands begin with the command name as a string. The name + // is read and converted to lower cases in order to do case sensitive + // matching. + let command_name = parse.next_string()?.to_lowercase(); + + // Match the command name, delegating the rest of the parsing to the + // specific command. + let command = match &command_name[..] { + "get" => Command::Get(Get::parse_frames(&mut parse)?), + "publish" => Command::Publish(Publish::parse_frames(&mut parse)?), + "set" => Command::Set(Set::parse_frames(&mut parse)?), + "subscribe" => Command::Subscribe(Subscribe::parse_frames(&mut parse)?), + "unsubscribe" => Command::Unsubscribe(Unsubscribe::parse_frames(&mut parse)?), + _ => { + // The command is not recognized and an Unknown command is + // returned. + // + // `return` is called here to skip the `finish()` call below. As + // the command is not recognized, there is most likely + // unconsumed fields remaining in the `Parse` instance. + return Ok(Command::Unknown(Unknown::new(command_name))); + } + }; + + // Check if there is any remaining unconsumed fields in the `Parse` + // value. If fields remain, this indicates an unexpected frame format + // and an error is returned. + parse.finish()?; + + // The command has been successfully parsed + Ok(command) + } + + /// Apply the command to the specified `Db` instance. + /// + /// The response is written to `dst`. This is called by the server in order + /// to execute a received command. + pub(crate) async fn apply( + self, + db: &Db, + dst: &mut Connection, + shutdown: &mut Shutdown, + ) -> crate::Result<()> { + use Command::*; + + match self { + Get(cmd) => cmd.apply(db, dst).await, + Publish(cmd) => cmd.apply(db, dst).await, + Set(cmd) => cmd.apply(db, dst).await, + Subscribe(cmd) => cmd.apply(db, dst, shutdown).await, + Unknown(cmd) => cmd.apply(dst).await, + // `Unsubscribe` cannot be applied. It may only be received from the + // context of a `Subscribe` command. + Unsubscribe(_) => Err("`Unsubscribe` is unsupported in this context".into()), + } + } + + /// Returns the command name + pub(crate) fn get_name(&self) -> &str { + match self { + Command::Get(_) => "get", + Command::Publish(_) => "pub", + Command::Set(_) => "set", + Command::Subscribe(_) => "subscribe", + Command::Unsubscribe(_) => "unsubscribe", + Command::Unknown(cmd) => cmd.get_name(), + } + } +} diff --git a/assets/mini-redis/src/cmd/publish.rs b/assets/mini-redis/src/cmd/publish.rs new file mode 100644 index 0000000..3c28b1c --- /dev/null +++ b/assets/mini-redis/src/cmd/publish.rs @@ -0,0 +1,101 @@ +use crate::{Connection, Db, Frame, Parse}; + +use bytes::Bytes; + +/// Posts a message to the given channel. +/// +/// Send a message into a channel without any knowledge of individual consumers. +/// Consumers may subscribe to channels in order to receive the messages. +/// +/// Channel names have no relation to the key-value namespace. Publishing on a +/// channel named "foo" has no relation to setting the "foo" key. +#[derive(Debug)] +pub struct Publish { + /// Name of the channel on which the message should be published. + channel: String, + + /// The message to publish. + message: Bytes, +} + +impl Publish { + /// Create a new `Publish` command which sends `message` on `channel`. + pub(crate) fn new(channel: impl ToString, message: Bytes) -> Publish { + Publish { + channel: channel.to_string(), + message, + } + } + + /// Parse a `Publish` instance from a received frame. + /// + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from + /// the socket. + /// + /// The `PUBLISH` string has already been consumed. + /// + /// # Returns + /// + /// On success, the `Publish` value is returned. If the frame is malformed, + /// `Err` is returned. + /// + /// # Format + /// + /// Expects an array frame containing three entries. + /// + /// ```text + /// PUBLISH channel message + /// ``` + pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { + // The `PUBLISH` string has already been consumed. Extract the `channel` + // and `message` values from the frame. + // + // The `channel` must be a valid string. + let channel = parse.next_string()?; + + // The `message` is arbitrary bytes. + let message = parse.next_bytes()?; + + Ok(Publish { channel, message }) + } + + /// Apply the `Publish` command to the specified `Db` instance. + /// + /// The response is written to `dst`. This is called by the server in order + /// to execute a received command. + pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { + // The shared state contains the `tokio::sync::broadcast::Sender` for + // all active channels. Calling `db.publish` dispatches the message into + // the appropriate channel. + // + // The number of subscribers currently listening on the channel is + // returned. This does not mean that `num_subscriber` channels will + // receive the message. Subscribers may drop before receiving the + // message. Given this, `num_subscribers` should only be used as a + // "hint". + let num_subscribers = db.publish(&self.channel, self.message); + + // The number of subscribers is returned as the response to the publish + // request. + let response = Frame::Integer(num_subscribers as u64); + + // Write the frame to the client. + dst.write_frame(&response).await?; + + Ok(()) + } + + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding a `Publish` command to send + /// to the server. + pub(crate) fn into_frame(self) -> Frame { + let mut frame = Frame::array(); + frame.push_bulk(Bytes::from("publish".as_bytes())); + frame.push_bulk(Bytes::from(self.channel.into_bytes())); + frame.push_bulk(self.message); + + frame + } +} diff --git a/assets/mini-redis/src/cmd/set.rs b/assets/mini-redis/src/cmd/set.rs new file mode 100644 index 0000000..eae05d7 --- /dev/null +++ b/assets/mini-redis/src/cmd/set.rs @@ -0,0 +1,161 @@ +use crate::cmd::{Parse, ParseError}; +use crate::{Connection, Db, Frame}; + +use bytes::Bytes; +use std::time::Duration; +use tracing::{debug, instrument}; + +/// Set `key` to hold the string `value`. +/// +/// If `key` already holds a value, it is overwritten, regardless of its type. +/// Any previous time to live associated with the key is discarded on successful +/// SET operation. +/// +/// # Options +/// +/// Currently, the following options are supported: +/// +/// * EX `seconds` -- Set the specified expire time, in seconds. +/// * PX `milliseconds` -- Set the specified expire time, in milliseconds. +#[derive(Debug)] +pub struct Set { + /// the lookup key + key: String, + + /// the value to be stored + value: Bytes, + + /// When to expire the key + expire: Option, +} + +impl Set { + /// Create a new `Set` command which sets `key` to `value`. + /// + /// If `expire` is `Some`, the value should expire after the specified + /// duration. + pub fn new(key: impl ToString, value: Bytes, expire: Option) -> Set { + Set { + key: key.to_string(), + value, + expire, + } + } + + /// Get the key + pub fn key(&self) -> &str { + &self.key + } + + /// Get the value + pub fn value(&self) -> &Bytes { + &self.value + } + + /// Get the expire + pub fn expire(&self) -> Option { + self.expire + } + + /// Parse a `Set` instance from a received frame. + /// + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from + /// the socket. + /// + /// The `SET` string has already been consumed. + /// + /// # Returns + /// + /// Returns the `Set` value on success. If the frame is malformed, `Err` is + /// returned. + /// + /// # Format + /// + /// Expects an array frame containing at least 3 entries. + /// + /// ```text + /// SET key value [EX seconds|PX milliseconds] + /// ``` + pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { + use ParseError::EndOfStream; + + // Read the key to set. This is a required field + let key = parse.next_string()?; + + // Read the value to set. This is a required field. + let value = parse.next_bytes()?; + + // The expiration is optional. If nothing else follows, then it is + // `None`. + let mut expire = None; + + // Attempt to parse another string. + match parse.next_string() { + Ok(s) if s.to_uppercase() == "EX" => { + // An expiration is specified in seconds. The next value is an + // integer. + let secs = parse.next_int()?; + expire = Some(Duration::from_secs(secs)); + } + Ok(s) if s.to_uppercase() == "PX" => { + // An expiration is specified in milliseconds. The next value is + // an integer. + let ms = parse.next_int()?; + expire = Some(Duration::from_millis(ms)); + } + // Currently, mini-redis does not support any of the other SET + // options. An error here results in the connection being + // terminated. Other connections will continue to operate normally. + Ok(_) => return Err("currently `SET` only supports the expiration option".into()), + // The `EndOfStream` error indicates there is no further data to + // parse. In this case, it is a normal run time situation and + // indicates there are no specified `SET` options. + Err(EndOfStream) => {} + // All other errors are bubbled up, resulting in the connection + // being terminated. + Err(err) => return Err(err.into()), + } + + Ok(Set { key, value, expire }) + } + + /// Apply the `Set` command to the specified `Db` instance. + /// + /// The response is written to `dst`. This is called by the server in order + /// to execute a received command. + #[instrument(skip(self, db, dst))] + pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> crate::Result<()> { + // Set the value in the shared database state. + db.set(self.key, self.value, self.expire); + + // Create a success response and write it to `dst`. + let response = Frame::Simple("OK".to_string()); + debug!(?response); + dst.write_frame(&response).await?; + + Ok(()) + } + + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding a `Set` command to send to + /// the server. + pub(crate) fn into_frame(self) -> Frame { + let mut frame = Frame::array(); + frame.push_bulk(Bytes::from("set".as_bytes())); + frame.push_bulk(Bytes::from(self.key.into_bytes())); + frame.push_bulk(self.value); + if let Some(ms) = self.expire { + // Expirations in Redis procotol can be specified in two ways + // 1. SET key value EX seconds + // 2. SET key value PX milliseconds + // We the second option because it allows greater precision and + // src/bin/cli.rs parses the expiration argument as milliseconds + // in duration_from_ms_str() + frame.push_bulk(Bytes::from("px".as_bytes())); + frame.push_int(ms.as_millis() as u64); + } + frame + } +} diff --git a/assets/mini-redis/src/cmd/subscribe.rs b/assets/mini-redis/src/cmd/subscribe.rs new file mode 100644 index 0000000..ab87763 --- /dev/null +++ b/assets/mini-redis/src/cmd/subscribe.rs @@ -0,0 +1,351 @@ +use crate::cmd::{Parse, ParseError, Unknown}; +use crate::{Command, Connection, Db, Frame, Shutdown}; + +use bytes::Bytes; +use std::pin::Pin; +use tokio::select; +use tokio::sync::broadcast; +use tokio_stream::{Stream, StreamExt, StreamMap}; + +/// Subscribes the client to one or more channels. +/// +/// Once the client enters the subscribed state, it is not supposed to issue any +/// other commands, except for additional SUBSCRIBE, PSUBSCRIBE, UNSUBSCRIBE, +/// PUNSUBSCRIBE, PING and QUIT commands. +#[derive(Debug)] +pub struct Subscribe { + channels: Vec, +} + +/// Unsubscribes the client from one or more channels. +/// +/// When no channels are specified, the client is unsubscribed from all the +/// previously subscribed channels. +#[derive(Clone, Debug)] +pub struct Unsubscribe { + channels: Vec, +} + +/// Stream of messages. The stream receives messages from the +/// `broadcast::Receiver`. We use `stream!` to create a `Stream` that consumes +/// messages. Because `stream!` values cannot be named, we box the stream using +/// a trait object. +type Messages = Pin + Send>>; + +impl Subscribe { + /// Creates a new `Subscribe` command to listen on the specified channels. + pub(crate) fn new(channels: &[String]) -> Subscribe { + Subscribe { + channels: channels.to_vec(), + } + } + + /// Parse a `Subscribe` instance from a received frame. + /// + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from + /// the socket. + /// + /// The `SUBSCRIBE` string has already been consumed. + /// + /// # Returns + /// + /// On success, the `Subscribe` value is returned. If the frame is + /// malformed, `Err` is returned. + /// + /// # Format + /// + /// Expects an array frame containing two or more entries. + /// + /// ```text + /// SUBSCRIBE channel [channel ...] + /// ``` + pub(crate) fn parse_frames(parse: &mut Parse) -> crate::Result { + use ParseError::EndOfStream; + + // The `SUBSCRIBE` string has already been consumed. At this point, + // there is one or more strings remaining in `parse`. These represent + // the channels to subscribe to. + // + // Extract the first string. If there is none, the the frame is + // malformed and the error is bubbled up. + let mut channels = vec![parse.next_string()?]; + + // Now, the remainder of the frame is consumed. Each value must be a + // string or the frame is malformed. Once all values in the frame have + // been consumed, the command is fully parsed. + loop { + match parse.next_string() { + // A string has been consumed from the `parse`, push it into the + // list of channels to subscribe to. + Ok(s) => channels.push(s), + // The `EndOfStream` error indicates there is no further data to + // parse. + Err(EndOfStream) => break, + // All other errors are bubbled up, resulting in the connection + // being terminated. + Err(err) => return Err(err.into()), + } + } + + Ok(Subscribe { channels }) + } + + /// Apply the `Subscribe` command to the specified `Db` instance. + /// + /// This function is the entry point and includes the initial list of + /// channels to subscribe to. Additional `subscribe` and `unsubscribe` + /// commands may be received from the client and the list of subscriptions + /// are updated accordingly. + /// + /// [here]: https://redis.io/topics/pubsub + pub(crate) async fn apply( + mut self, + db: &Db, + dst: &mut Connection, + shutdown: &mut Shutdown, + ) -> crate::Result<()> { + // Each individual channel subscription is handled using a + // `sync::broadcast` channel. Messages are then fanned out to all + // clients currently subscribed to the channels. + // + // An individual client may subscribe to multiple channels and may + // dynamically add and remove channels from its subscription set. To + // handle this, a `StreamMap` is used to track active subscriptions. The + // `StreamMap` merges messages from individual broadcast channels as + // they are received. + let mut subscriptions = StreamMap::new(); + + loop { + // `self.channels` is used to track additional channels to subscribe + // to. When new `SUBSCRIBE` commands are received during the + // execution of `apply`, the new channels are pushed onto this vec. + for channel_name in self.channels.drain(..) { + subscribe_to_channel(channel_name, &mut subscriptions, db, dst).await?; + } + + // Wait for one of the following to happen: + // + // - Receive a message from one of the subscribed channels. + // - Receive a subscribe or unsubscribe command from the client. + // - A server shutdown signal. + select! { + // Receive messages from subscribed channels + Some((channel_name, msg)) = subscriptions.next() => { + dst.write_frame(&make_message_frame(channel_name, msg)).await?; + }, + res = dst.read_frame() => { + let frame = match res? { + Some(frame) => frame, + // This happens if the remote client has disconnected. + None => return Ok(()) + }; + + handle_command( + frame, + &mut self.channels, + &mut subscriptions, + dst, + ).await?; + }, + _ = shutdown.recv() => { + return Ok(()); + } + }; + } + } + + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding a `Subscribe` command to send + /// to the server. + pub(crate) fn into_frame(self) -> Frame { + let mut frame = Frame::array(); + frame.push_bulk(Bytes::from("subscribe".as_bytes())); + for channel in self.channels { + frame.push_bulk(Bytes::from(channel.into_bytes())); + } + frame + } +} + +async fn subscribe_to_channel( + channel_name: String, + subscriptions: &mut StreamMap, + db: &Db, + dst: &mut Connection, +) -> crate::Result<()> { + let mut rx = db.subscribe(channel_name.clone()); + + // Subscribe to the channel. + let rx = Box::pin(async_stream::stream! { + loop { + match rx.recv().await { + Ok(msg) => yield msg, + // If we lagged in consuming messages, just resume. + Err(broadcast::error::RecvError::Lagged(_)) => {} + Err(_) => break, + } + } + }); + + // Track subscription in this client's subscription set. + subscriptions.insert(channel_name.clone(), rx); + + // Respond with the successful subscription + let response = make_subscribe_frame(channel_name, subscriptions.len()); + dst.write_frame(&response).await?; + + Ok(()) +} + +/// Handle a command received while inside `Subscribe::apply`. Only subscribe +/// and unsubscribe commands are permitted in this context. +/// +/// Any new subscriptions are appended to `subscribe_to` instead of modifying +/// `subscriptions`. +async fn handle_command( + frame: Frame, + subscribe_to: &mut Vec, + subscriptions: &mut StreamMap, + dst: &mut Connection, +) -> crate::Result<()> { + // A command has been received from the client. + // + // Only `SUBSCRIBE` and `UNSUBSCRIBE` commands are permitted + // in this context. + match Command::from_frame(frame)? { + Command::Subscribe(subscribe) => { + // The `apply` method will subscribe to the channels we add to this + // vector. + subscribe_to.extend(subscribe.channels.into_iter()); + } + Command::Unsubscribe(mut unsubscribe) => { + // If no channels are specified, this requests unsubscribing from + // **all** channels. To implement this, the `unsubscribe.channels` + // vec is populated with the list of channels currently subscribed + // to. + if unsubscribe.channels.is_empty() { + unsubscribe.channels = subscriptions + .keys() + .map(|channel_name| channel_name.to_string()) + .collect(); + } + + for channel_name in unsubscribe.channels { + subscriptions.remove(&channel_name); + + let response = make_unsubscribe_frame(channel_name, subscriptions.len()); + dst.write_frame(&response).await?; + } + } + command => { + let cmd = Unknown::new(command.get_name()); + cmd.apply(dst).await?; + } + } + Ok(()) +} + +/// Creates the response to a subcribe request. +/// +/// All of these functions take the `channel_name` as a `String` instead of +/// a `&str` since `Bytes::from` can reuse the allocation in the `String`, and +/// taking a `&str` would require copying the data. This allows the caller to +/// decide whether to clone the channel name or not. +fn make_subscribe_frame(channel_name: String, num_subs: usize) -> Frame { + let mut response = Frame::array(); + response.push_bulk(Bytes::from_static(b"subscribe")); + response.push_bulk(Bytes::from(channel_name)); + response.push_int(num_subs as u64); + response +} + +/// Creates the response to an unsubcribe request. +fn make_unsubscribe_frame(channel_name: String, num_subs: usize) -> Frame { + let mut response = Frame::array(); + response.push_bulk(Bytes::from_static(b"unsubscribe")); + response.push_bulk(Bytes::from(channel_name)); + response.push_int(num_subs as u64); + response +} + +/// Creates a message informing the client about a new message on a channel that +/// the client subscribes to. +fn make_message_frame(channel_name: String, msg: Bytes) -> Frame { + let mut response = Frame::array(); + response.push_bulk(Bytes::from_static(b"message")); + response.push_bulk(Bytes::from(channel_name)); + response.push_bulk(msg); + response +} + +impl Unsubscribe { + /// Create a new `Unsubscribe` command with the given `channels`. + pub(crate) fn new(channels: &[String]) -> Unsubscribe { + Unsubscribe { + channels: channels.to_vec(), + } + } + + /// Parse a `Unsubscribe` instance from a received frame. + /// + /// The `Parse` argument provides a cursor-like API to read fields from the + /// `Frame`. At this point, the entire frame has already been received from + /// the socket. + /// + /// The `UNSUBSCRIBE` string has already been consumed. + /// + /// # Returns + /// + /// On success, the `Unsubscribe` value is returned. If the frame is + /// malformed, `Err` is returned. + /// + /// # Format + /// + /// Expects an array frame containing at least one entry. + /// + /// ```text + /// UNSUBSCRIBE [channel [channel ...]] + /// ``` + pub(crate) fn parse_frames(parse: &mut Parse) -> Result { + use ParseError::EndOfStream; + + // There may be no channels listed, so start with an empty vec. + let mut channels = vec![]; + + // Each entry in the frame must be a string or the frame is malformed. + // Once all values in the frame have been consumed, the command is fully + // parsed. + loop { + match parse.next_string() { + // A string has been consumed from the `parse`, push it into the + // list of channels to unsubscribe from. + Ok(s) => channels.push(s), + // The `EndOfStream` error indicates there is no further data to + // parse. + Err(EndOfStream) => break, + // All other errors are bubbled up, resulting in the connection + // being terminated. + Err(err) => return Err(err), + } + } + + Ok(Unsubscribe { channels }) + } + + /// Converts the command into an equivalent `Frame`. + /// + /// This is called by the client when encoding an `Unsubscribe` command to + /// send to the server. + pub(crate) fn into_frame(self) -> Frame { + let mut frame = Frame::array(); + frame.push_bulk(Bytes::from("unsubscribe".as_bytes())); + + for channel in self.channels { + frame.push_bulk(Bytes::from(channel.into_bytes())); + } + + frame + } +} diff --git a/assets/mini-redis/src/cmd/unknown.rs b/assets/mini-redis/src/cmd/unknown.rs new file mode 100644 index 0000000..25f869a --- /dev/null +++ b/assets/mini-redis/src/cmd/unknown.rs @@ -0,0 +1,37 @@ +use crate::{Connection, Frame}; + +use tracing::{debug, instrument}; + +/// Represents an "unknown" command. This is not a real `Redis` command. +#[derive(Debug)] +pub struct Unknown { + command_name: String, +} + +impl Unknown { + /// Create a new `Unknown` command which responds to unknown commands + /// issued by clients + pub(crate) fn new(key: impl ToString) -> Unknown { + Unknown { + command_name: key.to_string(), + } + } + + /// Returns the command name + pub(crate) fn get_name(&self) -> &str { + &self.command_name + } + + /// Responds to the client, indicating the command is not recognized. + /// + /// This usually means the command is not yet implemented by `mini-redis`. + #[instrument(skip(self, dst))] + pub(crate) async fn apply(self, dst: &mut Connection) -> crate::Result<()> { + let response = Frame::Error(format!("ERR unknown command '{}'", self.command_name)); + + debug!(?response); + + dst.write_frame(&response).await?; + Ok(()) + } +} diff --git a/assets/mini-redis/src/connection.rs b/assets/mini-redis/src/connection.rs new file mode 100644 index 0000000..cf3c296 --- /dev/null +++ b/assets/mini-redis/src/connection.rs @@ -0,0 +1,237 @@ +use crate::frame::{self, Frame}; + +use bytes::{Buf, BytesMut}; +use std::io::{self, Cursor}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; +use tokio::net::TcpStream; + +/// Send and receive `Frame` values from a remote peer. +/// +/// When implementing networking protocols, a message on that protocol is +/// often composed of several smaller messages known as frames. The purpose of +/// `Connection` is to read and write frames on the underlying `TcpStream`. +/// +/// To read frames, the `Connection` uses an internal buffer, which is filled +/// up until there are enough bytes to create a full frame. Once this happens, +/// the `Connection` creates the frame and returns it to the caller. +/// +/// When sending frames, the frame is first encoded into the write buffer. +/// The contents of the write buffer are then written to the socket. +#[derive(Debug)] +pub struct Connection { + // The `TcpStream`. It is decorated with a `BufWriter`, which provides write + // level buffering. The `BufWriter` implementation provided by Tokio is + // sufficient for our needs. + stream: BufWriter, + + // The buffer for reading frames. + buffer: BytesMut, +} + +impl Connection { + /// Create a new `Connection`, backed by `socket`. Read and write buffers + /// are initialized. + pub fn new(socket: TcpStream) -> Connection { + Connection { + stream: BufWriter::new(socket), + // Default to a 4KB read buffer. For the use case of mini redis, + // this is fine. However, real applications will want to tune this + // value to their specific use case. There is a high likelihood that + // a larger read buffer will work better. + buffer: BytesMut::with_capacity(4 * 1024), + } + } + + /// Read a single `Frame` value from the underlying stream. + /// + /// The function waits until it has retrieved enough data to parse a frame. + /// Any data remaining in the read buffer after the frame has been parsed is + /// kept there for the next call to `read_frame`. + /// + /// # Returns + /// + /// On success, the received frame is returned. If the `TcpStream` + /// is closed in a way that doesn't break a frame in half, it returns + /// `None`. Otherwise, an error is returned. + pub async fn read_frame(&mut self) -> crate::Result> { + loop { + // Attempt to parse a frame from the buffered data. If enough data + // has been buffered, the frame is returned. + if let Some(frame) = self.parse_frame()? { + return Ok(Some(frame)); + } + + // There is not enough buffered data to read a frame. Attempt to + // read more data from the socket. + // + // On success, the number of bytes is returned. `0` indicates "end + // of stream". + if 0 == self.stream.read_buf(&mut self.buffer).await? { + // The remote closed the connection. For this to be a clean + // shutdown, there should be no data in the read buffer. If + // there is, this means that the peer closed the socket while + // sending a frame. + if self.buffer.is_empty() { + return Ok(None); + } else { + let s = "connection reset by peer".into(); + return Err(s); + } + } + } + } + + /// Tries to parse a frame from the buffer. If the buffer contains enough + /// data, the frame is returned and the data removed from the buffer. If not + /// enough data has been buffered yet, `Ok(None)` is returned. If the + /// buffered data does not represent a valid frame, `Err` is returned. + fn parse_frame(&mut self) -> crate::Result> { + use frame::Error::Incomplete; + + // Cursor is used to track the "current" location in the + // buffer. Cursor also implements `Buf` from the `bytes` crate + // which provides a number of helpful utilities for working + // with bytes. + let mut buf = Cursor::new(&self.buffer[..]); + + // The first step is to check if enough data has been buffered to parse + // a single frame. This step is usually much faster than doing a full + // parse of the frame, and allows us to skip allocating data structures + // to hold the frame data unless we know the full frame has been + // received. + match Frame::check(&mut buf) { + Ok(_) => { + // The `check` function will have advanced the cursor until the + // end of the frame. Since the cursor had position set to zero + // before `Frame::check` was called, we obtain the length of the + // frame by checking the cursor position. + let len = buf.position() as usize; + + // Reset the position to zero before passing the cursor to + // `Frame::parse`. + buf.set_position(0); + + // Parse the frame from the buffer. This allocates the necessary + // structures to represent the frame and returns the frame + // value. + // + // If the encoded frame representation is invalid, an error is + // returned. This should terminate the **current** connection + // but should not impact any other connected client. + let frame = Frame::parse(&mut buf)?; + + // Discard the parsed data from the read buffer. + // + // When `advance` is called on the read buffer, all of the data + // up to `len` is discarded. The details of how this works is + // left to `BytesMut`. This is often done by moving an internal + // cursor, but it may be done by reallocating and copying data. + self.buffer.advance(len); + + // Return the parsed frame to the caller. + Ok(Some(frame)) + } + // There is not enough data present in the read buffer to parse a + // single frame. We must wait for more data to be received from the + // socket. Reading from the socket will be done in the statement + // after this `match`. + // + // We do not want to return `Err` from here as this "error" is an + // expected runtime condition. + Err(Incomplete) => Ok(None), + // An error was encountered while parsing the frame. The connection + // is now in an invalid state. Returning `Err` from here will result + // in the connection being closed. + Err(e) => Err(e.into()), + } + } + + /// Write a single `Frame` value to the underlying stream. + /// + /// The `Frame` value is written to the socket using the various `write_*` + /// functions provided by `AsyncWrite`. Calling these functions directly on + /// a `TcpStream` is **not** advised, as this will result in a large number of + /// syscalls. However, it is fine to call these functions on a *buffered* + /// write stream. The data will be written to the buffer. Once the buffer is + /// full, it is flushed to the underlying socket. + pub async fn write_frame(&mut self, frame: &Frame) -> io::Result<()> { + // Arrays are encoded by encoding each entry. All other frame types are + // considered literals. For now, mini-redis is not able to encode + // recursive frame structures. See below for more details. + match frame { + Frame::Array(val) => { + // Encode the frame type prefix. For an array, it is `*`. + self.stream.write_u8(b'*').await?; + + // Encode the length of the array. + self.write_decimal(val.len() as u64).await?; + + // Iterate and encode each entry in the array. + for entry in &**val { + self.write_value(entry).await?; + } + } + // The frame type is a literal. Encode the value directly. + _ => self.write_value(frame).await?, + } + + // Ensure the encoded frame is written to the socket. The calls above + // are to the buffered stream and writes. Calling `flush` writes the + // remaining contents of the buffer to the socket. + self.stream.flush().await + } + + /// Write a frame literal to the stream + async fn write_value(&mut self, frame: &Frame) -> io::Result<()> { + match frame { + Frame::Simple(val) => { + self.stream.write_u8(b'+').await?; + self.stream.write_all(val.as_bytes()).await?; + self.stream.write_all(b"\r\n").await?; + } + Frame::Error(val) => { + self.stream.write_u8(b'-').await?; + self.stream.write_all(val.as_bytes()).await?; + self.stream.write_all(b"\r\n").await?; + } + Frame::Integer(val) => { + self.stream.write_u8(b':').await?; + self.write_decimal(*val).await?; + } + Frame::Null => { + self.stream.write_all(b"$-1\r\n").await?; + } + Frame::Bulk(val) => { + let len = val.len(); + + self.stream.write_u8(b'$').await?; + self.write_decimal(len as u64).await?; + self.stream.write_all(val).await?; + self.stream.write_all(b"\r\n").await?; + } + // Encoding an `Array` from within a value cannot be done using a + // recursive strategy. In general, async fns do not support + // recursion. Mini-redis has not needed to encode nested arrays yet, + // so for now it is skipped. + Frame::Array(_val) => unreachable!(), + } + + Ok(()) + } + + /// Write a decimal frame to the stream + async fn write_decimal(&mut self, val: u64) -> io::Result<()> { + use std::io::Write; + + // Convert the value to a string + let mut buf = [0u8; 20]; + let mut buf = Cursor::new(&mut buf[..]); + write!(&mut buf, "{}", val)?; + + let pos = buf.position() as usize; + self.stream.write_all(&buf.get_ref()[..pos]).await?; + self.stream.write_all(b"\r\n").await?; + + Ok(()) + } +} diff --git a/assets/mini-redis/src/db.rs b/assets/mini-redis/src/db.rs new file mode 100644 index 0000000..07e33a2 --- /dev/null +++ b/assets/mini-redis/src/db.rs @@ -0,0 +1,378 @@ +use tokio::sync::{broadcast, Notify}; +use tokio::time::{self, Duration, Instant}; + +use bytes::Bytes; +use std::collections::{BTreeMap, HashMap}; +use std::sync::{Arc, Mutex}; +use tracing::debug; + +/// A wrapper around a `Db` instance. This exists to allow orderly cleanup +/// of the `Db` by signalling the background purge task to shut down when +/// this struct is dropped. +#[derive(Debug)] +pub(crate) struct DbDropGuard { + /// The `Db` instance that will be shut down when this `DbHolder` struct + /// is dropped. + db: Db, +} + +/// Server state shared across all connections. +/// +/// `Db` contains a `HashMap` storing the key/value data and all +/// `broadcast::Sender` values for active pub/sub channels. +/// +/// A `Db` instance is a handle to shared state. Cloning `Db` is shallow and +/// only incurs an atomic ref count increment. +/// +/// When a `Db` value is created, a background task is spawned. This task is +/// used to expire values after the requested duration has elapsed. The task +/// runs until all instances of `Db` are dropped, at which point the task +/// terminates. +#[derive(Debug, Clone)] +pub(crate) struct Db { + /// Handle to shared state. The background task will also have an + /// `Arc`. + shared: Arc, +} + +#[derive(Debug)] +struct Shared { + /// The shared state is guarded by a mutex. This is a `std::sync::Mutex` and + /// not a Tokio mutex. This is because there are no asynchronous operations + /// being performed while holding the mutex. Additionally, the critical + /// sections are very small. + /// + /// A Tokio mutex is mostly intended to be used when locks need to be held + /// across `.await` yield points. All other cases are **usually** best + /// served by a std mutex. If the critical section does not include any + /// async operations but is long (CPU intensive or performing blocking + /// operations), then the entire operation, including waiting for the mutex, + /// is considered a "blocking" operation and `tokio::task::spawn_blocking` + /// should be used. + state: Mutex, + + /// Notifies the background task handling entry expiration. The background + /// task waits on this to be notified, then checks for expired values or the + /// shutdown signal. + background_task: Notify, +} + +#[derive(Debug)] +struct State { + /// The key-value data. We are not trying to do anything fancy so a + /// `std::collections::HashMap` works fine. + entries: HashMap, + + /// The pub/sub key-space. Redis uses a **separate** key space for key-value + /// and pub/sub. `mini-redis` handles this by using a separate `HashMap`. + pub_sub: HashMap>, + + /// Tracks key TTLs. + /// + /// A `BTreeMap` is used to maintain expirations sorted by when they expire. + /// This allows the background task to iterate this map to find the value + /// expiring next. + /// + /// While highly unlikely, it is possible for more than one expiration to be + /// created for the same instant. Because of this, the `Instant` is + /// insufficient for the key. A unique expiration identifier (`u64`) is used + /// to break these ties. + expirations: BTreeMap<(Instant, u64), String>, + + /// Identifier to use for the next expiration. Each expiration is associated + /// with a unique identifier. See above for why. + next_id: u64, + + /// True when the Db instance is shutting down. This happens when all `Db` + /// values drop. Setting this to `true` signals to the background task to + /// exit. + shutdown: bool, +} + +/// Entry in the key-value store +#[derive(Debug)] +struct Entry { + /// Uniquely identifies this entry. + id: u64, + + /// Stored data + data: Bytes, + + /// Instant at which the entry expires and should be removed from the + /// database. + expires_at: Option, +} + +impl DbDropGuard { + /// Create a new `DbHolder`, wrapping a `Db` instance. When this is dropped + /// the `Db`'s purge task will be shut down. + pub(crate) fn new() -> DbDropGuard { + DbDropGuard { db: Db::new() } + } + + /// Get the shared database. Internally, this is an + /// `Arc`, so a clone only increments the ref count. + pub(crate) fn db(&self) -> Db { + self.db.clone() + } +} + +impl Drop for DbDropGuard { + fn drop(&mut self) { + // Signal the 'Db' instance to shut down the task that purges expired keys + self.db.shutdown_purge_task(); + } +} + +impl Db { + /// Create a new, empty, `Db` instance. Allocates shared state and spawns a + /// background task to manage key expiration. + pub(crate) fn new() -> Db { + let shared = Arc::new(Shared { + state: Mutex::new(State { + entries: HashMap::new(), + pub_sub: HashMap::new(), + expirations: BTreeMap::new(), + next_id: 0, + shutdown: false, + }), + background_task: Notify::new(), + }); + + // Start the background task. + tokio::spawn(purge_expired_tasks(shared.clone())); + + Db { shared } + } + + /// Get the value associated with a key. + /// + /// Returns `None` if there is no value associated with the key. This may be + /// due to never having assigned a value to the key or a previously assigned + /// value expired. + pub(crate) fn get(&self, key: &str) -> Option { + // Acquire the lock, get the entry and clone the value. + // + // Because data is stored using `Bytes`, a clone here is a shallow + // clone. Data is not copied. + let state = self.shared.state.lock().unwrap(); + state.entries.get(key).map(|entry| entry.data.clone()) + } + + /// Set the value associated with a key along with an optional expiration + /// Duration. + /// + /// If a value is already associated with the key, it is removed. + pub(crate) fn set(&self, key: String, value: Bytes, expire: Option) { + let mut state = self.shared.state.lock().unwrap(); + + // Get and increment the next insertion ID. Guarded by the lock, this + // ensures a unique identifier is associated with each `set` operation. + let id = state.next_id; + state.next_id += 1; + + // If this `set` becomes the key that expires **next**, the background + // task needs to be notified so it can update its state. + // + // Whether or not the task needs to be notified is computed during the + // `set` routine. + let mut notify = false; + + let expires_at = expire.map(|duration| { + // `Instant` at which the key expires. + let when = Instant::now() + duration; + + // Only notify the worker task if the newly inserted expiration is the + // **next** key to evict. In this case, the worker needs to be woken up + // to update its state. + notify = state + .next_expiration() + .map(|expiration| expiration > when) + .unwrap_or(true); + + // Track the expiration. + state.expirations.insert((when, id), key.clone()); + when + }); + + // Insert the entry into the `HashMap`. + let prev = state.entries.insert( + key, + Entry { + id, + data: value, + expires_at, + }, + ); + + // If there was a value previously associated with the key **and** it + // had an expiration time. The associated entry in the `expirations` map + // must also be removed. This avoids leaking data. + if let Some(prev) = prev { + if let Some(when) = prev.expires_at { + // clear expiration + state.expirations.remove(&(when, prev.id)); + } + } + + // Release the mutex before notifying the background task. This helps + // reduce contention by avoiding the background task waking up only to + // be unable to acquire the mutex due to this function still holding it. + drop(state); + + if notify { + // Finally, only notify the background task if it needs to update + // its state to reflect a new expiration. + self.shared.background_task.notify_one(); + } + } + + /// Returns a `Receiver` for the requested channel. + /// + /// The returned `Receiver` is used to receive values broadcast by `PUBLISH` + /// commands. + pub(crate) fn subscribe(&self, key: String) -> broadcast::Receiver { + use std::collections::hash_map::Entry; + + // Acquire the mutex + let mut state = self.shared.state.lock().unwrap(); + + // If there is no entry for the requested channel, then create a new + // broadcast channel and associate it with the key. If one already + // exists, return an associated receiver. + match state.pub_sub.entry(key) { + Entry::Occupied(e) => e.get().subscribe(), + Entry::Vacant(e) => { + // No broadcast channel exists yet, so create one. + // + // The channel is created with a capacity of `1024` messages. A + // message is stored in the channel until **all** subscribers + // have seen it. This means that a slow subscriber could result + // in messages being held indefinitely. + // + // When the channel's capacity fills up, publishing will result + // in old messages being dropped. This prevents slow consumers + // from blocking the entire system. + let (tx, rx) = broadcast::channel(1024); + e.insert(tx); + rx + } + } + } + + /// Publish a message to the channel. Returns the number of subscribers + /// listening on the channel. + pub(crate) fn publish(&self, key: &str, value: Bytes) -> usize { + let state = self.shared.state.lock().unwrap(); + + state + .pub_sub + .get(key) + // On a successful message send on the broadcast channel, the number + // of subscribers is returned. An error indicates there are no + // receivers, in which case, `0` should be returned. + .map(|tx| tx.send(value).unwrap_or(0)) + // If there is no entry for the channel key, then there are no + // subscribers. In this case, return `0`. + .unwrap_or(0) + } + + /// Signals the purge background task to shut down. This is called by the + /// `DbShutdown`s `Drop` implementation. + fn shutdown_purge_task(&self) { + // The background task must be signaled to shut down. This is done by + // setting `State::shutdown` to `true` and signalling the task. + let mut state = self.shared.state.lock().unwrap(); + state.shutdown = true; + + // Drop the lock before signalling the background task. This helps + // reduce lock contention by ensuring the background task doesn't + // wake up only to be unable to acquire the mutex. + drop(state); + self.shared.background_task.notify_one(); + } +} + +impl Shared { + /// Purge all expired keys and return the `Instant` at which the **next** + /// key will expire. The background task will sleep until this instant. + fn purge_expired_keys(&self) -> Option { + let mut state = self.state.lock().unwrap(); + + if state.shutdown { + // The database is shutting down. All handles to the shared state + // have dropped. The background task should exit. + return None; + } + + // This is needed to make the borrow checker happy. In short, `lock()` + // returns a `MutexGuard` and not a `&mut State`. The borrow checker is + // not able to see "through" the mutex guard and determine that it is + // safe to access both `state.expirations` and `state.entries` mutably, + // so we get a "real" mutable reference to `State` outside of the loop. + let state = &mut *state; + + // Find all keys scheduled to expire **before** now. + let now = Instant::now(); + + while let Some((&(when, id), key)) = state.expirations.iter().next() { + if when > now { + // Done purging, `when` is the instant at which the next key + // expires. The worker task will wait until this instant. + return Some(when); + } + + // The key expired, remove it + state.entries.remove(key); + state.expirations.remove(&(when, id)); + } + + None + } + + /// Returns `true` if the database is shutting down + /// + /// The `shutdown` flag is set when all `Db` values have dropped, indicating + /// that the shared state can no longer be accessed. + fn is_shutdown(&self) -> bool { + self.state.lock().unwrap().shutdown + } +} + +impl State { + fn next_expiration(&self) -> Option { + self.expirations + .keys() + .next() + .map(|expiration| expiration.0) + } +} + +/// Routine executed by the background task. +/// +/// Wait to be notified. On notification, purge any expired keys from the shared +/// state handle. If `shutdown` is set, terminate the task. +async fn purge_expired_tasks(shared: Arc) { + // If the shutdown flag is set, then the task should exit. + while !shared.is_shutdown() { + // Purge all keys that are expired. The function returns the instant at + // which the **next** key will expire. The worker should wait until the + // instant has passed then purge again. + if let Some(when) = shared.purge_expired_keys() { + // Wait until the next key expires **or** until the background task + // is notified. If the task is notified, then it must reload its + // state as new keys have been set to expire early. This is done by + // looping. + tokio::select! { + _ = time::sleep_until(when) => {} + _ = shared.background_task.notified() => {} + } + } else { + // There are no keys expiring in the future. Wait until the task is + // notified. + shared.background_task.notified().await; + } + } + + debug!("Purge background task shut down") +} diff --git a/assets/mini-redis/src/frame.rs b/assets/mini-redis/src/frame.rs new file mode 100644 index 0000000..6b26719 --- /dev/null +++ b/assets/mini-redis/src/frame.rs @@ -0,0 +1,300 @@ +//! Provides a type representing a Redis protocol frame as well as utilities for +//! parsing frames from a byte array. + +use bytes::{Buf, Bytes}; +use std::convert::TryInto; +use std::fmt; +use std::io::Cursor; +use std::num::TryFromIntError; +use std::string::FromUtf8Error; + +/// A frame in the Redis protocol. +#[derive(Clone, Debug)] +pub enum Frame { + Simple(String), + Error(String), + Integer(u64), + Bulk(Bytes), + Null, + Array(Vec), +} + +#[derive(Debug)] +pub enum Error { + /// Not enough data is available to parse a message + Incomplete, + + /// Invalid message encoding + Other(crate::Error), +} + +impl Frame { + /// Returns an empty array + pub(crate) fn array() -> Frame { + Frame::Array(vec![]) + } + + /// Push a "bulk" frame into the array. `self` must be an Array frame. + /// + /// # Panics + /// + /// panics if `self` is not an array + pub(crate) fn push_bulk(&mut self, bytes: Bytes) { + match self { + Frame::Array(vec) => { + vec.push(Frame::Bulk(bytes)); + } + _ => panic!("not an array frame"), + } + } + + /// Push an "integer" frame into the array. `self` must be an Array frame. + /// + /// # Panics + /// + /// panics if `self` is not an array + pub(crate) fn push_int(&mut self, value: u64) { + match self { + Frame::Array(vec) => { + vec.push(Frame::Integer(value)); + } + _ => panic!("not an array frame"), + } + } + + /// Checks if an entire message can be decoded from `src` + pub fn check(src: &mut Cursor<&[u8]>) -> Result<(), Error> { + match get_u8(src)? { + b'+' => { + get_line(src)?; + Ok(()) + } + b'-' => { + get_line(src)?; + Ok(()) + } + b':' => { + let _ = get_decimal(src)?; + Ok(()) + } + b'$' => { + if b'-' == peek_u8(src)? { + // Skip '-1\r\n' + skip(src, 4) + } else { + // Read the bulk string + let len: usize = get_decimal(src)?.try_into()?; + + // skip that number of bytes + 2 (\r\n). + skip(src, len + 2) + } + } + b'*' => { + let len = get_decimal(src)?; + + for _ in 0..len { + Frame::check(src)?; + } + + Ok(()) + } + actual => Err(format!("protocol error; invalid frame type byte `{}`", actual).into()), + } + } + + /// The message has already been validated with `check`. + pub fn parse(src: &mut Cursor<&[u8]>) -> Result { + match get_u8(src)? { + b'+' => { + // Read the line and convert it to `Vec` + let line = get_line(src)?.to_vec(); + + // Convert the line to a String + let string = String::from_utf8(line)?; + + Ok(Frame::Simple(string)) + } + b'-' => { + // Read the line and convert it to `Vec` + let line = get_line(src)?.to_vec(); + + // Convert the line to a String + let string = String::from_utf8(line)?; + + Ok(Frame::Error(string)) + } + b':' => { + let len = get_decimal(src)?; + Ok(Frame::Integer(len)) + } + b'$' => { + if b'-' == peek_u8(src)? { + let line = get_line(src)?; + + if line != b"-1" { + return Err("protocol error; invalid frame format".into()); + } + + Ok(Frame::Null) + } else { + // Read the bulk string + let len = get_decimal(src)?.try_into()?; + let n = len + 2; + + if src.remaining() < n { + return Err(Error::Incomplete); + } + + let data = Bytes::copy_from_slice(&src.chunk()[..len]); + + // skip that number of bytes + 2 (\r\n). + skip(src, n)?; + + Ok(Frame::Bulk(data)) + } + } + b'*' => { + let len = get_decimal(src)?.try_into()?; + let mut out = Vec::with_capacity(len); + + for _ in 0..len { + out.push(Frame::parse(src)?); + } + + Ok(Frame::Array(out)) + } + _ => unimplemented!(), + } + } + + /// Converts the frame to an "unexpected frame" error + pub(crate) fn to_error(&self) -> crate::Error { + format!("unexpected frame: {}", self).into() + } +} + +impl PartialEq<&str> for Frame { + fn eq(&self, other: &&str) -> bool { + match self { + Frame::Simple(s) => s.eq(other), + Frame::Bulk(s) => s.eq(other), + _ => false, + } + } +} + +impl fmt::Display for Frame { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + use std::str; + + match self { + Frame::Simple(response) => response.fmt(fmt), + Frame::Error(msg) => write!(fmt, "error: {}", msg), + Frame::Integer(num) => num.fmt(fmt), + Frame::Bulk(msg) => match str::from_utf8(msg) { + Ok(string) => string.fmt(fmt), + Err(_) => write!(fmt, "{:?}", msg), + }, + Frame::Null => "(nil)".fmt(fmt), + Frame::Array(parts) => { + for (i, part) in parts.iter().enumerate() { + if i > 0 { + write!(fmt, " ")?; + part.fmt(fmt)?; + } + } + + Ok(()) + } + } + } +} + +fn peek_u8(src: &mut Cursor<&[u8]>) -> Result { + if !src.has_remaining() { + return Err(Error::Incomplete); + } + + Ok(src.chunk()[0]) +} + +fn get_u8(src: &mut Cursor<&[u8]>) -> Result { + if !src.has_remaining() { + return Err(Error::Incomplete); + } + + Ok(src.get_u8()) +} + +fn skip(src: &mut Cursor<&[u8]>, n: usize) -> Result<(), Error> { + if src.remaining() < n { + return Err(Error::Incomplete); + } + + src.advance(n); + Ok(()) +} + +/// Read a new-line terminated decimal +fn get_decimal(src: &mut Cursor<&[u8]>) -> Result { + use atoi::atoi; + + let line = get_line(src)?; + + atoi::(line).ok_or_else(|| "protocol error; invalid frame format".into()) +} + +/// Find a line +fn get_line<'a>(src: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], Error> { + // Scan the bytes directly + let start = src.position() as usize; + // Scan to the second to last byte + let end = src.get_ref().len() - 1; + + for i in start..end { + if src.get_ref()[i] == b'\r' && src.get_ref()[i + 1] == b'\n' { + // We found a line, update the position to be *after* the \n + src.set_position((i + 2) as u64); + + // Return the line + return Ok(&src.get_ref()[start..i]); + } + } + + Err(Error::Incomplete) +} + +impl From for Error { + fn from(src: String) -> Error { + Error::Other(src.into()) + } +} + +impl From<&str> for Error { + fn from(src: &str) -> Error { + src.to_string().into() + } +} + +impl From for Error { + fn from(_src: FromUtf8Error) -> Error { + "protocol error; invalid frame format".into() + } +} + +impl From for Error { + fn from(_src: TryFromIntError) -> Error { + "protocol error; invalid frame format".into() + } +} + +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Incomplete => "stream ended early".fmt(fmt), + Error::Other(err) => err.fmt(fmt), + } + } +} diff --git a/assets/mini-redis/src/lib.rs b/assets/mini-redis/src/lib.rs new file mode 100644 index 0000000..48c472b --- /dev/null +++ b/assets/mini-redis/src/lib.rs @@ -0,0 +1,76 @@ +//! A minimal (i.e. very incomplete) implementation of a Redis server and +//! client. +//! +//! The purpose of this project is to provide a larger example of an +//! asynchronous Rust project built with Tokio. Do not attempt to run this in +//! production... seriously. +//! +//! # Layout +//! +//! The library is structured such that it can be used with guides. There are +//! modules that are public that probably would not be public in a "real" redis +//! client library. +//! +//! The major components are: +//! +//! * `server`: Redis server implementation. Includes a single `run` function +//! that takes a `TcpListener` and starts accepting redis client connections. +//! +//! * `client`: an asynchronous Redis client implementation. Demonstrates how to +//! build clients with Tokio. +//! +//! * `cmd`: implementations of the supported Redis commands. +//! +//! * `frame`: represents a single Redis protocol frame. A frame is used as an +//! intermediate representation between a "command" and the byte +//! representation. + +pub mod blocking_client; +pub mod client; + +pub mod cmd; +pub use cmd::Command; + +mod connection; +pub use connection::Connection; + +pub mod frame; +pub use frame::Frame; + +mod db; +use db::Db; +use db::DbDropGuard; + +mod parse; +use parse::{Parse, ParseError}; + +pub mod server; + +mod buffer; +pub use buffer::{buffer, Buffer}; + +mod shutdown; +use shutdown::Shutdown; + +/// Default port that a redis server listens on. +/// +/// Used if no port is specified. +pub const DEFAULT_PORT: &str = "6379"; + +/// Error returned by most functions. +/// +/// When writing a real application, one might want to consider a specialized +/// error handling crate or defining an error type as an `enum` of causes. +/// However, for our example, using a boxed `std::error::Error` is sufficient. +/// +/// For performance reasons, boxing is avoided in any hot path. For example, in +/// `parse`, a custom error `enum` is defined. This is because the error is hit +/// and handled during normal execution when a partial frame is received on a +/// socket. `std::error::Error` is implemented for `parse::Error` which allows +/// it to be converted to `Box`. +pub type Error = Box; + +/// A specialized `Result` type for mini-redis operations. +/// +/// This is defined as a convenience. +pub type Result = std::result::Result; diff --git a/assets/mini-redis/src/parse.rs b/assets/mini-redis/src/parse.rs new file mode 100644 index 0000000..5cb0d7e --- /dev/null +++ b/assets/mini-redis/src/parse.rs @@ -0,0 +1,152 @@ +use crate::Frame; + +use bytes::Bytes; +use std::{fmt, str, vec}; + +/// Utility for parsing a command +/// +/// Commands are represented as array frames. Each entry in the frame is a +/// "token". A `Parse` is initialized with the array frame and provides a +/// cursor-like API. Each command struct includes a `parse_frame` method that +/// uses a `Parse` to extract its fields. +#[derive(Debug)] +pub(crate) struct Parse { + /// Array frame iterator. + parts: vec::IntoIter, +} + +/// Error encountered while parsing a frame. +/// +/// Only `EndOfStream` errors are handled at runtime. All other errors result in +/// the connection being terminated. +#[derive(Debug)] +pub(crate) enum ParseError { + /// Attempting to extract a value failed due to the frame being fully + /// consumed. + EndOfStream, + + /// All other errors + Other(crate::Error), +} + +impl Parse { + /// Create a new `Parse` to parse the contents of `frame`. + /// + /// Returns `Err` if `frame` is not an array frame. + pub(crate) fn new(frame: Frame) -> Result { + let array = match frame { + Frame::Array(array) => array, + frame => return Err(format!("protocol error; expected array, got {:?}", frame).into()), + }; + + Ok(Parse { + parts: array.into_iter(), + }) + } + + /// Return the next entry. Array frames are arrays of frames, so the next + /// entry is a frame. + fn next(&mut self) -> Result { + self + .parts + .next() + .ok_or(ParseError::EndOfStream) + } + + /// Return the next entry as a string. + /// + /// If the next entry cannot be represented as a String, then an error is returned. + pub(crate) fn next_string(&mut self) -> Result { + match self.next()? { + // Both `Simple` and `Bulk` representation may be strings. Strings + // are parsed to UTF-8. + // + // While errors are stored as strings, they are considered separate + // types. + Frame::Simple(s) => Ok(s), + Frame::Bulk(data) => str::from_utf8(&data[..]) + .map(|s| s.to_string()) + .map_err(|_| "protocol error; invalid string".into()), + frame => Err(format!( + "protocol error; expected simple frame or bulk frame, got {:?}", + frame + ) + .into()), + } + } + + /// Return the next entry as raw bytes. + /// + /// If the next entry cannot be represented as raw bytes, an error is + /// returned. + pub(crate) fn next_bytes(&mut self) -> Result { + match self.next()? { + // Both `Simple` and `Bulk` representation may be raw bytes. + // + // Although errors are stored as strings and could be represented as + // raw bytes, they are considered separate types. + Frame::Simple(s) => Ok(Bytes::from(s.into_bytes())), + Frame::Bulk(data) => Ok(data), + frame => Err(format!( + "protocol error; expected simple frame or bulk frame, got {:?}", + frame + ) + .into()), + } + } + + /// Return the next entry as an integer. + /// + /// This includes `Simple`, `Bulk`, and `Integer` frame types. `Simple` and + /// `Bulk` frame types are parsed. + /// + /// If the next entry cannot be represented as an integer, then an error is + /// returned. + pub(crate) fn next_int(&mut self) -> Result { + use atoi::atoi; + + const MSG: &str = "protocol error; invalid number"; + + match self.next()? { + // An integer frame type is already stored as an integer. + Frame::Integer(v) => Ok(v), + // Simple and bulk frames must be parsed as integers. If the parsing + // fails, an error is returned. + Frame::Simple(data) => atoi::(data.as_bytes()).ok_or_else(|| MSG.into()), + Frame::Bulk(data) => atoi::(&data).ok_or_else(|| MSG.into()), + frame => Err(format!("protocol error; expected int frame but got {:?}", frame).into()), + } + } + + /// Ensure there are no more entries in the array + pub(crate) fn finish(&mut self) -> Result<(), ParseError> { + if self.parts.next().is_none() { + Ok(()) + } else { + Err("protocol error; expected end of frame, but there was more".into()) + } + } +} + +impl From for ParseError { + fn from(src: String) -> ParseError { + ParseError::Other(src.into()) + } +} + +impl From<&str> for ParseError { + fn from(src: &str) -> ParseError { + src.to_string().into() + } +} + +impl fmt::Display for ParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + ParseError::EndOfStream => "protocol error; unexpected end of stream".fmt(f), + ParseError::Other(err) => err.fmt(f), + } + } +} + +impl std::error::Error for ParseError {} diff --git a/assets/mini-redis/src/server.rs b/assets/mini-redis/src/server.rs new file mode 100644 index 0000000..4728365 --- /dev/null +++ b/assets/mini-redis/src/server.rs @@ -0,0 +1,399 @@ +//! Minimal Redis server implementation +//! +//! Provides an async `run` function that listens for inbound connections, +//! spawning a task per connection. + +use crate::{Command, Connection, Db, DbDropGuard, Shutdown}; + +use std::future::Future; +use std::sync::Arc; +use tokio::net::{TcpListener, TcpStream}; +use tokio::sync::{broadcast, mpsc, Semaphore}; +use tokio::time::{self, Duration}; +use tracing::{debug, error, info, instrument}; + +/// Server listener state. Created in the `run` call. It includes a `run` method +/// which performs the TCP listening and initialization of per-connection state. +#[derive(Debug)] +struct Listener { + /// Shared database handle. + /// + /// Contains the key / value store as well as the broadcast channels for + /// pub/sub. + /// + /// This holds a wrapper around an `Arc`. The internal `Db` can be + /// retrieved and passed into the per connection state (`Handler`). + db_holder: DbDropGuard, + + /// TCP listener supplied by the `run` caller. + listener: TcpListener, + + /// Limit the max number of connections. + /// + /// A `Semaphore` is used to limit the max number of connections. Before + /// attempting to accept a new connection, a permit is acquired from the + /// semaphore. If none are available, the listener waits for one. + /// + /// When handlers complete processing a connection, the permit is returned + /// to the semaphore. + limit_connections: Arc, + + /// Broadcasts a shutdown signal to all active connections. + /// + /// The initial `shutdown` trigger is provided by the `run` caller. The + /// server is responsible for gracefully shutting down active connections. + /// When a connection task is spawned, it is passed a broadcast receiver + /// handle. When a graceful shutdown is initiated, a `()` value is sent via + /// the broadcast::Sender. Each active connection receives it, reaches a + /// safe terminal state, and completes the task. + notify_shutdown: broadcast::Sender<()>, + + /// Used as part of the graceful shutdown process to wait for client + /// connections to complete processing. + /// + /// Tokio channels are closed once all `Sender` handles go out of scope. + /// When a channel is closed, the receiver receives `None`. This is + /// leveraged to detect all connection handlers completing. When a + /// connection handler is initialized, it is assigned a clone of + /// `shutdown_complete_tx`. When the listener shuts down, it drops the + /// sender held by this `shutdown_complete_tx` field. Once all handler tasks + /// complete, all clones of the `Sender` are also dropped. This results in + /// `shutdown_complete_rx.recv()` completing with `None`. At this point, it + /// is safe to exit the server process. + shutdown_complete_rx: mpsc::Receiver<()>, + shutdown_complete_tx: mpsc::Sender<()>, +} + +/// Per-connection handler. Reads requests from `connection` and applies the +/// commands to `db`. +#[derive(Debug)] +struct Handler { + /// Shared database handle. + /// + /// When a command is received from `connection`, it is applied with `db`. + /// The implementation of the command is in the `cmd` module. Each command + /// will need to interact with `db` in order to complete the work. + db: Db, + + /// The TCP connection decorated with the redis protocol encoder / decoder + /// implemented using a buffered `TcpStream`. + /// + /// When `Listener` receives an inbound connection, the `TcpStream` is + /// passed to `Connection::new`, which initializes the associated buffers. + /// `Connection` allows the handler to operate at the "frame" level and keep + /// the byte level protocol parsing details encapsulated in `Connection`. + connection: Connection, + + /// Max connection semaphore. + /// + /// When the handler is dropped, a permit is returned to this semaphore. If + /// the listener is waiting for connections to close, it will be notified of + /// the newly available permit and resume accepting connections. + limit_connections: Arc, + + /// Listen for shutdown notifications. + /// + /// A wrapper around the `broadcast::Receiver` paired with the sender in + /// `Listener`. The connection handler processes requests from the + /// connection until the peer disconnects **or** a shutdown notification is + /// received from `shutdown`. In the latter case, any in-flight work being + /// processed for the peer is continued until it reaches a safe state, at + /// which point the connection is terminated. + shutdown: Shutdown, + + /// Not used directly. Instead, when `Handler` is dropped...? + _shutdown_complete: mpsc::Sender<()>, +} + +/// Maximum number of concurrent connections the redis server will accept. +/// +/// When this limit is reached, the server will stop accepting connections until +/// an active connection terminates. +/// +/// A real application will want to make this value configurable, but for this +/// example, it is hard coded. +/// +/// This is also set to a pretty low value to discourage using this in +/// production (you'd think that all the disclaimers would make it obvious that +/// this is not a serious project... but I thought that about mini-http as +/// well). +const MAX_CONNECTIONS: usize = 250; + +/// Run the mini-redis server. +/// +/// Accepts connections from the supplied listener. For each inbound connection, +/// a task is spawned to handle that connection. The server runs until the +/// `shutdown` future completes, at which point the server shuts down +/// gracefully. +/// +/// `tokio::signal::ctrl_c()` can be used as the `shutdown` argument. This will +/// listen for a SIGINT signal. +pub async fn run(listener: TcpListener, shutdown: impl Future) { + // When the provided `shutdown` future completes, we must send a shutdown + // message to all active connections. We use a broadcast channel for this + // purpose. The call below ignores the receiver of the broadcast pair, and when + // a receiver is needed, the subscribe() method on the sender is used to create + // one. + let (notify_shutdown, _) = broadcast::channel(1); + + let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1); + + // Initialize the listener state + let mut server = Listener { + listener, + db_holder: DbDropGuard::new(), + limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)), + notify_shutdown, + shutdown_complete_tx, + shutdown_complete_rx, + }; + + // Concurrently run the server and listen for the `shutdown` signal. The + // server task runs until an error is encountered, so under normal + // circumstances, this `select!` statement runs until the `shutdown` signal + // is received. + // + // `select!` statements are written in the form of: + // + // ``` + // = => + // ``` + // + // All `` statements are executed concurrently. Once the **first** + // op completes, its associated `` is + // performed. + // + // The `select! macro is a foundational building block for writing + // asynchronous Rust. See the API docs for more details: + // + // https://docs.rs/tokio/*/tokio/macro.select.html + tokio::select! { + res = server.run() => { + // If an error is received here, accepting connections from the TCP + // listener failed multiple times and the server is giving up and + // shutting down. + // + // Errors encountered when handling individual connections do not + // bubble up to this point. + if let Err(err) = res { + error!(cause = %err, "failed to accept"); + } + } + _ = shutdown => { + // The shutdown signal has been received. + info!("shutting down"); + } + } + + // Extract the `shutdown_complete` receiver and transmitter + // explicitly drop `shutdown_transmitter`. This is important, as the + // `.await` below would otherwise never complete. + let Listener { + mut shutdown_complete_rx, + shutdown_complete_tx, + notify_shutdown, + .. + } = server; + + // When `notify_shutdown` is dropped, all tasks which have `subscribe`d will + // receive the shutdown signal and can exit + drop(notify_shutdown); + // Drop final `Sender` so the `Receiver` below can complete + drop(shutdown_complete_tx); + + // Wait for all active connections to finish processing. As the `Sender` + // handle held by the listener has been dropped above, the only remaining + // `Sender` instances are held by connection handler tasks. When those drop, + // the `mpsc` channel will close and `recv()` will return `None`. + let _ = shutdown_complete_rx.recv().await; +} + +impl Listener { + /// Run the server + /// + /// Listen for inbound connections. For each inbound connection, spawn a + /// task to process that connection. + /// + /// # Errors + /// + /// Returns `Err` if accepting returns an error. This can happen for a + /// number reasons that resolve over time. For example, if the underlying + /// operating system has reached an internal limit for max number of + /// sockets, accept will fail. + /// + /// The process is not able to detect when a transient error resolves + /// itself. One strategy for handling this is to implement a back off + /// strategy, which is what we do here. + async fn run(&mut self) -> crate::Result<()> { + info!("accepting inbound connections"); + + loop { + // Wait for a permit to become available + // + // `acquire` returns a permit that is bound via a lifetime to the + // semaphore. When the permit value is dropped, it is automatically + // returned to the semaphore. This is convenient in many cases. + // However, in this case, the permit must be returned in a different + // task than it is acquired in (the handler task). To do this, we + // "forget" the permit, which drops the permit value **without** + // incrementing the semaphore's permits. Then, in the handler task + // we manually add a new permit when processing completes. + // + // `acquire()` returns `Err` when the semaphore has been closed. We + // don't ever close the sempahore, so `unwrap()` is safe. + self.limit_connections.acquire().await.unwrap().forget(); + + // Accept a new socket. This will attempt to perform error handling. + // The `accept` method internally attempts to recover errors, so an + // error here is non-recoverable. + let socket = self.accept().await?; + + // Create the necessary per-connection handler state. + let mut handler = Handler { + // Get a handle to the shared database. + db: self.db_holder.db(), + + // Initialize the connection state. This allocates read/write + // buffers to perform redis protocol frame parsing. + connection: Connection::new(socket), + + // The connection state needs a handle to the max connections + // semaphore. When the handler is done processing the + // connection, a permit is added back to the semaphore. + limit_connections: self.limit_connections.clone(), + + // Receive shutdown notifications. + shutdown: Shutdown::new(self.notify_shutdown.subscribe()), + + // Notifies the receiver half once all clones are + // dropped. + _shutdown_complete: self.shutdown_complete_tx.clone(), + }; + + // Spawn a new task to process the connections. Tokio tasks are like + // asynchronous green threads and are executed concurrently. + tokio::spawn(async move { + // Process the connection. If an error is encountered, log it. + if let Err(err) = handler.run().await { + error!(cause = ?err, "connection error"); + } + }); + } + } + + /// Accept an inbound connection. + /// + /// Errors are handled by backing off and retrying. An exponential backoff + /// strategy is used. After the first failure, the task waits for 1 second. + /// After the second failure, the task waits for 2 seconds. Each subsequent + /// failure doubles the wait time. If accepting fails on the 6th try after + /// waiting for 64 seconds, then this function returns with an error. + async fn accept(&mut self) -> crate::Result { + let mut backoff = 1; + + // Try to accept a few times + loop { + // Perform the accept operation. If a socket is successfully + // accepted, return it. Otherwise, save the error. + match self.listener.accept().await { + Ok((socket, _)) => return Ok(socket), + Err(err) => { + if backoff > 64 { + // Accept has failed too many times. Return the error. + return Err(err.into()); + } + } + } + + // Pause execution until the back off period elapses. + time::sleep(Duration::from_secs(backoff)).await; + + // Double the back off + backoff *= 2; + } + } +} + +impl Handler { + /// Process a single connection. + /// + /// Request frames are read from the socket and processed. Responses are + /// written back to the socket. + /// + /// Currently, pipelining is not implemented. Pipelining is the ability to + /// process more than one request concurrently per connection without + /// interleaving frames. See for more details: + /// https://redis.io/topics/pipelining + /// + /// When the shutdown signal is received, the connection is processed until + /// it reaches a safe state, at which point it is terminated. + #[instrument(skip(self))] + async fn run(&mut self) -> crate::Result<()> { + // As long as the shutdown signal has not been received, try to read a + // new request frame. + while !self.shutdown.is_shutdown() { + // While reading a request frame, also listen for the shutdown + // signal. + let maybe_frame = tokio::select! { + res = self.connection.read_frame() => res?, + _ = self.shutdown.recv() => { + // If a shutdown signal is received, return from `run`. + // This will result in the task terminating. + return Ok(()); + } + }; + + // If `None` is returned from `read_frame()` then the peer closed + // the socket. There is no further work to do and the task can be + // terminated. + let frame = match maybe_frame { + Some(frame) => frame, + None => return Ok(()), + }; + + // Convert the redis frame into a command struct. This returns an + // error if the frame is not a valid redis command or it is an + // unsupported command. + let cmd = Command::from_frame(frame)?; + + // Logs the `cmd` object. The syntax here is a shorthand provided by + // the `tracing` crate. It can be thought of as similar to: + // + // ``` + // debug!(cmd = format!("{:?}", cmd)); + // ``` + // + // `tracing` provides structured logging, so information is "logged" + // as key-value pairs. + debug!(?cmd); + + // Perform the work needed to apply the command. This may mutate the + // database state as a result. + // + // The connection is passed into the apply function which allows the + // command to write response frames directly to the connection. In + // the case of pub/sub, multiple frames may be send back to the + // peer. + cmd.apply(&self.db, &mut self.connection, &mut self.shutdown) + .await?; + } + + Ok(()) + } +} + +impl Drop for Handler { + fn drop(&mut self) { + // Add a permit back to the semaphore. + // + // Doing so unblocks the listener if the max number of + // connections has been reached. + // + // This is done in a `Drop` implementation in order to guarantee that + // the permit is added even if the task handling the connection panics. + // If `add_permit` was called at the end of the `run` function and some + // bug causes a panic. The permit would never be returned to the + // semaphore. + self.limit_connections.add_permits(1); + } +} diff --git a/assets/mini-redis/src/shutdown.rs b/assets/mini-redis/src/shutdown.rs new file mode 100644 index 0000000..bf1b1c3 --- /dev/null +++ b/assets/mini-redis/src/shutdown.rs @@ -0,0 +1,49 @@ +use tokio::sync::broadcast; + +/// Listens for the server shutdown signal. +/// +/// Shutdown is signalled using a `broadcast::Receiver`. Only a single value is +/// ever sent. Once a value has been sent via the broadcast channel, the server +/// should shutdown. +/// +/// The `Shutdown` struct listens for the signal and tracks that the signal has +/// been received. Callers may query for whether the shutdown signal has been +/// received or not. +#[derive(Debug)] +pub(crate) struct Shutdown { + /// `true` if the shutdown signal has been received + shutdown: bool, + + /// The receive half of the channel used to listen for shutdown. + notify: broadcast::Receiver<()>, +} + +impl Shutdown { + /// Create a new `Shutdown` backed by the given `broadcast::Receiver`. + pub(crate) fn new(notify: broadcast::Receiver<()>) -> Shutdown { + Shutdown { + shutdown: false, + notify, + } + } + + /// Returns `true` if the shutdown signal has been received. + pub(crate) fn is_shutdown(&self) -> bool { + self.shutdown + } + + /// Receive the shutdown notice, waiting if necessary. + pub(crate) async fn recv(&mut self) { + // If the shutdown signal has already been received, then return + // immediately. + if self.shutdown { + return; + } + + // Cannot receive a "lag error" as only one value is ever sent. + let _ = self.notify.recv().await; + + // Remember that the signal has been received. + self.shutdown = true; + } +} diff --git a/assets/mini-redis/tests/buffer.rs b/assets/mini-redis/tests/buffer.rs new file mode 100644 index 0000000..823b720 --- /dev/null +++ b/assets/mini-redis/tests/buffer.rs @@ -0,0 +1,30 @@ +use mini_redis::{buffer, client, server}; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tokio::task::JoinHandle; + +/// A basic "hello world" style test. A server instance is started in a +/// background task. A client instance is then established and used to intialize +/// the buffer. Set and get commands are sent to the server. The response is +/// then evaluated. +#[tokio::test] +async fn pool_key_value_get_set() { + let (addr, _) = start_server().await; + + let client = client::connect(addr).await.unwrap(); + let mut client = buffer(client); + + client.set("hello", "world".into()).await.unwrap(); + + let value = client.get("hello").await.unwrap().unwrap(); + assert_eq!(b"world", &value[..]) +} + +async fn start_server() -> (SocketAddr, JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let handle = tokio::spawn(async move { server::run(listener, tokio::signal::ctrl_c()).await }); + + (addr, handle) +} diff --git a/assets/mini-redis/tests/client.rs b/assets/mini-redis/tests/client.rs new file mode 100644 index 0000000..e2e7b42 --- /dev/null +++ b/assets/mini-redis/tests/client.rs @@ -0,0 +1,92 @@ +use mini_redis::{client, server}; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tokio::task::JoinHandle; + +/// A basic "hello world" style test. A server instance is started in a +/// background task. A client instance is then established and set and get +/// commands are sent to the server. The response is then evaluated +#[tokio::test] +async fn key_value_get_set() { + let (addr, _) = start_server().await; + + let mut client = client::connect(addr).await.unwrap(); + client.set("hello", "world".into()).await.unwrap(); + + let value = client.get("hello").await.unwrap().unwrap(); + assert_eq!(b"world", &value[..]) +} + +/// similar to the "hello world" style test, But this time +/// a single channel subscription will be tested instead +#[tokio::test] +async fn receive_message_subscribed_channel() { + let (addr, _) = start_server().await; + + let client = client::connect(addr.clone()).await.unwrap(); + let mut subscriber = client.subscribe(vec!["hello".into()]).await.unwrap(); + + tokio::spawn(async move { + let mut client = client::connect(addr).await.unwrap(); + client.publish("hello", "world".into()).await.unwrap() + }); + + let message = subscriber.next_message().await.unwrap().unwrap(); + assert_eq!("hello", &message.channel); + assert_eq!(b"world", &message.content[..]) +} + +/// test that a client gets messages from multiple subscribed channels +#[tokio::test] +async fn receive_message_multiple_subscribed_channels() { + let (addr, _) = start_server().await; + + let client = client::connect(addr.clone()).await.unwrap(); + let mut subscriber = client + .subscribe(vec!["hello".into(), "world".into()]) + .await + .unwrap(); + + tokio::spawn(async move { + let mut client = client::connect(addr).await.unwrap(); + client.publish("hello", "world".into()).await.unwrap() + }); + + let message1 = subscriber.next_message().await.unwrap().unwrap(); + assert_eq!("hello", &message1.channel); + assert_eq!(b"world", &message1.content[..]); + + tokio::spawn(async move { + let mut client = client::connect(addr).await.unwrap(); + client.publish("world", "howdy?".into()).await.unwrap() + }); + + let message2 = subscriber.next_message().await.unwrap().unwrap(); + assert_eq!("world", &message2.channel); + assert_eq!(b"howdy?", &message2.content[..]) +} + +/// test that a client accurately removes its own subscribed chanel list +/// when unbscribing to all subscribed channels by submitting an empty vec +#[tokio::test] +async fn unsubscribes_from_channels() { + let (addr, _) = start_server().await; + + let client = client::connect(addr.clone()).await.unwrap(); + let mut subscriber = client + .subscribe(vec!["hello".into(), "world".into()]) + .await + .unwrap(); + + subscriber.unsubscribe(&[]).await.unwrap(); + assert_eq!(subscriber.get_subscribed().len(), 0); +} + +async fn start_server() -> (SocketAddr, JoinHandle<()>) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let handle = tokio::spawn(async move { server::run(listener, tokio::signal::ctrl_c()).await }); + + (addr, handle) +} diff --git a/assets/mini-redis/tests/server.rs b/assets/mini-redis/tests/server.rs new file mode 100644 index 0000000..488cb58 --- /dev/null +++ b/assets/mini-redis/tests/server.rs @@ -0,0 +1,407 @@ +use mini_redis::server; + +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; +use tokio::time::{self, Duration}; + +/// A basic "hello world" style test. A server instance is started in a +/// background task. A client TCP connection is then established and raw redis +/// commands are sent to the server. The response is evaluated at the byte +/// level. +#[tokio::test] +async fn key_value_get_set() { + let addr = start_server().await; + + // Establish a connection to the server + let mut stream = TcpStream::connect(addr).await.unwrap(); + + // Get a key, data is missing + stream + .write_all(b"*2\r\n$3\r\nGET\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + // Read nil response + let mut response = [0; 5]; + stream.read_exact(&mut response).await.unwrap(); + assert_eq!(b"$-1\r\n", &response); + + // Set a key + stream + .write_all(b"*3\r\n$3\r\nSET\r\n$5\r\nhello\r\n$5\r\nworld\r\n") + .await + .unwrap(); + + // Read OK + let mut response = [0; 5]; + stream.read_exact(&mut response).await.unwrap(); + assert_eq!(b"+OK\r\n", &response); + + // Get the key, data is present + stream + .write_all(b"*2\r\n$3\r\nGET\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + // Shutdown the write half + stream.shutdown().await.unwrap(); + + // Read "world" response + let mut response = [0; 11]; + stream.read_exact(&mut response).await.unwrap(); + assert_eq!(b"$5\r\nworld\r\n", &response); + + // Receive `None` + assert_eq!(0, stream.read(&mut response).await.unwrap()); +} + +/// Similar to the basic key-value test, however, this time timeouts will be +/// tested. This test demonstrates how to test time related behavior. +/// +/// When writing tests, it is useful to remove sources of non-determinism. Time +/// is a source of non-determinism. Here, we "pause" time using the +/// `time::pause()` function. This function is available with the `test-util` +/// feature flag. This allows us to deterministically control how time appears +/// to advance to the application. +#[tokio::test] +async fn key_value_timeout() { + tokio::time::pause(); + + let addr = start_server().await; + + // Establish a connection to the server + let mut stream = TcpStream::connect(addr).await.unwrap(); + + // Set a key + stream + .write_all( + b"*5\r\n$3\r\nSET\r\n$5\r\nhello\r\n$5\r\nworld\r\n\ + +EX\r\n:1\r\n", + ) + .await + .unwrap(); + + let mut response = [0; 5]; + + // Read OK + stream.read_exact(&mut response).await.unwrap(); + + assert_eq!(b"+OK\r\n", &response); + + // Get the key, data is present + stream + .write_all(b"*2\r\n$3\r\nGET\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + // Read "world" response + let mut response = [0; 11]; + + stream.read_exact(&mut response).await.unwrap(); + + assert_eq!(b"$5\r\nworld\r\n", &response); + + // Wait for the key to expire + time::advance(Duration::from_secs(1)).await; + + // Get a key, data is missing + stream + .write_all(b"*2\r\n$3\r\nGET\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + // Read nil response + let mut response = [0; 5]; + + stream.read_exact(&mut response).await.unwrap(); + + assert_eq!(b"$-1\r\n", &response); +} + +#[tokio::test] +async fn pub_sub() { + let addr = start_server().await; + + let mut publisher = TcpStream::connect(addr).await.unwrap(); + + // Publish a message, there are no subscribers yet so the server will + // return `0`. + publisher + .write_all(b"*3\r\n$7\r\nPUBLISH\r\n$5\r\nhello\r\n$5\r\nworld\r\n") + .await + .unwrap(); + + let mut response = [0; 4]; + publisher.read_exact(&mut response).await.unwrap(); + assert_eq!(b":0\r\n", &response); + + // Create a subscriber. This subscriber will only subscribe to the `hello` + // channel. + let mut sub1 = TcpStream::connect(addr).await.unwrap(); + sub1.write_all(b"*2\r\n$9\r\nSUBSCRIBE\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + // Read the subscribe response + let mut response = [0; 34]; + sub1.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$9\r\nsubscribe\r\n$5\r\nhello\r\n:1\r\n"[..], + &response[..] + ); + + // Publish a message, there now is a subscriber + publisher + .write_all(b"*3\r\n$7\r\nPUBLISH\r\n$5\r\nhello\r\n$5\r\nworld\r\n") + .await + .unwrap(); + + let mut response = [0; 4]; + publisher.read_exact(&mut response).await.unwrap(); + assert_eq!(b":1\r\n", &response); + + // The first subscriber received the message + let mut response = [0; 39]; + sub1.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$7\r\nmessage\r\n$5\r\nhello\r\n$5\r\nworld\r\n"[..], + &response[..] + ); + + // Create a second subscriber + // + // This subscriber will be subscribed to both `hello` and `foo` + let mut sub2 = TcpStream::connect(addr).await.unwrap(); + sub2.write_all(b"*3\r\n$9\r\nSUBSCRIBE\r\n$5\r\nhello\r\n$3\r\nfoo\r\n") + .await + .unwrap(); + + // Read the subscribe response + let mut response = [0; 34]; + sub2.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$9\r\nsubscribe\r\n$5\r\nhello\r\n:1\r\n"[..], + &response[..] + ); + let mut response = [0; 32]; + sub2.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$9\r\nsubscribe\r\n$3\r\nfoo\r\n:2\r\n"[..], + &response[..] + ); + + // Publish another message on `hello`, there are two subscribers + publisher + .write_all(b"*3\r\n$7\r\nPUBLISH\r\n$5\r\nhello\r\n$5\r\njazzy\r\n") + .await + .unwrap(); + + let mut response = [0; 4]; + publisher.read_exact(&mut response).await.unwrap(); + assert_eq!(b":2\r\n", &response); + + // Publish a message on `foo`, there is only one subscriber + publisher + .write_all(b"*3\r\n$7\r\nPUBLISH\r\n$3\r\nfoo\r\n$3\r\nbar\r\n") + .await + .unwrap(); + + let mut response = [0; 4]; + publisher.read_exact(&mut response).await.unwrap(); + assert_eq!(b":1\r\n", &response); + + // The first subscriber received the message + let mut response = [0; 39]; + sub1.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$7\r\nmessage\r\n$5\r\nhello\r\n$5\r\njazzy\r\n"[..], + &response[..] + ); + + // The second subscriber received the message + let mut response = [0; 39]; + sub2.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$7\r\nmessage\r\n$5\r\nhello\r\n$5\r\njazzy\r\n"[..], + &response[..] + ); + + // The first subscriber **did not** receive the second message + let mut response = [0; 1]; + time::timeout(Duration::from_millis(100), sub1.read(&mut response)) + .await + .unwrap_err(); + + // The second subscriber **did** receive the message + let mut response = [0; 35]; + sub2.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$7\r\nmessage\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"[..], + &response[..] + ); +} + +#[tokio::test] +async fn manage_subscription() { + let addr = start_server().await; + + let mut publisher = TcpStream::connect(addr).await.unwrap(); + + // Create a subscriber + let mut sub = TcpStream::connect(addr).await.unwrap(); + sub.write_all(b"*2\r\n$9\r\nSUBSCRIBE\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + // Read the subscribe response + let mut response = [0; 34]; + sub.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$9\r\nsubscribe\r\n$5\r\nhello\r\n:1\r\n"[..], + &response[..] + ); + + // Update subscription to add `foo` + sub.write_all(b"*2\r\n$9\r\nSUBSCRIBE\r\n$3\r\nfoo\r\n") + .await + .unwrap(); + + let mut response = [0; 32]; + sub.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$9\r\nsubscribe\r\n$3\r\nfoo\r\n:2\r\n"[..], + &response[..] + ); + + // Update subscription to remove `hello` + sub.write_all(b"*2\r\n$11\r\nUNSUBSCRIBE\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + let mut response = [0; 37]; + sub.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$11\r\nunsubscribe\r\n$5\r\nhello\r\n:1\r\n"[..], + &response[..] + ); + + // Publish a message to `hello` and then a message to `foo` + publisher + .write_all(b"*3\r\n$7\r\nPUBLISH\r\n$5\r\nhello\r\n$5\r\nworld\r\n") + .await + .unwrap(); + let mut response = [0; 4]; + publisher.read_exact(&mut response).await.unwrap(); + assert_eq!(b":0\r\n", &response); + + publisher + .write_all(b"*3\r\n$7\r\nPUBLISH\r\n$3\r\nfoo\r\n$3\r\nbar\r\n") + .await + .unwrap(); + let mut response = [0; 4]; + publisher.read_exact(&mut response).await.unwrap(); + assert_eq!(b":1\r\n", &response); + + // Receive the message + // The second subscriber **did** receive the message + let mut response = [0; 35]; + sub.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$7\r\nmessage\r\n$3\r\nfoo\r\n$3\r\nbar\r\n"[..], + &response[..] + ); + + // No more messages + let mut response = [0; 1]; + time::timeout(Duration::from_millis(100), sub.read(&mut response)) + .await + .unwrap_err(); + + // Unsubscribe from all channels + sub.write_all(b"*1\r\n$11\r\nunsubscribe\r\n") + .await + .unwrap(); + + let mut response = [0; 35]; + sub.read_exact(&mut response).await.unwrap(); + assert_eq!( + &b"*3\r\n$11\r\nunsubscribe\r\n$3\r\nfoo\r\n:0\r\n"[..], + &response[..] + ); +} + +// In this case we test that server Responds with an Error message if a client +// sends an unknown command +#[tokio::test] +async fn send_error_unknown_command() { + let addr = start_server().await; + + // Establish a connection to the server + let mut stream = TcpStream::connect(addr).await.unwrap(); + + // Get a key, data is missing + stream + .write_all(b"*2\r\n$3\r\nFOO\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + let mut response = [0; 28]; + + stream.read_exact(&mut response).await.unwrap(); + + assert_eq!(b"-ERR unknown command \'foo\'\r\n", &response); +} + +// In this case we test that server Responds with an Error message if a client +// sends an GET or SET command after a SUBSCRIBE +#[tokio::test] +async fn send_error_get_set_after_subscribe() { + let addr = start_server().await; + + let mut stream = TcpStream::connect(addr).await.unwrap(); + + // send SUBSCRIBE command + stream + .write_all(b"*2\r\n$9\r\nsubscribe\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + let mut response = [0; 34]; + + stream.read_exact(&mut response).await.unwrap(); + + assert_eq!( + &b"*3\r\n$9\r\nsubscribe\r\n$5\r\nhello\r\n:1\r\n"[..], + &response[..] + ); + + stream + .write_all(b"*3\r\n$3\r\nSET\r\n$5\r\nhello\r\n$5\r\nworld\r\n") + .await + .unwrap(); + + let mut response = [0; 28]; + + stream.read_exact(&mut response).await.unwrap(); + assert_eq!(b"-ERR unknown command \'set\'\r\n", &response); + + stream + .write_all(b"*2\r\n$3\r\nGET\r\n$5\r\nhello\r\n") + .await + .unwrap(); + + let mut response = [0; 28]; + + stream.read_exact(&mut response).await.unwrap(); + assert_eq!(b"-ERR unknown command \'get\'\r\n", &response); +} + +async fn start_server() -> SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { server::run(listener, tokio::signal::ctrl_c()).await }); + + addr +}