Skip to content

Commit

Permalink
Merge pull request #3089 from stan-dev/feature/vari-set-adj
Browse files Browse the repository at this point in the history
adds constructor to vari for passing both initial values and adjoints
  • Loading branch information
SteveBronder authored Jun 21, 2024
2 parents 87bb8a7 + 29b366e commit fe521e0
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 7 deletions.
22 changes: 21 additions & 1 deletion stan/math/opencl/rev/vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class vari_value<T, require_matrix_cl_t<T>> : public chainable_alloc,
require_vt_same<T, S>* = nullptr>
explicit vari_value(const S& x)
: chainable_alloc(), vari_cl_base<T>(x, constant(0, x.rows(), x.cols())) {
ChainableStack::instance_->var_stack_.push_back(this);
ChainableStack::instance_->var_nochain_stack_.push_back(this);
}

/**
Expand Down Expand Up @@ -259,6 +259,26 @@ class vari_value<T, require_matrix_cl_t<T>> : public chainable_alloc,
}
}

/**
* Construct a dense Eigen variable implementation from a
* preconstructed values and adjoints.
*
* All constructed variables are not added to the stack. Variables
* should be constructed before variables on which they depend
* to insure proper partial derivative propagation.
* @tparam S A dense Eigen type that is convertible to `value_type`
* @tparam K A dense Eigen type that is convertible to `value_type`
* @param val Matrix of values
* @param adj Matrix of adjoints
*/
template <typename S, typename K, require_convertible_t<T, S>* = nullptr,
require_convertible_t<T, K>* = nullptr>
explicit vari_value(S&& val, K&& adj)
: chainable_alloc(),
vari_cl_base<T>(std::forward<S>(val), std::forward<K>(adj)) {
ChainableStack::instance_->var_nochain_stack_.push_back(this);
}

/**
* Set the adjoint value of this variable to 0. This is used to
* reset adjoints before propagating derivatives again (for
Expand Down
2 changes: 1 addition & 1 deletion stan/math/rev/core/callback_vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct callback_vari : public vari_value<T> {
template <typename S,
require_same_t<plain_type_t<T>, plain_type_t<S>>* = nullptr>
explicit callback_vari(S&& value, F&& rev_functor)
: vari_value<T>(std::move(value)),
: vari_value<T>(std::move(value), true),
rev_functor_(std::forward<F>(rev_functor)) {}

inline void chain() final { rev_functor_(*this); }
Expand Down
26 changes: 21 additions & 5 deletions stan/math/rev/core/vari.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,11 +678,9 @@ class vari_value<T, require_all_t<is_plain_type<T>, is_eigen_dense_base<T>>>
* Construct a dense Eigen variable implementation from a value. The
* adjoint is initialized to zero.
*
* All constructed variables are added to the stack. Variables
* All constructed variables are added to the no chain stack. Variables
* should be constructed before variables on which they depend
* to insure proper partial derivative propagation. During
* derivative propagation, the chain() method of each variable
* will be called in the reverse order of construction.
* to insure proper partial derivative propagation.
*
* @tparam S A dense Eigen type that is convertible to `value_type`
* @param x Value of the constructed variable.
Expand All @@ -699,7 +697,7 @@ class vari_value<T, require_all_t<is_plain_type<T>, is_eigen_dense_base<T>>>
? x.rows()
: x.cols()) {
adj_.setZero();
ChainableStack::instance_->var_stack_.push_back(this);
ChainableStack::instance_->var_nochain_stack_.push_back(this);
}

/**
Expand Down Expand Up @@ -736,6 +734,24 @@ class vari_value<T, require_all_t<is_plain_type<T>, is_eigen_dense_base<T>>>
}
}

/**
* Construct a dense Eigen variable implementation from a
* preconstructed values and adjoints.
*
* All constructed variables are not added to the stack. Variables
* should be constructed before variables on which they depend
* to insure proper partial derivative propagation.
* @tparam S A dense Eigen type that is convertible to `value_type`
* @tparam K A dense Eigen type that is convertible to `value_type`
* @param val Matrix of values
* @param adj Matrix of adjoints
*/
template <typename S, typename K, require_assignable_t<T, S>* = nullptr,
require_assignable_t<T, K>* = nullptr>
explicit vari_value(const S& val, const K& adj) : val_(val), adj_(adj) {
ChainableStack::instance_->var_nochain_stack_.push_back(this);
}

protected:
template <typename S, require_not_same_t<T, S>* = nullptr>
explicit vari_value(const vari_value<S>* x) : val_(x->val_), adj_(x->adj_) {}
Expand Down
5 changes: 5 additions & 0 deletions test/unit/math/opencl/rev/vari_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ TEST(AgradRev, matrix_cl_vari_block) {
stan::math::from_matrix_cl(B.block(0, 1, 2, 2).val_));
EXPECT_MATRIX_EQ(b.block(0, 1, 2, 2),
stan::math::from_matrix_cl(B.block(0, 1, 2, 2).adj_));
vari_value<stan::math::matrix_cl<double>> C(a_cl, a_cl);
EXPECT_MATRIX_EQ(a.block(0, 1, 2, 2),
stan::math::from_matrix_cl(C.block(0, 1, 2, 2).val_));
EXPECT_MATRIX_EQ(a.block(0, 1, 2, 2),
stan::math::from_matrix_cl(C.block(0, 1, 2, 2).adj_));
}

#endif
7 changes: 7 additions & 0 deletions test/unit/math/rev/core/vari_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ TEST(AgradRevVari, arena_matrix_matrix_vari) {
EXPECT_MATRIX_FLOAT_EQ((*C).val(), x);
auto* D = new vari_value<Eigen::MatrixXd>(x_ref, true);
EXPECT_MATRIX_FLOAT_EQ((*D).val(), x);
auto* E = new vari_value<Eigen::MatrixXd>(x, (x.array() + 1.0).matrix());
EXPECT_MATRIX_FLOAT_EQ((*E).val(), x);
EXPECT_MATRIX_FLOAT_EQ((*E).adj(), (x.array() + 1.0).matrix());
auto* F = new vari_value<Eigen::MatrixXd>(x, x);
EXPECT_MATRIX_FLOAT_EQ((*F).val(), x);
EXPECT_MATRIX_FLOAT_EQ((*F).adj(), x);
EXPECT_EQ((*F).val().data(), (*F).adj().data());
}

TEST(AgradRevVari, dense_vari_matrix_views) {
Expand Down

0 comments on commit fe521e0

Please sign in to comment.