Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DIA-1402: Add RefinedPromptResponse schema #331

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .mock/definition/__package__.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2347,6 +2347,22 @@ types:
organization: optional<PromptVersionOrganization>
source:
openapi: openapi/openapi.yaml
RefinedPromptResponse:
properties:
title:
type: string
docs: Title of the refined prompt
reasoning:
type: string
docs: Reasoning behind the refinement
prompt:
type: string
docs: The refined prompt text
refinement_job_id:
type: string
docs: Unique identifier for the refinement job
source:
openapi: openapi/openapi.yaml
InferenceRunOrganization:
discriminated: false
union:
Expand Down
12 changes: 3 additions & 9 deletions .mock/definition/prompts/versions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ service:
docs: Project ID to target the refined prompt for
response:
docs: ''
type: root.PromptVersion
type: root.RefinedPromptResponse
examples:
- path-parameters:
prompt_id: 1
Expand All @@ -201,15 +201,9 @@ service:
response:
body:
title: title
parent_model: 1
model_provider_connection: 1
reasoning: reasoning
prompt: prompt
provider: OpenAI
provider_model_id: provider_model_id
created_by: 1
created_at: '2024-01-15T09:30:00Z'
updated_at: '2024-01-15T09:30:00Z'
organization: 1
refinement_job_id: refinement_job_id
audiences:
- public
source:
Expand Down
14 changes: 7 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/label_studio_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
RedisExportStorageStatus,
RedisImportStorage,
RedisImportStorageStatus,
RefinedPromptResponse,
S3ExportStorage,
S3ExportStorageStatus,
S3ImportStorage,
Expand Down Expand Up @@ -275,6 +276,7 @@
"RedisExportStorageStatus",
"RedisImportStorage",
"RedisImportStorageStatus",
"RefinedPromptResponse",
"S3ExportStorage",
"S3ExportStorageStatus",
"S3ImportStorage",
Expand Down
13 changes: 7 additions & 6 deletions src/label_studio_sdk/prompts/versions/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ...types.prompt_version_created_by import PromptVersionCreatedBy
from ...types.prompt_version_organization import PromptVersionOrganization
from ...types.prompt_version_provider import PromptVersionProvider
from ...types.refined_prompt_response import RefinedPromptResponse

# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)
Expand Down Expand Up @@ -344,7 +345,7 @@ def refine_prompt(
teacher_model_name: typing.Optional[str] = OMIT,
project_id: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> PromptVersion:
) -> RefinedPromptResponse:
"""
Refine a prompt version using a teacher model and save the refined prompt as a new version.

Expand All @@ -370,7 +371,7 @@ def refine_prompt(

Returns
-------
PromptVersion
RefinedPromptResponse


Examples
Expand Down Expand Up @@ -398,7 +399,7 @@ def refine_prompt(
)
try:
if 200 <= _response.status_code < 300:
return pydantic_v1.parse_obj_as(PromptVersion, _response.json()) # type: ignore
return pydantic_v1.parse_obj_as(RefinedPromptResponse, _response.json()) # type: ignore
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
Expand Down Expand Up @@ -735,7 +736,7 @@ async def refine_prompt(
teacher_model_name: typing.Optional[str] = OMIT,
project_id: typing.Optional[int] = OMIT,
request_options: typing.Optional[RequestOptions] = None,
) -> PromptVersion:
) -> RefinedPromptResponse:
"""
Refine a prompt version using a teacher model and save the refined prompt as a new version.

Expand All @@ -761,7 +762,7 @@ async def refine_prompt(

Returns
-------
PromptVersion
RefinedPromptResponse


Examples
Expand Down Expand Up @@ -789,7 +790,7 @@ async def refine_prompt(
)
try:
if 200 <= _response.status_code < 300:
return pydantic_v1.parse_obj_as(PromptVersion, _response.json()) # type: ignore
return pydantic_v1.parse_obj_as(RefinedPromptResponse, _response.json()) # type: ignore
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
Expand Down
2 changes: 2 additions & 0 deletions src/label_studio_sdk/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from .redis_export_storage_status import RedisExportStorageStatus
from .redis_import_storage import RedisImportStorage
from .redis_import_storage_status import RedisImportStorageStatus
from .refined_prompt_response import RefinedPromptResponse
from .s3export_storage import S3ExportStorage
from .s3export_storage_status import S3ExportStorageStatus
from .s3import_storage import S3ImportStorage
Expand Down Expand Up @@ -167,6 +168,7 @@
"RedisExportStorageStatus",
"RedisImportStorage",
"RedisImportStorageStatus",
"RefinedPromptResponse",
"S3ExportStorage",
"S3ExportStorageStatus",
"S3ImportStorage",
Expand Down
47 changes: 47 additions & 0 deletions src/label_studio_sdk/types/refined_prompt_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This file was auto-generated by Fern from our API Definition.

import datetime as dt
import typing

from ..core.datetime_utils import serialize_datetime
from ..core.pydantic_utilities import deep_union_pydantic_dicts, pydantic_v1


class RefinedPromptResponse(pydantic_v1.BaseModel):
title: str = pydantic_v1.Field()
"""
Title of the refined prompt
"""

reasoning: str = pydantic_v1.Field()
"""
Reasoning behind the refinement
"""

prompt: str = pydantic_v1.Field()
"""
The refined prompt text
"""

refinement_job_id: str = pydantic_v1.Field()
"""
Unique identifier for the refinement job
"""

def json(self, **kwargs: typing.Any) -> str:
kwargs_with_defaults: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs}
return super().json(**kwargs_with_defaults)

def dict(self, **kwargs: typing.Any) -> typing.Dict[str, typing.Any]:
kwargs_with_defaults_exclude_unset: typing.Any = {"by_alias": True, "exclude_unset": True, **kwargs}
kwargs_with_defaults_exclude_none: typing.Any = {"by_alias": True, "exclude_none": True, **kwargs}

return deep_union_pydantic_dicts(
super().dict(**kwargs_with_defaults_exclude_unset), super().dict(**kwargs_with_defaults_exclude_none)
)

class Config:
frozen = True
smart_union = True
extra = pydantic_v1.Extra.allow
json_encoders = {dt.datetime: serialize_datetime}
23 changes: 3 additions & 20 deletions tests/prompts/test_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,28 +152,11 @@ async def test_update(client: LabelStudio, async_client: AsyncLabelStudio) -> No
async def test_refine_prompt(client: LabelStudio, async_client: AsyncLabelStudio) -> None:
expected_response: typing.Any = {
"title": "title",
"parent_model": 1,
"model_provider_connection": 1,
"reasoning": "reasoning",
"prompt": "prompt",
"provider": "OpenAI",
"provider_model_id": "provider_model_id",
"created_by": 1,
"created_at": "2024-01-15T09:30:00Z",
"updated_at": "2024-01-15T09:30:00Z",
"organization": 1,
}
expected_types: typing.Any = {
"title": None,
"parent_model": "integer",
"model_provider_connection": "integer",
"prompt": None,
"provider": None,
"provider_model_id": None,
"created_by": "integer",
"created_at": "datetime",
"updated_at": "datetime",
"organization": "integer",
"refinement_job_id": "refinement_job_id",
}
expected_types: typing.Any = {"title": None, "reasoning": None, "prompt": None, "refinement_job_id": None}
response = client.prompts.versions.refine_prompt(prompt_id=1, version_id=1)
validate_response(response, expected_response, expected_types)

Expand Down