Skip to content

Commit

Permalink
ops: add support for user-defined big count ops
Browse files Browse the repository at this point in the history
related to open-mpi#12226 and open-mpi#9194

remove incorrect and misleading comment about ompi_3buff_op_reduce.
See open-mpi#967

Signed-off-by: Howard Pritchard <howardp@lanl.gov>
  • Loading branch information
hppritcha committed Jan 9, 2025
1 parent 0bccfcd commit 5e653ab
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 18 deletions.
3 changes: 3 additions & 0 deletions ompi/mpi/c/op_create.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
* Copyright (c) 2008-2009 Cisco Systems, Inc. All rights reserved.
* Copyright (c) 2015 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2025 Triad National Security, LLC. All rights
* reserved.
* $COPYRIGHT$
*
* Additional copyrights may follow
Expand Down Expand Up @@ -57,6 +59,7 @@ int MPI_Op_create(MPI_User_function * function, int commute, MPI_Op * op)
/* Create and cache the op. Sets a refcount of 1. */

*op = ompi_op_create_user(OPAL_INT_TO_BOOL(commute),
false,
(ompi_op_fortran_handler_fn_t *) function);
if (NULL == *op) {
err = MPI_ERR_INTERN;
Expand Down
6 changes: 5 additions & 1 deletion ompi/op/op.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* Copyright (c) 2015 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2018 FUJITSU LIMITED. All rights reserved.
* Copyright (c) 2018 Triad National Security, LLC. All rights
* Copyright (c) 2018-2025 Triad National Security, LLC. All rights
* reserved.
* $COPYRIGHT$
*
Expand Down Expand Up @@ -353,6 +353,7 @@ static int ompi_op_finalize (void)
* Create a new MPI_Op
*/
ompi_op_t *ompi_op_create_user(bool commute,
bool bigcount,
ompi_op_fortran_handler_fn_t func)
{
ompi_op_t *new_op;
Expand Down Expand Up @@ -382,6 +383,9 @@ ompi_op_t *ompi_op_create_user(bool commute,
if (commute) {
new_op->o_flags |= OMPI_OP_FLAGS_COMMUTE;
}
if(bigcount) {
new_op->o_flags |= OMPI_OP_FLAGS_BIGCOUNT;
}

opal_string_copy(new_op->o_name, "USER OP", sizeof(new_op->o_name));
new_op->o_name[sizeof(new_op->o_name) - 1] = '\0';
Expand Down
78 changes: 61 additions & 17 deletions ompi/op/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
* Copyright (c) 2018 Amazon.com, Inc. or its affiliates. All Rights reserved.
* Copyright (c) 2019 Research Organization for Information Science
* and Technology (RIST). All rights reserved.
* Copyright (c) 2018 Triad National Security, LLC. All rights
* Copyright (c) 2018-2025 Triad National Security, LLC. All rights
* reserved.
* Copyright (c) 2021 IBM Corporation. All rights reserved.
* $COPYRIGHT$
Expand Down Expand Up @@ -61,12 +61,16 @@ BEGIN_C_DECLS
*/
typedef void (ompi_op_c_handler_fn_t)(const void *, void *, int *,
struct ompi_datatype_t **);
typedef void (ompi_op_c_handler_bc_fn_t)(const void *, void *, size_t *,
struct ompi_datatype_t **);

/**
* Typedef for fortran user-defined MPI_Ops.
*/
typedef void (ompi_op_fortran_handler_fn_t)(const void *, void *,
MPI_Fint *, MPI_Fint *);
typedef void (ompi_op_fortran_handler_bc_fn_t)(const void *, void *,
size_t *, MPI_Fint *);

/**
* Typedef for Java op functions intercept (used for user-defined
Expand Down Expand Up @@ -98,8 +102,8 @@ typedef void (ompi_op_java_handler_fn_t)(const void *, void *, int *,
#define OMPI_OP_FLAGS_FLOAT_ASSOC 0x0020
/** Set if the callback function is communative */
#define OMPI_OP_FLAGS_COMMUTE 0x0040


/** Set if the callback function is using bigcount */
#define OMPI_OP_FLAGS_BIGCOUNT 0x0080


/*
Expand Down Expand Up @@ -152,8 +156,12 @@ struct ompi_op_t {
ompi_op_base_op_fns_t intrinsic;
/** C handler function pointer */
ompi_op_c_handler_fn_t *c_fn;
/** C handler function pointer - bigcount*/
ompi_op_c_handler_bc_fn_t *c_fn_bc;
/** Fortran handler function pointer */
ompi_op_fortran_handler_fn_t *fort_fn;
/** Fortran handler function pointer - bigcount*/
ompi_op_fortran_handler_bc_fn_t *fort_fn_bc;
/** Java intercept function data */
struct {
/* The OMPI C++ callback/intercept function */
Expand Down Expand Up @@ -333,6 +341,8 @@ int ompi_op_init(void);
*
* @param commute Boolean indicating whether the operation is
* communative or not
* @param bigcount Boolean indicating whether or not the op is
* using the bigcount (MPI_Count) interface
* @param func Function pointer of the error handler
*
* @returns op Pointer to the ompi_op_t that will be
Expand All @@ -355,6 +365,7 @@ int ompi_op_init(void);
* manually.
*/
ompi_op_t *ompi_op_create_user(bool commute,
bool bigcount,
ompi_op_fortran_handler_fn_t func);

/**
Expand Down Expand Up @@ -512,11 +523,9 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
* in iterations of counts <= INT_MAX since it has an `int *len`
* parameter.
*
* Note: When we add BigCount support then we can distinguish between
* a reduction operation with `int *len` and `MPI_Count *len`. At which
* point we can avoid this loop.
*/
if( OPAL_UNLIKELY(full_count > INT_MAX) ) {
if(OPAL_UNLIKELY((full_count > INT_MAX) &&
(0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)))) {
size_t done_count = 0, shift;
int iter_count;
ptrdiff_t ext, lb;
Expand Down Expand Up @@ -578,8 +587,12 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
/* User-defined function */
if (0 != (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC)) {
f_dtype = OMPI_INT_2_FINT(dtype->d_f_to_c_index);
f_count = OMPI_INT_2_FINT(count);
op->o_func.fort_fn(source, target, &f_count, &f_dtype);
if (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)) {
f_count = OMPI_INT_2_FINT(count);
op->o_func.fort_fn(source, target, &f_count, &f_dtype);
} else {
op->o_func.fort_fn_bc(source, target, &full_count, &f_dtype);
}
return;
} else if (0 != (op->o_flags & OMPI_OP_FLAGS_JAVA_FUNC)) {
op->o_func.java_data.intercept_fn(source, target, &count, &dtype,
Expand All @@ -588,15 +601,25 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source,
op->o_func.java_data.object);
return;
}
op->o_func.c_fn(source, target, &count, &dtype);
if (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)) {
op->o_func.c_fn(source, target, &count, &dtype);
} else {
op->o_func.c_fn_bc(source, target, &full_count, &dtype);
}
return;
}

static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, void * restrict source2,
void * restrict result, int count, struct ompi_datatype_t *dtype)
void * restrict result, size_t full_count, struct ompi_datatype_t *dtype)
{
ompi_datatype_copy_content_same_ddt (dtype, count, (char*)result, (char*)source1);
op->o_func.c_fn (source2, result, &count, &dtype);
ompi_datatype_copy_content_same_ddt (dtype, full_count, (char*)result, (char*)source1);
if (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)) {
int count = (int)full_count; /* protected by loop in only caller of this function */
assert(full_count <= INT_MAX);
op->o_func.c_fn (source2, result, &count, &dtype);
} else {
op->o_func.c_fn_bc (source2, result, &full_count, &dtype);
}
}

/**
Expand All @@ -618,13 +641,11 @@ static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, v
* with the values in the source buffer and the result is stored in
* the target buffer).
*
* This function will *only* be invoked on intrinsic MPI_Ops.
*
* Otherwise, this function is the same as ompi_op_reduce.
*/
static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1,
void *source2, void *target,
int count, ompi_datatype_t * dtype)
size_t full_count, ompi_datatype_t * dtype)
{
void *restrict src1;
void *restrict src2;
Expand All @@ -633,13 +654,36 @@ static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1,
src2 = source2;
tgt = target;

if(OPAL_UNLIKELY((full_count > INT_MAX) &&
(0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)))) {
size_t done_count = 0, shift, iter_count;
ptrdiff_t ext, lb;

ompi_datatype_get_extent(dtype, &lb, &ext);

while(done_count < full_count) {
if(done_count + INT_MAX > full_count) {
iter_count = full_count - done_count;
} else {
iter_count = INT_MAX;
}
shift = done_count * ext;
// Recurse one level in iterations of 'int'
ompi_3buff_op_reduce(op, (char*)source1 + shift, (char *)source2 + shift,
(char*)target + shift, iter_count, dtype);
done_count += iter_count;
}
return;
}

if (OPAL_LIKELY(ompi_op_is_intrinsic (op))) {
int count = (int)full_count;
op->o_3buff_intrinsic.fns[ompi_op_ddt_map[dtype->id]](src1, src2,
tgt, &count,
&dtype,
op->o_3buff_intrinsic.modules[ompi_op_ddt_map[dtype->id]]);
} else {
ompi_3buff_op_user (op, src1, src2, tgt, count, dtype);
ompi_3buff_op_user (op, src1, src2, tgt, full_count, dtype);
}
}

Expand Down

0 comments on commit 5e653ab

Please sign in to comment.