Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add overloads for tuples of bounds to lub_free and lub_constrain #3087

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions stan/math/prim/fun/lub_constrain.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,33 @@ inline auto lub_constrain(const T& x, const L& lb, const U& ub,
}
}

/**
* Wrapper for tuple of bounds, simply delegates to the appropriate overload
*/
template <typename T, typename L, typename U>
inline auto lub_constrain(const T& x, const std::tuple<L, U>& bounds) {
return lub_constrain(x, std::get<0>(bounds), std::get<1>(bounds));
}

/**
* Wrapper for tuple of bounds, simply delegates to the appropriate overload
*/
template <typename T, typename L, typename U>
inline auto lub_constrain(const T& x, const std::tuple<L, U>& bounds,
return_type_t<T, L, U>& lp) {
return lub_constrain(x, std::get<0>(bounds), std::get<1>(bounds), lp);
}

/**
* Wrapper for tuple of bounds, simply delegates to the appropriate overload
*/
template <bool Jacobian, typename T, typename L, typename U>
inline auto lub_constrain(const T& x, const std::tuple<L, U>& bounds,
return_type_t<T, L, U>& lp) {
return lub_constrain<Jacobian>(x, std::get<0>(bounds), std::get<1>(bounds),
lp);
}

} // namespace math
} // namespace stan

Expand Down
8 changes: 8 additions & 0 deletions stan/math/prim/fun/lub_free.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ inline auto lub_free(const std::vector<T> y, const std::vector<L>& lb,
}
return ret;
}

/**
* Wrapper for tuple of bounds, simply delegates to the appropriate overload
*/
template <typename T, typename L, typename U>
inline auto lub_free(T&& y, const std::tuple<L, U>& bounds) {
return lub_free(std::forward<T>(y), std::get<0>(bounds), std::get<1>(bounds));
}
///@}

} // namespace math
Expand Down
20 changes: 20 additions & 0 deletions test/unit/math/mix/fun/lub_constrain_helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,21 @@ void expect(const T1& x, const T2& lb, const T3& ub) {
auto xx = stan::math::lub_constrain<true>(x, lb, ub, lp);
return stan::math::add(lp, stan::math::sum(xx));
};
auto f5 = [](const auto& x, const auto& lb, const auto& ub) {
stan::return_type_t<decltype(x), decltype(lb), decltype(ub)> lp = 0;
return stan::math::lub_constrain<false>(x, std::make_tuple(lb, ub), lp);
};
auto f6 = [](const auto& x, const auto& lb, const auto& ub) {
stan::return_type_t<decltype(x), decltype(lb), decltype(ub)> lp = 0;
return stan::math::lub_constrain<true>(x, std::make_tuple(lb, ub), lp);
};

stan::test::expect_ad(f1, x, lb, ub);
stan::test::expect_ad(f2, x, lb, ub);
stan::test::expect_ad(f3, x, lb, ub);
stan::test::expect_ad(f4, x, lb, ub);
stan::test::expect_ad(f5, x, lb, ub);
stan::test::expect_ad(f6, x, lb, ub);
}
template <typename T1, typename T2, typename T3>
void expect_vec(const T1& x, const T2& lb, const T3& ub) {
Expand All @@ -52,11 +62,21 @@ void expect_vec(const T1& x, const T2& lb, const T3& ub) {
}
return stan::math::add(lp, xx_acc);
};
auto f5 = [](const auto& x, const auto& lb, const auto& ub) {
stan::return_type_t<decltype(x), decltype(lb), decltype(ub)> lp = 0;
return stan::math::lub_constrain<false>(x, std::make_tuple(lb, ub), lp);
};
auto f6 = [](const auto& x, const auto& lb, const auto& ub) {
stan::return_type_t<decltype(x), decltype(lb), decltype(ub)> lp = 0;
return stan::math::lub_constrain<true>(x, std::make_tuple(lb, ub), lp);
};

stan::test::expect_ad(f1, x, lb, ub);
stan::test::expect_ad(f2, x, lb, ub);
stan::test::expect_ad(f3, x, lb, ub);
stan::test::expect_ad(f4, x, lb, ub);
stan::test::expect_ad(f5, x, lb, ub);
stan::test::expect_ad(f6, x, lb, ub);
}
} // namespace lub_constrain_tests

Expand Down