From a8b458f56206c4c7d5238651b10c6dc9130bb6dc Mon Sep 17 00:00:00 2001 From: odashi Date: Tue, 26 Jul 2022 18:48:32 +0000 Subject: [PATCH 1/7] add downcast --- explainaboard/utils/typing_utils.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/explainaboard/utils/typing_utils.py b/explainaboard/utils/typing_utils.py index baccd480..d228e3b3 100644 --- a/explainaboard/utils/typing_utils.py +++ b/explainaboard/utils/typing_utils.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Generator, Iterable -from typing import Any, Optional, TypeVar +from typing import Any, cast, Optional, TypeVar T = TypeVar('T') @@ -42,13 +42,18 @@ def unwrap_generator(obj: Optional[Iterable[T]]) -> Generator[T, None, None]: yield from obj -NarrowType = TypeVar("NarrowType") +def downcast(obj: Any, subcls: type[T]) -> T: + """Downcast the object. + :param obj: The object to be downcasted. + :type obj: ``Any`` + :param subcls: The type that ``obj`` is casted to. + :type subcls: ``type[T]`` + :return: ``obj`` itself + :rtype: ``T`` + :raises TypeError: ``obj`` is not an object of ``T``. + """ + if not isinstance(obj, subcls): + raise TypeError(f"{obj} is not an object of {subcls.__name__}") -def narrow(obj: Any, narrow_type: type[NarrowType]) -> NarrowType: - """returns the object with the narrowed type or raises a TypeError - (obj: Any, new_type: type[T]) -> T""" - if isinstance(obj, narrow_type): - return obj - else: - raise TypeError(f"{obj} is expected to be {narrow_type}") + return cast(obj, subcls) From dd7ca05ba75de4bbfbd9dfb5095f818d4ccff928 Mon Sep 17 00:00:00 2001 From: odashi Date: Tue, 26 Jul 2022 19:00:56 +0000 Subject: [PATCH 2/7] refactor --- explainaboard/utils/typing_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/explainaboard/utils/typing_utils.py b/explainaboard/utils/typing_utils.py index d228e3b3..4c8ad50e 100644 --- a/explainaboard/utils/typing_utils.py +++ b/explainaboard/utils/typing_utils.py @@ -42,10 +42,10 @@ def unwrap_generator(obj: Optional[Iterable[T]]) -> Generator[T, None, None]: yield from obj -def downcast(obj: Any, subcls: type[T]) -> T: - """Downcast the object. +def narrow(obj: Any, subcls: type[T]) -> T: + """Narrow (downcast) an object with a type-safe manner. - :param obj: The object to be downcasted. + :param obj: The object to be casted. :type obj: ``Any`` :param subcls: The type that ``obj`` is casted to. :type subcls: ``type[T]`` From 1fa4f37c7a785db4c4704e2aa9795e78a7c12392 Mon Sep 17 00:00:00 2001 From: odashi Date: Tue, 26 Jul 2022 19:05:43 +0000 Subject: [PATCH 3/7] fix --- explainaboard/utils/typing_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/explainaboard/utils/typing_utils.py b/explainaboard/utils/typing_utils.py index 4c8ad50e..5fc29c2a 100644 --- a/explainaboard/utils/typing_utils.py +++ b/explainaboard/utils/typing_utils.py @@ -54,6 +54,8 @@ def narrow(obj: Any, subcls: type[T]) -> T: :raises TypeError: ``obj`` is not an object of ``T``. """ if not isinstance(obj, subcls): - raise TypeError(f"{obj} is not an object of {subcls.__name__}") + raise TypeError( + f"{obj.__class__.__name__} is not an object of {subcls.__name__}" + ) return cast(obj, subcls) From f85309bedaa4967b045f89b21d6a0bc94c58ca31 Mon Sep 17 00:00:00 2001 From: odashi Date: Tue, 26 Jul 2022 19:06:30 +0000 Subject: [PATCH 4/7] fix --- explainaboard/utils/typing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/explainaboard/utils/typing_utils.py b/explainaboard/utils/typing_utils.py index 5fc29c2a..c0b9d074 100644 --- a/explainaboard/utils/typing_utils.py +++ b/explainaboard/utils/typing_utils.py @@ -55,7 +55,7 @@ def narrow(obj: Any, subcls: type[T]) -> T: """ if not isinstance(obj, subcls): raise TypeError( - f"{obj.__class__.__name__} is not an object of {subcls.__name__}" + f"{obj.__class__.__name__} is not a subclass of {subcls.__name__}" ) return cast(obj, subcls) From 3e656d91e53e505ef8c786eef3b47c6903eac934 Mon Sep 17 00:00:00 2001 From: odashi Date: Tue, 26 Jul 2022 20:37:44 +0000 Subject: [PATCH 5/7] fix --- explainaboard/utils/typing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/explainaboard/utils/typing_utils.py b/explainaboard/utils/typing_utils.py index c0b9d074..bc3281ce 100644 --- a/explainaboard/utils/typing_utils.py +++ b/explainaboard/utils/typing_utils.py @@ -58,4 +58,4 @@ def narrow(obj: Any, subcls: type[T]) -> T: f"{obj.__class__.__name__} is not a subclass of {subcls.__name__}" ) - return cast(obj, subcls) + return cast(subcls, obj) From 6202cee6662834e6ca12c45f52c9a18f1f3755c5 Mon Sep 17 00:00:00 2001 From: odashi Date: Tue, 26 Jul 2022 20:43:01 +0000 Subject: [PATCH 6/7] change argument order --- explainaboard/loaders/file_loader.py | 14 +++++++------- explainaboard/utils/typing_utils.py | 10 +++++++--- explainaboard/utils/typing_utils_test.py | 4 ++-- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/explainaboard/loaders/file_loader.py b/explainaboard/loaders/file_loader.py index c514a93a..73bcb851 100644 --- a/explainaboard/loaders/file_loader.py +++ b/explainaboard/loaders/file_loader.py @@ -367,7 +367,7 @@ def validate(self): def load_raw( self, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - data = narrow(data, str) + data = narrow(str, data) if source == Source.in_memory: file = StringIO(data) lines = list(csv.reader(file, delimiter='\t', quoting=csv.QUOTE_NONE)) @@ -396,7 +396,7 @@ def validate(self): def load_raw( self, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - data = narrow(data, str) + data = narrow(str, data) if source == Source.in_memory: return FileLoaderReturn(data.splitlines()) elif source == Source.local_filesystem: @@ -431,7 +431,7 @@ def add_sample(): field.src_name: [] for field in self._fields } # reset - max_field: int = max([narrow(x.src_name, int) for x in self._fields]) + max_field: int = max([narrow(int, x.src_name) for x in self._fields]) for line in raw_data.samples: # at sentence boundary if line.startswith("-DOCSTART-") or line == "" or line == "\n": @@ -447,7 +447,7 @@ def add_sample(): for field in self._fields: curr_sentence_fields[field.src_name].append( - self.parse_data(splits[narrow(field.src_name, int)], field) + self.parse_data(splits[narrow(int, field.src_name)], field) ) add_sample() # add last example @@ -458,7 +458,7 @@ class JSONFileLoader(FileLoader): def load_raw( self, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - data = narrow(data, str) + data = narrow(str, data) if source == Source.in_memory: loaded = json.loads(data) elif source == Source.local_filesystem: @@ -524,7 +524,7 @@ def replace_labels(cls, features: dict, example: dict) -> dict: def load_raw( self, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - config = narrow(data, DatalabLoaderOption) + config = narrow(DatalabLoaderOption, data) dataset = load_dataset( config.dataset, config.subdataset, split=config.split, streaming=False ) @@ -594,7 +594,7 @@ def __init__( def load_raw( cls, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - data = narrow(data, str) + data = narrow(str, data) if source == Source.in_memory: return FileLoaderReturn(data.splitlines()) elif source == Source.local_filesystem: diff --git a/explainaboard/utils/typing_utils.py b/explainaboard/utils/typing_utils.py index bc3281ce..ca43c97b 100644 --- a/explainaboard/utils/typing_utils.py +++ b/explainaboard/utils/typing_utils.py @@ -42,13 +42,17 @@ def unwrap_generator(obj: Optional[Iterable[T]]) -> Generator[T, None, None]: yield from obj -def narrow(obj: Any, subcls: type[T]) -> T: +def narrow(subcls: type[T], obj: Any) -> T: """Narrow (downcast) an object with a type-safe manner. - :param obj: The object to be casted. - :type obj: ``Any`` + This function does the same type casting with ``typing.cast()``, but additionally + checks the actual type of the given object. If the type of the given object is not + castable to the given type, this funtion raises a ``TypeError``. + :param subcls: The type that ``obj`` is casted to. :type subcls: ``type[T]`` + :param obj: The object to be casted. + :type obj: ``Any`` :return: ``obj`` itself :rtype: ``T`` :raises TypeError: ``obj`` is not an object of ``T``. diff --git a/explainaboard/utils/typing_utils_test.py b/explainaboard/utils/typing_utils_test.py index c0df09e7..ad0c9db2 100644 --- a/explainaboard/utils/typing_utils_test.py +++ b/explainaboard/utils/typing_utils_test.py @@ -8,5 +8,5 @@ class TestTypingUtils(unittest.TestCase): def test_narrow(self): a: str | int = 's' - self.assertEqual(narrow(a, str), a) - self.assertRaises(TypeError, lambda: narrow(a, int)) + self.assertEqual(narrow(str, a), a) + self.assertRaises(TypeError, lambda: narrow(int, a)) From b178f66d25424978c47eccfc8b1f9cc9218b5061 Mon Sep 17 00:00:00 2001 From: odashi Date: Thu, 28 Jul 2022 04:46:54 +0000 Subject: [PATCH 7/7] remove cast --- explainaboard/utils/typing_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/explainaboard/utils/typing_utils.py b/explainaboard/utils/typing_utils.py index ca43c97b..3a151bbe 100644 --- a/explainaboard/utils/typing_utils.py +++ b/explainaboard/utils/typing_utils.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Generator, Iterable -from typing import Any, cast, Optional, TypeVar +from typing import Any, Optional, TypeVar T = TypeVar('T') @@ -62,4 +62,7 @@ def narrow(subcls: type[T], obj: Any) -> T: f"{obj.__class__.__name__} is not a subclass of {subcls.__name__}" ) - return cast(subcls, obj) + # NOTE(odashi): typing.cast() does not work with TypeVar. + # Simply returning the obj is correct because we already narrowed its type + # by the previous if-statement. + return obj