-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathgeneric.py
129 lines (106 loc) · 4.57 KB
/
generic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import inspect
from collections import OrderedDict, UserDict
from collections.abc import MutableMapping
from contextlib import ExitStack
from dataclasses import fields
from enum import Enum
from typing import Any, ContextManager, List, Tuple
import numpy as np
def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor`, `tf.Tensor`, `jaxlib.xla_extension.DeviceArray` or `np.ndarray`.
"""
if is_torch_fx_proxy(x):
return True
if is_torch_available():
import torch
if isinstance(x, torch.Tensor):
return True
if is_tf_available():
import tensorflow as tf
if isinstance(x, tf.Tensor):
return True
if is_flax_available():
import jax.numpy as jnp
from jax.core import Tracer
if isinstance(x, (jnp.ndarray, Tracer)):
return True
return isinstance(x, np.ndarray)
class ModelOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
python dictionary.
<Tip warning={true}>
You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
before.
</Tip>
"""
def __post_init__(self):
class_fields = fields(self)
# Safety and consistency checks
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
if not all(field.default is None for field in class_fields[1:]):
raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not is_tensor(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
# Will raise a KeyException if needed
super().__setitem__(key, value)
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())