Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Duckdb #578

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
284 changes: 259 additions & 25 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ tree-sitter-rust = "0.23"
tree-sitter-typescript = "0.23"
tree-sitter-go = "0.23"
tree-sitter-solidity = "1.2.11"
duckdb = { version = "1.1.1", default-features = false }


# Testing
Expand Down
1 change: 1 addition & 0 deletions swiftide-core/src/template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use anyhow::{Context as _, Result};
use tokio::sync::RwLock;

use lazy_static::lazy_static;
pub use tera::Context;
use tera::Tera;
use uuid::Uuid;

Expand Down
5 changes: 5 additions & 0 deletions swiftide-integrations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ strum = { workspace = true }
strum_macros = { workspace = true }
regex = { workspace = true }
futures-util = { workspace = true }
uuid = { workspace = true }

# Integrations
async-openai = { workspace = true, optional = true }
Expand Down Expand Up @@ -84,6 +85,7 @@ parquet = { workspace = true, optional = true, features = [
] }
arrow = { workspace = true, optional = true }
redb = { workspace = true, optional = true }
duckdb = { workspace = true, optional = true }

[dev-dependencies]
swiftide-core = { path = "../swiftide-core", features = ["test-utils"] }
Expand All @@ -93,6 +95,7 @@ swiftide-test-utils = { path = "../swiftide-test-utils", features = [
temp-dir = { workspace = true }
pretty_assertions = { workspace = true }
arrow = { workspace = true, features = ["test_utils"] }
duckdb = { workspace = true, features = ["bundled"] }

# Used for hacking fluv to play nice
flv-util = { workspace = true }
Expand Down Expand Up @@ -154,6 +157,8 @@ fluvio = ["dep:fluvio"]
parquet = ["dep:arrow-array", "dep:parquet", "dep:arrow"]
# Redb as an embeddable node cache
redb = ["dep:redb"]
# Duckdb for indexing and retrieval
duckdb = ["dep:duckdb"]


[lints]
Expand Down
62 changes: 62 additions & 0 deletions swiftide-integrations/src/duckdb/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use std::{
collections::HashMap,
sync::{Arc, OnceLock},
};

use derive_builder::Builder;
use swiftide_core::indexing::EmbeddedField;
use tokio::sync::Mutex;

pub mod persist;
pub mod retrieve;

#[derive(Clone, Builder)]
#[builder(setter(into))]
pub struct Duckdb {
#[builder(setter(custom))]
connection: Arc<Mutex<duckdb::Connection>>,
table_name: String,

// The vectors to be stored, field name -> size
vectors: HashMap<EmbeddedField, usize>,

#[builder(default = "256")]
batch_size: usize,

#[builder(default = OnceLock::new())]
node_upsert_sql: OnceLock<String>,
}

impl std::fmt::Debug for Duckdb {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Duckdb")
.field("connection", &"Arc<Mutex<duckdb::Connection>>")
.field("table_name", &self.table_name)
.field("batch_size", &self.batch_size)
.finish()
}
}

impl Duckdb {
pub fn builder() -> DuckdbBuilder {
DuckdbBuilder::default()
}

pub async fn connection(&self) -> &Mutex<duckdb::Connection> {
&self.connection
}
}

impl DuckdbBuilder {
pub fn connection(&mut self, connection: impl Into<duckdb::Connection>) -> &mut Self {
self.connection = Some(Arc::new(Mutex::new(connection.into())));
self
}

pub fn with_vector(&mut self, field: EmbeddedField, size: usize) -> &mut Self {
self.vectors
.get_or_insert_with(HashMap::new)
.insert(field, size);
self
}
}
188 changes: 188 additions & 0 deletions swiftide-integrations/src/duckdb/persist.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
use std::{borrow::Cow, collections::HashMap, path::Path};

use anyhow::{Context as _, Result};
use async_trait::async_trait;
use duckdb::{
params_from_iter,
types::{FromSql, OrderedMap, ToSqlOutput, Type, Value, ValueRef},
ToSql,
};
use swiftide_core::{
indexing::{self, EmbeddedField, Metadata},
template::{Context, Template},
Persist,
};
use uuid::Uuid;

use super::Duckdb;

const SCHEMA: &str = include_str!("schema.sql");
const UPSERT: &str = include_str!("upsert.sql");

#[allow(dead_code)]
enum NodeValues<'a> {
Uuid(Uuid),
Path(&'a Path),
Chunk(&'a str),
Metadata(&'a Metadata),
Vector(Cow<'a, [f32]>),
Null,
}

impl ToSql for NodeValues<'_> {
fn to_sql(&self) -> duckdb::Result<ToSqlOutput<'_>> {
match self {
NodeValues::Uuid(uuid) => Ok(ToSqlOutput::Owned(uuid.to_string().into())),
NodeValues::Path(path) => Ok(path.to_string_lossy().to_string().into()), // Should be borrow-able
NodeValues::Chunk(chunk) => chunk.to_sql(),
NodeValues::Metadata(_metadata) => {
unimplemented!("maps are not yet implemented for duckdb");
// let ordered_map = metadata
// .iter()
// .map(|(k, v)| format!("'{}': {}", k, serde_json::to_string(v).unwrap()))
// .collect::<Vec<_>>()
// .join(",");
//
// let formatted = format!("MAP {{{ordered_map}}}");
// Ok(ToSqlOutput::Owned(formatted.into()))
}
NodeValues::Vector(vector) => {
let array_str = format!(
"[{}]",
vector
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",")
);
Ok(ToSqlOutput::Owned(array_str.into()))
}
NodeValues::Null => Ok(ToSqlOutput::Owned(Value::Null)),
}
}
}

#[async_trait]
impl Persist for Duckdb {
async fn setup(&self) -> Result<()> {
let mut context = Context::default();
context.insert("table_name", &self.table_name);
context.insert("vectors", &self.vectors);

let rendered = Template::Static(SCHEMA).render(&context).await?;
self.connection.lock().await.execute_batch(&rendered)?;

context.insert(
"vector_field_names",
&self.vectors.keys().collect::<Vec<_>>(),
);

// User could have overridden the upsert sql
// Which is fine
let upsert = Template::Static(UPSERT).render(&context).await?;
self.node_upsert_sql
.set(upsert)
.map_err(|_| anyhow::anyhow!("Failed to set upsert sql"))?;

Ok(())
}

async fn store(&self, node: indexing::Node) -> Result<indexing::Node> {
let Some(query) = self.node_upsert_sql.get() else {
anyhow::bail!("Upsert sql in Duckdb not set");
};

// TODO: Doing potentially many locks here for the duration of a single query,
// SOMEONE IS GOING TO HAVE A BAD TIME

// metadata needs to be converted to `map_from_entries([('key1', value)])``
// TODO: Investigate if we can do with way less allocations
let mut values = vec![
NodeValues::Uuid(node.id()),
NodeValues::Chunk(&node.chunk),
NodeValues::Path(&node.path),
];

// if node.metadata.is_empty() {
// values.push(NodeValues::Null);
// } else {
// values.push(NodeValues::Metadata(&node.metadata));
// }

let Some(node_vectors) = &node.vectors else {
anyhow::bail!("Expected node to have vectors; cannot store into duckdb");
};

for (field, size) in &self.vectors {
let Some(vector) = node_vectors.get(field) else {
anyhow::bail!("Expected vector for field {} in node", field);
};

values.push(NodeValues::Vector(vector.into()));
}

let lock = self.connection.lock().await;
let mut stmt = lock.prepare(query)?;
// TODO: Investigate concurrency in duckdb, maybe optmistic if it works
stmt.execute(params_from_iter(values))
.context("Failed to store node")?;

Ok(node)
}

async fn batch_store(&self, nodes: Vec<indexing::Node>) -> indexing::IndexingStream {
// TODO: Must batch
let mut new_nodes = vec![];
for node in nodes {
new_nodes.push(self.store(node).await);
}
new_nodes.into()
}
}

#[cfg(test)]
mod tests {
use indexing::{EmbeddedField, Node};

use super::*;

#[test_log::test(tokio::test)]
async fn test_persisting_nodes() {
let client = Duckdb::builder()
.connection(duckdb::Connection::open_in_memory().unwrap())
.table_name("test".to_string())
.with_vector(EmbeddedField::Combined, 3)
.build()
.unwrap();

let node = Node::new("Hello duckdb!")
.with_vectors([(EmbeddedField::Combined, vec![1.0, 2.0, 3.0])])
.to_owned();

client.setup().await.unwrap();
client.store(node).await.unwrap();

tracing::info!("Stored node");

let connection = client.connection.lock().await;
let mut stmt = connection
.prepare("SELECT uuid,path,chunk FROM test")
.unwrap();
let node_iter = stmt
.query_map([], |row| {
Ok((
row.get::<_, String>(0).unwrap(), // id
row.get::<_, String>(1).unwrap(), // chunk
row.get::<_, String>(2).unwrap(), // path
// row.get::<_, String>(3).unwrap(), // metadata
// row.get::<_, Vec<f32>>(4).unwrap(), // vector
))
})
.unwrap();

let retrieved = node_iter.collect::<Result<Vec<_>, _>>().unwrap();
dbg!(&retrieved);
//
assert_eq!(retrieved.len(), 1);
}
}
4 changes: 4 additions & 0 deletions swiftide-integrations/src/duckdb/retrieve.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
use duckdb::Statement;
use swiftide_core::{querying::search_strategies::CustomStrategy, Retrieve};

// impl Retrieve<CustomStrategy<Statement<''>
12 changes: 12 additions & 0 deletions swiftide-integrations/src/duckdb/schema.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
INSTALL vss;
LOAD vss;

CREATE TABLE IF NOT EXISTS {{table_name}} (
uuid VARCHAR PRIMARY KEY,
chunk VARCHAR NOT NULL,
path VARCHAR,

{% for vector, size in vectors %}
{{vector}} FLOAT[{{size}}],
{% endfor %}
);
13 changes: 13 additions & 0 deletions swiftide-integrations/src/duckdb/upsert.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
INSERT INTO {{ table_name }} (uuid, chunk, path, {{ vector_field_names | join(sep=", ") }})
VALUES (?, ?, ?,
{% for _ in range(end=vector_field_names | length) %}
?,
{% endfor %}
)
ON CONFLICT (uuid) DO UPDATE SET
chunk = EXCLUDED.chunk,
path = EXCLUDED.path,
{% for vector in vector_field_names %}
{{ vector }} = EXCLUDED.{{ vector }},
{% endfor %}
;
2 changes: 2 additions & 0 deletions swiftide-integrations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
pub mod aws_bedrock;
#[cfg(feature = "dashscope")]
pub mod dashscope;
#[cfg(feature = "duckdb")]
pub mod duckdb;
#[cfg(feature = "fastembed")]
pub mod fastembed;
#[cfg(feature = "fluvio")]
Expand Down