Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Adding Scores to the Database #195

Merged
merged 6 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion pyrit/memory/duckdb_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from sqlalchemy.engine.base import Engine
from contextlib import closing

from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base
from pyrit.memory.memory_models import EmbeddingData, PromptMemoryEntry, Base, ScoreEntry
from pyrit.memory.memory_interface import MemoryInterface
from pyrit.common.path import RESULTS_PATH
from pyrit.common.singleton import Singleton
from pyrit.models import PromptRequestPiece
from pyrit.models import Score

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -146,6 +147,22 @@ def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> N
"""
self._insert_entries(entries=embedding_data)

def add_scores_to_memory(self, *, scores: list[Score]) -> None:
"""
Inserts a list of scores into the memory storage.
"""
self._insert_entries(entries=[ScoreEntry(entry=score) for score in scores])

def get_scores_by_prompt_ids(self, *, prompt_request_response_ids: list[str]) -> list[Score]:
rlundeen2 marked this conversation as resolved.
Show resolved Hide resolved
"""
Gets a list of scores based on prompt_request_response_ids.
"""
entries = self.query_entries(
ScoreEntry, conditions=ScoreEntry.prompt_request_response_id.in_(prompt_request_response_ids)
)

return [entry.get_score() for entry in entries]

def update_entries_by_conversation_id(self, *, conversation_id: str, update_fields: dict) -> bool:
"""
Updates entries for a given conversation ID with the specified field values.
Expand Down
14 changes: 13 additions & 1 deletion pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pathlib import Path

from pyrit.memory.memory_models import EmbeddingData
from pyrit.models import PromptRequestResponse, PromptRequestPiece, PromptResponseError, PromptDataType
from pyrit.models import PromptRequestResponse, Score, PromptRequestPiece, PromptResponseError, PromptDataType

from pyrit.memory.memory_embedding import default_memory_embedding_factory
from pyrit.memory.memory_embedding import MemoryEmbedding
Expand Down Expand Up @@ -87,6 +87,18 @@ def _add_embeddings_to_memory(self, *, embedding_data: list[EmbeddingData]) -> N
Inserts embedding data into memory storage
"""

@abc.abstractmethod
def add_scores_to_memory(self, *, scores: list[Score]) -> None:
"""
Inserts a list of scores into the memory storage.
"""

@abc.abstractmethod
def get_scores_by_prompt_ids(self, *, prompt_request_response_ids: list[str]) -> list[Score]:
"""
Gets a list of scores based on prompt_request_response_ids.
"""

def get_conversation(self, *, conversation_id: str) -> list[PromptRequestResponse]:
"""
Retrieves a list of PromptRequestResponse objects that have the specified conversation ID.
Expand Down
57 changes: 39 additions & 18 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.dialects.postgresql import UUID

from pyrit.models import PromptRequestPiece
from pyrit.models import PromptRequestPiece, Score


Base = declarative_base()
Expand Down Expand Up @@ -142,30 +142,51 @@ def __str__(self):
return f"{self.id}"


class Score(Base): # type: ignore
class ScoreEntry(Base): # type: ignore
"""
Represents the Score
Represents the Score Memory Entry

Attributes:
uuid (UUID): The primary key, which is a foreign key referencing the UUID in the MemoryEntries table.
embedding (ARRAY(Float)): An array of floats representing the embedding vector.
embedding_type_name (String): The name or type of the embedding, indicating the model or method used.
"""

__tablename__ = "Score"
# Allows table redefinition if already defined.
__tablename__ = "ScoreEntries"
__table_args__ = {"extend_existing": True}

id = Column(UUID(as_uuid=True), nullable=False, primary_key=True)

scorer = Column(String) # identifier for the class

id = Column(UUID(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"), primary_key=True)
embedding = Column(ARRAY(Float))
embedding_type_name = Column(String)

def __str__(self):
return f"{self.id}"
score_value = Column(String, nullable=False)
score_value_description = Column(String, nullable=True)
score_type = Column(String, nullable=False)
score_category = Column(String, nullable=False)
score_rationale = Column(String, nullable=True)
score_metadata = Column(String, nullable=True)
scorer_class_identifier = Column(JSON)
prompt_request_response_id = Column(UUID(as_uuid=True), ForeignKey(f"{PromptMemoryEntry.__tablename__}.id"))
date_time = Column(DateTime, nullable=False)

def __init__(self, *, entry: Score):
self.id = entry.id
self.score_value = entry.score_value
self.score_value_description = entry.score_value_description
self.score_type = entry.score_type
self.score_category = entry.score_category
self.score_rationale = entry.score_rationale
self.score_metadata = entry.score_metadata
self.scorer_class_identifier = entry.scorer_class_identifier
self.prompt_request_response_id = entry.prompt_request_response_id if entry.prompt_request_response_id else None
self.date_time = entry.date_time

def get_score(self) -> Score:
return Score(
id=self.id,
score_value=self.score_value,
score_value_description=self.score_value_description,
score_type=self.score_type,
score_category=self.score_category,
score_rationale=self.score_rationale,
score_metadata=self.score_metadata,
scorer_class_identifier=self.scorer_class_identifier,
prompt_request_response_id=self.prompt_request_response_id,
date_time=self.date_time,
)


class ConversationMessageWithSimilarity(BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions pyrit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pyrit.models.prompt_request_piece import PromptRequestPiece
from pyrit.models.prompt_request_response import PromptRequestResponse, group_conversation_request_pieces_by_sequence
from pyrit.models.identifiers import Identifier
from pyrit.models.score import Score, ScoreType


__all__ = [
Expand All @@ -31,5 +32,7 @@
"PromptResponseError",
"PromptDataType",
"PromptRequestResponse",
"Score",
"ScoreType",
"TextDataTypeSerializer",
]
214 changes: 111 additions & 103 deletions pyrit/score/score_class.py → pyrit/models/score.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,111 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Dict, Literal
import uuid


ScoreType = Literal["true_false", "float_scale"]


class Score:

id: uuid.UUID | str

# The value the scorer ended up with; e.g. True (if true_false) or 0 (if float_scale)
score_value: str

# Value that can include a description of the score value
score_value_description: str

# The type of the scorer; e.g. "true_false" or "float_scale"
score_type: ScoreType

# The type of the harms category (e.g. "hate" or "violence")
score_category: str

# Extra data the scorer provides around the rationale of the score
score_rationale: str

# Custom metadata a scorer might use. This is left undefined other than for the
# specific scorer that uses it.
metadata: str

# The identifier of the scorer class, including relavent information
# e.g. {"scorer_name": "SelfAskScorer", "classifier": "current_events.yml"}
scorer_class_identifier: Dict[str, str]

# This is the prompt_request_response_id that the score is scoring
# Note a scorer can generate an additional request. This is NOT that, but
# the request associated with what we're scoring.
prompt_request_response_id: uuid.UUID | str

def __init__(
self,
score_value: str,
score_value_description: str,
score_type: ScoreType,
score_category: str,
score_rationale: str,
metadata: str,
scorer_class_identifier: Dict[str, str],
prompt_request_response_id: str | uuid.UUID,
):
self.id = uuid.uuid4()

self._validate(score_type, score_value)

self.score_value = score_value
self.score_value_description = score_value_description
self.score_type = score_type
self.score_category = score_category
self.score_rationale = score_rationale
self.metadata = metadata
self.scorer_class_identifier = scorer_class_identifier
self.prompt_request_response_id = prompt_request_response_id

def get_value(self):
"""
Returns the value of the score based on its type.

If the score type is "true_false", it returns True if the score value is "true" (case-insensitive),
otherwise it returns False.

If the score type is "float_scale", it returns the score value as a float.

Raises:
ValueError: If the score type is unknown.

Returns:
The value of the score based on its type.
"""
if self.score_type == "true_false":
return self.score_value.lower() == "true"
elif self.score_type == "float_scale":
return float(self.score_value)

raise ValueError(f"Unknown scorer type: {self.score_type}")

def __str__(self):
if self.scorer_class_identifier:
return f"{self.scorer_class_identifier['__type__']}: {self.score_value}"
return f": {self.score_value}"

def _validate(self, scorer_type, score_value):
if scorer_type == "true_false" and str(score_value).lower() not in ["true", "false"]:
raise ValueError(f"True False scorers must have a score value of 'true' or 'false' not {score_value}")
elif scorer_type == "float_scale":
try:
score = float(score_value)
if not (0 <= score <= 1):
raise ValueError(f"Float scale scorers must have a score value between 0 and 1. Got {score_value}")
except ValueError:
raise ValueError(f"Float scale scorers require a numeric score value. Got {score_value}")
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
rlundeen2 marked this conversation as resolved.
Show resolved Hide resolved

from datetime import datetime
from typing import Dict, Literal, Optional
import uuid


ScoreType = Literal["true_false", "float_scale"]


class Score:

id: uuid.UUID | str

# The value the scorer ended up with; e.g. True (if true_false) or 0 (if float_scale)
score_value: str

# Value that can include a description of the score value
score_value_description: str

# The type of the scorer; e.g. "true_false" or "float_scale"
score_type: ScoreType

# The type of the harms category (e.g. "hate" or "violence")
score_category: str

# Extra data the scorer provides around the rationale of the score
score_rationale: str

# Custom metadata a scorer might use. This is left undefined other than for the
# specific scorer that uses it.
score_metadata: str

# The identifier of the scorer class, including relavent information
# e.g. {"scorer_name": "SelfAskScorer", "classifier": "current_events.yml"}
scorer_class_identifier: Dict[str, str]

# This is the prompt_request_response_id that the score is scoring
# Note a scorer can generate an additional request. This is NOT that, but
# the request associated with what we're scoring.
prompt_request_response_id: uuid.UUID | str

# Timestamp of when the score was created
timestamp: datetime

def __init__(
self,
*,
id: Optional[uuid.UUID] = None,
score_value: str,
score_value_description: str,
score_type: ScoreType,
score_category: str,
score_rationale: str,
score_metadata: str,
scorer_class_identifier: Dict[str, str],
prompt_request_response_id: str | uuid.UUID,
date_time: Optional[datetime] = datetime.now(),
):
self.id = id if id else uuid.uuid4()

self._validate(score_type, score_value)

self.score_value = score_value
self.score_value_description = score_value_description
self.score_type = score_type
self.score_category = score_category
self.score_rationale = score_rationale
self.score_metadata = score_metadata
self.scorer_class_identifier = scorer_class_identifier
self.prompt_request_response_id = prompt_request_response_id
self.date_time = date_time

def get_value(self):
"""
Returns the value of the score based on its type.

If the score type is "true_false", it returns True if the score value is "true" (case-insensitive),
otherwise it returns False.

If the score type is "float_scale", it returns the score value as a float.

Raises:
ValueError: If the score type is unknown.

Returns:
The value of the score based on its type.
"""
if self.score_type == "true_false":
return self.score_value.lower() == "true"
elif self.score_type == "float_scale":
return float(self.score_value)

raise ValueError(f"Unknown scorer type: {self.score_type}")

def __str__(self):
if self.scorer_class_identifier:
return f"{self.scorer_class_identifier['__type__']}: {self.score_value}"
return f": {self.score_value}"

def _validate(self, scorer_type, score_value):
if scorer_type == "true_false" and str(score_value).lower() not in ["true", "false"]:
raise ValueError(f"True False scorers must have a score value of 'true' or 'false' not {score_value}")
elif scorer_type == "float_scale":
try:
score = float(score_value)
if not (0 <= score <= 1):
raise ValueError(f"Float scale scorers must have a score value between 0 and 1. Got {score_value}")
except ValueError:
raise ValueError(f"Float scale scorers require a numeric score value. Got {score_value}")
4 changes: 1 addition & 3 deletions pyrit/score/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pyrit.score.score_class import Score, ScoreType
from pyrit.models import Score, ScoreType # noqa: F401

from pyrit.score.scorer import Scorer

Expand All @@ -13,8 +13,6 @@
from pyrit.score.substring_scorer import SubStringScorer

__all__ = [
"Score",
"ScoreType",
"Scorer",
"SelfAskCategoryScorer",
"ContentClassifierPaths",
Expand Down
Loading
Loading