diff --git a/crates/atuin-client/src/database.rs b/crates/atuin-client/src/database.rs index d01dadb44c7..15b0797d5e0 100644 --- a/crates/atuin-client/src/database.rs +++ b/crates/atuin-client/src/database.rs @@ -144,10 +144,21 @@ impl Sqlite { .with_regexp() .create_if_missing(true); - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) - .connect_with(opts) - .await?; + let mut pool_opts = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)); + + // Workaround for sqlx-sqlite in-memory failure. + // Make sure a single connection is created to avoid overwriting the database. + // Safe to use in this case since memory databases are only used for testing. + // https://github.com/launchbadge/sqlx/issues/2510 + if path.to_str().unwrap().ends_with(":memory:"){ + pool_opts = pool_opts + .max_connections(1) + .idle_timeout(None) + .max_lifetime(None); + } + + let pool = pool_opts.connect_with(opts).await?; Self::setup_db(&pool).await?; diff --git a/crates/atuin-client/src/record/sqlite_store.rs b/crates/atuin-client/src/record/sqlite_store.rs index 31de311b65f..1d47c435c60 100644 --- a/crates/atuin-client/src/record/sqlite_store.rs +++ b/crates/atuin-client/src/record/sqlite_store.rs @@ -45,8 +45,21 @@ impl SqliteStore { .foreign_keys(true) .create_if_missing(true); - let pool = SqlitePoolOptions::new() - .acquire_timeout(Duration::from_secs_f64(timeout)) + let mut pool_opts = SqlitePoolOptions::new() + .acquire_timeout(Duration::from_secs_f64(timeout)); + + // Workaround for sqlx-sqlite in-memory failure. + // Make sure a single connection is created to avoid overwriting the database. + // Safe to use in this case since memory databases are only used for testing. + // https://github.com/launchbadge/sqlx/issues/2510 + if path.to_str().unwrap().ends_with(":memory:"){ + pool_opts = pool_opts + .max_connections(1) + .idle_timeout(None) + .max_lifetime(None); + } + + let pool = pool_opts .connect_with(opts) .await?; @@ -72,16 +85,16 @@ impl SqliteStore { "insert or ignore into store(id, idx, host, tag, timestamp, version, data, cek) values(?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", ) - .bind(r.id.0.as_hyphenated().to_string()) - .bind(r.idx as i64) - .bind(r.host.id.0.as_hyphenated().to_string()) - .bind(r.tag.as_str()) - .bind(r.timestamp as i64) - .bind(r.version.as_str()) - .bind(r.data.data.as_str()) - .bind(r.data.content_encryption_key.as_str()) - .execute(&mut **tx) - .await?; + .bind(r.id.0.as_hyphenated().to_string()) + .bind(r.idx as i64) + .bind(r.host.id.0.as_hyphenated().to_string()) + .bind(r.tag.as_str()) + .bind(r.timestamp as i64) + .bind(r.version.as_str()) + .bind(r.data.data.as_str()) + .bind(r.data.content_encryption_key.as_str()) + .execute(&mut **tx) + .await?; Ok(()) } @@ -122,7 +135,7 @@ impl SqliteStore { impl Store for SqliteStore { async fn push_batch( &self, - records: impl Iterator> + Send + Sync, + records: impl Iterator> + Send + Sync, ) -> Result<()> { let mut tx = self.pool.begin().await?; @@ -181,7 +194,7 @@ impl Store for SqliteStore { } async fn len_all(&self) -> Result { - let res: Result<(i64,), sqlx::Error> = sqlx::query_as("select count(*) from store") + let res: Result<(i64, ), sqlx::Error> = sqlx::query_as("select count(*) from store") .fetch_one(&self.pool) .await; match res { @@ -191,7 +204,7 @@ impl Store for SqliteStore { } async fn len_tag(&self, tag: &str) -> Result { - let res: Result<(i64,), sqlx::Error> = + let res: Result<(i64, ), sqlx::Error> = sqlx::query_as("select count(*) from store where tag=?1") .bind(tag) .fetch_one(&self.pool)