diff --git a/alembic/versions/2024_09_03_0356-0de9150bc624_update_task_generation_table_use_user_.py b/alembic/versions/2024_09_03_0356-0de9150bc624_update_task_generation_table_use_user_.py new file mode 100644 index 0000000000..86ae059b7a --- /dev/null +++ b/alembic/versions/2024_09_03_0356-0de9150bc624_update_task_generation_table_use_user_.py @@ -0,0 +1,46 @@ +"""update task_generation table - use user_prompt_hash as the index of a user prompt + +Revision ID: 0de9150bc624 +Revises: 6de11b2be7c8 +Create Date: 2024-09-03 03:56:58.352307+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0de9150bc624" +down_revision: Union[str, None] = "6de11b2be7c8" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("task_generations", sa.Column("user_prompt_hash", sa.String(), nullable=True)) + op.add_column("task_generations", sa.Column("source_task_generation_id", sa.String(), nullable=True)) + op.drop_index("ix_task_generations_user_prompt", table_name="task_generations") + op.create_index( + op.f("ix_task_generations_source_task_generation_id"), + "task_generations", + ["source_task_generation_id"], + unique=False, + ) + op.create_index( + op.f("ix_task_generations_user_prompt_hash"), "task_generations", ["user_prompt_hash"], unique=False + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_task_generations_user_prompt_hash"), table_name="task_generations") + op.drop_index(op.f("ix_task_generations_source_task_generation_id"), table_name="task_generations") + op.create_index("ix_task_generations_user_prompt", "task_generations", ["user_prompt"], unique=False) + op.drop_column("task_generations", "source_task_generation_id") + op.drop_column("task_generations", "user_prompt_hash") + # ### end Alembic commands ### diff --git a/skyvern/config.py b/skyvern/config.py index 7ebbccc679..9641a8f1f2 100644 --- a/skyvern/config.py +++ b/skyvern/config.py @@ -77,6 +77,9 @@ class Settings(BaseSettings): BITWARDEN_TIMEOUT_SECONDS: int = 60 BITWARDEN_MAX_RETRIES: int = 1 + # task generation settings + PROMPT_CACHE_WINDOW_HOURS: int = 24 + ##################### # LLM Configuration # ##################### diff --git a/skyvern/forge/sdk/db/client.py b/skyvern/forge/sdk/db/client.py index 87a80fe1ab..786243742a 100644 --- a/skyvern/forge/sdk/db/client.py +++ b/skyvern/forge/sdk/db/client.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Sequence import structlog @@ -6,6 +6,7 @@ from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from skyvern.config import settings from skyvern.exceptions import WorkflowParameterNotFound from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType @@ -1386,6 +1387,7 @@ async def create_task_generation( self, organization_id: str, user_prompt: str, + user_prompt_hash: str, url: str | None = None, navigation_goal: str | None = None, navigation_payload: dict[str, Any] | None = None, @@ -1395,11 +1397,13 @@ async def create_task_generation( llm: str | None = None, llm_prompt: str | None = None, llm_response: str | None = None, + source_task_generation_id: str | None = None, ) -> TaskGeneration: async with self.Session() as session: new_task_generation = TaskGenerationModel( organization_id=organization_id, user_prompt=user_prompt, + user_prompt_hash=user_prompt_hash, url=url, navigation_goal=navigation_goal, navigation_payload=navigation_payload, @@ -1409,8 +1413,27 @@ async def create_task_generation( llm_prompt=llm_prompt, llm_response=llm_response, suggested_title=suggested_title, + source_task_generation_id=source_task_generation_id, ) session.add(new_task_generation) await session.commit() await session.refresh(new_task_generation) return TaskGeneration.model_validate(new_task_generation) + + async def get_task_generation_by_prompt_hash( + self, + user_prompt_hash: str, + query_window_hours: int = settings.PROMPT_ACTION_HISTORY_WINDOW, + ) -> TaskGeneration | None: + before_time = datetime.utcnow() - timedelta(hours=query_window_hours) + async with self.Session() as session: + query = ( + select(TaskGenerationModel) + .filter_by(user_prompt_hash=user_prompt_hash) + .filter(TaskGenerationModel.llm.is_not(None)) + .filter(TaskGenerationModel.created_at > before_time) + ) + task_generation = (await session.scalars(query)).first() + if not task_generation: + return None + return TaskGeneration.model_validate(task_generation) diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 7762c0f573..6cee7f4d23 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -374,7 +374,8 @@ class TaskGenerationModel(Base): task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id) organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False) - user_prompt = Column(String, nullable=False, index=True) # The prompt from the user + user_prompt = Column(String, nullable=False) + user_prompt_hash = Column(String, index=True) url = Column(String) navigation_goal = Column(String) navigation_payload = Column(JSON) @@ -386,5 +387,7 @@ class TaskGenerationModel(Base): llm_prompt = Column(String) # The prompt sent to the language model llm_response = Column(String) # The response from the language model + source_task_generation_id = Column(String, index=True) + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/routes/agent_protocol.py b/skyvern/forge/sdk/routes/agent_protocol.py index 0dca5dc6ca..4f604258cc 100644 --- a/skyvern/forge/sdk/routes/agent_protocol.py +++ b/skyvern/forge/sdk/routes/agent_protocol.py @@ -1,4 +1,5 @@ import datetime +import hashlib import uuid from typing import Annotated, Any @@ -54,6 +55,7 @@ base_router = APIRouter() LOG = structlog.get_logger() +PROMPT_CACHE_WINDOW_HOURS = 24 @base_router.post("/webhook", tags=["server"]) @@ -766,6 +768,32 @@ async def generate_task( data: GenerateTaskRequest, current_org: Organization = Depends(org_auth_service.get_current_org), ) -> TaskGeneration: + user_prompt = data.prompt + hash_object = hashlib.sha256() + hash_object.update(user_prompt.encode("utf-8")) + user_prompt_hash = hash_object.hexdigest() + # check if there's a same user_prompt within the past x Hours + # in the future, we can use vector db to fetch similar prompts + existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash( + user_prompt_hash=user_prompt_hash, query_window_hours=PROMPT_CACHE_WINDOW_HOURS + ) + if existing_task_generation: + new_task_generation = await app.DATABASE.create_task_generation( + organization_id=current_org.organization_id, + user_prompt=data.prompt, + user_prompt_hash=user_prompt_hash, + url=existing_task_generation.url, + navigation_goal=existing_task_generation.navigation_goal, + navigation_payload=existing_task_generation.navigation_payload, + data_extraction_goal=existing_task_generation.data_extraction_goal, + extracted_information_schema=existing_task_generation.extracted_information_schema, + llm=existing_task_generation.llm, + llm_prompt=existing_task_generation.llm_prompt, + llm_response=existing_task_generation.llm_response, + source_task_generation_id=existing_task_generation.task_generation_id, + ) + return new_task_generation + llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt) try: llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt) @@ -775,6 +803,7 @@ async def generate_task( task_generation = await app.DATABASE.create_task_generation( organization_id=current_org.organization_id, user_prompt=data.prompt, + user_prompt_hash=user_prompt_hash, url=parsed_task_generation_obj.url, navigation_goal=parsed_task_generation_obj.navigation_goal, navigation_payload=parsed_task_generation_obj.navigation_payload, diff --git a/skyvern/forge/sdk/schemas/task_generations.py b/skyvern/forge/sdk/schemas/task_generations.py index 383724f6be..54342ebd36 100644 --- a/skyvern/forge/sdk/schemas/task_generations.py +++ b/skyvern/forge/sdk/schemas/task_generations.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import Any -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field class TaskGenerationBase(BaseModel): @@ -9,6 +9,7 @@ class TaskGenerationBase(BaseModel): organization_id: str | None = None user_prompt: str | None = None + user_prompt_hash: str | None = None url: str | None = None navigation_goal: str | None = None navigation_payload: dict[str, Any] | None = None @@ -20,19 +21,16 @@ class TaskGenerationBase(BaseModel): suggested_title: str | None = None -class TaskGenerationCreate(TaskGenerationBase): - organization_id: str - user_prompt: str - - class TaskGeneration(TaskGenerationBase): task_generation_id: str organization_id: str user_prompt: str + user_prompt_hash: str created_at: datetime modified_at: datetime class GenerateTaskRequest(BaseModel): - prompt: str + # prompt needs to be at least 1 character long + prompt: str = Field(..., min_length=1)