Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit dc5d315

Browse files
committed
Support integer type in ImageIter
1 parent 64d2e8b commit dc5d315

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

python/mxnet/image/image.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1057,14 +1057,16 @@ class ImageIter(io.DataIter):
10571057
Data name for provided symbols.
10581058
label_name : str
10591059
Label name for provided symbols.
1060+
dtype : str
1061+
Label data type. Default: float32. Other options: int32, int64, float64
10601062
kwargs : ...
10611063
More arguments for creating augmenter. See mx.image.CreateAugmenter.
10621064
"""
10631065

10641066
def __init__(self, batch_size, data_shape, label_width=1,
10651067
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
10661068
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
1067-
data_name='data', label_name='softmax_label', **kwargs):
1069+
data_name='data', label_name='softmax_label', dtype='float32', **kwargs):
10681070
super(ImageIter, self).__init__()
10691071
assert path_imgrec or path_imglist or (isinstance(imglist, list))
10701072
num_threads = os.environ.get('MXNET_CPU_WORKER_NTHREADS', 1)
@@ -1091,7 +1093,7 @@ def __init__(self, batch_size, data_shape, label_width=1,
10911093
imgkeys = []
10921094
for line in iter(fin.readline, ''):
10931095
line = line.strip().split('\t')
1094-
label = nd.array([float(i) for i in line[1:-1]])
1096+
label = nd.array(line[1:-1], dtype=dtype)
10951097
key = int(line[0])
10961098
imglist[key] = (label, line[-1])
10971099
imgkeys.append(key)
@@ -1105,11 +1107,11 @@ def __init__(self, batch_size, data_shape, label_width=1,
11051107
key = str(index) # pylint: disable=redefined-variable-type
11061108
index += 1
11071109
if len(img) > 2:
1108-
label = nd.array(img[:-1])
1110+
label = nd.array(img[:-1], dtype=dtype)
11091111
elif isinstance(img[0], numeric_types):
1110-
label = nd.array([img[0]])
1112+
label = nd.array([img[0]], dtype=dtype)
11111113
else:
1112-
label = nd.array(img[0])
1114+
label = nd.array(img[0], dtype=dtype)
11131115
result[key] = (label, img[-1])
11141116
imgkeys.append(str(key))
11151117
self.imglist = result

tests/python/unittest/test_image.py

+21-17
Original file line numberDiff line numberDiff line change
@@ -132,26 +132,30 @@ def test_color_normalize(self):
132132

133133

134134
def test_imageiter(self):
135-
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
136-
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
137-
path_root='')
138-
for _ in range(3):
135+
def check_imageiter(dtype='float32'):
136+
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
137+
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
138+
path_root='', dtype=dtype)
139+
for _ in range(3):
140+
for batch in test_iter:
141+
pass
142+
test_iter.reset()
143+
144+
# test with list file
145+
fname = './data/test_imageiter.lst'
146+
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
147+
for k, x in enumerate(TestImage.IMAGES)]
148+
with open(fname, 'w') as f:
149+
for line in file_list:
150+
f.write(line + '\n')
151+
152+
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
153+
path_root='', dtype=dtype)
139154
for batch in test_iter:
140155
pass
141-
test_iter.reset()
142-
143-
# test with list file
144-
fname = './data/test_imageiter.lst'
145-
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
146-
for k, x in enumerate(TestImage.IMAGES)]
147-
with open(fname, 'w') as f:
148-
for line in file_list:
149-
f.write(line + '\n')
150156

151-
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
152-
path_root='')
153-
for batch in test_iter:
154-
pass
157+
for dtype in ['int32', 'float32', 'int64', 'float64']:
158+
check_imageiter(dtype)
155159

156160
@with_seed()
157161
def test_augmenters(self):

0 commit comments

Comments
 (0)