diff --git a/examples/madness/mra-device/kernels.cu b/examples/madness/mra-device/kernels.cu index 99ab0e516..c536b14f3 100644 --- a/examples/madness/mra-device/kernels.cu +++ b/examples/madness/mra-device/kernels.cu @@ -542,9 +542,9 @@ void submit_reconstruct_kernel( const TensorView& from_parent, const std::array::num_children()>& r_arr, T* tmp, + std::size_t K, cudaStream_t stream) { - const std::size_t K = node.dim(0); /* runs on a single block */ Dim3 thread_dims = Dim3(K, 1, 1); // figure out how to consider register usage CALL_KERNEL(reconstruct_kernel, 1, thread_dims, 0, stream)( @@ -562,4 +562,5 @@ void submit_reconstruct_kernel( const TensorView& from_parent, const std::array::num_children()>& r_arr, double* tmp, + std::size_t K, cudaStream_t stream); \ No newline at end of file diff --git a/examples/madness/mra-device/kernels.h b/examples/madness/mra-device/kernels.h index 8447a7f1d..7cc38ec25 100644 --- a/examples/madness/mra-device/kernels.h +++ b/examples/madness/mra-device/kernels.h @@ -73,6 +73,7 @@ void submit_reconstruct_kernel( const mra::TensorView& from_parent, const std::array::num_children()>& r_arr, T* tmp, + std::size_t K, ttg::device::Stream stream); #endif // HAVE_KERNELS_H \ No newline at end of file diff --git a/examples/madness/mra-device/mrattg-device.cc b/examples/madness/mra-device/mrattg-device.cc index 4a492cec4..e70807d81 100644 --- a/examples/madness/mra-device/mrattg-device.cc +++ b/examples/madness/mra-device/mrattg-device.cc @@ -174,6 +174,7 @@ static TASKTYPE do_send_leafs_up(const mra::Key& key, const mra::FunctionR /// Make a composite operator that implements compression for a single function template static auto make_compress( + const std::size_t K, const mra::FunctionData& functiondata, ttg::Edge, mra::FunctionReconstructedNode>& in, ttg::Edge, mra::FunctionCompressedNode>& out) @@ -190,7 +191,7 @@ static auto make_compress( /* append out edge to set of edges */ auto compress_out_edges = std::tuple_cat(send_to_compress_edges, std::make_tuple(out)); /* use the tuple variant to handle variable number of inputs while suppressing the output tuple */ - auto do_compress = [&](const mra::Key& key, + auto do_compress = [&, K](const mra::Key& key, //const std::tuple& input_frns const mra::FunctionReconstructedNode &in0, const mra::FunctionReconstructedNode &in1, @@ -204,7 +205,6 @@ static auto make_compress( //typename ::detail::tree_types::compress_out_type& out) { constexpr const auto num_children = mra::Key::num_children(); constexpr const auto out_terminal_id = num_children; - auto K = in0.coeffs.dim(0); mra::FunctionCompressedNode result(key, K); // The eventual result auto& d = result.coeffs; // allocate even though we might not need it @@ -300,7 +300,7 @@ auto make_reconstruct( { ttg::Edge, mra::Tensor> S("S"); // passes scaling functions down - auto do_reconstruct = [&](const mra::Key& key, + auto do_reconstruct = [&, K](const mra::Key& key, mra::FunctionCompressedNode&& node, const mra::Tensor& from_parent) -> TASKTYPE { const std::size_t K = from_parent.dim(0); @@ -340,7 +340,7 @@ auto make_reconstruct( auto hg_view = hg.current_view(); auto from_parent_view = from_parent.current_view(); submit_reconstruct_kernel(key, node_view, hg_view, from_parent_view, - r_ptrs, tmp_scratch.device_ptr(), ttg::device::current_stream()); + r_ptrs, tmp_scratch.device_ptr(), K, ttg::device::current_stream()); // forward() returns a vector that we can push into #ifndef TTG_ENABLE_HOST @@ -426,7 +426,7 @@ void test(std::size_t K) { auto gauss_buffer = ttg::Buffer>(&gaussian); auto start = make_start(project_control); auto project = make_project(D, gauss_buffer, K, functiondata, T(1e-6), project_control, project_result); - auto compress = make_compress(functiondata, project_result, compress_result); + auto compress = make_compress(K, functiondata, project_result, compress_result); auto reconstruct = make_reconstruct(K, functiondata, compress_result, reconstruct_result); auto printer = make_printer(project_result, "projected ", false); auto printer2 = make_printer(compress_result, "compressed ", false);