diff --git a/scenic/dataset_lib/big_transfer/bit.py b/scenic/dataset_lib/big_transfer/bit.py index 9fcaaf8d..543f844d 100644 --- a/scenic/dataset_lib/big_transfer/bit.py +++ b/scenic/dataset_lib/big_transfer/bit.py @@ -85,7 +85,16 @@ def pp_fn(x, how): pp = builder.get_preprocess_fn(how) example = pp(x) # to scenic format - return {'inputs': example['image'], 'label': example['labels']} + if dataset_configs.dataset == 'imagenet2012' and 'file_name' in example: + return { + 'inputs': example['image'], + 'label': example['labels'], + 'file_name': example['file_name'], + } + return { + 'inputs': example['image'], + 'label': example['labels'], + } # E.g. for testing with TAP. shuffle_buffer_size = (1000 if num_shards == 1 else diff --git a/scenic/dataset_lib/big_transfer/builder.py b/scenic/dataset_lib/big_transfer/builder.py index 79dd8f85..b1f0df91 100644 --- a/scenic/dataset_lib/big_transfer/builder.py +++ b/scenic/dataset_lib/big_transfer/builder.py @@ -23,7 +23,7 @@ TPU_SUPPORTED_DTYPES = [ tf.bool, tf.int32, tf.int64, tf.bfloat16, tf.float32, tf.complex64, - tf.uint32 + tf.uint32, tf.string, ]