Skip to content

Commit

Permalink
Actually plugins (#421)
Browse files Browse the repository at this point in the history
* more plugins

* clean up

* fix tests

* fix flakey test
  • Loading branch information
levkk authored May 3, 2023
1 parent d5e329f commit 811885f
Show file tree
Hide file tree
Showing 11 changed files with 264 additions and 70 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ serde_json = "1"

[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0"

42 changes: 39 additions & 3 deletions pgcat.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ admin_username = "admin_user"
# Password to access the virtual administrative database
admin_password = "admin_pass"

# Plugins!!
# query_router_plugins = ["pg_table_access", "intercept"]

# pool configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting.
# For a pool named `sharded_db`, clients access that pool using connection string like
Expand Down Expand Up @@ -157,6 +154,45 @@ connect_timeout = 3000
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30

[plugins]

[plugins.query_logger]
enabled = false

[plugins.table_access]
enabled = false
tables = [
"pg_user",
"pg_roles",
"pg_database",
]

[plugins.intercept]
enabled = true

[plugins.intercept.queries.0]

query = "select current_database() as a, current_schemas(false) as b"
schema = [
["a", "text"],
["b", "text"],
]
result = [
["${DATABASE}", "{public}"],
]

[plugins.intercept.queries.1]

query = "select current_database(), current_schema(), current_user"
schema = [
["current_database", "text"],
["current_schema", "text"],
["current_user", "text"],
]
result = [
["${DATABASE}", "public", "${USER}"],
]

# User configs are structured as pool.<pool_name>.users.<user_index>
# This section holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
Expand Down
63 changes: 56 additions & 7 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,6 @@ pub struct General {
pub auth_query: Option<String>,
pub auth_query_user: Option<String>,
pub auth_query_password: Option<String>,

pub query_router_plugins: Option<Vec<String>>,
}

impl General {
Expand Down Expand Up @@ -404,7 +402,6 @@ impl Default for General {
auth_query_user: None,
auth_query_password: None,
server_lifetime: 1000 * 3600 * 24, // 24 hours,
query_router_plugins: None,
}
}
}
Expand Down Expand Up @@ -682,6 +679,55 @@ impl Default for Shard {
}
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Plugins {
pub intercept: Option<Intercept>,
pub table_access: Option<TableAccess>,
pub query_logger: Option<QueryLogger>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Intercept {
pub enabled: bool,
pub queries: BTreeMap<String, Query>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct TableAccess {
pub enabled: bool,
pub tables: Vec<String>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct QueryLogger {
pub enabled: bool,
}

impl Intercept {
pub fn substitute(&mut self, db: &str, user: &str) {
for (_, query) in self.queries.iter_mut() {
query.substitute(db, user);
}
}
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Default)]
pub struct Query {
pub query: String,
pub schema: Vec<Vec<String>>,
pub result: Vec<Vec<String>>,
}

impl Query {
pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() {
for i in 0..col.len() {
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
}
}
}
}

/// Configuration wrapper.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Config {
Expand All @@ -700,6 +746,7 @@ pub struct Config {
pub path: String,

pub general: General,
pub plugins: Option<Plugins>,
pub pools: HashMap<String, Pool>,
}

Expand Down Expand Up @@ -737,6 +784,7 @@ impl Default for Config {
path: Self::default_path(),
general: General::default(),
pools: HashMap::default(),
plugins: None,
}
}
}
Expand Down Expand Up @@ -1128,25 +1176,26 @@ pub async fn parse(path: &str) -> Result<(), Error> {

pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
let old_config = get_config();

match parse(&old_config.path).await {
Ok(()) => (),
Err(err) => {
error!("Config reload error: {:?}", err);
return Err(Error::BadConfig);
}
};

let new_config = get_config();

match CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
};

if old_config.pools != new_config.pools {
info!("Pool configuration changed");
if old_config != new_config {
info!("Config changed, reloading");
ConnectionPool::from_config(client_server_map).await?;
Ok(true)
} else if old_config != new_config {
Ok(true)
} else {
Ok(false)
}
Expand Down
50 changes: 30 additions & 20 deletions src/plugins/intercept.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,41 @@ use serde_json::{json, Value};
use sqlparser::ast::Statement;
use std::collections::HashMap;

use log::debug;
use log::{debug, info};
use std::sync::Arc;

use crate::{
config::Intercept as InterceptConfig,
errors::Error,
messages::{command_complete, data_row_nullable, row_description, DataType},
plugins::{Plugin, PluginOutput},
pool::{PoolIdentifier, PoolMap},
query_router::QueryRouter,
};

pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, Value>>> =
pub static CONFIG: Lazy<ArcSwap<HashMap<PoolIdentifier, InterceptConfig>>> =
Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));

/// Configure the intercept plugin.
pub fn configure(pools: &PoolMap) {
/// Check if the interceptor plugin has been enabled.
pub fn enabled() -> bool {
!CONFIG.load().is_empty()
}

pub fn setup(intercept_config: &InterceptConfig, pools: &PoolMap) {
let mut config = HashMap::new();
for (identifier, _) in pools.iter() {
// TODO: make this configurable from a text config.
let value = fool_datagrip(&identifier.db, &identifier.user);
config.insert(identifier.clone(), value);
let mut intercept_config = intercept_config.clone();
intercept_config.substitute(&identifier.db, &identifier.user);
config.insert(identifier.clone(), intercept_config);
}

CONFIG.store(Arc::new(config));

info!("Intercepting {} queries", intercept_config.queries.len());
}

pub fn disable() {
CONFIG.store(Arc::new(HashMap::new()));
}

// TODO: use these structs for deserialization
Expand Down Expand Up @@ -78,19 +89,19 @@ impl Plugin for Intercept {
// Normalization
let q = q.to_string().to_ascii_lowercase();

for target in query_map.as_array().unwrap().iter() {
if target["query"].as_str().unwrap() == q {
debug!("Query matched: {}", q);
for (_, target) in query_map.queries.iter() {
if target.query.as_str() == q {
debug!("Intercepting query: {}", q);

let rd = target["schema"]
.as_array()
.unwrap()
let rd = target
.schema
.iter()
.map(|row| {
let row = row.as_object().unwrap();
let name = &row[0];
let data_type = &row[1];
(
row["name"].as_str().unwrap(),
match row["data_type"].as_str().unwrap() {
name.as_str(),
match data_type.as_str() {
"text" => DataType::Text,
"anyarray" => DataType::AnyArray,
"oid" => DataType::Oid,
Expand All @@ -104,13 +115,11 @@ impl Plugin for Intercept {

result.put(row_description(&rd));

target["result"].as_array().unwrap().iter().for_each(|row| {
target.result.iter().for_each(|row| {
let row = row
.as_array()
.unwrap()
.iter()
.map(|s| {
let s = s.as_str().unwrap().to_string();
let s = s.as_str().to_string();

if s == "" {
None
Expand Down Expand Up @@ -141,6 +150,7 @@ impl Plugin for Intercept {

/// Make IntelliJ SQL plugin believe it's talking to an actual database
/// instead of PgCat.
#[allow(dead_code)]
fn fool_datagrip(database: &str, user: &str) -> Value {
json!([
{
Expand Down
9 changes: 6 additions & 3 deletions src/plugins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//!

pub mod intercept;
pub mod query_logger;
pub mod table_access;

use crate::{errors::Error, query_router::QueryRouter};
Expand All @@ -17,6 +18,7 @@ use bytes::BytesMut;
use sqlparser::ast::Statement;

pub use intercept::Intercept;
pub use query_logger::QueryLogger;
pub use table_access::TableAccess;

#[derive(Clone, Debug, PartialEq)]
Expand All @@ -29,12 +31,13 @@ pub enum PluginOutput {

#[async_trait]
pub trait Plugin {
// Custom output is allowed because we want to extend this system
// to rewriting queries some day. So an output of a plugin could be
// a rewritten AST.
// Run before the query is sent to the server.
async fn run(
&mut self,
query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error>;

// TODO: run after the result is returned
// async fn callback(&mut self, query_router: &QueryRouter);
}
49 changes: 49 additions & 0 deletions src/plugins/query_logger.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//! Log all queries to stdout (or somewhere else, why not).

use crate::{
errors::Error,
plugins::{Plugin, PluginOutput},
query_router::QueryRouter,
};
use arc_swap::ArcSwap;
use async_trait::async_trait;
use log::info;
use once_cell::sync::Lazy;
use sqlparser::ast::Statement;
use std::sync::Arc;

static ENABLED: Lazy<ArcSwap<bool>> = Lazy::new(|| ArcSwap::from_pointee(false));

pub struct QueryLogger;

pub fn setup() {
ENABLED.store(Arc::new(true));

info!("Logging queries to stdout");
}

pub fn disable() {
ENABLED.store(Arc::new(false));
}

pub fn enabled() -> bool {
**ENABLED.load()
}

#[async_trait]
impl Plugin for QueryLogger {
async fn run(
&mut self,
_query_router: &QueryRouter,
ast: &Vec<Statement>,
) -> Result<PluginOutput, Error> {
let query = ast
.iter()
.map(|q| q.to_string())
.collect::<Vec<String>>()
.join("; ");
info!("{}", query);

Ok(PluginOutput::Allow)
}
}
Loading

0 comments on commit 811885f

Please sign in to comment.