From 8bed231ec2dedadb0b9b42cde9f18f046cef0c68 Mon Sep 17 00:00:00 2001 From: "Jose Fernandez (magec)" Date: Wed, 7 Dec 2022 13:18:42 +0100 Subject: [PATCH] Add dns_cache so server addresses are cached and invalidated when DNS changes. Adds a module to deal with dns_cache feature. It's main struct is CachedResolver, which is a simple thread safe hostname <-> Ips cache with the ability to refresh resolutions every `dns_max_ttl` seconds. This way, a client can check whether its ip address has changed. --- Cargo.lock | 230 ++++++++++++++++++++++++++ Cargo.toml | 1 + README.md | 2 + examples/docker/pgcat.toml | 9 + src/config.rs | 12 ++ src/dns_cache.rs | 328 +++++++++++++++++++++++++++++++++++++ src/errors.rs | 1 + src/lib.rs | 1 + src/main.rs | 8 + src/server.rs | 35 ++++ 10 files changed, 627 insertions(+) create mode 100644 src/dns_cache.rs diff --git a/Cargo.lock b/Cargo.lock index 29992a3bc..fd8c64e57 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,6 +195,12 @@ dependencies = [ "syn", ] +[[package]] +name = "data-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57" + [[package]] name = "digest" version = "0.10.5" @@ -206,6 +212,18 @@ dependencies = [ "subtle", ] +[[package]] +name = "enum-as-inner" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "env_logger" version = "0.10.0" @@ -252,6 +270,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "form_urlencoded" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" +dependencies = [ + "percent-encoding", +] + [[package]] name = "fs_extra" version = "1.2.0" @@ -273,6 +300,12 @@ version = "0.3.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0c8ff0461b82559810cdccfde3215c3f373807f5e5232b71479bff7bb2583d7" +[[package]] +name = "futures-io" +version = "0.3.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00f5fb52a06bdcadeb54e8d3671f8888a39697dcb0b81b23b55174030427f4eb" + [[package]] name = "futures-sink" version = "0.3.21" @@ -345,6 +378,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "heck" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9" + [[package]] name = "hermit-abi" version = "0.1.19" @@ -372,6 +411,17 @@ dependencies = [ "digest", ] +[[package]] +name = "hostname" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867" +dependencies = [ + "libc", + "match_cfg", + "winapi", +] + [[package]] name = "http" version = "0.2.8" @@ -460,6 +510,27 @@ dependencies = [ "cxx-build", ] +[[package]] +name = "idna" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" +dependencies = [ + "matches", + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "idna" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indexmap" version = "1.9.1" @@ -480,6 +551,24 @@ dependencies = [ "windows-sys 0.42.0", ] +[[package]] +name = "ipconfig" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be" +dependencies = [ + "socket2", + "widestring", + "winapi", + "winreg", +] + +[[package]] +name = "ipnet" +version = "2.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745" + [[package]] name = "is-terminal" version = "0.4.0" @@ -549,6 +638,12 @@ dependencies = [ "cc", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "linux-raw-sys" version = "0.1.3" @@ -573,6 +668,27 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "lru-cache" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c" +dependencies = [ + "linked-hash-map", +] + +[[package]] +name = "match_cfg" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4" + +[[package]] +name = "matches" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" + [[package]] name = "md-5" version = "0.10.5" @@ -658,6 +774,12 @@ dependencies = [ "windows-sys 0.36.1", ] +[[package]] +name = "percent-encoding" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" + [[package]] name = "pgcat" version = "0.6.0-alpha1" @@ -691,6 +813,7 @@ dependencies = [ "tokio", "tokio-rustls", "toml", + "trust-dns-resolver", ] [[package]] @@ -762,6 +885,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quick-error" +version = "1.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" + [[package]] name = "quote" version = "1.0.15" @@ -827,6 +956,16 @@ version = "0.6.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" +[[package]] +name = "resolv-conf" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00" +dependencies = [ + "hostname", + "quick-error", +] + [[package]] name = "ring" version = "0.16.20" @@ -1026,6 +1165,26 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "thiserror" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "time" version = "0.1.44" @@ -1155,6 +1314,51 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "trust-dns-proto" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26" +dependencies = [ + "async-trait", + "cfg-if", + "data-encoding", + "enum-as-inner", + "futures-channel", + "futures-io", + "futures-util", + "idna 0.2.3", + "ipnet", + "lazy_static", + "rand", + "smallvec", + "thiserror", + "tinyvec", + "tokio", + "tracing", + "url", +] + +[[package]] +name = "trust-dns-resolver" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe" +dependencies = [ + "cfg-if", + "futures-util", + "ipconfig", + "lazy_static", + "lru-cache", + "parking_lot", + "resolv-conf", + "smallvec", + "thiserror", + "tokio", + "tracing", + "trust-dns-proto", +] + [[package]] name = "try-lock" version = "0.2.3" @@ -1200,6 +1404,17 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" +[[package]] +name = "url" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" +dependencies = [ + "form_urlencoded", + "idna 0.3.0", + "percent-encoding", +] + [[package]] name = "version_check" version = "0.9.4" @@ -1302,6 +1517,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "widestring" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17882f045410753661207383517a6f62ec3dbeb6a4ed2acce01f0728238d1983" + [[package]] name = "winapi" version = "0.3.9" @@ -1432,3 +1653,12 @@ name = "windows_x86_64_msvc" version = "0.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f40009d85759725a34da6d89a94e63d7bdc50a862acf0dbc7c8e488f1edcb6f5" + +[[package]] +name = "winreg" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +dependencies = [ + "winapi", +] diff --git a/Cargo.toml b/Cargo.toml index ca97c2846..3d59f0e8f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ rustls-pemfile = "1" hyper = { version = "0.14", features = ["full"] } phf = { version = "0.11.1", features = ["macros"] } exitcode = "1.1.2" +trust-dns-resolver = "0.22" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/README.md b/README.md index 313c11008..110223656 100644 --- a/README.md +++ b/README.md @@ -276,6 +276,8 @@ The config can be reloaded by sending a `kill -s SIGHUP` to the process or by qu | `default_role` | no | | `primary_reads_enabled` | no | | `query_parser_enabled` | no | +| `dns_max_ttl` | no | +| `dns_cache_enabled` | no | ## Benchmarks diff --git a/examples/docker/pgcat.toml b/examples/docker/pgcat.toml index c41c8cdd6..2c7f1e3c6 100644 --- a/examples/docker/pgcat.toml +++ b/examples/docker/pgcat.toml @@ -41,6 +41,15 @@ log_client_disconnections = false # Reload config automatically if it changes. autoreload = false +# If enabled, hostname resolution will be cached and +# and server connections will be invalidated if a change on the ip is +# detected. This check is done every `dns_max_ttl` seconds. +# dns_cache_enabled = false + +# The number of seconds to wait until we check again the +# cached hostnames resolution. 30 seconds by default. +# dns_max_ttl = 30 + # TLS # tls_certificate = "server.cert" # tls_private_key = "server.key" diff --git a/src/config.rs b/src/config.rs index 39bff1bea..4aab87194 100644 --- a/src/config.rs +++ b/src/config.rs @@ -163,6 +163,12 @@ pub struct General { #[serde(default)] // False pub log_client_disconnections: bool, + #[serde(default)] // False + pub dns_cache_enabled: bool, + + #[serde(default = "General::default_dns_max_ttl")] + pub dns_max_ttl: u64, + #[serde(default = "General::default_shutdown_timeout")] pub shutdown_timeout: u64, @@ -201,6 +207,10 @@ impl General { 60000 } + pub fn default_dns_max_ttl() -> u64 { + 30 + } + pub fn default_healthcheck_timeout() -> u64 { 1000 } @@ -228,6 +238,8 @@ impl Default for General { ban_time: Self::default_ban_time(), log_client_connections: false, log_client_disconnections: false, + dns_cache_enabled: false, + dns_max_ttl: Self::default_dns_max_ttl(), autoreload: false, tls_certificate: None, tls_private_key: None, diff --git a/src/dns_cache.rs b/src/dns_cache.rs new file mode 100644 index 000000000..aee180f50 --- /dev/null +++ b/src/dns_cache.rs @@ -0,0 +1,328 @@ +use crate::config::get_config; +use crate::errors::Error; +use arc_swap::ArcSwap; +use log::{debug, error, info}; +use once_cell::sync::Lazy; +use std::collections::{HashMap, HashSet}; +use std::io; +use std::net::IpAddr; +use std::sync::Arc; +use std::sync::RwLock; +use tokio::time::{sleep, Duration}; +use trust_dns_resolver::error::ResolveResult; +use trust_dns_resolver::lookup_ip::LookupIp; +use trust_dns_resolver::TokioAsyncResolver; + +/// Cached Resolver Globally available +pub static CACHED_RESOLVER: Lazy>>> = + Lazy::new(|| ArcSwap::from_pointee(None)); + +// Ip addressed are returned as a set of addresses +// so we can compare. +#[derive(Clone, PartialEq, Debug)] +pub struct AddrSet { + set: HashSet, +} + +impl AddrSet { + fn new() -> AddrSet { + AddrSet { + set: HashSet::new(), + } + } +} + +impl From for AddrSet { + fn from(lookup_ip: LookupIp) -> Self { + let mut addr_set = AddrSet::new(); + for address in lookup_ip.iter() { + addr_set.set.insert(address); + } + addr_set + } +} + +/// +/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time. +/// +/// The system works as follows: +/// +/// When a host is to be resolved, if we have not resolved it before, a new resolution is +/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the +/// cache is refreshed. +/// +/// # Example: +/// +/// ``` +/// let config = CachedResolverConfig(dns_max_ttl: 10); +/// let resolver = CachedResolver::new(config).await.unwrap() +/// let addrset = resolver.lookup_ip("www.example.com.").unwrap(); +/// +/// // Now the ip resolution is stored in local cache and subsequent +/// // calls will be returned from cache. Also, the cache is refreshed +/// // and updated every 10 seconds. +/// +/// // You can now check if an 'old' lookup differs from what it's currently +/// // store in cache by using `has_changed`. +/// resolver.has_changed("www.example.com.", addrset) +pub struct CachedResolver { + // The configuration of the cached_resolver. + config: CachedResolverConfig, + + // This is the hash that contains the hash. + data: Arc>>, + + // The resolver to be used for DNS queries. + resolver: Arc, +} + +/// +/// Configuration +#[derive(Clone, Debug)] +pub struct CachedResolverConfig { + /// Amount of time in secods that a resolved dns address is considered stale. + pub dns_max_ttl: u64, +} + +impl CachedResolver { + /// + /// Returns a new Arc based on passed configuration. + /// It also starts the loop that will refresh cache entries. + /// + /// # Arguments: + /// + /// * `config` - The `CachedResolverConfig` to be used to create the resolver. + /// + /// # Example: + /// + /// ``` + /// let config = CachedResolverConfig(dns_max_ttl: 10); + /// let resolver = CachedResolver::new(config) + /// ``` + /// + pub async fn new(config: CachedResolverConfig) -> io::Result> { + // Construct a new Resolver with default configuration options + let resolver = Arc::new(TokioAsyncResolver::tokio_from_system_conf()?); + let data = Arc::new(RwLock::new(HashMap::new())); + + let self_ref = Arc::new(Self { + config, + resolver, + data, + }); + let clone_self_ref = self_ref.clone(); + + info!("Scheduling DNS refresh loop"); + tokio::task::spawn(async move { + clone_self_ref.refresh_dns_entries_loop().await; + }); + + Ok(self_ref) + } + + // Schedules the refresher + async fn refresh_dns_entries_loop(&self) { + let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap(); + let interval = Duration::from_secs(self.config.dns_max_ttl); + loop { + debug!("Begin refreshing cached DNS addresses."); + // To minimize the time we hold the lock, we first create + // an array with keys. + let mut hostnames: Vec = Vec::new(); + { + for hostname in self.data.read().unwrap().keys() { + hostnames.push(hostname.clone()); + } + } + + for hostname in hostnames.iter() { + let addrset = self + .fetch_from_cache(hostname.as_str()) + .expect("Could not obtain expected address from cache, this should not happen"); + + match resolver.lookup_ip(hostname).await { + Ok(lookup_ip) => { + let new_addrset = AddrSet::from(lookup_ip); + debug!( + "Obtained address for host ({}) -> ({:?})", + hostname, new_addrset + ); + + if addrset != new_addrset { + debug!( + "Addr changed from {:?} to {:?} updating cache.", + addrset, new_addrset + ); + self.store_in_cache(hostname, new_addrset); + } + } + Err(err) => { + error!( + "There was an error trying to resolv {}: ({}).", + hostname, err + ); + } + } + } + debug!("Finished refreshing cached DNS addresses."); + sleep(interval).await; + } + } + + /// Returns a `AddrSet` given the specified hostname. + /// + /// This method first tries to fetch the value from the cache, if it misses + /// then it is resolved and stored in the cache. TTL from records is ignored. + /// + /// # Arguments + /// + /// * `host` - A string slice referencing the hostname to be resolved. + /// + /// # Example: + /// + /// ``` + /// let config = CachedResolverConfig { dns_max_ttl: 10 }; + /// let resolver = CachedResolver::new(config).await.unwrap(); + /// let response = resolver.lookup_ip("www.google.com."); + /// ``` + /// + pub async fn lookup_ip(&self, host: &str) -> ResolveResult { + debug!("Lookup up {} in cache", host); + match self.fetch_from_cache(host) { + Some(addr_set) => { + debug!("Cache hit!"); + Ok(addr_set) + } + None => { + debug!("Not found, executing a dns query!"); + let addr_set = AddrSet::from(self.resolver.lookup_ip(host).await?); + debug!("Obtained: {:?}", addr_set); + self.store_in_cache(host, addr_set.clone()); + Ok(addr_set) + } + } + } + + // + // Returns true if the stored host resolution differs from the AddrSet passed. + pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool { + if let Some(fetched_addr_set) = self.fetch_from_cache(host) { + return fetched_addr_set != *addr_set; + } + false + } + + // Fetches an AddrSet from the inner cache adquiring the read lock. + fn fetch_from_cache(&self, key: &str) -> Option { + let hash = &self.data.read().unwrap(); + if let Some(addr_set) = hash.get(key) { + return Some(addr_set.clone()); + } + None + } + + // Sets up the global CACHED_RESOLVER static variable so we can globally use DNS + // cache. + pub async fn from_config() -> Result<(), Error> { + let config = get_config(); + + info!("Starting Dns cache? {:?}", config); + + // Configure dns_cache if enabled + if config.general.dns_cache_enabled { + info!("Starting Dns cache"); + let cached_resolver_config = CachedResolverConfig { + dns_max_ttl: config.general.dns_max_ttl, + }; + return match CachedResolver::new(cached_resolver_config).await { + Ok(ok) => { + let value = Some(ArcSwap::from(ok)); + CACHED_RESOLVER.store(Arc::new(value)); + Ok(()) + } + Err(err) => { + let message = format!("Error Starting cached_resolver error: {:?}, will continue without this feature.", err); + Err(Error::DNSCachedError(message)) + } + }; + } + Ok(()) + } + + // Stores the AddrSet in cache adquiring the write lock. + fn store_in_cache(&self, host: &str, addr_set: AddrSet) { + self.data + .write() + .unwrap() + .insert(host.to_string(), addr_set); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use trust_dns_resolver::error::ResolveError; + + #[tokio::test] + async fn new() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await; + assert!(resolver.is_ok()); + } + + #[tokio::test] + async fn lookup_ip() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let response = resolver.lookup_ip("www.google.com.").await; + assert!(response.is_ok()); + } + + #[tokio::test] + async fn has_changed() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let hostname = "www.google.com."; + let response = resolver.lookup_ip(hostname).await; + let addr_set = response.unwrap(); + assert!(!resolver.has_changed(hostname, &addr_set)); + } + + #[tokio::test] + async fn unknown_host() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let hostname = "www.idontexists."; + let response = resolver.lookup_ip(hostname).await; + assert!(matches!(response, Err(ResolveError { .. }))); + } + + #[tokio::test] + async fn incorrect_address() { + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let hostname = "w ww.idontexists."; + let response = resolver.lookup_ip(hostname).await; + assert!(matches!(response, Err(ResolveError { .. }))); + assert!(!resolver.has_changed(hostname, &AddrSet::new())); + } + + #[tokio::test] + // Ok, this test is based on the fact that google does DNS RR + // and does not responds with every available ip everytime, so + // if I cache here, it will miss after one cache iteration or two. + async fn thread() { + env_logger::init(); + let config = CachedResolverConfig { dns_max_ttl: 10 }; + let resolver = CachedResolver::new(config).await.unwrap(); + let hostname = "www.google.com."; + let response = resolver.lookup_ip(hostname).await; + let addr_set = response.unwrap(); + assert!(!resolver.has_changed(hostname, &addr_set)); + let resolver_for_refresher = resolver.clone(); + let _thread_handle = tokio::task::spawn(async move { + resolver_for_refresher.refresh_dns_entries_loop().await; + }); + assert!(!resolver.has_changed(hostname, &addr_set)); + } +} diff --git a/src/errors.rs b/src/errors.rs index 7789a8a77..baf1552bc 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -12,5 +12,6 @@ pub enum Error { ClientError(String), TlsError, StatementTimeout, + DNSCachedError(String), ShuttingDown, } diff --git a/src/lib.rs b/src/lib.rs index e9a683f3d..7702f3070 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ pub mod config; pub mod constants; +pub mod dns_cache; pub mod errors; pub mod messages; pub mod pool; diff --git a/src/main.rs b/src/main.rs index 0b5f7324e..d3fb7b630 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,7 @@ extern crate sqlparser; extern crate tokio; extern crate tokio_rustls; extern crate toml; +extern crate trust_dns_resolver; #[cfg(not(target_env = "msvc"))] use jemallocator::Jemalloc; @@ -63,6 +64,7 @@ mod admin; mod client; mod config; mod constants; +mod dns_cache; mod errors; mod messages; mod pool; @@ -146,6 +148,12 @@ async fn main() { let (stats_tx, stats_rx) = mpsc::channel(100_000); REPORTER.store(Arc::new(Reporter::new(stats_tx.clone()))); + // Starts (if enabled) dns cache before pools initialization + match dns_cache::CachedResolver::from_config().await { + Ok(_) => (), + Err(err) => error!("DNS cache initialization error: {:?}", err), + }; + // Connection pool that allows to query all shards and replicas. match ConnectionPool::from_config(client_server_map.clone()).await { Ok(_) => (), diff --git a/src/server.rs b/src/server.rs index 65fb8d9eb..f796b6cd7 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,7 @@ use bytes::{Buf, BufMut, BytesMut}; use log::{debug, error, info, trace, warn}; use std::io::Read; +use std::net::IpAddr; use std::time::SystemTime; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{ @@ -12,6 +13,7 @@ use tokio::net::{ use crate::config::{Address, User}; use crate::constants::*; +use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::errors::Error; use crate::messages::*; use crate::pool::ClientServerMap; @@ -68,6 +70,9 @@ pub struct Server { // Last time that a successful server send or response happened last_activity: SystemTime, + + // Associated addresses used + addr_set: Option, } impl Server { @@ -81,6 +86,28 @@ impl Server { client_server_map: ClientServerMap, stats: Reporter, ) -> Result { + let cached_resolver = CACHED_RESOLVER.load(); + let addr_set = match cached_resolver.as_ref() { + Some(cached_resolver) => { + if address.host.parse::().is_err() { + debug!("Resolving {}", &address.host); + match cached_resolver.load().lookup_ip(&address.host).await { + Ok(ok) => { + debug!("Obtained: {:?}", ok); + Some(ok) + } + Err(err) => { + warn!("Error trying to resolve {}, ({:?})", &address.host, err); + None + } + } + } else { + None + } + } + None => None, + }; + let mut stream = match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await { Ok(stream) => stream, @@ -329,6 +356,7 @@ impl Server { bad: false, needs_cleanup: false, client_server_map, + addr_set, connected_at: chrono::offset::Utc::now().naive_utc(), stats, application_name: String::new(), @@ -554,6 +582,13 @@ impl Server { /// Server & client are out of sync, we must discard this connection. /// This happens with clients that misbehave. pub fn is_bad(&self) -> bool { + if let Some(cached_resolver) = CACHED_RESOLVER.load().as_ref() { + if let Some(addr_set) = &self.addr_set { + return cached_resolver + .load() + .has_changed(self.address.host.as_str(), addr_set); + } + } self.bad }