Skip to content

Commit

Permalink
Fix sqlite memory connection issues
Browse files Browse the repository at this point in the history
Signed-off-by: Cristian Le <cristian.le@mpsd.mpg.de>
  • Loading branch information
LecrisUT committed Jun 10, 2024
1 parent 85aa88a commit d2afd5b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 19 deletions.
19 changes: 15 additions & 4 deletions crates/atuin-client/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,21 @@ impl Sqlite {
.with_regexp()
.create_if_missing(true);

let pool = SqlitePoolOptions::new()
.acquire_timeout(Duration::from_secs_f64(timeout))
.connect_with(opts)
.await?;
let mut pool_opts = SqlitePoolOptions::new()
.acquire_timeout(Duration::from_secs_f64(timeout));

// Workaround for sqlx-sqlite in-memory failure.
// Make sure a single connection is created to avoid overwriting the database.
// Safe to use in this case since memory databases are only used for testing.
// https://github.com/launchbadge/sqlx/issues/2510
if path.to_str().unwrap().ends_with(":memory:"){
pool_opts = pool_opts
.max_connections(1)
.idle_timeout(None)
.max_lifetime(None);
}

let pool = pool_opts.connect_with(opts).await?;

Self::setup_db(&pool).await?;

Expand Down
43 changes: 28 additions & 15 deletions crates/atuin-client/src/record/sqlite_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,21 @@ impl SqliteStore {
.foreign_keys(true)
.create_if_missing(true);

let pool = SqlitePoolOptions::new()
.acquire_timeout(Duration::from_secs_f64(timeout))
let mut pool_opts = SqlitePoolOptions::new()
.acquire_timeout(Duration::from_secs_f64(timeout));

// Workaround for sqlx-sqlite in-memory failure.
// Make sure a single connection is created to avoid overwriting the database.
// Safe to use in this case since memory databases are only used for testing.
// https://github.com/launchbadge/sqlx/issues/2510
if path.to_str().unwrap().ends_with(":memory:"){
pool_opts = pool_opts
.max_connections(1)
.idle_timeout(None)
.max_lifetime(None);
}

let pool = pool_opts
.connect_with(opts)
.await?;

Expand All @@ -72,16 +85,16 @@ impl SqliteStore {
"insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek)
values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
)
.bind(r.id.0.as_hyphenated().to_string())
.bind(r.idx as i64)
.bind(r.host.id.0.as_hyphenated().to_string())
.bind(r.tag.as_str())
.bind(r.timestamp as i64)
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
.execute(&mut **tx)
.await?;
.bind(r.id.0.as_hyphenated().to_string())
.bind(r.idx as i64)
.bind(r.host.id.0.as_hyphenated().to_string())
.bind(r.tag.as_str())
.bind(r.timestamp as i64)
.bind(r.version.as_str())
.bind(r.data.data.as_str())
.bind(r.data.content_encryption_key.as_str())
.execute(&mut **tx)
.await?;

Ok(())
}
Expand Down Expand Up @@ -122,7 +135,7 @@ impl SqliteStore {
impl Store for SqliteStore {
async fn push_batch(
&self,
records: impl Iterator<Item = &Record<EncryptedData>> + Send + Sync,
records: impl Iterator<Item=&Record<EncryptedData>> + Send + Sync,
) -> Result<()> {
let mut tx = self.pool.begin().await?;

Expand Down Expand Up @@ -181,7 +194,7 @@ impl Store for SqliteStore {
}

async fn len_all(&self) -> Result<u64> {
let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store")
let res: Result<(i64, ), sqlx::Error> = sqlx::query_as("select count(*) from store")
.fetch_one(&self.pool)
.await;
match res {
Expand All @@ -191,7 +204,7 @@ impl Store for SqliteStore {
}

async fn len_tag(&self, tag: &str) -> Result<u64> {
let res: Result<(i64,), sqlx::Error> =
let res: Result<(i64, ), sqlx::Error> =
sqlx::query_as("select count(*) from store where tag=?1")
.bind(tag)
.fetch_one(&self.pool)
Expand Down

0 comments on commit d2afd5b

Please sign in to comment.