diff --git a/explainaboard/loaders/file_loader.py b/explainaboard/loaders/file_loader.py index c514a93a..73bcb851 100644 --- a/explainaboard/loaders/file_loader.py +++ b/explainaboard/loaders/file_loader.py @@ -367,7 +367,7 @@ def validate(self): def load_raw( self, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - data = narrow(data, str) + data = narrow(str, data) if source == Source.in_memory: file = StringIO(data) lines = list(csv.reader(file, delimiter='\t', quoting=csv.QUOTE_NONE)) @@ -396,7 +396,7 @@ def validate(self): def load_raw( self, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - data = narrow(data, str) + data = narrow(str, data) if source == Source.in_memory: return FileLoaderReturn(data.splitlines()) elif source == Source.local_filesystem: @@ -431,7 +431,7 @@ def add_sample(): field.src_name: [] for field in self._fields } # reset - max_field: int = max([narrow(x.src_name, int) for x in self._fields]) + max_field: int = max([narrow(int, x.src_name) for x in self._fields]) for line in raw_data.samples: # at sentence boundary if line.startswith("-DOCSTART-") or line == "" or line == "\n": @@ -447,7 +447,7 @@ def add_sample(): for field in self._fields: curr_sentence_fields[field.src_name].append( - self.parse_data(splits[narrow(field.src_name, int)], field) + self.parse_data(splits[narrow(int, field.src_name)], field) ) add_sample() # add last example @@ -458,7 +458,7 @@ class JSONFileLoader(FileLoader): def load_raw( self, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - data = narrow(data, str) + data = narrow(str, data) if source == Source.in_memory: loaded = json.loads(data) elif source == Source.local_filesystem: @@ -524,7 +524,7 @@ def replace_labels(cls, features: dict, example: dict) -> dict: def load_raw( self, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - config = narrow(data, DatalabLoaderOption) + config = narrow(DatalabLoaderOption, data) dataset = load_dataset( config.dataset, config.subdataset, split=config.split, streaming=False ) @@ -594,7 +594,7 @@ def __init__( def load_raw( cls, data: str | DatalabLoaderOption, source: Source ) -> FileLoaderReturn: - data = narrow(data, str) + data = narrow(str, data) if source == Source.in_memory: return FileLoaderReturn(data.splitlines()) elif source == Source.local_filesystem: diff --git a/explainaboard/utils/typing_utils.py b/explainaboard/utils/typing_utils.py index baccd480..3a151bbe 100644 --- a/explainaboard/utils/typing_utils.py +++ b/explainaboard/utils/typing_utils.py @@ -42,13 +42,27 @@ def unwrap_generator(obj: Optional[Iterable[T]]) -> Generator[T, None, None]: yield from obj -NarrowType = TypeVar("NarrowType") +def narrow(subcls: type[T], obj: Any) -> T: + """Narrow (downcast) an object with a type-safe manner. + This function does the same type casting with ``typing.cast()``, but additionally + checks the actual type of the given object. If the type of the given object is not + castable to the given type, this funtion raises a ``TypeError``. -def narrow(obj: Any, narrow_type: type[NarrowType]) -> NarrowType: - """returns the object with the narrowed type or raises a TypeError - (obj: Any, new_type: type[T]) -> T""" - if isinstance(obj, narrow_type): - return obj - else: - raise TypeError(f"{obj} is expected to be {narrow_type}") + :param subcls: The type that ``obj`` is casted to. + :type subcls: ``type[T]`` + :param obj: The object to be casted. + :type obj: ``Any`` + :return: ``obj`` itself + :rtype: ``T`` + :raises TypeError: ``obj`` is not an object of ``T``. + """ + if not isinstance(obj, subcls): + raise TypeError( + f"{obj.__class__.__name__} is not a subclass of {subcls.__name__}" + ) + + # NOTE(odashi): typing.cast() does not work with TypeVar. + # Simply returning the obj is correct because we already narrowed its type + # by the previous if-statement. + return obj diff --git a/explainaboard/utils/typing_utils_test.py b/explainaboard/utils/typing_utils_test.py index c0df09e7..ad0c9db2 100644 --- a/explainaboard/utils/typing_utils_test.py +++ b/explainaboard/utils/typing_utils_test.py @@ -8,5 +8,5 @@ class TestTypingUtils(unittest.TestCase): def test_narrow(self): a: str | int = 's' - self.assertEqual(narrow(a, str), a) - self.assertRaises(TypeError, lambda: narrow(a, int)) + self.assertEqual(narrow(str, a), a) + self.assertRaises(TypeError, lambda: narrow(int, a))