Skip to content


Merge pull request #855 from jeffdonahue/sgd-solver-test
Browse files Browse the repository at this point in the history
SGDSolver tests
  • Loading branch information
jeffdonahue committed Aug 4, 2014
2 parents 5f520af + b08729e commit 7b54598
Show file tree
Hide file tree
Showing 3 changed files with 352 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ message SolverParameter {
// If true, print information about the state of the net that may help with
// debugging learning problems.
optional bool debug_info = 23 [default = false];

// If false, don't save a snapshot after training finishes.
optional bool snapshot_after_train = 28 [default = true];

// A message that stores the solver snapshots
Expand Down
5 changes: 3 additions & 2 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,9 @@ void Solver<Dtype>::Solve(const char* resume_file) {
// Always save a snapshot after optimization.
// Always save a snapshot after optimization, unless overridden by setting
// snapshot_after_train := false.
if (param_.snapshot_after_train()) { Snapshot(); }
// After the optimization is done, run an additional train and test pass to
// display the train and test loss/outputs if appropriate (based on the
// display and test_interval settings, respectively). Unlike in the rest of
Expand Down
346 changes: 346 additions & 0 deletions src/caffe/test/test_sgd_solver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
// Copyright 2014 BVLC and contributors.

#include <algorithm>
#include <string>
#include <utility>
#include <vector>

#include "google/protobuf/text_format.h"

#include "gtest/gtest.h"
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"

#include "caffe/test/test_caffe_main.hpp"

using std::ostringstream;

namespace caffe {

template <typename TypeParam>
class SGDSolverTest : public MultiDeviceTest<TypeParam> {
typedef typename TypeParam::Dtype Dtype;

SGDSolverTest() :
seed_(1701), num_(5), channels_(3), height_(10), width_(10) {}

// MockSGDSolver: an SGDSolver with public history.
class MockSGDSolver : public SGDSolver<Dtype> {
explicit MockSGDSolver(const SolverParameter& param) :
SGDSolver<Dtype>(param) {}
vector<shared_ptr<Blob<Dtype> > >& history() { return this->history_; }

shared_ptr<MockSGDSolver> solver_;
int seed_;
int num_, channels_, height_, width_;

virtual void InitSolverFromProtoString(const string& proto) {
SolverParameter param;
CHECK(google::protobuf::TextFormat::ParseFromString(proto, &param));
// Disable saving a final snapshot so the tests don't pollute the user's
// working directory with useless snapshots.
// Set the solver_mode according to current Caffe::mode.
switch (Caffe::mode()) {
case Caffe::CPU:
case Caffe::GPU:
LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode();
solver_.reset(new MockSGDSolver(param));

void RunLeastSquaresSolver(const Dtype learning_rate,
const Dtype weight_decay, const Dtype momentum, const int num_iters) {
ostringstream proto;
proto <<
"max_iter: " << num_iters << " "
"base_lr: " << learning_rate << " "
"lr_policy: 'fixed' "
"net_param { "
" name: 'TestNetwork' "
" layers: { "
" name: 'data' "
" type: DUMMY_DATA "
" dummy_data_param { "
" num: " << num_ << " "
" channels: " << channels_ << " "
" height: " << height_ << " "
" width: " << width_ << " "
" channels: 1 "
" height: 1 "
" width: 1 "
" data_filler { "
" type: 'gaussian' "
" std: 1.0 "
" } "
" } "
" top: 'data' "
" top: 'targets' "
" } "
" layers: { "
" name: 'innerprod' "
" inner_product_param { "
" num_output: 1 "
" weight_filler { "
" type: 'gaussian' "
" std: 1.0 "
" } "
" bias_filler { "
" type: 'gaussian' "
" std: 1.0 "
" } "
" } "
" bottom: 'data' "
" top: 'innerprod' "
" } "
" layers: { "
" name: 'loss' "
" bottom: 'innerprod' "
" bottom: 'targets' "
" } "
"} ";
if (weight_decay != 0) {
proto << "weight_decay: " << weight_decay << " ";
if (momentum != 0) {
proto << "momentum: " << momentum << " ";

// Compute an update value given the current state of the train net,
// using the analytical formula for the least squares gradient.
// updated_params will store the updated weight and bias results,
// using the blobs' diffs to hold the update values themselves.
void ComputeLeastSquaresUpdate(const Dtype learning_rate,
const Dtype weight_decay, const Dtype momentum,
vector<shared_ptr<Blob<Dtype> > >* updated_params) {
const int N = num_;
const int D = channels_ * height_ * width_;

// Run a forward pass, and manually compute the update values from the
// result.
Net<Dtype>& net = *this->solver_->net();
vector<Blob<Dtype>*> empty_bottom_vec;
const Blob<Dtype>& data = *net.blob_by_name("data");
const Blob<Dtype>& targets = *net.blob_by_name("targets");
const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
const int num_param_blobs = 2;
ASSERT_EQ(num_param_blobs, param_blobs.size());
const Blob<Dtype>& weights = *param_blobs[0];
const Blob<Dtype>& bias = *param_blobs[1];
ASSERT_EQ(D * N, data.count());
ASSERT_EQ(N, targets.count());
ASSERT_EQ(D, weights.count());
ASSERT_EQ(1, bias.count());

for (int i = 0; i < num_param_blobs; ++i) {
(*updated_params)[i].reset(new Blob<Dtype>());
Blob<Dtype>& updated_weights = *(*updated_params)[0];
Blob<Dtype>& updated_bias = *(*updated_params)[1];

for (int i = 0; i <= D; ++i) {
// Compute the derivative with respect to the ith weight (i.e., the ith
// element of the gradient).
Dtype grad = 0;
for (int j = 0; j <= D; ++j) {
// Compute element (i, j) of X^T * X.
Dtype element = 0;
for (int k = 0; k < N; ++k) {
// (i, k) in X^T (== (k, i) in X) times (k, j) in X.
const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j];
element += element_i * element_j;
if (j == D) {
grad += element * bias.cpu_data()[0];
} else {
grad += element * weights.cpu_data()[j];
for (int k = 0; k < N; ++k) {
const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i];
grad -= element_i * targets.cpu_data()[k];
// Scale the gradient over the N samples.
grad /= N;
// Add the weight decay to the gradient.
grad += weight_decay *
((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]);
// Finally, add any momentum.
const vector<shared_ptr<Blob<Dtype> > >& history = solver_->history();
ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias
Dtype update_value = learning_rate * grad;
if (i == D) {
update_value += momentum * history[1]->cpu_data()[0];
updated_bias.mutable_cpu_diff()[0] = update_value;
updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value;
} else {
update_value += momentum * history[0]->cpu_data()[i];
updated_weights.mutable_cpu_diff()[i] = update_value;
updated_weights.mutable_cpu_data()[i] =
weights.cpu_data()[i] - update_value;

void CheckLeastSquaresUpdate(
const vector<shared_ptr<Blob<Dtype> > >& updated_params) {
const int D = channels_ * height_ * width_;

const Blob<Dtype>& updated_weights = *updated_params[0];
const Blob<Dtype>& updated_bias = *updated_params[1];

Net<Dtype>& net = *this->solver_->net();
const vector<shared_ptr<Blob<Dtype> > >& param_blobs =
ASSERT_EQ(2, param_blobs.size());
const Blob<Dtype>& solver_updated_weights = *param_blobs[0];
ASSERT_EQ(D, solver_updated_weights.count());
const double kPrecision = 1e-3;
const double kMinPrecision = 1e-7;
for (int i = 0; i < D; ++i) {
const Dtype expected_updated_weight = updated_weights.cpu_data()[i];
const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i];
const Dtype error_margin = std::max(kMinPrecision, kPrecision *
std::min(fabs(expected_updated_weight), fabs(solver_updated_weight)));
EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin);
const Blob<Dtype>& solver_updated_bias_blob = *param_blobs[1];
ASSERT_EQ(1, solver_updated_bias_blob.count());
const Dtype expected_updated_bias = updated_bias.cpu_data()[0];
const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0];
const Dtype error_margin = std::max(kMinPrecision, kPrecision *
std::min(fabs(expected_updated_bias), fabs(solver_updated_bias)));
EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin);

// Check the solver's history -- should contain the previous update value.
vector<shared_ptr<Blob<Dtype> > >& history = this->solver_->history();
ASSERT_EQ(2, history.size());
for (int i = 0; i < D; ++i) {
const Dtype expected_history = updated_weights.cpu_diff()[i];
const Dtype solver_history = history[0]->cpu_data()[i];
const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
std::min(fabs(expected_history), fabs(solver_history)));
EXPECT_NEAR(expected_history, solver_history, error_margin_hist);
const Dtype expected_history = updated_bias.cpu_diff()[0];
const Dtype solver_history = history[1]->cpu_data()[0];
const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision *
std::min(fabs(expected_history), fabs(solver_history)));
EXPECT_NEAR(expected_history, solver_history, error_margin_hist);

// Test that the correct update is computed for a regularized least squares
// problem:
// E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2
// \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w
// X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1)
// w \in R^{(d+1) x 1} ((d+1)th element is the bias)
// y \in R^{n x 1}
// lambda is weight_decay
// TestLeastSquaresUpdate works "inductively", assuming that the solver
// correctly updates the net K (= iter_to_check) times, then given the history
// from the Kth update, we compute the (K+1)th update and check that it
// matches the solver's (K+1)th update.
void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0,
const Dtype weight_decay = 0.0, const Dtype momentum = 0.0,
const int iter_to_check = 0) {
// Initialize the solver and run K (= iter_to_check) solver iterations.
RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check);

// Compute the (K+1)th update using the analytic least squares gradient.
vector<shared_ptr<Blob<Dtype> > > updated_params;
ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum,

// Reinitialize the solver and run K+1 solver iterations.
RunLeastSquaresSolver(learning_rate, weight_decay, momentum,
iter_to_check + 1);

// Check that the solver's solution matches ours.

TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices);

TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdate) {
typedef typename TypeParam::Dtype Dtype;

TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateLROneTenth) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.1;

TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecay) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.5;
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay);

TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.0;
const Dtype kMomentum = 0.5;
const int kNumIters = 1;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);

TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 1.0;
const Dtype kWeightDecay = 0.0;
const Dtype kMomentum = 0.5;
const int kNumIters = 5;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);

TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) {
typedef typename TypeParam::Dtype Dtype;
const Dtype kLearningRate = 0.01;
const Dtype kWeightDecay = 0.1;
const Dtype kMomentum = 0.9;
const int kNumIters = 5;
for (int i = 0; i <= kNumIters; ++i) {
this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i);

} // namespace caffe

0 comments on commit 7b54598

Please sign in to comment.