diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index d2f61e57f34..4e77863e7e6 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -140,6 +140,9 @@ message SolverParameter { // If true, print information about the state of the net that may help with // debugging learning problems. optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; } // A message that stores the solver snapshots diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 11795f058cd..6049ddfec7c 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -200,8 +200,9 @@ void Solver::Solve(const char* resume_file) { ComputeUpdateValue(); net_->Update(); } - // Always save a snapshot after optimization. - Snapshot(); + // Always save a snapshot after optimization, unless overridden by setting + // snapshot_after_train := false. + if (param_.snapshot_after_train()) { Snapshot(); } // After the optimization is done, run an additional train and test pass to // display the train and test loss/outputs if appropriate (based on the // display and test_interval settings, respectively). Unlike in the rest of diff --git a/src/caffe/test/test_sgd_solver.cpp b/src/caffe/test/test_sgd_solver.cpp new file mode 100644 index 00000000000..a28ed7b1817 --- /dev/null +++ b/src/caffe/test/test_sgd_solver.cpp @@ -0,0 +1,346 @@ +// Copyright 2014 BVLC and contributors. + +#include +#include +#include +#include + +#include "google/protobuf/text_format.h" + +#include "gtest/gtest.h" +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +using std::ostringstream; + +namespace caffe { + +template +class SGDSolverTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + + protected: + SGDSolverTest() : + seed_(1701), num_(5), channels_(3), height_(10), width_(10) {} + + // MockSGDSolver: an SGDSolver with public history. + class MockSGDSolver : public SGDSolver { + public: + explicit MockSGDSolver(const SolverParameter& param) : + SGDSolver(param) {} + vector > >& history() { return this->history_; } + }; + + shared_ptr solver_; + int seed_; + int num_, channels_, height_, width_; + + virtual void InitSolverFromProtoString(const string& proto) { + SolverParameter param; + CHECK(google::protobuf::TextFormat::ParseFromString(proto, ¶m)); + // Disable saving a final snapshot so the tests don't pollute the user's + // working directory with useless snapshots. + param.set_snapshot_after_train(false); + // Set the solver_mode according to current Caffe::mode. + switch (Caffe::mode()) { + case Caffe::CPU: + param.set_solver_mode(SolverParameter_SolverMode_CPU); + break; + case Caffe::GPU: + param.set_solver_mode(SolverParameter_SolverMode_GPU); + break; + default: + LOG(FATAL) << "Unknown Caffe mode: " << Caffe::mode(); + } + solver_.reset(new MockSGDSolver(param)); + } + + void RunLeastSquaresSolver(const Dtype learning_rate, + const Dtype weight_decay, const Dtype momentum, const int num_iters) { + ostringstream proto; + proto << + "max_iter: " << num_iters << " " + "base_lr: " << learning_rate << " " + "lr_policy: 'fixed' " + "net_param { " + " name: 'TestNetwork' " + " layers: { " + " name: 'data' " + " type: DUMMY_DATA " + " dummy_data_param { " + " num: " << num_ << " " + " channels: " << channels_ << " " + " height: " << height_ << " " + " width: " << width_ << " " + " channels: 1 " + " height: 1 " + " width: 1 " + " data_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " } " + " top: 'data' " + " top: 'targets' " + " } " + " layers: { " + " name: 'innerprod' " + " type: INNER_PRODUCT " + " inner_product_param { " + " num_output: 1 " + " weight_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " bias_filler { " + " type: 'gaussian' " + " std: 1.0 " + " } " + " } " + " bottom: 'data' " + " top: 'innerprod' " + " } " + " layers: { " + " name: 'loss' " + " type: EUCLIDEAN_LOSS " + " bottom: 'innerprod' " + " bottom: 'targets' " + " } " + "} "; + if (weight_decay != 0) { + proto << "weight_decay: " << weight_decay << " "; + } + if (momentum != 0) { + proto << "momentum: " << momentum << " "; + } + Caffe::set_random_seed(this->seed_); + this->InitSolverFromProtoString(proto.str()); + this->solver_->Solve(); + } + + // Compute an update value given the current state of the train net, + // using the analytical formula for the least squares gradient. + // updated_params will store the updated weight and bias results, + // using the blobs' diffs to hold the update values themselves. + void ComputeLeastSquaresUpdate(const Dtype learning_rate, + const Dtype weight_decay, const Dtype momentum, + vector > >* updated_params) { + const int N = num_; + const int D = channels_ * height_ * width_; + + // Run a forward pass, and manually compute the update values from the + // result. + Net& net = *this->solver_->net(); + vector*> empty_bottom_vec; + net.Forward(empty_bottom_vec); + ASSERT_TRUE(net.has_blob("data")); + const Blob& data = *net.blob_by_name("data"); + ASSERT_TRUE(net.has_blob("targets")); + const Blob& targets = *net.blob_by_name("targets"); + ASSERT_TRUE(net.has_layer("innerprod")); + const vector > >& param_blobs = + net.layer_by_name("innerprod")->blobs(); + const int num_param_blobs = 2; + ASSERT_EQ(num_param_blobs, param_blobs.size()); + const Blob& weights = *param_blobs[0]; + const Blob& bias = *param_blobs[1]; + ASSERT_EQ(D * N, data.count()); + ASSERT_EQ(N, targets.count()); + ASSERT_EQ(D, weights.count()); + ASSERT_EQ(1, bias.count()); + + updated_params->clear(); + updated_params->resize(num_param_blobs); + for (int i = 0; i < num_param_blobs; ++i) { + (*updated_params)[i].reset(new Blob()); + } + Blob& updated_weights = *(*updated_params)[0]; + updated_weights.ReshapeLike(weights); + Blob& updated_bias = *(*updated_params)[1]; + updated_bias.ReshapeLike(bias); + + for (int i = 0; i <= D; ++i) { + // Compute the derivative with respect to the ith weight (i.e., the ith + // element of the gradient). + Dtype grad = 0; + for (int j = 0; j <= D; ++j) { + // Compute element (i, j) of X^T * X. + Dtype element = 0; + for (int k = 0; k < N; ++k) { + // (i, k) in X^T (== (k, i) in X) times (k, j) in X. + const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; + const Dtype element_j = (j == D) ? 1 : data.cpu_data()[k * D + j]; + element += element_i * element_j; + } + if (j == D) { + grad += element * bias.cpu_data()[0]; + } else { + grad += element * weights.cpu_data()[j]; + } + } + for (int k = 0; k < N; ++k) { + const Dtype element_i = (i == D) ? 1 : data.cpu_data()[k * D + i]; + grad -= element_i * targets.cpu_data()[k]; + } + // Scale the gradient over the N samples. + grad /= N; + // Add the weight decay to the gradient. + grad += weight_decay * + ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); + // Finally, add any momentum. + const vector > >& history = solver_->history(); + ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias + Dtype update_value = learning_rate * grad; + if (i == D) { + update_value += momentum * history[1]->cpu_data()[0]; + updated_bias.mutable_cpu_diff()[0] = update_value; + updated_bias.mutable_cpu_data()[0] = bias.cpu_data()[0] - update_value; + } else { + update_value += momentum * history[0]->cpu_data()[i]; + updated_weights.mutable_cpu_diff()[i] = update_value; + updated_weights.mutable_cpu_data()[i] = + weights.cpu_data()[i] - update_value; + } + } + } + + void CheckLeastSquaresUpdate( + const vector > >& updated_params) { + const int D = channels_ * height_ * width_; + + const Blob& updated_weights = *updated_params[0]; + const Blob& updated_bias = *updated_params[1]; + + Net& net = *this->solver_->net(); + ASSERT_TRUE(net.has_layer("innerprod")); + const vector > >& param_blobs = + net.layer_by_name("innerprod")->blobs(); + ASSERT_EQ(2, param_blobs.size()); + const Blob& solver_updated_weights = *param_blobs[0]; + ASSERT_EQ(D, solver_updated_weights.count()); + const double kPrecision = 1e-3; + const double kMinPrecision = 1e-7; + for (int i = 0; i < D; ++i) { + const Dtype expected_updated_weight = updated_weights.cpu_data()[i]; + const Dtype solver_updated_weight = solver_updated_weights.cpu_data()[i]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_updated_weight), fabs(solver_updated_weight))); + EXPECT_NEAR(expected_updated_weight, solver_updated_weight, error_margin); + } + const Blob& solver_updated_bias_blob = *param_blobs[1]; + ASSERT_EQ(1, solver_updated_bias_blob.count()); + const Dtype expected_updated_bias = updated_bias.cpu_data()[0]; + const Dtype solver_updated_bias = solver_updated_bias_blob.cpu_data()[0]; + const Dtype error_margin = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_updated_bias), fabs(solver_updated_bias))); + EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin); + + // Check the solver's history -- should contain the previous update value. + vector > >& history = this->solver_->history(); + ASSERT_EQ(2, history.size()); + for (int i = 0; i < D; ++i) { + const Dtype expected_history = updated_weights.cpu_diff()[i]; + const Dtype solver_history = history[0]->cpu_data()[i]; + const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_history), fabs(solver_history))); + EXPECT_NEAR(expected_history, solver_history, error_margin_hist); + } + const Dtype expected_history = updated_bias.cpu_diff()[0]; + const Dtype solver_history = history[1]->cpu_data()[0]; + const Dtype error_margin_hist = std::max(kMinPrecision, kPrecision * + std::min(fabs(expected_history), fabs(solver_history))); + EXPECT_NEAR(expected_history, solver_history, error_margin_hist); + } + + // Test that the correct update is computed for a regularized least squares + // problem: + // + // E = (1/(2n)) || X w - y ||^2 + (lambda / 2) || w ||^2 + // \nabla_w E = (1/n) (X^T X w - X^T y) + lambda * w + // + // X \in R^{n x (d+1)} (each example is a row, (d+1)th element is always 1) + // w \in R^{(d+1) x 1} ((d+1)th element is the bias) + // y \in R^{n x 1} + // lambda is weight_decay + // + // TestLeastSquaresUpdate works "inductively", assuming that the solver + // correctly updates the net K (= iter_to_check) times, then given the history + // from the Kth update, we compute the (K+1)th update and check that it + // matches the solver's (K+1)th update. + void TestLeastSquaresUpdate(const Dtype learning_rate = 1.0, + const Dtype weight_decay = 0.0, const Dtype momentum = 0.0, + const int iter_to_check = 0) { + // Initialize the solver and run K (= iter_to_check) solver iterations. + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, iter_to_check); + + // Compute the (K+1)th update using the analytic least squares gradient. + vector > > updated_params; + ComputeLeastSquaresUpdate(learning_rate, weight_decay, momentum, + &updated_params); + + // Reinitialize the solver and run K+1 solver iterations. + RunLeastSquaresSolver(learning_rate, weight_decay, momentum, + iter_to_check + 1); + + // Check that the solver's solution matches ours. + CheckLeastSquaresUpdate(updated_params); + } +}; + +TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices); + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdate) { + typedef typename TypeParam::Dtype Dtype; + this->TestLeastSquaresUpdate(); +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateLROneTenth) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.1; + this->TestLeastSquaresUpdate(kLearningRate); +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithWeightDecay) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.5; + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay); +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentum) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 1; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithMomentumMultiIter) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 1.0; + const Dtype kWeightDecay = 0.0; + const Dtype kMomentum = 0.5; + const int kNumIters = 5; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +TYPED_TEST(SGDSolverTest, TestLeastSquaresUpdateWithEverything) { + typedef typename TypeParam::Dtype Dtype; + const Dtype kLearningRate = 0.01; + const Dtype kWeightDecay = 0.1; + const Dtype kMomentum = 0.9; + const int kNumIters = 5; + for (int i = 0; i <= kNumIters; ++i) { + this->TestLeastSquaresUpdate(kLearningRate, kWeightDecay, kMomentum, i); + } +} + +} // namespace caffe