Skip to content

Commit

Permalink
Refactor ImageDataLayer to use DataTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
arntanguy committed Aug 21, 2014
1 parent 413940f commit 110558e
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 117 deletions.
6 changes: 4 additions & 2 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ template <typename Dtype>
class ImageDataLayer : public Layer<Dtype>, public InternalThread {
public:
explicit ImageDataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
: Layer<Dtype>(param),
data_transformer_(param.data_param().transform_param()) {}
virtual ~ImageDataLayer();
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
Expand All @@ -203,10 +204,11 @@ class ImageDataLayer : public Layer<Dtype>, public InternalThread {

virtual void CreatePrefetchThread();
virtual void JoinPrefetchThread();
virtual unsigned int PrefetchRand();
virtual void InternalThreadEntry();

DataTransformer<Dtype> data_transformer_;
shared_ptr<Caffe::RNG> prefetch_rng_;

vector<std::pair<std::string, int> > lines_;
int lines_id_;
int datum_channels_;
Expand Down
1 change: 0 additions & 1 deletion src/caffe/data_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ void DataTransformer<Dtype>::Transform(const int batch_item_id,
const Datum& datum,
const Dtype* mean,
Dtype* transformed_data) {

const string& data = datum.data();
const int channels = datum.channels();
const int height = datum.height();
Expand Down
97 changes: 11 additions & 86 deletions src/caffe/layers/image_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,11 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
Dtype* top_data = prefetch_data_.mutable_cpu_data();
Dtype* top_label = prefetch_label_.mutable_cpu_data();
ImageDataParameter image_data_param = this->layer_param_.image_data_param();
const Dtype scale = image_data_param.scale();
const int batch_size = image_data_param.batch_size();
const int crop_size = image_data_param.crop_size();
const bool mirror = image_data_param.mirror();
const int new_height = image_data_param.new_height();
const int new_width = image_data_param.new_width();

if (mirror && crop_size == 0) {
LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
<< "set at the same time.";
}
// datum scales
const int channels = datum_channels_;
const int height = datum_height_;
const int width = datum_width_;
const int size = datum_size_;
const int lines_size = lines_.size();
const Dtype* mean = data_mean_.cpu_data();
for (int item_id = 0; item_id < batch_size; ++item_id) {
Expand All @@ -46,62 +35,9 @@ void ImageDataLayer<Dtype>::InternalThreadEntry() {
new_height, new_width, &datum)) {
continue;
}
const string& data = datum.data();
if (crop_size) {
CHECK(data.size()) << "Image cropping only support uint8 data";
int h_off, w_off;
// We only do random crop when we do training.
if (phase_ == Caffe::TRAIN) {
h_off = PrefetchRand() % (height - crop_size);
w_off = PrefetchRand() % (width - crop_size);
} else {
h_off = (height - crop_size) / 2;
w_off = (width - crop_size) / 2;
}
if (mirror && PrefetchRand() % 2) {
// Copy mirrored version
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((item_id * channels + c) * crop_size + h)
* crop_size + (crop_size - 1 - w);
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
top_data[top_index] = (datum_element - mean[data_index]) * scale;
}
}
}
} else {
// Normal copy
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((item_id * channels + c) * crop_size + h)
* crop_size + w;
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
top_data[top_index] = (datum_element - mean[data_index]) * scale;
}
}
}
}
} else {
// Just copy the whole data
if (data.size()) {
for (int j = 0; j < size; ++j) {
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[j]));
top_data[item_id * size + j] = (datum_element - mean[j]) * scale;
}
} else {
for (int j = 0; j < size; ++j) {
top_data[item_id * size + j] =
(datum.float_data(j) - mean[j]) * scale;
}
}
}

// Apply transformations (mirror, crop...) to the data
data_transformer_.Transform(item_id, datum, mean, top_data);

top_label[item_id] = datum.label();
// go to the next iter
Expand Down Expand Up @@ -163,9 +99,11 @@ void ImageDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK(ReadImageToDatum(lines_[lines_id_].first, lines_[lines_id_].second,
new_height, new_width, &datum));
// image
const int crop_size = this->layer_param_.image_data_param().crop_size();
const int crop_size = this->layer_param_.image_data_param()
.transform_param().crop_size();
const int batch_size = this->layer_param_.image_data_param().batch_size();
const string& mean_file = this->layer_param_.image_data_param().mean_file();
const string& mean_file = this->layer_param_.image_data_param()
.transform_param().mean_file();
if (crop_size > 0) {
(*top)[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size);
prefetch_data_.Reshape(batch_size, datum.channels(), crop_size, crop_size);
Expand All @@ -189,7 +127,7 @@ void ImageDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_GT(datum_height_, crop_size);
CHECK_GT(datum_width_, crop_size);
// check if we want to have mean
if (this->layer_param_.image_data_param().has_mean_file()) {
if (this->layer_param_.image_data_param().transform_param().has_mean_file()) {
BlobProto blob_proto;
LOG(INFO) << "Loading mean file from" << mean_file;
ReadProtoFromBinaryFile(mean_file.c_str(), &blob_proto);
Expand Down Expand Up @@ -217,15 +155,9 @@ void ImageDataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
template <typename Dtype>
void ImageDataLayer<Dtype>::CreatePrefetchThread() {
phase_ = Caffe::phase();
const bool prefetch_needs_rand =
this->layer_param_.image_data_param().shuffle() ||
this->layer_param_.image_data_param().crop_size();
if (prefetch_needs_rand) {
const unsigned int prefetch_rng_seed = caffe_rng_rand();
prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
} else {
prefetch_rng_.reset();
}

data_transformer_.InitRand();

// Create the thread.
CHECK(!StartInternalThread()) << "Pthread execution failed";
}
Expand All @@ -243,13 +175,6 @@ void ImageDataLayer<Dtype>::JoinPrefetchThread() {
CHECK(!WaitForInternalThreadToExit()) << "Pthread joining failed";
}

template <typename Dtype>
unsigned int ImageDataLayer<Dtype>::PrefetchRand() {
caffe::rng_t* prefetch_rng =
static_cast<caffe::rng_t*>(prefetch_rng_->generator());
return (*prefetch_rng)();
}

template <typename Dtype>
void ImageDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Expand Down
25 changes: 9 additions & 16 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ message ConvolutionParameter {
}

// Message that stores parameters used to apply transformation
// to the data layer's data
// to the data layer's data
message TransformationParameter {
// For data pre-processing, we can do simple scaling and subtracting the
// data mean, if provided. Note that the mean subtraction is always carried
Expand Down Expand Up @@ -444,27 +444,20 @@ message HingeLossParameter {
message ImageDataParameter {
// Specify the data source.
optional string source = 1;
// For data pre-processing, we can do simple scaling and subtracting the
// data mean, if provided. Note that the mean subtraction is always carried
// out before scaling.
optional float scale = 2 [default = 1];
optional string mean_file = 3;
// Specify the batch size.
optional uint32 batch_size = 4;
// Specify if we would like to randomly crop an image.
optional uint32 crop_size = 5 [default = 0];
// Specify if we want to randomly mirror data.
optional bool mirror = 6 [default = false];
optional uint32 batch_size = 2;
// The rand_skip variable is for the data layer to skip a few data points
// to avoid all asynchronous sgd clients to start at the same point. The skip
// point would be set as rand_skip * rand(0,1). Note that rand_skip should not
// be larger than the number of keys in the leveldb.
optional uint32 rand_skip = 7 [default = 0];
optional uint32 rand_skip = 3 [default = 0];
// Whether or not ImageLayer should shuffle the list of files at every epoch.
optional bool shuffle = 8 [default = false];
optional bool shuffle = 4 [default = false];
// It will also resize images if new_height or new_width are not zero.
optional uint32 new_height = 9 [default = 0];
optional uint32 new_width = 10 [default = 0];
optional uint32 new_height = 5 [default = 0];
optional uint32 new_width = 6 [default = 0];
// Parameters for data pre-processing.
optional TransformationParameter transform_param = 7;
}

// Message that stores parameters InfogainLossLayer
Expand Down Expand Up @@ -505,7 +498,7 @@ message MemoryDataParameter {
message MVNParameter {
// This parameter can be set to false to normalize mean only
optional bool normalize_variance = 1 [default = true];

// This parameter can be set to true to perform DNN-like MVN
optional bool across_channels = 2 [default = false];
}
Expand Down
10 changes: 6 additions & 4 deletions src/caffe/test/test_upgrade_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1542,11 +1542,13 @@ TEST_F(V0UpgradeTest, TestAllParams) {
" type: IMAGE_DATA "
" image_data_param { "
" source: '/home/jiayq/Data/ILSVRC12/train-images' "
" mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' "
" batch_size: 256 "
" crop_size: 227 "
" mirror: true "
" scale: 0.25 "
" transform_param {"
" mean_file: '/home/jiayq/Data/ILSVRC12/image_mean.binaryproto' "
" crop_size: 227 "
" mirror: true "
" scale: 0.25 "
" } "
" rand_skip: 73 "
" shuffle: true "
" new_height: 40 "
Expand Down
16 changes: 8 additions & 8 deletions src/caffe/util/upgrade_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection,
layer_param->mutable_data_param()->mutable_transform_param()->
set_scale(v0_layer_param.scale());
} else if (type == "images") {
layer_param->mutable_image_data_param()->set_scale(
v0_layer_param.scale());
layer_param->mutable_image_data_param()->mutable_transform_param()->
set_scale(v0_layer_param.scale());
} else {
LOG(ERROR) << "Unknown parameter scale for layer type " << type;
is_fully_compatible = false;
Expand All @@ -322,8 +322,8 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection,
layer_param->mutable_data_param()->mutable_transform_param()->
set_mean_file(v0_layer_param.meanfile());
} else if (type == "images") {
layer_param->mutable_image_data_param()->set_mean_file(
v0_layer_param.meanfile());
layer_param->mutable_image_data_param()->mutable_transform_param()->
set_mean_file(v0_layer_param.meanfile());
} else if (type == "window_data") {
layer_param->mutable_window_data_param()->set_mean_file(
v0_layer_param.meanfile());
Expand Down Expand Up @@ -355,8 +355,8 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection,
layer_param->mutable_data_param()->mutable_transform_param()->
set_crop_size(v0_layer_param.cropsize());
} else if (type == "images") {
layer_param->mutable_image_data_param()->set_crop_size(
v0_layer_param.cropsize());
layer_param->mutable_image_data_param()->mutable_transform_param()->
set_crop_size(v0_layer_param.cropsize());
} else if (type == "window_data") {
layer_param->mutable_window_data_param()->set_crop_size(
v0_layer_param.cropsize());
Expand All @@ -370,8 +370,8 @@ bool UpgradeLayerParameter(const LayerParameter& v0_layer_connection,
layer_param->mutable_data_param()->mutable_transform_param()->
set_mirror(v0_layer_param.mirror());
} else if (type == "images") {
layer_param->mutable_image_data_param()->set_mirror(
v0_layer_param.mirror());
layer_param->mutable_image_data_param()->mutable_transform_param()->
set_mirror(v0_layer_param.mirror());
} else if (type == "window_data") {
layer_param->mutable_window_data_param()->set_mirror(
v0_layer_param.mirror());
Expand Down

0 comments on commit 110558e

Please sign in to comment.