Skip to content

Commit

Permalink
feat: centralize global variable management (#3284)
Browse files Browse the repository at this point in the history
* test: add tests for global variable endpoints

* test: add unit tests variable service

* fix: anticipate checks to prevent the code from breaking

* feat: add a new method to interface

* feat: add method to update fields in variable service

* feat: replace variable api code

* fix: mypy error

* fix: mypy error

* feat(variable): Allow deleting variables by name or ID in DatabaseVariableService.

* refactor(api): Simplify delete method in variable router.

---------

Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
  • Loading branch information
italojohnny and ogabrielluiz authored Aug 13, 2024
1 parent 6658426 commit 952ba5e
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 67 deletions.
90 changes: 39 additions & 51 deletions src/backend/base/langflow/api/v1/variable.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from datetime import datetime, timezone
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException
from sqlmodel import Session, select
from sqlalchemy.exc import NoResultFound
from sqlmodel import Session

from langflow.services.auth import utils as auth_utils
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.user.model import User
from langflow.services.database.models.variable import Variable, VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_session, get_settings_service
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_session, get_settings_service, get_variable_service
from langflow.services.variable.base import VariableService
from langflow.services.variable.service import GENERIC_TYPE, DatabaseVariableService

router = APIRouter(prefix="/variables", tags=["Variables"])

Expand All @@ -20,36 +21,30 @@ def create_variable(
variable: VariableCreate,
current_user: User = Depends(get_current_active_user),
settings_service=Depends(get_settings_service),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Create a new variable."""
try:
# check if variable name already exists
variable_exists = session.exec(
select(Variable).where(
Variable.name == variable.name,
Variable.user_id == current_user.id,
)
).first()
if variable_exists:
raise HTTPException(status_code=400, detail="Variable name already exists")

variable_dict = variable.model_dump()
variable_dict["user_id"] = current_user.id

db_variable = Variable.model_validate(variable_dict)
if not db_variable.name and not db_variable.value:
if not variable.name and not variable.value:
raise HTTPException(status_code=400, detail="Variable name and value cannot be empty")
elif not db_variable.name:

if not variable.name:
raise HTTPException(status_code=400, detail="Variable name cannot be empty")
elif not db_variable.value:

if not variable.value:
raise HTTPException(status_code=400, detail="Variable value cannot be empty")
encrypted = auth_utils.encrypt_api_key(db_variable.value, settings_service=settings_service)
db_variable.value = encrypted
db_variable.user_id = current_user.id
session.add(db_variable)
session.commit()
session.refresh(db_variable)
return db_variable

if variable.name in variable_service.list_variables(user_id=current_user.id, session=session):
raise HTTPException(status_code=400, detail="Variable name already exists")

return variable_service.create_variable(
user_id=current_user.id,
name=variable.name,
value=variable.value,
default_fields=variable.default_fields or [],
_type=variable.type or GENERIC_TYPE,
session=session,
)
except Exception as e:
if isinstance(e, HTTPException):
raise e
Expand All @@ -61,11 +56,12 @@ def read_variables(
*,
session: Session = Depends(get_session),
current_user: User = Depends(get_current_active_user),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Read all variables."""
try:
variables = session.exec(select(Variable).where(Variable.user_id == current_user.id)).all()
return variables
return variable_service.get_all(user_id=current_user.id, session=session)

except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

Expand All @@ -77,22 +73,19 @@ def update_variable(
variable_id: UUID,
variable: VariableUpdate,
current_user: User = Depends(get_current_active_user),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Update a variable."""
try:
db_variable = session.exec(
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
).first()
if not db_variable:
raise HTTPException(status_code=404, detail="Variable not found")

variable_data = variable.model_dump(exclude_unset=True)
for key, value in variable_data.items():
setattr(db_variable, key, value)
db_variable.updated_at = datetime.now(timezone.utc)
session.commit()
session.refresh(db_variable)
return db_variable
return variable_service.update_variable_fields(
user_id=current_user.id,
variable_id=variable_id,
variable=variable,
session=session,
)
except NoResultFound:
raise HTTPException(status_code=404, detail="Variable not found")

except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e

Expand All @@ -103,15 +96,10 @@ def delete_variable(
session: Session = Depends(get_session),
variable_id: UUID,
current_user: User = Depends(get_current_active_user),
variable_service: VariableService = Depends(get_variable_service),
):
"""Delete a variable."""
try:
db_variable = session.exec(
select(Variable).where(Variable.id == variable_id, Variable.user_id == current_user.id)
).first()
if not db_variable:
raise HTTPException(status_code=404, detail="Variable not found")
session.delete(db_variable)
session.commit()
variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
13 changes: 12 additions & 1 deletion src/backend/base/langflow/services/variable/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def update_variable(self, user_id: Union[UUID, str], name: str, value: str, sess
"""

@abc.abstractmethod
def delete_variable(self, user_id: Union[UUID, str], name: str, session: Session) -> Variable:
def delete_variable(self, user_id: Union[UUID, str], name: str, session: Session) -> None:
"""
Delete a variable.
Expand All @@ -82,6 +82,17 @@ def delete_variable(self, user_id: Union[UUID, str], name: str, session: Session
The deleted variable.
"""

@abc.abstractmethod
def delete_variable_by_id(self, user_id: Union[UUID, str], variable_id: UUID, session: Session) -> None:
"""
Delete a variable by ID.
Args:
user_id: The user ID.
variable_id: The ID of the variable.
session: The database session.
"""

@abc.abstractmethod
def create_variable(
self,
Expand Down
16 changes: 8 additions & 8 deletions src/backend/base/langflow/services/variable/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from typing import Optional, Tuple, Union
from uuid import UUID

from loguru import logger
from sqlmodel import Session

from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate
from langflow.services.settings.service import SettingsService
from langflow.services.variable.base import VariableService
from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id
from langflow.services.variable.service import CREDENTIAL_TYPE, GENERIC_TYPE
from loguru import logger
from sqlmodel import Session


class KubernetesSecretService(VariableService, Service):
Expand Down Expand Up @@ -110,17 +111,16 @@ def update_variable(
secret_key, _ = self.resolve_variable(secret_name, user_id, name)
return self.kubernetes_secrets.update_secret(name=secret_name, data={secret_key: value})

def delete_variable(
self,
user_id: Union[UUID, str],
name: str,
_session: Session,
):
def delete_variable(self, user_id: Union[UUID, str], name: str, _session: Session) -> None:
secret_name = encode_user_id(user_id)

secret_key, _ = self.resolve_variable(secret_name, user_id, name)
self.kubernetes_secrets.delete_secret_key(name=secret_name, key=secret_key)
return

def delete_variable_by_id(self, user_id: Union[UUID, str], variable_id: UUID | str, _session: Session) -> None:
self.delete_variable(user_id, _session, str(variable_id))

def create_variable(
self,
user_id: Union[UUID, str],
Expand Down
48 changes: 41 additions & 7 deletions src/backend/base/langflow/services/variable/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Optional, Union
from uuid import UUID

Expand All @@ -8,7 +9,7 @@

from langflow.services.auth import utils as auth_utils
from langflow.services.base import Service
from langflow.services.database.models.variable.model import Variable, VariableCreate
from langflow.services.database.models.variable.model import Variable, VariableCreate, VariableUpdate
from langflow.services.deps import get_session
from langflow.services.variable.base import VariableService

Expand Down Expand Up @@ -76,21 +77,25 @@ def get_variable(
# credential = session.query(Variable).filter(Variable.user_id == user_id, Variable.name == name).first()
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()

if not variable or not variable.value:
raise ValueError(f"{name} variable not found.")

if variable.type == CREDENTIAL_TYPE and field == "session_id": # type: ignore
raise TypeError(
f"variable {name} of type 'Credential' cannot be used in a Session ID field "
"because its purpose is to prevent the exposure of values."
)

# we decrypt the value
if not variable or not variable.value:
raise ValueError(f"{name} variable not found.")
decrypted = auth_utils.decrypt_api_key(variable.value, settings_service=self.settings_service)
return decrypted

def get_all(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[Variable]]:
return list(session.exec(select(Variable).where(Variable.user_id == user_id)).all())

def list_variables(self, user_id: Union[UUID, str], session: Session = Depends(get_session)) -> list[Optional[str]]:
variables = session.exec(select(Variable).where(Variable.user_id == user_id)).all()
return [variable.name for variable in variables]
variables = self.get_all(user_id=user_id, session=session)
return [variable.name for variable in variables if variable]

def update_variable(
self,
Expand All @@ -109,18 +114,47 @@ def update_variable(
session.refresh(variable)
return variable

def update_variable_fields(
self,
user_id: Union[UUID, str],
variable_id: Union[UUID, str],
variable: VariableUpdate,
session: Session = Depends(get_session),
):
query = select(Variable).where(Variable.id == variable_id, Variable.user_id == user_id)
db_variable = session.exec(query).one()

variable_data = variable.model_dump(exclude_unset=True)
for key, value in variable_data.items():
setattr(db_variable, key, value)
db_variable.updated_at = datetime.now(timezone.utc)
encrypted = auth_utils.encrypt_api_key(db_variable.value, settings_service=self.settings_service)
variable.value = encrypted

session.add(db_variable)
session.commit()
session.refresh(db_variable)
return db_variable

def delete_variable(
self,
user_id: Union[UUID, str],
name: str,
session: Session = Depends(get_session),
):
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.name == name)).first()
stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name)
variable = session.exec(stmt).first()
if not variable:
raise ValueError(f"{name} variable not found.")
session.delete(variable)
session.commit()
return variable

def delete_variable_by_id(self, user_id: Union[UUID, str], variable_id: UUID, session: Session):
variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)).first()
if not variable:
raise ValueError(f"{variable_id} variable not found.")
session.delete(variable)
session.commit()

def create_variable(
self,
Expand Down
Loading

0 comments on commit 952ba5e

Please sign in to comment.