From 8c6d635907dae649676b4e5d22f1fa492a57aeca Mon Sep 17 00:00:00 2001 From: etorreborre Date: Thu, 30 Nov 2023 11:50:40 +0100 Subject: [PATCH] fix(rust): use the same in-memory database for the vault used in nif --- .../ockam/ockly/native/ockly/Cargo.toml | 1 + .../ockam/ockly/native/ockly/src/lib.rs | 28 ++++--- .../src/storage/database/sqlx_database.rs | 76 +++++++++++++++---- 3 files changed, 80 insertions(+), 25 deletions(-) diff --git a/implementations/elixir/ockam/ockly/native/ockly/Cargo.toml b/implementations/elixir/ockam/ockly/native/ockly/Cargo.toml index a14feb205dc..f374e7eafa8 100644 --- a/implementations/elixir/ockam/ockly/native/ockly/Cargo.toml +++ b/implementations/elixir/ockam/ockly/native/ockly/Cargo.toml @@ -16,6 +16,7 @@ hex = { version = "0.4", default-features = false } lazy_static = "1.4.0" minicbor = { version = "0.20.0", features = ["alloc", "derive"] } ockam_identity = { path = "../../../../../rust/ockam/ockam_identity" } +ockam_node = { path = "../../../../../rust/ockam/ockam_node" } ockam_vault = { path = "../../../../../rust/ockam/ockam_vault" } ockam_vault_aws = { path = "../../../../../rust/ockam/ockam_vault_aws" } # Enable credentials-sso feature in ockam_vault_aws for use on sso environments (like dev machines) diff --git a/implementations/elixir/ockam/ockly/native/ockly/src/lib.rs b/implementations/elixir/ockam/ockly/native/ockly/src/lib.rs index 531d1ca0b61..fcd1fb7ce86 100644 --- a/implementations/elixir/ockam/ockly/native/ockly/src/lib.rs +++ b/implementations/elixir/ockam/ockly/native/ockly/src/lib.rs @@ -12,11 +12,14 @@ use ockam_identity::{ utils::AttributesBuilder, Identifier, Identities, Vault, }; +use ockam_node::database::SqlxDatabase; use ockam_vault::{ EdDSACurve25519SecretKey, HandleToSecret, SigningKeyType, SigningSecret, SigningSecretKeyHandle, SoftwareVaultForSecureChannels, SoftwareVaultForSigning, X25519PublicKey, X25519SecretKey, }; +use ockam_vault::storage::SecretsRepository; +use ockam_vault::storage::SecretsSqlxDatabase; use ockam_vault_aws::{AwsKmsConfig, AwsSigningVault, InitialKeysDiscovery}; use rustler::{Atom, Binary, Env, Error, NewBinary, NifResult}; use std::clone::Clone; @@ -65,8 +68,8 @@ fn get_runtime() -> Arc { } fn block_future(f: F) -> ::Output -where - F: Future, + where + F: Future, { let rt = get_runtime(); task::block_in_place(move || { @@ -89,17 +92,20 @@ fn identities_ref() -> NifResult> { fn load_memory_vault() -> bool { block_future(async move { - let identity_vault = SoftwareVaultForSigning::create().await.unwrap(); - let secure_channel_vault = SoftwareVaultForSecureChannels::create().await.unwrap(); + let database = SqlxDatabase::in_memory("in-memory-vault").await.unwrap(); + let secrets_repository: Arc = + Arc::new(SecretsSqlxDatabase::new(database)); + let identity_vault = Arc::new(SoftwareVaultForSigning::new(secrets_repository.clone())); + let secure_channel_vault = Arc::new(SoftwareVaultForSecureChannels::new(secrets_repository.clone())); *IDENTITY_MEMORY_VAULT.write().unwrap() = Some(identity_vault.clone()); *SECURE_CHANNEL_MEMORY_VAULT.write().unwrap() = Some(secure_channel_vault.clone()); let builder = ockam_identity::Identities::builder() .await .unwrap() .with_vault(Vault::new( - identity_vault, + identity_vault.clone(), secure_channel_vault, - Vault::create_credential_vault().await.unwrap(), + identity_vault, Vault::create_verifying_vault(), )); *IDENTITIES.write().unwrap() = Some(builder.build()); @@ -182,7 +188,7 @@ fn create_identity(env: Env, existing_key: Option) -> NifResult<(Binary, let identifier = builder.build().await?; identities_ref.get_identity(&identifier).await }) - .map_err(|e| Error::Term(Box::new((atoms::identity_creation_error(), e.to_string()))))?; + .map_err(|e| Error::Term(Box::new((atoms::identity_creation_error(), e.to_string()))))?; let exported = identity .export() @@ -224,7 +230,7 @@ fn attest_secure_channel_key<'a>( .build() .await }) - .map_err(|e| Error::Term(Box::new((atoms::attest_error(), e.to_string()))))?; + .map_err(|e| Error::Term(Box::new((atoms::attest_error(), e.to_string()))))?; let encoded = minicbor::to_vec(purpose_key.attestation()) .map_err(|e| Error::Term(Box::new((atoms::attestation_encode_error(), e.to_string()))))?; let mut exp_binary = NewBinary::new(env, encoded.len()); @@ -270,7 +276,7 @@ fn verify_secure_channel_key_attestation( } }) }) - .map_err(|reason| Error::Term(Box::new(reason))) + .map_err(|reason| Error::Term(Box::new(reason))) } #[rustler::nif] @@ -283,7 +289,7 @@ fn check_identity<'a>(env: Env<'a>, identity: Binary) -> NifResult> { .await .map_err(|e| (atoms::identity_import_error(), e.to_string())) }) - .map_err(|reason| Error::Term(Box::new(reason)))?; + .map_err(|reason| Error::Term(Box::new(reason)))?; let identifier = identifier.to_string(); let mut binary = NewBinary::new(env, identifier.len()); binary.copy_from_slice(identifier.as_bytes()); @@ -323,7 +329,7 @@ fn issue_credential<'a>( .await .map_err(|e| (atoms::credential_issuing_error(), e.to_string())) }) - .map_err(|reason| Error::Term(Box::new(reason)))?; + .map_err(|reason| Error::Term(Box::new(reason)))?; let encoded = minicbor::to_vec(credential_and_purpose_key) .map_err(|e| Error::Term(Box::new((atoms::credential_encode_error(), e.to_string()))))?; let mut binary = NewBinary::new(env, encoded.len()); diff --git a/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_database.rs b/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_database.rs index 685116f05a4..5141765406e 100644 --- a/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_database.rs +++ b/implementations/rust/ockam/ockam_node/src/storage/database/sqlx_database.rs @@ -10,6 +10,7 @@ use tokio_retry::Retry; use tracing::debug; use tracing::log::LevelFilter; +use ockam_core::compat::rand; use ockam_core::compat::sync::Arc; use ockam_core::{Error, Result}; @@ -53,7 +54,7 @@ impl SqlxDatabase { let db = Retry::spawn(retry_strategy, || async { Self::create_at(path.as_ref()).await }) - .await?; + .await?; db.migrate().await?; Ok(db) } @@ -88,7 +89,14 @@ impl SqlxDatabase { } async fn create_in_memory_connection_pool() -> Result { - let pool = SqlitePool::connect("sqlite::memory:") + // the database url has to be a random one and specify a shared cache + // to avoid data leakage: https://github.com/p2panda/aquadoggo/pull/595 + let database_url = { + let db_name = format!("dbmem{}", rand::random::()); + format!("sqlite://file:{db_name}?mode=memory&cache=shared") + }; + + let pool = SqlitePool::connect(&database_url) .await .map_err(Self::map_sql_err)?; Ok(pool) @@ -149,6 +157,7 @@ impl ToVoid for core::result::Result { mod tests { use sqlx::sqlite::SqliteQueryResult; use sqlx::FromRow; + use std::thread; use tempfile::NamedTempFile; use crate::database::ToSqlxType; @@ -160,9 +169,9 @@ mod tests { #[tokio::test] async fn test_create_identity_table() -> Result<()> { let db_file = NamedTempFile::new().unwrap(); - let db = SqlxDatabase::create(db_file.path()).await?; + let db = Arc::new(SqlxDatabase::create(db_file.path()).await?); - let inserted = insert_identity(&db).await.unwrap(); + let inserted = insert_identity(db).await.unwrap(); assert_eq!(inserted.rows_affected(), 1); Ok(()) @@ -173,14 +182,15 @@ mod tests { async fn test_query() -> Result<()> { let db_file = NamedTempFile::new().unwrap(); let db = SqlxDatabase::create(db_file.path()).await?; + let pool = db.pool.clone(); - insert_identity(&db).await.unwrap(); + insert_identity(Arc::new(db)).await.unwrap(); // successful query let result: Option = sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1") .bind("Ifa804b7fca12a19eed206ae180b5b576860ae651") - .fetch_optional(&db.pool) + .fetch_optional(&pool) .await .unwrap(); assert_eq!( @@ -194,21 +204,59 @@ mod tests { let result: Option = sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1") .bind("x") - .fetch_optional(&db.pool) + .fetch_optional(&pool) .await .unwrap(); assert_eq!(result, None); Ok(()) } + /// This test checks that we can access the in-memory database from several threads concurrently + #[tokio::test] + async fn test_in_memory() -> Result<()> { + let db = SqlxDatabase::in_memory("test").await?; + + let handles = (0..5) + .map(|i| { + let db_arc = db.clone(); + thread::spawn(move || async move { + insert_identity_row(db_arc, &format!("{i}"), "123") + .await + .unwrap() + }) + }) + .collect::>(); + + for handle in handles { + handle.join().unwrap().await; + } + + let result: Vec = + sqlx::query_as("SELECT * FROM identity ORDER BY identifier ASC") + .fetch_all(&db.pool) + .await + .into_core()?; + assert_eq!( + result.iter().map(|r| r.0.as_str()).collect::>(), + vec!["0", "1", "2", "3", "4"] + ); + Ok(()) + } + /// HELPERS - async fn insert_identity(db: &SqlxDatabase) -> Result { - sqlx::query("INSERT INTO identity VALUES (?1, ?2)") - .bind("Ifa804b7fca12a19eed206ae180b5b576860ae651") - .bind("123".to_sql()) - .execute(&db.pool) - .await - .into_core() + async fn insert_identity(db: Arc) -> Result { + insert_identity_row(db, "Ifa804b7fca12a19eed206ae180b5b576860ae651", "123").await + } + + async fn insert_identity_row( + db: Arc, + identifier: &str, + change_history: &str, + ) -> Result { + let query = sqlx::query("INSERT INTO identity VALUES (?1, ?2)") + .bind(identifier.to_sql()) + .bind(change_history.to_sql()); + db.pool.execute(query).await.into_core() } #[derive(FromRow, PartialEq, Eq, Debug)]