From 714b731d45959d1686997455841ae221be4f4aec Mon Sep 17 00:00:00 2001 From: Evan Ramos Date: Wed, 15 Sep 2021 13:28:04 -0500 Subject: [PATCH] Converse: Strengthen locking in node reductions (#3481) It is possible for node reduction messages to arrive at the root before it has completed its own CmiNodeReduce* call. Therefore, lock around all accesses of _nodereduce_info, and make _nodereduce_seqID* atomic. --- src/conv-core/convcore.C | 128 +++++++++++++++++++++++++++------------ src/conv-core/converse.h | 3 + 2 files changed, 91 insertions(+), 40 deletions(-) diff --git a/src/conv-core/convcore.C b/src/conv-core/convcore.C index 604d837aa4..1f0207da07 100644 --- a/src/conv-core/convcore.C +++ b/src/conv-core/convcore.C @@ -2414,9 +2414,8 @@ void CmiSyncVectorSendAndFree(int destPE, int n, int *sizes, char **msgs) { * merge call will not be deleted by the system, and the CmiHandler function * will be in charge of its deletion. * - * CmiReduce/CmiReduceStruct MUST be called once by every processor, - * CmiNodeReduce/CmiNodeReduceStruct MUST be called once by every node, and in - * particular by the rank zero in each node. + * CmiReduce/CmiReduceStruct MUST be called once by every processor. + * CmiNodeReduce/CmiNodeReduceStruct MUST be called once by every node. ****************************************************************************/ #define REDUCTION_DEBUG 0 @@ -2454,6 +2453,20 @@ struct CmiNodeReduction { CmiReduction * red; }; +static inline CmiReductionID CmiReductionIDFetchAdd(CmiReductionID & id, CmiReductionID addend) { + const CmiReductionID oldid = id; + id = oldid + addend; + return oldid; +} +#if CMK_SMP +static inline CmiReductionID CmiReductionIDFetchAdd(std::atomic & id, CmiReductionID addend) { + return id.fetch_add(addend); +} +using CmiNodeReductionID = std::atomic; +#else +using CmiNodeReductionID = CmiReductionID; +#endif + CpvStaticDeclare(int, CmiReductionMessageHandler); CpvStaticDeclare(int, CmiReductionDynamicRequestHandler); @@ -2466,20 +2479,27 @@ CpvStaticDeclare(CmiReductionID, _reduce_seqID_request); CpvStaticDeclare(CmiReductionID, _reduce_seqID_dynamic); CsvStaticDeclare(CmiNodeReduction *, _nodereduce_info); -CsvStaticDeclare(CmiReductionID, _nodereduce_seqID_global); -CsvStaticDeclare(CmiReductionID, _nodereduce_seqID_request); -CsvStaticDeclare(CmiReductionID, _nodereduce_seqID_dynamic); +CsvStaticDeclare(CmiNodeReductionID, _nodereduce_seqID_global); +CsvStaticDeclare(CmiNodeReductionID, _nodereduce_seqID_request); +CsvStaticDeclare(CmiNodeReductionID, _nodereduce_seqID_dynamic); enum : CmiReductionID { CmiReductionID_globalOffset = 0, /* Reductions that involve the whole set of processors */ CmiReductionID_requestOffset = 1, /* Reductions IDs that are requested by all the processors (i.e during intialization) */ CmiReductionID_dynamicOffset = 2, /* Reductions IDs that are requested by only one processor (typically at runtime) */ - CmiReductionID_multiplier = 3 + + CmiReductionID_multiplier = 4 }; +static_assert(CmiIsPow2(CmiReductionID_multiplier), + "CmiReductionID_multiplier must be a power of two because seqID counters may overflow and wrap to 0"); + +static inline unsigned int CmiGetRedIndex(CmiReductionID id) { + return id & ~((~0u) << CmiLogMaxReductions); +} + static CmiReduction* CmiGetReductionCreate(int id, short int numChildren) { - const int idx = id & ~((~0u) << CmiLogMaxReductions); - auto & redref = CpvAccess(_reduce_info)[idx]; + auto & redref = CpvAccess(_reduce_info)[CmiGetRedIndex(id)]; CmiReduction *red = redref; if (red != NULL && red->seqID != id) { /* The table needs to be expanded */ @@ -2506,33 +2526,27 @@ static CmiReduction* CmiGetReductionCreate(int id, short int numChildren) { } static void CmiClearReduction(int id) { - const int idx = id & ~((~0u) << CmiLogMaxReductions); - auto & redref = CpvAccess(_reduce_info)[idx]; + auto & redref = CpvAccess(_reduce_info)[CmiGetRedIndex(id)]; auto red = redref; redref = nullptr; free(red); } -static CmiReduction* CmiGetNextReduction(short int numChildren) { - int id = CpvAccess(_reduce_seqID_global); - int newid = id + CmiReductionID_multiplier; - if (id > 0xFFF0) newid = CmiReductionID_globalOffset; - CpvAccess(_reduce_seqID_global) = newid; - return CmiGetReductionCreate(id, numChildren); +static CmiReductionID CmiGetNextReductionID(void) { + return CmiReductionIDFetchAdd(CpvAccess(_reduce_seqID_global), CmiReductionID_multiplier); } CmiReductionID CmiGetGlobalReduction(void) { - return CpvAccess(_reduce_seqID_request)+=CmiReductionID_multiplier; + return CmiReductionIDFetchAdd(CpvAccess(_reduce_seqID_request), CmiReductionID_multiplier); } CmiReductionID CmiGetDynamicReduction(void) { if (CmiMyPe() != 0) CmiAbort("Cannot call CmiGetDynamicReduction on processors other than zero!\n"); - return CpvAccess(_reduce_seqID_dynamic)+=CmiReductionID_multiplier; + return CmiReductionIDFetchAdd(CpvAccess(_reduce_seqID_dynamic), CmiReductionID_multiplier); } static CmiReduction* CmiGetNodeReductionCreate(int id, short int numChildren) { - const int idx = id & ~((~0u) << CmiLogMaxReductions); - auto & redref = CsvAccess(_nodereduce_info)[idx].red; + auto & redref = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)].red; CmiReduction *red = redref; if (red != NULL && red->seqID != id) { /* The table needs to be expanded */ @@ -2559,28 +2573,23 @@ static CmiReduction* CmiGetNodeReductionCreate(int id, short int numChildren) { } static void CmiClearNodeReduction(int id) { - const int idx = id & ~((~0u) << CmiLogMaxReductions); - auto & redref = CsvAccess(_nodereduce_info)[idx].red; + auto & redref = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)].red; auto red = redref; redref = nullptr; free(red); } -static CmiReduction* CmiGetNextNodeReduction(short int numChildren) { - int id = CsvAccess(_nodereduce_seqID_global); - int newid = id + CmiReductionID_multiplier; - if (id > 0xFFF0) newid = CmiReductionID_globalOffset; - CsvAccess(_nodereduce_seqID_global) = newid; - return CmiGetNodeReductionCreate(id, numChildren); +static CmiReductionID CmiGetNextNodeReductionID(void) { + return CmiReductionIDFetchAdd(CsvAccess(_nodereduce_seqID_global), CmiReductionID_multiplier); } CmiReductionID CmiGetGlobalNodeReduction(void) { - return CsvAccess(_nodereduce_seqID_request)+=CmiReductionID_multiplier; + return CmiReductionIDFetchAdd(CsvAccess(_nodereduce_seqID_request), CmiReductionID_multiplier); } CmiReductionID CmiGetDynamicNodeReduction(void) { if (CmiMyNode() != 0) CmiAbort("Cannot call CmiGetDynamicNodeReduction on nodes other than zero!\n"); - return CsvAccess(_nodereduce_seqID_dynamic)+=CmiReductionID_multiplier; + return CmiReductionIDFetchAdd(CsvAccess(_nodereduce_seqID_dynamic), CmiReductionID_multiplier); } void CmiReductionHandleDynamicRequest(char *msg) { @@ -2797,14 +2806,16 @@ static void CmiGlobalNodeReduceStruct(void *data, CmiReducePupFn pupFn, } void CmiReduce(void *msg, int size, CmiReduceMergeFn mergeFn) { - CmiReduction *red = CmiGetNextReduction(CmiNumSpanTreeChildren(CmiMyPe())); + const CmiReductionID id = CmiGetNextReductionID(); + CmiReduction *red = CmiGetReductionCreate(id, CmiNumSpanTreeChildren(CmiMyPe())); CmiGlobalReduce(msg, size, mergeFn, red); } void CmiReduceStruct(void *data, CmiReducePupFn pupFn, CmiReduceMergeFn mergeFn, CmiHandler dest, CmiReduceDeleteFn deleteFn) { - CmiReduction *red = CmiGetNextReduction(CmiNumSpanTreeChildren(CmiMyPe())); + const CmiReductionID id = CmiGetNextReductionID(); + CmiReduction *red = CmiGetReductionCreate(id, CmiNumSpanTreeChildren(CmiMyPe())); CmiGlobalReduceStruct(data, pupFn, mergeFn, dest, deleteFn, red); } @@ -2884,27 +2895,65 @@ void CmiGroupReduceStruct(CmiGroup grp, void *data, CmiReducePupFn pupFn, } void CmiNodeReduce(void *msg, int size, CmiReduceMergeFn mergeFn) { - CmiReduction *red = CmiGetNextNodeReduction(CmiNumNodeSpanTreeChildren(CmiMyNode())); + const CmiReductionID id = CmiGetNextNodeReductionID(); +#if CMK_SMP + CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)]; + CmiLock(nodered.lock); +#endif + + CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode())); CmiGlobalNodeReduce(msg, size, mergeFn, red); + +#if CMK_SMP + CmiUnlock(nodered.lock); +#endif } void CmiNodeReduceStruct(void *data, CmiReducePupFn pupFn, CmiReduceMergeFn mergeFn, CmiHandler dest, CmiReduceDeleteFn deleteFn) { - CmiReduction *red = CmiGetNextNodeReduction(CmiNumNodeSpanTreeChildren(CmiMyNode())); + const CmiReductionID id = CmiGetNextNodeReductionID(); +#if CMK_SMP + CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)]; + CmiLock(nodered.lock); +#endif + + CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode())); CmiGlobalNodeReduceStruct(data, pupFn, mergeFn, dest, deleteFn, red); + +#if CMK_SMP + CmiUnlock(nodered.lock); +#endif } void CmiNodeReduceID(void *msg, int size, CmiReduceMergeFn mergeFn, CmiReductionID id) { +#if CMK_SMP + CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)]; + CmiLock(nodered.lock); +#endif + CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode())); CmiGlobalNodeReduce(msg, size, mergeFn, red); + +#if CMK_SMP + CmiUnlock(nodered.lock); +#endif } void CmiNodeReduceStructID(void *data, CmiReducePupFn pupFn, CmiReduceMergeFn mergeFn, CmiHandler dest, CmiReduceDeleteFn deleteFn, CmiReductionID id) { +#if CMK_SMP + CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)]; + CmiLock(nodered.lock); +#endif + CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode())); CmiGlobalNodeReduceStruct(data, pupFn, mergeFn, dest, deleteFn, red); + +#if CMK_SMP + CmiUnlock(nodered.lock); +#endif } static void CmiHandleReductionMessage(void *msg) { @@ -2921,8 +2970,7 @@ static void CmiHandleReductionMessage(void *msg) { static void CmiHandleNodeReductionMessage(void *msg) { const auto id = CmiGetRedID(msg); #if CMK_SMP - const int idx = id & ~((~0u) << CmiLogMaxReductions); - CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[idx]; + CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)]; CmiLock(nodered.lock); #endif @@ -2964,11 +3012,11 @@ void CmiReductionsInit(void) { if (CmiMyRank() == 0) { - CsvInitialize(CmiReductionID, _nodereduce_seqID_global); + CsvInitialize(CmiNodeReductionID, _nodereduce_seqID_global); CsvAccess(_nodereduce_seqID_global) = CmiReductionID_globalOffset; - CsvInitialize(CmiReductionID, _nodereduce_seqID_request); + CsvInitialize(CmiNodeReductionID, _nodereduce_seqID_request); CsvAccess(_nodereduce_seqID_request) = CmiReductionID_requestOffset; - CsvInitialize(CmiReductionID, _nodereduce_seqID_dynamic); + CsvInitialize(CmiNodeReductionID, _nodereduce_seqID_dynamic); CsvAccess(_nodereduce_seqID_dynamic) = CmiReductionID_dynamicOffset; CsvInitialize(CmiNodeReduction *, _nodereduce_info); auto noderedinfo = (CmiNodeReduction *)malloc(CmiMaxReductions * sizeof(CmiNodeReduction)); diff --git a/src/conv-core/converse.h b/src/conv-core/converse.h index c340c10fff..bd7b2bca42 100644 --- a/src/conv-core/converse.h +++ b/src/conv-core/converse.h @@ -78,6 +78,9 @@ #define CMI_MSG_NOKEEP(msg) ((CmiMsgHeaderBasic *)msg)->nokeep +#define CmiIsPow2OrZero(v) (((v) & ((v) - 1)) == 0) +#define CmiIsPow2(v) (CmiIsPow2OrZero(v) && (v)) + #define CMIALIGN(x,n) (size_t)((~((size_t)n-1))&((x)+(n-1))) /*#define ALIGN8(x) (size_t)((~7)&((x)+7)) */ #define ALIGN8(x) CMIALIGN(x,8)