diff --git a/ompi/mpi/c/op_create.c b/ompi/mpi/c/op_create.c index 28e00222ad9..1fafbea899a 100644 --- a/ompi/mpi/c/op_create.c +++ b/ompi/mpi/c/op_create.c @@ -12,6 +12,8 @@ * Copyright (c) 2008-2009 Cisco Systems, Inc. All rights reserved. * Copyright (c) 2015 Research Organization for Information Science * and Technology (RIST). All rights reserved. + * Copyright (c) 2025 Triad National Security, LLC. All rights + * reserved. * $COPYRIGHT$ * * Additional copyrights may follow @@ -57,6 +59,7 @@ int MPI_Op_create(MPI_User_function * function, int commute, MPI_Op * op) /* Create and cache the op. Sets a refcount of 1. */ *op = ompi_op_create_user(OPAL_INT_TO_BOOL(commute), + false, (ompi_op_fortran_handler_fn_t *) function); if (NULL == *op) { err = MPI_ERR_INTERN; diff --git a/ompi/op/op.c b/ompi/op/op.c index 3977fa8b97b..c800dc0a1cb 100644 --- a/ompi/op/op.c +++ b/ompi/op/op.c @@ -17,7 +17,7 @@ * Copyright (c) 2015 Research Organization for Information Science * and Technology (RIST). All rights reserved. * Copyright (c) 2018 FUJITSU LIMITED. All rights reserved. - * Copyright (c) 2018 Triad National Security, LLC. All rights + * Copyright (c) 2018-2025 Triad National Security, LLC. All rights * reserved. * $COPYRIGHT$ * @@ -353,6 +353,7 @@ static int ompi_op_finalize (void) * Create a new MPI_Op */ ompi_op_t *ompi_op_create_user(bool commute, + bool bigcount, ompi_op_fortran_handler_fn_t func) { ompi_op_t *new_op; @@ -382,6 +383,9 @@ ompi_op_t *ompi_op_create_user(bool commute, if (commute) { new_op->o_flags |= OMPI_OP_FLAGS_COMMUTE; } + if(bigcount) { + new_op->o_flags |= OMPI_OP_FLAGS_BIGCOUNT; + } opal_string_copy(new_op->o_name, "USER OP", sizeof(new_op->o_name)); new_op->o_name[sizeof(new_op->o_name) - 1] = '\0'; diff --git a/ompi/op/op.h b/ompi/op/op.h index f3cf5b53636..ffa48f4f15b 100644 --- a/ompi/op/op.h +++ b/ompi/op/op.h @@ -18,7 +18,7 @@ * Copyright (c) 2018 Amazon.com, Inc. or its affiliates. All Rights reserved. * Copyright (c) 2019 Research Organization for Information Science * and Technology (RIST). All rights reserved. - * Copyright (c) 2018 Triad National Security, LLC. All rights + * Copyright (c) 2018-2025 Triad National Security, LLC. All rights * reserved. * Copyright (c) 2021 IBM Corporation. All rights reserved. * $COPYRIGHT$ @@ -61,12 +61,16 @@ BEGIN_C_DECLS */ typedef void (ompi_op_c_handler_fn_t)(const void *, void *, int *, struct ompi_datatype_t **); +typedef void (ompi_op_c_handler_bc_fn_t)(const void *, void *, size_t *, + struct ompi_datatype_t **); /** * Typedef for fortran user-defined MPI_Ops. */ typedef void (ompi_op_fortran_handler_fn_t)(const void *, void *, MPI_Fint *, MPI_Fint *); +typedef void (ompi_op_fortran_handler_bc_fn_t)(const void *, void *, + size_t *, MPI_Fint *); /** * Typedef for Java op functions intercept (used for user-defined @@ -98,8 +102,8 @@ typedef void (ompi_op_java_handler_fn_t)(const void *, void *, int *, #define OMPI_OP_FLAGS_FLOAT_ASSOC 0x0020 /** Set if the callback function is communative */ #define OMPI_OP_FLAGS_COMMUTE 0x0040 - - +/** Set if the callback function is using bigcount */ +#define OMPI_OP_FLAGS_BIGCOUNT 0x0080 /* @@ -152,8 +156,12 @@ struct ompi_op_t { ompi_op_base_op_fns_t intrinsic; /** C handler function pointer */ ompi_op_c_handler_fn_t *c_fn; + /** C handler function pointer - bigcount*/ + ompi_op_c_handler_bc_fn_t *c_fn_bc; /** Fortran handler function pointer */ ompi_op_fortran_handler_fn_t *fort_fn; + /** Fortran handler function pointer - bigcount*/ + ompi_op_fortran_handler_bc_fn_t *fort_fn_bc; /** Java intercept function data */ struct { /* The OMPI C++ callback/intercept function */ @@ -333,6 +341,8 @@ int ompi_op_init(void); * * @param commute Boolean indicating whether the operation is * communative or not + * @param bigcount Boolean indicating whether or not the op is + * using the bigcount (MPI_Count) interface * @param func Function pointer of the error handler * * @returns op Pointer to the ompi_op_t that will be @@ -355,6 +365,7 @@ int ompi_op_init(void); * manually. */ ompi_op_t *ompi_op_create_user(bool commute, + bool bigcount, ompi_op_fortran_handler_fn_t func); /** @@ -512,11 +523,9 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, * in iterations of counts <= INT_MAX since it has an `int *len` * parameter. * - * Note: When we add BigCount support then we can distinguish between - * a reduction operation with `int *len` and `MPI_Count *len`. At which - * point we can avoid this loop. */ - if( OPAL_UNLIKELY(full_count > INT_MAX) ) { + if(OPAL_UNLIKELY((full_count > INT_MAX) && + (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)))) { size_t done_count = 0, shift; int iter_count; ptrdiff_t ext, lb; @@ -578,8 +587,12 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, /* User-defined function */ if (0 != (op->o_flags & OMPI_OP_FLAGS_FORTRAN_FUNC)) { f_dtype = OMPI_INT_2_FINT(dtype->d_f_to_c_index); - f_count = OMPI_INT_2_FINT(count); - op->o_func.fort_fn(source, target, &f_count, &f_dtype); + if (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)) { + f_count = OMPI_INT_2_FINT(count); + op->o_func.fort_fn(source, target, &f_count, &f_dtype); + } else { + op->o_func.fort_fn_bc(source, target, &full_count, &f_dtype); + } return; } else if (0 != (op->o_flags & OMPI_OP_FLAGS_JAVA_FUNC)) { op->o_func.java_data.intercept_fn(source, target, &count, &dtype, @@ -588,15 +601,25 @@ static inline void ompi_op_reduce(ompi_op_t * op, const void *source, op->o_func.java_data.object); return; } - op->o_func.c_fn(source, target, &count, &dtype); + if (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)) { + op->o_func.c_fn(source, target, &count, &dtype); + } else { + op->o_func.c_fn_bc(source, target, &full_count, &dtype); + } return; } static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, void * restrict source2, - void * restrict result, int count, struct ompi_datatype_t *dtype) + void * restrict result, size_t full_count, struct ompi_datatype_t *dtype) { - ompi_datatype_copy_content_same_ddt (dtype, count, (char*)result, (char*)source1); - op->o_func.c_fn (source2, result, &count, &dtype); + ompi_datatype_copy_content_same_ddt (dtype, full_count, (char*)result, (char*)source1); + if (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)) { + int count = (int)full_count; /* protected by loop in only caller of this function */ + assert(full_count <= INT_MAX); + op->o_func.c_fn (source2, result, &count, &dtype); + } else { + op->o_func.c_fn_bc (source2, result, &full_count, &dtype); + } } /** @@ -618,13 +641,11 @@ static inline void ompi_3buff_op_user (ompi_op_t *op, void * restrict source1, v * with the values in the source buffer and the result is stored in * the target buffer). * - * This function will *only* be invoked on intrinsic MPI_Ops. - * * Otherwise, this function is the same as ompi_op_reduce. */ static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1, void *source2, void *target, - int count, ompi_datatype_t * dtype) + size_t full_count, ompi_datatype_t * dtype) { void *restrict src1; void *restrict src2; @@ -633,13 +654,36 @@ static inline void ompi_3buff_op_reduce(ompi_op_t * op, void *source1, src2 = source2; tgt = target; + if(OPAL_UNLIKELY((full_count > INT_MAX) && + (0 == (op->o_flags & OMPI_OP_FLAGS_BIGCOUNT)))) { + size_t done_count = 0, shift, iter_count; + ptrdiff_t ext, lb; + + ompi_datatype_get_extent(dtype, &lb, &ext); + + while(done_count < full_count) { + if(done_count + INT_MAX > full_count) { + iter_count = full_count - done_count; + } else { + iter_count = INT_MAX; + } + shift = done_count * ext; + // Recurse one level in iterations of 'int' + ompi_3buff_op_reduce(op, (char*)source1 + shift, (char *)source2 + shift, + (char*)target + shift, iter_count, dtype); + done_count += iter_count; + } + return; + } + if (OPAL_LIKELY(ompi_op_is_intrinsic (op))) { + int count = (int)full_count; op->o_3buff_intrinsic.fns[ompi_op_ddt_map[dtype->id]](src1, src2, tgt, &count, &dtype, op->o_3buff_intrinsic.modules[ompi_op_ddt_map[dtype->id]]); } else { - ompi_3buff_op_user (op, src1, src2, tgt, count, dtype); + ompi_3buff_op_user (op, src1, src2, tgt, full_count, dtype); } }