-
Notifications
You must be signed in to change notification settings - Fork 35
/
ops_debug.py
173 lines (146 loc) · 5.49 KB
/
ops_debug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
from abc import abstractmethod
from typing import Hashable, List, Optional, Sequence
import numpy
import torch
from fuse.data import OpBase
from fuse.data.utils.sample import get_sample_id
from fuse.utils import NDict
class OpDebugBase(OpBase):
"""
Base class for debug operations.
Provides the ability to limit samples to debug (will debug the first k samples).
Inherits and implements self.call_debug instead of self.__call__.
"""
def __init__(
self,
name: Optional[str] = None,
sample_ids: Optional[List[Hashable]] = None,
num_samples: int = 0,
):
"""
:param name: string identifier - might be useful when the debug op display or save information into a file
:param sample_ids: apply for the specified sample ids. To apply for all set to None.
:param num_samples: apply for the first num_samples (per process). if 0, will apply for all.
"""
super().__init__()
self._name = name
self._sample_ids = sample_ids
self._num_samples = num_samples
self._num_samples_done = 0
def reset(self, name: Optional[str] = None) -> None:
"""Reset operation state"""
self._num_samples_done = 0
self._name = name
def should_debug_sample(self, sample_dict: NDict) -> bool:
if self._num_samples and self._num_samples_done >= self._num_samples:
return False
if self._sample_ids is not None:
sid = get_sample_id(sample_dict)
if sid not in self._sample_ids:
return False
self._num_samples_done += 1
return True
def __call__(self, sample_dict: NDict, **kwargs: dict) -> NDict:
if self.should_debug_sample(sample_dict):
self.call_debug(sample_dict, **kwargs)
return sample_dict
@abstractmethod
def call_debug(self, sample_dict: NDict, **kwargs: dict) -> None:
"""The actual debug op implementation"""
raise NotImplementedError
class OpPrintKeys(OpDebugBase):
"""
Print list of available keys at a given point in the data pipeline
It's recommended, but not a must, to run it in a single process.
```
from fuse.utils.utils_debug import FuseDebug
FuseDebug("debug")
```
Example:
```
(OpPrintKeys(num_samples=1), dict()),
```
"""
def call_debug(self, sample_dict: NDict) -> None:
print(f"Sample {get_sample_id(sample_dict)} keys:")
for key in sample_dict.keypaths():
print(f"\t{key}")
class OpPrintShapes(OpDebugBase):
"""
Print the shapes/length of every torch tensor / numpy array / sequence
Add at the top your script to force single process:
```
from fuse.utils.utils_debug import FuseDebug
FuseDebug("debug")
```
Example:
```
(OpPrintShapes(num_samples=1), dict()),
```
"""
def call_debug(self, sample_dict: NDict, include_strings: bool = False) -> None:
"""
:param include_strings: set to True if you want to see strings lengths.
useful when dealing with sequences such as SMILES.
"""
print(f"Sample {get_sample_id(sample_dict)} shapes:")
for key in sample_dict.keypaths():
value = sample_dict[key]
if isinstance(value, torch.Tensor):
print(f"\t{key} is tensor with shape: {value.shape}")
elif isinstance(value, numpy.ndarray):
print(f"\t{key} is numpy array with shape: {value.shape}")
elif not isinstance(value, str) and isinstance(value, Sequence):
print(f"\t{key} is sequence with length: {len(value)}")
elif isinstance(value, str) and include_strings:
print(f"\t{key} is string with length: {len(value)}")
class OpPrintTypes(OpDebugBase):
"""
Print the the type of each key
Add at the top your script to force single process:
```
from fuse.utils.utils_debug import FuseDebug
FuseDebug("debug")
```
Example:
```
(OpPrintTypes(num_samples=1), dict()),
```
"""
def call_debug(self, sample_dict: NDict) -> None:
print(f"Sample {get_sample_id(sample_dict)} types:")
for key in sample_dict.keypaths():
value = sample_dict[key]
if isinstance(value, torch.Tensor) or isinstance(value, numpy.ndarray):
print(f"\t{key} - {type(value).__name__} with dtype: {value.dtype}")
else:
print(f"\t{key} - {type(value).__name__}")
class OpPrintKeysContent(OpDebugBase):
"""
Print sample's keys and contents.
It's recommended, but not a must, to run it in a single process.
```
from fuse.utils.utils_debug import FuseDebug
FuseDebug("debug")
```
Example:
```
(OpPrintKeysContent(num_samples=1), dict(keys=["data.input.mri_path", "data.label"])),
(OpPrintKeysContent(num_samples=1), dict(keys=None)),
```
"""
def call_debug(self, sample_dict: NDict, keys: List[str]) -> None:
"""
:param keys: List of keys to print. Set to 'None' to print all keys.
"""
print(
f"OpPrintKeysContent, sample '{get_sample_id(sample_dict)}', <key> = <content>:"
)
dict_keys = sample_dict.keypaths()
if keys is None:
keys = dict_keys # print all keys
for key in keys:
if key in dict_keys:
print(f"\t{key} = {sample_dict[key]}")
else:
print(f"\tWARNING! {key} not in sample_dict")