Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-1685: [sdk] Create example predictions and annotations from a LabelConfig #360

Merged
merged 8 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 150 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ typing_extensions = ">= 4.0.0"
ujson = ">=5.8.0"
xmljson = "0.2.1"

jsf = "^0.11.2"
[tool.poetry.dev-dependencies]
mypy = "1.0.1"
pytest = "^7.4.0"
Expand Down
66 changes: 62 additions & 4 deletions src/label_studio_sdk/label_interface/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from collections import defaultdict, OrderedDict
from lxml import etree
import xmljson
from jsf import JSF

from label_studio_sdk._legacy.exceptions import (
LSConfigParseException,
Expand Down Expand Up @@ -770,7 +771,7 @@ def validate_region(self, region) -> bool:
return False

# type of the region should match the tag name
if control.tag.lower() != region["type"]:
if control.tag.lower() != region["type"].lower():
return False

# make sure that in config it connects to the same tag as
Expand Down Expand Up @@ -839,9 +840,66 @@ def generate_sample_task(self, mode="upload", secure_mode=False):

return task

def generate_sample_annotation(self):
""" """
raise NotImplemented()
def generate_sample_prediction(self) -> Optional[dict]:
"""Generates a sample prediction that is valid for this label config.

Example:
{'model_version': 'sample model version',
'score': 0.0,
'result': [{'id': 'e7bd76e6-4e88-4eb3-b433-55e03661bf5d',
'from_name': 'sentiment',
'to_name': 'text',
'type': 'choices',
'value': {'choices': ['Neutral']}}]}

NOTE: `id` field in result is not required when importing predictions; it will be generated automatically.
NOTE: for each control tag, depends on tag.to_json_schema() being implemented correctly
"""
prediction = PredictionValue(
model_version='sample model version',
result=self.create_regions({
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
control.name: JSF(control.to_json_schema()).generate()
for control in self.controls
})
)
prediction_dct = prediction.model_dump()
if self.validate_prediction(prediction_dct):
return prediction_dct
else:
logger.debug(f'Sample prediction {prediction_dct} failed validation for label config {self.config}')
return None

def generate_sample_annotation(self) -> Optional[dict]:
"""Generates a sample annotation that is valid for this label config.

Example:
{'was_cancelled': False,
'ground_truth': False,
'lead_time': 0.0,
'result_count': 0,
'completed_by': -1,
'result': [{'id': 'b05da11d-3ffc-4657-8b8d-f5bc37cd59ac',
'from_name': 'sentiment',
'to_name': 'text',
'type': 'choices',
'value': {'choices': ['Negative']}}]}

NOTE: `id` field in result is not required when importing predictions; it will be generated automatically.
NOTE: for each control tag, depends on tag.to_json_schema() being implemented correctly
"""
annotation = AnnotationValue(
completed_by=-1, # annotator's user id
result=self.create_regions({
control.name: JSF(control.to_json_schema()).generate()
for control in self.controls
})
)
annotation_dct = annotation.model_dump()
if self.validate_annotation(annotation_dct):
return annotation_dct
else:
logger.debug(f'Sample annotation {annotation_dct} failed validation for label config {self.config}')
return None

#####
##### COMPATIBILITY LAYER
Expand Down
11 changes: 1 addition & 10 deletions src/label_studio_sdk/label_interface/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,7 @@

class Region(BaseModel):
"""
Class for Region Tag

Attributes:
-----------
id: str
The unique identifier of the region
x: int
The x coordinate of the region
y: int

A Region is an item in the `result` list of a PredictionValue or AnnotationValue.
"""

id: str = Field(default_factory=lambda: str(uuid4()))
Expand Down
Loading