Skip to content

Commit

Permalink
mostly working, alpha build
Browse files Browse the repository at this point in the history
  • Loading branch information
Julia Merz committed Aug 10, 2023
1 parent 8eae2d7 commit 24bb063
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 10 deletions.
4 changes: 3 additions & 1 deletion src-tauri/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ pub fn update_permissions(
) -> Result<usize, diesel::result::Error> {
let conn = &mut pool.get().unwrap();
use schema::user::dsl::*;
println!("Updating to {:?}", perms);
diesel::update(user)
.filter(id.eq(DbUuid(user_id)))
.set((perm_superuser.eq(perms.perm_superuser),
Expand All @@ -332,7 +333,8 @@ pub fn update_permissions(
perm_request_download.eq(perms.perm_request_download),
perm_request_load.eq(perms.perm_request_load),
perm_request_unload.eq(perms.perm_request_unload),
perm_view_llms.eq(perms.perm_view_llms)))
perm_view_llms.eq(perms.perm_view_llms),
perm_bare_model.eq(perms.perm_bare_model)))
.execute(conn)
}

Expand Down
45 changes: 41 additions & 4 deletions src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::connectors::llm_manager;

use dashmap::DashMap;
use diesel::prelude::*;
use diesel::connection::SimpleConnection;
use diesel::r2d2::ConnectionManager;
use diesel::r2d2::Pool;
use diesel::sqlite::Sqlite;
Expand Down Expand Up @@ -51,16 +52,51 @@ mod server;
mod state;
mod user;

#[derive(Debug)]
pub struct ConnectionOptions {
pub enable_wal: bool,
pub enable_foreign_keys: bool,
pub busy_timeout: Option<Duration>,
}

impl diesel::r2d2::CustomizeConnection<SqliteConnection, diesel::r2d2::Error>
for ConnectionOptions
{
fn on_acquire(&self, conn: &mut SqliteConnection) -> Result<(), diesel::r2d2::Error> {
(|| {
if self.enable_wal {
conn.batch_execute("PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL;")?;
}
if self.enable_foreign_keys {
conn.batch_execute("PRAGMA foreign_keys = ON;")?;
}
if let Some(d) = self.busy_timeout {
conn.batch_execute(&format!("PRAGMA busy_timeout = {};", d.as_millis()))?;
}
Ok(())
})()
.map_err(diesel::r2d2::Error::QueryError)
}
}

pub fn get_connection_pool(db_url: String) -> Pool<ConnectionManager<SqliteConnection>> {
// let url = database_url_for_env();

let manager = ConnectionManager::<SqliteConnection>::new(db_url);
// Refer to the `r2d2` documentation for more methods to use
// when building a connection pool
Pool::builder()
let pool = Pool::builder()
.max_size(8)
.connection_customizer(Box::new(ConnectionOptions {
enable_wal: true,
enable_foreign_keys: true,
busy_timeout: Some(Duration::from_secs(10)),
}))
.test_on_check_out(true)
.build(manager)
.expect("Could not build connection pool")
.expect("Could not build connection pool");
pool

}

// pub fn establish_connection() -> SqliteConnection {
Expand All @@ -74,7 +110,7 @@ fn run_migrations(
//
// See the documentation for `MigrationHarness` for
// all available methods.
println!("Mirations:\n{:?}", connection.revert_all_migrations(MIGRATIONS));
// println!("Mirations:\n{:?}", connection.revert_all_migrations(MIGRATIONS));
// println!("Mirations:\n{:?}", connection.applied_migrations());
connection.run_pending_migrations(MIGRATIONS)?;

Expand Down Expand Up @@ -128,7 +164,8 @@ async fn main() {
let context = tauri::generate_context!();

let mut db_path = tauri::api::path::app_local_data_dir(context.config()).unwrap();
let llm_path = tauri::api::path::local_data_dir().unwrap();
let mut llm_path = tauri::api::path::local_data_dir().unwrap();
llm_path.push("pantry");

if !llm_path.exists() {
fs::create_dir_all(&llm_path).unwrap();
Expand Down
10 changes: 6 additions & 4 deletions src-tauri/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,12 @@ async fn register_user(
let user = user::User::new(payload.user_name);
match database::save_new_user(user, state.pool.clone()) {
Ok(user) => Ok(Json((&user).into())),
Err(err) => Err((
Err(err) => {
println!("Error creating user: {:?}", err.to_string());
Err((
StatusCode::INTERNAL_SERVER_ERROR,
"Error creating user".into(),
)),
))},
}
}

Expand Down Expand Up @@ -1523,7 +1525,7 @@ async fn bare_model_flex(
state: State<state::GlobalStateWrapper>,
Json(payload): Json<BareModelFlexRequest>,
) -> Result<Json<BareModelResponse>, (StatusCode, String)> {
println!("Called bare_mode_flex from API.");
println!("Called bare_model_flex from API.");
let user_uuid =
Uuid::parse_str(&payload.user_id).map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
let user = user_permission_check("bare_model", payload.api_key, user_uuid, state.pool.clone())?;
Expand Down Expand Up @@ -1706,7 +1708,7 @@ async fn bare_model(
state: State<state::GlobalStateWrapper>,
Json(payload): Json<BareModelRequest>,
) -> Result<Json<BareModelResponse>, (StatusCode, String)> {
println!("Called load_llm from API.");
println!("Called bare_model from API.");
let user_uuid =
Uuid::parse_str(&payload.user_id).map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;

Expand Down
3 changes: 2 additions & 1 deletion src/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ function toUserPermissions(permissions: any) {
permRequestDownload: permissions.perm_request_download || false,
permRequestLoad: permissions.perm_request_load || false,
permRequestUnload: permissions.perm_request_unload || false,
permViewLlms: permissions.perm_view_llms || false
permViewLlms: permissions.perm_view_llms || false,
permBareModel: permissions.perm_bare_model || false
};
}

Expand Down

0 comments on commit 24bb063

Please sign in to comment.