Skip to content

Commit

Permalink
Merge pull request BVLC#954 from geenux/dev-redundant-data
Browse files Browse the repository at this point in the history
Refactor data layers to avoid duplication of data transformation code
  • Loading branch information
shelhamer committed Aug 22, 2014
2 parents 5436ade + 110558e commit f868906
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 249 deletions.
14 changes: 8 additions & 6 deletions include/caffe/data_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/data_transformer.hpp"
#include "caffe/filler.hpp"
#include "caffe/internal_thread.hpp"
#include "caffe/layer.hpp"
Expand All @@ -24,12 +25,12 @@ namespace caffe {

// TODO: DataLayer, ImageDataLayer, and WindowDataLayer all have the
// same basic structure and a lot of duplicated code.

template <typename Dtype>
class DataLayer : public Layer<Dtype>, public InternalThread {
public:
explicit DataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
: Layer<Dtype>(param),
data_transformer_(param.data_param().transform_param()) {}
virtual ~DataLayer();
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
Expand All @@ -53,11 +54,10 @@ class DataLayer : public Layer<Dtype>, public InternalThread {

virtual void CreatePrefetchThread();
virtual void JoinPrefetchThread();
virtual unsigned int PrefetchRand();
// The thread's function
virtual void InternalThreadEntry();

shared_ptr<Caffe::RNG> prefetch_rng_;
DataTransformer<Dtype> data_transformer_;

// LEVELDB
shared_ptr<leveldb::DB> db_;
Expand Down Expand Up @@ -178,7 +178,8 @@ template <typename Dtype>
class ImageDataLayer : public Layer<Dtype>, public InternalThread {
public:
explicit ImageDataLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
: Layer<Dtype>(param),
data_transformer_(param.data_param().transform_param()) {}
virtual ~ImageDataLayer();
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
Expand All @@ -203,10 +204,11 @@ class ImageDataLayer : public Layer<Dtype>, public InternalThread {

virtual void CreatePrefetchThread();
virtual void JoinPrefetchThread();
virtual unsigned int PrefetchRand();
virtual void InternalThreadEntry();

DataTransformer<Dtype> data_transformer_;
shared_ptr<Caffe::RNG> prefetch_rng_;

vector<std::pair<std::string, int> > lines_;
int lines_id_;
int datum_channels_;
Expand Down
55 changes: 55 additions & 0 deletions include/caffe/data_transformer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef CAFFE_DATA_TRANSFORMER_HPP
#define CAFFE_DATA_TRANSFORMER_HPP

#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"

namespace caffe {

/**
* @brief Applies common transformations to the input data, such as
* scaling, mirroring, substracting the image mean...
*/
template <typename Dtype>
class DataTransformer {
public:
explicit DataTransformer(const TransformationParameter& param)
: param_(param) {
phase_ = Caffe::phase();
}
virtual ~DataTransformer() {}

void InitRand();

/**
* @brief Applies the transformation defined in the data layer's
* transform_param block to the data.
*
* @param batch_item_id
* Datum position within the batch. This is used to compute the
* writing position in the top blob's data
* @param datum
* Datum containing the data to be transformed.
* @param mean
* @param top_data
* This is meant to be the top blob's data. The transformed data will be
* written at the appropriate place within the blob's data.
*/
void Transform(const int batch_item_id, const Datum& datum,
const Dtype* mean, Dtype* transformed_data);

protected:
virtual unsigned int Rand();

// Tranformation parameters
TransformationParameter param_;


shared_ptr<Caffe::RNG> rng_;
Caffe::Phase phase_;
};

} // namespace caffe

#endif // CAFFE_DATA_TRANSFORMER_HPP_

113 changes: 113 additions & 0 deletions src/caffe/data_transformer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#include <string>

#include "caffe/data_transformer.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"

namespace caffe {

template<typename Dtype>
void DataTransformer<Dtype>::Transform(const int batch_item_id,
const Datum& datum,
const Dtype* mean,
Dtype* transformed_data) {
const string& data = datum.data();
const int channels = datum.channels();
const int height = datum.height();
const int width = datum.width();
const int size = datum.channels() * datum.height() * datum.width();

const int crop_size = param_.crop_size();
const bool mirror = param_.mirror();
const Dtype scale = param_.scale();



if (mirror && crop_size == 0) {
LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
<< "set at the same time.";
}

if (crop_size) {
CHECK(data.size()) << "Image cropping only support uint8 data";
int h_off, w_off;
// We only do random crop when we do training.
if (phase_ == Caffe::TRAIN) {
h_off = Rand() % (height - crop_size);
w_off = Rand() % (width - crop_size);
} else {
h_off = (height - crop_size) / 2;
w_off = (width - crop_size) / 2;
}
if (mirror && Rand() % 2) {
// Copy mirrored version
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int data_index = (c * height + h + h_off) * width + w + w_off;
int top_index = ((batch_item_id * channels + c) * crop_size + h)
* crop_size + (crop_size - 1 - w);
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale;
}
}
}
} else {
// Normal copy
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((batch_item_id * channels + c) * crop_size + h)
* crop_size + w;
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale;
}
}
}
}
} else {
// we will prefer to use data() first, and then try float_data()
if (data.size()) {
for (int j = 0; j < size; ++j) {
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[j]));
transformed_data[j + batch_item_id * size] =
(datum_element - mean[j]) * scale;
}
} else {
for (int j = 0; j < size; ++j) {
transformed_data[j + batch_item_id * size] =
(datum.float_data(j) - mean[j]) * scale;
}
}
}
}

template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {
const bool needs_rand = (phase_ == Caffe::TRAIN) &&
(param_.mirror() || param_.crop_size());
if (needs_rand) {
const unsigned int rng_seed = caffe_rng_rand();
rng_.reset(new Caffe::RNG(rng_seed));
} else {
rng_.reset();
}
}

template <typename Dtype>
unsigned int DataTransformer<Dtype>::Rand() {
CHECK(rng_);
caffe::rng_t* rng =
static_cast<caffe::rng_t*>(rng_->generator());
return (*rng)();
}

INSTANTIATE_CLASS(DataTransformer);

} // namespace caffe
98 changes: 10 additions & 88 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,8 @@ void DataLayer<Dtype>::InternalThreadEntry() {
if (output_labels_) {
top_label = prefetch_label_.mutable_cpu_data();
}
const Dtype scale = this->layer_param_.data_param().scale();
const int batch_size = this->layer_param_.data_param().batch_size();
const int crop_size = this->layer_param_.data_param().crop_size();
const bool mirror = this->layer_param_.data_param().mirror();

if (mirror && crop_size == 0) {
LOG(FATAL) << "Current implementation requires mirror and crop_size to be "
<< "set at the same time.";
}
// datum scales
const int channels = datum_channels_;
const int height = datum_height_;
const int width = datum_width_;
const int size = datum_size_;
const Dtype* mean = data_mean_.cpu_data();
for (int item_id = 0; item_id < batch_size; ++item_id) {
// get a blob
Expand All @@ -56,66 +44,13 @@ void DataLayer<Dtype>::InternalThreadEntry() {
LOG(FATAL) << "Unknown database backend";
}

const string& data = datum.data();
if (crop_size) {
CHECK(data.size()) << "Image cropping only support uint8 data";
int h_off, w_off;
// We only do random crop when we do training.
if (phase_ == Caffe::TRAIN) {
h_off = PrefetchRand() % (height - crop_size);
w_off = PrefetchRand() % (width - crop_size);
} else {
h_off = (height - crop_size) / 2;
w_off = (width - crop_size) / 2;
}
if (mirror && PrefetchRand() % 2) {
// Copy mirrored version
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((item_id * channels + c) * crop_size + h)
* crop_size + (crop_size - 1 - w);
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
top_data[top_index] = (datum_element - mean[data_index]) * scale;
}
}
}
} else {
// Normal copy
for (int c = 0; c < channels; ++c) {
for (int h = 0; h < crop_size; ++h) {
for (int w = 0; w < crop_size; ++w) {
int top_index = ((item_id * channels + c) * crop_size + h)
* crop_size + w;
int data_index = (c * height + h + h_off) * width + w + w_off;
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
top_data[top_index] = (datum_element - mean[data_index]) * scale;
}
}
}
}
} else {
// we will prefer to use data() first, and then try float_data()
if (data.size()) {
for (int j = 0; j < size; ++j) {
Dtype datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[j]));
top_data[item_id * size + j] = (datum_element - mean[j]) * scale;
}
} else {
for (int j = 0; j < size; ++j) {
top_data[item_id * size + j] =
(datum.float_data(j) - mean[j]) * scale;
}
}
}
// Apply data transformations (mirror, scale, crop...)
data_transformer_.Transform(item_id, datum, mean, top_data);

if (output_labels_) {
top_label[item_id] = datum.label();
}

// go to the next iter
switch (this->layer_param_.data_param().backend()) {
case DataParameter_DB_LEVELDB:
Expand Down Expand Up @@ -244,7 +179,7 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
}

// image
int crop_size = this->layer_param_.data_param().crop_size();
int crop_size = this->layer_param_.data_param().transform_param().crop_size();
if (crop_size > 0) {
(*top)[0]->Reshape(this->layer_param_.data_param().batch_size(),
datum.channels(), crop_size, crop_size);
Expand Down Expand Up @@ -274,8 +209,9 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_GT(datum_height_, crop_size);
CHECK_GT(datum_width_, crop_size);
// check if we want to have mean
if (this->layer_param_.data_param().has_mean_file()) {
const string& mean_file = this->layer_param_.data_param().mean_file();
if (this->layer_param_.data_param().transform_param().has_mean_file()) {
const string& mean_file =
this->layer_param_.data_param().transform_param().mean_file();
LOG(INFO) << "Loading mean file from" << mean_file;
BlobProto blob_proto;
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
Expand Down Expand Up @@ -305,15 +241,9 @@ void DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
template <typename Dtype>
void DataLayer<Dtype>::CreatePrefetchThread() {
phase_ = Caffe::phase();
const bool prefetch_needs_rand = (phase_ == Caffe::TRAIN) &&
(this->layer_param_.data_param().mirror() ||
this->layer_param_.data_param().crop_size());
if (prefetch_needs_rand) {
const unsigned int prefetch_rng_seed = caffe_rng_rand();
prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
} else {
prefetch_rng_.reset();
}

data_transformer_.InitRand();

CHECK(!StartInternalThread()) << "Pthread execution failed";
}

Expand All @@ -322,14 +252,6 @@ void DataLayer<Dtype>::JoinPrefetchThread() {
CHECK(!WaitForInternalThreadToExit()) << "Pthread joining failed";
}

template <typename Dtype>
unsigned int DataLayer<Dtype>::PrefetchRand() {
CHECK(prefetch_rng_);
caffe::rng_t* prefetch_rng =
static_cast<caffe::rng_t*>(prefetch_rng_->generator());
return (*prefetch_rng)();
}

template <typename Dtype>
void DataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Expand Down
Loading

0 comments on commit f868906

Please sign in to comment.