-
Notifications
You must be signed in to change notification settings - Fork 22
/
netmisc.py
58 lines (48 loc) · 1.43 KB
/
netmisc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Miscellaneous functions for the network
import torch
from torch import nn
import vconv
from sys import stderr
import sys
import re
import collections as col
def xavier_init(mod):
if hasattr(mod, 'weight') and mod.weight is not None:
nn.init.xavier_uniform_(mod.weight)
if hasattr(mod, 'bias') and mod.bias is not None:
nn.init.constant_(mod.bias, 0)
this = sys.modules[__name__]
this.print_iter = 0
def set_print_iter(pos):
this.print_iter = pos
def print_metrics(metrics, worker_index, hdr_frequency):
"""
Flexibly prints a polymorphic set of metrics
"""
nlstrip = re.compile('\\n\s+')
sep = ''
h = ''
s = ''
d = col.OrderedDict({'w_idx': worker_index})
d.update(metrics)
max_width = 12
for k, v in d.items():
if isinstance(v, torch.Tensor) and v.numel() == 1:
v = v.item()
if isinstance(v, int):
fmt = '{:d}'
elif isinstance(v, float):
fmt = '{:.3}' if v < 1e-2 else '{:.3f}'
else:
fmt = '{}'
val = nlstrip.sub(' ', fmt.format(v))
if len(val) > max_width and not isinstance(v, torch.Tensor):
val = '~' + val[-(max_width-1):]
s += sep + val
h += f'{sep}{k}'
sep = '\t'
if this.print_iter % hdr_frequency == 0 and worker_index == 0:
print(h, file=stderr)
print(s, file=stderr)
this.print_iter += 1
stderr.flush()