Skip to content

Commit

Permalink
MRA: Fix handling of function buffer
Browse files Browse the repository at this point in the history
The function buffer doesn't provide the interface to
query the initial level.

Signed-off-by: Joseph Schuchart <joseph.schuchart@stonybrook.edu>
  • Loading branch information
devreal committed Sep 17, 2024
1 parent 3c86e4c commit 50d0bcf
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions examples/madness/mra-device/mrattg-device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ auto make_start(const ttg::Edge<mra::Key<NDIM>, void>& ctl) {
template<typename FnT, typename T, mra::Dimension NDIM>
auto make_project(
mra::Domain<NDIM>& domain,
ttg::Buffer<FnT>& f,
ttg::Buffer<FnT>& fb,
std::size_t K,
const mra::FunctionData<T, NDIM>& functiondata,
const T thresh, /// should be scalar value not complex
Expand All @@ -46,6 +46,7 @@ auto make_project(
auto result = node_type(key, K);
tensor_type& coeffs = result.coeffs;
auto outputs = ttg::device::forward();
auto& f = *fb.host_ptr();

if (key.level() < initial_level(f)) {
std::vector<mra::Key<NDIM>> bcast_keys;
Expand All @@ -61,7 +62,7 @@ auto make_project(
coeffs.current_view() = T(1e7); // set to obviously bad value to detect incorrect use
result.is_leaf = false;
}
else if (mra::is_negligible<FnT,T,NDIM>(*f.host_ptr(), domain.template bounding_box<T>(key), mra::truncate_tol(key,thresh))) {
else if (mra::is_negligible<FnT,T,NDIM>(f, domain.template bounding_box<T>(key), mra::truncate_tol(key,thresh))) {
/* zero coeffs */
coeffs.current_view() = T(0.0);
result.is_leaf = true;
Expand All @@ -88,15 +89,15 @@ auto make_project(

/* TODO: cannot do this from a function, need to move it into the main task */
#ifndef TTG_ENABLE_HOST
co_await ttg::device::select(db, gl, f, coeffs.buffer(), phibar.buffer(),
co_await ttg::device::select(db, gl, fb, coeffs.buffer(), phibar.buffer(),
hgT.buffer(), tmp_scratch, is_leaf_scratch);
#endif
auto coeffs_view = coeffs.current_view();
auto phibar_view = phibar.current_view();
auto hgT_view = hgT.current_view();
T* tmp_device = tmp_scratch.device_ptr();
bool *is_leaf_device = is_leaf_scratch.device_ptr();
FnT* f_ptr = f.current_device_ptr();
FnT* f_ptr = fb.current_device_ptr();
auto& domain = *db.current_device_ptr();
auto gldata = gl.current_device_ptr();

Expand Down Expand Up @@ -390,7 +391,7 @@ auto make_reconstruct(
static std::mutex printer_guard;
template <typename keyT, typename valueT>
auto make_printer(const ttg::Edge<keyT, valueT>& in, const char* str = "", const bool doprint=true) {
auto func = [str,doprint](const keyT& key, const auto& value, auto& out) {
auto func = [str,doprint](const keyT& key, const valueT& value) {
if (doprint) {
std::lock_guard<std::mutex> obolus(printer_guard);
std::cout << str << " (" << key << "," << value << ")" << std::endl;
Expand Down

0 comments on commit 50d0bcf

Please sign in to comment.