Skip to content

Commit

Permalink
Converse: Strengthen locking in node reductions (#3481)
Browse files Browse the repository at this point in the history
It is possible for node reduction messages to arrive at the root before
it has completed its own CmiNodeReduce* call. Therefore, lock around
all accesses of _nodereduce_info, and make _nodereduce_seqID* atomic.
  • Loading branch information
evan-charmworks committed Sep 15, 2021
1 parent 0746e1b commit 714b731
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 40 deletions.
128 changes: 88 additions & 40 deletions src/conv-core/convcore.C
Original file line number Diff line number Diff line change
Expand Up @@ -2414,9 +2414,8 @@ void CmiSyncVectorSendAndFree(int destPE, int n, int *sizes, char **msgs) {
* merge call will not be deleted by the system, and the CmiHandler function
* will be in charge of its deletion.
*
* CmiReduce/CmiReduceStruct MUST be called once by every processor,
* CmiNodeReduce/CmiNodeReduceStruct MUST be called once by every node, and in
* particular by the rank zero in each node.
* CmiReduce/CmiReduceStruct MUST be called once by every processor.
* CmiNodeReduce/CmiNodeReduceStruct MUST be called once by every node.
****************************************************************************/

#define REDUCTION_DEBUG 0
Expand Down Expand Up @@ -2454,6 +2453,20 @@ struct CmiNodeReduction {
CmiReduction * red;
};

static inline CmiReductionID CmiReductionIDFetchAdd(CmiReductionID & id, CmiReductionID addend) {
const CmiReductionID oldid = id;
id = oldid + addend;
return oldid;
}
#if CMK_SMP
static inline CmiReductionID CmiReductionIDFetchAdd(std::atomic<CmiReductionID> & id, CmiReductionID addend) {
return id.fetch_add(addend);
}
using CmiNodeReductionID = std::atomic<CmiReductionID>;
#else
using CmiNodeReductionID = CmiReductionID;
#endif

CpvStaticDeclare(int, CmiReductionMessageHandler);
CpvStaticDeclare(int, CmiReductionDynamicRequestHandler);

Expand All @@ -2466,20 +2479,27 @@ CpvStaticDeclare(CmiReductionID, _reduce_seqID_request);
CpvStaticDeclare(CmiReductionID, _reduce_seqID_dynamic);

CsvStaticDeclare(CmiNodeReduction *, _nodereduce_info);
CsvStaticDeclare(CmiReductionID, _nodereduce_seqID_global);
CsvStaticDeclare(CmiReductionID, _nodereduce_seqID_request);
CsvStaticDeclare(CmiReductionID, _nodereduce_seqID_dynamic);
CsvStaticDeclare(CmiNodeReductionID, _nodereduce_seqID_global);
CsvStaticDeclare(CmiNodeReductionID, _nodereduce_seqID_request);
CsvStaticDeclare(CmiNodeReductionID, _nodereduce_seqID_dynamic);

enum : CmiReductionID {
CmiReductionID_globalOffset = 0, /* Reductions that involve the whole set of processors */
CmiReductionID_requestOffset = 1, /* Reductions IDs that are requested by all the processors (i.e during intialization) */
CmiReductionID_dynamicOffset = 2, /* Reductions IDs that are requested by only one processor (typically at runtime) */
CmiReductionID_multiplier = 3

CmiReductionID_multiplier = 4
};

static_assert(CmiIsPow2(CmiReductionID_multiplier),
"CmiReductionID_multiplier must be a power of two because seqID counters may overflow and wrap to 0");

static inline unsigned int CmiGetRedIndex(CmiReductionID id) {
return id & ~((~0u) << CmiLogMaxReductions);
}

static CmiReduction* CmiGetReductionCreate(int id, short int numChildren) {
const int idx = id & ~((~0u) << CmiLogMaxReductions);
auto & redref = CpvAccess(_reduce_info)[idx];
auto & redref = CpvAccess(_reduce_info)[CmiGetRedIndex(id)];
CmiReduction *red = redref;
if (red != NULL && red->seqID != id) {
/* The table needs to be expanded */
Expand All @@ -2506,33 +2526,27 @@ static CmiReduction* CmiGetReductionCreate(int id, short int numChildren) {
}

static void CmiClearReduction(int id) {
const int idx = id & ~((~0u) << CmiLogMaxReductions);
auto & redref = CpvAccess(_reduce_info)[idx];
auto & redref = CpvAccess(_reduce_info)[CmiGetRedIndex(id)];
auto red = redref;
redref = nullptr;
free(red);
}

static CmiReduction* CmiGetNextReduction(short int numChildren) {
int id = CpvAccess(_reduce_seqID_global);
int newid = id + CmiReductionID_multiplier;
if (id > 0xFFF0) newid = CmiReductionID_globalOffset;
CpvAccess(_reduce_seqID_global) = newid;
return CmiGetReductionCreate(id, numChildren);
static CmiReductionID CmiGetNextReductionID(void) {
return CmiReductionIDFetchAdd(CpvAccess(_reduce_seqID_global), CmiReductionID_multiplier);
}

CmiReductionID CmiGetGlobalReduction(void) {
return CpvAccess(_reduce_seqID_request)+=CmiReductionID_multiplier;
return CmiReductionIDFetchAdd(CpvAccess(_reduce_seqID_request), CmiReductionID_multiplier);
}

CmiReductionID CmiGetDynamicReduction(void) {
if (CmiMyPe() != 0) CmiAbort("Cannot call CmiGetDynamicReduction on processors other than zero!\n");
return CpvAccess(_reduce_seqID_dynamic)+=CmiReductionID_multiplier;
return CmiReductionIDFetchAdd(CpvAccess(_reduce_seqID_dynamic), CmiReductionID_multiplier);
}

static CmiReduction* CmiGetNodeReductionCreate(int id, short int numChildren) {
const int idx = id & ~((~0u) << CmiLogMaxReductions);
auto & redref = CsvAccess(_nodereduce_info)[idx].red;
auto & redref = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)].red;
CmiReduction *red = redref;
if (red != NULL && red->seqID != id) {
/* The table needs to be expanded */
Expand All @@ -2559,28 +2573,23 @@ static CmiReduction* CmiGetNodeReductionCreate(int id, short int numChildren) {
}

static void CmiClearNodeReduction(int id) {
const int idx = id & ~((~0u) << CmiLogMaxReductions);
auto & redref = CsvAccess(_nodereduce_info)[idx].red;
auto & redref = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)].red;
auto red = redref;
redref = nullptr;
free(red);
}

static CmiReduction* CmiGetNextNodeReduction(short int numChildren) {
int id = CsvAccess(_nodereduce_seqID_global);
int newid = id + CmiReductionID_multiplier;
if (id > 0xFFF0) newid = CmiReductionID_globalOffset;
CsvAccess(_nodereduce_seqID_global) = newid;
return CmiGetNodeReductionCreate(id, numChildren);
static CmiReductionID CmiGetNextNodeReductionID(void) {
return CmiReductionIDFetchAdd(CsvAccess(_nodereduce_seqID_global), CmiReductionID_multiplier);
}

CmiReductionID CmiGetGlobalNodeReduction(void) {
return CsvAccess(_nodereduce_seqID_request)+=CmiReductionID_multiplier;
return CmiReductionIDFetchAdd(CsvAccess(_nodereduce_seqID_request), CmiReductionID_multiplier);
}

CmiReductionID CmiGetDynamicNodeReduction(void) {
if (CmiMyNode() != 0) CmiAbort("Cannot call CmiGetDynamicNodeReduction on nodes other than zero!\n");
return CsvAccess(_nodereduce_seqID_dynamic)+=CmiReductionID_multiplier;
return CmiReductionIDFetchAdd(CsvAccess(_nodereduce_seqID_dynamic), CmiReductionID_multiplier);
}

void CmiReductionHandleDynamicRequest(char *msg) {
Expand Down Expand Up @@ -2797,14 +2806,16 @@ static void CmiGlobalNodeReduceStruct(void *data, CmiReducePupFn pupFn,
}

void CmiReduce(void *msg, int size, CmiReduceMergeFn mergeFn) {
CmiReduction *red = CmiGetNextReduction(CmiNumSpanTreeChildren(CmiMyPe()));
const CmiReductionID id = CmiGetNextReductionID();
CmiReduction *red = CmiGetReductionCreate(id, CmiNumSpanTreeChildren(CmiMyPe()));
CmiGlobalReduce(msg, size, mergeFn, red);
}

void CmiReduceStruct(void *data, CmiReducePupFn pupFn,
CmiReduceMergeFn mergeFn, CmiHandler dest,
CmiReduceDeleteFn deleteFn) {
CmiReduction *red = CmiGetNextReduction(CmiNumSpanTreeChildren(CmiMyPe()));
const CmiReductionID id = CmiGetNextReductionID();
CmiReduction *red = CmiGetReductionCreate(id, CmiNumSpanTreeChildren(CmiMyPe()));
CmiGlobalReduceStruct(data, pupFn, mergeFn, dest, deleteFn, red);
}

Expand Down Expand Up @@ -2884,27 +2895,65 @@ void CmiGroupReduceStruct(CmiGroup grp, void *data, CmiReducePupFn pupFn,
}

void CmiNodeReduce(void *msg, int size, CmiReduceMergeFn mergeFn) {
CmiReduction *red = CmiGetNextNodeReduction(CmiNumNodeSpanTreeChildren(CmiMyNode()));
const CmiReductionID id = CmiGetNextNodeReductionID();
#if CMK_SMP
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
CmiLock(nodered.lock);
#endif

CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode()));
CmiGlobalNodeReduce(msg, size, mergeFn, red);

#if CMK_SMP
CmiUnlock(nodered.lock);
#endif
}

void CmiNodeReduceStruct(void *data, CmiReducePupFn pupFn,
CmiReduceMergeFn mergeFn, CmiHandler dest,
CmiReduceDeleteFn deleteFn) {
CmiReduction *red = CmiGetNextNodeReduction(CmiNumNodeSpanTreeChildren(CmiMyNode()));
const CmiReductionID id = CmiGetNextNodeReductionID();
#if CMK_SMP
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
CmiLock(nodered.lock);
#endif

CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode()));
CmiGlobalNodeReduceStruct(data, pupFn, mergeFn, dest, deleteFn, red);

#if CMK_SMP
CmiUnlock(nodered.lock);
#endif
}

void CmiNodeReduceID(void *msg, int size, CmiReduceMergeFn mergeFn, CmiReductionID id) {
#if CMK_SMP
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
CmiLock(nodered.lock);
#endif

CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode()));
CmiGlobalNodeReduce(msg, size, mergeFn, red);

#if CMK_SMP
CmiUnlock(nodered.lock);
#endif
}

void CmiNodeReduceStructID(void *data, CmiReducePupFn pupFn,
CmiReduceMergeFn mergeFn, CmiHandler dest,
CmiReduceDeleteFn deleteFn, CmiReductionID id) {
#if CMK_SMP
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
CmiLock(nodered.lock);
#endif

CmiReduction *red = CmiGetNodeReductionCreate(id, CmiNumNodeSpanTreeChildren(CmiMyNode()));
CmiGlobalNodeReduceStruct(data, pupFn, mergeFn, dest, deleteFn, red);

#if CMK_SMP
CmiUnlock(nodered.lock);
#endif
}

static void CmiHandleReductionMessage(void *msg) {
Expand All @@ -2921,8 +2970,7 @@ static void CmiHandleReductionMessage(void *msg) {
static void CmiHandleNodeReductionMessage(void *msg) {
const auto id = CmiGetRedID(msg);
#if CMK_SMP
const int idx = id & ~((~0u) << CmiLogMaxReductions);
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[idx];
CmiNodeReduction & nodered = CsvAccess(_nodereduce_info)[CmiGetRedIndex(id)];
CmiLock(nodered.lock);
#endif

Expand Down Expand Up @@ -2964,11 +3012,11 @@ void CmiReductionsInit(void) {

if (CmiMyRank() == 0)
{
CsvInitialize(CmiReductionID, _nodereduce_seqID_global);
CsvInitialize(CmiNodeReductionID, _nodereduce_seqID_global);
CsvAccess(_nodereduce_seqID_global) = CmiReductionID_globalOffset;
CsvInitialize(CmiReductionID, _nodereduce_seqID_request);
CsvInitialize(CmiNodeReductionID, _nodereduce_seqID_request);
CsvAccess(_nodereduce_seqID_request) = CmiReductionID_requestOffset;
CsvInitialize(CmiReductionID, _nodereduce_seqID_dynamic);
CsvInitialize(CmiNodeReductionID, _nodereduce_seqID_dynamic);
CsvAccess(_nodereduce_seqID_dynamic) = CmiReductionID_dynamicOffset;
CsvInitialize(CmiNodeReduction *, _nodereduce_info);
auto noderedinfo = (CmiNodeReduction *)malloc(CmiMaxReductions * sizeof(CmiNodeReduction));
Expand Down
3 changes: 3 additions & 0 deletions src/conv-core/converse.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@

#define CMI_MSG_NOKEEP(msg) ((CmiMsgHeaderBasic *)msg)->nokeep

#define CmiIsPow2OrZero(v) (((v) & ((v) - 1)) == 0)
#define CmiIsPow2(v) (CmiIsPow2OrZero(v) && (v))

#define CMIALIGN(x,n) (size_t)((~((size_t)n-1))&((x)+(n-1)))
/*#define ALIGN8(x) (size_t)((~7)&((x)+7)) */
#define ALIGN8(x) CMIALIGN(x,8)
Expand Down

0 comments on commit 714b731

Please sign in to comment.