Skip to content

Commit

Permalink
mpi: retain operation and datatype in non blocking collectives
Browse files Browse the repository at this point in the history
MPI standard states a user MPI_Op and/or user MPI_Datatype can be free'd
after a call to a non blocking collective and before the non-blocking
collective completes.
Retain user (only) MPI_Op and MPI_Datatype when the non blocking call is
invoked, and set a request callback so they are free'd when the MPI_Request
completes.

Thanks Thomas Ponweiser for reporting this

Fixes #2151
Fixes #1304

Signed-off-by: Gilles Gouaillardet <gilles@rist.or.jp>
  • Loading branch information
ggouaillardet committed Sep 1, 2017
1 parent 59b9602 commit 4fe431b
Show file tree
Hide file tree
Showing 23 changed files with 282 additions and 33 deletions.
160 changes: 160 additions & 0 deletions ompi/mca/coll/base/coll_base_util.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,27 @@
#include "ompi/mca/pml/pml.h"
#include "coll_base_util.h"

struct retain_op_data {
ompi_request_complete_fn_t req_complete_cb;
void *req_complete_cb_data;
ompi_op_t *op;
ompi_datatype_t *datatype;
};

struct retain_datatypes_data {
ompi_request_complete_fn_t req_complete_cb;
void *req_complete_cb_data;
ompi_datatype_t *stype;
ompi_datatype_t *rtype;
};

struct retain_datatypes_w_data {
ompi_request_complete_fn_t req_complete_cb;
void *req_complete_cb_data;
int count;
ompi_datatype_t *types[];
};

int ompi_coll_base_sendrecv_actual( const void* sendbuf, size_t scount,
ompi_datatype_t* sdatatype,
int dest, int stag,
Expand Down Expand Up @@ -78,3 +99,142 @@ int ompi_coll_base_sendrecv_actual( const void* sendbuf, size_t scount,
return (err);
}

static int release_op_callback(struct ompi_request_t *request) {
struct retain_op_data * p = (struct retain_op_data *)request->req_complete_cb_data;
int rc = OMPI_SUCCESS;
assert (NULL != p);
if (NULL != p->req_complete_cb) {
request->req_complete_cb = p->req_complete_cb;
request->req_complete_cb_data = p->req_complete_cb_data;
rc = request->req_complete_cb(request);
}
if (NULL != p->op) {
OBJ_RELEASE(p->op);
}
if (NULL != p->datatype) {
OBJ_RELEASE(p->datatype);
}
free(p);
return rc;
}

int ompi_coll_base_retain_op( ompi_request_t *request, ompi_op_t *op,
ompi_datatype_t *type) {
bool retain = !ompi_op_is_intrinsic(op);
retain |= !ompi_datatype_is_predefined(type);
if (OPAL_UNLIKELY(retain)) {
struct retain_op_data *p = (struct retain_op_data *)calloc(1, sizeof(struct retain_op_data));
if (OPAL_UNLIKELY(NULL == p)) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
if (!ompi_op_is_intrinsic(op)) {
OBJ_RETAIN(op);
p->op = op;
}
if (!ompi_datatype_is_predefined(type)) {
OBJ_RETAIN(type);
p->datatype = type;
}
p->req_complete_cb = request->req_complete_cb;
p->req_complete_cb_data = request->req_complete_cb_data;
request->req_complete_cb = release_op_callback;
request->req_complete_cb_data = p;
}
return OMPI_SUCCESS;
}

static int release_datatypes_callback(struct ompi_request_t *request) {
struct retain_datatypes_data * p = (struct retain_datatypes_data *)request->req_complete_cb_data;
int rc = OMPI_SUCCESS;
assert (NULL != p);
if (NULL != p->req_complete_cb) {
request->req_complete_cb = p->req_complete_cb;
request->req_complete_cb_data = p->req_complete_cb_data;
rc = request->req_complete_cb(request);
}
if (NULL != p->stype) {
OBJ_RELEASE(p->stype);
}
if (NULL != p->rtype) {
OBJ_RELEASE(p->rtype);
}
free(p);
return rc;
}

int ompi_coll_base_retain_datatypes( ompi_request_t *request, ompi_datatype_t *stype,
ompi_datatype_t *rtype) {
bool retain = NULL != stype && !ompi_datatype_is_predefined(stype);
retain |= NULL != rtype && !ompi_datatype_is_predefined(rtype);
if (OPAL_UNLIKELY(retain)) {
struct retain_datatypes_data *p = (struct retain_datatypes_data *)calloc(1, sizeof(struct retain_datatypes_data));
if (OPAL_UNLIKELY(NULL == p)) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
if (NULL != stype && !ompi_datatype_is_predefined(stype)) {
OBJ_RETAIN(stype);
p->stype = stype;
}
if (NULL != rtype && !ompi_datatype_is_predefined(rtype)) {
OBJ_RETAIN(rtype);
p->rtype = rtype;
}
p->req_complete_cb = request->req_complete_cb;
p->req_complete_cb_data = request->req_complete_cb_data;
request->req_complete_cb = release_datatypes_callback;
request->req_complete_cb_data = p;
}
return OMPI_SUCCESS;
}

static int release_datatypes_w_callback(struct ompi_request_t *request) {
struct retain_datatypes_w_data * p = (struct retain_datatypes_w_data *)request->req_complete_cb_data;
int rc = OMPI_SUCCESS;
assert (NULL != p);
if (NULL != p->req_complete_cb) {
request->req_complete_cb = p->req_complete_cb;
request->req_complete_cb_data = p->req_complete_cb_data;
rc = request->req_complete_cb(request);
}
for (int i=0; i<p->count; i++) {
OBJ_RELEASE(p->types[i]);
}
free(p);
return rc;
}

int ompi_coll_base_retain_datatypes_w( ompi_request_t *request, int count,
ompi_datatype_t *const stypes[], ompi_datatype_t *const rtypes[]) {
int datatypes = 0;
for (int i=0; i<count; i++) {
if (NULL != stypes[i] && !ompi_datatype_is_predefined(stypes[i])) {
datatypes++;
}
if (NULL != rtypes[i] && !ompi_datatype_is_predefined(rtypes[i])) {
datatypes++;
}
}
if (OPAL_UNLIKELY(0 < datatypes)) {
struct retain_datatypes_w_data *p = (struct retain_datatypes_w_data *)calloc(1, sizeof(struct retain_datatypes_data)+(datatypes-1)*sizeof(ompi_datatype_t *));
if (OPAL_UNLIKELY(NULL == p)) {
return OMPI_ERR_OUT_OF_RESOURCE;
}
datatypes = 0;
for (int i=0; i<count; i++) {
if (NULL != stypes[i] && !ompi_datatype_is_predefined(stypes[i])) {
p->types[datatypes++] = stypes[i];
OBJ_RETAIN(stypes[i]);
}
if (NULL != rtypes[i] && !ompi_datatype_is_predefined(rtypes[i])) {
p->types[datatypes++] = rtypes[i];
OBJ_RETAIN(rtypes[i]);
}
}
p->req_complete_cb = request->req_complete_cb;
p->req_complete_cb_data = request->req_complete_cb_data;
request->req_complete_cb = release_datatypes_w_callback;
request->req_complete_cb_data = p;
}
return OMPI_SUCCESS;
}

10 changes: 10 additions & 0 deletions ompi/mca/coll/base/coll_base_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "ompi/mca/mca.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/request/request.h"
#include "ompi/op/op.h"
#include "ompi/mca/pml/pml.h"

BEGIN_C_DECLS
Expand Down Expand Up @@ -70,5 +71,14 @@ ompi_coll_base_sendrecv( void* sendbuf, size_t scount, ompi_datatype_t* sdatatyp
source, rtag, comm, status);
}

int ompi_coll_base_retain_op( ompi_request_t *request, ompi_op_t *op,
ompi_datatype_t *type);

int ompi_coll_base_retain_datatypes( ompi_request_t *request, ompi_datatype_t *stype,
ompi_datatype_t *rtype);

int ompi_coll_base_retain_datatypes_w( ompi_request_t *request, int count,
ompi_datatype_t *const stypes[],
ompi_datatype_t *const rtypes[]);
END_C_DECLS
#endif /* MCA_COLL_BASE_UTIL_EXPORT_H */
6 changes: 5 additions & 1 deletion ompi/mpi/c/iallgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* Copyright (c) 2012 Oak Ridge National Laboratory. All rights reserved.
* Copyright (c) 2013 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2015 Research Organization for Information Science
* Copyright (c) 2015-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* $COPYRIGHT$
*
Expand All @@ -31,6 +31,7 @@
#include "ompi/communicator/communicator.h"
#include "ompi/errhandler/errhandler.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/mca/coll/base/coll_base_util.h"
#include "ompi/memchecker.h"

#if OMPI_BUILD_MPI_PROFILING
Expand Down Expand Up @@ -99,6 +100,9 @@ int MPI_Iallgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
err = comm->c_coll->coll_iallgather(sendbuf, sendcount, sendtype,
recvbuf, recvcount, recvtype, comm,
request, comm->c_coll->coll_iallgather_module);
if (OPAL_LIKELY(OMPI_SUCCESS == err)) {
ompi_coll_base_retain_datatypes(*request, sendtype, recvtype);
}

OMPI_ERRHANDLER_RETURN(err, comm, err, FUNC_NAME);
}
6 changes: 5 additions & 1 deletion ompi/mpi/c/iallgatherv.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* Copyright (c) 2012 Cisco Systems, Inc. All rights reserved.
* Copyright (c) 2012-2013 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2015 Research Organization for Information Science
* Copyright (c) 2015-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* $COPYRIGHT$
*
Expand All @@ -31,6 +31,7 @@
#include "ompi/communicator/communicator.h"
#include "ompi/errhandler/errhandler.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/mca/coll/base/coll_base_util.h"
#include "ompi/memchecker.h"

#if OMPI_BUILD_MPI_PROFILING
Expand Down Expand Up @@ -123,6 +124,9 @@ int MPI_Iallgatherv(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
recvbuf, recvcounts, displs,
recvtype, comm, request,
comm->c_coll->coll_iallgatherv_module);
if (OPAL_LIKELY(OMPI_SUCCESS == err)) {
ompi_coll_base_retain_datatypes(*request, sendtype, recvtype);
}
OMPI_ERRHANDLER_RETURN(err, comm, err, FUNC_NAME);
}

8 changes: 5 additions & 3 deletions ompi/mpi/c/iallreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
* All rights reserved.
* Copyright (c) 2013 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2015 Research Organization for Information Science
* Copyright (c) 2015-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2016 IBM Corporation. All rights reserved.
* $COPYRIGHT$
Expand All @@ -31,6 +31,7 @@
#include "ompi/errhandler/errhandler.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/op/op.h"
#include "ompi/mca/coll/base/coll_base_util.h"
#include "ompi/memchecker.h"

#if OMPI_BUILD_MPI_PROFILING
Expand Down Expand Up @@ -109,10 +110,11 @@ int MPI_Iallreduce(const void *sendbuf, void *recvbuf, int count,

/* Invoke the coll component to perform the back-end operation */

OBJ_RETAIN(op);
err = comm->c_coll->coll_iallreduce(sendbuf, recvbuf, count, datatype,
op, comm, request, comm->c_coll->coll_iallreduce_module);
OBJ_RELEASE(op);
if (OPAL_LIKELY(OMPI_SUCCESS == err)) {
ompi_coll_base_retain_op(*request, op, datatype);
}
OMPI_ERRHANDLER_RETURN(err, comm, err, FUNC_NAME);
}

6 changes: 5 additions & 1 deletion ompi/mpi/c/ialltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* Copyright (c) 2012 Oak Ridge National Laboratory. All rights reserved.
* Copyright (c) 2013 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2014-2016 Research Organization for Information Science
* Copyright (c) 2014-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* $COPYRIGHT$
*
Expand All @@ -31,6 +31,7 @@
#include "ompi/communicator/communicator.h"
#include "ompi/errhandler/errhandler.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/mca/coll/base/coll_base_util.h"
#include "ompi/memchecker.h"

#if OMPI_BUILD_MPI_PROFILING
Expand Down Expand Up @@ -98,5 +99,8 @@ int MPI_Ialltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
err = comm->c_coll->coll_ialltoall(sendbuf, sendcount, sendtype,
recvbuf, recvcount, recvtype, comm,
request, comm->c_coll->coll_ialltoall_module);
if (OPAL_LIKELY(OMPI_SUCCESS == err)) {
ompi_coll_base_retain_datatypes(*request, sendtype, recvtype);
}
OMPI_ERRHANDLER_RETURN(err, comm, err, FUNC_NAME);
}
6 changes: 5 additions & 1 deletion ompi/mpi/c/ialltoallv.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* Copyright (c) 2007 Cisco Systems, Inc. All rights reserved.
* Copyright (c) 2012-2013 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2014-2016 Research Organization for Information Science
* Copyright (c) 2014-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* $COPYRIGHT$
*
Expand All @@ -30,6 +30,7 @@
#include "ompi/communicator/communicator.h"
#include "ompi/errhandler/errhandler.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/mca/coll/base/coll_base_util.h"
#include "ompi/memchecker.h"

#if OMPI_BUILD_MPI_PROFILING
Expand Down Expand Up @@ -127,6 +128,9 @@ int MPI_Ialltoallv(const void *sendbuf, const int sendcounts[], const int sdispl
err = comm->c_coll->coll_ialltoallv(sendbuf, sendcounts, sdispls,
sendtype, recvbuf, recvcounts, rdispls,
recvtype, comm, request, comm->c_coll->coll_ialltoallv_module);
if (OPAL_LIKELY(OMPI_SUCCESS == err)) {
ompi_coll_base_retain_datatypes(*request, sendtype, recvtype);
}
OMPI_ERRHANDLER_RETURN(err, comm, err, FUNC_NAME);
}

9 changes: 8 additions & 1 deletion ompi/mpi/c/ialltoallw.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* Copyright (c) 2007 Cisco Systems, Inc. All rights reserved.
* Copyright (c) 2012-2013 Los Alamos National Security, LLC. All rights
* reserved.
* Copyright (c) 2014-2016 Research Organization for Information Science
* Copyright (c) 2014-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* $COPYRIGHT$
*
Expand All @@ -30,6 +30,7 @@
#include "ompi/communicator/communicator.h"
#include "ompi/errhandler/errhandler.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/mca/coll/base/coll_base_util.h"
#include "ompi/memchecker.h"

#if OMPI_BUILD_MPI_PROFILING
Expand Down Expand Up @@ -124,6 +125,12 @@ int MPI_Ialltoallw(const void *sendbuf, const int sendcounts[], const int sdispl
sendtypes, recvbuf, recvcounts,
rdispls, recvtypes, comm, request,
comm->c_coll->coll_ialltoallw_module);
if (OPAL_LIKELY(OMPI_SUCCESS == err)) {
ompi_coll_base_retain_datatypes_w(*request,
OMPI_COMM_IS_INTER(comm)?ompi_comm_remote_size(comm):ompi_comm_size(comm),
sendtypes,
recvtypes);
}
OMPI_ERRHANDLER_RETURN(err, comm, err, FUNC_NAME);
}

6 changes: 5 additions & 1 deletion ompi/mpi/c/ibcast.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
* Copyright (c) 2012 Oak Rigde National Laboratory. All rights reserved.
* Copyright (c) 2015 Research Organization for Information Science
* Copyright (c) 2015-2017 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2017 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
Expand All @@ -19,6 +19,7 @@
#include "ompi/communicator/communicator.h"
#include "ompi/errhandler/errhandler.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/mca/coll/base/coll_base_util.h"
#include "ompi/memchecker.h"

#if OMPI_BUILD_MPI_PROFILING
Expand Down Expand Up @@ -83,5 +84,8 @@ int MPI_Ibcast(void *buffer, int count, MPI_Datatype datatype,
err = comm->c_coll->coll_ibcast(buffer, count, datatype, root, comm,
request,
comm->c_coll->coll_ibcast_module);
if (OPAL_LIKELY(OMPI_SUCCESS == err)) {
ompi_coll_base_retain_datatypes(*request, datatype, NULL);
}
OMPI_ERRHANDLER_RETURN(err, comm, err, FUNC_NAME);
}
Loading

0 comments on commit 4fe431b

Please sign in to comment.