diff --git a/include/ngp_utils/SmartFieldRef.h b/include/ngp_utils/SmartFieldRef.h index 8baedc067..fb75e8c0e 100644 --- a/include/ngp_utils/SmartFieldRef.h +++ b/include/ngp_utils/SmartFieldRef.h @@ -14,193 +14,307 @@ #include namespace tags { -//clang-format off -struct READ -{ -}; -struct WRITE -{ -}; -struct READ_WRITE -{ -}; +// clang-format off -struct HOST -{ -}; -struct DEVICE -{ -}; -//clang-format on +//ACCESS TYPES +struct READ{}; +struct WRITE{}; +struct READ_WRITE{}; + +// MEMSPACE +struct HOST{}; +struct DEVICE{}; +struct LEGACY{}; + +// clang-format on } // namespace tags namespace sierra::nalu { using namespace tags; -template +template class SmartFieldRef +{}; + +template +class SmartFieldRef::value>> { +public: + using T = typename FieldType::value_type; + + SmartFieldRef(FieldType& fieldRef) : fieldRef_(fieldRef) {} + SmartFieldRef(const SmartFieldRef& src) + : fieldRef_(src.fieldRef_) + { + if (is_read_){ + fieldRef_.sync_to_host(); + } + else{ + fieldRef_.clear_sync_state(); + } + } + // --- Default Accessors + template + typename std::enable_if_t::value, T>& + get(const stk::mesh::Entity& entity) const + { + return *stk::mesh::field_data(fieldRef_, entity); + } + + template + typename std::enable_if_t::value, T>& + operator()(const stk::mesh::Entity& entity) const + { + return *stk::mesh::field_data(fieldRef_, entity); + } + + // --- Const Accessors + template + const typename std::enable_if_t::value, T>& + get(const stk::mesh::Entity& entity) const + { + return *stk::mesh::field_data(fieldRef_, entity); + } + + template + const typename std::enable_if_t::value, T>& + operator()(const stk::mesh::Entity& entity) const + { + return *stk::mesh::field_data(fieldRef_, entity); + } + + + ~SmartFieldRef() + { + if (is_write_) { + fieldRef_.modify_on_host(); + } + } +private: + static constexpr bool is_read_ + { + std::is_same::value || + std::is_same::value + }; + + static constexpr bool is_write_ + { + std::is_same::value || + std::is_same::value + }; + + FieldType& fieldRef_; }; -template -class SmartFieldRef +template +class SmartFieldRef::value>> { public: - SmartFieldRef(stk::mesh::NgpField& ngpField) : fieldRef_(ngpField) {} + using T = typename FieldType::value_type; + + SmartFieldRef(FieldType fieldRef) : fieldRef_(fieldRef) {} SmartFieldRef(const SmartFieldRef& src) : fieldRef_(src.fieldRef_), is_copy_constructed_(true) { - if (is_read()) - fieldRef_.sync_to_device(); - else + if (is_read_){ + if(is_device_space){ + fieldRef_.sync_to_device(); + } + else{ + fieldRef_.sync_to_host(); + } + } + else{ fieldRef_.clear_sync_state(); + } } - // device implementations should only ever execute inside a - // kokkos::paralle_for and hence be captured by a lambda. Therefore we only - // ever need to sync copies that will have been snatched up through lambda - // capture. ~SmartFieldRef() { - if (is_copy_constructed_ && is_write()) { - fieldRef_.modify_on_device(); + if (is_write_) { + if(is_copy_constructed_){ + // device implementations should only ever execute inside a + // kokkos::paralle_for and hence be captured by a lambda. Therefore we only + // ever need to sync copies that will have been snatched up through lambda + // capture. + fieldRef_.modify_on_device(); + } + else{ + // try not requiring copy mechanism for host + fieldRef_.modify_on_host(); + } } } - KOKKOS_INLINE_FUNCTION - unsigned get_ordinal() const { return fieldRef_.get_ordinal(); } + //************************************************************ + // Host functions (Remove KOKKOS_FUNCTION decorators) + //************************************************************ + template + std::enable_if_t::value, unsigned> + get_ordinal() const { return fieldRef_.get_ordinal(); } - // TODO make it so these accessors are read only for read type i.e. const - // correct and give clear compile or runtime error for programming mistakes - KOKKOS_INLINE_FUNCTION - T& get(stk::mesh::FastMeshIndex& index, int component) const + // --- Default Accessors + template + std::enable_if_t::value && !std::is_same::value, T>& + get(stk::mesh::FastMeshIndex& index, int component) const { return fieldRef_.get(index, component); } - template - KOKKOS_INLINE_FUNCTION T& get(MeshIndex index, int component) const + template + std::enable_if_t::value && !std::is_same::value, T>& + get(MeshIndex index, int component) const { return fieldRef_.get(index, component); } - KOKKOS_INLINE_FUNCTION - T& operator()(const stk::mesh::FastMeshIndex& index, int component) const + template + std::enable_if_t::value && !std::is_same::value, T>& + operator()(const stk::mesh::FastMeshIndex& index, int component) const { return fieldRef_.get(index, component); } - template - KOKKOS_INLINE_FUNCTION T& + template + std::enable_if_t::value && !std::is_same::value, T>& operator()(const MeshIndex index, int component) const { return fieldRef_.operator()(index, component); } -private: - bool is_read() + // --- Const Accessors + template + const std::enable_if_t::value && std::is_same::value, T>& + get(stk::mesh::FastMeshIndex& index, int component) const { - return std::is_same::value || - std::is_same::value; + return fieldRef_.get(index, component); } - bool is_write() + template + const std::enable_if_t::value && std::is_same::value, T>& + get(MeshIndex index, int component) const { - return std::is_same::value || - std::is_same::value; + return fieldRef_.get(index, component); } - stk::mesh::NgpField& fieldRef_; - const bool is_copy_constructed_{false}; -}; + template + const std::enable_if_t::value && std::is_same::value, T>& + operator()(const stk::mesh::FastMeshIndex& index, int component) const + { + return fieldRef_.get(index, component); + } -// HOST specialization using legacy bucket loops -// TODO would we ever/can we use stk::mesh::HostField's inside a device enabled -// build? -// If so I think we should change this to LEGACY instead of HOST -template -class SmartFieldRef -{ -public: - SmartFieldRef(stk::mesh::Field& field) : fieldRef_(field) + template + const std::enable_if_t::value && std::is_same::value, T>& + operator()(const MeshIndex index, int component) const { - if (is_read()) - fieldRef_.sync_to_host(); - else - fieldRef_.clear_sync_state(); + return fieldRef_.operator()(index, component); } + //************************************************************ + // Device functions + //************************************************************ + KOKKOS_FUNCTION + template + std::enable_if_t::value, unsigned> + get_ordinal() const { return fieldRef_.get_ordinal(); } - SmartFieldRef(const SmartFieldRef& src) - : fieldRef_(src.fieldRef_), is_copy_constructed_(true) + // --- Default Accessors + KOKKOS_FUNCTION + template + std::enable_if_t::value && !std::is_same::value, T>& + get(stk::mesh::FastMeshIndex& index, int component) const { - if (is_read()) - fieldRef_.sync_to_host(); - else - fieldRef_.clear_sync_state(); + return fieldRef_.get(index, component); } - // try removing the copy constructor requirement for host fields - ~SmartFieldRef() + KOKKOS_FUNCTION + template + std::enable_if_t::value && !std::is_same::value, T>& + get(MeshIndex index, int component) const { - if (is_write()) { - fieldRef_.modify_on_host(); - } + return fieldRef_.get(index, component); } - template - const typename std::enable_if_t::value, T>& - get(const stk::mesh::Entity& entity) const + KOKKOS_FUNCTION + template + std::enable_if_t::value && !std::is_same::value, T>& + operator()(const stk::mesh::FastMeshIndex& index, int component) const { - return *stk::mesh::field_data(fieldRef_, entity); + return fieldRef_.get(index, component); } - template - const typename std::enable_if_t::value, T>& - operator()(const stk::mesh::Entity& entity) const + KOKKOS_FUNCTION + template + std::enable_if_t::value && !std::is_same::value, T>& + operator()(const MeshIndex index, int component) const { - return *stk::mesh::field_data(fieldRef_, entity); + return fieldRef_.operator()(index, component); } - template - typename std::enable_if_t::value, T>& - get(const stk::mesh::Entity& entity) const + // --- Const Accessors + KOKKOS_FUNCTION + template + const std::enable_if_t::value && std::is_same::value, T>& + get(stk::mesh::FastMeshIndex& index, int component) const { - return *stk::mesh::field_data(fieldRef_, entity); + return fieldRef_.get(index, component); } - template - typename std::enable_if_t::value, T>& - operator()(const stk::mesh::Entity& entity) const + KOKKOS_FUNCTION + template + const std::enable_if_t::value && std::is_same::value, T>& + get(MeshIndex index, int component) const { - return *stk::mesh::field_data(fieldRef_, entity); + return fieldRef_.get(index, component); } -private: - bool is_read() + KOKKOS_FUNCTION + template + const std::enable_if_t::value && std::is_same::value, T>& + operator()(const stk::mesh::FastMeshIndex& index, int component) const { - return std::is_same::value || - std::is_same::value; + return fieldRef_.get(index, component); } - bool is_write() + KOKKOS_FUNCTION + template + const std::enable_if_t::value && std::is_same::value, T>& + operator()(const MeshIndex index, int component) const { - return std::is_same::value || - std::is_same::value; + return fieldRef_.operator()(index, component); } - stk::mesh::Field& fieldRef_; +private: + static constexpr bool is_device_space + { + std::is_same::value + }; + + static constexpr bool is_read_ + { + std::is_same::value || + std::is_same::value + }; + + static constexpr bool is_write_ + { + std::is_same::value || + std::is_same::value + }; + + FieldType fieldRef_; const bool is_copy_constructed_{false}; }; template struct MakeFieldRef { - template - SmartFieldRef operator()(T& field) + template + SmartFieldRef operator()(FieldType& field) { - return SmartFieldRef(field); + return SmartFieldRef(field); } }; diff --git a/unit_tests/UnitTestSmartFieldRef.C b/unit_tests/UnitTestSmartFieldRef.C index 82003bd22..f6ac9bdf6 100644 --- a/unit_tests/UnitTestSmartFieldRef.C +++ b/unit_tests/UnitTestSmartFieldRef.C @@ -73,7 +73,6 @@ TEST_F(TestSmartFieldRef, device_read_write_mod_sync_with_lambda) ASSERT_TRUE(ngpField_->need_sync_to_device()); - // TODO can we get rid of the double template param some how? auto sPtr = MakeFieldRef()(*ngpField_); lambda_ordinal(sPtr); @@ -131,7 +130,7 @@ TEST_F(TestSmartFieldRef, update_field_on_device_check_on_host) int counter = 0; auto* field = fieldManager->get_field_ptr("scalarQ"); auto fieldRef = - sierra::nalu::MakeFieldRef()(*field); + sierra::nalu::MakeFieldRef()(*field); stk::mesh::Selector sel = stk::mesh::selectUnion(partVec); const auto& buckets = bulk->get_buckets(stk::topology::NODE_RANK, sel); for (auto b : buckets) {