-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
93 lines (82 loc) · 3.4 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import print_function
import torch
from torch._six import string_classes
import collections
import errno
import os
from PIL import Image
import torch
import torch.nn as nn
import re
import json
import pickle as cPickle
import numpy as np
import utils
import h5py
import operator
import functools
from torch._six import string_classes
import torch.nn.functional as F
import collections
from torch.utils.data.dataloader import default_collate
def save_model(path, model, epoch, optimizer=None):
model_dict = {
'epoch': epoch,
'model_state': model.state_dict()
}
if optimizer is not None:
model_dict['optimizer_state'] = optimizer.state_dict()
torch.save(model_dict, path)
def trim_collate(batch, pad=True):
# breakpoint()
"Puts each data field into a tensor with outer dimension batch size"
_use_shared_memory = False
error_msg = "batch must contain tensors, numbers, dicts or lists; found {}"
elem_type = type(batch[0])
# print(elem_type)
if torch.is_tensor(batch[0]):
out = None
#print("batch[0].dim()",len(batch[0]),batch[0].dim(),batch[0])
if 1 < batch[0].dim(): # image features
max_num_boxes = max([x.size(0) for x in batch])
# print(max_num_boxes)
if _use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = len(batch) * max_num_boxes * batch[0].size(-1)
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
# warning: F.pad returns Variable!
return torch.stack([F.pad(x, (0,0,0,max_num_boxes-x.size(0))).data for x in batch], 0, out=out)
else:
if _use_shared_memory:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = batch[0].storage()._new_shared(numel)
out = batch[0].new(storage)
#print("batch",batch,"\n\n\n",len(batch))
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if re.search('[SaUO]', elem.dtype.str) is not None:
raise TypeError(error_msg.format(elem.dtype))
return torch.stack([torch.from_numpy(b) for b in batch], 0)
if elem.shape == (): # scalars
py_type = float if elem.dtype.name.startswith('float') else int
return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
elif isinstance(batch[0], int):
return torch.LongTensor(batch)
elif isinstance(batch[0], float):
return torch.DoubleTensor(batch)
elif isinstance(batch[0], string_classes):
return batch
elif isinstance(batch[0], collections.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [trim_collate(samples) for samples in transposed]
raise TypeError((error_msg.format(type(batch[0]))))