Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use cached prompt generation #768

Merged
merged 2 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions skyvern/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class Settings(BaseSettings):
BITWARDEN_TIMEOUT_SECONDS: int = 60
BITWARDEN_MAX_RETRIES: int = 1

# task generation settings
PROMPT_CACHE_WINDOW_HOURS: int = 24

#####################
# LLM Configuration #
#####################
Expand Down
25 changes: 24 additions & 1 deletion skyvern/forge/sdk/db/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from datetime import datetime
from datetime import datetime, timedelta
from typing import Any, Sequence

import structlog
from sqlalchemy import and_, delete, func, select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine

from skyvern.config import settings
from skyvern.exceptions import WorkflowParameterNotFound
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.db.enums import OrganizationAuthTokenType
Expand Down Expand Up @@ -1386,6 +1387,7 @@ async def create_task_generation(
self,
organization_id: str,
user_prompt: str,
user_prompt_hash: str,
url: str | None = None,
navigation_goal: str | None = None,
navigation_payload: dict[str, Any] | None = None,
Expand All @@ -1395,11 +1397,13 @@ async def create_task_generation(
llm: str | None = None,
llm_prompt: str | None = None,
llm_response: str | None = None,
source_task_generation_id: str | None = None,
) -> TaskGeneration:
async with self.Session() as session:
new_task_generation = TaskGenerationModel(
organization_id=organization_id,
user_prompt=user_prompt,
user_prompt_hash=user_prompt_hash,
url=url,
navigation_goal=navigation_goal,
navigation_payload=navigation_payload,
Expand All @@ -1409,8 +1413,27 @@ async def create_task_generation(
llm_prompt=llm_prompt,
llm_response=llm_response,
suggested_title=suggested_title,
source_task_generation_id=source_task_generation_id,
)
session.add(new_task_generation)
await session.commit()
await session.refresh(new_task_generation)
return TaskGeneration.model_validate(new_task_generation)

async def get_task_generation_by_prompt_hash(
self,
user_prompt_hash: str,
query_window_hours: int = settings.PROMPT_ACTION_HISTORY_WINDOW,
) -> TaskGeneration | None:
before_time = datetime.utcnow() - timedelta(hours=query_window_hours)
async with self.Session() as session:
query = (
select(TaskGenerationModel)
.filter_by(user_prompt_hash=user_prompt_hash)
.filter(TaskGenerationModel.llm.is_not(None))
.filter(TaskGenerationModel.created_at > before_time)
)
task_generation = (await session.scalars(query)).first()
if not task_generation:
return None
return TaskGeneration.model_validate(task_generation)
5 changes: 4 additions & 1 deletion skyvern/forge/sdk/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ class TaskGenerationModel(Base):

task_generation_id = Column(String, primary_key=True, default=generate_task_generation_id)
organization_id = Column(String, ForeignKey("organizations.organization_id"), nullable=False)
user_prompt = Column(String, nullable=False, index=True) # The prompt from the user
user_prompt = Column(String, nullable=False)
user_prompt_hash = Column(String, index=True)
url = Column(String)
navigation_goal = Column(String)
navigation_payload = Column(JSON)
Expand All @@ -386,5 +387,7 @@ class TaskGenerationModel(Base):
llm_prompt = Column(String) # The prompt sent to the language model
llm_response = Column(String) # The response from the language model

source_task_generation_id = Column(String, index=True)

created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
29 changes: 29 additions & 0 deletions skyvern/forge/sdk/routes/agent_protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
import hashlib
import uuid
from typing import Annotated, Any

Expand Down Expand Up @@ -54,6 +55,7 @@
base_router = APIRouter()

LOG = structlog.get_logger()
PROMPT_CACHE_WINDOW_HOURS = 24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PROMPT_CACHE_WINDOW_HOURS is already defined in config.py. Consider importing it from there instead of redefining it.



@base_router.post("/webhook", tags=["server"])
Expand Down Expand Up @@ -766,6 +768,32 @@ async def generate_task(
data: GenerateTaskRequest,
current_org: Organization = Depends(org_auth_service.get_current_org),
) -> TaskGeneration:
user_prompt = data.prompt
hash_object = hashlib.sha256()
hash_object.update(user_prompt.encode("utf-8"))
user_prompt_hash = hash_object.hexdigest()
# check if there's a same user_prompt within the past x Hours
# in the future, we can use vector db to fetch similar prompts
existing_task_generation = await app.DATABASE.get_task_generation_by_prompt_hash(
user_prompt_hash=user_prompt_hash, query_window_hours=PROMPT_CACHE_WINDOW_HOURS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using settings.PROMPT_CACHE_WINDOW_HOURS instead of the hardcoded value 24 for consistency and maintainability.

)
if existing_task_generation:
new_task_generation = await app.DATABASE.create_task_generation(
organization_id=current_org.organization_id,
user_prompt=data.prompt,
user_prompt_hash=user_prompt_hash,
url=existing_task_generation.url,
navigation_goal=existing_task_generation.navigation_goal,
navigation_payload=existing_task_generation.navigation_payload,
data_extraction_goal=existing_task_generation.data_extraction_goal,
extracted_information_schema=existing_task_generation.extracted_information_schema,
llm=existing_task_generation.llm,
llm_prompt=existing_task_generation.llm_prompt,
llm_response=existing_task_generation.llm_response,
source_task_generation_id=existing_task_generation.task_generation_id,
)
return new_task_generation

llm_prompt = prompt_engine.load_prompt("generate-task", user_prompt=data.prompt)
try:
llm_response = await app.LLM_API_HANDLER(prompt=llm_prompt)
Expand All @@ -775,6 +803,7 @@ async def generate_task(
task_generation = await app.DATABASE.create_task_generation(
organization_id=current_org.organization_id,
user_prompt=data.prompt,
user_prompt_hash=user_prompt_hash,
url=parsed_task_generation_obj.url,
navigation_goal=parsed_task_generation_obj.navigation_goal,
navigation_payload=parsed_task_generation_obj.navigation_payload,
Expand Down
12 changes: 5 additions & 7 deletions skyvern/forge/sdk/schemas/task_generations.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from datetime import datetime
from typing import Any

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field


class TaskGenerationBase(BaseModel):
model_config = ConfigDict(from_attributes=True)

organization_id: str | None = None
user_prompt: str | None = None
user_prompt_hash: str | None = None
url: str | None = None
navigation_goal: str | None = None
navigation_payload: dict[str, Any] | None = None
Expand All @@ -20,19 +21,16 @@ class TaskGenerationBase(BaseModel):
suggested_title: str | None = None


class TaskGenerationCreate(TaskGenerationBase):
organization_id: str
user_prompt: str


class TaskGeneration(TaskGenerationBase):
task_generation_id: str
organization_id: str
user_prompt: str
user_prompt_hash: str

created_at: datetime
modified_at: datetime


class GenerateTaskRequest(BaseModel):
prompt: str
# prompt needs to be at least 1 character long
prompt: str = Field(..., min_length=1)
Loading