Skip to content

Commit

Permalink
fix: Add index to messages table (#2280)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 authored Dec 18, 2024
1 parent 5163f6b commit b135223
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Add composite index to messages table
Revision ID: d6632deac81d
Revises: 54dec07619c4
Create Date: 2024-12-18 13:38:56.511701
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "d6632deac81d"
down_revision: Union[str, None] = "54dec07619c4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index("ix_messages_agent_created_at", "messages", ["agent_id", "created_at"], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_messages_agent_created_at", table_name="messages")
# ### end Alembic commands ###
3 changes: 2 additions & 1 deletion letta/orm/message.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from sqlalchemy import Index
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.custom_columns import ToolCallColumn
Expand All @@ -13,7 +14,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
"""Defines data model for storing Message objects"""

__tablename__ = "messages"
__table_args__ = {"extend_existing": True}
__table_args__ = (Index("ix_messages_agent_created_at", "agent_id", "created_at"),)
__pydantic_model__ = PydanticMessage

id: Mapped[str] = mapped_column(primary_key=True, doc="Unique message identifier")
Expand Down
29 changes: 14 additions & 15 deletions letta/orm/passage.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
from typing import TYPE_CHECKING
from sqlalchemy import Column, JSON, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr

from letta.orm.mixins import FileMixin, OrganizationMixin
from sqlalchemy import JSON, Column, Index
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship

from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.orm.custom_columns import CommonVector, EmbeddingConfigColumn
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings

from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM

config = LettaConfig()

if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.agent import Agent
from letta.orm.organization import Organization


class BasePassage(SqlalchemyBase, OrganizationMixin):
"""Base class for all passage types with common fields"""

__abstract__ = True
__pydantic_model__ = PydanticPassage

Expand All @@ -45,17 +45,15 @@ def organization(cls) -> Mapped["Organization"]:
@declared_attr
def __table_args__(cls):
if settings.letta_pg_uri_no_default:
return (
Index(f'{cls.__tablename__}_org_idx', 'organization_id'),
{"extend_existing": True}
)
return (Index(f"{cls.__tablename__}_org_idx", "organization_id"), {"extend_existing": True})
return ({"extend_existing": True},)


class SourcePassage(BasePassage, FileMixin, SourceMixin):
"""Passages derived from external files/sources"""

__tablename__ = "source_passages"

@declared_attr
def file(cls) -> Mapped["FileMetadata"]:
"""Relationship to file"""
Expand All @@ -64,7 +62,7 @@ def file(cls) -> Mapped["FileMetadata"]:
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="source_passages", lazy="selectin")

@declared_attr
def source(cls) -> Mapped["Source"]:
"""Relationship to source"""
Expand All @@ -73,8 +71,9 @@ def source(cls) -> Mapped["Source"]:

class AgentPassage(BasePassage, AgentMixin):
"""Passages created by agents as archival memories"""

__tablename__ = "agent_passages"

@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="agent_passages", lazy="selectin")
Expand Down

0 comments on commit b135223

Please sign in to comment.