Skip to content

Commit

Permalink
Merge branch 'add-log_sum_exp-func' of https://github.com/MichaScant/…
Browse files Browse the repository at this point in the history
…math into add-log_sum_exp-func
  • Loading branch information
MichaScant committed Nov 25, 2024
2 parents 6d39d20 + 85c4a94 commit c0d738b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 90 deletions.
132 changes: 68 additions & 64 deletions stan/math/fwd/fun/log_add_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,109 +16,113 @@ namespace math {
// Overload for fvar and fvar
template <typename T>
inline fvar<T> log_add_exp(const fvar<T>& x1, const fvar<T>& x2) {
auto val = stan::math::log_add_exp(x1.val_, x2.val_);
auto exp_x1 = stan::math::exp(x1.val_);
auto exp_x2 = stan::math::exp(x2.val_);
auto sum_exp = exp_x1 + exp_x2;
auto grad1 = exp_x1 / sum_exp;
auto grad2 = exp_x2 / sum_exp;
return fvar<T>(val, x1.d_ * grad1 + x2.d_ * grad2);
auto val = stan::math::log_add_exp(x1.val_, x2.val_);

auto exp_x1 = stan::math::exp(x1.val_);
auto exp_x2 = stan::math::exp(x2.val_);
auto sum_exp = exp_x1 + exp_x2;

auto grad1 = exp_x1 / sum_exp;
auto grad2 = exp_x2 / sum_exp;

return fvar<T>(val, x1.d_ * grad1 + x2.d_ * grad2);
}

template <typename T>
inline fvar<T> log_add_exp(const fvar<T>& x1, double x2) {
if (x1.val_ == NEGATIVE_INFTY) {
return fvar<T>(x2, 0.0); // log_add_exp(-∞, b) = b
}
return log_add_exp(x2, x1);
if (x1.val_ == NEGATIVE_INFTY) {
return fvar<T>(x2, 0.0); // log_add_exp(-∞, b) = b
}
return log_add_exp(x2, x1);
}

template <typename T>
inline fvar<T> log_add_exp(double x1, const fvar<T>& x2) {
if (x2.val_ == NEGATIVE_INFTY) {
return fvar<T>(x1, 0.0); // log_add_exp(a, -∞) = a
}
auto val = stan::math::log_add_exp(x1, x2.val_);
auto exp_x2 = stan::math::exp(x2.val_);
auto grad = exp_x2 / (stan::math::exp(x1) + exp_x2);
return fvar<T>(val, x2.d_ * grad);
if (x2.val_ == NEGATIVE_INFTY) {
return fvar<T>(x1, 0.0); // log_add_exp(a, -∞) = a
}
auto val = stan::math::log_add_exp(x1, x2.val_);
auto exp_x2 = stan::math::exp(x2.val_);
auto grad = exp_x2 / (stan::math::exp(x1) + exp_x2);
return fvar<T>(val, x2.d_ * grad);
}

// Overload for matrices of fvar
template <typename T>
inline fvar<T> log_add_exp(const Eigen::Matrix<fvar<T>, -1, -1>& a,
const Eigen::Matrix<fvar<T>, -1, -1>& b) {
inline fvar<T> log_add_exp(const Eigen::Matrix<fvar<T>, -1, -1>& a, const Eigen::Matrix<fvar<T>, -1, -1>& b) {

using fvar_mat_type = Eigen::Matrix<fvar<T>, -1, -1>;
fvar_mat_type result(a.rows(), a.cols());

// Check for empty inputs
if (a.size() == 0 || b.size() == 0) {
throw std::invalid_argument("Input containers must not be empty.");
throw std::invalid_argument("Input containers must not be empty.");
}

// Check for NaN
if (a.array().isNaN().any() || b.array().isNaN().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Check for infinity
if (a.array().isInf().any() || b.array().isInf().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Apply the log_add_exp operation directly
for (int i = 0; i < a.rows(); ++i) {
for (int j = 0; j < a.cols(); ++j) {
result(i, j) = stan::math::log_add_exp(a(i, j), b(i, j));
}
for (int j = 0; j < a.cols(); ++j) {
result(i, j) = stan::math::log_add_exp(a(i, j), b(i, j));
}
}

return result; // Return the result matrix
return result; // Return the result matrix
}

// Specialization for nested fvar types
template <typename T>
inline auto log_add_exp(const Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>& a,
const Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>& b) {
using nested_fvar_mat_type = Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>;
nested_fvar_mat_type result(a.rows(), a.cols());

// Check for empty inputs
if (a.size() == 0 || b.size() == 0) {
throw std::invalid_argument("Input containers must not be empty.");
inline auto log_add_exp(
const Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>& a,
const Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>&
b) {
using nested_fvar_mat_type
= Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>;
nested_fvar_mat_type result(a.rows(), a.cols());

// Check for empty inputs
if (a.size() == 0 || b.size() == 0) {
throw std::invalid_argument("Input containers must not be empty.");
}

// Check for NaN
if (a.array().isNaN().any() || b.array().isNaN().any()) {
result.setConstant(stan::math::fvar<stan::math::fvar<double>>(
std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Check for infinity
if (a.array().isInf().any() || b.array().isInf().any()) {
result.setConstant(stan::math::fvar<stan::math::fvar<double>>(
std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Implement the logic for log_add_exp for nested fvar types
for (int i = 0; i < a.rows(); ++i) {
for (int j = 0; j < a.cols(); ++j) {
auto inner_a = a(i, j);
auto inner_b = b(i, j);
result(i, j) = stan::math::log_add_exp(inner_a, inner_b);
}
}

// Check for NaN
if (a.array().isNaN().any() || b.array().isNaN().any()) {
result.setConstant(stan::math::fvar<stan::math::fvar<double>>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Check for infinity
if (a.array().isInf().any() || b.array().isInf().any()) {
result.setConstant(stan::math::fvar<stan::math::fvar<double>>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Implement the logic for log_add_exp for nested fvar types
for (int i = 0; i < a.rows(); ++i) {
for (int j = 0; j < a.cols(); ++j) {
auto inner_a = a(i, j);
auto inner_b = b(i, j);
result(i, j) = stan::math::log_add_exp(inner_a, inner_b);
}
}

return result; // Return the result matrix
return result; // Return the result matrix
}

}
}
} // namespace math
} // namespace stan

#endif
55 changes: 30 additions & 25 deletions stan/math/prim/fun/log_add_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,41 +57,46 @@ inline return_type_t<T1, T2> log_add_exp(const T2& a, const T1& b) {
*/
template <typename T, require_container_st<std::is_arithmetic, T>* = nullptr>
inline auto log_add_exp(const T& a, const T& b) {
if (a.size() != b.size()) {
throw std::invalid_argument("Binary function: size of x (" + std::to_string(a.size()) + ") and size of y (" + std::to_string(b.size()) + ") must match in size");
}
if (a.size() != b.size()) {
throw std::invalid_argument("Binary function: size of x ("
+ std::to_string(a.size()) + ") and size of y ("
+ std::to_string(b.size())
+ ") must match in size");
}

const size_t min_size = std::min(a.size(), b.size());
using return_t = return_type_t<T>;
const size_t min_size = std::min(a.size(), b.size());
using return_t = return_type_t<T>;

std::vector<return_t> result(min_size);
std::vector<return_t> result(min_size);

for (size_t i = 0; i < min_size; ++i) {
if (a[i] == NEGATIVE_INFTY) {
result[i] = b[i]; // log_add_exp(-∞, b) = b
} else if (b[i] == NEGATIVE_INFTY) {
result[i] = a[i]; // log_add_exp(a, -∞) = a
} else if (a[i] == INFTY || b[i] == INFTY) {
result[i] = INFTY; // log_add_exp(∞, b) = ∞
} else {
// Log-add-exp trick
const double max_val = std::max(a[i], b[i]);
result[i] = max_val + std::log(std::exp(a[i] - max_val) + std::exp(b[i] - max_val));
}
for (size_t i = 0; i < min_size; ++i) {
if (a[i] == NEGATIVE_INFTY) {
result[i] = b[i]; // log_add_exp(-∞, b) = b
} else if (b[i] == NEGATIVE_INFTY) {
result[i] = a[i]; // log_add_exp(a, -∞) = a
} else if (a[i] == INFTY || b[i] == INFTY) {
result[i] = INFTY; // log_add_exp(∞, b) = ∞
} else {
// Log-add-exp trick
const double max_val = std::max(a[i], b[i]);
result[i]
= max_val
+ std::log(std::exp(a[i] - max_val) + std::exp(b[i] - max_val));
}
}

return result;
return result;
}

/**
* Enables the vectorized application of the log_add_exp function,
* when the first and/or second arguments are containers.
*
* @tparam T1
* @tparam T2
* @param a
* @param b
* @return auto
*
* @tparam T1
* @tparam T2
* @param a
* @param b
* @return auto
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
inline auto log_add_exp(const T1& a, const T2& b) {
Expand Down
2 changes: 1 addition & 1 deletion stan/math/rev/fun/log_add_exp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ inline T log_add_exp(const T& x, const T& y) {
} // namespace math
} // namespace stan

#endif
#endif

0 comments on commit c0d738b

Please sign in to comment.