diff --git a/.github/workflows/publish-docker-dev.yml b/.github/workflows/publish-docker-dev.yml index 3cdaa72548a..2a755385f80 100644 --- a/.github/workflows/publish-docker-dev.yml +++ b/.github/workflows/publish-docker-dev.yml @@ -51,28 +51,15 @@ jobs: notebook: tests/tests.ipynb image: ghcr.io/${{ needs.build-agixt.outputs.github_user }}/${{ needs.build-agixt.outputs.repo_name }}:${{ github.sha }} port: "7437" - db-connected: false report-name: "agixt-tests" additional-python-dependencies: agixtsdk needs: build-agixt - test-agixt-db: - uses: josh-xt/AGiXT/.github/workflows/operation-test-with-jupyter.yml@main - with: - notebook: tests/tests.ipynb - image: ghcr.io/${{ needs.build-agixt.outputs.github_user }}/${{ needs.build-agixt.outputs.repo_name }}:${{ github.sha }} - port: "7437" - db-connected: true - database-type: "postgresql" - report-name: "agixt-db-tests" - additional-python-dependencies: agixtsdk - needs: build-agixt test-completions: uses: josh-xt/AGiXT/.github/workflows/operation-test-with-jupyter.yml@main with: notebook: tests/completions-tests.ipynb image: ghcr.io/${{ needs.build-agixt.outputs.github_user }}/${{ needs.build-agixt.outputs.repo_name }}:${{ github.sha }} port: "7437" - db-connected: false report-name: "completions-tests" additional-python-dependencies: openai requests python-dotenv needs: build-agixt \ No newline at end of file diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml index db0799feead..39e1d65f546 100644 --- a/.github/workflows/publish-docker.yml +++ b/.github/workflows/publish-docker.yml @@ -31,28 +31,15 @@ jobs: notebook: tests/tests.ipynb image: ${{ needs.build-agixt.outputs.primary-image }} port: "7437" - db-connected: false report-name: "agixt-tests" additional-python-dependencies: agixtsdk needs: build-agixt - test-agixt-db: - uses: josh-xt/AGiXT/.github/workflows/operation-test-with-jupyter.yml@main - with: - notebook: tests/tests.ipynb - image: ${{ needs.build-agixt.outputs.primary-image }} - port: "7437" - db-connected: true - database-type: "postgresql" - report-name: "agixt-db-tests" - additional-python-dependencies: agixtsdk - needs: build-agixt test-completions: uses: josh-xt/AGiXT/.github/workflows/operation-test-with-jupyter.yml@main with: notebook: tests/completions-tests.ipynb image: ${{ needs.build-agixt.outputs.primary-image }} port: "7437" - db-connected: false report-name: "completions-tests" additional-python-dependencies: openai requests python-dotenv needs: build-agixt \ No newline at end of file diff --git a/agixt/AGiXT.py b/agixt/AGiXT.py index dd2cb8b65b6..494ee8b0a12 100644 --- a/agixt/AGiXT.py +++ b/agixt/AGiXT.py @@ -4,7 +4,7 @@ from Extensions import Extensions from Chains import Chains from pydub import AudioSegment -from Defaults import getenv, get_tokens, DEFAULT_SETTINGS +from Globals import getenv, get_tokens, DEFAULT_SETTINGS from Models import ChatCompletions import os import base64 diff --git a/agixt/db/Agent.py b/agixt/Agent.py similarity index 97% rename from agixt/db/Agent.py rename to agixt/Agent.py index 4047e6fef02..919011b059e 100644 --- a/agixt/db/Agent.py +++ b/agixt/Agent.py @@ -1,4 +1,4 @@ -from DBConnection import ( +from DB import ( Agent as AgentModel, AgentSetting as AgentSettingModel, AgentBrowsedLink, @@ -15,7 +15,7 @@ ) from Providers import Providers from Extensions import Extensions -from Defaults import getenv, DEFAULT_SETTINGS, DEFAULT_USER +from Globals import getenv, DEFAULT_SETTINGS, DEFAULT_USER from datetime import datetime, timezone, timedelta import logging import json @@ -169,9 +169,14 @@ class Agent: def __init__(self, agent_name=None, user=DEFAULT_USER, ApiClient=None): self.agent_name = agent_name if agent_name is not None else "AGiXT" self.session = get_session() - self.user = user - user_data = self.session.query(User).filter(User.email == self.user).first() - self.user_id = user_data.id + user = user if user is not None else DEFAULT_USER + self.user = user.lower() + try: + user_data = self.session.query(User).filter(User.email == self.user).first() + self.user_id = user_data.id + except Exception as e: + logging.error(f"User {self.user} not found.") + raise self.AGENT_CONFIG = self.get_agent_config() self.load_config_keys() if "settings" not in self.AGENT_CONFIG: diff --git a/agixt/ApiClient.py b/agixt/ApiClient.py index 351c62745c9..82fb43d6c0a 100644 --- a/agixt/ApiClient.py +++ b/agixt/ApiClient.py @@ -2,30 +2,21 @@ import jwt from agixtsdk import AGiXTSDK from fastapi import Header, HTTPException -from Defaults import getenv +from Globals import getenv from datetime import datetime logging.basicConfig( level=getenv("LOG_LEVEL"), format=getenv("LOG_FORMAT"), ) -DB_CONNECTED = True if getenv("DB_CONNECTED").lower() == "true" else False WORKERS = int(getenv("UVICORN_WORKERS")) AGIXT_URI = getenv("AGIXT_URI") # Defining these here to be referenced externally. -if DB_CONNECTED: - from db.Agent import Agent, add_agent, delete_agent, rename_agent, get_agents - from db.Chain import Chain - from db.Prompts import Prompts - from db.Conversations import Conversations - from db.User import User -else: - from fb.Agent import Agent, add_agent, delete_agent, rename_agent, get_agents - from fb.Chain import Chain - from fb.Prompts import Prompts - from fb.Conversations import Conversations - from Models import User_fb as User +from Agent import Agent, add_agent, delete_agent, rename_agent, get_agents +from Chain import Chain +from Prompts import Prompts +from Conversations import Conversations def verify_api_key(authorization: str = Header(None)): @@ -34,7 +25,7 @@ def verify_api_key(authorization: str = Header(None)): DEFAULT_USER = getenv("DEFAULT_USER") authorization = str(authorization).replace("Bearer ", "").replace("bearer ", "") if DEFAULT_USER == "" or DEFAULT_USER is None or DEFAULT_USER == "None": - DEFAULT_USER = "USER" + DEFAULT_USER = "user" if getenv("AUTH_PROVIDER") == "magicalauth": auth_key = AGIXT_API_KEY + str(datetime.now().strftime("%Y%m%d")) try: @@ -88,20 +79,15 @@ def is_admin(email: str = "USER", api_key: str = None): return True # Commenting out functionality until testing is complete. AGIXT_API_KEY = getenv("AGIXT_API_KEY") - DB_CONNECTED = True if getenv("DB_CONNECTED").lower() == "true" else False - if DB_CONNECTED != True: - return True if api_key is None: api_key = "" api_key = api_key.replace("Bearer ", "").replace("bearer ", "") if AGIXT_API_KEY == api_key: return True - if DB_CONNECTED == True: - from db.User import is_agixt_admin + if email == "" or email is None or email == "None": + email = getenv("DEFAULT_USER") if email == "" or email is None or email == "None": - email = getenv("DEFAULT_USER") - if email == "" or email is None or email == "None": - email = "USER" - return is_agixt_admin(email=email, api_key=api_key) + email = "USER" + return is_agixt_admin(email=email, api_key=api_key) return False diff --git a/agixt/db/Chain.py b/agixt/Chain.py similarity index 99% rename from agixt/db/Chain.py rename to agixt/Chain.py index b0ac07e0756..054a402b3b9 100644 --- a/agixt/db/Chain.py +++ b/agixt/Chain.py @@ -1,4 +1,4 @@ -from DBConnection import ( +from DB import ( get_session, Chain as ChainDB, ChainStep, @@ -10,7 +10,7 @@ Command, User, ) -from Defaults import getenv, DEFAULT_USER +from Globals import getenv, DEFAULT_USER import logging logging.basicConfig( diff --git a/agixt/Chains.py b/agixt/Chains.py index cf2919ff07c..6951a5055c3 100644 --- a/agixt/Chains.py +++ b/agixt/Chains.py @@ -1,5 +1,5 @@ import logging -from Defaults import getenv +from Globals import getenv from ApiClient import Chain, Prompts, Conversations from Extensions import Extensions diff --git a/agixt/db/Conversations.py b/agixt/Conversations.py similarity index 99% rename from agixt/db/Conversations.py rename to agixt/Conversations.py index 18aaf2cd794..420e546ecdb 100644 --- a/agixt/db/Conversations.py +++ b/agixt/Conversations.py @@ -1,12 +1,12 @@ from datetime import datetime import logging -from DBConnection import ( +from DB import ( Conversation, Message, User, get_session, ) -from Defaults import getenv, DEFAULT_USER +from Globals import getenv, DEFAULT_USER logging.basicConfig( level=getenv("LOG_LEVEL"), diff --git a/agixt/DB.py b/agixt/DB.py new file mode 100644 index 00000000000..6d78c287951 --- /dev/null +++ b/agixt/DB.py @@ -0,0 +1,481 @@ +import uuid +import time +import logging +from sqlalchemy import ( + create_engine, + Column, + Text, + String, + Integer, + ForeignKey, + DateTime, + Boolean, + func, +) +from sqlalchemy.orm import sessionmaker, relationship, declarative_base +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.sql import text +from Globals import getenv + +logging.basicConfig( + level=getenv("LOG_LEVEL"), + format=getenv("LOG_FORMAT"), +) +DEFAULT_USER = getenv("DEFAULT_USER") +try: + DATABASE_TYPE = getenv("DATABASE_TYPE") + DATABASE_NAME = getenv("DATABASE_NAME") + if DATABASE_TYPE != "sqlite": + DATABASE_USER = getenv("DATABASE_USER") + DATABASE_PASSWORD = getenv("DATABASE_PASSWORD") + DATABASE_HOST = getenv("DATABASE_HOST") + DATABASE_PORT = getenv("DATABASE_PORT") + LOGIN_URI = f"{DATABASE_USER}:{DATABASE_PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_NAME}" + DATABASE_URI = f"postgresql://{LOGIN_URI}" + else: + DATABASE_URI = f"sqlite:///{DATABASE_NAME}.db" + engine = create_engine(DATABASE_URI, pool_size=40, max_overflow=-1) + connection = engine.connect() + Base = declarative_base() +except Exception as e: + logging.error(f"Error connecting to database: {e}") + Base = None + engine = None + + +def get_session(): + Session = sessionmaker(bind=engine, autoflush=False) + session = Session() + return session + + +class User(Base): + __tablename__ = "user" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + email = Column(String, unique=True) + first_name = Column(String, default="", nullable=True) + last_name = Column(String, default="", nullable=True) + admin = Column(Boolean, default=False, nullable=False) + mfa_token = Column(String, default="", nullable=True) + created_at = Column(DateTime, server_default=func.now()) + updated_at = Column(DateTime, server_default=func.now(), onupdate=func.now()) + is_active = Column(Boolean, default=True) + + +class FailedLogins(Base): + __tablename__ = "failed_logins" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + ) + user = relationship("User") + ip_address = Column(String, default="", nullable=True) + created_at = Column(DateTime, server_default=func.now()) + + +class Provider(Base): + __tablename__ = "provider" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + provider_settings = relationship("ProviderSetting", backref="provider") + + +class ProviderSetting(Base): + __tablename__ = "provider_setting" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + provider_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("provider.id"), + nullable=False, + ) + name = Column(Text, nullable=False) + value = Column(Text, nullable=True) + + +class AgentProviderSetting(Base): + __tablename__ = "agent_provider_setting" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + provider_setting_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("provider_setting.id"), + nullable=False, + ) + agent_provider_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent_provider.id"), + nullable=False, + ) + value = Column(Text, nullable=True) + + +class AgentProvider(Base): + __tablename__ = "agent_provider" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + provider_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("provider.id"), + nullable=False, + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + settings = relationship("AgentProviderSetting", backref="agent_provider") + + +class AgentBrowsedLink(Base): + __tablename__ = "agent_browsed_link" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + link = Column(Text, nullable=False) + timestamp = Column(DateTime, server_default=text("now()")) + + +class Agent(Base): + __tablename__ = "agent" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + provider_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("provider.id"), + nullable=True, + default=None, + ) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + settings = relationship("AgentSetting", backref="agent") # One-to-many relationship + browsed_links = relationship("AgentBrowsedLink", backref="agent") + user = relationship("User", backref="agent") + + +class Command(Base): + __tablename__ = "command" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + extension_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("extension.id"), + ) + extension = relationship("Extension", backref="commands") + + +class AgentCommand(Base): + __tablename__ = "agent_command" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + command_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("command.id"), + nullable=False, + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + state = Column(Boolean, nullable=False) + command = relationship("Command") # Add this line to define the relationship + + +class Conversation(Base): + __tablename__ = "conversation" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + user = relationship("User", backref="conversation") + + +class Message(Base): + __tablename__ = "message" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + role = Column(Text, nullable=False) + content = Column(Text, nullable=False) + timestamp = Column(DateTime, server_default=text("now()")) + conversation_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("conversation.id"), + nullable=False, + ) + + +class Setting(Base): + __tablename__ = "setting" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + extension_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("extension.id"), + ) + value = Column(Text) + + +class AgentSetting(Base): + __tablename__ = "agent_setting" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + name = Column(String) + value = Column(String) + + +class Chain(Base): + __tablename__ = "chain" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + description = Column(Text, nullable=True) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + steps = relationship( + "ChainStep", + backref="chain", + cascade="all, delete", # Add the cascade option for deleting steps + passive_deletes=True, + foreign_keys="ChainStep.chain_id", + ) + target_steps = relationship( + "ChainStep", backref="target_chain", foreign_keys="ChainStep.target_chain_id" + ) + user = relationship("User", backref="chain") + + +class ChainStep(Base): + __tablename__ = "chain_step" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + chain_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain.id", ondelete="CASCADE"), + nullable=False, + ) + agent_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("agent.id"), + nullable=False, + ) + prompt_type = Column(Text) # Add the prompt_type field + prompt = Column(Text) # Add the prompt field + target_chain_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain.id", ondelete="SET NULL"), + ) + target_command_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("command.id", ondelete="SET NULL"), + ) + target_prompt_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("prompt.id", ondelete="SET NULL"), + ) + step_number = Column(Integer, nullable=False) + responses = relationship( + "ChainStepResponse", backref="chain_step", cascade="all, delete" + ) + + def add_response(self, content): + session = get_session() + response = ChainStepResponse(content=content, chain_step=self) + session.add(response) + session.commit() + + +class ChainStepArgument(Base): + __tablename__ = "chain_step_argument" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + argument_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("argument.id"), + nullable=False, + ) + chain_step_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain_step.id", ondelete="CASCADE"), + nullable=False, # Add the ondelete option + ) + value = Column(Text, nullable=True) + + +class ChainStepResponse(Base): + __tablename__ = "chain_step_response" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + chain_step_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain_step.id", ondelete="CASCADE"), + nullable=False, # Add the ondelete option + ) + timestamp = Column(DateTime, server_default=text("now()")) + content = Column(Text, nullable=False) + + +class Extension(Base): + __tablename__ = "extension" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + description = Column(Text, nullable=True, default="") + + +class Argument(Base): + __tablename__ = "argument" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + prompt_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("prompt.id"), + ) + command_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("command.id"), + ) + chain_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("chain.id"), + ) + name = Column(Text, nullable=False) + + +class PromptCategory(Base): + __tablename__ = "prompt_category" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + name = Column(Text, nullable=False) + description = Column(Text, nullable=False) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + user = relationship("User", backref="prompt_category") + + +class Prompt(Base): + __tablename__ = "prompt" + id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + primary_key=True, + default=uuid.uuid4 if DATABASE_TYPE != "sqlite" else str(uuid.uuid4()), + ) + prompt_category_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("prompt_category.id"), + nullable=False, + ) + name = Column(Text, nullable=False) + description = Column(Text, nullable=False) + content = Column(Text, nullable=False) + user_id = Column( + UUID(as_uuid=True) if DATABASE_TYPE != "sqlite" else String, + ForeignKey("user.id"), + nullable=True, + ) + prompt_category = relationship("PromptCategory", backref="prompts") + user = relationship("User", backref="prompt") + arguments = relationship("Argument", backref="prompt", cascade="all, delete-orphan") + + +if __name__ == "__main__": + logging.info("Connecting to database...") + time.sleep(10) + Base.metadata.create_all(engine) + logging.info("Connected to database.") + # Check if the user table is empty + from SeedImports import import_all_data + + import_all_data() diff --git a/agixt/DBConnection.py b/agixt/DBConnection.py deleted file mode 100644 index 4f372dc8f04..00000000000 --- a/agixt/DBConnection.py +++ /dev/null @@ -1,287 +0,0 @@ -import uuid -import time -import logging -from sqlalchemy import ( - create_engine, - Column, - Text, - String, - Integer, - ForeignKey, - DateTime, - Boolean, -) -from sqlalchemy.orm import sessionmaker, relationship, declarative_base -from sqlalchemy.dialects.postgresql import UUID -from sqlalchemy.sql import text -from Defaults import getenv - -logging.basicConfig( - level=getenv("LOG_LEVEL"), - format=getenv("LOG_FORMAT"), -) -DB_CONNECTED = True if getenv("DB_CONNECTED").lower() == "true" else False -DEFAULT_USER = getenv("DEFAULT_USER") -if DB_CONNECTED: - DATABASE_USER = getenv("DATABASE_USER") - DATABASE_PASSWORD = getenv("DATABASE_PASSWORD") - DATABASE_HOST = getenv("DATABASE_HOST") - DATABASE_PORT = getenv("DATABASE_PORT") - DATABASE_NAME = getenv("DATABASE_NAME") - LOGIN_URI = f"{DATABASE_USER}:{DATABASE_PASSWORD}@{DATABASE_HOST}:{DATABASE_PORT}/{DATABASE_NAME}" - DATABASE_URL = f"postgresql://{LOGIN_URI}" - try: - engine = create_engine(DATABASE_URL, pool_size=40, max_overflow=-1) - except Exception as e: - logging.error(f"Error connecting to database: {e}") - connection = engine.connect() - Base = declarative_base() -else: - Base = None - engine = None - - -def get_session(): - Session = sessionmaker(bind=engine, autoflush=False) - session = Session() - return session - - -class User(Base): - __tablename__ = "user" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - email = Column(String, default=DEFAULT_USER, unique=True) - role = Column(String, default="user") - - -class Provider(Base): - __tablename__ = "provider" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - provider_settings = relationship("ProviderSetting", backref="provider") - - -class ProviderSetting(Base): - __tablename__ = "provider_setting" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - provider_id = Column(UUID(as_uuid=True), ForeignKey("provider.id"), nullable=False) - name = Column(Text, nullable=False) - value = Column(Text, nullable=True) - - -class AgentProviderSetting(Base): - __tablename__ = "agent_provider_setting" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - provider_setting_id = Column( - UUID(as_uuid=True), ForeignKey("provider_setting.id"), nullable=False - ) - agent_provider_id = Column( - UUID(as_uuid=True), ForeignKey("agent_provider.id"), nullable=False - ) - value = Column(Text, nullable=True) - - -class AgentProvider(Base): - __tablename__ = "agent_provider" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - provider_id = Column(UUID(as_uuid=True), ForeignKey("provider.id"), nullable=False) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - settings = relationship("AgentProviderSetting", backref="agent_provider") - - -class AgentBrowsedLink(Base): - __tablename__ = "agent_browsed_link" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - link = Column(Text, nullable=False) - timestamp = Column(DateTime, server_default=text("now()")) - - -class Agent(Base): - __tablename__ = "agent" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - provider_id = Column( - UUID(as_uuid=True), ForeignKey("provider.id"), nullable=True, default=None - ) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - settings = relationship("AgentSetting", backref="agent") # One-to-many relationship - browsed_links = relationship("AgentBrowsedLink", backref="agent") - user = relationship("User", backref="agent") - - -class Command(Base): - __tablename__ = "command" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - extension_id = Column(UUID(as_uuid=True), ForeignKey("extension.id")) - extension = relationship("Extension", backref="commands") - - -class AgentCommand(Base): - __tablename__ = "agent_command" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - command_id = Column(UUID(as_uuid=True), ForeignKey("command.id"), nullable=False) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - state = Column(Boolean, nullable=False) - command = relationship("Command") # Add this line to define the relationship - - -class Conversation(Base): - __tablename__ = "conversation" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - user = relationship("User", backref="conversation") - - -class Message(Base): - __tablename__ = "message" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - role = Column(Text, nullable=False) - content = Column(Text, nullable=False) - timestamp = Column(DateTime, server_default=text("now()")) - conversation_id = Column( - UUID(as_uuid=True), ForeignKey("conversation.id"), nullable=False - ) - - -class Setting(Base): - __tablename__ = "setting" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - extension_id = Column(UUID(as_uuid=True), ForeignKey("extension.id")) - value = Column(Text) - - -class AgentSetting(Base): - __tablename__ = "agent_setting" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - name = Column(String) - value = Column(String) - - -class Chain(Base): - __tablename__ = "chain" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - description = Column(Text, nullable=True) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - steps = relationship( - "ChainStep", - backref="chain", - cascade="all, delete", # Add the cascade option for deleting steps - passive_deletes=True, - foreign_keys="ChainStep.chain_id", - ) - target_steps = relationship( - "ChainStep", backref="target_chain", foreign_keys="ChainStep.target_chain_id" - ) - user = relationship("User", backref="chain") - - -class ChainStep(Base): - __tablename__ = "chain_step" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - chain_id = Column( - UUID(as_uuid=True), ForeignKey("chain.id", ondelete="CASCADE"), nullable=False - ) - agent_id = Column(UUID(as_uuid=True), ForeignKey("agent.id"), nullable=False) - prompt_type = Column(Text) # Add the prompt_type field - prompt = Column(Text) # Add the prompt field - target_chain_id = Column( - UUID(as_uuid=True), ForeignKey("chain.id", ondelete="SET NULL") - ) - target_command_id = Column( - UUID(as_uuid=True), ForeignKey("command.id", ondelete="SET NULL") - ) - target_prompt_id = Column( - UUID(as_uuid=True), ForeignKey("prompt.id", ondelete="SET NULL") - ) - step_number = Column(Integer, nullable=False) - responses = relationship( - "ChainStepResponse", backref="chain_step", cascade="all, delete" - ) - - def add_response(self, content): - session = get_session() - response = ChainStepResponse(content=content, chain_step=self) - session.add(response) - session.commit() - - -class ChainStepArgument(Base): - __tablename__ = "chain_step_argument" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - argument_id = Column(UUID(as_uuid=True), ForeignKey("argument.id"), nullable=False) - chain_step_id = Column( - UUID(as_uuid=True), - ForeignKey("chain_step.id", ondelete="CASCADE"), - nullable=False, # Add the ondelete option - ) - value = Column(Text, nullable=True) - - -class ChainStepResponse(Base): - __tablename__ = "chain_step_response" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - chain_step_id = Column( - UUID(as_uuid=True), - ForeignKey("chain_step.id", ondelete="CASCADE"), - nullable=False, # Add the ondelete option - ) - timestamp = Column(DateTime, server_default=text("now()")) - content = Column(Text, nullable=False) - - -class Extension(Base): - __tablename__ = "extension" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - description = Column(Text, nullable=True, default="") - - -class Argument(Base): - __tablename__ = "argument" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - prompt_id = Column(UUID(as_uuid=True), ForeignKey("prompt.id")) - command_id = Column(UUID(as_uuid=True), ForeignKey("command.id")) - chain_id = Column(UUID(as_uuid=True), ForeignKey("chain.id")) - name = Column(Text, nullable=False) - - -class PromptCategory(Base): - __tablename__ = "prompt_category" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - name = Column(Text, nullable=False) - description = Column(Text, nullable=False) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - user = relationship("User", backref="prompt_category") - - -class Prompt(Base): - __tablename__ = "prompt" - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - prompt_category_id = Column( - UUID(as_uuid=True), ForeignKey("prompt_category.id"), nullable=False - ) - name = Column(Text, nullable=False) - description = Column(Text, nullable=False) - content = Column(Text, nullable=False) - user_id = Column(UUID(as_uuid=True), ForeignKey("user.id"), nullable=True) - prompt_category = relationship("PromptCategory", backref="prompts") - user = relationship("User", backref="prompt") - arguments = relationship("Argument", backref="prompt", cascade="all, delete-orphan") - - -if __name__ == "__main__": - if DB_CONNECTED: - logging.info("Connecting to database...") - time.sleep(10) - Base.metadata.create_all(engine) - logging.info("Connected to database.") - # Check if the user table is empty - from db.imports import import_all_data - - import_all_data() diff --git a/agixt/Extensions.py b/agixt/Extensions.py index dcc7a3c842c..9c8bbe02bcb 100644 --- a/agixt/Extensions.py +++ b/agixt/Extensions.py @@ -4,7 +4,7 @@ from inspect import signature, Parameter import logging import inspect -from Defaults import getenv, DEFAULT_USER +from Globals import getenv, DEFAULT_USER logging.basicConfig( level=getenv("LOG_LEVEL"), diff --git a/agixt/Defaults.py b/agixt/Globals.py similarity index 94% rename from agixt/Defaults.py rename to agixt/Globals.py index f09bdbef6fe..28efef1582b 100644 --- a/agixt/Defaults.py +++ b/agixt/Globals.py @@ -44,13 +44,12 @@ def getenv(var_name: str): "LOG_LEVEL": "INFO", "LOG_FORMAT": "%(asctime)s | %(levelname)s | %(message)s", "UVICORN_WORKERS": 10, - "DB_CONNECTED": "false", "DATABASE_NAME": "postgres", "DATABASE_USER": "postgres", "DATABASE_PASSWORD": "postgres", "DATABASE_HOST": "localhost", "DATABASE_PORT": "5432", - "DEFAULT_USER": "USER", + "DEFAULT_USER": "user", "USING_JWT": "false", "CHROMA_PORT": "8000", "CHROMA_SSL": "false", @@ -68,4 +67,4 @@ def get_tokens(text: str) -> int: return num_tokens -DEFAULT_USER = getenv("DEFAULT_USER") +DEFAULT_USER = str(getenv("DEFAULT_USER")).lower() diff --git a/agixt/Interactions.py b/agixt/Interactions.py index 1746a495c6c..b58eef72993 100644 --- a/agixt/Interactions.py +++ b/agixt/Interactions.py @@ -18,7 +18,7 @@ Conversations, AGIXT_URI, ) -from Defaults import getenv, DEFAULT_USER, get_tokens +from Globals import getenv, DEFAULT_USER, get_tokens logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -470,8 +470,8 @@ async def run( websearch_depth=websearch_depth, websearch_timeout=websearch_timeout, ) - except: - logging.warning("Failed to websearch.") + except Exception as e: + logging.warning("Failed to websearch. Error: {e}") vision_response = "" if "vision_provider" in self.agent.AGENT_CONFIG["settings"]: vision_provider = self.agent.AGENT_CONFIG["settings"]["vision_provider"] diff --git a/agixt/MagicalAuth.py b/agixt/MagicalAuth.py new file mode 100644 index 00000000000..471580bbf52 --- /dev/null +++ b/agixt/MagicalAuth.py @@ -0,0 +1,425 @@ +from DB import User, FailedLogins, get_session +from Models import UserInfo, Register, Login +from fastapi import Header, HTTPException +from Globals import getenv +from datetime import datetime, timedelta +from Agent import add_agent +from agixtsdk import AGiXTSDK +from fastapi import HTTPException +from sendgrid import SendGridAPIClient +from sendgrid.helpers.mail import ( + Attachment, + FileContent, + FileName, + FileType, + Disposition, + Mail, +) +import pyotp +import requests +import logging +import jwt + + +logging.basicConfig( + level=getenv("LOG_LEVEL"), + format=getenv("LOG_FORMAT"), +) +""" +Required environment variables: + +- SENDGRID_API_KEY: SendGrid API key +- SENDGRID_FROM_EMAIL: Default email address to send emails from +- ENCRYPTION_SECRET: Encryption key to encrypt and decrypt data +- MAGIC_LINK_URL: URL to send in the email for the user to click on +- REGISTRATION_WEBHOOK: URL to send a POST request to when a user registers +""" + + +def is_agixt_admin(email: str = "", api_key: str = ""): + if api_key == getenv("AGIXT_API_KEY"): + return True + session = get_session() + user = session.query(User).filter_by(email=email).first() + if not user: + return False + if user.admin is True: + return True + return False + + +def webhook_create_user( + api_key: str, + email: str, + role: str = "user", + agent_name: str = "", + settings: dict = {}, + commands: dict = {}, + training_urls: list = [], + github_repos: list = [], + ApiClient: AGiXTSDK = AGiXTSDK(), +): + if not is_agixt_admin(email=email, api_key=api_key): + return {"error": "Access Denied"}, 403 + session = get_session() + email = email.lower() + user_exists = session.query(User).filter_by(email=email).first() + if user_exists: + session.close() + return {"error": "User already exists"}, 400 + admin = True if role.lower() == "admin" else False + user = User( + email=email, + admin=admin, + first_name="", + last_name="", + ) + session.add(user) + session.commit() + session.close() + if agent_name != "" and agent_name is not None: + add_agent( + agent_name=agent_name, + provider_settings=settings, + commands=commands, + user=email, + ) + if training_urls != []: + for url in training_urls: + ApiClient.learn_url(agent_name=agent_name, url=url) + if github_repos != []: + for repo in github_repos: + ApiClient.learn_github_repo(agent_name=agent_name, github_repo=repo) + return {"status": "Success"}, 200 + + +def verify_api_key(authorization: str = Header(None)): + ENCRYPTION_SECRET = getenv("ENCRYPTION_SECRET") + if getenv("AUTH_PROVIDER") == "magicalauth": + ENCRYPTION_SECRET = f'{ENCRYPTION_SECRET}{datetime.now().strftime("%Y%m%d")}' + authorization = str(authorization).replace("Bearer ", "").replace("bearer ", "") + if ENCRYPTION_SECRET: + if authorization is None: + raise HTTPException( + status_code=401, detail="Authorization header is missing" + ) + if authorization == ENCRYPTION_SECRET: + return "ADMIN" + try: + if authorization == ENCRYPTION_SECRET: + return "ADMIN" + token = jwt.decode( + jwt=authorization, + key=ENCRYPTION_SECRET, + algorithms=["HS256"], + ) + db = get_session() + user = db.query(User).filter(User.id == token["sub"]).first() + db.close() + return user + except Exception as e: + raise HTTPException(status_code=401, detail="Invalid API Key") + else: + return authorization + + +def send_email( + email: str, + subject: str, + body: str, + attachment_content=None, + attachment_file_type=None, + attachment_file_name=None, +): + message = Mail( + from_email=getenv("SENDGRID_FROM_EMAIL"), + to_emails=email, + subject=subject, + html_content=body, + ) + if ( + attachment_content != None + and attachment_file_type != None + and attachment_file_name != None + ): + attachment = Attachment( + FileContent(attachment_content), + FileName(attachment_file_name), + FileType(attachment_file_type), + Disposition("attachment"), + ) + message.attachment = attachment + + try: + response = SendGridAPIClient(getenv("SENDGRID_API_KEY")).send(message) + except Exception as e: + print(e) + raise HTTPException(status_code=400, detail="Email could not be sent.") + if response.status_code != 202: + raise HTTPException(status_code=400, detail="Email could not be sent.") + return None + + +class MagicalAuth: + def __init__(self, token: str = None): + encryption_key = getenv("ENCRYPTION_SECRET") + self.link = getenv("MAGIC_LINK_URL") + self.encryption_key = f'{encryption_key}{datetime.now().strftime("%Y%m%d")}' + self.token = ( + str(token) + .replace("%2B", "+") + .replace("%2F", "/") + .replace("%3D", "=") + .replace("%20", " ") + .replace("%3A", ":") + .replace("%3F", "?") + .replace("%26", "&") + .replace("%23", "#") + .replace("%3B", ";") + .replace("%40", "@") + .replace("%21", "!") + .replace("%24", "$") + .replace("%27", "'") + .replace("%28", "(") + .replace("%29", ")") + .replace("%2A", "*") + .replace("%2C", ",") + .replace("%3B", ";") + .replace("%5B", "[") + .replace("%5D", "]") + .replace("%7B", "{") + .replace("%7D", "}") + .replace("%7C", "|") + .replace("%5C", "\\") + .replace("%5E", "^") + .replace("%60", "`") + .replace("%7E", "~") + .replace("Bearer ", "") + .replace("bearer ", "") + if token + else None + ) + try: + # Decode jwt + decoded = jwt.decode( + jwt=token, key=self.encryption_key, algorithms=["HS256"] + ) + self.email = decoded["email"] + self.token = token + except: + self.email = None + self.token = None + + def user_exists(self, email: str = None): + self.email = email.lower() + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + session.close() + if not user: + raise HTTPException(status_code=404, detail="User not found") + return True + + def add_failed_login(self, ip_address): + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + if user is not None: + failed_login = FailedLogins(user_id=user.id, ip_address=ip_address) + session.add(failed_login) + session.commit() + session.close() + + def count_failed_logins(self): + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + if user is None: + session.close() + return 0 + failed_logins = ( + session.query(FailedLogins) + .filter(FailedLogins.user_id == user.id) + .filter(FailedLogins.created_at >= datetime.now() - timedelta(hours=24)) + .count() + ) + session.close() + return failed_logins + + def send_magic_link(self, ip_address, login: Login, referrer=None): + self.email = login.email.lower() + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + session.close() + if user is None: + raise HTTPException(status_code=404, detail="User not found") + if not pyotp.TOTP(user.mfa_token).verify(login.token): + self.add_failed_login(ip_address=ip_address) + raise HTTPException( + status_code=401, detail="Invalid MFA token. Please try again." + ) + self.token = jwt.encode( + { + "sub": str(user.id), + "email": self.email, + "admin": user.admin, + "exp": datetime.utcnow() + timedelta(hours=24), + }, + self.encryption_key, + algorithm="HS256", + ) + token = ( + self.token.replace("+", "%2B") + .replace("/", "%2F") + .replace("=", "%3D") + .replace(" ", "%20") + .replace(":", "%3A") + .replace("?", "%3F") + .replace("&", "%26") + .replace("#", "%23") + .replace(";", "%3B") + .replace("@", "%40") + .replace("!", "%21") + .replace("$", "%24") + .replace("'", "%27") + .replace("(", "%28") + .replace(")", "%29") + .replace("*", "%2A") + .replace(",", "%2C") + .replace(";", "%3B") + .replace("[", "%5B") + .replace("]", "%5D") + .replace("{", "%7B") + .replace("}", "%7D") + .replace("|", "%7C") + .replace("\\", "%5C") + .replace("^", "%5E") + .replace("`", "%60") + .replace("~", "%7E") + ) + if referrer is not None: + self.link = referrer + magic_link = f"{self.link}?token={token}" + if ( + getenv("SENDGRID_API_KEY") != "" + and str(getenv("SENDGRID_API_KEY")).lower() != "none" + and getenv("SENDGRID_FROM_EMAIL") != "" + and str(getenv("SENDGRID_FROM_EMAIL")).lower() != "none" + ): + send_email( + email=self.email, + subject="Magic Link", + body=f"Click here to log in", + ) + else: + return magic_link + # Upon clicking the link, the front end will call the login method and save the email and encrypted_id in the session + return f"A login link has been sent to {self.email}, please check your email and click the link to log in. The link will expire in 24 hours." + + def login(self, ip_address): + """ " + Login method to verify the token and return the user object + + :param ip_address: IP address of the user + :return: User object + """ + session = get_session() + failures = self.count_failed_logins() + if failures >= 50: + raise HTTPException( + status_code=429, + detail="Too many failed login attempts today. Please try again tomorrow.", + ) + try: + user_info = jwt.decode( + jwt=self.token, key=self.encryption_key, algorithms=["HS256"] + ) + except: + self.add_failed_login(ip_address=ip_address) + raise HTTPException( + status_code=401, + detail="Invalid login token. Please log out and try again.", + ) + user_id = user_info["sub"] + user = session.query(User).filter(User.id == user_id).first() + session.close() + if user is None: + raise HTTPException(status_code=404, detail="User not found") + if str(user.id) == str(user_id): + return user + self.add_failed_login(ip_address=ip_address) + raise HTTPException( + status_code=401, + detail="Invalid login token. Please log out and try again.", + ) + + def register( + self, + new_user: Register, + ): + new_user.email = new_user.email.lower() + self.email = new_user.email + allowed_domains = getenv("ALLOWED_DOMAINS") + if allowed_domains is None or allowed_domains == "": + allowed_domains = "*" + if allowed_domains != "*": + if "," in allowed_domains: + allowed_domains = allowed_domains.split(",") + else: + allowed_domains = [allowed_domains] + domain = self.email.split("@")[1] + if domain not in allowed_domains: + raise HTTPException( + status_code=403, + detail="Registration is not allowed for this domain.", + ) + session = get_session() + user = session.query(User).filter(User.email == self.email).first() + if user is not None: + session.close() + raise HTTPException( + status_code=409, detail="User already exists with this email." + ) + mfa_token = pyotp.random_base32() + user = User( + mfa_token=mfa_token, + **new_user.model_dump(), + ) + session.add(user) + session.commit() + session.close() + # Send registration webhook out to third party application such as AGiXT to create a user there. + registration_webhook = getenv("REGISTRATION_WEBHOOK") + if registration_webhook: + try: + requests.post( + registration_webhook, + json={"email": self.email}, + headers={"Authorization": getenv("ENCRYPTION_SECRET")}, + ) + except Exception as e: + pass + # Return mfa_token for QR code generation + return mfa_token + + def update_user(self, **kwargs): + user = verify_api_key(self.token) + if user is None: + raise HTTPException(status_code=404, detail="User not found") + session = get_session() + user = session.query(User).filter(User.id == user.id).first() + allowed_keys = list(UserInfo.__annotations__.keys()) + for key, value in kwargs.items(): + if key in allowed_keys: + setattr(user, key, value) + session.commit() + session.close() + return "User updated successfully" + + def delete_user(self): + user = verify_api_key(self.token) + if user is None: + raise HTTPException(status_code=404, detail="User not found") + session = get_session() + user = session.query(User).filter(User.id == user.id).first() + user.is_active = False + session.commit() + session.close() + return "User deleted successfully" diff --git a/agixt/Memories.py b/agixt/Memories.py index 3752e5781aa..6c2040415a9 100644 --- a/agixt/Memories.py +++ b/agixt/Memories.py @@ -14,7 +14,7 @@ from datetime import datetime from collections import Counter from typing import List -from Defaults import getenv, DEFAULT_USER +from Globals import getenv, DEFAULT_USER logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -154,9 +154,9 @@ def __init__( global DEFAULT_USER self.agent_name = agent_name if not DEFAULT_USER: - DEFAULT_USER = "USER" + DEFAULT_USER = "user" if not user: - user = "USER" + user = "user" if user != DEFAULT_USER: self.collection_name = f"{snake(user)}_{snake(agent_name)}" else: diff --git a/agixt/Models.py b/agixt/Models.py index 26a5e7197ea..2fb75888dfa 100644 --- a/agixt/Models.py +++ b/agixt/Models.py @@ -1,6 +1,6 @@ from pydantic import BaseModel from typing import Optional, Dict, List, Any, Union -from Defaults import DEFAULT_USER +from Globals import DEFAULT_USER class AgentName(BaseModel): @@ -266,7 +266,7 @@ class CommandExecution(BaseModel): conversation_name: str = "AGiXT Terminal Command Execution" -class User(BaseModel): +class WebhookUser(BaseModel): email: str agent_name: Optional[str] = "" settings: Optional[Dict[str, Any]] = {} @@ -275,5 +275,26 @@ class User(BaseModel): github_repos: Optional[List[str]] = [] -class User_fb(BaseModel): - email: str = DEFAULT_USER +# Auth user models +class Login(BaseModel): + email: str + token: str + + +class Register(BaseModel): + email: str + first_name: str + last_name: str + company_name: str + job_title: str + + +class UserInfo(BaseModel): + first_name: str + last_name: str + company_name: str + job_title: str + + +class Detail(BaseModel): + detail: str diff --git a/agixt/db/Prompts.py b/agixt/Prompts.py similarity index 98% rename from agixt/db/Prompts.py rename to agixt/Prompts.py index ee8ac50467d..544e3fa422c 100644 --- a/agixt/db/Prompts.py +++ b/agixt/Prompts.py @@ -1,5 +1,5 @@ -from DBConnection import Prompt, PromptCategory, Argument, User, get_session -from Defaults import DEFAULT_USER +from DB import Prompt, PromptCategory, Argument, User, get_session +from Globals import DEFAULT_USER class Prompts: diff --git a/agixt/Providers.py b/agixt/Providers.py index da151f49201..9c9234294d1 100644 --- a/agixt/Providers.py +++ b/agixt/Providers.py @@ -5,7 +5,7 @@ import os import inspect import logging -from Defaults import getenv +from Globals import getenv logging.basicConfig( level=getenv("LOG_LEVEL"), diff --git a/agixt/db/imports.py b/agixt/SeedImports.py similarity index 98% rename from agixt/db/imports.py rename to agixt/SeedImports.py index 59cf487594f..ff854e0a443 100644 --- a/agixt/db/imports.py +++ b/agixt/SeedImports.py @@ -2,7 +2,7 @@ import json import yaml import logging -from DBConnection import ( +from DB import ( get_session, Provider, ProviderSetting, @@ -17,8 +17,8 @@ User, ) from Providers import get_providers, get_provider_options -from db.Agent import add_agent -from Defaults import getenv, DEFAULT_USER +from Agent import add_agent +from Globals import getenv, DEFAULT_USER logging.basicConfig( level=getenv("LOG_LEVEL"), @@ -186,7 +186,7 @@ def import_chains(user=DEFAULT_USER): if not chain_files: logging.info(f"No JSON files found in chains directory.") return - from db.Chain import Chain + from Chain import Chain chain_importer = Chain(user=user) for file in chain_files: @@ -406,7 +406,7 @@ def import_all_data(): if user_count == 0: # Create the default user logging.info("Creating default admin user...") - user = User(email=DEFAULT_USER, role="admin") + user = User(email=DEFAULT_USER, admin=True) session.add(user) session.commit() logging.info("Default user created.") diff --git a/agixt/Tunnel.py b/agixt/Tunnel.py index 9f9ac779c2d..2d51484a38e 100644 --- a/agixt/Tunnel.py +++ b/agixt/Tunnel.py @@ -1,5 +1,5 @@ import logging -from Defaults import getenv +from Globals import getenv logging.basicConfig( level=getenv("LOG_LEVEL"), diff --git a/agixt/Websearch.py b/agixt/Websearch.py index f5b074b7cb1..74d7b4cace4 100644 --- a/agixt/Websearch.py +++ b/agixt/Websearch.py @@ -11,7 +11,7 @@ from bs4 import BeautifulSoup from typing import List from ApiClient import Agent, Conversations -from Defaults import getenv, get_tokens +from Globals import getenv, get_tokens from readers.youtube import YoutubeReader from readers.github import GithubReader @@ -517,10 +517,7 @@ async def websearch_agent( if len(search_string) > 0: links = [] logging.info(f"Searching for: {search_string}") - if ( - self.searx_instance_url != "" - and self.searx_instance_url is not None - ): + if self.searx_instance_url != "": links = await self.search(query=search_string) else: links = await self.ddg_search(query=search_string) diff --git a/agixt/app.py b/agixt/app.py index 0647fbdb1df..e1644c2d6e7 100644 --- a/agixt/app.py +++ b/agixt/app.py @@ -12,7 +12,7 @@ from endpoints.Memory import app as memory_endpoints from endpoints.Prompt import app as prompt_endpoints from endpoints.Provider import app as provider_endpoints -from Defaults import getenv +from Globals import getenv os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/agixt/db/User.py b/agixt/db/User.py deleted file mode 100644 index 6688020f5c9..00000000000 --- a/agixt/db/User.py +++ /dev/null @@ -1,55 +0,0 @@ -from DBConnection import User, get_session -from db.Agent import add_agent -from Defaults import getenv -from agixtsdk import AGiXTSDK - - -def is_agixt_admin(email: str = "", api_key: str = ""): - if api_key == getenv("AGIXT_API_KEY"): - return True - session = get_session() - user = session.query(User).filter_by(email=email).first() - if not user: - return False - if user.role == "admin": - return True - return False - - -def create_user( - api_key: str, - email: str, - role: str = "user", - agent_name: str = "", - settings: dict = {}, - commands: dict = {}, - training_urls: list = [], - github_repos: list = [], - ApiClient: AGiXTSDK = AGiXTSDK(), -): - if not is_agixt_admin(email=email, api_key=api_key): - return {"error": "Access Denied"}, 403 - session = get_session() - email = email.lower() - user_exists = session.query(User).filter_by(email=email).first() - if user_exists: - session.close() - return {"error": "User already exists"}, 400 - user = User(email=email, role=role.lower()) - session.add(user) - session.commit() - session.close() - if agent_name != "" and agent_name is not None: - add_agent( - agent_name=agent_name, - provider_settings=settings, - commands=commands, - user=email, - ) - if training_urls != []: - for url in training_urls: - ApiClient.learn_url(agent_name=agent_name, url=url) - if github_repos != []: - for repo in github_repos: - ApiClient.learn_github_repo(agent_name=agent_name, github_repo=repo) - return {"status": "Success"}, 200 diff --git a/agixt/endpoints/Agent.py b/agixt/endpoints/Agent.py index ef7949c7c12..c63133a0377 100644 --- a/agixt/endpoints/Agent.py +++ b/agixt/endpoints/Agent.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, HTTPException, Depends, Header from Interactions import Interactions from Websearch import Websearch -from Defaults import getenv +from Globals import getenv from ApiClient import ( Agent, add_agent, diff --git a/agixt/endpoints/Auth.py b/agixt/endpoints/Auth.py new file mode 100644 index 00000000000..44e5108fa34 --- /dev/null +++ b/agixt/endpoints/Auth.py @@ -0,0 +1,132 @@ +from fastapi import APIRouter, Request, Header, Depends, HTTPException +from Models import Detail, Login, UserInfo, Register +from MagicalAuth import MagicalAuth, verify_api_key, webhook_create_user +from ApiClient import get_api_client, is_admin +from Models import WebhookUser +from Globals import getenv +import pyotp + +app = APIRouter() + + +@app.post("/v1/user") +def register(register: Register): + mfa_token = MagicalAuth().register(new_user=register) + totp = pyotp.TOTP(mfa_token) + otp_uri = totp.provisioning_uri(name=register.email, issuer_name=getenv("APP_NAME")) + return {"otp_uri": otp_uri} + + +@app.get("/v1/user/exists", response_model=bool, summary="Check if user exists") +def get_user(email: str) -> bool: + try: + return MagicalAuth().user_exists(email=email) + except: + return False + + +@app.get( + "/v1/user", + dependencies=[Depends(verify_api_key)], + summary="Get user details", +) +def log_in( + request: Request, + authorization: str = Header(None), +): + user_data = MagicalAuth(token=authorization).login(ip_address=request.client.host) + return { + "email": user_data.email, + "first_name": user_data.first_name, + "last_name": user_data.last_name, + } + + +@app.post( + "/v1/login", + response_model=Detail, + summary="Login with email and OTP token", +) +async def send_magic_link(request: Request, login: Login): + auth = MagicalAuth() + data = await request.json() + referrer = None + if "referrer" in data: + referrer = data["referrer"] + magic_link = auth.send_magic_link( + ip_address=request.client.host, login=login, referrer=referrer + ) + return Detail(detail=magic_link) + + +@app.put( + "/v1/user", + dependencies=[Depends(verify_api_key)], + response_model=Detail, + summary="Update user details", +) +def update_user(update: UserInfo, request: Request, authorization: str = Header(None)): + user = MagicalAuth(token=authorization).update_user( + ip_address=request.client.host, **update.model_dump() + ) + return Detail(detail=user) + + +# Delete user +@app.delete( + "/v1/user", + dependencies=[Depends(verify_api_key)], + response_model=Detail, + summary="Delete user", +) +def delete_user( + user=Depends(verify_api_key), + authorization: str = Header(None), +): + MagicalAuth(token=authorization).delete_user() + return Detail(detail="User deleted successfully.") + + +# Webhook user creations from other applications +@app.post("/api/user", tags=["User"]) +async def createuser( + account: WebhookUser, + authorization: str = Header(None), + user=Depends(verify_api_key), +): + if is_admin(email=user, api_key=authorization) != True: + raise HTTPException(status_code=403, detail="Access Denied") + ApiClient = get_api_client(authorization=authorization) + return webhook_create_user( + api_key=authorization, + email=account.email, + role="user", + agent_name=account.agent_name, + settings=account.settings, + commands=account.commands, + training_urls=account.training_urls, + github_repos=account.github_repos, + ApiClient=ApiClient, + ) + + +@app.post("/api/admin", tags=["User"]) +async def createadmin( + account: WebhookUser, + authorization: str = Header(None), + user=Depends(verify_api_key), +): + if is_admin(email=user, api_key=authorization) != True: + raise HTTPException(status_code=403, detail="Access Denied") + ApiClient = get_api_client(authorization=authorization) + return webhook_create_user( + api_key=authorization, + email=account.email, + role="admin", + agent_name=account.agent_name, + settings=account.settings, + commands=account.commands, + training_urls=account.training_urls, + github_repos=account.github_repos, + ApiClient=ApiClient, + ) diff --git a/agixt/endpoints/Completions.py b/agixt/endpoints/Completions.py index c6717e3d9b8..1362f282dd7 100644 --- a/agixt/endpoints/Completions.py +++ b/agixt/endpoints/Completions.py @@ -2,7 +2,7 @@ import base64 import uuid from fastapi import APIRouter, Depends, Header -from Defaults import get_tokens +from Globals import get_tokens from ApiClient import Agent, verify_api_key, get_api_client from providers.default import DefaultProvider from fastapi import UploadFile, File, Form diff --git a/agixt/endpoints/Provider.py b/agixt/endpoints/Provider.py index 3f0c458e11b..256459f7def 100644 --- a/agixt/endpoints/Provider.py +++ b/agixt/endpoints/Provider.py @@ -6,7 +6,7 @@ get_providers_with_settings, get_providers_by_service, ) -from ApiClient import verify_api_key, DB_CONNECTED, get_api_client, is_admin +from ApiClient import verify_api_key, get_api_client, is_admin from typing import Any app = APIRouter() @@ -67,46 +67,3 @@ async def get_embed_providers(user=Depends(verify_api_key)): ) async def get_embedder_info(user=Depends(verify_api_key)) -> Dict[str, Any]: return {"embedders": get_providers_by_service(service="embeddings")} - - -if DB_CONNECTED: - from db.User import create_user - from Models import User - - @app.post("/api/user", tags=["User"]) - async def createuser( - account: User, authorization: str = Header(None), user=Depends(verify_api_key) - ): - if is_admin(email=user, api_key=authorization) != True: - raise HTTPException(status_code=403, detail="Access Denied") - ApiClient = get_api_client(authorization=authorization) - return create_user( - api_key=authorization, - email=account.email, - role="user", - agent_name=account.agent_name, - settings=account.settings, - commands=account.commands, - training_urls=account.training_urls, - github_repos=account.github_repos, - ApiClient=ApiClient, - ) - - @app.post("/api/admin", tags=["User"]) - async def createadmin( - account: User, authorization: str = Header(None), user=Depends(verify_api_key) - ): - if is_admin(email=user, api_key=authorization) != True: - raise HTTPException(status_code=403, detail="Access Denied") - ApiClient = get_api_client(authorization=authorization) - return create_user( - api_key=authorization, - email=account.email, - role="admin", - agent_name=account.agent_name, - settings=account.settings, - commands=account.commands, - training_urls=account.training_urls, - github_repos=account.github_repos, - ApiClient=ApiClient, - ) diff --git a/agixt/extensions/agixt_actions.py b/agixt/extensions/agixt_actions.py index 82455c7250e..8bd4e7b26d6 100644 --- a/agixt/extensions/agixt_actions.py +++ b/agixt/extensions/agixt_actions.py @@ -163,7 +163,7 @@ def __init__(self, **kwargs): "Strip CSV Data from Code Block": self.get_csv_from_response, "Convert a string to a Pydantic model": self.convert_string_to_pydantic_model, } - user = kwargs["user"] if "user" in kwargs else "USER" + user = kwargs["user"] if "user" in kwargs else "user" for chain in Chain(user=user).get_chains(): self.commands[chain] = self.run_chain self.command_name = ( diff --git a/agixt/fb/Agent.py b/agixt/fb/Agent.py deleted file mode 100644 index 97865409b2f..00000000000 --- a/agixt/fb/Agent.py +++ /dev/null @@ -1,389 +0,0 @@ -import os -import json -import glob -import shutil -import importlib -import numpy as np -from inspect import signature, Parameter -from Providers import Providers -from Extensions import Extensions -from Defaults import DEFAULT_SETTINGS -from datetime import datetime, timezone, timedelta - - -def get_agent_file_paths(agent_name, user="USER"): - base_path = os.path.join(os.getcwd(), "agents") - folder_path = os.path.normpath(os.path.join(base_path, agent_name)) - config_path = os.path.normpath(os.path.join(folder_path, "config.json")) - if not config_path.startswith(base_path) or not folder_path.startswith(base_path): - raise ValueError("Invalid path, agent name must not contain slashes.") - if not os.path.exists(folder_path): - os.mkdir(folder_path) - return config_path, folder_path - - -def add_agent(agent_name, provider_settings=None, commands={}, user="USER"): - if not agent_name: - return "Agent name cannot be empty." - provider_settings = ( - DEFAULT_SETTINGS - if not provider_settings or provider_settings == {} - else provider_settings - ) - config_path, folder_path = get_agent_file_paths(agent_name=agent_name) - if provider_settings is None or provider_settings == "" or provider_settings == {}: - provider_settings = DEFAULT_SETTINGS - settings = json.dumps( - { - "commands": commands, - "settings": provider_settings, - } - ) - # Write the settings to the agent config file - with open(config_path, "w") as f: - f.write(settings) - return {"message": f"Agent {agent_name} created."} - - -def delete_agent(agent_name, user="USER"): - config_path, folder_path = get_agent_file_paths(agent_name=agent_name) - try: - if os.path.exists(folder_path): - shutil.rmtree(folder_path) - return {"message": f"Agent {agent_name} deleted."}, 200 - except: - return {"message": f"Agent {agent_name} could not be deleted."}, 400 - - -def rename_agent(agent_name, new_name, user="USER"): - config_path, folder_path = get_agent_file_paths(agent_name=agent_name) - base_path = os.path.join(os.getcwd(), "agents") - new_agent_folder = os.path.normpath(os.path.join(base_path, new_name)) - if not new_agent_folder.startswith(base_path): - raise ValueError("Invalid path, agent name must not contain slashes.") - - if os.path.exists(folder_path): - # Check if the new name is already taken - if os.path.exists(new_agent_folder): - # Add a number to the end of the new name - i = 1 - while os.path.exists(new_agent_folder): - i += 1 - new_name = f"{new_name}_{i}" - new_agent_folder = os.path.normpath(os.path.join(base_path, new_name)) - if not new_agent_folder.startswith(base_path): - raise ValueError("Invalid path, agent name must not contain slashes.") - os.rename(folder_path, new_agent_folder) - return {"message": f"Agent {agent_name} renamed to {new_name}."}, 200 - - -def get_agents(user="USER"): - agents_dir = "agents" - if not os.path.exists(agents_dir): - os.makedirs(agents_dir) - agents = [ - dir_name - for dir_name in os.listdir(agents_dir) - if os.path.isdir(os.path.join(agents_dir, dir_name)) - ] - output = [] - if agents: - for agent in agents: - agent_config = Agent(agent_name=agent, user=user).get_agent_config() - if "settings" not in agent_config: - agent_config["settings"] = {} - if "training" in agent_config["settings"]: - if str(agent_config["settings"]["training"]).lower() == "true": - output.append({"name": agent, "status": True}) - else: - output.append({"name": agent, "status": False}) - else: - output.append({"name": agent, "status": False}) - return output - - -class Agent: - def __init__(self, agent_name=None, user="USER", ApiClient=None): - self.USER = user - self.agent_name = agent_name if agent_name is not None else "AGiXT" - self.config_path, self.folder_path = get_agent_file_paths( - agent_name=self.agent_name - ) - self.AGENT_CONFIG = self.get_agent_config() - if "settings" not in self.AGENT_CONFIG: - self.AGENT_CONFIG["settings"] = {} - self.PROVIDER_SETTINGS = self.AGENT_CONFIG["settings"] - for setting in DEFAULT_SETTINGS: - if setting not in self.PROVIDER_SETTINGS: - self.PROVIDER_SETTINGS[setting] = DEFAULT_SETTINGS[setting] - self.AI_PROVIDER = self.PROVIDER_SETTINGS["provider"] - self.PROVIDER = Providers( - name=self.AI_PROVIDER, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - self._load_agent_config_keys(["AI_MODEL", "AI_TEMPERATURE", "MAX_TOKENS"]) - tts_provider = ( - self.AGENT_CONFIG["settings"]["tts_provider"] - if "tts_provider" in self.AGENT_CONFIG["settings"] - else "None" - ) - if tts_provider != "None" and tts_provider != None and tts_provider != "": - self.TTS_PROVIDER = Providers( - name=tts_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - else: - self.TTS_PROVIDER = None - transcription_provider = ( - self.AGENT_CONFIG["settings"]["transcription_provider"] - if "transcription_provider" in self.AGENT_CONFIG["settings"] - else "default" - ) - self.TRANSCRIPTION_PROVIDER = Providers( - name=transcription_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - translation_provider = ( - self.AGENT_CONFIG["settings"]["translation_provider"] - if "translation_provider" in self.AGENT_CONFIG["settings"] - else "default" - ) - self.TRANSLATION_PROVIDER = Providers( - name=translation_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - image_provider = ( - self.AGENT_CONFIG["settings"]["image_provider"] - if "image_provider" in self.AGENT_CONFIG["settings"] - else "default" - ) - self.IMAGE_PROVIDER = Providers( - name=image_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - embeddings_provider = ( - self.AGENT_CONFIG["settings"]["embeddings_provider"] - if "embeddings_provider" in self.AGENT_CONFIG["settings"] - else "default" - ) - self.EMBEDDINGS_PROVIDER = Providers( - name=embeddings_provider, ApiClient=ApiClient, **self.PROVIDER_SETTINGS - ) - if hasattr(self.EMBEDDINGS_PROVIDER, "chunk_size"): - self.chunk_size = self.EMBEDDINGS_PROVIDER.chunk_size - else: - self.chunk_size = 256 - self.embedder = self.EMBEDDINGS_PROVIDER.embedder - if "AI_MODEL" in self.PROVIDER_SETTINGS: - self.AI_MODEL = self.PROVIDER_SETTINGS["AI_MODEL"] - if self.AI_MODEL == "": - self.AI_MODEL = "default" - else: - self.AI_MODEL = "openassistant" - if "embedder" in self.PROVIDER_SETTINGS: - self.EMBEDDER = self.PROVIDER_SETTINGS["embedder"] - else: - if self.AI_PROVIDER == "openai": - self.EMBEDDER = "openai" - else: - self.EMBEDDER = "default" - if "MAX_TOKENS" in self.PROVIDER_SETTINGS: - self.MAX_TOKENS = self.PROVIDER_SETTINGS["MAX_TOKENS"] - else: - self.MAX_TOKENS = 4000 - self.commands = self.load_commands() - self.available_commands = Extensions( - agent_name=self.agent_name, - agent_config=self.AGENT_CONFIG, - ApiClient=ApiClient, - user=user, - ).get_available_commands() - self.clean_agent_config_commands() - - async def inference(self, prompt: str, tokens: int = 0, images: list = []): - if not prompt: - return "" - answer = await self.PROVIDER.inference( - prompt=prompt, tokens=tokens, images=images - ) - return answer.replace("\_", "_") - - def embeddings(self, input) -> np.ndarray: - return self.embedder(input=input) - - async def transcribe_audio(self, audio_path: str): - return await self.TRANSCRIPTION_PROVIDER.transcribe_audio(audio_path=audio_path) - - async def translate_audio(self, audio_path: str): - return await self.TRANSLATION_PROVIDER.translate_audio(audio_path=audio_path) - - async def generate_image(self, prompt: str): - return await self.IMAGE_PROVIDER.generate_image(prompt=prompt) - - async def text_to_speech(self, text: str): - if self.TTS_PROVIDER is not None: - return await self.TTS_PROVIDER.text_to_speech(text=text) - - def _load_agent_config_keys(self, keys): - for key in keys: - if key in self.AGENT_CONFIG: - setattr(self, key, self.AGENT_CONFIG[key]) - - def clean_agent_config_commands(self): - for command in self.commands: - friendly_name = command[0] - if friendly_name not in self.AGENT_CONFIG["commands"]: - self.AGENT_CONFIG["commands"][friendly_name] = False - for command in list(self.AGENT_CONFIG["commands"]): - if command not in [cmd[0] for cmd in self.commands]: - del self.AGENT_CONFIG["commands"][command] - with open(self.config_path, "w") as f: - json.dump(self.AGENT_CONFIG, f) - - def get_commands_string(self): - if len(self.available_commands) == 0: - return "" - working_dir = ( - self.AGENT_CONFIG["WORKING_DIRECTORY"] - if "WORKING_DIRECTORY" in self.AGENT_CONFIG - else os.path.join(os.getcwd(), "WORKSPACE") - ) - verbose_commands = f"### Available Commands\n**The assistant has commands available to use if they would be useful to provide a better user experience.**\nIf a file needs saved, the assistant's working directory is {working_dir}, use that as the file path.\n\n" - verbose_commands += "**See command execution examples of commands that the assistant has access to below:**\n" - for command in self.available_commands: - command_args = json.dumps(command["args"]) - command_args = command_args.replace( - '""', - '"The assistant will fill in the value based on relevance to the conversation."', - ) - verbose_commands += ( - f"\n- #execute('{command['friendly_name']}', {command_args})" - ) - verbose_commands += "\n\n**To execute an available command, the assistant can reference the examples and the command execution response will be replaced with the commands output for the user in the assistants response. The assistant can execute a command anywhere in the response and the commands will be executed in the order they are used.**\n**THE ASSISTANT CANNOT EXECUTE A COMMAND THAT IS NOT ON THE LIST OF EXAMPLES!**\n\n" - return verbose_commands - - def get_provider(self): - config_file = self.get_agent_config() - if "provider" in config_file: - return config_file["provider"] - else: - return "openai" - - def get_command_params(self, func): - params = {} - sig = signature(func) - for name, param in sig.parameters.items(): - if param.default == Parameter.empty: - params[name] = None - else: - params[name] = param.default - return params - - def load_commands(self): - commands = [] - command_files = glob.glob("extensions/*.py") - for command_file in command_files: - module_name = os.path.splitext(os.path.basename(command_file))[0] - module = importlib.import_module(f"extensions.{module_name}") - command_class = getattr(module, module_name.lower())() - if hasattr(command_class, "commands"): - for command_name, command_function in command_class.commands.items(): - params = self.get_command_params(command_function) - commands.append((command_name, command_function.__name__, params)) - return commands - - def get_agent_config(self): - while True: - if os.path.exists(self.config_path): - try: - with open(self.config_path, "r") as f: - file_content = f.read().strip() - if file_content: - return json.loads(file_content) - except: - None - add_agent(agent_name=self.agent_name) - return self.get_agent_config() - - def update_agent_config(self, new_config, config_key): - if os.path.exists(self.config_path): - with open(self.config_path, "r") as f: - current_config = json.load(f) - - # Ensure the config_key is present in the current configuration - if config_key not in current_config: - current_config[config_key] = {} - - # Update the specified key with new_config while preserving other keys and values - for key, value in new_config.items(): - current_config[config_key][key] = value - - # Save the updated configuration back to the file - with open(self.config_path, "w") as f: - json.dump(current_config, f) - return f"Agent {self.agent_name} configuration updated." - else: - return f"Agent {self.agent_name} configuration not found." - - def get_browsed_links(self): - """ - Get the list of URLs that have been browsed by the agent. - - Returns: - list: The list of URLs that have been browsed by the agent. - """ - # They will be stored in the agent's config file as: - # "browsed_links": [{"url": "https://example.com", "timestamp": "2021-01-01T00:00:00Z"}] - return self.AGENT_CONFIG.get("browsed_links", []) - - def browsed_recently(self, url) -> bool: - """ - Check if the given URL has been browsed by the agent within the last 24 hours. - - Args: - url (str): The URL to check. - - Returns: - bool: True if the URL has been browsed within the last 24 hours, False otherwise. - """ - browsed_links = self.get_browsed_links() - if not browsed_links: - return False - for link in browsed_links: - if link["url"] == url: - if link["timestamp"] >= datetime.now(timezone.utc) - timedelta(days=1): - return True - return False - - def add_browsed_link(self, url): - """ - Add a URL to the list of browsed links for the agent. - - Args: - url (str): The URL to add. - - Returns: - str: The response message. - """ - browsed_links = self.get_browsed_links() - # check if the URL has already been browsed - if self.browsed_recently(url): - return "URL has already been browsed recently." - browsed_links.append( - {"url": url, "timestamp": datetime.now(timezone.utc).isoformat()} - ) - self.update_agent_config(browsed_links, "browsed_links") - return "URL added to browsed links." - - def delete_browsed_link(self, url): - """ - Delete a URL from the list of browsed links for the agent. - - Args: - url (str): The URL to delete. - - Returns: - str: The response message. - """ - browsed_links = self.get_browsed_links() - for link in browsed_links: - if link["url"] == url: - browsed_links.remove(link) - self.update_agent_config(browsed_links, "browsed_links") - return "URL deleted from browsed links." - return "URL not found in browsed links." diff --git a/agixt/fb/Chain.py b/agixt/fb/Chain.py deleted file mode 100644 index b465bf6614d..00000000000 --- a/agixt/fb/Chain.py +++ /dev/null @@ -1,284 +0,0 @@ -import os -import json -import logging -from Defaults import getenv - -logging.basicConfig( - level=getenv("LOG_LEVEL"), - format=getenv("LOG_FORMAT"), -) - - -def create_command_suggestion_chain( - agent_name, command_name, command_args, user="USER" -): - chain = Chain() - chains = chain.get_chains() - chain_name = f"{agent_name} Command Suggestions" - if chain_name in chains: - step = int(chain.get_chain(chain_name=chain_name)["steps"][-1]["step"]) + 1 - else: - chain.add_chain(chain_name=chain_name) - step = 1 - chain.add_chain_step( - chain_name=chain_name, - agent_name=agent_name, - step_number=step, - prompt_type="Command", - prompt={ - "command_name": command_name, - **command_args, - }, - ) - return f"The command has been added to a chain called '{agent_name} Command Suggestions' for you to review and execute manually." - - -def get_chain_file_path(chain_name, user="USER"): - base_path = os.path.join(os.getcwd(), "chains") - folder_path = os.path.normpath(os.path.join(base_path, chain_name)) - file_path = os.path.normpath(os.path.join(base_path, f"{chain_name}.json")) - if not file_path.startswith(base_path) or not folder_path.startswith(base_path): - raise ValueError("Invalid path, chain name must not contain slashes.") - if not os.path.exists(folder_path): - os.mkdir(folder_path) - return file_path - - -def get_chain_responses_file_path(chain_name, user="USER"): - base_path = os.path.join(os.getcwd(), "chains") - file_path = os.path.normpath(os.path.join(base_path, chain_name, "responses.json")) - if not file_path.startswith(base_path): - raise ValueError("Invalid path, chain name must not contain slashes.") - return file_path - - -class Chain: - def __init__(self, user="USER"): - self.user = user - - def import_chain(self, chain_name: str, steps: dict): - file_path = get_chain_file_path(chain_name=chain_name) - steps = steps["steps"] if "steps" in steps else steps - with open(file_path, "w") as f: - json.dump({"chain_name": chain_name, "steps": steps}, f) - return f"Chain '{chain_name}' imported." - - def get_chain(self, chain_name): - try: - file_path = get_chain_file_path(chain_name=chain_name) - with open(file_path, "r") as f: - chain_data = json.load(f) - return chain_data - except: - return {} - - def get_chains(self): - chains = [ - f.replace(".json", "") for f in os.listdir("chains") if f.endswith(".json") - ] - return chains - - def add_chain(self, chain_name): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = {"chain_name": chain_name, "steps": []} - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def rename_chain(self, chain_name, new_name): - file_path = get_chain_file_path(chain_name=chain_name) - new_file_path = get_chain_file_path(chain_name=new_name) - os.rename( - os.path.join(file_path), - os.path.join(new_file_path), - ) - chain_data = self.get_chain(chain_name=new_name) - chain_data["chain_name"] = new_name - with open(new_file_path, "w") as f: - json.dump(chain_data, f) - - def add_chain_step(self, chain_name, step_number, agent_name, prompt_type, prompt): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = self.get_chain(chain_name=chain_name) - chain_data["steps"].append( - { - "step": step_number, - "agent_name": agent_name, - "prompt_type": prompt_type, - "prompt": prompt, - } - ) - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def update_step(self, chain_name, step_number, agent_name, prompt_type, prompt): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = self.get_chain(chain_name=chain_name) - for step in chain_data["steps"]: - if step["step"] == step_number: - step["agent_name"] = agent_name - step["prompt_type"] = prompt_type - step["prompt"] = prompt - break - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def delete_step(self, chain_name, step_number): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = self.get_chain(chain_name=chain_name) - chain_data["steps"] = [ - step for step in chain_data["steps"] if step["step"] != step_number - ] - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def delete_chain(self, chain_name): - file_path = get_chain_file_path(chain_name=chain_name) - os.remove(file_path) - - def get_step(self, chain_name, step_number): - chain_data = self.get_chain(chain_name=chain_name) - for step in chain_data["steps"]: - if step["step"] == step_number: - return step - return None - - def get_steps(self, chain_name): - chain_data = self.get_chain(chain_name=chain_name) - return chain_data["steps"] - - def move_step(self, chain_name, current_step_number, new_step_number): - file_path = get_chain_file_path(chain_name=chain_name) - chain_data = self.get_chain(chain_name=chain_name) - if not 1 <= new_step_number <= len( - chain_data["steps"] - ) or current_step_number not in [step["step"] for step in chain_data["steps"]]: - logging.info(f"Error: Invalid step numbers.") - return - moved_step = None - for step in chain_data["steps"]: - if step["step"] == current_step_number: - moved_step = step - chain_data["steps"].remove(step) - break - for step in chain_data["steps"]: - if new_step_number < current_step_number: - if new_step_number <= step["step"] < current_step_number: - step["step"] += 1 - else: - if current_step_number < step["step"] <= new_step_number: - step["step"] -= 1 - moved_step["step"] = new_step_number - chain_data["steps"].append(moved_step) - chain_data["steps"] = sorted(chain_data["steps"], key=lambda x: x["step"]) - with open(file_path, "w") as f: - json.dump(chain_data, f) - - def get_step_response(self, chain_name, step_number="all"): - file_path = get_chain_responses_file_path(chain_name=chain_name) - try: - with open(file_path, "r") as f: - responses = json.load(f) - if step_number == "all": - return responses - else: - data = responses.get(str(step_number)) - if isinstance(data, dict) and "response" in data: - data = data["response"] - logging.info(f"Step {step_number} response: {data}") - return data - except: - return "" - - def get_chain_responses(self, chain_name): - file_path = get_chain_responses_file_path(chain_name=chain_name) - try: - with open(file_path, "r") as f: - responses = json.load(f) - return responses - except: - return {} - - def get_step_content(self, chain_name, prompt_content, user_input, agent_name): - if isinstance(prompt_content, dict): - new_prompt_content = {} - for arg, value in prompt_content.items(): - if isinstance(value, str): - if "{user_input}" in value: - value = value.replace("{user_input}", user_input) - if "{agent_name}" in value: - value = value.replace("{agent_name}", agent_name) - if "{STEP" in value: - step_count = value.count("{STEP") - for i in range(step_count): - new_step_number = int(value.split("{STEP")[1].split("}")[0]) - step_response = self.get_step_response( - chain_name=chain_name, step_number=new_step_number - ) - if step_response: - resp = ( - step_response[0] - if isinstance(step_response, list) - else step_response - ) - value = value.replace( - f"{{STEP{new_step_number}}}", f"{resp}" - ) - new_prompt_content[arg] = value - return new_prompt_content - elif isinstance(prompt_content, str): - new_prompt_content = prompt_content - if "{user_input}" in prompt_content: - new_prompt_content = new_prompt_content.replace( - "{user_input}", user_input - ) - if "{agent_name}" in new_prompt_content: - new_prompt_content = new_prompt_content.replace( - "{agent_name}", agent_name - ) - if "{STEP" in prompt_content: - step_count = prompt_content.count("{STEP") - for i in range(step_count): - new_step_number = int( - prompt_content.split("{STEP")[1].split("}")[0] - ) - step_response = self.get_step_response( - chain_name=chain_name, step_number=new_step_number - ) - if step_response: - resp = ( - step_response[0] - if isinstance(step_response, list) - else step_response - ) - new_prompt_content = new_prompt_content.replace( - f"{{STEP{new_step_number}}}", f"{resp}" - ) - return new_prompt_content - else: - return prompt_content - - async def update_step_response(self, chain_name, step_number, response): - file_path = get_chain_responses_file_path(chain_name=chain_name) - try: - with open(file_path, "r") as f: - responses = json.load(f) - except: - responses = {} - - if str(step_number) not in responses: - responses[str(step_number)] = response - else: - if isinstance(responses[str(step_number)], dict) and isinstance( - response, dict - ): - responses[str(step_number)].update(response) - elif isinstance(responses[str(step_number)], list): - if isinstance(response, list): - responses[str(step_number)].extend(response) - else: - responses[str(step_number)].append(response) - else: - responses[str(step_number)] = response - - with open(file_path, "w") as f: - json.dump(responses, f) diff --git a/agixt/fb/Conversations.py b/agixt/fb/Conversations.py deleted file mode 100644 index c129143c118..00000000000 --- a/agixt/fb/Conversations.py +++ /dev/null @@ -1,103 +0,0 @@ -from datetime import datetime -import yaml -import os -import logging -from Defaults import getenv, DEFAULT_USER - -logging.basicConfig( - level=getenv("LOG_LEVEL"), - format=getenv("LOG_FORMAT"), -) - - -class Conversations: - def __init__(self, conversation_name=None, user=DEFAULT_USER): - self.conversation_name = conversation_name - self.user = user - - def export_conversation(self): - if not self.conversation_name: - self.conversation_name = f"{str(datetime.now())} Conversation" - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - if os.path.exists(history_file): - with open(history_file, "r") as file: - history = yaml.safe_load(file) - return history - return {"interactions": []} - - def get_conversation(self, limit=100, page=1): - history = {"interactions": []} - try: - history_file = os.path.join( - "conversations", f"{self.conversation_name}.yaml" - ) - if os.path.exists(history_file): - with open(history_file, "r") as file: - history = yaml.safe_load(file) - except: - history = self.new_conversation() - return history - - def get_conversations(self): - conversation_dir = os.path.join("conversations") - if os.path.exists(conversation_dir): - conversations = os.listdir(conversation_dir) - return [conversation.split(".")[0] for conversation in conversations] - return [] - - def new_conversation(self, conversation_content=[]): - history = {"interactions": conversation_content} - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - os.makedirs(os.path.dirname(history_file), exist_ok=True) - with open(history_file, "w") as file: - yaml.safe_dump(history, file) - return history - - def log_interaction(self, role: str, message: str): - history = self.get_conversation() - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - if not os.path.exists(history_file): - os.makedirs(os.path.dirname(history_file), exist_ok=True) - if not history: - history = {"interactions": []} - if "interactions" not in history: - history["interactions"] = [] - history["interactions"].append( - { - "role": role, - "message": message, - "timestamp": datetime.now().strftime("%B %d, %Y %I:%M %p"), - } - ) - with open(history_file, "w") as file: - yaml.safe_dump(history, file) - if role.lower() == "user": - logging.info(f"{self.user}: {message}") - else: - logging.info(f"{role}: {message}") - - def delete_conversation(self): - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - if os.path.exists(history_file): - os.remove(history_file) - - def delete_message(self, message): - history = self.get_conversation() - history["interactions"] = [ - interaction - for interaction in history["interactions"] - if interaction["message"] != message - ] - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - with open(history_file, "w") as file: - yaml.safe_dump(history, file) - - def update_message(self, message, new_message): - history = self.get_conversation() - for interaction in history["interactions"]: - if interaction["message"] == message: - interaction["message"] = new_message - break - history_file = os.path.join("conversations", f"{self.conversation_name}.yaml") - with open(history_file, "w") as file: - yaml.safe_dump(history, file) diff --git a/agixt/fb/Prompts.py b/agixt/fb/Prompts.py deleted file mode 100644 index 1b845bfd37c..00000000000 --- a/agixt/fb/Prompts.py +++ /dev/null @@ -1,112 +0,0 @@ -import os - - -def get_prompt_file_path(prompt_name, prompt_category="Default", user="USER"): - base_path = os.path.join(os.getcwd(), "prompts") - base_model_path = os.path.normpath( - os.path.join(os.getcwd(), "prompts", prompt_category) - ) - model_prompt_file = os.path.normpath( - os.path.join(base_model_path, f"{prompt_name}.txt") - ) - default_prompt_file = os.path.normpath( - os.path.join(base_path, "Default", f"{prompt_name}.txt") - ) - if ( - not base_model_path.startswith(base_path) - or not model_prompt_file.startswith(base_model_path) - or not default_prompt_file.startswith(base_path) - ): - raise ValueError( - "Invalid file path. Prompt name cannot contain '/', '\\' or '..' in" - ) - if not os.path.exists(base_path): - os.mkdir(base_path) - if not os.path.exists(base_model_path): - os.mkdir(base_model_path) - prompt_file = ( - model_prompt_file if os.path.isfile(model_prompt_file) else default_prompt_file - ) - return prompt_file - - -class Prompts: - def __init__(self, user="USER"): - self.user = user - - def add_prompt(self, prompt_name, prompt, prompt_category="Default"): - # if prompts folder does not exist, create it - file_path = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - # if prompt file does not exist, create it - if not os.path.exists(file_path): - with open(file_path, "w") as f: - f.write(prompt) - - def get_prompt(self, prompt_name, prompt_category="Default"): - prompt_file = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - with open(prompt_file, "r") as f: - prompt = f.read() - return prompt - - def get_prompts(self, prompt_category="Default"): - # Get all files in prompts folder that end in .txt and replace .txt with empty string - prompts = [] - # For each folder in prompts folder, get all files that end in .txt and replace .txt with empty string - base_path = os.path.join("prompts", prompt_category) - base_path = os.path.join(os.getcwd(), "prompts") - base_model_path = os.path.normpath( - os.path.join(os.getcwd(), "prompts", prompt_category) - ) - if not base_model_path.startswith(base_path) or not base_model_path.startswith( - base_model_path - ): - raise ValueError( - "Invalid file path. Prompt name cannot contain '/', '\\' or '..' in" - ) - if not os.path.exists(base_model_path): - os.mkdir(base_model_path) - for file in os.listdir(base_model_path): - if file.endswith(".txt"): - prompts.append(file.replace(".txt", "")) - return prompts - - def get_prompt_categories(self): - prompt_categories = [] - for folder in os.listdir("prompts"): - if os.path.isdir(os.path.join("prompts", folder)): - prompt_categories.append(folder) - return prompt_categories - - def get_prompt_args(self, prompt_text): - # Find anything in the file between { and } and add them to a list to return - prompt_vars = [] - for word in prompt_text.split(): - if word.startswith("{") and word.endswith("}"): - prompt_vars.append(word[1:-1]) - return prompt_vars - - def delete_prompt(self, prompt_name, prompt_category="Default"): - prompt_file = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - os.remove(prompt_file) - - def update_prompt(self, prompt_name, prompt, prompt_category="Default"): - prompt_file = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - with open(prompt_file, "w") as f: - f.write(prompt) - - def rename_prompt(self, prompt_name, new_prompt_name, prompt_category="Default"): - prompt_file = get_prompt_file_path( - prompt_name=prompt_name, prompt_category=prompt_category - ) - new_prompt_file = get_prompt_file_path( - prompt_name=new_prompt_name, prompt_category=prompt_category - ) - os.rename(prompt_file, new_prompt_file) diff --git a/agixt/launch-backend.sh b/agixt/launch-backend.sh index 7ce76dc82d8..a415e693886 100755 --- a/agixt/launch-backend.sh +++ b/agixt/launch-backend.sh @@ -1,11 +1,8 @@ #!/bin/sh echo "Starting AGiXT..." -if [ "$DB_CONNECTED" = "true" ]; then - sleep 15 - echo "Connecting to DB..." - python3 DBConnection.py - sleep 5 -fi +sleep 15 +python3 DB.py +sleep 5 if [ -n "$NGROK_TOKEN" ]; then echo "Starting ngrok..." python3 Tunnel.py diff --git a/agixt/providers/ezlocalai.py b/agixt/providers/ezlocalai.py index 3325bccac97..11290b0cf88 100644 --- a/agixt/providers/ezlocalai.py +++ b/agixt/providers/ezlocalai.py @@ -3,7 +3,7 @@ import re import numpy as np import requests -from Defaults import getenv +from Globals import getenv import uuid from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction diff --git a/agixt/providers/openai.py b/agixt/providers/openai.py index 14fba0223e0..6a095d5b5b0 100644 --- a/agixt/providers/openai.py +++ b/agixt/providers/openai.py @@ -3,7 +3,7 @@ import random import requests import uuid -from Defaults import getenv +from Globals import getenv import numpy as np from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction diff --git a/agixt/version b/agixt/version index d9f31c4efcf..05f629f1b7f 100644 --- a/agixt/version +++ b/agixt/version @@ -1 +1 @@ -v1.5.18 \ No newline at end of file +v1.6.0 \ No newline at end of file diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 8273bd26524..5736bc2d89b 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -1,4 +1,3 @@ -version: "3.7" services: db: image: postgres @@ -14,7 +13,6 @@ services: image: joshxt/agixt:main init: true environment: - - DB_CONNECTED=${DB_CONNECTED:-false} - DATABASE_HOST=${DATABASE_HOST:-db} - DATABASE_USER=${DATABASE_USER:-postgres} - DATABASE_PASSWORD=${DATABASE_PASSWORD:-postgres} @@ -29,6 +27,7 @@ services: - WORKING_DIRECTORY=${WORKING_DIRECTORY:-/agixt/WORKSPACE} - TOKENIZERS_PARALLELISM=False - LOG_LEVEL=${LOG_LEVEL:-INFO} + - AUTH_PROVIDER=${AUTH_PROVIDER:-none} - TZ=${TZ-America/New_York} ports: - "7437:7437" diff --git a/docker-compose.yml b/docker-compose.yml index 84b05ae02da..d03b6fbddfb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,15 +1,33 @@ -version: "3.7" services: + db: + image: postgres + ports: + - 5432:5432 + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: ${DATABASE_PASSWORD:-postgres} + POSTGRES_DB: postgres + volumes: + - ./data:/var/lib/postgresql/data agixt: image: joshxt/agixt:latest init: true environment: + - DATABASE_HOST=${DATABASE_HOST:-db} + - DATABASE_USER=${DATABASE_USER:-postgres} + - DATABASE_PASSWORD=${DATABASE_PASSWORD:-postgres} + - DATABASE_NAME=${DATABASE_NAME:-postgres} + - DATABASE_PORT=${DATABASE_PORT:-5432} - UVICORN_WORKERS=${UVICORN_WORKERS:-10} + - USING_JWT=${USING_JWT:-false} - AGIXT_API_KEY=${AGIXT_API_KEY} - AGIXT_URI=${AGIXT_URI-http://agixt:7437} + - DISABLED_EXTENSIONS=${DISABLED_EXTENSIONS:-} + - DISABLED_PROVIDERS=${DISABLED_PROVIDERS:-} - WORKING_DIRECTORY=${WORKING_DIRECTORY:-/agixt/WORKSPACE} - TOKENIZERS_PARALLELISM=False - - LOG_LEVEL=${LOG_LEVEL:-ERROR} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + - AUTH_PROVIDER=${AUTH_PROVIDER:-none} - TZ=${TZ-America/New_York} ports: - "7437:7437" diff --git a/tests/completions-tests.ipynb b/tests/completions-tests.ipynb index 74f6a794938..54336981657 100644 --- a/tests/completions-tests.ipynb +++ b/tests/completions-tests.ipynb @@ -32,7 +32,9 @@ "from dotenv import load_dotenv\n", "\n", "load_dotenv()\n", + "import time\n", "\n", + "time.sleep(180) # wait for the AGiXT server to start\n", "# Set your system message, max tokens, temperature, and top p here, or use the defaults.\n", "AGENT_NAME = \"gpt4free\"\n", "AGIXT_SERVER = \"http://localhost:7437\"\n",