Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX Fixed mypy Type Failures #269

Merged
merged 6 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging

from contextlib import closing
from typing import Optional
from typing import Optional, Sequence

from sqlalchemy import create_engine, func, and_
from sqlalchemy.engine.base import Engine
Expand Down Expand Up @@ -96,7 +96,7 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: int) -> list[Pr
func.ISJSON(PromptMemoryEntry.orchestrator_identifier) > 0,
func.JSON_VALUE(PromptMemoryEntry.orchestrator_identifier, "$.id") == str(orchestrator_id),
),
)
) # type: ignore
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {orchestrator_id}. {e}"
Expand All @@ -117,12 +117,12 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
return self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.conversation_id == conversation_id,
)
) # type: ignore
except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
return []

def add_request_pieces_to_memory(self, *, request_pieces: list[PromptRequestPiece]) -> None:
def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequestPiece]) -> None:
"""
Inserts a list of prompt request pieces into the memory storage.

Expand Down Expand Up @@ -171,7 +171,7 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
return self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.id.in_(prompt_ids),
)
) # type: ignore
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}"
Expand Down Expand Up @@ -239,7 +239,7 @@ def query_entries(self, model, *, conditions: Optional = None) -> list[Base]: #
query = session.query(model)
if conditions is not None:
query = query.filter(conditions)
return query.all()
return query.all() # TODO: use generics to make types work
elgertam marked this conversation as resolved.
Show resolved Hide resolved
except SQLAlchemyError as e:
logger.exception(f"Error fetching data from table {model.__tablename__}: {e}")

Expand Down
27 changes: 15 additions & 12 deletions pyrit/memory/duckdb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

from pathlib import Path
from typing import Union, Optional
from typing import MutableSequence, Union, Optional, Sequence
import logging

from sqlalchemy import create_engine, MetaData
Expand Down Expand Up @@ -84,14 +84,15 @@ def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result: list[PromptMemoryEntry] = self.query_entries(PromptMemoryEntry)
return [entry.get_prompt_request_piece() for entry in result]
entries = self.query_entries(PromptMemoryEntry)
result: list[PromptRequestPiece] = [entry.get_prompt_request_piece() for entry in entries]
return result

def get_all_embeddings(self) -> list[EmbeddingData]:
"""
Fetches all entries from the specified table and returns them as model instances.
"""
result = self.query_entries(EmbeddingData)
result: list[EmbeddingData] = self.query_entries(EmbeddingData)
return result

def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> list[PromptRequestPiece]:
Expand All @@ -105,9 +106,10 @@ def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> li
list[PromptRequestPiece]: A list of PromptRequestPieces with the specified conversation ID.
"""
try:
return self.query_entries(
result: list[PromptRequestPiece] = self.query_entries(
PromptMemoryEntry, conditions=PromptMemoryEntry.conversation_id == conversation_id
)
) # type: ignore
return result
except Exception as e:
logger.exception(f"Failed to retrieve conversation_id {conversation_id} with error {e}")
return []
Expand All @@ -126,7 +128,7 @@ def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[Prom
return self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.id.in_(prompt_ids),
)
) # type: ignore
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {prompt_ids}. {e}"
Expand All @@ -145,17 +147,18 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: int) -> list[Pr
list[PromptRequestPiece]: A list of PromptRequestPiece objects matching the specified orchestrator ID.
"""
try:
return self.query_entries(
result: list[PromptRequestPiece] = self.query_entries(
PromptMemoryEntry,
conditions=PromptMemoryEntry.orchestrator_identifier.op("->>")("id") == str(orchestrator_id),
)
) # type: ignore
return result
except Exception as e:
logger.exception(
f"Unexpected error: Failed to retrieve ConversationData with orchestrator {orchestrator_id}. {e}"
)
return []

def add_request_pieces_to_memory(self, *, request_pieces: list[PromptRequestPiece]) -> None:
def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequestPiece]) -> None:
"""
Inserts a list of prompt request pieces into the memory storage.

Expand Down Expand Up @@ -266,11 +269,11 @@ def query_entries(self, model, *, conditions: Optional = None) -> list[Base]: #
query = session.query(model)
if conditions is not None:
query = query.filter(conditions)
return query.all()
return query.all() # TODO: use generics to make types work
except SQLAlchemyError as e:
logger.exception(f"Error fetching data from table {model.__tablename__}: {e}")

def update_entries(self, *, entries: list[Base], update_fields: dict) -> bool: # type: ignore
def update_entries(self, *, entries: MutableSequence[Base], update_fields: dict) -> bool: # type: ignore
"""
Updates the given entries with the specified field values.

Expand Down
24 changes: 12 additions & 12 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import abc
from pathlib import Path

from typing import Optional
from typing import MutableSequence, Optional, Sequence
from uuid import uuid4

from pyrit.common.path import RESULTS_PATH
Expand Down Expand Up @@ -45,31 +45,31 @@ def disable_embedding(self):
self.memory_embedding = None

@abc.abstractmethod
def get_all_prompt_pieces(self) -> list[PromptRequestPiece]:
def get_all_prompt_pieces(self) -> Sequence[PromptRequestPiece]:
"""
Loads all ConversationData from the memory storage handler.
"""

@abc.abstractmethod
def get_all_embeddings(self) -> list[EmbeddingData]:
def get_all_embeddings(self) -> Sequence[EmbeddingData]:
"""
Loads all EmbeddingData from the memory storage handler.
"""

@abc.abstractmethod
def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> list[PromptRequestPiece]:
def _get_prompt_pieces_with_conversation_id(self, *, conversation_id: str) -> MutableSequence[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects that have the specified conversation ID.

Args:
conversation_id (str): The conversation ID to match.

Returns:
list[PromptRequestPiece]: A list of chat memory entries with the specified conversation ID.
MutableSequence[PromptRequestPiece]: A list of chat memory entries with the specified conversation ID.
"""

@abc.abstractmethod
def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: int) -> list[PromptRequestPiece]:
def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: int) -> Sequence[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects that have the specified orchestrator ID.

Expand All @@ -78,11 +78,11 @@ def _get_prompt_pieces_by_orchestrator(self, *, orchestrator_id: int) -> list[Pr
Can be retrieved by calling orchestrator.get_identifier()["id"]

Returns:
list[PromptRequestPiece]: A list of PromptMemoryEntry objects matching the specified orchestrator ID.
Sequence[PromptRequestPiece]: A list of PromptMemoryEntry objects matching the specified orchestrator ID.
"""

@abc.abstractmethod
def add_request_pieces_to_memory(self, *, request_pieces: list[PromptRequestPiece]) -> None:
def add_request_pieces_to_memory(self, *, request_pieces: Sequence[PromptRequestPiece]) -> None:
"""
Inserts a list of prompt request pieces into the memory storage.
"""
Expand All @@ -105,29 +105,29 @@ def get_scores_by_prompt_ids(self, *, prompt_request_response_ids: list[str]) ->
Gets a list of scores based on prompt_request_response_ids.
"""

def get_conversation(self, *, conversation_id: str) -> list[PromptRequestResponse]:
def get_conversation(self, *, conversation_id: str) -> MutableSequence[PromptRequestResponse]:
"""
Retrieves a list of PromptRequestResponse objects that have the specified conversation ID.

Args:
conversation_id (str): The conversation ID to match.

Returns:
list[PromptRequestResponse]: A list of chat memory entries with the specified conversation ID.
MutableSequence[PromptRequestResponse]: A list of chat memory entries with the specified conversation ID.
"""
request_pieces = self._get_prompt_pieces_with_conversation_id(conversation_id=conversation_id)
return group_conversation_request_pieces_by_sequence(request_pieces=request_pieces)

@abc.abstractmethod
def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> list[PromptRequestPiece]:
def get_prompt_request_pieces_by_id(self, *, prompt_ids: list[str]) -> Sequence[PromptRequestPiece]:
"""
Retrieves a list of PromptRequestPiece objects that have the specified prompt ids.

Args:
prompt_ids (list[int]): The prompt IDs to match.

Returns:
list[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
Sequence[PromptRequestPiece]: A list of PromptRequestPiece with the specified conversation ID.
"""

def get_prompt_request_piece_by_orchestrator_id(self, *, orchestrator_id: int) -> list[PromptRequestPiece]:
Expand Down
40 changes: 23 additions & 17 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# mypy: ignore-errors

import uuid
from typing import Literal

from pydantic import BaseModel, ConfigDict
from sqlalchemy import Column, String, DateTime, Float, JSON, ForeignKey, Index, INTEGER, ARRAY, Uuid
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, String, DateTime, Float, JSON, ForeignKey, Index, INTEGER, ARRAY
from sqlalchemy.types import Uuid # type: ignore
from sqlalchemy.orm import DeclarativeBase # type: ignore
from sqlalchemy.orm import Mapped # type: ignore

from pyrit.models import PromptRequestPiece, Score


Base = declarative_base()
class Base(DeclarativeBase):
pass


class PromptMemoryEntry(Base):
Expand Down Expand Up @@ -52,22 +54,26 @@ class PromptMemoryEntry(Base):
__tablename__ = "PromptMemoryEntries"
__table_args__ = {"extend_existing": True}
id = Column(Uuid, nullable=False, primary_key=True)
role = Column(String, nullable=False)
role: Mapped[Literal["system", "user", "assistant"]] = Column(String, nullable=False)
conversation_id = Column(String, nullable=False)
sequence = Column(INTEGER, nullable=False)
timestamp = Column(DateTime, nullable=False)
labels = Column(JSON)
labels: Mapped[dict[str, str]] = Column(JSON)
prompt_metadata = Column(String, nullable=True)
converter_identifiers = Column(JSON)
prompt_target_identifier = Column(JSON)
orchestrator_identifier = Column(JSON)
response_error = Column(String, nullable=True)

original_value_data_type = Column(String, nullable=False)
converter_identifiers: Mapped[dict[str, str]] = Column(JSON)
prompt_target_identifier: Mapped[dict[str, str]] = Column(JSON)
orchestrator_identifier: Mapped[dict[str, str]] = Column(JSON)
response_error: Mapped[Literal["blocked", "none", "processing", "unknown"]] = Column(String, nullable=True)

original_value_data_type: Mapped[Literal["text", "image_path", "audio_path", "url", "error"]] = Column(
String, nullable=False
)
original_value = Column(String, nullable=False)
original_value_sha256 = Column(String)

converted_value_data_type = Column(String, nullable=False)
converted_value_data_type: Mapped[Literal["text", "image_path", "audio_path", "url", "error"]] = Column(
String, nullable=False
)
converted_value = Column(String)
converted_value_sha256 = Column(String)

Expand Down Expand Up @@ -134,7 +140,7 @@ class EmbeddingData(Base): # type: ignore
# Allows table redefinition if already defined.
__table_args__ = {"extend_existing": True}
id = Column(Uuid(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), primary_key=True)
embedding = Column(ARRAY(Float).with_variant(JSON, "mssql"))
embedding = Column(ARRAY(Float).with_variant(JSON, "mssql")) # type: ignore
embedding_type_name = Column(String)

def __str__(self):
Expand All @@ -153,11 +159,11 @@ class ScoreEntry(Base): # type: ignore
id = Column(Uuid(as_uuid=True), nullable=False, primary_key=True)
score_value = Column(String, nullable=False)
score_value_description = Column(String, nullable=True)
score_type = Column(String, nullable=False)
score_type: Mapped[Literal["true_false", "float_scale"]] = Column(String, nullable=False)
score_category = Column(String, nullable=False)
score_rationale = Column(String, nullable=True)
score_metadata = Column(String, nullable=True)
scorer_class_identifier = Column(JSON)
scorer_class_identifier: Mapped[dict[str, str]] = Column(JSON)
prompt_request_response_id = Column(Uuid(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"))
date_time = Column(DateTime, nullable=False)

Expand Down
12 changes: 6 additions & 6 deletions pyrit/models/prompt_request_response.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Optional
from typing import MutableSequence, Optional, Sequence

from pyrit.models import PromptRequestPiece
from pyrit.models.literals import PromptDataType, PromptResponseError
Expand Down Expand Up @@ -51,20 +51,20 @@ def __str__(self):


def group_conversation_request_pieces_by_sequence(
request_pieces: list[PromptRequestPiece],
) -> list[PromptRequestResponse]:
request_pieces: Sequence[PromptRequestPiece],
) -> MutableSequence[PromptRequestResponse]:
"""
Groups prompt request pieces from the same conversation into PromptRequestResponses.

This is done using the sequence number and conversation ID.

Args:
request_pieces (list[PromptRequestPiece]): A list of PromptRequestPiece objects representing individual
request_pieces (Sequence[PromptRequestPiece]): A list of PromptRequestPiece objects representing individual
request pieces.

Returns:
list[PromptRequestResponse]: A list of PromptRequestResponse objects representing grouped request pieces.
this is ordered by the sequence number
MutableSequence[PromptRequestResponse]: A list of PromptRequestResponse objects representing grouped request
pieces. This is ordered by the sequence number

Raises:
ValueError: If the conversation ID of any request piece does not match the conversation ID of the first
Expand Down
2 changes: 1 addition & 1 deletion pyrit/orchestrator/orchestrator_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

def dispose_db_engine(self) -> None:
"""
Dispose DuckDB database engine to release database connections and resources.
Dispose database engine to release database connections and resources.
"""
self._memory.dispose_engine()

Expand Down
5 changes: 3 additions & 2 deletions pyrit/orchestrator/scoring_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import asyncio
import logging
from typing import Sequence

from pyrit.memory import MemoryInterface
from pyrit.models.prompt_request_piece import PromptRequestPiece
Expand Down Expand Up @@ -60,7 +61,7 @@ async def score_prompts_by_request_id_async(self, *, scorer: Scorer, prompt_ids:
Scores prompts using the Scorer for prompts with the prompt_ids
"""

requests: list[PromptRequestPiece] = []
requests: Sequence[PromptRequestPiece] = []
requests = self._memory.get_prompt_request_pieces_by_id(prompt_ids=prompt_ids)

return await self._score_prompts_batch_async(prompts=requests, scorer=scorer)
Expand All @@ -71,7 +72,7 @@ def _extract_responses_only(self, request_responses: list[PromptRequestPiece]) -
"""
return [response for response in request_responses if response.role == "assistant"]

async def _score_prompts_batch_async(self, prompts: list[PromptRequestPiece], scorer: Scorer) -> list[Score]:
async def _score_prompts_batch_async(self, prompts: Sequence[PromptRequestPiece], scorer: Scorer) -> list[Score]:
results = []

for prompts_batch in self._chunked_prompts(prompts, self._batch_size):
Expand Down
Loading