Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2212 add message properties for generated messages #2213

Merged
merged 12 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/vt/collective/reduce/reduce_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@

namespace vt { namespace collective { namespace reduce {

static std::unique_ptr<Reduce> makeReduceScope(detail::ReduceScope const& scope) {
return std::make_unique<Reduce>(scope);
}

ReduceManager::ReduceManager()
: reducers_( // default cons reducer for non-group
[](detail::ReduceScope const& scope) {
return std::make_unique<Reduce>(scope);
}
)
: reducers_(makeReduceScope)
{
// insert the default reducer scope
reducers_.make(
Expand Down
6 changes: 4 additions & 2 deletions src/vt/messaging/active.h
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,8 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
PendingSendType send(Node dest, Params&&... params) {
using Tuple = typename FuncTraits<decltype(f)>::TupleType;
using MsgT = ParamMsg<Tuple>;
auto msg = vt::makeMessage<MsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::forward<Params>(params)...);
auto han = auto_registry::makeAutoHandlerParam<decltype(f), f, MsgT>();
return sendMsg<MsgT>(dest.get(), han, msg, no_tag);
}
Expand All @@ -782,7 +783,8 @@ struct ActiveMessenger : runtime::component::PollableComponent<ActiveMessenger>
PendingSendType broadcast(Params&&... params) {
using Tuple = typename FuncTraits<decltype(f)>::TupleType;
using MsgT = ParamMsg<Tuple>;
auto msg = vt::makeMessage<MsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::forward<Params>(params)...);
auto han = auto_registry::makeAutoHandlerParam<decltype(f), f, MsgT>();
constexpr bool deliver_to_sender = true;
return broadcastMsg<MsgT>(han, msg, deliver_to_sender, no_tag);
Expand Down
2 changes: 2 additions & 0 deletions src/vt/messaging/active.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,4 +512,6 @@ inline EpochType ActiveMessenger::setupEpochMsg(MsgSharedPtr<MsgT> const& msg) {

}} //end namespace vt::messaging

#include "vt/messaging/param_msg.impl.h"

#endif /*INCLUDED_VT_MESSAGING_ACTIVE_IMPL_H*/
146 changes: 132 additions & 14 deletions src/vt/messaging/param_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,103 @@

#include "vt/messaging/message/message_serialize.h"

namespace vt { namespace messaging {
namespace vt {

struct MsgProps {

MsgProps() = default;

MsgProps&& asLocationMsg(bool set = true) {
as_location_msg_ = set;
return std::move(*this);
}

MsgProps&& asTerminationMsg(bool set = true) {
as_termination_msg_ = set;
return std::move(*this);
}

MsgProps&& asCollectionMsg(bool set = true) {
as_collection_msg_ = set;
return std::move(*this);
}

MsgProps&& asSerializationMsg(bool set = true) {
as_serial_msg_ = set;
return std::move(*this);
}

MsgProps&& withEpoch(EpochType in_ep) {
ep_ = in_ep;
return std::move(*this);
}

MsgProps&& withPriority(PriorityType in_priority) {
#if vt_check_enabled(priorities)
priority_ = in_priority;
#endif
return std::move(*this);
}

MsgProps&& withPriorityLevel(PriorityLevelType in_priority_level) {
#if vt_check_enabled(priorities)
priority_level_ = in_priority_level;
#endif
return std::move(*this);
}

template <typename MsgPtrT>
void apply(MsgPtrT msg);

private:
bool as_location_msg_ = false;
bool as_termination_msg_ = false;
bool as_serial_msg_ = false;
bool as_collection_msg_ = false;
EpochType ep_ = no_epoch;
#if vt_check_enabled(priorities)
PriorityType priority_ = no_priority;
PriorityLevelType priority_level_ = no_priority_level;
#endif
};

} /* end namespace vt */

namespace vt::messaging::detail {

template <typename enabled_, typename... Params>
struct GetTraits;

template <>
struct GetTraits<std::enable_if_t<std::is_same_v<void, void>>> {
using TupleType = std::tuple<>;
};

template <typename Param, typename... Params>
struct GetTraits<
std::enable_if_t<std::is_same_v<MsgProps, Param>>, Param, Params...
> {
using TupleType = std::tuple<Params...>;
};

template <typename Param, typename... Params>
struct GetTraits<
std::enable_if_t<not std::is_same_v<MsgProps, Param>>, Param, Params...
> {
using TupleType = std::tuple<Param, Params...>;
};

template <typename Tuple>
struct GetTraitsTuple;

template <typename... Params>
struct GetTraitsTuple<std::tuple<Params...>> {
using TupleType = typename GetTraits<void, Params...>::TupleType;
};

} /* end namespace vt::messaging::detail */

namespace vt::messaging {

template <typename Tuple, typename enabled = void>
struct ParamMsg;
Expand All @@ -56,15 +152,24 @@ struct ParamMsg<
Tuple, std::enable_if_t<is_byte_copyable_t<Tuple>::value>
> : vt::Message
{
using TupleType = typename detail::GetTraitsTuple<Tuple>::TupleType;

ParamMsg() = default;

template <typename... Params>
explicit ParamMsg(Params&&... in_params)
: params(std::forward<Params>(in_params)...)
{ }
void setParams() { }

template <typename Param, typename... Params>
void setParams(Param&& p, Params&&... in_params) {
if constexpr (std::is_same_v<std::decay_t<Param>, MsgProps>) {
params = TupleType{std::forward<Params>(in_params)...};
p.apply(this);
} else {
params = TupleType{std::forward<Param>(p), std::forward<Params>(in_params)...};
}
}

Tuple params;
Tuple& getTuple() { return params; }
TupleType params;
TupleType& getTuple() { return params; }
};

template <typename Tuple>
Expand All @@ -75,16 +180,29 @@ struct ParamMsg<
using MessageParentType = vt::Message;
vt_msg_serialize_if_needed_by_parent_or_type1(Tuple); // by tup

using TupleType = typename detail::GetTraitsTuple<Tuple>::TupleType;

ParamMsg() = default;

template <typename... Params>
explicit ParamMsg(Params&&... in_params)
: params(std::make_unique<Tuple>(std::forward<Params>(in_params)...))
{ }
void setParams() {
params = std::make_unique<TupleType>();
}

template <typename Param, typename... Params>
void setParams(Param&& p, Params&&... in_params) {
if constexpr (std::is_same_v<std::decay_t<Param>, MsgProps>) {
params = std::make_unique<TupleType>(std::forward<Params>(in_params)...);
p.apply(this);
} else {
params = std::make_unique<TupleType>(
std::forward<Param>(p), std::forward<Params>(in_params)...
);
}
}

std::unique_ptr<Tuple> params;
std::unique_ptr<TupleType> params;

Tuple& getTuple() { return *params.get(); }
TupleType& getTuple() { return *params.get(); }

template <typename SerializerT>
void serialize(SerializerT& s) {
Expand All @@ -93,6 +211,6 @@ struct ParamMsg<
}
};

}} /* end namespace vt::messaging */
} /* end namespace vt::messaging */

#endif /*INCLUDED_VT_MESSAGING_PARAM_MSG_H*/
80 changes: 80 additions & 0 deletions src/vt/messaging/param_msg.impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
//@HEADER
// *****************************************************************************
//
// param_msg.impl.h
// DARMA/vt => Virtual Transport
//
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact darma@sandia.gov
//
// *****************************************************************************
//@HEADER
*/

#if !defined INCLUDED_VT_MESSAGING_PARAM_MSG_IMPL_H
#define INCLUDED_VT_MESSAGING_PARAM_MSG_IMPL_H

#include "vt/messaging/param_msg.h"

namespace vt {

template <typename MsgPtrT>
void MsgProps::apply(MsgPtrT msg) {
if (as_location_msg_) {
theMsg()->markAsLocationMessage(msg);
}
if (as_termination_msg_) {
theMsg()->markAsTermMessage(msg);
}
if (as_serial_msg_) {
theMsg()->markAsSerialMsgMessage(msg);
}
if (as_collection_msg_) {
theMsg()->markAsCollectionMessage(msg);
}
if (ep_ != no_epoch) {
envelopeSetEpoch(msg->env, ep_);
}
#if vt_check_enabled(priorities)
if (priority_ != no_priority) {
envelopeSetPriority(msg->env, priority_);
}
if (priority_level_ != no_priority_level) {
envelopeSetPriorityLevel(msg->env, priority_level_);
}
#endif
}

} /* end namespace vt */

#endif /*INCLUDED_VT_MESSAGING_PARAM_MSG_IMPL_H*/
9 changes: 9 additions & 0 deletions src/vt/messaging/pending_send.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ struct PendingSend final {
*/
void release();

/**
* \internal \brief Get the message stored in the pending send
*
* \note Used for testing purposes
*
* \return a reference to the message
*/
MsgPtr<BaseMsgType>& getMsg() { return msg_; }

private:

/**
Expand Down
6 changes: 4 additions & 2 deletions src/vt/objgroup/proxy/proxy_objgroup.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ Proxy<ObjT>::broadcast(Params&&... params) const {
if constexpr (std::is_same_v<MsgT, NoMsg>) {
using Tuple = typename ObjFuncTraits<decltype(f)>::TupleType;
using SendMsgT = messaging::ParamMsg<Tuple>;
auto msg = vt::makeMessage<SendMsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<SendMsgT>();
msg->setParams(std::forward<Params>(params)...);
auto const ctrl = proxy::ObjGroupProxy::getID(proxy_);
auto const han = auto_registry::makeAutoHandlerObjGroupParam<
ObjT, decltype(f), f, SendMsgT
Expand All @@ -117,7 +118,8 @@ Proxy<ObjT>::multicast(GroupType type, Params&&... params) const{
if constexpr (std::is_same_v<MsgT, NoMsg>) {
using Tuple = typename ObjFuncTraits<decltype(f)>::TupleType;
using SendMsgT = messaging::ParamMsg<Tuple>;
auto msg = vt::makeMessage<SendMsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<SendMsgT>();
msg->setParams(std::forward<Params>(params)...);
vt::envelopeSetGroup(msg->env, type);
auto const ctrl = proxy::ObjGroupProxy::getID(proxy_);
auto const han = auto_registry::makeAutoHandlerObjGroupParam<
Expand Down
3 changes: 2 additions & 1 deletion src/vt/objgroup/proxy/proxy_objgroup_elm.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ ProxyElm<ObjT>::send(Params&&... params) const {
if constexpr (std::is_same_v<MsgT, NoMsg>) {
using Tuple = typename ObjFuncTraits<decltype(f)>::TupleType;
using SendMsgT = messaging::ParamMsg<Tuple>;
auto msg = vt::makeMessage<SendMsgT>(std::forward<Params>(params)...);
auto msg = vt::makeMessage<SendMsgT>();
msg->setParams(std::forward<Params>(params)...);
auto const ctrl = proxy::ObjGroupProxy::getID(proxy_);
auto const han = auto_registry::makeAutoHandlerObjGroupParam<
ObjT, decltype(f), f, SendMsgT
Expand Down
30 changes: 26 additions & 4 deletions src/vt/pipe/callback/cb_union/cb_raw_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,39 @@ struct CallbackTyped : CallbackRawBaseSingle {
void sendTuple(std::tuple<Params...> tup) {
using Trait = CBTraits<Args...>;
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
auto msg = vt::makeMessage<MsgT>(std::move(tup));
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::move(tup));
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
}

template <typename... Params>
void send(Params&&... params) {
using Trait = CBTraits<Args...>;
if constexpr (std::is_same_v<typename Trait::MsgT, NoMsg>) {
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
auto msg = vt::makeMessage<MsgT>(std::forward<Params>(params)...);
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
// We have to go through some tricky code to make the MsgProps case work
// If we use the type for Params to send, it's possible that we have a
// type mismatch in the actual handler type. A possible edge case is when
// a char const* is sent, but the handler is a std::string. In this case,
// the ParamMsg will be cast incorrectly during the virual dispatch to a
// collection because callbacks don't have the collection type. Thus, the
// wrong ParamMsg will be cast to which requires serialization, leading to
// a failure.
if constexpr (sizeof...(Params) == sizeof...(Args) + 1) {
using MsgT = messaging::ParamMsg<
std::tuple<
std::decay_t<std::tuple_element_t<0, std::tuple<Params...>>>,
std::decay_t<Args>...
>
>;
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::forward<Params>(params)...);
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
} else {
using MsgT = messaging::ParamMsg<typename Trait::TupleType>;
auto msg = vt::makeMessage<MsgT>();
msg->setParams(std::forward<Params>(params)...);
CallbackRawBaseSingle::sendMsg<MsgT>(msg);
}
} else {
using MsgT = typename Trait::MsgT;
auto msg = makeMessage<MsgT>(std::forward<Params>(params)...);
Expand Down
Loading
Loading