Skip to content

Commit

Permalink
feat: add the ability to store API key in a secure maner (#158)
Browse files Browse the repository at this point in the history
* feat(secure_api_key): Adding keyring dependency to use secure credential store

* feat(secure_api_key): update cx_Freeze settings for keyring

* feat(secure_api_key): update account model to store and retreve api key from system keyring

* feat(secure_api_key): Update account dialog to choose the api key storage method

* refactor(secure_api_key): use a field validator instead of init to get keyring secret

* feat(secure_api_key): use app name for service name and account id for username in system keyring

* fix(secure_api_key): change api key field validation method to fix account creation

* feat(secure_api_key): implement secure storage for organization key

Refactor validation method for api_key
Restore custom __init__ method for account

* feat(secure_api_key): implement orgazization GUI to secure API key

* fix(secure_api_key): fix saving key storage method on account organisation edition

* fix: reset active_organization property when necessary
  • Loading branch information
clementb49 authored Aug 25, 2024
1 parent fc2ee29 commit 068d3ac
Show file tree
Hide file tree
Showing 4 changed files with 488 additions and 138 deletions.
109 changes: 99 additions & 10 deletions basilisk/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Iterable, Optional
from uuid import uuid4

import keyring
from pydantic import (
UUID4,
BaseModel,
Expand All @@ -15,7 +16,7 @@
OnErrorOmit,
RootModel,
SecretStr,
ValidationError,
ValidationInfo,
field_serializer,
field_validator,
model_serializer,
Expand All @@ -24,11 +25,17 @@

import basilisk.global_vars as global_vars

from .consts import APP_NAME
from .provider import Provider, get_provider, providers

log = getLogger(__name__)


class KeyStorageMethodEnum(Enum):
plain = "plain"
system = "system"


class AccountSource(Enum):
ENV_VAR = "env_var"
CONFIG = "config"
Expand All @@ -38,12 +45,45 @@ class AccountOrganization(BaseModel):
model_config = ConfigDict(populate_by_name=True)
id: UUID4 = Field(default_factory=uuid4)
name: str
key_storage_method: KeyStorageMethodEnum = Field(
default=KeyStorageMethodEnum.plain
)
key: SecretStr
source: AccountSource = Field(default=AccountSource.CONFIG, exclude=True)

@field_validator("key", mode="before")
@classmethod
def validate_key(
cls, value: Optional[Any], info: ValidationInfo
) -> SecretStr:
if isinstance(value, SecretStr):
return value
data = info.data
if data["key_storage_method"] == KeyStorageMethodEnum.plain:
if not isinstance(value, str):
raise ValueError("Key must be a string")
return SecretStr(value)
elif data["key_storage_method"] == KeyStorageMethodEnum.system:
value = keyring.get_password(APP_NAME, str(data["id"]))
if not value:
raise ValueError("Key not found in keyring")
return SecretStr(value)
else:
raise ValueError("Invalid key storage method")

@field_serializer("key", when_used="json")
def dump_secret(self, value: SecretStr) -> str:
return value.get_secret_value()
if self.key_storage_method == KeyStorageMethodEnum.plain:
return value.get_secret_value()
elif self.key_storage_method == KeyStorageMethodEnum.system:
keyring.set_password(
APP_NAME, str(self.id), value.get_secret_value()
)
return None

def delete_keyring_password(self):
if self.key_storage_method == KeyStorageMethodEnum.system:
keyring.delete_password(APP_NAME, str(self.id))


class Account(BaseModel):
Expand All @@ -57,6 +97,9 @@ class Account(BaseModel):
provider: Provider = Field(
validation_alias="provider_id", serialization_alias="provider_id"
)
api_key_storage_method: Optional[KeyStorageMethodEnum] = Field(
default=KeyStorageMethodEnum.plain
)
api_key: Optional[SecretStr] = Field(default=None)
organizations: Optional[list[AccountOrganization]] = Field(default=None)
active_organization_id: Optional[UUID4] = Field(default=None)
Expand All @@ -65,19 +108,46 @@ class Account(BaseModel):
def __init__(self, **data: Any):
try:
super().__init__(**data)
except ValidationError as e:
except Exception as e:
log.error(
f"Error in account {e} the account will not be accessible"
f"Error in account {e} the account will not be accessible",
exc_info=e,
)
raise e

@field_serializer("provider", when_used="always")
def serialize_provider(value: Provider) -> str:
return value.id

@field_validator("api_key", mode="before")
@classmethod
def validate_api_key(
cls, value: Optional[Any], info: ValidationInfo
) -> Optional[SecretStr]:
if isinstance(value, SecretStr):
return value
data = info.data
if data["api_key_storage_method"] == KeyStorageMethodEnum.plain:
if not isinstance(value, str):
raise ValueError("API key must be a string")
return SecretStr(value)
elif data["api_key_storage_method"] == KeyStorageMethodEnum.system:
value = keyring.get_password(APP_NAME, str(data["id"]))
if not value:
raise ValueError("API key not found in keyring")
return SecretStr(value)
else:
raise ValueError("Invalid API key storage method")

@field_serializer("api_key", when_used="json")
def dump_secret(self, value: SecretStr) -> str:
return value.get_secret_value()
def dump_secret(self, value: SecretStr) -> Optional[str]:
if self.api_key_storage_method == KeyStorageMethodEnum.plain:
return value.get_secret_value()
elif self.api_key_storage_method == KeyStorageMethodEnum.system:
keyring.set_password(
APP_NAME, str(self.id), value.get_secret_value()
)
return None

@field_validator("provider", mode="plain")
@classmethod
Expand Down Expand Up @@ -128,18 +198,31 @@ def active_organization(self) -> Optional[AccountOrganization]:
None,
)

def reset_active_organization(self):
try:
del self.active_organization
except AttributeError:
pass

@property
def active_organization_name(self) -> Optional[str]:
if not self.active_organization:
return None
return self.active_organization.name
return (
self.active_organization.name if self.active_organization else None
)

@property
def active_organization_key(self) -> Optional[SecretStr]:
return (
self.active_organization.key if self.active_organization else None
)

def delete_keyring_password(self):
if self.organisations:
for org in self.organisations:
org.delete_keyring_password()
if self.api_key_storage_method == KeyStorageMethodEnum.system:
keyring.delete_password(APP_NAME, str(self.id))


class AccountManager(RootModel):
root: list[OnErrorOmit[Account]] = Field(default=list())
Expand Down Expand Up @@ -197,7 +280,12 @@ def serialize_account_config(self) -> list[dict[str, Any]]:
lambda x: x.source == AccountSource.CONFIG, self.root
)
return [
acc.model_dump(mode="json", by_alias=True, exclude_none=True)
acc.model_dump(
mode="json",
by_alias=True,
exclude_none=True,
exclude_defaults=True,
)
for acc in accounts_config
]

Expand All @@ -215,6 +303,7 @@ def get_accounts_by_provider(
return filter(lambda x: x.provider.name == provider_name, self.root)

def remove(self, account: Account):
account.delete_keyring_password()
self.root.remove(account)

def clear(self):
Expand Down
Loading

0 comments on commit 068d3ac

Please sign in to comment.