Skip to content

Commit

Permalink
fix(rust): use the same in-memory database for the vault used in nif
Browse files Browse the repository at this point in the history
  • Loading branch information
etorreborre committed Nov 30, 2023
1 parent 62d32a5 commit 8c6d635
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 25 deletions.
1 change: 1 addition & 0 deletions implementations/elixir/ockam/ockly/native/ockly/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ hex = { version = "0.4", default-features = false }
lazy_static = "1.4.0"
minicbor = { version = "0.20.0", features = ["alloc", "derive"] }
ockam_identity = { path = "../../../../../rust/ockam/ockam_identity" }
ockam_node = { path = "../../../../../rust/ockam/ockam_node" }
ockam_vault = { path = "../../../../../rust/ockam/ockam_vault" }
ockam_vault_aws = { path = "../../../../../rust/ockam/ockam_vault_aws" }
# Enable credentials-sso feature in ockam_vault_aws for use on sso environments (like dev machines)
Expand Down
28 changes: 17 additions & 11 deletions implementations/elixir/ockam/ockly/native/ockly/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@ use ockam_identity::{
utils::AttributesBuilder,
Identifier, Identities, Vault,
};
use ockam_node::database::SqlxDatabase;
use ockam_vault::{
EdDSACurve25519SecretKey, HandleToSecret, SigningKeyType, SigningSecret,
SigningSecretKeyHandle, SoftwareVaultForSecureChannels, SoftwareVaultForSigning,
X25519PublicKey, X25519SecretKey,
};
use ockam_vault::storage::SecretsRepository;
use ockam_vault::storage::SecretsSqlxDatabase;
use ockam_vault_aws::{AwsKmsConfig, AwsSigningVault, InitialKeysDiscovery};
use rustler::{Atom, Binary, Env, Error, NewBinary, NifResult};
use std::clone::Clone;
Expand Down Expand Up @@ -65,8 +68,8 @@ fn get_runtime() -> Arc<Runtime> {
}

fn block_future<F>(f: F) -> <F as Future>::Output
where
F: Future,
where
F: Future,
{
let rt = get_runtime();
task::block_in_place(move || {
Expand All @@ -89,17 +92,20 @@ fn identities_ref() -> NifResult<Arc<Identities>> {

fn load_memory_vault() -> bool {
block_future(async move {
let identity_vault = SoftwareVaultForSigning::create().await.unwrap();
let secure_channel_vault = SoftwareVaultForSecureChannels::create().await.unwrap();
let database = SqlxDatabase::in_memory("in-memory-vault").await.unwrap();
let secrets_repository: Arc<dyn SecretsRepository> =
Arc::new(SecretsSqlxDatabase::new(database));
let identity_vault = Arc::new(SoftwareVaultForSigning::new(secrets_repository.clone()));
let secure_channel_vault = Arc::new(SoftwareVaultForSecureChannels::new(secrets_repository.clone()));
*IDENTITY_MEMORY_VAULT.write().unwrap() = Some(identity_vault.clone());
*SECURE_CHANNEL_MEMORY_VAULT.write().unwrap() = Some(secure_channel_vault.clone());
let builder = ockam_identity::Identities::builder()
.await
.unwrap()
.with_vault(Vault::new(
identity_vault,
identity_vault.clone(),
secure_channel_vault,
Vault::create_credential_vault().await.unwrap(),
identity_vault,
Vault::create_verifying_vault(),
));
*IDENTITIES.write().unwrap() = Some(builder.build());
Expand Down Expand Up @@ -182,7 +188,7 @@ fn create_identity(env: Env, existing_key: Option<String>) -> NifResult<(Binary,
let identifier = builder.build().await?;
identities_ref.get_identity(&identifier).await
})
.map_err(|e| Error::Term(Box::new((atoms::identity_creation_error(), e.to_string()))))?;
.map_err(|e| Error::Term(Box::new((atoms::identity_creation_error(), e.to_string()))))?;

let exported = identity
.export()
Expand Down Expand Up @@ -224,7 +230,7 @@ fn attest_secure_channel_key<'a>(
.build()
.await
})
.map_err(|e| Error::Term(Box::new((atoms::attest_error(), e.to_string()))))?;
.map_err(|e| Error::Term(Box::new((atoms::attest_error(), e.to_string()))))?;
let encoded = minicbor::to_vec(purpose_key.attestation())
.map_err(|e| Error::Term(Box::new((atoms::attestation_encode_error(), e.to_string()))))?;
let mut exp_binary = NewBinary::new(env, encoded.len());
Expand Down Expand Up @@ -270,7 +276,7 @@ fn verify_secure_channel_key_attestation(
}
})
})
.map_err(|reason| Error::Term(Box::new(reason)))
.map_err(|reason| Error::Term(Box::new(reason)))
}

#[rustler::nif]
Expand All @@ -283,7 +289,7 @@ fn check_identity<'a>(env: Env<'a>, identity: Binary) -> NifResult<Binary<'a>> {
.await
.map_err(|e| (atoms::identity_import_error(), e.to_string()))
})
.map_err(|reason| Error::Term(Box::new(reason)))?;
.map_err(|reason| Error::Term(Box::new(reason)))?;
let identifier = identifier.to_string();
let mut binary = NewBinary::new(env, identifier.len());
binary.copy_from_slice(identifier.as_bytes());
Expand Down Expand Up @@ -323,7 +329,7 @@ fn issue_credential<'a>(
.await
.map_err(|e| (atoms::credential_issuing_error(), e.to_string()))
})
.map_err(|reason| Error::Term(Box::new(reason)))?;
.map_err(|reason| Error::Term(Box::new(reason)))?;
let encoded = minicbor::to_vec(credential_and_purpose_key)
.map_err(|e| Error::Term(Box::new((atoms::credential_encode_error(), e.to_string()))))?;
let mut binary = NewBinary::new(env, encoded.len());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use tokio_retry::Retry;
use tracing::debug;
use tracing::log::LevelFilter;

use ockam_core::compat::rand;
use ockam_core::compat::sync::Arc;
use ockam_core::{Error, Result};

Expand Down Expand Up @@ -53,7 +54,7 @@ impl SqlxDatabase {
let db = Retry::spawn(retry_strategy, || async {
Self::create_at(path.as_ref()).await
})
.await?;
.await?;
db.migrate().await?;
Ok(db)
}
Expand Down Expand Up @@ -88,7 +89,14 @@ impl SqlxDatabase {
}

async fn create_in_memory_connection_pool() -> Result<SqlitePool> {
let pool = SqlitePool::connect("sqlite::memory:")
// the database url has to be a random one and specify a shared cache
// to avoid data leakage: https://github.com/p2panda/aquadoggo/pull/595
let database_url = {
let db_name = format!("dbmem{}", rand::random::<u32>());
format!("sqlite://file:{db_name}?mode=memory&cache=shared")
};

let pool = SqlitePool::connect(&database_url)
.await
.map_err(Self::map_sql_err)?;
Ok(pool)
Expand Down Expand Up @@ -149,6 +157,7 @@ impl<T> ToVoid<T> for core::result::Result<T, sqlx::error::Error> {
mod tests {
use sqlx::sqlite::SqliteQueryResult;
use sqlx::FromRow;
use std::thread;
use tempfile::NamedTempFile;

use crate::database::ToSqlxType;
Expand All @@ -160,9 +169,9 @@ mod tests {
#[tokio::test]
async fn test_create_identity_table() -> Result<()> {
let db_file = NamedTempFile::new().unwrap();
let db = SqlxDatabase::create(db_file.path()).await?;
let db = Arc::new(SqlxDatabase::create(db_file.path()).await?);

let inserted = insert_identity(&db).await.unwrap();
let inserted = insert_identity(db).await.unwrap();

assert_eq!(inserted.rows_affected(), 1);
Ok(())
Expand All @@ -173,14 +182,15 @@ mod tests {
async fn test_query() -> Result<()> {
let db_file = NamedTempFile::new().unwrap();
let db = SqlxDatabase::create(db_file.path()).await?;
let pool = db.pool.clone();

insert_identity(&db).await.unwrap();
insert_identity(Arc::new(db)).await.unwrap();

// successful query
let result: Option<IdentifierRow> =
sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1")
.bind("Ifa804b7fca12a19eed206ae180b5b576860ae651")
.fetch_optional(&db.pool)
.fetch_optional(&pool)
.await
.unwrap();
assert_eq!(
Expand All @@ -194,21 +204,59 @@ mod tests {
let result: Option<IdentifierRow> =
sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1")
.bind("x")
.fetch_optional(&db.pool)
.fetch_optional(&pool)
.await
.unwrap();
assert_eq!(result, None);
Ok(())
}

/// This test checks that we can access the in-memory database from several threads concurrently
#[tokio::test]
async fn test_in_memory() -> Result<()> {
let db = SqlxDatabase::in_memory("test").await?;

let handles = (0..5)
.map(|i| {
let db_arc = db.clone();
thread::spawn(move || async move {
insert_identity_row(db_arc, &format!("{i}"), "123")
.await
.unwrap()
})
})
.collect::<Vec<_>>();

for handle in handles {
handle.join().unwrap().await;
}

let result: Vec<IdentifierRow> =
sqlx::query_as("SELECT * FROM identity ORDER BY identifier ASC")
.fetch_all(&db.pool)
.await
.into_core()?;
assert_eq!(
result.iter().map(|r| r.0.as_str()).collect::<Vec<_>>(),
vec!["0", "1", "2", "3", "4"]
);
Ok(())
}

/// HELPERS
async fn insert_identity(db: &SqlxDatabase) -> Result<SqliteQueryResult> {
sqlx::query("INSERT INTO identity VALUES (?1, ?2)")
.bind("Ifa804b7fca12a19eed206ae180b5b576860ae651")
.bind("123".to_sql())
.execute(&db.pool)
.await
.into_core()
async fn insert_identity(db: Arc<SqlxDatabase>) -> Result<SqliteQueryResult> {
insert_identity_row(db, "Ifa804b7fca12a19eed206ae180b5b576860ae651", "123").await
}

async fn insert_identity_row(
db: Arc<SqlxDatabase>,
identifier: &str,
change_history: &str,
) -> Result<SqliteQueryResult> {
let query = sqlx::query("INSERT INTO identity VALUES (?1, ?2)")
.bind(identifier.to_sql())
.bind(change_history.to_sql());
db.pool.execute(query).await.into_core()
}

#[derive(FromRow, PartialEq, Eq, Debug)]
Expand Down

0 comments on commit 8c6d635

Please sign in to comment.