-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensor_data_dict.py
113 lines (92 loc) · 4.64 KB
/
tensor_data_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import dataclasses
from typing import List, Tuple, Dict, TypeVar, final, Iterable, Mapping
from .tensors_data_class_base import TensorsDataClass
from .misc import CollateData, compose_fns
DictKeyT = TypeVar('DictKeyT')
DictValueT = TypeVar('DictValueT')
@final
@dataclasses.dataclass
class TensorsDataDict(TensorsDataClass, Mapping[DictKeyT, DictValueT]):
_dict: Dict[DictKeyT, DictValueT] = dataclasses.field(default_factory=dict)
def __init__(self, dct: Dict[DictKeyT, DictValueT] = None, /, **kwargs):
super(TensorsDataDict, self).__init__()
def __new__(cls, dct: Dict[DictKeyT, DictValueT] = None, /, **kwargs):
obj = super(TensorsDataDict, cls).__new__(cls)
dct = {} if dct is None else dct
dct = {**dct, **kwargs}
obj._dict = dct
return obj
@classmethod
def factory(cls, kwargs: Dict) -> 'TensorsDataDict':
assert {field.name for field in dataclasses.fields(cls) if field.init} == {'_dict'}
assert '_dict' not in kwargs.keys()
dataclass_fields_not_to_init = {
field.name: field for field in dataclasses.fields(cls) if not field.init}
new_obj = cls({
key: val for key, val in kwargs.items()
if key not in dataclass_fields_not_to_init})
for field_name in set(dataclass_fields_not_to_init.keys()) & set(kwargs.keys()):
setattr(new_obj, field_name, kwargs[field_name])
return new_obj
def get_all_fields(self) -> Tuple[str, ...]:
return tuple((set(self._dict.keys()) | set(super(TensorsDataDict, self).get_all_fields())) - {'_dict'})
def get_field_names_by_group(self, group: str = 'all') -> Tuple[str, ...]:
if group == 'data':
return tuple(self._dict.keys())
else:
return super(TensorsDataDict, self).get_field_names_by_group()
@classmethod
def _collate_first_pass(cls, inputs: List['TensorsDataDict'], collate_data: CollateData) -> 'TensorsDataDict':
assert all(isinstance(inp, TensorsDataDict) for inp in inputs)
all_keys = {key for dct in inputs for key in dct._dict.keys()}
batched_obj = TensorsDataDict({
key: cls.collate_values(
tuple(dct._dict[key] for dct in inputs if key in dct._dict),
collate_data=collate_data)
for key in all_keys})
batched_obj._batch_size = len(inputs)
return batched_obj
# TODO: remove this override - unnecessary!
# def post_collate_indices_fix(
# self, parents: Tuple['TensorsDataClass', ...],
# fields_path: Tuple[str, ...], collate_data: CollateData):
# for key, value in self._dict.items():
# if isinstance(value, TensorsDataClass):
# value.post_collate_indices_fix(
# parents=parents + (self,), fields_path=fields_path + (key,),
# collate_data=collate_data)
# TODO: remove this override - unnecessary!
# def post_collate_remove_unnecessary_collate_info(self):
# for key, value in self._dict.items():
# if isinstance(value, TensorsDataClass):
# value.post_collate_remove_unnecessary_collate_info()
def __getitem__(self, item):
# `__getitem__` is overridden to support `lazy_map()`.
if hasattr(self, '_lazy_map_fns_per_field') and item in self._lazy_map_fns_per_field:
old_val = self._dict[item]
composed_map_fn = compose_fns(*(fn for fn, _ in self._lazy_map_fns_per_field[item]))
self._dict[item] = composed_map_fn(old_val)
del self._lazy_map_fns_per_field[item]
if len(self._lazy_map_fns_per_field) == 0:
del self._lazy_map_fns_per_field
self._lazy_map_usage_history.add(item)
return self._dict[item]
def access_field(self, name: str):
if isinstance(name, str) and hasattr(super(TensorsDataDict, self), name):
return super(TensorsDataDict, self).access_field(name)
return self[name]
def access_field_wo_applying_lazy_maps(self, name: str):
if isinstance(name, str) and hasattr(super(TensorsDataDict, self), name):
return super(TensorsDataDict, self).access_field_wo_applying_lazy_maps(name)
return self._dict[name]
def items(self) -> Iterable[Tuple[DictKeyT, DictValueT]]:
for key in self._dict.keys():
yield key, self[key]
def values(self) -> Iterable[DictValueT]:
return (self[key] for key in self.keys())
def keys(self) -> Iterable[DictKeyT]:
return self._dict.keys()
def __len__(self):
return len(self._dict)
def __iter__(self):
return self._dict.keys()