Skip to content

Commit

Permalink
[ENHANCEMENT] argilla: add record status property (#5184)
Browse files Browse the repository at this point in the history
# Description
<!-- Please include a summary of the changes and the related issue.
Please also include relevant motivation and context. List any
dependencies that are required for this change. -->

This PR adds the record status as a read-only property in the `Record`
resource class.

Closes #5141

**Type of change**
<!-- Please delete options that are not relevant. Remember to title the
PR according to the type of change -->

- Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**
<!-- Please add some reference about how your feature has been tested.
-->

**Checklist**
<!-- Please go over the list and make sure you've taken everything into
account -->

- I added relevant documentation
- I followed the style guidelines of this project
- I did a self-review of my code
- I made corresponding changes to the documentation
- I confirm My changes generate no new warnings
- I have added tests that prove my fix is effective or that my feature
works
- I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Jul 12, 2024
1 parent 22263d8 commit c219764
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 7 deletions.
4 changes: 2 additions & 2 deletions argilla/src/argilla/_models/_record/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, Literal

from pydantic import Field, field_serializer, field_validator

Expand All @@ -30,12 +30,12 @@
class RecordModel(ResourceModel):
"""Schema for the records of a `Dataset`"""

status: Literal["pending", "completed"] = "pending"
fields: Optional[Dict[str, FieldValue]] = None
metadata: Optional[Union[List[MetadataModel], Dict[str, MetadataValue]]] = Field(default_factory=dict)
vectors: Optional[List[VectorModel]] = Field(default_factory=list)
responses: Optional[List[UserResponseModel]] = Field(default_factory=list)
suggestions: Optional[Union[Tuple[SuggestionModel], List[SuggestionModel]]] = Field(default_factory=tuple)

external_id: Optional[Any] = None

@field_serializer("external_id", when_used="unless-none")
Expand Down
22 changes: 17 additions & 5 deletions argilla/src/argilla/records/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(

def __repr__(self) -> str:
return (
f"Record(id={self.id},fields={self.fields},metadata={self.metadata},"
f"Record(id={self.id},status={self.status},fields={self.fields},metadata={self.metadata},"
f"suggestions={self.suggestions},responses={self.responses})"
)

Expand Down Expand Up @@ -147,6 +147,10 @@ def metadata(self) -> "RecordMetadata":
def vectors(self) -> "RecordVectors":
return self.__vectors

@property
def status(self) -> str:
return self._model.status

@property
def _server_id(self) -> Optional[UUID]:
return self._model.id
Expand All @@ -164,6 +168,7 @@ def api_model(self) -> RecordModel:
vectors=self.vectors.api_models(),
responses=self.responses.api_models(),
suggestions=self.suggestions.api_models(),
status=self.status,
)

def serialize(self) -> Dict[str, Any]:
Expand All @@ -185,6 +190,7 @@ def to_dict(self) -> Dict[str, Dict]:
"""
id = str(self.id) if self.id else None
server_id = str(self._model.id) if self._model.id else None
status = self.status
fields = self.fields.to_dict()
metadata = self.metadata.to_dict()
suggestions = self.suggestions.to_dict()
Expand All @@ -198,6 +204,7 @@ def to_dict(self) -> Dict[str, Dict]:
"suggestions": suggestions,
"responses": responses,
"vectors": vectors,
"status": status,
"_server_id": server_id,
}

Expand Down Expand Up @@ -245,7 +252,7 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
Returns:
A Record object.
"""
return cls(
instance = cls(
id=model.external_id,
fields=model.fields,
metadata={meta.name: meta.value for meta in model.metadata},
Expand All @@ -257,10 +264,15 @@ def from_model(cls, model: RecordModel, dataset: "Dataset") -> "Record":
for response in UserResponse.from_model(response_model, dataset=dataset)
],
suggestions=[Suggestion.from_model(model=suggestion, dataset=dataset) for suggestion in model.suggestions],
_dataset=dataset,
_server_id=model.id,
)

# set private attributes
instance._dataset = dataset
instance._model.id = model.id
instance._model.status = model.status

return instance


class RecordFields(dict):
"""This is a container class for the fields of a Record.
Expand Down Expand Up @@ -335,7 +347,7 @@ def to_dict(self) -> Dict[str, List[Dict]]:
response_dict = defaultdict(list)
for response in self.__responses:
response_dict[response.question_name].append({"value": response.value, "user_id": str(response.user_id)})
return response_dict
return dict(response_dict)

def api_models(self) -> List[UserResponseModel]:
"""Returns a list of ResponseModel objects."""
Expand Down
13 changes: 13 additions & 0 deletions argilla/tests/integration/test_list_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ def test_list_records_with_start_offset(client: Argilla, dataset: Dataset):
records = list(dataset.records(start_offset=1))
assert len(records) == 1

assert [record.to_dict() for record in records] == [
{
"_server_id": str(records[0]._server_id),
"fields": {"text": "The record text field"},
"id": "2",
"status": "pending",
"metadata": {},
"responses": {},
"suggestions": {},
"vectors": {},
}
]


def test_list_records_with_responses(client: Argilla, dataset: Dataset):
dataset.records.log(
Expand Down
1 change: 1 addition & 0 deletions argilla/tests/unit/test_io/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_to_list_flatten(self):
assert records_list == [
{
"id": str(record.id),
"status": "pending",
"_server_id": None,
"field": "The field",
"key": "value",
Expand Down
1 change: 1 addition & 0 deletions argilla/tests/unit/test_io/test_hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_to_datasets_with_partial_values_in_records(self):

ds = HFDatasetsIO.to_datasets(records)
assert ds.features == {
"status": Value(dtype="string", id=None),
"_server_id": Value(dtype="null", id=None),
"a": Value(dtype="string", id=None),
"b": Value(dtype="string", id=None),
Expand Down
10 changes: 10 additions & 0 deletions argilla/tests/unit/test_resources/test_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import uuid

import pytest

from argilla import Record, Suggestion, Response
from argilla._models import MetadataModel

Expand All @@ -31,6 +33,7 @@ def test_record_repr(self):
)
assert (
record.__repr__() == f"Record(id={record_id},"
"status=pending,"
"fields={'name': 'John', 'age': '30'},"
"metadata={'key': 'value'},"
"suggestions={'question': {'value': 'answer', 'score': None, 'agent': None}},"
Expand Down Expand Up @@ -62,3 +65,10 @@ def test_update_record_vectors(self):

record.vectors["new-vector"] = [1.0, 2.0, 3.0]
assert record.vectors == {"vector": [1.0, 2.0, 3.0], "new-vector": [1.0, 2.0, 3.0]}

def test_prevent_update_record(self):
record = Record(fields={"name": "John"})
assert record.status == "pending"

with pytest.raises(AttributeError):
record.status = "completed"

0 comments on commit c219764

Please sign in to comment.