diff --git a/LICENSE b/LICENSE index 5937b116a..7320037d2 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2022 NVIDIA CORPORATION + Copyright 2023 NVIDIA CORPORATION Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/README.md b/README.md index 606f37f78..672e0ab2d 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,19 @@ WholeMemory is a Tensor like storage and provide multi-GPU support. It is optimized for NVLink systems like DGX A100 servers. By working together with cuGraph, cuGraph-Ops, cuGraph-DGL, cuGraph-PyG, and upstream DGL and PyG, it will be easy to build GNN applications. + +## Table of content +- Installation + - [Getting WholeGraph Packages](./docs/wholegraph/source/installation/getting_wholegraph.md) + - [Building from Source](./docs/wholegraph/source/installation/source_build.md) +- General + - [WholeGraph Introduction](./docs/wholegraph/source/basics/wholegraph_intro.md) +- Packages + - libwholegraph (C/CUDA) + - pylibwholegraph +- API Docs + - Python + - C +- Reference + - [RAPIDS](https://rapids.ai) + - [cuGraph](https://github.com/rapidsai/cugraph) diff --git a/build.sh b/build.sh index 209023761..d7abdd813 100755 --- a/build.sh +++ b/build.sh @@ -29,6 +29,7 @@ VALIDARGS=" -v -g -n + --allgpuarch --native --cmake-args --compile-cmd @@ -81,7 +82,7 @@ CMAKE_VERBOSE_OPTION="" BUILD_TYPE=Release BUILD_ALL_GPU_ARCH=0 INSTALL_TARGET="--target install" -PYTHON="python" +PYTHON=${PYTHON:-python} # Set defaults for vars that may not have been defined externally # FIXME: if INSTALL_PREFIX is not set, check PREFIX, then check @@ -299,4 +300,4 @@ if hasArg docs; then cmake --build "${LIBWHOLEGRAPH_BUILD_DIR}" -j${PARALLEL_LEVEL} --target docs_wholegraph ${VERBOSE_FLAG} cd ${REPODIR}/docs/wholegraph make html -fi \ No newline at end of file +fi diff --git a/ci/test_clang_tidy.sh b/ci/test_clang_tidy.sh index aa0e9a494..ca7efc96e 100644 --- a/ci/test_clang_tidy.sh +++ b/ci/test_clang_tidy.sh @@ -28,12 +28,12 @@ env PATH=${PATH}:/usr/local/cuda/bin # library in the second run. CMAKE_EXTRA_ARGS="--cmake-args=\"-DBUILD_OPS_WITH_TORCH_C10_API=OFF\"" rapids-logger "Generate compilation databases for C++ library and tests" -./build.sh clean libwholegraph tests pylibwholegraph --compile-cmd ${CMAKE_EXTRA_ARGS} +./build.sh clean libwholegraph tests pylibwholegraph --allgpuarch --compile-cmd ${CMAKE_EXTRA_ARGS} # -git_modified_only -v rapids-logger "Run clang-tidy" python scripts/checks/run-clang-tidy.py \ -ignore wholememory_binding \ - build/compile_commands.json \ - pylibwholegraph/_skbuild/build/compile_commands.json \ + cpp/build/compile_commands.json \ + python/pylibwholegraph/_skbuild/build/compile_commands.json \ -v diff --git a/ci/test_python.sh b/ci/test_python.sh index 79ea14ecf..de66213d3 100644 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -50,7 +50,7 @@ PYTEST_PATH=${PYLIBWHOLEGRAPH_INSTALL_PATH}/tests pytest \ --cache-clear \ --forked \ - ${PYTEST_PATH}/pylibwholegraph/ ${PYTEST_PATH}/wholegraph_torch/ops/test_wholegraph_gather_scatter.py + ${PYTEST_PATH} echo "test_python is exiting with value: ${EXITCODE}" exit ${EXITCODE} diff --git a/conda/recipes/libwholegraph/build.sh b/conda/recipes/libwholegraph/build.sh index dc1dd59d9..86672e13c 100644 --- a/conda/recipes/libwholegraph/build.sh +++ b/conda/recipes/libwholegraph/build.sh @@ -1,4 +1,4 @@ #!/usr/bin/env bash # Copyright (c) 2021-2023, NVIDIA CORPORATION. -./build.sh -n libwholegraph tests -v +./build.sh -n libwholegraph tests -v --allgpuarch diff --git a/conda/recipes/libwholegraph/install_libwholegraph.sh b/conda/recipes/libwholegraph/install_libwholegraph.sh index dbabc2796..649b35621 100644 --- a/conda/recipes/libwholegraph/install_libwholegraph.sh +++ b/conda/recipes/libwholegraph/install_libwholegraph.sh @@ -1,4 +1,4 @@ #!/bin/bash # Copyright (c) 2022-2023, NVIDIA CORPORATION. -cmake --install build +cmake --install cpp/build diff --git a/conda/recipes/libwholegraph/install_libwholegraph_tests.sh b/conda/recipes/libwholegraph/install_libwholegraph_tests.sh index 48ec4242a..0522596af 100644 --- a/conda/recipes/libwholegraph/install_libwholegraph_tests.sh +++ b/conda/recipes/libwholegraph/install_libwholegraph_tests.sh @@ -1,4 +1,4 @@ #!/bin/bash # Copyright (c) 2022-2023, NVIDIA CORPORATION. -cmake --install build --component testing +cmake --install cpp/build --component testing diff --git a/conda/recipes/pylibwholegraph/build.sh b/conda/recipes/pylibwholegraph/build.sh index 645923b86..c7cc18764 100644 --- a/conda/recipes/pylibwholegraph/build.sh +++ b/conda/recipes/pylibwholegraph/build.sh @@ -3,4 +3,4 @@ CMAKE_EXTRA_ARGS="--cmake-args=\"-DBUILD_OPS_WITH_TORCH_C10_API=OFF\"" -./build.sh pylibwholegraph -v ${CMAKE_EXTRA_ARGS} +./build.sh pylibwholegraph --allgpuarch -v ${CMAKE_EXTRA_ARGS} diff --git a/cpp/cmake/thirdparty/nanobind.cmake b/cpp/cmake/thirdparty/nanobind.cmake deleted file mode 100644 index 33ca2dd4e..000000000 --- a/cpp/cmake/thirdparty/nanobind.cmake +++ /dev/null @@ -1,58 +0,0 @@ -#============================================================================= -# Copyright (c) 2021-2022, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#============================================================================= - -if (NOT SKBUILD) - message(WARNING "WHOLEGRAPH: This CMake file should be executed via scikit-build.") -endif() - -if (SKBUILD) - # Constrain FindPython to find the Python version used by scikit-build - set(Python_VERSION "${PYTHON_VERSION_STRING}") - set(Python_EXECUTABLE "${PYTHON_EXECUTABLE}") - set(Python_INCLUDE_DIR "${PYTHON_INCLUDE_DIR}") - set(Python_LIBRARIES "${PYTHON_LIBRARY}") -elseif (MSVC) - # MSVC needs a little extra help finding the Python library - find_package(PythonInterp) - find_package(Python) -endif() - -find_package(Python COMPONENTS Interpreter Development REQUIRED) - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(DEFAULT_CXX_FLAGS "") - -# reset the default flags if we have DEBUG_CXXFLAGS -if(CMAKE_BUILD_TYPE MATCHES Debug) - if (DEFINED ENV{DEBUG_CXXFLAGS}) - set(DEFAULT_CXX_FLAGS "$ENV{DEBUG_CXXFLAGS}") - separate_arguments(DEFAULT_CXX_FLAGS) - add_compile_options(${DEFAULT_CXX_FLAGS}) - endif() -endif() - -execute_process( - COMMAND - "${Python_EXECUTABLE}" -c "import nanobind; print(nanobind.cmake_dir())" - OUTPUT_VARIABLE _tmp_dir - OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ECHO STDOUT -) -message(STATUS "WHOLEGRAPH: nanobind dir='${_tmp_dir}'") -list(APPEND CMAKE_PREFIX_PATH "${_tmp_dir}") - -# Now import nanobind from CMake -find_package(nanobind CONFIG REQUIRED) diff --git a/cpp/include/wholememory/wholegraph_op.h b/cpp/include/wholememory/wholegraph_op.h index bbc79664b..b6a65bd20 100644 --- a/cpp/include/wholememory/wholegraph_op.h +++ b/cpp/include/wholememory/wholegraph_op.h @@ -87,26 +87,22 @@ wholememory_error_code_t wholegraph_csr_weighted_sample_without_replacement( * raft_pcg_generator_random_int cpu op * @param random_seed : random seed * @param subsequence : subsequence for generating random value - * @param output : Wholememory Tensor of output + * @param output : Wholememory Tensor of output * @return : wholememory_error_code_t */ -wholememory_error_code_t generate_random_positive_int_cpu( - int64_t random_seed, - int64_t subsequence, - wholememory_tensor_t output -); +wholememory_error_code_t generate_random_positive_int_cpu(int64_t random_seed, + int64_t subsequence, + wholememory_tensor_t output); /** * raft_pcg_generator_random_float cpu op * @param random_seed : random seed * @param subsequence : subsequence for generating random value - * @param output : Wholememory Tensor of output + * @param output : Wholememory Tensor of output * @return : wholememory_error_code_t */ wholememory_error_code_t generate_exponential_distribution_negative_float_cpu( - int64_t random_seed, - int64_t subsequence, - wholememory_tensor_t output); + int64_t random_seed, int64_t subsequence, wholememory_tensor_t output); #ifdef __cplusplus } diff --git a/cpp/include/wholememory/wholememory.h b/cpp/include/wholememory/wholememory.h index df8b04a6e..2a12f761a 100644 --- a/cpp/include/wholememory/wholememory.h +++ b/cpp/include/wholememory/wholememory.h @@ -59,8 +59,17 @@ enum wholememory_memory_location_t { WHOLEMEMORY_ML_HOST, }; +/** + * Initialize WholeMemory library + * @param flags : reserved should be 0 + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_init(unsigned int flags); +/** + * Finalize WholeMemory library + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_finalize(); /* Opaque handle to communicator */ @@ -71,23 +80,68 @@ struct wholememory_unique_id_t { char internal[WHOLEMEMORY_UNIQUE_ID_BYTES]; }; +/** + * Create UniqueID for WholeMemory Communicator + * @param unique_id : returned UniqueID + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_create_unique_id(wholememory_unique_id_t* unique_id); +/** + * Create WholeMemory Communicator + * @param comm : returned WholeMemory Communicator + * @param unique_id : UniqueID + * @param rank : rank of this process. + * @param size : number of processes in this Communicator + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_create_communicator(wholememory_comm_t* comm, wholememory_unique_id_t unique_id, int rank, int size); +/** + * Destroy WholeMemory Communicator + * @param comm : WholeMemory Communicator to destroy + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_destroy_communicator(wholememory_comm_t comm); +/** + * Get the rank of current process in the WholeMemory Communicator + * @param rank : returned rank + * @param comm : WholeMemory Communicator + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_communicator_get_rank(int* rank, wholememory_comm_t comm); +/** + * Get the size of WholeMemory Communicator + * @param size : returned size + * @param comm : WholeMemory Communicator + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_communicator_get_size(int* size, wholememory_comm_t comm); +/** + * Barrier on WholeMemory Communicator + * @param comm : WholeMemory Communicator + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_communicator_barrier(wholememory_comm_t comm); typedef struct wholememory_handle_* wholememory_handle_t; +/** + * Malloc WholeMemory + * @param wholememory_handle_ptr : returned WholeMemory Handle + * @param total_size : total allocated size in bytes. + * @param comm : WholeMemory Communicator + * @param memory_type : WholeMemory type + * @param memory_location : memory location, host or device + * @param data_granularity : granularity size of data, which is guaranteed not to be partitioned. + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_malloc(wholememory_handle_t* wholememory_handle_ptr, size_t total_size, wholememory_comm_t comm, @@ -95,51 +149,150 @@ wholememory_error_code_t wholememory_malloc(wholememory_handle_t* wholememory_ha wholememory_memory_location_t memory_location, size_t data_granularity); +/** + * Free allocated WholeMemory Handle + * @param wholememory_handle : WholeMemory Handle to free + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handle); +/** + * Get underlying WholeMemory Communicator from WholeMemory Handle + * @param comm : returned WholeMemory Communicator + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t* comm, wholememory_handle_t wholememory_handle); +/** + * Get WholeMemory Type + * @param wholememory_handle : WholeMemory Handle + * @return : WholeMemory Type + */ wholememory_memory_type_t wholememory_get_memory_type(wholememory_handle_t wholememory_handle); +/** + * Get WholeMemory Location + * @param wholememory_handle : WholeMemory Handle + * @return : WholeMemory Location + */ wholememory_memory_location_t wholememory_get_memory_location( wholememory_handle_t wholememory_handle); +/** + * Get total size of WholeMemory + * @param wholememory_handle : WholeMemory Handle + * @return : total size + */ size_t wholememory_get_total_size(wholememory_handle_t wholememory_handle); +/** + * Get data granularity of WholeMemory Handle + * @param wholememory_handle : WholeMemory Handle + * @return : data granularity size + */ size_t wholememory_get_data_granularity(wholememory_handle_t wholememory_handle); +/** + * Get local memory from WholeMemory Handle of current rank, local memory has direct access to the + * memory. But local memory doesn't have to be on local GPU. + * @param local_ptr : returned local memory pointer + * @param local_size : returned local memory size + * @param local_offset : returned local memory offset from WholeMemory + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, size_t* local_size, size_t* local_offset, wholememory_handle_t wholememory_handle); +/** + * Get local memory of specified rank from WholeMemory Handle + * @param rank_memory_ptr : returned local memory pointer of specified rank + * @param rank_memory_size : returned local memory size of specified rank + * @param rank_memory_offset : returned local memory offset of specified rank from WholeMemory + * @param rank : rank specified + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, size_t* rank_memory_size, size_t* rank_memory_offset, int rank, wholememory_handle_t wholememory_handle); +/** + * Get global memory pointer from WholeMemory Handle. + * Only Continuous memory type or Chunked Host memory has global pointer. + * @param global_ptr : returned pointer of WholeMemory + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_get_global_pointer(void** global_ptr, wholememory_handle_t wholememory_handle); +/** + * Get global reference from WholeMemory Handle + * WholeMemory global reference is common data structure for Continuous and Chunked Memory Types. + * @param wholememory_gref : returned WholeMemory global reference + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_get_global_reference(wholememory_gref_t* wholememory_gref, wholememory_handle_t wholememory_handle); +/** + * Get the partition plan WholeMemory will use + * @param size_per_rank : returned size per rank + * @param total_size : total size + * @param data_granularity : data granularity + * @param world_size : communicator world size + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_determine_partition_plan(size_t* size_per_rank, size_t total_size, size_t data_granularity, int world_size); +/** + * Get the partition plan WholeMemory will use based on entry count. + * Entry is number of data granularity + * @param entry_per_rank : returned entry count per rank + * @param total_entry_count : total entry count + * @param world_size : communicator world size + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_determine_entry_partition_plan(size_t* entry_per_rank, size_t total_entry_count, int world_size); +/** + * Get the partition plan used in WholeMemory Handle + * @param size_per_rank : returned size per rank + * @param wholememory_handle : WholeMemory Handle + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_get_partition_plan(size_t* size_per_rank, wholememory_handle_t wholememory_handle); +/** + * Fork a new process and get device count. Should be called before other CUDA call + * @return : CUDA device count, -1 on error + */ int fork_get_device_count(); +/** + * Load WholeMemory from binary files, all rank should be called together + * @param wholememory_handle : WholeMemory Handle + * @param memory_offset : load to memory offset + * @param memory_entry_size : entry size of WholeMemory + * @param file_entry_size : entry size in file, should be less than or equal to memory_entry_size + * @param file_names : file names, all binary files will be logically concatenated and loaded. + * @param file_count : number of files. + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_load_from_file(wholememory_handle_t wholememory_handle, size_t memory_offset, size_t memory_entry_size, @@ -147,6 +300,16 @@ wholememory_error_code_t wholememory_load_from_file(wholememory_handle_t wholeme const char** file_names, int file_count); +/** + * Store local WholeMemory to file, this should be called by all ranks, with different + * local_file_name. + * @param wholememory_handle : WholeMemory Handle + * @param memory_offset : memory offset to store + * @param memory_entry_stride : entry size of WholeMemory + * @param file_entry_size : entry size in file, should be less than or equal to memory_entry_size + * @param local_file_name : local file to store to + * @return : wholememory_error_code_t + */ wholememory_error_code_t wholememory_store_to_file(wholememory_handle_t wholememory_handle, size_t memory_offset, size_t memory_entry_stride, diff --git a/cpp/src/parallel_utils.cpp b/cpp/src/parallel_utils.cpp index 96839f343..2251f0655 100644 --- a/cpp/src/parallel_utils.cpp +++ b/cpp/src/parallel_utils.cpp @@ -93,6 +93,8 @@ void MultiProcessRun(int world_size, std::function f, bool inlin int ForkGetDeviceCount() { + static int s_device_count = -1; + if (s_device_count >= 0) { return s_device_count; } int pipes[2]; if (pipe(pipes) == -1) { WHOLEMEMORY_ERROR("Create pipe failed."); @@ -120,6 +122,7 @@ int ForkGetDeviceCount() int wstatus; pid_t pid_ret = waitpid(pid, &wstatus, 0); if (pid_ret != pid) { WHOLEMEMORY_FATAL("wait dev_count process failed."); } + s_device_count = dev_count; return dev_count; } } diff --git a/cpp/src/wholegraph_ops/raft_random_gen.cu b/cpp/src/wholegraph_ops/raft_random_gen.cu index 7bbd27b8e..b7277781f 100644 --- a/cpp/src/wholegraph_ops/raft_random_gen.cu +++ b/cpp/src/wholegraph_ops/raft_random_gen.cu @@ -14,38 +14,36 @@ * limitations under the License. */ -#include #include +#include #include - #include "error.hpp" #include "logger.hpp" -wholememory_error_code_t generate_random_positive_int_cpu( - int64_t random_seed, - int64_t subsequence, - wholememory_tensor_t output -) { +wholememory_error_code_t generate_random_positive_int_cpu(int64_t random_seed, + int64_t subsequence, + wholememory_tensor_t output) +{ auto output_tensor_desc = *wholememory_tensor_get_tensor_description(output); if (output_tensor_desc.dim != 1) { WHOLEMEMORY_ERROR("output should be 1D tensor."); return WHOLEMEMORY_INVALID_INPUT; } - if (output_tensor_desc.dtype != WHOLEMEMORY_DT_INT64 && output_tensor_desc.dtype != WHOLEMEMORY_DT_INT) { + if (output_tensor_desc.dtype != WHOLEMEMORY_DT_INT64 && + output_tensor_desc.dtype != WHOLEMEMORY_DT_INT) { WHOLEMEMORY_ERROR("output should be int64 or int32 tensor."); return WHOLEMEMORY_INVALID_INPUT; } auto* output_ptr = wholememory_tensor_get_data_pointer(output); PCGenerator rng((unsigned long long)random_seed, subsequence, 0); - for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) { + for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) { if (output_tensor_desc.dtype == WHOLEMEMORY_DT_INT) { int32_t random_num; - rng.next(random_num); + rng.next(random_num); static_cast(output_ptr)[i] = random_num; - } - else { + } else { int64_t random_num; rng.next(random_num); static_cast(output_ptr)[i] = random_num; @@ -55,10 +53,8 @@ wholememory_error_code_t generate_random_positive_int_cpu( } wholememory_error_code_t generate_exponential_distribution_negative_float_cpu( - int64_t random_seed, - int64_t subsequence, - wholememory_tensor_t output -) { + int64_t random_seed, int64_t subsequence, wholememory_tensor_t output) +{ auto output_tensor_desc = *wholememory_tensor_get_tensor_description(output); if (output_tensor_desc.dim != 1) { WHOLEMEMORY_ERROR("output should be 1D tensor."); @@ -71,9 +67,9 @@ wholememory_error_code_t generate_exponential_distribution_negative_float_cpu( auto* output_ptr = wholememory_tensor_get_data_pointer(output); PCGenerator rng((unsigned long long)random_seed, subsequence, 0); for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) { - float u = -rng.next_float(1.0f, 0.5f); + float u = -rng.next_float(1.0f, 0.5f); uint64_t random_num2 = 0; - int seed_count = -1; + int seed_count = -1; do { rng.next(random_num2); seed_count++; @@ -89,7 +85,7 @@ wholememory_error_code_t generate_exponential_distribution_negative_float_cpu( int32_t one_bit = count_one(random_num2) + seed_count * 64; u *= pow(2, -one_bit); // float logk = (log1pf(u) / logf(2.0)) * (1.0f / (float)weight); - float logk = (log1p(u) / log(2.0)); + float logk = (log1p(u) / log(2.0)); static_cast(output_ptr)[i] = logk; } return WHOLEMEMORY_SUCCESS; diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh index c71fc3353..0581090c3 100644 --- a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh @@ -196,7 +196,7 @@ __global__ void unweighted_sample_without_replacement_kernel( int32_t random_num; rng.next(random_num); int32_t r = idx < M ? (random_num % (N - idx)) : N; - sa_p[i] = ((uint64_t)r << 32UL) | idx; + sa_p[i] = ((uint64_t)r << 32UL) | idx; } __syncthreads(); BlockRadixSort(shared_data.temp_storage).SortBlockedToStriped(sa_p); diff --git a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh index ca0df4591..22a97fd19 100644 --- a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh @@ -39,9 +39,9 @@ namespace wholegraph_ops { template __device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PCGenerator& rng) { - float u = -rng.next_float(1.0f, 0.5f); + float u = -rng.next_float(1.0f, 0.5f); uint64_t random_num2 = 0; - int seed_count = -1; + int seed_count = -1; do { rng.next(random_num2); seed_count++; diff --git a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu index 60a4d6984..d8f8c7e28 100644 --- a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu +++ b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu @@ -319,7 +319,6 @@ void random_sample_without_replacement_cpu_base(std::vector* a, } } - template void host_unweighted_sample_without_replacement( void* host_csr_row_ptr, @@ -545,9 +544,9 @@ inline int count_one(unsigned long long num) template float host_gen_key_from_weight(const WeightType weight, PCGenerator& rng) { - float u = -rng.next_float(1.0f, 0.5f); + float u = -rng.next_float(1.0f, 0.5f); uint64_t random_num2 = 0; - int seed_count = -1; + int seed_count = -1; do { rng.next(random_num2); seed_count++; diff --git a/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx deleted file mode 100644 index d2105c85d..000000000 --- a/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ /dev/null @@ -1,1926 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# cython: profile=False -# distutils: language = c++ -# cython: embedsignature = True -# cython: language_level = 3 - -cimport cpython -from libc cimport stdlib -from libc.stdio cimport printf, fprintf, stdout, stderr, fflush -import functools -import cython -from libc.stdint cimport * -from libcpp.cast cimport * -from libcpp cimport bool -from cpython cimport Py_buffer -from cpython cimport array -import array -import numpy as np -from cpython.ref cimport PyObject, Py_INCREF, Py_DECREF -from cpython.object cimport Py_TYPE, PyObject_CallObject -from cpython.tuple cimport * -from cpython.long cimport PyLong_AsLongLong - - -cdef extern from "Python.h": - void Py_INCREF(PyObject *o) - void Py_DECREF(PyObject *o) - - const char * PyUnicode_AsUTF8(object unicode) - - PyObject * PyUnicode_FromString(const char * u) - - -cdef extern from "wholememory/wholememory.h": - ctypedef enum wholememory_error_code_t: - WHOLEMEMORY_SUCCESS "WHOLEMEMORY_SUCCESS" # success - WHOLEMEMORY_UNKNOW_ERROR "WHOLEMEMORY_UNKNOW_ERROR" # unknown error - WHOLEMEMORY_NOT_IMPLEMENTED "WHOLEMEMORY_NOT_IMPLEMENTED" # method is not implemented - WHOLEMEMORY_LOGIC_ERROR "WHOLEMEMORY_LOGIC_ERROR" # logic error - WHOLEMEMORY_CUDA_ERROR "WHOLEMEMORY_CUDA_ERROR" # CUDA error - WHOLEMEMORY_COMMUNICATION_ERROR "WHOLEMEMORY_COMMUNICATION_ERROR" # communication error - WHOLEMEMORY_INVALID_INPUT "WHOLEMEMORY_INVALID_INPUT" # invalid input, e.g. nullptr - WHOLEMEMORY_INVALID_VALUE "WHOLEMEMORY_INVALID_VALUE" # input value is invalid - WHOLEMEMORY_OUT_OF_MEMORY "WHOLEMEMORY_OUT_OF_MEMORY" # out of memory - - ctypedef enum wholememory_memory_type_t: - WHOLEMEMORY_MT_NONE "WHOLEMEMORY_MT_NONE" - WHOLEMEMORY_MT_CONTINUOUS "WHOLEMEMORY_MT_CONTINUOUS" - WHOLEMEMORY_MT_CHUNKED "WHOLEMEMORY_MT_CHUNKED" - WHOLEMEMORY_MT_DISTRIBUTED "WHOLEMEMORY_MT_DISTRIBUTED" - - ctypedef enum wholememory_memory_location_t: - WHOLEMEMORY_ML_NONE "WHOLEMEMORY_ML_NONE" - WHOLEMEMORY_ML_DEVICE "WHOLEMEMORY_ML_DEVICE" - WHOLEMEMORY_ML_HOST "WHOLEMEMORY_ML_HOST" - - cdef wholememory_error_code_t wholememory_init(unsigned int flags) - - cdef wholememory_error_code_t wholememory_finalize() - - cdef struct wholememory_unique_id_t: - char internal[128] - - cdef struct wholememory_comm_: - pass - - ctypedef wholememory_comm_ * wholememory_comm_t - - cdef wholememory_error_code_t wholememory_create_unique_id(wholememory_unique_id_t * unique_id) - - cdef wholememory_error_code_t wholememory_create_communicator(wholememory_comm_t * comm, - wholememory_unique_id_t unique_id, - int rank, - int size) - - cdef wholememory_error_code_t wholememory_destroy_communicator(wholememory_comm_t comm) - - cdef wholememory_error_code_t wholememory_communicator_get_rank(int * rank, wholememory_comm_t comm) - - cdef wholememory_error_code_t wholememory_communicator_get_size(int * size, wholememory_comm_t comm) - - cdef wholememory_error_code_t wholememory_communicator_barrier(wholememory_comm_t comm) - - cdef struct wholememory_handle_: - pass - - ctypedef wholememory_handle_ * wholememory_handle_t - - cdef wholememory_error_code_t wholememory_malloc(wholememory_handle_t * wholememory_handle_ptr, - size_t total_size, - wholememory_comm_t comm, - wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location, - size_t data_granularity) - - cdef wholememory_error_code_t wholememory_free(wholememory_handle_t wholememory_handle) - - cdef wholememory_error_code_t wholememory_get_communicator(wholememory_comm_t * comm, - wholememory_handle_t wholememory_handle) - - cdef wholememory_memory_type_t wholememory_get_memory_type(wholememory_handle_t wholememory_handle) - - cdef wholememory_memory_location_t wholememory_get_memory_location(wholememory_handle_t wholememory_handle) - - cdef size_t wholememory_get_total_size(wholememory_handle_t wholememory_handle) - - cdef wholememory_error_code_t wholememory_get_local_memory(void** local_ptr, - size_t * local_size, - size_t * local_offset, - wholememory_handle_t wholememory_handle) - - cdef wholememory_error_code_t wholememory_get_rank_memory(void** rank_memory_ptr, - size_t * rank_memory_size, - size_t * rank_memory_offset, - int rank, - wholememory_handle_t wholememory_handle) - - cdef wholememory_error_code_t wholememory_get_global_pointer(void** global_ptr, - wholememory_handle_t wholememory_handle) - - cdef wholememory_error_code_t wholememory_determine_partition_plan(size_t * size_per_rank, - size_t total_size, - size_t data_granularity, - int world_size) - - cdef wholememory_error_code_t wholememory_determine_entry_partition_plan(size_t * entry_per_rank, - size_t total_entry_count, - int world_size) - - cdef wholememory_error_code_t wholememory_get_partition_plan(size_t * size_per_rank, - wholememory_handle_t wholememory_handle) - - cdef int fork_get_device_count() - - cdef wholememory_error_code_t wholememory_load_from_file(wholememory_handle_t wholememory_handle, - size_t memory_offset, - size_t memory_entry_size, - size_t file_entry_size, - const char** file_names, - int file_count) - - cdef wholememory_error_code_t wholememory_store_to_file(wholememory_handle_t wholememory_handle, - size_t memory_offset, - size_t memory_entry_stride, - size_t file_entry_size, - const char *local_file_name) - - -cpdef enum WholeMemoryErrorCode: - Success = WHOLEMEMORY_SUCCESS - UnknowError = WHOLEMEMORY_UNKNOW_ERROR - NotImplemented = WHOLEMEMORY_NOT_IMPLEMENTED - LogicError = WHOLEMEMORY_LOGIC_ERROR - CUDAError = WHOLEMEMORY_CUDA_ERROR - CommunicationError = WHOLEMEMORY_COMMUNICATION_ERROR - InvalidInput = WHOLEMEMORY_INVALID_INPUT - InvalidValue = WHOLEMEMORY_INVALID_VALUE - OutOfMemory = WHOLEMEMORY_OUT_OF_MEMORY - -cpdef enum WholeMemoryMemoryType: - MtNone = WHOLEMEMORY_MT_NONE - MtContinuous = WHOLEMEMORY_MT_CONTINUOUS - MtChunked = WHOLEMEMORY_MT_CHUNKED - MtDistributed = WHOLEMEMORY_MT_DISTRIBUTED - -cpdef enum WholeMemoryMemoryLocation: - MlNone = WHOLEMEMORY_ML_NONE - MlDevice = WHOLEMEMORY_ML_DEVICE - MlHost = WHOLEMEMORY_ML_HOST - -cdef check_wholememory_error_code(wholememory_error_code_t err): - cdef WholeMemoryErrorCode err_code = int(err) - if err_code == Success: - return - elif err_code == UnknowError: - raise Exception('Unknown error') - elif err_code == NotImplemented: - raise NotImplementedError('Not implemented') - elif err_code == LogicError: - raise RuntimeError('Logic error') - elif err_code == CUDAError: - raise RuntimeError('CUDA error') - elif err_code == CommunicationError: - raise RuntimeError('Communication error') - elif err_code == InvalidInput: - raise ValueError('Invalid input') - elif err_code == InvalidValue: - raise ValueError('Invalid value') - elif err_code == OutOfMemory: - raise MemoryError('Out of memory') - else: - raise NotImplementedError('Error code %d not recognized' % (int(err),)) - - -cdef extern from "wholememory/tensor_description.h": - ctypedef enum wholememory_dtype_t: - WHOLEMEMORY_DT_UNKNOWN "WHOLEMEMORY_DT_UNKNOWN" - WHOLEMEMORY_DT_FLOAT "WHOLEMEMORY_DT_FLOAT" - WHOLEMEMORY_DT_HALF "WHOLEMEMORY_DT_HALF" - WHOLEMEMORY_DT_DOUBLE "WHOLEMEMORY_DT_DOUBLE" - WHOLEMEMORY_DT_BF16 "WHOLEMEMORY_DT_BF16" - WHOLEMEMORY_DT_INT "WHOLEMEMORY_DT_INT" - WHOLEMEMORY_DT_INT64 "WHOLEMEMORY_DT_INT64" - WHOLEMEMORY_DT_INT16 "WHOLEMEMORY_DT_INT16" - WHOLEMEMORY_DT_INT8 "WHOLEMEMORY_DT_INT8" - WHOLEMEMORY_DT_COUNT "WHOLEMEMORY_DT_COUNT" - - cdef struct wholememory_tensor_description_t: - int64_t sizes[8] - int64_t strides[8] - int64_t storage_offset - int dim - wholememory_dtype_t dtype - - cdef size_t wholememory_dtype_get_element_size(wholememory_dtype_t dtype) - - cdef int64_t wholememory_get_memory_element_count_from_tensor( - wholememory_tensor_description_t * p_tensor_description) - - -cdef extern from "wholememory/env_func_ptrs.h": - ctypedef enum wholememory_memory_allocation_type_t: - WHOLEMEMORY_MA_NONE "WHOLEMEMORY_MA_NONE" - WHOLEMEMORY_MA_DEVICE "WHOLEMEMORY_MA_DEVICE" - WHOLEMEMORY_MA_HOST "WHOLEMEMORY_MA_HOST" - WHOLEMEMORY_MA_PINNED "WHOLEMEMORY_MA_PINNED" - - ctypedef void (*wholememory_create_memory_context_func_t)(void ** memory_context, - void * global_context) - - ctypedef void (*wholememory_destroy_memory_context_func_t)(void * memory_context, - void * global_context) - - ctypedef void * (*wholememory_malloc_func_t)(wholememory_tensor_description_t * desc, - wholememory_memory_allocation_type_t memory_allocation_type, - void * memory_context, - void * global_context) - - ctypedef void (*wholememory_free_func_t)(void * memory_context, void * global_context) - - cdef struct wholememory_temp_memory_func_t: - wholememory_create_memory_context_func_t create_memory_context_fn - wholememory_destroy_memory_context_func_t destroy_memory_context_fn - wholememory_malloc_func_t malloc_fn - wholememory_free_func_t free_fn - void * global_context - - cdef struct wholememory_output_memory_func_t: - wholememory_malloc_func_t malloc_fn - wholememory_free_func_t free_fn - void * global_context - - cdef struct wholememory_env_func_t: - wholememory_temp_memory_func_t temporary_fns - wholememory_output_memory_func_t output_fns - - -cpdef enum WholeMemoryMemoryAllocType: - MatNone = WHOLEMEMORY_MA_NONE - MatDevice = WHOLEMEMORY_MA_DEVICE - MatHost = WHOLEMEMORY_MA_HOST - MatPinned = WHOLEMEMORY_MA_PINNED - -cdef class PyMemoryAllocType: - cdef wholememory_memory_allocation_type_t alloc_type - - def __cinit__(self): - self.alloc_type = WHOLEMEMORY_MA_NONE - - def set_type(self, WholeMemoryMemoryAllocType new_type): - self.alloc_type = new_type - - def get_type(self): - return self.alloc_type - - def set_ctype(self, wholememory_memory_allocation_type_t alloc_type): - self.alloc_type = alloc_type - - def get_ctype(self): - return self.alloc_type - -cdef class GlobalContextWrapper: - cdef PyObject * temp_create_context_fn - cdef PyObject * temp_destroy_context_fn - cdef PyObject * temp_malloc_fn - cdef PyObject * temp_free_fn - cdef PyObject * temp_global_context - cdef PyObject * output_malloc_fn - cdef PyObject * output_free_fn - cdef PyObject * output_global_context - cdef wholememory_env_func_t env_func - - def __cinit__(self): - self.temp_create_context_fn = NULL - self.temp_destroy_context_fn = NULL - self.temp_malloc_fn = NULL - self.temp_free_fn = NULL - self.temp_global_context = NULL - self.output_malloc_fn = NULL - self.output_free_fn = NULL - self.output_global_context = NULL - - def __dealloc__(self): - Py_DECREF(self.self.temp_create_context_fn) - Py_DECREF(self.self.temp_destroy_context_fn) - Py_DECREF(self.self.temp_malloc_fn) - Py_DECREF(self.self.temp_free_fn) - if self.temp_global_context: - Py_DECREF(self.self.temp_global_context) - Py_DECREF(self.self.output_malloc_fn) - Py_DECREF(self.self.output_free_fn) - if self.output_global_context: - Py_DECREF(self.self.output_global_context) - - cpdef create_context(self, - temp_create_context_fn, - temp_destroy_context_fn, - temp_malloc_fn, - temp_free_fn, - temp_global_context, - output_malloc_fn, - output_free_fn, - output_global_context): - self.temp_create_context_fn = temp_create_context_fn - Py_INCREF(self.temp_create_context_fn) - self.temp_destroy_context_fn = temp_destroy_context_fn - Py_INCREF(self.temp_destroy_context_fn) - self.temp_malloc_fn = temp_malloc_fn - Py_INCREF(self.temp_malloc_fn) - self.temp_free_fn = temp_free_fn - Py_INCREF(self.temp_free_fn) - if temp_global_context: - self.temp_global_context = temp_global_context - Py_INCREF(self.temp_global_context) - self.output_malloc_fn = output_malloc_fn - Py_INCREF(self.output_malloc_fn) - self.output_free_fn = output_free_fn - Py_INCREF(self.output_free_fn) - if output_global_context: - self.output_global_context = output_global_context - Py_INCREF(self.output_global_context) - self.env_func.temporary_fns.create_memory_context_fn = &python_cb_wrapper_temp_create_context - self.env_func.temporary_fns.destroy_memory_context_fn = &python_cb_wrapper_temp_destroy_context - self.env_func.temporary_fns.malloc_fn = &python_cb_wrapper_temp_malloc - self.env_func.temporary_fns.free_fn = &python_cb_wrapper_temp_free - self.env_func.temporary_fns.global_context = self - self.env_func.output_fns.malloc_fn = &python_cb_wrapper_output_malloc - self.env_func.output_fns.free_fn = &python_cb_wrapper_output_free - self.env_func.output_fns.global_context = self - - cpdef int64_t get_env_fns(self): - return (&self.env_func) - -cdef void python_cb_wrapper_temp_create_context(void** memory_context, - void * global_context) nogil: - cdef PyObject * ret_memory_context = NULL - with gil: - wrapped_global_context = global_context - python_fn = wrapped_global_context.temp_create_context_fn - python_global_context = wrapped_global_context.temp_global_context - args = PyTuple_New(1) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 0, python_global_context) - py_memory_context = PyObject_CallObject( python_fn, args) - ret_memory_context = py_memory_context - Py_DECREF(args) - Py_INCREF(ret_memory_context) - ( memory_context)[0] = ret_memory_context - return - -cdef void python_cb_wrapper_temp_destroy_context(void * memory_context, - void * global_context) nogil: - with gil: - wrapped_global_context = global_context - python_fn = wrapped_global_context.temp_destroy_context_fn - python_global_context = wrapped_global_context.temp_global_context - args = PyTuple_New(2) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 0, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 1, python_global_context) - PyObject_CallObject( python_fn, args) - Py_DECREF(args) - Py_DECREF( memory_context) - return - -cdef void * python_cb_wrapper_temp_malloc(wholememory_tensor_description_t * tensor_desc, - wholememory_memory_allocation_type_t malloc_type, - void * memory_context, - void * global_context) nogil: - cdef int64_t res_ptr = 0 - with gil: - wrapped_global_context = global_context - py_tensor_desc = PyWholeMemoryTensorDescription() - py_tensor_desc.set_by_tensor_desc(tensor_desc) - py_malloc_type = PyMemoryAllocType() - py_malloc_type.set_type(malloc_type) - python_fn = wrapped_global_context.temp_malloc_fn - python_global_context = wrapped_global_context.temp_global_context - args = PyTuple_New(4) - Py_INCREF(py_tensor_desc) - PyTuple_SetItem(args, 0, py_tensor_desc) - Py_INCREF(py_malloc_type) - PyTuple_SetItem(args, 1, py_malloc_type) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 2, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 3, python_global_context) - res_ptr = PyLong_AsLongLong(PyObject_CallObject( python_fn, args)) - Py_DECREF(args) - return res_ptr - -cdef void python_cb_wrapper_temp_free(void * memory_context, - void * global_context) nogil: - with gil: - wrapped_global_context = global_context - python_fn = wrapped_global_context.temp_free_fn - python_global_context = wrapped_global_context.temp_global_context - args = PyTuple_New(2) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 0, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 1, python_global_context) - PyObject_CallObject( python_fn, args) - Py_DECREF(args) - return - -cdef void * python_cb_wrapper_output_malloc(wholememory_tensor_description_t * tensor_desc, - wholememory_memory_allocation_type_t malloc_type, - void * memory_context, - void * global_context) nogil: - cdef int64_t res_ptr = 0 - with gil: - wrapped_global_context = global_context - py_tensor_desc = PyWholeMemoryTensorDescription() - py_tensor_desc.set_by_tensor_desc(tensor_desc) - py_malloc_type = PyMemoryAllocType() - py_malloc_type.set_type(malloc_type) - python_fn = wrapped_global_context.output_malloc_fn - python_global_context = wrapped_global_context.output_global_context - args = PyTuple_New(4) - Py_INCREF(py_tensor_desc) - PyTuple_SetItem(args, 0, py_tensor_desc) - Py_INCREF(py_malloc_type) - PyTuple_SetItem(args, 1, py_malloc_type) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 2, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 3, python_global_context) - res_ptr = PyLong_AsLongLong(PyObject_CallObject( python_fn, args)) - Py_DECREF(args) - return res_ptr - -cdef void python_cb_wrapper_output_free(void * memory_context, - void * global_context) nogil: - with gil: - wrapped_global_context = global_context - python_fn = wrapped_global_context.output_free_fn - python_global_context = wrapped_global_context.output_global_context - args = PyTuple_New(2) - Py_INCREF( memory_context) - PyTuple_SetItem(args, 0, memory_context) - Py_INCREF( python_global_context) - PyTuple_SetItem(args, 1, python_global_context) - PyObject_CallObject( python_fn, args) - Py_DECREF(args) - return - - -cdef extern from "wholememory/wholememory_tensor.h": - cdef struct wholememory_tensor_: - pass - - ctypedef wholememory_tensor_ * wholememory_tensor_t - - cdef wholememory_error_code_t wholememory_create_tensor(wholememory_tensor_t *wholememory_tensor, - wholememory_tensor_description_t *tensor_description, - wholememory_comm_t comm, - wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location) - - cdef wholememory_error_code_t wholememory_destroy_tensor(wholememory_tensor_t wholememory_tensor) - - cdef wholememory_error_code_t wholememory_make_tensor_from_pointer(wholememory_tensor_t *wholememory_tensor, - void *data_ptr, - wholememory_tensor_description_t *tensor_description) - - cdef wholememory_error_code_t wholememory_make_tensor_from_handle(wholememory_tensor_t *wholememory_tensor, - wholememory_handle_t wholememory_handle, - wholememory_tensor_description_t *tensor_description) - - cdef bool wholememory_tensor_has_handle(wholememory_tensor_t wholememory_tensor) - - cdef wholememory_handle_t wholememory_tensor_get_memory_handle(wholememory_tensor_t wholememory_tensor) - - cdef wholememory_tensor_description_t * wholememory_tensor_get_tensor_description( - wholememory_tensor_t wholememory_tensor) - - cdef wholememory_error_code_t wholememory_tensor_get_subtensor(wholememory_tensor_t wholememory_tensor, - int64_t *starts, - int64_t *ends, - wholememory_tensor_t *sub_wholememory_tensor) - - int64_t get_wholememory_tensor_count() - - -def py_get_wholememory_tensor_count(): - return get_wholememory_tensor_count() - -cpdef enum WholeMemoryDataType: - DtUnknown = WHOLEMEMORY_DT_UNKNOWN - DtFloat = WHOLEMEMORY_DT_FLOAT - DtHalf = WHOLEMEMORY_DT_HALF - DtDouble = WHOLEMEMORY_DT_DOUBLE - DtBF16 = WHOLEMEMORY_DT_BF16 - DtInt = WHOLEMEMORY_DT_INT - DtInt64 = WHOLEMEMORY_DT_INT64 - DtInt16 = WHOLEMEMORY_DT_INT16 - DtInt8 = WHOLEMEMORY_DT_INT8 - DtCount = WHOLEMEMORY_DT_COUNT - -cdef extern from "wholememory/embedding.h": - cdef struct wholememory_embedding_cache_policy_: - pass - - cdef struct wholememory_embedding_optimizer_: - pass - - cdef struct wholememory_embedding_: - pass - - ctypedef wholememory_embedding_cache_policy_ * wholememory_embedding_cache_policy_t - ctypedef wholememory_embedding_optimizer_ * wholememory_embedding_optimizer_t - ctypedef wholememory_embedding_ * wholememory_embedding_t - - ctypedef enum wholememory_access_type_t: - WHOLEMEMORY_AT_NONE "WHOLEMEMORY_AT_NONE" - WHOLEMEMORY_AT_READONLY "WHOLEMEMORY_AT_READONLY" - WHOLEMEMORY_AT_READWRITE "WHOLEMEMORY_AT_READWRITE" - - ctypedef enum wholememory_optimizer_type_t: - WHOLEMEMORY_OPT_NONE "WHOLEMEMORY_OPT_NONE" - WHOLEMEMORY_OPT_SGD "WHOLEMEMORY_OPT_SGD" - WHOLEMEMORY_OPT_LAZY_ADAM "WHOLEMEMORY_OPT_LAZY_ADAM" - WHOLEMEMORY_OPT_RMSPROP "WHOLEMEMORY_OPT_RMSPROP" - WHOLEMEMORY_OPT_ADAGRAD "WHOLEMEMORY_OPT_ADAGRAD" - - cdef wholememory_error_code_t wholememory_create_embedding_optimizer( - wholememory_embedding_optimizer_t * optimizer, wholememory_optimizer_type_t optimizer_type) - - cdef wholememory_error_code_t wholememory_optimizer_set_parameter( - wholememory_embedding_optimizer_t optimizer, const char * parameter_name, void * value) - - cdef void wholememory_destroy_embedding_optimizer(wholememory_embedding_optimizer_t optimizer) - - cdef wholememory_error_code_t wholememory_create_embedding_cache_policy( - wholememory_embedding_cache_policy_t * cache_policy, - wholememory_comm_t cache_level_comm, - wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location, - wholememory_access_type_t access_type, - float cache_ratio) - - cdef wholememory_error_code_t wholememory_destroy_embedding_cache_policy( - wholememory_embedding_cache_policy_t cache_policy) - - cdef wholememory_error_code_t wholememory_create_embedding( - wholememory_embedding_t * wholememory_embedding, - wholememory_tensor_description_t * embedding_tensor_description, - wholememory_comm_t comm, - wholememory_memory_type_t memory_type, - wholememory_memory_location_t memory_location, - wholememory_embedding_optimizer_t optimizer, - wholememory_embedding_cache_policy_t cache_policy) - - cdef wholememory_error_code_t wholememory_destroy_embedding( - wholememory_embedding_t wholememory_embedding) - - cdef wholememory_error_code_t wholememory_embedding_gather(wholememory_embedding_t wholememory_embedding, - wholememory_tensor_t indices, - wholememory_tensor_t output, - bool adjust_cache, - wholememory_env_func_t * p_env_fns, - int64_t stream_int) - - cdef wholememory_error_code_t wholememory_embedding_gather_gradient_apply( - wholememory_embedding_t wholememory_embedding, - wholememory_tensor_t indices, - wholememory_tensor_t grads, - bool adjust_cache, - float lr, - wholememory_env_func_t * p_env_fns, - int64_t stream_int) - - cdef wholememory_tensor_t wholememory_embedding_get_embedding_tensor( - wholememory_embedding_t wholememory_embedding) - - cdef const char * const * wholememory_embedding_get_optimizer_state_names( - wholememory_embedding_t wholememory_embedding) - - cdef wholememory_tensor_t wholememory_embedding_get_optimizer_state( - wholememory_embedding_t wholememory_embedding, const char * name) - - cdef wholememory_error_code_t wholememory_embedding_writeback_cache( - wholememory_embedding_t wholememory_embedding, int64_t stream_int) - - cdef wholememory_error_code_t wholememory_embedding_drop_all_cache( - wholememory_embedding_t wholememory_embedding, int64_t stream_int) - - -cpdef enum WholeMemoryAccessType: - AtNone = WHOLEMEMORY_AT_NONE - AtReadOnly = WHOLEMEMORY_AT_READONLY - AtReadWrite = WHOLEMEMORY_AT_READWRITE - -cpdef enum WholeMemoryOptimizerType: - OptNone = WHOLEMEMORY_OPT_NONE - OptSgd = WHOLEMEMORY_OPT_SGD - OptLazyAdam = WHOLEMEMORY_OPT_LAZY_ADAM - OptAdaGrad = WHOLEMEMORY_OPT_ADAGRAD - OptRmsProp = WHOLEMEMORY_OPT_RMSPROP - -cdef class WholeMemoryOptimizer: - cdef wholememory_embedding_optimizer_t wm_optimizer - cdef wholememory_optimizer_type_t optimizer_type - cdef public dict param_dict - - def __cinit__(self): - self.wm_optimizer = NULL - self.optimizer_type = WHOLEMEMORY_OPT_NONE - - def __init__(self): - self.param_dict = {} - - def create_optimizer(self, - WholeMemoryOptimizerType optimizer_type, - dict param_dict): - cdef str param_key - cdef float param_value - self.optimizer_type = optimizer_type - self.param_dict = param_dict - check_wholememory_error_code(wholememory_create_embedding_optimizer(&self.wm_optimizer, self.optimizer_type)) - for param_key, param_value in self.param_dict.items(): - key_bytes = param_key.encode('utf-8') - check_wholememory_error_code( - wholememory_optimizer_set_parameter(self.wm_optimizer, key_bytes, ¶m_value)) - - def destroy_optimizer(self): - if self.wm_optimizer == NULL: - return - wholememory_destroy_embedding_optimizer(self.wm_optimizer) - self.wm_optimizer = NULL - self.optimizer_type = WHOLEMEMORY_OPT_NONE - self.param_dict = None - -def create_optimizer(WholeMemoryOptimizerType optimizer_type, - dict param_dict): - wm_optimizer = WholeMemoryOptimizer() - wm_optimizer.create_optimizer(optimizer_type, param_dict) - return wm_optimizer - -def create_non_optimizer(): - return WholeMemoryOptimizer() - -cdef class WholeMemoryCachePolicy: - cdef wholememory_embedding_cache_policy_t cache_policy - cdef wholememory_memory_type_t memory_type - cdef wholememory_memory_location_t memory_location - cdef wholememory_access_type_t access_type - cdef float ratio - cdef PyWholeMemoryComm comm - - def __cinit__(self): - self.cache_policy = NULL - self.memory_type = WHOLEMEMORY_MT_NONE - self.memory_location = WHOLEMEMORY_ML_NONE - self.access_type = WHOLEMEMORY_AT_NONE - self.ratio = 0.5 - self.comm = None - - def create_policy(self, - PyWholeMemoryComm comm, - WholeMemoryMemoryType memory_type, - WholeMemoryMemoryLocation memory_location, - WholeMemoryAccessType access_type, - float ratio): - self.memory_type = memory_type - self.memory_location = memory_location - self.access_type = access_type - self.ratio = ratio - check_wholememory_error_code(wholememory_create_embedding_cache_policy(&self.cache_policy, - comm.comm_id, - self.memory_type, - self.memory_location, - self.access_type, - self.ratio)) - - def destroy_policy(self): - if self.cache_policy == NULL: - return - check_wholememory_error_code(wholememory_destroy_embedding_cache_policy(self.cache_policy)) - self.cache_policy = NULL - self.memory_type = WHOLEMEMORY_MT_NONE - self.memory_location = WHOLEMEMORY_ML_NONE - self.access_type = WHOLEMEMORY_AT_NONE - self.ratio = 0.5 - self.comm = None - -def create_cache_policy(PyWholeMemoryComm comm, - WholeMemoryMemoryType memory_type, - WholeMemoryMemoryLocation memory_location, - WholeMemoryAccessType access_type, - float ratio): - cache_policy = WholeMemoryCachePolicy() - cache_policy.create_policy(comm, memory_type, memory_location, access_type, ratio) - return cache_policy - -def create_non_cache_policy(): - return WholeMemoryCachePolicy() - -cdef class PyWholeMemoryEmbedding: - cdef wholememory_embedding_t wm_embedding - cdef wholememory_memory_type_t memory_type - cdef wholememory_memory_location_t memory_location - - def __cinit__(self): - self.wm_embedding = NULL - self.memory_type = WHOLEMEMORY_MT_NONE - self.memory_location = WHOLEMEMORY_ML_NONE - - def create_embedding(self, - PyWholeMemoryTensorDescription tensor_desc, - PyWholeMemoryComm comm, - WholeMemoryMemoryType memory_type, - WholeMemoryMemoryLocation memory_location, - WholeMemoryOptimizer optimizer, - WholeMemoryCachePolicy cache_policy): - self.memory_type = memory_type - self.memory_location = memory_location - check_wholememory_error_code(wholememory_create_embedding(&self.wm_embedding, - &tensor_desc.tensor_description, - comm.comm_id, - self.memory_type, - self.memory_location, - optimizer.wm_optimizer, - cache_policy.cache_policy)) - - def destroy_embedding(self): - check_wholememory_error_code(wholememory_destroy_embedding(self.wm_embedding)) - - def writeback_all_cache(self, - int64_t stream): - check_wholememory_error_code(wholememory_embedding_writeback_cache(self.wm_embedding, stream)) - - def drop_all_cache(self, - int64_t stream): - check_wholememory_error_code(wholememory_embedding_drop_all_cache(self.wm_embedding, stream)) - - def get_embedding_tensor(self): - cdef wholememory_tensor_t wm_tensor - wm_tensor = wholememory_embedding_get_embedding_tensor(self.wm_embedding) - py_wm_tensor = PyWholeMemoryTensor() - py_wm_tensor.from_c_handle(wm_tensor) - return py_wm_tensor - - def get_optimizer_state_names(self): - cdef int i = 0 - result = [] - cdef const char * const * state_names - state_names = wholememory_embedding_get_optimizer_state_names(self.wm_embedding) - while state_names[i] != NULL: - result.append( PyUnicode_FromString(state_names[i])) - i += 1 - return result - - def get_optimizer_state(self, - state_name): - cdef wholememory_tensor_t state_tensor - state_tensor = wholememory_embedding_get_optimizer_state( - self.wm_embedding, - PyUnicode_AsUTF8(state_name)) - py_state_tensor = PyWholeMemoryTensor() - py_state_tensor.from_c_handle(state_tensor) - return py_state_tensor - -def create_embedding(PyWholeMemoryTensorDescription tensor_desc, - PyWholeMemoryComm comm, - WholeMemoryMemoryType memory_type, - WholeMemoryMemoryLocation memory_location, - WholeMemoryOptimizer optimizer, - WholeMemoryCachePolicy cache_policy): - wm_embedding = PyWholeMemoryEmbedding() - wm_embedding.create_embedding(tensor_desc, - comm, - memory_type, - memory_location, - optimizer, - cache_policy) - return wm_embedding - -cpdef void EmbeddingGatherForward(PyWholeMemoryEmbedding wm_embedding, - WrappedLocalTensor indice, - WrappedLocalTensor output, - bool adjust_cache, - int64_t p_env_fns_int, - int64_t stream_int): - check_wholememory_error_code(wholememory_embedding_gather(wm_embedding.wm_embedding, - indice.get_c_handle(), - output.get_c_handle(), - adjust_cache, - p_env_fns_int, - stream_int)) - -cpdef void EmbeddingGatherGradientApply(PyWholeMemoryEmbedding wm_embedding, - WrappedLocalTensor indice, - WrappedLocalTensor grads, - bool adjust_cache, - float lr, - int64_t p_env_fns_int, - int64_t stream_int): - check_wholememory_error_code(wholememory_embedding_gather_gradient_apply( - wm_embedding.wm_embedding, - indice.get_c_handle(), - grads.get_c_handle(), - adjust_cache, - lr, - p_env_fns_int, - stream_int)) - -###################################################################### -# dlpack -# https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h -# https://github.com/cupy/cupy/blob/master/cupy/_core/dlpack.pyx - -cpdef enum DLDeviceType: - kDLCPU = 1 - kDLCUDA = 2 - kDLCUDAHost = 3 - -ctypedef struct DLDevice: - DLDeviceType device_type - int device_id - -cdef enum DLDataTypeCode: - kDLInt = 0 - kDLUInt = 1 - kDLFloat = 2 - kDLBfloat = 4 - -ctypedef struct DLDataType: - uint8_t code - uint8_t bits - uint16_t lanes - -ctypedef struct DLTensor: - void * data - DLDevice device - int ndim - DLDataType dtype - int64_t * shape - int64_t * strides - uint64_t byte_offset - -ctypedef struct DLManagedTensor: - DLTensor dl_tensor - void * manager_ctx - void (*deleter)(DLManagedTensor *) - -cdef void pycapsule_deleter(object dltensor): - cdef DLManagedTensor * dlm_tensor - # Do not invoke the deleter on a used capsule - if cpython.PyCapsule_IsValid(dltensor, 'dltensor'): - dlm_tensor = cpython.PyCapsule_GetPointer( - dltensor, 'dltensor') - dlm_tensor.deleter(dlm_tensor) - -cdef void deleter(DLManagedTensor * tensor) with gil: - if tensor.manager_ctx is NULL: - return - cpython.Py_DECREF( tensor.manager_ctx) - tensor.manager_ctx = NULL - stdlib.free(tensor) - -# end dlpack -###################################################################### - -cdef class PyWholeMemoryUniqueID: - cdef wholememory_unique_id_t wholememory_unique_id - cdef Py_ssize_t shape[1] - cdef Py_ssize_t strides[1] - cdef int64_t shape_int64_t[1] - cdef int64_t strides_int64_t[1] - - def __cinit__(self): - self.shape[0] = sizeof(self.wholememory_unique_id.internal) - self.strides[0] = 1 - self.shape_int64_t[0] = self.shape[0] - self.strides_int64_t[0] = self.strides[0] - - def __len__(self): - return self.shape[0] - - def __getbuffer__(self, Py_buffer *buffer, int flags): - buffer.buf = &self.wholememory_unique_id.internal[0] - buffer.format = 'c' - buffer.internal = NULL - buffer.itemsize = 1 - buffer.len = self.shape[0] - buffer.ndim = 1 - buffer.obj = self - buffer.readonly = 0 - buffer.shape = self.shape - buffer.strides = self.strides - buffer.suboffsets = NULL - - def __releasebuffer__(self, Py_buffer *buffer): - buffer.buf = NULL - buffer.format = 'c' - buffer.len = 0 - buffer.ndim = 0 - buffer.obj = None - buffer.shape = NULL - buffer.strides = NULL - - def __dlpack__(self, stream=None): - cdef DLManagedTensor * dlm_tensor = \ - stdlib.malloc(sizeof(DLManagedTensor)) - cdef DLTensor * dl_tensor = &dlm_tensor.dl_tensor - dl_tensor.data = &self.wholememory_unique_id.internal[0] - dl_tensor.ndim = 1 - dl_tensor.shape = &self.shape_int64_t[0] - dl_tensor.strides = &self.strides_int64_t[0] - dl_tensor.byte_offset = 0 - dl_tensor.device.device_type, dl_tensor.device.device_id = self.__dlpack_device__() - cdef DLDataType * dtype = &dl_tensor.dtype - dtype.code = kDLInt - dtype.lanes = 1 - dtype.bits = 8 - - dlm_tensor.manager_ctx = self - cpython.Py_INCREF(self) - dlm_tensor.deleter = deleter - return cpython.PyCapsule_New(dlm_tensor, 'dltensor', pycapsule_deleter) - - def __dlpack_device__(self): - return (kDLCPU, 0) - -def init(unsigned int flags): - check_wholememory_error_code(wholememory_init(flags)) - -def finalize(): - check_wholememory_error_code(wholememory_finalize()) - -def create_unique_id(): - py_uid = PyWholeMemoryUniqueID() - check_wholememory_error_code(wholememory_create_unique_id(&py_uid.wholememory_unique_id)) - return py_uid - -cpdef enum WholeMemoryViewType: - VtNone = 0 - VtLocal = 1 - VtGlobal = 2 - VtRemote = 3 - -def get_type_string(WholeMemoryDataType data_type): - # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.interface.html#__array_interface__ - if data_type == DtFloat: - return ' 8: - raise ValueError('data_type not supported') - self.typestr = get_type_string(data_type) - cdef WholeMemoryMemoryType mem_type - cdef WholeMemoryMemoryLocation mem_location - mem_type = int(wholememory_get_memory_type(handle.wholememory_handle)) - mem_location = int(wholememory_get_memory_location(handle.wholememory_handle)) - if self.device_type == MlHost and mem_location == MlDevice: - raise ValueError('Device WholeMemory cannot get view from host.') - if mem_type == MtDistributed and (view_type == VtGlobal or view_type == VtRemote): - raise ValueError('Distributed WholeMemory have no view of Global or Remote') - cdef size_t map_size - cdef size_t map_offset - cdef size_t global_size - cdef wholememory_comm_t comm - cdef int world_rank - cdef int world_size - global_size = wholememory_get_total_size(handle.wholememory_handle) - if global_size % elt_size != 0: - raise ValueError('global_size=%d not multiple of elt_size=%d' % (global_size, elt_size)) - global_elt_count = global_size // elt_size - if view_type == VtLocal: - check_wholememory_error_code( - wholememory_get_local_memory(&self.c_ptr, &map_size, &map_offset, handle.wholememory_handle)) - if map_size % elt_size != 0 or map_offset % elt_size != 0: - raise ValueError('map_size=%d, map_offset=%d not multiple of elt_size=%d' - % (map_size, map_offset, elt_size)) - local_elt_count = map_size // elt_size - local_start = map_offset // elt_size - self.shape[0] = map_size // elt_size - self.shape_int64_t[0] = map_size // elt_size - return local_elt_count, local_start - elif view_type == VtGlobal: - check_wholememory_error_code(wholememory_get_global_pointer(&self.c_ptr, handle.wholememory_handle)) - self.shape[0] = global_size // elt_size - self.shape_int64_t[0] = global_size // elt_size - global_elt_count - return global_elt_count, 0 - elif view_type == VtRemote: - check_wholememory_error_code(wholememory_get_communicator(&comm, handle.wholememory_handle)) - check_wholememory_error_code(wholememory_communicator_get_rank(&world_rank, comm)) - check_wholememory_error_code(wholememory_communicator_get_size(&world_size, comm)) - if target_rank < 0 or target_rank >= world_size: - raise IndexError('target_rank=%d but world_size=%d' % (target_rank, int(world_size))) - check_wholememory_error_code(wholememory_get_rank_memory( - &self.c_ptr, &map_size, &map_offset, target_rank, handle.wholememory_handle)) - if map_size % elt_size != 0 or map_offset % elt_size != 0: - raise ValueError('target_rank=%d map_size=%d, map_offset=%d not multiple of elt_size=%d' - % (target_rank, map_size, map_offset, elt_size)) - target_elt_count = map_size // elt_size - target_start = map_offset // elt_size - self.shape[0] = map_size // elt_size - self.shape_int64_t[0] = map_size // elt_size - return target_elt_count, target_start - else: - raise ValueError('view type should be VtLocal or VtGlobal or VtRemote') - - def __len__(self): - return self.shape[0] - - def __getbuffer__(self, Py_buffer *buffer, int flags): - buffer.buf = self.c_ptr - buffer.format = 'c' - buffer.internal = NULL - buffer.itemsize = self.itemsize - buffer.len = self.shape[0] - buffer.ndim = 1 - buffer.obj = self - buffer.readonly = 0 - buffer.shape = self.shape - buffer.strides = self.strides - buffer.suboffsets = NULL - - def __releasebuffer__(self, Py_buffer *buffer): - buffer.buf = NULL - buffer.format = 'c' - buffer.len = 0 - buffer.ndim = 0 - buffer.obj = None - buffer.shape = NULL - buffer.strides = NULL - - @property - def ptr(self): - return int( self.c_ptr) - - @property - def __cuda_array_interface__(self): - """See - https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.interface.html#__array_interface__ - and - https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html - """ - cdef dict intf = { - "data": (self.ptr, False), - "shape": (self.shape[0],), - "strides": None, - "typestr": self.typestr, - "version": 2 - } - return intf - - def __dlpack__(self, stream=None): - cdef DLManagedTensor * dlm_tensor = \ - stdlib.malloc(sizeof(DLManagedTensor)) - cdef DLTensor * dl_tensor = &dlm_tensor.dl_tensor - dl_tensor.data = self.c_ptr - dl_tensor.ndim = 1 - dl_tensor.shape = &self.shape_int64_t[0] - dl_tensor.strides = &self.strides_int64_t[0] - dl_tensor.byte_offset = 0 - dl_tensor.device.device_type, dl_tensor.device.device_id = self.__dlpack_device__() - cdef DLDataType * dtype = &dl_tensor.dtype - if self.data_type == DtInt or self.data_type == DtInt64 \ - or self.data_type == DtInt16 or self.data_type == DtInt8: - dtype.code = kDLInt - elif self.data_type == DtFloat or self.data_type == DtDouble \ - or self.data_type == DtHalf: - dtype.code = kDLFloat - elif self.data_type == DtHalf: - dtype.code = kDLBfloat - else: - raise ValueError('Invalid data_type') - dtype.lanes = 1 - dtype.bits = (self.itemsize * 8) - - dlm_tensor.manager_ctx = self - cpython.Py_INCREF(self) - dlm_tensor.deleter = deleter - return cpython.PyCapsule_New(dlm_tensor, 'dltensor', pycapsule_deleter) - - def __dlpack_device__(self): - if self.device_type == MlHost: - return (kDLCPU, 0) - elif self.device_type == MlDevice: - return (kDLCUDA, self.device_id) - else: - raise ValueError('self.device_type=%d' % (int(self.device_type),)) - -cdef class PyWholeMemoryComm: - cdef wholememory_comm_t comm_id - - def __cinit__(self): - self.comm_id = NULL - - def get_c_handle(self): - return self.comm_id - - def get_rank(self): - cdef int world_rank = -1 - check_wholememory_error_code(wholememory_communicator_get_rank(&world_rank, self.comm_id)) - return world_rank - def get_size(self): - cdef int world_size = -1 - check_wholememory_error_code(wholememory_communicator_get_size(&world_size, self.comm_id)) - return world_size - def barrier(self): - check_wholememory_error_code(wholememory_communicator_barrier(self.comm_id)) - -cdef class PyWholeMemoryHandle: - cdef wholememory_handle_t wholememory_handle - - def __cinit__(self): - self.wholememory_handle = NULL - - def get_c_handle(self): - return self.wholememory_handle - - def get_communicator(self): - py_comm = PyWholeMemoryComm() - check_wholememory_error_code(wholememory_get_communicator(&py_comm.comm_id, self.wholememory_handle)) - return py_comm - - def get_memory_type(self): - return WholeMemoryMemoryType(wholememory_get_memory_type(self.wholememory_handle)) - - def get_memory_location(self): - return WholeMemoryMemoryLocation(wholememory_get_memory_location(self.wholememory_handle)) - - def get_partition_plan(self): - cdef size_t size_per_rank - check_wholememory_error_code(wholememory_get_partition_plan(&size_per_rank, self.wholememory_handle)) - return size_per_rank - - def get_global_flatten_tensor(self, - object import_dlpack_fn, - WholeMemoryDataType data_type, - WholeMemoryMemoryLocation view_from_device, - int view_from_device_id): - tb = PyWholeMemoryFlattenDlpack() - tb.set_view_device(view_from_device, view_from_device_id) - tsize, toffset = tb.get_view(self, data_type, VtGlobal, 0) - assert toffset == 0 - return import_dlpack_fn(tb), toffset - - def get_local_flatten_tensor(self, - object import_dlpack_fn, - WholeMemoryDataType data_type, - WholeMemoryMemoryLocation view_from_device, - int view_from_device_id): - tb = PyWholeMemoryFlattenDlpack() - tb.set_view_device(view_from_device, view_from_device_id) - tsize, toffset = tb.get_view(self, data_type, VtLocal, 0) - return import_dlpack_fn(tb), toffset - - def get_all_chunked_flatten_tensor(self, - object import_dlpack_fn, - WholeMemoryDataType data_type, - WholeMemoryMemoryLocation view_from_device, - int view_from_device_id): - cdef Whole - cdef int world_rank - cdef int world_size - cdef wholememory_comm_t comm - check_wholememory_error_code(wholememory_get_communicator(&comm, self.wholememory_handle)) - check_wholememory_error_code(wholememory_communicator_get_rank(&world_rank, comm)) - check_wholememory_error_code(wholememory_communicator_get_size(&world_size, comm)) - chunked_tensors = [] - toffsets = [] - for r in range(world_size): - tb = PyWholeMemoryFlattenDlpack() - tb.set_view_device(view_from_device, view_from_device_id) - tsize, toffset = tb.get_view(self, data_type, VtRemote, r) - chunked_tensors.append(import_dlpack_fn(tb)) - toffsets.append(toffset) - return chunked_tensors, toffsets - - def from_filelist(self, - int64_t memory_offset, - int64_t memory_entry_size, - int64_t file_entry_size, - file_list): - load_wholememory_handle_from_filelist( self.wholememory_handle, - memory_offset, - memory_entry_size, - file_entry_size, - file_list) - - def to_file(self, - int64_t memory_offset, - int64_t memory_entry_size, - int64_t file_entry_size, - file_name): - store_wholememory_handle_to_file( self.wholememory_handle, - memory_offset, - memory_entry_size, - file_entry_size, - file_name) - -cdef class PyWholeMemoryTensorDescription: - cdef wholememory_tensor_description_t tensor_description - - def __cinit__(self): - self.tensor_description.dim = 0 - self.tensor_description.dtype = int(0) - self.tensor_description.storage_offset = 0 - - cdef set_by_tensor_desc(self, wholememory_tensor_description_t * td): - self.tensor_description = td[0] - - def set_dtype(self, WholeMemoryDataType dtype): - self.tensor_description.dtype = int(dtype) - - def set_shape(self, shape): - assert 0 < len(shape) < 8 - dim = len(shape) - self.tensor_description.dim = dim - for i in range(dim): - self.tensor_description.sizes[i] = shape[i] - - def set_stride(self, strides): - assert len(strides) == self.tensor_description.dim - for i in range(self.tensor_description.dim): - self.tensor_description.strides[i] = strides[i] - - def set_storage_offset(self, storage_offset): - self.tensor_description.storage_offset = storage_offset - - @property - def dtype(self): - return WholeMemoryDataType(self.tensor_description.dtype) - - def dim(self): - return self.tensor_description.dim - - @property - def shape(self): - ret_shape = tuple([self.tensor_description.sizes[i] for i in range(self.tensor_description.dim)]) - return ret_shape - - def stride(self): - return tuple([self.tensor_description.strides[i] for i in range(self.dim())]) - - def storage_offset(self): - return self.tensor_description.storage_offset - -cdef class WrappedLocalTensor: - cdef wholememory_tensor_t wm_tensor - - def __cinit__(self): - self.wm_tensor = NULL - - def __dealloc__(self): - if self.wm_tensor: - check_wholememory_error_code(wholememory_destroy_tensor(self.wm_tensor)) - self.wm_tensor = NULL - - def wrap_tensor(self, - PyWholeMemoryTensorDescription py_desc, - int64_t data_ptr): - check_wholememory_error_code(wholememory_make_tensor_from_pointer(&self.wm_tensor, - data_ptr, - &py_desc.tensor_description)) - - return self - - def get_c_handle(self) -> int: - if self.wm_tensor: - return self.wm_tensor - else: - return 0 - -cdef class PyWholeMemoryTensor: - cdef wholememory_tensor_t wholememory_tensor - cdef wholememory_tensor_description_t tensor_description - - def __cinit__(self): - self.wholememory_tensor = NULL - - cdef from_c_handle(self, - wholememory_tensor_t wm_tensor): - self.wholememory_tensor = wm_tensor - self.tensor_description = wholememory_tensor_get_tensor_description(wm_tensor)[0] - - def get_c_handle(self): - return self.wholememory_tensor - - def get_wholememory_handle(self): - handle = PyWholeMemoryHandle() - handle.wholememory_handle = wholememory_tensor_get_memory_handle(self.wholememory_tensor) - return handle - - @property - def dtype(self): - return WholeMemoryDataType(self.tensor_description.dtype) - - def dim(self): - return self.tensor_description.dim - - @property - def shape(self): - if self.dim() == 1: - return (self.tensor_description.sizes[0],) - elif self.dim() == 2: - return (self.tensor_description.sizes[0], self.tensor_description.sizes[1]) - else: - raise ValueError('self.dim()=%d' % (self.dim(),)) - - def stride(self): - if self.dim() == 1: - return (self.tensor_description.strides[0],) - elif self.dim() == 2: - return (self.tensor_description.strides[0], self.tensor_description.strides[1]) - else: - raise ValueError('self.dim()=%d' % (self.dim(),)) - - def storage_offset(self): - return self.tensor_description.storage_offset - - def get_partition_plan(self): - mem_size_per_rank = self.get_wholememory_handle().get_partition_plan() - element_size = wholememory_dtype_get_element_size(self.tensor_description.dtype) - vector_size = element_size * self.stride()[0] - assert mem_size_per_rank % vector_size == 0 - return mem_size_per_rank // vector_size - - def get_sub_tensor(self, starts, ends): - cdef int64_t start_array[2] - cdef int64_t end_array[2] - start_array[0] = starts[0] - end_array[0] = ends[0] - if self.dim() == 1: - pass - elif self.dim() == 2: - start_array[1] = starts[1] - end_array[1] = ends[1] - else: - raise ValueError('self.dim()=%d' % (self.dim(),)) - sub_tensor = PyWholeMemoryTensor() - check_wholememory_error_code( - wholememory_tensor_get_subtensor(self.wholememory_tensor, start_array, end_array, - &sub_tensor.wholememory_tensor)) - sub_tensor.from_c_handle(sub_tensor.wholememory_tensor) - return sub_tensor - - def get_tensor_in_window(self, - flatten_tensor, - int64_t storage_window_offset): - if self.tensor_description.dim == 1: - start_indice = max(0, self.tensor_description.storage_offset - storage_window_offset) - end_indice = min(flatten_tensor.shape[0], - self.tensor_description.storage_offset + self.tensor_description.sizes[ - 0] - storage_window_offset) - return flatten_tensor[start_indice: end_indice], max(0, - storage_window_offset - self.tensor_description.storage_offset) - elif self.tensor_description.dim == 2: - embedding_stride = self.tensor_description.strides[0] - storage_offset0 = self.tensor_description.storage_offset // embedding_stride - storage_offset1 = self.tensor_description.storage_offset % embedding_stride - mat_tensor = flatten_tensor.reshape(-1, embedding_stride) - assert storage_window_offset % self.tensor_description.strides[0] == 0 - vector_start_offset = storage_window_offset // self.tensor_description.strides[0] - start_indice0 = max(0, storage_offset0 - vector_start_offset) - end_indice0 = min(mat_tensor.shape[0], - storage_offset0 + self.tensor_description.sizes[0] - vector_start_offset) - start_indice_1 = storage_offset1 - assert mat_tensor.shape[1] >= storage_offset1 + self.tensor_description.sizes[1] - end_indice_1 = storage_offset1 + self.tensor_description.sizes[1] - return mat_tensor[start_indice0:end_indice0, start_indice_1:end_indice_1], max(0, - vector_start_offset - storage_offset0) - else: - raise ValueError('tensor dim should be 1 or 2') - - def get_local_tensor(self, - object import_dlpack_fn, - WholeMemoryMemoryLocation view_from_device, - int view_from_device_id): - flatten_tensor, element_offset = self.get_wholememory_handle().get_local_flatten_tensor(import_dlpack_fn, - self.tensor_description.dtype, - view_from_device, - view_from_device_id) - return self.get_tensor_in_window(flatten_tensor, element_offset) - - def get_global_tensor(self, - object import_dlpack_fn, - WholeMemoryMemoryLocation view_from_device, - int view_from_device_id): - global_flatten_tensor, _ = self.get_wholememory_handle().get_global_flatten_tensor(import_dlpack_fn, - self.tensor_description.dtype, - view_from_device, - view_from_device_id) - return self.get_tensor_in_window(global_flatten_tensor, 0)[0] - - def get_all_chunked_tensor(self, - object import_dlpack_fn, - WholeMemoryMemoryLocation view_from_device, - int view_from_device_id): - chunked_flatten_tensors, element_offsets = self.get_wholememory_handle().get_all_chunked_flatten_tensor( - import_dlpack_fn, - self.tensor_description.dtype, - view_from_device, - view_from_device_id) - chunked_tensors = [] - for i in range(len(chunked_flatten_tensors)): - chunked_tensors.append(self.get_tensor_in_window(chunked_flatten_tensors[i], element_offsets[i])[0]) - return chunked_tensors - - def from_filelist(self, filelist): - handle = self.get_wholememory_handle() - strides = self.stride() - shape = self.shape - cdef size_t elt_size = wholememory_dtype_get_element_size(self.tensor_description.dtype) - - cdef size_t memory_offset - cdef size_t memory_entry_size - cdef size_t file_entry_size - memory_offset = self.storage_offset() * elt_size - memory_entry_size = elt_size * strides[0] - if self.dim() == 1: - file_entry_size = elt_size - elif self.dim() == 2: - file_entry_size = elt_size * shape[1] - else: - raise ValueError('tensor dim should be 1 or 2') - handle.from_filelist(memory_offset, memory_entry_size, file_entry_size, filelist) - - def to_file(self, filename): - handle = self.get_wholememory_handle() - strides = self.stride() - shape = self.shape - cdef size_t elt_size = wholememory_dtype_get_element_size(self.tensor_description.dtype) - - cdef size_t memory_offset - cdef size_t memory_entry_size - cdef size_t file_entry_size - memory_offset = self.storage_offset() * elt_size - memory_entry_size = elt_size * strides[0] - if self.dim() == 1: - file_entry_size = elt_size - elif self.dim() == 2: - file_entry_size = elt_size * shape[1] - else: - raise ValueError('tensor dim should be 1 or 2') - handle.to_file(memory_offset, memory_entry_size, file_entry_size, filename) - -############################################################################### - - -def create_communicator(PyWholeMemoryUniqueID py_uid, int world_rank, int world_size): - py_comm = PyWholeMemoryComm() - check_wholememory_error_code(wholememory_create_communicator(&py_comm.comm_id, - py_uid.wholememory_unique_id, - world_rank, - world_size)) - return py_comm - -def destroy_communicator(PyWholeMemoryComm py_comm): - check_wholememory_error_code(wholememory_destroy_communicator(py_comm.comm_id)) - -def determine_partition_plan(int64_t entry_count, - int world_size): - cdef size_t per_rank_count - check_wholememory_error_code(wholememory_determine_entry_partition_plan(&per_rank_count, - entry_count, - world_size)) - return per_rank_count - -def malloc(cython.size_t total_size, - PyWholeMemoryComm py_comm, - WholeMemoryMemoryType memory_type, - WholeMemoryMemoryLocation memory_location, - cython.size_t data_granularity): - handle = PyWholeMemoryHandle() - check_wholememory_error_code(wholememory_malloc(&handle.wholememory_handle, total_size, py_comm.comm_id, - int(memory_type), int(memory_location), - data_granularity)) - return handle - -def free(PyWholeMemoryHandle handle): - check_wholememory_error_code(wholememory_free(handle.wholememory_handle)) - -def create_wholememory_array(WholeMemoryDataType dtype, - int64_t size, - PyWholeMemoryComm comm, - WholeMemoryMemoryType mem_type, - WholeMemoryMemoryLocation mem_location): - wholememory_tensor = PyWholeMemoryTensor() - wholememory_tensor.tensor_description.dtype = int(dtype) - wholememory_tensor.tensor_description.storage_offset = 0 - wholememory_tensor.tensor_description.dim = 1 - wholememory_tensor.tensor_description.strides[0] = 1 - wholememory_tensor.tensor_description.sizes[0] = size - check_wholememory_error_code(wholememory_create_tensor(&wholememory_tensor.wholememory_tensor, - &wholememory_tensor.tensor_description, - comm.comm_id, - int(mem_type), - int(mem_location))) - return wholememory_tensor - -def create_wholememory_matrix(WholeMemoryDataType dtype, - int64_t row, - int64_t column, - int64_t stride, - PyWholeMemoryComm comm, - WholeMemoryMemoryType mem_type, - WholeMemoryMemoryLocation mem_location): - wholememory_tensor = PyWholeMemoryTensor() - wholememory_tensor.tensor_description.dtype = int(dtype) - wholememory_tensor.tensor_description.storage_offset = 0 - wholememory_tensor.tensor_description.dim = 2 - if stride == -1: - stride = column - wholememory_tensor.tensor_description.strides[0] = stride - wholememory_tensor.tensor_description.strides[1] = 1 - wholememory_tensor.tensor_description.sizes[0] = row - wholememory_tensor.tensor_description.sizes[1] = column - check_wholememory_error_code(wholememory_create_tensor(&wholememory_tensor.wholememory_tensor, - &wholememory_tensor.tensor_description, - comm.comm_id, - int(mem_type), - int(mem_location))) - return wholememory_tensor - -def create_wholememory_tensor(PyWholeMemoryTensorDescription tensor_description, - PyWholeMemoryComm comm, - WholeMemoryMemoryType mem_type, - WholeMemoryMemoryLocation mem_location): - if tensor_description.dim() != 1 and tensor_description.dim() != 2: - raise NotImplementedError('WholeMemory currently only support 1D or 2D tensor') - if tensor_description.stride()[tensor_description.dim() - 1] != 1: - raise ValueError('last stride should be 1') - if tensor_description.storage_offset() != 0: - raise ValueError('storage_offset be 0 when created') - wholememory_tensor = PyWholeMemoryTensor() - wholememory_tensor.tensor_description = tensor_description.tensor_description - check_wholememory_error_code(wholememory_create_tensor(&wholememory_tensor.wholememory_tensor, - &wholememory_tensor.tensor_description, - comm.comm_id, - int(mem_type), - int(mem_location))) - return wholememory_tensor - -def make_tensor_as_wholememory(PyWholeMemoryTensorDescription tensor_description, - int64_t data_ptr): - if tensor_description.stride()[tensor_description.dim() - 1] != 1: - raise ValueError('last stride should be 1') - wholememory_tensor = PyWholeMemoryTensor() - check_wholememory_error_code(wholememory_make_tensor_from_pointer(&wholememory_tensor.wholememory_tensor, - data_ptr, - &tensor_description.tensor_description)) - wholememory_tensor.from_c_handle(wholememory_tensor.wholememory_tensor) - return wholememory_tensor - -def make_handle_as_wholememory(PyWholeMemoryTensorDescription tensor_description, - PyWholeMemoryHandle handle): - if tensor_description.stride()[tensor_description.dim() - 1] != 1: - raise ValueError('last stride should be 1') - wholememory_tensor = PyWholeMemoryTensor() - check_wholememory_error_code(wholememory_make_tensor_from_handle(&wholememory_tensor.wholememory_tensor, - handle.wholememory_handle, - &tensor_description.tensor_description)) - wholememory_tensor.from_c_handle(wholememory_tensor.wholememory_tensor) - return wholememory_tensor - -def destroy_wholememory_tensor(PyWholeMemoryTensor wholememory_tensor): - check_wholememory_error_code(wholememory_destroy_tensor(wholememory_tensor.wholememory_tensor)) - -def fork_get_gpu_count(): - return fork_get_device_count() - -cpdef load_wholememory_handle_from_filelist(int64_t wholememory_handle_int_ptr, - int64_t memory_offset, - int64_t memory_entry_size, - int64_t file_entry_size, - file_list): - cdef const char ** filenames - cdef int num_files = len(file_list) - cdef int i - - filenames = stdlib.malloc(num_files * sizeof(char *)) - - try: - for i in range(num_files): - filenames[i] = PyUnicode_AsUTF8(file_list[i]) - - check_wholememory_error_code(wholememory_load_from_file( - wholememory_handle_int_ptr, - memory_offset, - memory_entry_size, - file_entry_size, - filenames, - num_files)) - finally: - stdlib.free(filenames) - -cpdef store_wholememory_handle_to_file(int64_t wholememory_handle_int_ptr, - int64_t memory_offset, - int64_t memory_entry_size, - int64_t file_entry_size, - file_name): - check_wholememory_error_code(wholememory_store_to_file( - wholememory_handle_int_ptr, - memory_offset, - memory_entry_size, - file_entry_size, - PyUnicode_AsUTF8(file_name))) - -cdef extern from "wholememory/wholememory_op.h": - cdef wholememory_error_code_t wholememory_gather(wholememory_tensor_t wholememory_tensor, - wholememory_tensor_t indices_tensor, - wholememory_tensor_t output_tensor, - wholememory_env_func_t * p_env_fns, - void * stream) - - cdef wholememory_error_code_t wholememory_scatter(wholememory_tensor_t input_tensor, - wholememory_tensor_t indices_tensor, - wholememory_tensor_t wholememory_tensor, - wholememory_env_func_t * p_env_fns, - void * stream) - cdef wholememory_error_code_t wholememory_env_test_op(wholememory_tensor_t input_tensor, - wholememory_tensor_t output_fixed_tensor, - void *output_variable_device_tensor_handle, - void *output_variable_pinned_tensor_handle, - void *output_variable_host_tensor_handle, - int64_t output_variable_entry_count, - wholememory_env_func_t *p_env_fns, - void *stream) - - -cpdef void wholememory_gather_op(PyWholeMemoryTensor wholememory_tensor, - WrappedLocalTensor indices_tensor, - WrappedLocalTensor output_tensor, - int64_t p_env_fns_int, - int64_t stream_int): - check_wholememory_error_code(wholememory_gather( wholememory_tensor.get_c_handle(), - indices_tensor.get_c_handle(), - output_tensor.get_c_handle(), - p_env_fns_int, - stream_int)) - -cpdef void wholememory_scatter_op(WrappedLocalTensor input_tensor, - WrappedLocalTensor indices_tensor, - PyWholeMemoryTensor wholememory_tensor, - int64_t p_env_fns_int, - int64_t stream_int): - check_wholememory_error_code(wholememory_scatter( input_tensor.get_c_handle(), - indices_tensor.get_c_handle(), - wholememory_tensor.get_c_handle(), - p_env_fns_int, - stream_int)) - -cpdef void wholememory_env_test_cython_op(WrappedLocalTensor input, - WrappedLocalTensor output, - int64_t output_variable_device_tensor_handle, - int64_t output_variable_pinned_tensor_handle, - int64_t output_variable_host_tensor_handle, - int64_t output_variable_entry_count, - int64_t p_env_fns_int, - int64_t stream_int): - check_wholememory_error_code(wholememory_env_test_op( input.get_c_handle(), - output.get_c_handle(), - output_variable_device_tensor_handle, - output_variable_pinned_tensor_handle, - output_variable_host_tensor_handle, - output_variable_entry_count, - p_env_fns_int, - stream_int)) - return - -cdef extern from "wholememory/wholegraph_op.h": - cdef wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement( - wholememory_tensor_t wm_csr_row_ptr_tensor, - wholememory_tensor_t wm_csr_col_ptr_tensor, - wholememory_tensor_t center_nodes_tensor, - int max_sample_count, - wholememory_tensor_t output_sample_offset_tensor, - void * output_dest_memory_context, - void * output_center_localid_memory_context, - void * output_edge_gid_memory_context, - unsigned long long random_seed, - wholememory_env_func_t * p_env_fns, - void * stream) - - cdef wholememory_error_code_t wholegraph_csr_weighted_sample_without_replacement( - wholememory_tensor_t wm_csr_row_ptr_tensor, - wholememory_tensor_t wm_csr_col_ptr_tensor, - wholememory_tensor_t wm_csr_weight_ptr_tensor, - wholememory_tensor_t center_nodes_tensor, - int max_sample_count, - wholememory_tensor_t output_sample_offset_tensor, - void * output_dest_memory_context, - void * output_center_localid_memory_context, - void * output_edge_gid_memory_context, - unsigned long long random_seed, - wholememory_env_func_t * p_env_fns, - void * stream) - - cdef wholememory_error_code_t generate_random_positive_int_cpu( - int64_t random_seed, - int64_t subsequence, - wholememory_tensor_t output) - - cdef wholememory_error_code_t generate_exponential_distribution_negative_float_cpu( - int64_t random_seed, - int64_t subsequence, - wholememory_tensor_t output) - -cpdef void csr_unweighted_sample_without_replacement( - PyWholeMemoryTensor wm_csr_row_ptr_tensor, - PyWholeMemoryTensor wm_csr_col_ptr_tensor, - WrappedLocalTensor center_nodes_tensor, - int max_sample_count, - WrappedLocalTensor output_sample_offset_tensor, - int64_t output_dest_memory_handle, - int64_t output_center_localid_memory_handle, - int64_t output_edge_gid_memory_handle, - unsigned long long random_seed, - int64_t p_env_fns_int, - int64_t stream_int -): - check_wholememory_error_code(wholegraph_csr_unweighted_sample_without_replacement( - wm_csr_row_ptr_tensor.get_c_handle(), - wm_csr_col_ptr_tensor.get_c_handle(), - center_nodes_tensor.get_c_handle(), - max_sample_count, - output_sample_offset_tensor.get_c_handle(), - output_dest_memory_handle, - output_center_localid_memory_handle, - output_edge_gid_memory_handle, - random_seed, - p_env_fns_int, - stream_int)) - -cpdef void csr_weighted_sample_without_replacement( - PyWholeMemoryTensor wm_csr_row_ptr_tensor, - PyWholeMemoryTensor wm_csr_col_ptr_tensor, - PyWholeMemoryTensor wm_csr_weight_ptr_tensor, - WrappedLocalTensor center_nodes_tensor, - int max_sample_count, - WrappedLocalTensor output_sample_offset_tensor, - int64_t output_dest_memory_handle, - int64_t output_center_localid_memory_handle, - int64_t output_edge_gid_memory_handle, - unsigned long long random_seed, - int64_t p_env_fns_int, - int64_t stream_int -): - check_wholememory_error_code(wholegraph_csr_weighted_sample_without_replacement( - wm_csr_row_ptr_tensor.get_c_handle(), - wm_csr_col_ptr_tensor.get_c_handle(), - wm_csr_weight_ptr_tensor.get_c_handle(), - center_nodes_tensor.get_c_handle(), - max_sample_count, - output_sample_offset_tensor.get_c_handle(), - output_dest_memory_handle, - output_center_localid_memory_handle, - output_edge_gid_memory_handle, - random_seed, - p_env_fns_int, - stream_int)) - -cpdef void host_generate_random_positive_int( - int64_t random_seed, - int64_t subsequence, - WrappedLocalTensor output -): - check_wholememory_error_code(generate_random_positive_int_cpu( - random_seed, - subsequence, - output.get_c_handle() - )) - - -cpdef void host_generate_exponential_distribution_negative_float( - int64_t random_seed, - int64_t subsequence, - WrappedLocalTensor output -): - check_wholememory_error_code(generate_exponential_distribution_negative_float_cpu( - random_seed, - subsequence, - output.get_c_handle() - )) - - -cdef extern from "wholememory/graph_op.h": - cdef wholememory_error_code_t graph_append_unique(wholememory_tensor_t target_nodes_tensor, - wholememory_tensor_t neighbor_nodes_tensor, - void * output_unique_node_memory_context, - wholememory_tensor_t output_neighbor_raw_to_unique_mapping_tensor, - wholememory_env_func_t * p_env_fns, - void * stream) - - - cdef wholememory_error_code_t csr_add_self_loop(wholememory_tensor_t csr_row_ptr_tensor, - wholememory_tensor_t csr_col_ptr_tensor, - wholememory_tensor_t output_csr_row_ptr_tensor, - wholememory_tensor_t output_csr_col_ptr_tensor, - void* stream) - - -cpdef void append_unique( - WrappedLocalTensor target_node_tensor, - WrappedLocalTensor neighbor_node_tensor, - int64_t output_unique_node_memory_handle, - WrappedLocalTensor output_neighbor_raw_to_unique_mapping_tensor, - int64_t p_env_fns_int, - int64_t stream_int): - check_wholememory_error_code(graph_append_unique( - target_node_tensor.get_c_handle(), - neighbor_node_tensor.get_c_handle(), - output_unique_node_memory_handle, - output_neighbor_raw_to_unique_mapping_tensor.get_c_handle(), - p_env_fns_int, - stream_int - )) - - -cpdef void add_csr_self_loop( - WrappedLocalTensor csr_row_ptr_tensor, - WrappedLocalTensor csr_col_ptr_tensor, - WrappedLocalTensor csr_row_ptr_self_tensor, - WrappedLocalTensor csr_col_ptr_self_tensor, - int64_t stream_int): - check_wholememory_error_code(csr_add_self_loop( - csr_row_ptr_tensor.get_c_handle(), - csr_col_ptr_tensor.get_c_handle(), - csr_row_ptr_self_tensor.get_c_handle(), - csr_col_ptr_self_tensor.get_c_handle(), - stream_int)) diff --git a/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py b/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py deleted file mode 100644 index 63e2a2f07..000000000 --- a/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py +++ /dev/null @@ -1,402 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import pylibwholegraph.binding.wholememory_binding as wmb -from pylibwholegraph.utils.multiprocess import multiprocess_run -from pylibwholegraph.torch.initialize import ( - init_torch_env_and_create_wm_comm, - load_wholegraph_op_libraries, -) -import torch -from functools import partial -from pylibwholegraph.test_utils.test_comm import ( - gen_csr_graph, - copy_host_1D_tensor_to_wholememory, - host_get_sample_offset_tensor, - host_sample_all_neighbors, - int_to_wholememory_datatype, - int_to_wholememory_location, - int_to_wholememory_type, -) -import pylibwholegraph.torch.wholegraph_ops as wg_ops -import random - - -def unweighte_sample_without_replacement_base(random_values, M, N): - a = torch.empty((M,), dtype=torch.int32) - Q = torch.arange(N, dtype=torch.int32) - for i in range(M): - a[i] = Q[random_values[i]] - Q[random_values[i]] = Q[N - i - 1] - return a - - -def host_unweighted_sample_without_replacement_func( - host_csr_row_ptr, - host_csr_col_ptr, - center_nodes, - output_sample_offset_tensor, - col_id_dtype, - total_sample_count, - max_sample_count, - random_seed, -): - output_dest_tensor = torch.empty((total_sample_count,), dtype=col_id_dtype) - output_center_localid_tensor = torch.empty((total_sample_count,), dtype=torch.int32) - output_edge_gid_tensor = torch.empty((total_sample_count,), dtype=torch.int64) - center_nodes_count = center_nodes.size(0) - - M = max_sample_count - - warp_count = [ - 1, - 1, - 1, - 2, - 2, - 2, - 4, - 4, - 4, - 4, - 4, - 4, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - 8, - ] - total_items_per_thread = [ - 1, - 2, - 3, - 2, - 3, - 3, - 2, - 2, - 3, - 3, - 3, - 3, - 2, - 2, - 2, - 2, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 3, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - 4, - ] - func_idx = int((max_sample_count - 1) / 32) - block_threads = warp_count[func_idx] * 32 - items_per_thread = total_items_per_thread[func_idx] - - for i in range(center_nodes_count): - node_id = center_nodes[i] - start = host_csr_row_ptr[node_id] - end = host_csr_row_ptr[node_id + 1] - neighbor_count = end - start - N = neighbor_count - output_id = output_sample_offset_tensor[i] - gidx = i * block_threads - if neighbor_count <= max_sample_count: - for j in range(end - start): - output_dest_tensor[output_id + j] = host_csr_col_ptr[start + j] - output_center_localid_tensor[output_id + j] = i - output_edge_gid_tensor[output_id + j] = start + j - else: - random_values = torch.empty((N,), dtype=torch.int32) - for j in range(block_threads): - local_gidx = gidx + j - random_nums = wg_ops.generate_random_positive_int_cpu(random_seed, local_gidx, items_per_thread) - for k in range(items_per_thread): - id = k * block_threads + j - if id < neighbor_count: - if id < M: - random_values[id] = random_nums[k] % (N - id) - else: - random_values[id] = N - random_sample_ids = unweighte_sample_without_replacement_base( - random_values, M, N - ) - for sample_id in range(M): - output_dest_tensor[output_id + sample_id] = host_csr_col_ptr[ - start + random_sample_ids[sample_id] - ] - output_center_localid_tensor[output_id + sample_id] = i - output_edge_gid_tensor[output_id + sample_id] = ( - start + random_sample_ids[sample_id] - ) - return output_dest_tensor, output_center_localid_tensor, output_edge_gid_tensor - - -def host_unweighted_sample_without_replacement( - host_csr_row_ptr, - host_csr_col_ptr, - center_nodes, - max_sample_count, - col_id_dtype, - random_seed, -): - center_nodes_count = center_nodes.size(0) - output_sample_offset_tensor = host_get_sample_offset_tensor( - host_csr_row_ptr, center_nodes, max_sample_count - ) - total_sample_count = output_sample_offset_tensor[center_nodes_count] - - if max_sample_count <= 0: - return host_sample_all_neighbors( - host_csr_row_ptr, - host_csr_col_ptr, - center_nodes, - output_sample_offset_tensor, - col_id_dtype, - total_sample_count, - ) - if max_sample_count > 1024: - raise ValueError( - "invalid host_unweighted_sample_without_replacement test max_sample_count" - ) - - ( - output_dest_tensor, - output_center_localid_tensor, - output_edge_gid_tensor, - ) = host_unweighted_sample_without_replacement_func( - host_csr_row_ptr, - host_csr_col_ptr, - center_nodes, - output_sample_offset_tensor, - col_id_dtype, - total_sample_count, - max_sample_count, - random_seed, - ) - - return ( - output_sample_offset_tensor, - output_dest_tensor, - output_center_localid_tensor, - output_edge_gid_tensor, - ) - - -def routine_func(world_rank: int, world_size: int, **kwargs): - wm_comm, _ = init_torch_env_and_create_wm_comm( - world_rank, world_size, world_rank, world_size - ) - wm_comm = wm_comm.wmb_comm - load_wholegraph_op_libraries() - host_csr_row_ptr = kwargs["host_csr_row_ptr"] - host_csr_col_ptr = kwargs["host_csr_col_ptr"] - graph_node_count = kwargs["graph_node_count"] - graph_edge_count = kwargs["graph_edge_count"] - max_sample_count = kwargs["max_sample_count"] - center_node_count = kwargs["center_node_count"] - center_node_dtype = kwargs["center_node_dtype"] - int_col_id_dtype = kwargs["col_id_dtype"] - int_wholememory_location = kwargs["wholememory_location"] - int_wholememory_type = kwargs["wholememory_type"] - need_center_local_output = kwargs["need_center_local_output"] - need_edge_output = kwargs["need_edge_output"] - - world_rank = wm_comm.get_rank() - world_size = wm_comm.get_size() - - col_id_dtype = int_to_wholememory_datatype(int_col_id_dtype) - wholememory_location = int_to_wholememory_location(int_wholememory_location) - wholememory_type = int_to_wholememory_type(int_wholememory_type) - - wm_csr_row_ptr = wmb.create_wholememory_array( - wmb.WholeMemoryDataType.DtInt64, - graph_node_count + 1, - wm_comm, - wholememory_type, - wholememory_location, - ) - wm_csr_col_ptr = wmb.create_wholememory_array( - col_id_dtype, graph_edge_count, wm_comm, wholememory_type, wholememory_location - ) - copy_host_1D_tensor_to_wholememory( - wm_csr_row_ptr, host_csr_row_ptr, world_rank, world_size, wm_comm - ) - copy_host_1D_tensor_to_wholememory( - wm_csr_col_ptr, host_csr_col_ptr, world_rank, world_size, wm_comm - ) - - wm_comm.barrier() - - center_node_tensor = torch.randint( - 0, graph_node_count, (center_node_count,), dtype=center_node_dtype - ) - center_node_tensor_cuda = center_node_tensor.cuda() - random_seed = random.randint(1, 10000) - - # output_sample_offset_tensor_cuda, - # output_dest_tensor_cuda, - # output_center_localid_tensor_cuda, - # output_edge_gid_tensor_cuda = - # torch.ops.wholegraph.unweighted_sample_without_replacement(wm_csr_row_ptr.get_c_handle(), - # wm_csr_col_ptr.get_c_handle(), - # center_node_tensor_cuda, - # max_sample_count, - # random_seed) - output_sample_offset_tensor = None - output_dest_tensor = None - output_center_localid_tensor = None - output_edge_gid_tensor = None - output_tensors = wg_ops.unweighted_sample_without_replacement( - wm_csr_row_ptr, - wm_csr_col_ptr, - center_node_tensor_cuda, - max_sample_count, - random_seed, - need_center_local_output=need_center_local_output, - need_edge_output=need_edge_output, - ) - output_cpu_tensors = tuple(tensor.cpu() for tensor in output_tensors) - torch_col_id_dtype = torch.int32 - if col_id_dtype == wmb.WholeMemoryDataType.DtInt64: - torch_col_id_dtype = torch.int64 - ( - output_sample_offset_tensor_ref, - output_dest_tensor_ref, - output_center_localid_tensor_ref, - output_edge_gid_tensor_ref, - ) = host_unweighted_sample_without_replacement( - host_csr_row_ptr, - host_csr_col_ptr, - center_node_tensor, - max_sample_count, - torch_col_id_dtype, - random_seed, - ) - - if need_edge_output and need_center_local_output: - ( - output_sample_offset_tensor, - output_dest_tensor, - output_center_localid_tensor, - output_edge_gid_tensor, - ) = output_cpu_tensors - assert torch.equal(output_sample_offset_tensor, output_sample_offset_tensor_ref) - assert torch.equal(output_dest_tensor, output_dest_tensor_ref) - assert torch.equal( - output_center_localid_tensor, output_center_localid_tensor_ref - ) - assert torch.equal(output_edge_gid_tensor, output_edge_gid_tensor_ref) - elif need_center_local_output: - ( - output_sample_offset_tensor, - output_dest_tensor, - output_center_localid_tensor, - ) = output_cpu_tensors - assert torch.equal(output_sample_offset_tensor, output_sample_offset_tensor_ref) - assert torch.equal(output_dest_tensor, output_dest_tensor_ref) - assert torch.equal( - output_center_localid_tensor, output_center_localid_tensor_ref - ) - elif need_edge_output: - ( - output_sample_offset_tensor, - output_dest_tensor, - output_edge_gid_tensor, - ) = output_cpu_tensors - assert torch.equal(output_sample_offset_tensor, output_sample_offset_tensor_ref) - assert torch.equal(output_dest_tensor, output_dest_tensor_ref) - assert torch.equal(output_edge_gid_tensor, output_edge_gid_tensor_ref) - else: - output_sample_offset_tensor, output_dest_tensor = output_cpu_tensors - assert torch.equal(output_sample_offset_tensor, output_sample_offset_tensor_ref) - assert torch.equal(output_dest_tensor, output_dest_tensor_ref) - - wmb.destroy_wholememory_tensor(wm_csr_row_ptr) - wmb.destroy_wholememory_tensor(wm_csr_col_ptr) - - -@pytest.mark.parametrize("graph_node_count", [103]) -@pytest.mark.parametrize("graph_edge_count", [1043]) -@pytest.mark.parametrize("max_sample_count", [11]) -@pytest.mark.parametrize("center_node_count", [13]) -@pytest.mark.parametrize("center_node_dtype", [torch.int32, torch.int64]) -@pytest.mark.parametrize("col_id_dtype", [0, 1]) -@pytest.mark.parametrize("wholememory_location", ([0, 1])) -@pytest.mark.parametrize("wholememory_type", ([0, 1])) -@pytest.mark.parametrize("need_center_local_output", [True, False]) -@pytest.mark.parametrize("need_edge_output", [True, False]) -def test_wholegraph_unweighted_sample( - graph_node_count, - graph_edge_count, - max_sample_count, - center_node_count, - center_node_dtype, - col_id_dtype, - wholememory_location, - wholememory_type, - need_center_local_output, - need_edge_output, -): - gpu_count = wmb.fork_get_gpu_count() - assert gpu_count > 0 - csr_col_dtype = torch.int32 - if col_id_dtype == wmb.WholeMemoryDataType.DtInt64: - csr_col_dtype = torch.int64 - host_csr_row_ptr, host_csr_col_ptr, _ = gen_csr_graph( - graph_node_count, graph_edge_count, csr_col_dtype=csr_col_dtype - ) - routine_func_partial = partial( - routine_func, - host_csr_row_ptr=host_csr_row_ptr, - host_csr_col_ptr=host_csr_col_ptr, - graph_node_count=graph_node_count, - graph_edge_count=graph_edge_count, - max_sample_count=max_sample_count, - center_node_count=center_node_count, - center_node_dtype=center_node_dtype, - col_id_dtype=col_id_dtype, - wholememory_location=wholememory_location, - wholememory_type=wholememory_type, - need_center_local_output=need_center_local_output, - need_edge_output=need_edge_output, - ) - multiprocess_run(gpu_count, routine_func_partial, True) diff --git a/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py b/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py deleted file mode 100644 index 09239ffb7..000000000 --- a/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py +++ /dev/null @@ -1,403 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -from pylibwholegraph.utils.multiprocess import multiprocess_run -from pylibwholegraph.torch.initialize import ( - init_torch_env_and_create_wm_comm, - load_wholegraph_op_libraries, -) -import pylibwholegraph.binding.wholememory_binding as wmb -import torch -import random -from functools import partial -from pylibwholegraph.test_utils.test_comm import ( - gen_csr_graph, - copy_host_1D_tensor_to_wholememory, - host_get_sample_offset_tensor, - host_sample_all_neighbors, - int_to_wholememory_datatype, - int_to_wholememory_location, - int_to_wholememory_type, -) -import pylibwholegraph.torch.wholegraph_ops as wg_ops - - -def host_weighted_sample_without_replacement_func( - host_csr_row_ptr, - host_csr_col_ptr, - host_csr_weight_ptr, - center_nodes, - output_sample_offset_tensor, - col_id_dtype, - csr_weight_dtype, - total_sample_count, - max_sample_count, - random_seed, -): - output_dest_tensor = torch.empty((total_sample_count,), dtype=col_id_dtype) - output_center_localid_tensor = torch.empty((total_sample_count,), dtype=torch.int32) - output_edge_gid_tensor = torch.empty((total_sample_count,), dtype=torch.int64) - center_nodes_count = center_nodes.size(0) - block_sizes = [128, 256, 256, 512] - items_per_threads = [4, 4, 8, 8] - fun_idx = int((max_sample_count - 1) / 128) - if fun_idx > 3: - fun_idx = 3 - - block_size = block_sizes[fun_idx] - items_per_thread = items_per_threads[fun_idx] - - for i in range(center_nodes_count): - node_id = center_nodes[i] - start = host_csr_row_ptr[node_id] - end = host_csr_row_ptr[node_id + 1] - neighbor_count = end - start - output_id = output_sample_offset_tensor[i] - gidx = i * block_size - if neighbor_count <= max_sample_count: - for j in range(end - start): - output_dest_tensor[output_id + j] = host_csr_col_ptr[start + j] - output_center_localid_tensor[output_id + j] = i - output_edge_gid_tensor[output_id + j] = start + j - else: - total_neighbor_generated_weights = torch.tensor([], dtype=csr_weight_dtype) - edge_weight_corresponding_ids = torch.tensor([], dtype=col_id_dtype) - for j in range(block_size): - local_gidx = gidx + j - local_edge_weights = torch.empty( - (items_per_thread,), dtype=csr_weight_dtype - ) - generated_edge_weight_count = 0 - for k in range(items_per_thread): - id = k * block_size + j - if id < neighbor_count: - local_edge_weights[k] = host_csr_weight_ptr[start + id] - generated_edge_weight_count += 1 - edge_weight_corresponding_ids = torch.cat( - ( - edge_weight_corresponding_ids, - torch.tensor([id], dtype=col_id_dtype), - ) - ) - random_values = wg_ops.generate_exponential_distribution_negative_float_cpu(random_seed, local_gidx, generated_edge_weight_count) - generated_random_weight = torch.tensor([(1.0/local_edge_weights[i]) * random_values[i] for i in range(generated_edge_weight_count)]) - - total_neighbor_generated_weights = torch.cat( - (total_neighbor_generated_weights, generated_random_weight) - ) - assert total_neighbor_generated_weights.size(0) == neighbor_count - _, sorted_weight_ids = torch.sort( - total_neighbor_generated_weights, descending=True - ) - sorted_top_m_weight_ids = edge_weight_corresponding_ids[ - sorted_weight_ids[0:max_sample_count] - ] - for sample_id in range(max_sample_count): - output_dest_tensor[output_id + sample_id] = host_csr_col_ptr[ - start + sorted_top_m_weight_ids[sample_id] - ] - output_center_localid_tensor[output_id + sample_id] = i - output_edge_gid_tensor[output_id + sample_id] = ( - start + sorted_top_m_weight_ids[sample_id] - ) - return output_dest_tensor, output_center_localid_tensor, output_edge_gid_tensor - - -def host_weighted_sample_without_replacement( - host_csr_row_ptr, - host_csr_col_ptr, - host_csr_weight_ptr, - center_nodes, - max_sample_count, - col_id_dtype, - random_seed, -): - center_nodes_count = center_nodes.size(0) - output_sample_offset_tensor = host_get_sample_offset_tensor( - host_csr_row_ptr, center_nodes, max_sample_count - ) - total_sample_count = output_sample_offset_tensor[center_nodes_count] - - if max_sample_count <= 0: - return host_sample_all_neighbors( - host_csr_row_ptr, - host_csr_col_ptr, - center_nodes, - output_sample_offset_tensor, - col_id_dtype, - total_sample_count, - ) - if max_sample_count > 1024: - raise ValueError( - "invalid host_unweighted_sample_without_replacement test max_sample_count" - ) - - torch_col_id_dtype = torch.int32 - if col_id_dtype == wmb.WholeMemoryDataType.DtInt64: - torch_col_id_dtype = torch.int64 - - ( - output_dest_tensor, - output_center_localid_tensor, - output_edge_gid_tensor, - ) = host_weighted_sample_without_replacement_func( - host_csr_row_ptr, - host_csr_col_ptr, - host_csr_weight_ptr, - center_nodes, - output_sample_offset_tensor, - torch_col_id_dtype, - host_csr_weight_ptr.dtype, - total_sample_count, - max_sample_count, - random_seed, - ) - - return ( - output_sample_offset_tensor, - output_dest_tensor, - output_center_localid_tensor, - output_edge_gid_tensor, - ) - - -def routine_func(world_rank: int, world_size: int, **kwargs): - wm_comm, _ = init_torch_env_and_create_wm_comm( - world_rank, world_size, world_rank, world_size - ) - wm_comm = wm_comm.wmb_comm - load_wholegraph_op_libraries() - host_csr_row_ptr = kwargs["host_csr_row_ptr"] - host_csr_col_ptr = kwargs["host_csr_col_ptr"] - host_csr_weight_ptr = kwargs["host_csr_weight_ptr"] - graph_node_count = kwargs["graph_node_count"] - graph_edge_count = kwargs["graph_edge_count"] - max_sample_count = kwargs["max_sample_count"] - center_node_count = kwargs["center_node_count"] - center_node_dtype = kwargs["center_node_dtype"] - int_col_id_dtype = kwargs["col_id_dtype"] - int_csr_weight_dtype = kwargs["csr_weight_dtype"] - int_wholememory_location = kwargs["wholememory_location"] - int_wholememory_type = kwargs["wholememory_type"] - need_center_local_output = kwargs["need_center_local_output"] - need_edge_output = kwargs["need_edge_output"] - - world_rank = wm_comm.get_rank() - world_size = wm_comm.get_size() - - col_id_dtype = int_to_wholememory_datatype(int_col_id_dtype) - csr_weight_dtype = int_to_wholememory_datatype(int_csr_weight_dtype) - wholememory_location = int_to_wholememory_location(int_wholememory_location) - wholememory_type = int_to_wholememory_type(int_wholememory_type) - - wm_csr_row_ptr = wmb.create_wholememory_array( - wmb.WholeMemoryDataType.DtInt64, - graph_node_count + 1, - wm_comm, - wholememory_type, - wholememory_location, - ) - wm_csr_col_ptr = wmb.create_wholememory_array( - col_id_dtype, graph_edge_count, wm_comm, wholememory_type, wholememory_location - ) - wm_csr_weight_ptr = wmb.create_wholememory_array( - csr_weight_dtype, - graph_edge_count, - wm_comm, - wholememory_type, - wholememory_location, - ) - - copy_host_1D_tensor_to_wholememory( - wm_csr_row_ptr, host_csr_row_ptr, world_rank, world_size, wm_comm - ) - copy_host_1D_tensor_to_wholememory( - wm_csr_col_ptr, host_csr_col_ptr, world_rank, world_size, wm_comm - ) - copy_host_1D_tensor_to_wholememory( - wm_csr_weight_ptr, host_csr_weight_ptr, world_rank, world_size, wm_comm - ) - - wm_comm.barrier() - - center_node_tensor = torch.randint( - 0, graph_node_count, (center_node_count,), dtype=center_node_dtype - ) - center_node_tensor_cuda = center_node_tensor.cuda() - random_seed = random.randint(1, 10000) - - # output_sample_offset_tensor_cuda, - # output_dest_tensor_cuda, - # output_center_localid_tensor_cuda, - # output_edge_gid_tensor_cuda = - # torch.ops.wholegraph.weighted_sample_without_replacement(wm_csr_row_ptr.get_c_handle(), - # wm_csr_col_ptr.get_c_handle(), - # wm_csr_weight_ptr.get_c_handle(), - # center_node_tensor_cuda, - # max_sample_count, - # random_seed) - output_sample_offset_tensor = None - output_dest_tensor = None - output_center_localid_tensor = None - output_edge_gid_tensor = None - - output_tensors = wg_ops.weighted_sample_without_replacement( - wm_csr_row_ptr, - wm_csr_col_ptr, - wm_csr_weight_ptr, - center_node_tensor_cuda, - max_sample_count, - random_seed, - need_center_local_output=need_center_local_output, - need_edge_output=need_edge_output, - ) - output_cpu_tensors = tuple(tensor.cpu() for tensor in output_tensors) - if need_edge_output and need_center_local_output: - ( - output_sample_offset_tensor, - output_dest_tensor, - output_center_localid_tensor, - output_edge_gid_tensor, - ) = output_cpu_tensors - elif need_center_local_output: - ( - output_sample_offset_tensor, - output_dest_tensor, - output_center_localid_tensor, - ) = output_cpu_tensors - elif need_edge_output: - ( - output_sample_offset_tensor, - output_dest_tensor, - output_edge_gid_tensor, - ) = output_cpu_tensors - else: - output_sample_offset_tensor, output_dest_tensor = output_cpu_tensors - - ( - output_sample_offset_tensor_ref, - output_dest_tensor_ref, - output_center_localid_tensor_ref, - output_edge_gid_tensor_ref, - ) = host_weighted_sample_without_replacement( - host_csr_row_ptr, - host_csr_col_ptr, - host_csr_weight_ptr, - center_node_tensor, - max_sample_count, - col_id_dtype, - random_seed, - ) - - assert torch.equal(output_sample_offset_tensor, output_sample_offset_tensor_ref) - - for i in range(center_node_count): - start = output_sample_offset_tensor[i] - end = output_sample_offset_tensor[i + 1] - output_dest_tensor[start:end], sorted_ids = torch.sort( - output_dest_tensor[start:end] - ) - - output_dest_tensor_ref[start:end], ref_sorted_ids = torch.sort( - output_dest_tensor_ref[start:end] - ) - output_center_localid_tensor_ref[start:end] = output_center_localid_tensor_ref[ - start:end - ][ref_sorted_ids] - output_edge_gid_tensor_ref[start:end] = output_edge_gid_tensor_ref[start:end][ - ref_sorted_ids - ] - if need_edge_output and need_center_local_output: - output_center_localid_tensor[start:end] = output_center_localid_tensor[ - start:end - ][sorted_ids] - output_edge_gid_tensor[start:end] = output_edge_gid_tensor[start:end][ - sorted_ids - ] - elif need_center_local_output: - output_center_localid_tensor[start:end] = output_center_localid_tensor[ - start:end - ][sorted_ids] - elif need_edge_output: - output_edge_gid_tensor[start:end] = output_edge_gid_tensor[start:end][ - sorted_ids - ] - - assert torch.equal(output_dest_tensor, output_dest_tensor_ref) - if need_edge_output and need_center_local_output: - assert torch.equal( - output_center_localid_tensor, output_center_localid_tensor_ref - ) - assert torch.equal(output_edge_gid_tensor, output_edge_gid_tensor_ref) - elif need_center_local_output: - assert torch.equal( - output_center_localid_tensor, output_center_localid_tensor_ref - ) - elif need_edge_output: - assert torch.equal(output_edge_gid_tensor, output_edge_gid_tensor_ref) - - wmb.destroy_wholememory_tensor(wm_csr_row_ptr) - wmb.destroy_wholememory_tensor(wm_csr_col_ptr) - - -@pytest.mark.parametrize("graph_node_count", [113]) -@pytest.mark.parametrize("graph_edge_count", [1043]) -@pytest.mark.parametrize("max_sample_count", [11]) -@pytest.mark.parametrize("center_node_count", [13]) -@pytest.mark.parametrize("center_node_dtype", [torch.int32, torch.int64]) -@pytest.mark.parametrize("col_id_dtype", [0, 1]) -@pytest.mark.parametrize("csr_weight_dtype", [2, 3]) -@pytest.mark.parametrize("wholememory_location", ([0, 1])) -@pytest.mark.parametrize("wholememory_type", ([0, 1])) -@pytest.mark.parametrize("need_center_local_output", [True, False]) -@pytest.mark.parametrize("need_edge_output", [True, False]) -def test_wholegraph_weighted_sample( - graph_node_count, - graph_edge_count, - max_sample_count, - center_node_count, - center_node_dtype, - col_id_dtype, - csr_weight_dtype, - wholememory_location, - wholememory_type, - need_center_local_output, - need_edge_output, -): - gpu_count = wmb.fork_get_gpu_count() - assert gpu_count > 0 - csr_col_dtype = torch.int32 - if col_id_dtype == 1: - csr_col_dtype = torch.int64 - host_csr_row_ptr, host_csr_col_ptr, host_csr_weight_ptr = gen_csr_graph( - graph_node_count, graph_edge_count, csr_col_dtype=csr_col_dtype - ) - routine_func_partial = partial( - routine_func, - host_csr_row_ptr=host_csr_row_ptr, - host_csr_col_ptr=host_csr_col_ptr, - host_csr_weight_ptr=host_csr_weight_ptr, - graph_node_count=graph_node_count, - graph_edge_count=graph_edge_count, - max_sample_count=max_sample_count, - center_node_count=center_node_count, - center_node_dtype=center_node_dtype, - col_id_dtype=col_id_dtype, - csr_weight_dtype=csr_weight_dtype, - wholememory_location=wholememory_location, - wholememory_type=wholememory_type, - need_center_local_output=need_center_local_output, - need_edge_output=need_edge_output, - ) - multiprocess_run(gpu_count, routine_func_partial, True) diff --git a/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py b/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py deleted file mode 100644 index 8e315c985..000000000 --- a/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import pylibwholegraph.binding.wholememory_binding as wmb -from .wholegraph_env import ( - get_stream, - TorchMemoryContext, - get_wholegraph_env_fns, - wrap_torch_tensor, -) -from typing import Union -import random - - -def unweighted_sample_without_replacement( - wm_csr_row_ptr_tensor: wmb.PyWholeMemoryTensor, - wm_csr_col_ptr_tensor: wmb.PyWholeMemoryTensor, - center_nodes_tensor: torch.Tensor, - max_sample_count: int, - random_seed: Union[int, None] = None, - need_center_local_output: bool = False, - need_edge_output: bool = False, -): - assert wm_csr_row_ptr_tensor.dim() == 1 - assert wm_csr_col_ptr_tensor.dim() == 1 - assert center_nodes_tensor.dim() == 1 - if random_seed is None: - random_seed = random.getrandbits(64) - output_sample_offset_tensor = torch.empty( - center_nodes_tensor.shape[0] + 1, device="cuda", dtype=torch.int - ) - output_dest_context = TorchMemoryContext() - output_dest_tensor_id = id(output_dest_context) - output_center_localid_context = None - output_center_localid_tensor_id = 0 - output_edge_gid_context = None - output_edge_gid_tensor_id = 0 - if need_center_local_output: - output_center_localid_context = TorchMemoryContext() - output_center_localid_tensor_id = id(output_center_localid_context) - if need_edge_output: - output_edge_gid_context = TorchMemoryContext() - output_edge_gid_tensor_id = id(output_edge_gid_context) - wmb.csr_unweighted_sample_without_replacement( - wm_csr_row_ptr_tensor, - wm_csr_col_ptr_tensor, - wrap_torch_tensor(center_nodes_tensor), - max_sample_count, - wrap_torch_tensor(output_sample_offset_tensor), - output_dest_tensor_id, - output_center_localid_tensor_id, - output_edge_gid_tensor_id, - random_seed, - get_wholegraph_env_fns(), - get_stream(), - ) - if need_edge_output and need_center_local_output: - return ( - output_sample_offset_tensor, - output_dest_context.get_tensor(), - output_center_localid_context.get_tensor(), - output_edge_gid_context.get_tensor(), - ) - elif need_center_local_output: - return ( - output_sample_offset_tensor, - output_dest_context.get_tensor(), - output_center_localid_context.get_tensor(), - ) - elif need_edge_output: - return ( - output_sample_offset_tensor, - output_dest_context.get_tensor(), - output_edge_gid_context.get_tensor(), - ) - else: - return output_sample_offset_tensor, output_dest_context.get_tensor() - - -def weighted_sample_without_replacement( - wm_csr_row_ptr_tensor: wmb.PyWholeMemoryTensor, - wm_csr_col_ptr_tensor: wmb.PyWholeMemoryTensor, - wm_csr_weight_ptr_tensor: wmb.PyWholeMemoryTensor, - center_nodes_tensor: torch.Tensor, - max_sample_count: int, - random_seed: Union[int, None] = None, - need_center_local_output: bool = False, - need_edge_output: bool = False, -): - assert wm_csr_row_ptr_tensor.dim() == 1 - assert wm_csr_col_ptr_tensor.dim() == 1 - assert wm_csr_weight_ptr_tensor.dim() == 1 - assert wm_csr_weight_ptr_tensor.shape[0] == wm_csr_col_ptr_tensor.shape[0] - assert center_nodes_tensor.dim() == 1 - if random_seed is None: - random_seed = random.getrandbits(64) - output_sample_offset_tensor = torch.empty( - center_nodes_tensor.shape[0] + 1, device="cuda", dtype=torch.int - ) - output_dest_context = TorchMemoryContext() - output_dest_tensor_id = id(output_dest_context) - output_center_localid_context = None - output_center_localid_tensor_id = 0 - output_edge_gid_context = None - output_edge_gid_tensor_id = 0 - if need_center_local_output: - output_center_localid_context = TorchMemoryContext() - output_center_localid_tensor_id = id(output_center_localid_context) - if need_edge_output: - output_edge_gid_context = TorchMemoryContext() - output_edge_gid_tensor_id = id(output_edge_gid_context) - wmb.csr_weighted_sample_without_replacement( - wm_csr_row_ptr_tensor, - wm_csr_col_ptr_tensor, - wm_csr_weight_ptr_tensor, - wrap_torch_tensor(center_nodes_tensor), - max_sample_count, - wrap_torch_tensor(output_sample_offset_tensor), - output_dest_tensor_id, - output_center_localid_tensor_id, - output_edge_gid_tensor_id, - random_seed, - get_wholegraph_env_fns(), - get_stream(), - ) - if need_edge_output and need_center_local_output: - return ( - output_sample_offset_tensor, - output_dest_context.get_tensor(), - output_center_localid_context.get_tensor(), - output_edge_gid_context.get_tensor(), - ) - elif need_center_local_output: - return ( - output_sample_offset_tensor, - output_dest_context.get_tensor(), - output_center_localid_context.get_tensor(), - ) - elif need_edge_output: - return ( - output_sample_offset_tensor, - output_dest_context.get_tensor(), - output_edge_gid_context.get_tensor(), - ) - else: - return output_sample_offset_tensor, output_dest_context.get_tensor() - - -def generate_random_positive_int_cpu(random_seed, - sub_sequence, - output_random_value_count): - output = torch.empty((output_random_value_count,), dtype=torch.int) - wmb.host_generate_random_positive_int(random_seed, sub_sequence, wrap_torch_tensor(output)) - return output - -def generate_exponential_distribution_negative_float_cpu(random_seed: int, - sub_sequence: int, - output_random_value_count: int): - output = torch.empty((output_random_value_count,), dtype = torch.float) - wmb.host_generate_exponential_distribution_negative_float(random_seed, sub_sequence, wrap_torch_tensor(output)) - return output diff --git a/python/pylibwholegraph/CMakeLists.txt b/python/pylibwholegraph/CMakeLists.txt index fb86e2d21..347c0d95b 100644 --- a/python/pylibwholegraph/CMakeLists.txt +++ b/python/pylibwholegraph/CMakeLists.txt @@ -152,41 +152,10 @@ message(STATUS "PYLIBWHOLEGRAPH: CXX_DEFINITIONS='${CXX_DEFINITIONS}'") ############################################################################## # - Variables ---------------------------------------------------------------- -#set(WHOLEGRAPH_PY_TARGET "pylibwholegraph_ext" CACHE STRING "wholegraph nanobind target name") - set(WHOLEGRAPH_CPP_TARGET "wholegraph::wholegraph" CACHE STRING "libwholegraph target name") -############################################################################## -# - nanobind targets --------------------------------------------------------- - -# Build the actual extension module -#nanobind_add_module( -# ${WHOLEGRAPH_PY_TARGET} -# NB_STATIC # Build static libnanobind (the extension module itself remains a shared library) -# ${CMAKE_CURRENT_LIST_DIR}/cpp/pylibwholegraph_ext.cpp -# ${CMAKE_CURRENT_LIST_DIR}/cpp/wholegraph_types.cpp -# ${CMAKE_CURRENT_LIST_DIR}/cpp/wholegraph_functions.cpp -#) -# -## this adds includes from the C++ target as well -#target_link_libraries(${WHOLEGRAPH_PY_TARGET} -# PUBLIC -# ${WHOLEGRAPH_CPP_TARGET} -# ) - -############################################################################## -# - Install Targets --------------------------------------------------------- - -#install( -# TARGETS ${WHOLEGRAPH_PY_TARGET} -# DESTINATION ${CMAKE_INSTALL_PREFIX}/${PYTHON_RELATIVE_SITE_PACKAGES_DIR}) - add_subdirectory(pylibwholegraph/binding) # when used without setup.py, command is like: # export LIBWHOLEGRAPH_DIR=`pwd`/../../cpp/build/install # cmake ../ -DSKBUILD=ON - -if (BUILD_OPS_WITH_TORCH_C10_API) - add_subdirectory(wholegraph_torch) -endif() diff --git a/python/pylibwholegraph/examples/node_classfication.py b/python/pylibwholegraph/examples/node_classfication.py index 7abd97a76..5f960f52a 100644 --- a/python/pylibwholegraph/examples/node_classfication.py +++ b/python/pylibwholegraph/examples/node_classfication.py @@ -131,6 +131,9 @@ def main_func(): wgth.get_local_size(), ) + if options.use_cpp_ext: + wgth.compile_cpp_extension() + train_ds, valid_ds, test_ds = wgth.create_node_claffication_datasets( options.pickle_data_path ) diff --git a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx index 30e1ee951..feddfaff4 100644 --- a/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx +++ b/python/pylibwholegraph/pylibwholegraph/binding/wholememory_binding.pyx @@ -1791,6 +1791,16 @@ cdef extern from "wholememory/wholegraph_op.h": wholememory_env_func_t * p_env_fns, void * stream) + cdef wholememory_error_code_t generate_random_positive_int_cpu( + int64_t random_seed, + int64_t subsequence, + wholememory_tensor_t output) + + cdef wholememory_error_code_t generate_exponential_distribution_negative_float_cpu( + int64_t random_seed, + int64_t subsequence, + wholememory_tensor_t output) + cpdef void csr_unweighted_sample_without_replacement( PyWholeMemoryTensor wm_csr_row_ptr_tensor, PyWholeMemoryTensor wm_csr_col_ptr_tensor, @@ -1845,6 +1855,29 @@ cpdef void csr_weighted_sample_without_replacement( p_env_fns_int, stream_int)) +cpdef void host_generate_random_positive_int( + int64_t random_seed, + int64_t subsequence, + WrappedLocalTensor output +): + check_wholememory_error_code(generate_random_positive_int_cpu( + random_seed, + subsequence, + output.get_c_handle() + )) + + +cpdef void host_generate_exponential_distribution_negative_float( + int64_t random_seed, + int64_t subsequence, + WrappedLocalTensor output +): + check_wholememory_error_code(generate_exponential_distribution_negative_float_cpu( + random_seed, + subsequence, + output.get_c_handle() + )) + cdef extern from "wholememory/graph_op.h": cdef wholememory_error_code_t graph_append_unique(wholememory_tensor_t target_nodes_tensor, diff --git a/python/pylibwholegraph/pylibwholegraph/test_utils/__init__.py b/python/pylibwholegraph/pylibwholegraph/test_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py index 4544607ad..828033598 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/pylibwholegraph/test_wholememory_io.py @@ -183,7 +183,7 @@ def store_routine_func( embedding_stride, storage_offset, ): - (wm_comm,) = init_torch_env_and_create_wm_comm( + (wm_comm, _) = init_torch_env_and_create_wm_comm( world_rank, world_size, world_rank, world_size ) wm_comm = wm_comm.wmb_comm diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_add_csr_self_loop.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_add_csr_self_loop.py index 288603aa6..925ac2d8b 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_add_csr_self_loop.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_add_csr_self_loop.py @@ -12,7 +12,6 @@ # limitations under the License. import pytest -from pylibwholegraph.torch.initialize import load_wholegraph_op_libraries import torch from pylibwholegraph.test_utils.test_comm import gen_csr_graph import pylibwholegraph.torch.graph_ops as wg_ops @@ -39,7 +38,6 @@ def host_add_csr_self_loop(csr_row_ptr_tensor, csr_col_ptr_tensor): def routine_func(**kwargs): - load_wholegraph_op_libraries() target_node_count = kwargs["target_node_count"] neighbor_node_count = kwargs["neighbor_node_count"] edge_num = kwargs["edge_num"] diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_append_unique.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_append_unique.py index a49b18d8d..a188dda01 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_append_unique.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_graph_append_unique.py @@ -12,7 +12,6 @@ # limitations under the License. import pytest -from pylibwholegraph.torch.initialize import load_wholegraph_op_libraries import torch import pylibwholegraph.torch.graph_ops as wg_ops @@ -30,7 +29,6 @@ def host_neighbor_raw_to_unique(unique_node_tensor, neighbor_node_tensor): def routine_func(**kwargs): - load_wholegraph_op_libraries() target_node_count = kwargs["target_node_count"] neighbor_node_count = kwargs["neighbor_node_count"] target_node_dtype = kwargs["target_node_dtype"] diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py index 88d5a74d3..ae4d8381e 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_unweighted_sample_without_replacement.py @@ -14,10 +14,7 @@ import pytest import pylibwholegraph.binding.wholememory_binding as wmb from pylibwholegraph.utils.multiprocess import multiprocess_run -from pylibwholegraph.torch.initialize import ( - init_torch_env_and_create_wm_comm, - load_wholegraph_op_libraries, -) +from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm import torch from functools import partial from pylibwholegraph.test_utils.test_comm import ( @@ -148,7 +145,7 @@ def host_unweighted_sample_without_replacement_func( random_values = torch.empty((N,), dtype=torch.int32) for j in range(block_threads): local_gidx = gidx + j - random_nums = torch.ops.wholegraph_test.raft_pcg_generator_random( + random_nums = wg_ops.generate_random_positive_int_cpu( random_seed, local_gidx, items_per_thread ) for k in range(items_per_thread): @@ -228,7 +225,6 @@ def routine_func(world_rank: int, world_size: int, **kwargs): world_rank, world_size, world_rank, world_size ) wm_comm = wm_comm.wmb_comm - load_wholegraph_op_libraries() host_csr_row_ptr = kwargs["host_csr_row_ptr"] host_csr_col_ptr = kwargs["host_csr_col_ptr"] graph_node_count = kwargs["graph_node_count"] diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py index 3f59dcbb0..c2369b475 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholegraph_weighted_sample_without_replacement.py @@ -13,10 +13,7 @@ import pytest from pylibwholegraph.utils.multiprocess import multiprocess_run -from pylibwholegraph.torch.initialize import ( - init_torch_env_and_create_wm_comm, - load_wholegraph_op_libraries, -) +from pylibwholegraph.torch.initialize import init_torch_env_and_create_wm_comm import pylibwholegraph.binding.wholememory_binding as wmb import torch import random @@ -90,14 +87,18 @@ def host_weighted_sample_without_replacement_func( torch.tensor([id], dtype=col_id_dtype), ) ) - generated_random_weight = ( - torch.ops.wholegraph_test.raft_pcg_generator_random_from_weight( - random_seed, - local_gidx, - local_edge_weights, - generated_edge_weight_count, + random_values = ( + wg_ops.generate_exponential_distribution_negative_float_cpu( + random_seed, local_gidx, generated_edge_weight_count ) ) + generated_random_weight = torch.tensor( + [ + (1.0 / local_edge_weights[i]) * random_values[i] + for i in range(generated_edge_weight_count) + ] + ) + total_neighbor_generated_weights = torch.cat( (total_neighbor_generated_weights, generated_random_weight) ) @@ -182,7 +183,6 @@ def routine_func(world_rank: int, world_size: int, **kwargs): world_rank, world_size, world_rank, world_size ) wm_comm = wm_comm.wmb_comm - load_wholegraph_op_libraries() host_csr_row_ptr = kwargs["host_csr_row_ptr"] host_csr_col_ptr = kwargs["host_csr_col_ptr"] host_csr_weight_ptr = kwargs["host_csr_weight_ptr"] diff --git a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholememory_cython_binding.py b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholememory_cython_binding.py index 2dac3c3c3..a247fdad0 100644 --- a/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholememory_cython_binding.py +++ b/python/pylibwholegraph/pylibwholegraph/tests/wholegraph_torch/ops/test_wholememory_cython_binding.py @@ -48,9 +48,9 @@ def test_smoke(): wmb.wholememory_env_test_cython_op( wrapped_input, wrapped_output, - id(output_device_context), - id(output_pinned_context), - id(output_host_context), + output_device_context.get_c_context(), + output_pinned_context.get_c_context(), + output_host_context.get_c_context(), output_len, env_func_int_ptr, stream_int_ptr, @@ -81,9 +81,9 @@ def test_loop_memory(): wmb.wholememory_env_test_cython_op( wrapped_input, wrapped_output, - id(output_device_context), - id(output_pinned_context), - id(output_host_context), + output_device_context.get_c_context(), + output_pinned_context.get_c_context(), + output_host_context.get_c_context(), output_len, env_func_int_ptr, stream_int_ptr, @@ -101,7 +101,7 @@ def test_loop_memory(): wmb.wholememory_env_test_cython_op( wrapped_input, wrapped_output, - id(output_device_context), + output_device_context.get_c_context(), 0, 0, output_len, @@ -136,9 +136,9 @@ def test_random_alloc(output_len, embed_dim): wmb.wholememory_env_test_cython_op( wrapped_input, wrapped_output, - id(output_device_context), - id(output_pinned_context), - id(output_host_context), + output_device_context.get_c_context(), + output_pinned_context.get_c_context(), + output_host_context.get_c_context(), output_len, env_func_int_ptr, stream_int_ptr, diff --git a/python/pylibwholegraph/pylibwholegraph/torch/__init__.py b/python/pylibwholegraph/pylibwholegraph/torch/__init__.py index edef113d9..e354bd262 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/__init__.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/__init__.py @@ -73,3 +73,4 @@ get_train_dataloader, get_valid_test_dataloader, ) +from .wholegraph_env import compile_cpp_extension diff --git a/python/pylibwholegraph/pylibwholegraph/torch/comm.py b/python/pylibwholegraph/pylibwholegraph/torch/comm.py index ed33fe30f..8fb0500dc 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/comm.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/comm.py @@ -28,6 +28,16 @@ def set_world_info(world_rank: int, world_size: int, local_rank: int, local_size: int): + """ + Set the global world's information. This is used for create common used communicators, like local node communicator, + global communicator, or local device communicator. + + :param world_rank: world rank of current process. + :param world_size: world size + :param local_rank: local rank of current process in current machine node. + :param local_size: local size of each machine node + :return: None + """ global all_comm_world_rank, all_comm_world_size, all_comm_local_rank, all_comm_local_size all_comm_world_rank = world_rank all_comm_world_size = world_size @@ -36,19 +46,31 @@ def set_world_info(world_rank: int, world_size: int, local_rank: int, local_size class WholeMemoryCommunicator(object): - r"""WholeMemory Communicator""" + """ + WholeMemory Communicator. + You should not create object of this class directly, use create_group_communicator, get_global_communicator, + get_local_node_communicator or get_local_device_communicator instead. + """ def __init__(self, wmb_comm: wmb.PyWholeMemoryComm): super().__init__() self.wmb_comm = wmb_comm def get_rank(self): + """Get rank of current process in this communicator""" return self.wmb_comm.get_rank() def get_size(self): + """Get world size of this communicator""" return self.wmb_comm.get_size() def barrier(self): + """ + Barrier on WholeMemory Communicator. + This function will use internal communicator associated CUDA stream. And synchronized with host. + So if you have work in other CUDA stream, and expect that to be done before barrier, you may need to + synchrionze that stream before calling this function. + """ return self.wmb_comm.barrier() def destroy(self): @@ -57,7 +79,7 @@ def destroy(self): def create_group_communicator(group_size: int = -1, comm_stride: int = 1): - r"""Create WholeMemory Communicator. + """Create WholeMemory Communicator. For example: 24 ranks with group_size = 4 and comm_stride = 2 will create following groups: [0, 2, 4, 6], [1, 3, 5, 7], [8, 10, 12, 14], [9, 11, 13, 15], [16, 18, 20, 22], [17, 19, 21, 23] :param group_size: Size of each group, -1 means to use all ranks in just one single group. @@ -95,12 +117,21 @@ def create_group_communicator(group_size: int = -1, comm_stride: int = 1): def destroy_communicator(wm_comm: WholeMemoryCommunicator): + """ + Destroy WholeMemoryCommunicator + :param wm_comm: WholeMemoryCommunicator to destroy + :return: None + """ if wm_comm is not None and wm_comm.wmb_comm is not None: wmb.destroy_communicator(wm_comm.wmb_comm) wm_comm.wmb_comm = None def get_global_communicator(): + """ + Get the global communicator of this job + :return: WholeMemoryCommunicator that has all GPUs in it. + """ global global_communicator, local_node_communicator, local_device_communicator global all_comm_local_size, all_comm_world_size if global_communicator is None: @@ -115,6 +146,10 @@ def get_global_communicator(): def get_local_node_communicator(): + """ + Get the local node communicator of this job + :return: WholeMemoryCommunicator that has GPUs in the same node. + """ global global_communicator, local_node_communicator, local_device_communicator global all_comm_local_size, all_comm_world_size if local_node_communicator is None: @@ -129,6 +164,10 @@ def get_local_node_communicator(): def get_local_device_communicator(): + """ + Get the local device communicator of this job + :return: WholeMemoryCommunicator that has only the GPU belonging to current process. + """ global global_communicator, local_node_communicator, local_device_communicator global all_comm_local_size, all_comm_world_size if local_device_communicator is None: diff --git a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py index f7fcef35d..e79aaa122 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/common_options.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/common_options.py @@ -48,6 +48,13 @@ def add_training_options(parser: OptionParser): default=0.5, help="cache ratio", ) + parser.add_option( + "--use-cpp-ext", + action="store_true", + dest="use_cpp_ext", + default=False, + help="Whether to use cpp extension for pytorch" + ) parser.add_option( "--train-embedding", action="store_true", diff --git a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py index 732485b26..f4cfa0550 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/embedding.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/embedding.py @@ -31,6 +31,12 @@ class WholeMemoryOptimizer(object): + """ + Sparse Optimizer for WholeMemoryEmbedding. + Many WholeMemoryEmbedding can share same WholeMemoryOptimizer + You should not create WholeMemoryOptimizer object directly, but use :func:`create_wholememory_optimizer` instead. + """ + def __init__(self, global_comm: WholeMemoryCommunicator): super().__init__() self.wmb_opt = wmb.WholeMemoryOptimizer() @@ -38,7 +44,7 @@ def __init__(self, global_comm: WholeMemoryCommunicator): self.global_comm = global_comm def add_embedding(self, wm_embedding): - r"""Add WholeMemory Embedding to this optimizer + """Add WholeMemory Embedding to this optimizer NOTE: you don't need to call this method, it is automatic called when WholeMemory Embedding is created. :param wm_embedding: WholeMemory Embedding that use this optimizer :return: None @@ -56,6 +62,12 @@ def step(self, lr: float): def create_wholememory_optimizer(optimizer_type: str, param_dict: dict): + """ + Create WholeMemoryOptimizer. + :param optimizer_type: Type of the Optimizer + :param param_dict: parameters of the optimizer + :return: WholeMemoryOptimizer + """ wm_optimizer = WholeMemoryOptimizer(get_global_communicator()) wm_optimizer.wmb_opt.create_optimizer( str_to_wmb_wholememory_optimizer_type(optimizer_type), param_dict @@ -64,11 +76,22 @@ def create_wholememory_optimizer(optimizer_type: str, param_dict: dict): def destroy_wholememory_optimizer(optimizer: WholeMemoryOptimizer): + """ + Destroy WholeMemoryOptimizer + :param optimizer: WholeMemoryOptimizer to destroy + :return: None + """ optimizer.wmb_opt.destroy_optimizer() optimizer.wmb_opt = None class WholeMemoryCachePolicy(object): + """ + Cache policy to create WholeMemoryEmbedding. + NOTE: You should not create WholeMemoryCachePolicy object directly, + use :func:`create_wholememory_cache_policy` instead. + """ + def __init__(self, wmb_cache_policy: wmb.WholeMemoryCachePolicy): super().__init__() self.wmb_cache_policy = wmb_cache_policy diff --git a/python/pylibwholegraph/pylibwholegraph/torch/graph_ops.py b/python/pylibwholegraph/pylibwholegraph/torch/graph_ops.py index f479f6510..464b6db0e 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/graph_ops.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/graph_ops.py @@ -26,13 +26,23 @@ def append_unique( neighbor_node_tensor: torch.Tensor, need_neighbor_raw_to_unique: bool = False, ): + """ + Append neighbor_node_tenosr to target_node_tensor, keep target_node_tensor unchanged and do unique + e.g. if target_node_tensor is [3, 11, 2, 10], neighbor_node_tensor is [4, 5, 2, 11, 6, 9, 10, 5], + output_unique_node may be [3, 11, 2, 10, 6, 4, 9, 5], order of 6, 4, 9, 5 may change. + neighbor_raw_to_unique_mapping will be [5, 7, 2, 1, 4, 6, 3, 7] + :param target_node_tensor: target node tensor + :param neighbor_node_tensor: neighbor node tensor + :param need_neighbor_raw_to_unique: if need to output neighbor_raw_to_unique_mapping + :return: output_unique_node and neighbor_raw_to_unique_mapping + """ assert target_node_tensor.dim() == 1 assert neighbor_node_tensor.dim() == 1 assert target_node_tensor.is_cuda assert neighbor_node_tensor.is_cuda output_unique_node_context = TorchMemoryContext() - output_unique_node_tensor_id = id(output_unique_node_context) + output_unique_node_c_context = output_unique_node_context.get_c_context() output_neighbor_raw_to_unique_mapping_tensor = None if need_neighbor_raw_to_unique: output_neighbor_raw_to_unique_mapping_tensor = torch.empty( @@ -42,7 +52,7 @@ def append_unique( wmb.append_unique( wrap_torch_tensor(target_node_tensor), wrap_torch_tensor(neighbor_node_tensor), - output_unique_node_tensor_id, + output_unique_node_c_context, wrap_torch_tensor(output_neighbor_raw_to_unique_mapping_tensor), get_wholegraph_env_fns(), get_stream(), @@ -59,6 +69,13 @@ def append_unique( def add_csr_self_loop( csr_row_ptr_tensor: torch.Tensor, csr_col_ptr_tensor: torch.Tensor ): + """ + Add self loop to sampled CSR graph + NOTE: this function will not check if there is already self loop in the raw CSR graph. + :param csr_row_ptr_tensor: CSR row pointer tensor + :param csr_col_ptr_tensor: CSR column index tensor + :return: CSR graph added self loop + """ assert csr_row_ptr_tensor.dim() == 1 assert csr_col_ptr_tensor.dim() == 1 assert csr_row_ptr_tensor.is_cuda diff --git a/python/pylibwholegraph/pylibwholegraph/torch/graph_structure.py b/python/pylibwholegraph/pylibwholegraph/torch/graph_structure.py index 0028571eb..241ad4012 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/graph_structure.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/graph_structure.py @@ -21,6 +21,7 @@ class GraphStructure(object): r"""Graph structure storage Actually, it is the graph structure of one relation, represented in CSR format. + It contains CSR representation of Graph structure, and also attributes associated with nodes and edges. """ def __init__(self): @@ -35,6 +36,12 @@ def __init__(self): def set_csr_graph( self, csr_row_ptr: WholeMemoryTensor, csr_col_ind: WholeMemoryTensor ): + """ + Set the CSR graph structure + :param csr_row_ptr: CSR graph row pointer + :param csr_col_ind: CSR graph column index + :return: None + """ assert csr_row_ptr.dim() == 1 assert csr_row_ptr.dtype == torch.int64 assert csr_row_ptr.shape[0] > 1 @@ -46,28 +53,50 @@ def set_csr_graph( self.csr_col_ind = csr_col_ind def set_node_attribute(self, attr_name: str, attr_tensor: WholeMemoryTensor): + """ + Set attribute for node + :param attr_name: attribute name for node + :param attr_tensor: attribute tensor + :return: None + """ assert attr_name not in self.node_attributes assert attr_tensor.shape[0] == self.node_count self.node_attributes[attr_name] = attr_tensor def set_edge_attribute(self, attr_name: str, attr_tensor: WholeMemoryTensor): + """ + Set attribute for edge + :param attr_name: attribute name for edge + :param attr_tensor: attribute tensor + :return: None + """ assert attr_name not in self.edge_attributes assert attr_tensor.shape[0] == self.edge_count self.edge_attributes[attr_name] = attr_tensor def unweighted_sample_without_replacement_one_hop( self, - centor_nodes_tensor: torch.Tensor, + center_nodes_tensor: torch.Tensor, max_sample_count: int, *, random_seed: Union[int, None] = None, need_center_local_output: bool = False, need_edge_output: bool = False ): + """ + Unweighted Sample without replacement on CSR graph structure + :param center_nodes_tensor: center node ids + :param max_sample_count: max sample count for each center node + :param random_seed: random seed for the sampler + :param need_center_local_output: If True, output a tensor same length as sampled nodes but each element is the + center node index in center_nodes_tensor. + :param need_edge_output: If True, output the edge index of each sampled node + :return: csr_row_ptr, sampled_nodes[, center_node_local_id, edge_index] + """ return wholegraph_ops.unweighted_sample_without_replacement( self.csr_row_ptr.wmb_tensor, self.csr_col_ind.wmb_tensor, - centor_nodes_tensor, + center_nodes_tensor, max_sample_count, random_seed, need_center_local_output, @@ -84,6 +113,17 @@ def weighted_sample_without_replacement_one_hop( need_center_local_output: bool = False, need_edge_output: bool = False ): + """ + Weighted Sample without replacement on CSR graph structure with edge weights attribute + :param weight_name: edge attribute name for weight + :param center_nodes_tensor: center node ids + :param max_sample_count: max sample count for each center node + :param random_seed: random seed for the sampler + :param need_center_local_output: If True, output a tensor same length as sampled nodes but each element is the + center node index in center_nodes_tensor. + :param need_edge_output: If True, output the edge index of each sampled node + :return: csr_row_ptr, sampled_nodes[, center_node_local_id, edge_index] + """ assert weight_name in self.edge_attributes weight_tensor = self.edge_attributes[weight_name] return wholegraph_ops.weighted_sample_without_replacement( @@ -103,6 +143,13 @@ def multilayer_sample_without_replacement( max_neighbors: List[int], weight_name: Union[str, None] = None, ): + """ + Multilayer sample without replacement + :param node_ids: initial node ids + :param max_neighbors: maximum neighbor for each layer + :param weight_name: edge attribute name for weight, if None, use unweighted sample + :return: target_gids, edge_indice, csr_row_ptr, csr_col_ind + """ hops = len(max_neighbors) edge_indice = [None] * hops csr_row_ptr = [None] * hops diff --git a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py index 106066634..4dc35c3b1 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/initialize.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/initialize.py @@ -68,16 +68,3 @@ def finalize(): :return: None """ wmb.finalize() - - -def load_wholegraph_op_libraries(): - cxx11abi = torch.torch.compiled_with_cxx11_abi() - if cxx11abi is True: - lib_path = "wholegraph_torch/libwholegraph_torch_cxx11abi.so" - else: - lib_path = "wholegraph_torch/libwholegraph_torch_precxx11abi.so" - torch.ops.load_library(lib_path) - - -def jit_load_wholegraph_op_libraries(): - pass diff --git a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py index a67085459..f58ef60a6 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/tensor.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/tensor.py @@ -58,6 +58,12 @@ def get_comm(self): ) def get_sub_tensor(self, starts, ends): + """ + Get sub tensor of WholeMemory Tensor + :param starts: An array of the start indices of each dim + :param ends: An array of the end indices of each dim, -1 means to the last element + :return: WholeMemory Tensor + """ return WholeMemoryTensor(self.wmb_tensor.get_sub_tensor(starts, ends)) def get_local_tensor(self, host_view: bool = False): @@ -109,20 +115,42 @@ def get_all_chunked_tensor(self, host_view: bool = False): ) def from_filelist(self, filelist: Union[List[str], str]): + """ + Load WholeMemory Tensor from file lists + :param filelist: file list to load from + :return: None + """ if isinstance(filelist, str): filelist = [filelist] self.wmb_tensor.from_filelist(filelist) def from_file_prefix(self, file_prefix: str, part_count: Union[int, None] = None): + """ + Load WholeMemory tensor from files with same prefix, files has format + "%s_part_%d_of_%d" % (prefix, part_id, part_count) + :param file_prefix: file name prefix + :param part_count: part count of file + :return: None + """ if part_count is None: part_count = self.get_comm().get_size() file_list = get_part_file_list(file_prefix, part_count) self.from_filelist(file_list) def local_to_file(self, filename: str): + """ + Store local tensor of WholeMemory Tensor to file, all ranks should call this together with different filename + :param filename: file name of local tensor file. + :return: None + """ self.wmb_tensor.to_file(filename) def to_file_prefix(self, file_prefix: str): + """ + Store WholeMemory Tensor to files with same prefix. + :param file_prefix: file name prefix + :return: None + """ wm_comm = self.get_comm() filename = get_part_file_name( file_prefix, wm_comm.get_rank(), wm_comm.get_size() @@ -138,7 +166,7 @@ def create_wholememory_tensor( dtype: torch.dtype, strides: List[int], ): - r""" + """ Create empty WholeMemory Tensor. Now only support dim = 1 or 2 :param comm: WholeMemoryCommunicator :param memory_type: WholeMemory type, should be continuous, chunked or distributed @@ -181,7 +209,7 @@ def create_wholememory_tensor_from_filelist( last_dim_size: int = 0, last_dim_strides: int = -1, ): - r""" + """ Create WholeMemory Tensor from list of binary files. :param comm: WholeMemoryCommunicator :param memory_type: WholeMemory type, should be continuous, chunked or distributed @@ -224,5 +252,10 @@ def create_wholememory_tensor_from_filelist( def destroy_wholememory_tensor(wm_tensor: WholeMemoryTensor): + """ + Destroy allocated WholeMemory Tensor + :param wm_tensor: WholeMemory Tensor + :return: None + """ wmb.destroy_wholememory_tensor(wm_tensor.wmb_tensor) wm_tensor.wmb_tensor = None diff --git a/python/pylibwholegraph/pylibwholegraph/torch/utils.py b/python/pylibwholegraph/pylibwholegraph/torch/utils.py index 5077c9942..ee112953c 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/utils.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/utils.py @@ -20,6 +20,11 @@ def torch_dtype_to_wholememory_dtype(torch_dtype: torch.dtype): + """ + Convert torch.dtype to WholeMemoryDataType + :param torch_dtype: torch.dtype + :return: WholeMemoryDataType + """ if torch_dtype == torch.float: return WholeMemoryDataType.DtFloat elif torch_dtype == torch.half: @@ -41,6 +46,11 @@ def torch_dtype_to_wholememory_dtype(torch_dtype: torch.dtype): def wholememory_dtype_to_torch_dtype(wm_dtype: WholeMemoryDataType): + """ + Convert WholeMemoryDataType to torch.dtype + :param wm_dtype: WholeMemoryDataType + :return: torch.dtype + """ if wm_dtype == WholeMemoryDataType.DtFloat: return torch.float elif wm_dtype == WholeMemoryDataType.DtHalf: @@ -62,6 +72,11 @@ def wholememory_dtype_to_torch_dtype(wm_dtype: WholeMemoryDataType): def get_file_size(filename: str): + """ + Get file size. + :param filename: file name + :return: size of file + """ if not os.path.isfile(filename): raise ValueError("File %s not found or not file" % (filename,)) if not os.access(filename, os.R_OK): diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py index 25c124322..7c7fe4e00 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_env.py @@ -49,29 +49,51 @@ def __init__(self): class TorchMemoryContext(object): def __init__(self): self.tensor = None + if torch_cpp_ext_loaded: + self.handle = torch_cpp_ext_lib.create_output_context() + else: + self.handle = 0 + + def __del__(self): + self.free() + + def get_c_context(self): + if torch_cpp_ext_loaded: + return self.handle + else: + return id(self) def set_tensor(self, t: torch.Tensor): self.tensor = t + def get_handle(self): + return self.handle + def get_tensor(self): - return self.tensor + if torch_cpp_ext_loaded: + self.tensor = torch_cpp_ext_lib.get_tensor_from_context(self.handle) + return self.tensor + else: + return self.tensor def free(self): self.tensor = None + if torch_cpp_ext_loaded and self.get_handle() != 0: + torch_cpp_ext_lib.destroy_output_context(self.get_handle()) + self.handle = 0 def torch_create_memory_context_env_fn( global_context: TorchEmptyGlobalContext, ) -> TorchMemoryContext: t = TorchMemoryContext() - # print('torch_create_memory_context_env_fn t=%d' % (id(t), )) return t def torch_destroy_memory_context_env_fn( memory_context: TorchMemoryContext, global_context: TorchEmptyGlobalContext ): - pass + memory_context.free() def torch_malloc_env_fn( @@ -168,6 +190,7 @@ def get_cpp_extension_src_path(): def compile_cpp_extension(): import torch.utils.cpp_extension + global torch_cpp_ext_loaded global torch_cpp_ext_lib cpp_extension_path = os.path.join(get_cpp_extension_src_path(), "torch_cpp_ext") @@ -180,6 +203,13 @@ def compile_cpp_extension(): extra_ldflags.append( "".join(["-L", os.path.join(os.environ["CONDA_PREFIX"], "lib")]) ) + if "LIBWHOLEGRAPH_DIR" in os.environ: + extra_cflags.append( + "".join(["-I", os.path.join(os.environ["LIBWHOLEGRAPH_DIR"], "include")]) + ) + extra_ldflags.append( + "".join(["-L", os.path.join(os.environ["LIBWHOLEGRAPH_DIR"], "lib")]) + ) torch.utils.cpp_extension.load( name="pylibwholegraph.pylibwholegraph_torch_ext", sources=[ @@ -192,5 +222,7 @@ def compile_cpp_extension(): with_cuda=True, verbose=True, ) - torch_cpp_ext_lib = importlib.import_module('pylibwholegraph.pylibwholegraph_torch_ext') + torch_cpp_ext_lib = importlib.import_module( + "pylibwholegraph.pylibwholegraph_torch_ext" + ) torch_cpp_ext_loaded = True diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py index 535cf7b20..2a49e2fab 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholegraph_ops.py @@ -41,26 +41,26 @@ def unweighted_sample_without_replacement( center_nodes_tensor.shape[0] + 1, device="cuda", dtype=torch.int ) output_dest_context = TorchMemoryContext() - output_dest_tensor_id = id(output_dest_context) + output_dest_c_context = output_dest_context.get_c_context() output_center_localid_context = None - output_center_localid_tensor_id = 0 + output_center_localid_c_context = 0 output_edge_gid_context = None - output_edge_gid_tensor_id = 0 + output_edge_gid_c_context = 0 if need_center_local_output: output_center_localid_context = TorchMemoryContext() - output_center_localid_tensor_id = id(output_center_localid_context) + output_center_localid_c_context = output_center_localid_context.get_c_context() if need_edge_output: output_edge_gid_context = TorchMemoryContext() - output_edge_gid_tensor_id = id(output_edge_gid_context) + output_edge_gid_c_context = output_edge_gid_context.get_c_context() wmb.csr_unweighted_sample_without_replacement( wm_csr_row_ptr_tensor, wm_csr_col_ptr_tensor, wrap_torch_tensor(center_nodes_tensor), max_sample_count, wrap_torch_tensor(output_sample_offset_tensor), - output_dest_tensor_id, - output_center_localid_tensor_id, - output_edge_gid_tensor_id, + output_dest_c_context, + output_center_localid_c_context, + output_edge_gid_c_context, random_seed, get_wholegraph_env_fns(), get_stream(), @@ -109,17 +109,17 @@ def weighted_sample_without_replacement( center_nodes_tensor.shape[0] + 1, device="cuda", dtype=torch.int ) output_dest_context = TorchMemoryContext() - output_dest_tensor_id = id(output_dest_context) + output_dest_c_context = output_dest_context.get_c_context() output_center_localid_context = None - output_center_localid_tensor_id = 0 + output_center_localid_c_context = 0 output_edge_gid_context = None - output_edge_gid_tensor_id = 0 + output_edge_gid_c_context = 0 if need_center_local_output: output_center_localid_context = TorchMemoryContext() - output_center_localid_tensor_id = id(output_center_localid_context) + output_center_localid_c_context = output_center_localid_context.get_c_context() if need_edge_output: output_edge_gid_context = TorchMemoryContext() - output_edge_gid_tensor_id = id(output_edge_gid_context) + output_edge_gid_c_context = output_edge_gid_context.get_c_context() wmb.csr_weighted_sample_without_replacement( wm_csr_row_ptr_tensor, wm_csr_col_ptr_tensor, @@ -127,9 +127,9 @@ def weighted_sample_without_replacement( wrap_torch_tensor(center_nodes_tensor), max_sample_count, wrap_torch_tensor(output_sample_offset_tensor), - output_dest_tensor_id, - output_center_localid_tensor_id, - output_edge_gid_tensor_id, + output_dest_c_context, + output_center_localid_c_context, + output_edge_gid_c_context, random_seed, get_wholegraph_env_fns(), get_stream(), @@ -155,3 +155,23 @@ def weighted_sample_without_replacement( ) else: return output_sample_offset_tensor, output_dest_context.get_tensor() + + +def generate_random_positive_int_cpu( + random_seed, sub_sequence, output_random_value_count +): + output = torch.empty((output_random_value_count,), dtype=torch.int) + wmb.host_generate_random_positive_int( + random_seed, sub_sequence, wrap_torch_tensor(output) + ) + return output + + +def generate_exponential_distribution_negative_float_cpu( + random_seed: int, sub_sequence: int, output_random_value_count: int +): + output = torch.empty((output_random_value_count,), dtype=torch.float) + wmb.host_generate_exponential_distribution_negative_float( + random_seed, sub_sequence, wrap_torch_tensor(output) + ) + return output diff --git a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py index 02904b977..5bc25d4ca 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py +++ b/python/pylibwholegraph/pylibwholegraph/torch/wholememory_ops.py @@ -27,6 +27,14 @@ def wholememory_gather_forward_functor( requires_grad=False, torch_output_dtype=None, ): + """ + Wrapper functor for gather op of WholeMemory Tensor + :param wholememory_tensor: PyWholeMemoryTensor + :param indices_tensor: Indices to gather from + :param requires_grad: if requires gradients + :param torch_output_dtype: output dtype, None for same as wholememory_tensor + :return: Gathered tensor + """ assert indices_tensor.dim() == 1 assert indices_tensor.dtype == torch.int32 or indices_tensor.dtype == torch.int64 if torch_output_dtype is None: @@ -52,6 +60,13 @@ def wholememory_scatter_functor( indices_tensor: torch.Tensor, wholememory_tensor: wmb.PyWholeMemoryTensor, ): + """ + Wrapper functor for scatter op of WholeMemory Tensor + :param input_tensor: Input tensor to scater to WholeMemory Tensor + :param indices_tensor: Indices to scatter to + :param wholememory_tensor: WholeMemory Tensor + :return: None + """ assert indices_tensor.dim() == 1 assert indices_tensor.dtype == torch.int32 or indices_tensor.dtype == torch.int64 wmb.wholememory_scatter_op( diff --git a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp index c3191bbdf..15d2e5160 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp +++ b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.cpp @@ -50,4 +50,14 @@ wholememory_env_func_t* get_pytorch_env_func() { return &pytorch_env_func; } cudaStream_t get_current_stream() { return at::cuda::getCurrentCUDAStream(); } +void* create_output_context() { + void* output_context = nullptr; + create_torch_memory_context_func(&output_context, nullptr); + return output_context; +} + +void destroy_output_context(void* output_context) { + destroy_torch_memory_context_func(output_context, nullptr); +} + } // namespace wholegraph_torch diff --git a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.h b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.h index 90bdcc4cb..56b966953 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.h +++ b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/torch_env_func_ptrs.h @@ -29,4 +29,8 @@ wholememory_env_func_t* get_pytorch_env_func(); cudaStream_t get_current_stream(); +void* create_output_context(); + +void destroy_output_context(void* output_context); + } // namespace wholegraph_torch diff --git a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp index ebc396371..f1dcbecdb 100644 --- a/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp +++ b/python/pylibwholegraph/pylibwholegraph/torch_cpp_ext/wholegraph_torch_ext.cpp @@ -1,4 +1,5 @@ #include +#include #include "torch_env_func_ptrs.h" #include "torch_utils.h" @@ -13,10 +14,32 @@ int64_t wrapped_get_stream() return reinterpret_cast(static_cast(wholegraph_torch::get_current_stream())); } +int64_t wrapped_create_output_context() +{ + return reinterpret_cast(wholegraph_torch::create_output_context()); +} + +void wrapped_destroy_output_context(int64_t output_context) +{ + wholegraph_torch::destroy_output_context(reinterpret_cast(output_context)); +} + +torch::Tensor get_torch_tensor_from_output_context(int64_t output_context) +{ + auto* torch_output_context = + static_cast(reinterpret_cast(output_context)); + return torch_output_context->tensor; +} + PYBIND11_MODULE(pylibwholegraph_torch_ext, m) { m.def("get_wholegraph_env_fns", &wrapped_get_wholegraph_env_fns, "Get WholeGraph Environment functions."); m.def("get_stream", &wrapped_get_stream, "Get current CUDA stream."); + m.def("create_output_context", &wrapped_create_output_context, "Create output memory context."); + m.def("destroy_output_context", &wrapped_destroy_output_context, "Destroy output memory context."); + m.def("get_tensor_from_context", + &get_torch_tensor_from_output_context, + "Get PyTorch Tensor from output memory context"); } diff --git a/python/pylibwholegraph/pylibwholegraph/utils/multiprocess.py b/python/pylibwholegraph/pylibwholegraph/utils/multiprocess.py index f1b0aedc8..445d5fb38 100644 --- a/python/pylibwholegraph/pylibwholegraph/utils/multiprocess.py +++ b/python/pylibwholegraph/pylibwholegraph/utils/multiprocess.py @@ -15,6 +15,13 @@ def multiprocess_run(world_size: int, func, inline_single_process=False): + """ + Run func in multiple process + :param world_size: process count + :param func: function to run + :param inline_single_process: when only one process, whether to use current process to run. + :return: None + """ assert world_size > 0 if world_size == 1 and inline_single_process: func(0, 1) diff --git a/python/pylibwholegraph/wholegraph_torch/CMakeLists.txt b/python/pylibwholegraph/wholegraph_torch/CMakeLists.txt deleted file mode 100644 index 8aa8d3d16..000000000 --- a/python/pylibwholegraph/wholegraph_torch/CMakeLists.txt +++ /dev/null @@ -1,70 +0,0 @@ -#============================================================================= -# Copyright (c) 2018-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -#============================================================================= -# Find PyTorch -# Get PyTorch cmake path -set(PY_EXE ${Python_EXECUTABLE}) -execute_process(COMMAND ${PY_EXE} -c "import torch.utils; print(torch.utils.cmake_prefix_path)" - OUTPUT_VARIABLE TORCH_CMAKE_PREFIX OUTPUT_STRIP_TRAILING_WHITESPACE ERROR_QUIET) -message("TORCH_CMAKE_PREFIX=${TORCH_CMAKE_PREFIX}") -set(Torch_ROOT "${TORCH_CMAKE_PREFIX}/Torch") -set(TORCH_CUDA_ARCH_LIST "7.0;8.0;8.6") -find_package(Torch "1.8.0" "REQUIRED") -#execute_process(COMMAND ${PY_EXE} -c "from torch.utils.cpp_extension import CUDAExtension as ext; e = ext('', []); print(';'.join(e.library_dirs))" -# OUTPUT_VARIABLE Torch_LIBRARY_DIRS OUTPUT_STRIP_TRAILING_WHITESPACE) -#message(STATUS "Torch_LIBRARY_DIRS=${Torch_LIBRARY_DIRS}") -#string(REGEX REPLACE "No CUDA runtime[^\n]*\n?" "" Torch_LIBRARY_DIRS "${Torch_LIBRARY_DIRS}") -#execute_process(COMMAND ${PY_EXE} -c "from torch.utils.cpp_extension import CUDAExtension as ext; e = ext('', []); print(';'.join(e.libraries))" -# OUTPUT_VARIABLE _Torch_LIBRARIES OUTPUT_STRIP_TRAILING_WHITESPACE) -#string(REGEX REPLACE "No CUDA runtime[^\n]*\n?" "" _Torch_LIBRARIES "${_Torch_LIBRARIES}") -#foreach (_TLIB IN LISTS _Torch_LIBRARIES) -# find_library(FOUND_LIB_${_TLIB} -# NAMES ${_TLIB} -# HINTS ${Torch_LIBRARY_DIRS}) -# list(APPEND TORCH_LIBRARIES ${FOUND_LIB_${_TLIB}}) -#endforeach () -if (NOT TORCH_FOUND) - message(FATAL_ERROR "Torch not found.") - return() -endif () -#execute_process(COMMAND ${PY_EXE} -c "import torch; print(torch.torch.compiled_with_cxx11_abi())" -# OUTPUT_VARIABLE Torch_CXX11 OUTPUT_STRIP_TRAILING_WHITESPACE) -#string(TOUPPER ${Torch_CXX11} Torch_CXX11) -#message(STATUS "Torch_CXX11: ${Torch_CXX11}") -#set(USE_CXX11_ABI ${Torch_CXX11}) - -file(GLOB WHOLEGRAPH_PYTORCH_SOURCES - "*.cpp" - "ops/*.cpp") - -message(STATUS "TORCH_CXX_FLAGS=${TORCH_CXX_FLAGS}") - -add_library(wholegraph_torch_cxx11abi SHARED "") -target_sources(wholegraph_torch_cxx11abi PRIVATE ${WHOLEGRAPH_PYTORCH_SOURCES}) -set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "-D_GLIBCXX_USE_CXX11_ABI=1") -target_include_directories(wholegraph_torch_cxx11abi - PRIVATE - "${wholegraph_ROOT}/include") -target_link_libraries(wholegraph_torch_cxx11abi "${TORCH_LIBRARIES}" "${WHOLEGRAPH_CPP_TARGET}") - -add_library(wholegraph_torch_precxx11abi SHARED "") -target_sources(wholegraph_torch_precxx11abi PRIVATE ${WHOLEGRAPH_PYTORCH_SOURCES}) -set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "-D_GLIBCXX_USE_CXX11_ABI=0") -target_include_directories(wholegraph_torch_precxx11abi - PRIVATE - "${wholegraph_ROOT}/include") -target_link_libraries(wholegraph_torch_precxx11abi "${TORCH_LIBRARIES}" "${WHOLEGRAPH_CPP_TARGET}") - -set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS ${TORCH_CXX_FLAGS}) diff --git a/python/pylibwholegraph/wholegraph_torch/ops/gather_scatter_ops.cpp b/python/pylibwholegraph/wholegraph_torch/ops/gather_scatter_ops.cpp deleted file mode 100644 index 4b230eda2..000000000 --- a/python/pylibwholegraph/wholegraph_torch/ops/gather_scatter_ops.cpp +++ /dev/null @@ -1,91 +0,0 @@ -#include - -#include -#include -#include - -#include "../torch_env_func_ptrs.h" -#include "../torch_utils.h" - -namespace wholegraph_torch { - -torch::Tensor gather(int64_t wholememory_tensor_handle, - const torch::Tensor& indices, - torch::optional output_type, - torch::optional requires_grad) -{ - torch_tensor_check_dim(indices, 1, "gather indice"); - torch_tensor_check_dtype_is_index(indices, "gather, indices"); - - wrapped_torch_tensor const wrapped_indices_tensor(indices); - auto wt = reinterpret_cast(wholememory_tensor_handle); - auto* p_wm_tensor_desc = wholememory_tensor_get_tensor_description(wt); - auto* p_indices_desc = - wholememory_tensor_get_tensor_description(wrapped_indices_tensor.get_wholememory_tensor()); - TORCH_CHECK(p_wm_tensor_desc->dim == 1 || p_wm_tensor_desc->dim == 2, - "wholememory_tensor_handle should be 1D or 2D WholeMemory Tensor.") - - wholememory_dtype_t wm_output_type = p_wm_tensor_desc->dtype; - if (output_type.has_value()) { wm_output_type = get_wholememory_dtype(output_type.value()); } - - wholememory_tensor_description_t output_alloc_tensor_desc; - output_alloc_tensor_desc.dtype = wm_output_type; - output_alloc_tensor_desc.dim = p_wm_tensor_desc->dim; - output_alloc_tensor_desc.storage_offset = 0; - output_alloc_tensor_desc.sizes[0] = p_indices_desc->sizes[0]; - output_alloc_tensor_desc.strides[output_alloc_tensor_desc.dim - 1] = 1; - if (p_wm_tensor_desc->dim == 2) { - output_alloc_tensor_desc.sizes[1] = output_alloc_tensor_desc.strides[0] = - p_wm_tensor_desc->sizes[1]; - } - - pytorch_memory_context output_context; - if (requires_grad.has_value()) { set_need_grad(&output_context, requires_grad.value()); } - torch_common_malloc_func(&output_alloc_tensor_desc, &output_context); - - auto output_tensor = output_context.tensor; - wrapped_torch_tensor const wrapped_output_tensor(output_tensor); - - TORCH_CHECK(wholememory_gather(wt, - wrapped_indices_tensor.get_wholememory_tensor(), - wrapped_output_tensor.get_wholememory_tensor(), - wholegraph_torch::get_pytorch_env_func(), - wholegraph_torch::get_current_stream()) == WHOLEMEMORY_SUCCESS) - return output_tensor; -} - -void scatter(const torch::Tensor& input, - const torch::Tensor& indices, - int64_t wholememory_tensor_handle) -{ - torch_tensor_check_dim_in_range(input, 1, 2, "scatter input"); - torch_tensor_check_dim(indices, 1, "scatter indice"); - torch_tensor_check_dtype_is_index(indices, "scatter, indices"); - - wrapped_torch_tensor const wrapped_indices_tensor(indices); - wrapped_torch_tensor const wrapped_input_tensor(input); - auto wt = reinterpret_cast(wholememory_tensor_handle); - auto* p_wm_tensor_desc = wholememory_tensor_get_tensor_description(wt); - TORCH_CHECK(p_wm_tensor_desc->dim == input.dim(), - "input and wholememory_tensor_hand should be same dim.") - - if (input.dim() == 2) { - TORCH_CHECK(input.size(1) == p_wm_tensor_desc->sizes[1], - "input and wholememory should have same embedding size but input.size(1)=%ld, " - "wholememory.size(1)=%ld", - input.size(1), - p_wm_tensor_desc->sizes[1]) - } - - TORCH_CHECK(wholememory_scatter(wrapped_input_tensor.get_wholememory_tensor(), - wrapped_indices_tensor.get_wholememory_tensor(), - wt, - wholegraph_torch::get_pytorch_env_func(), - wholegraph_torch::get_current_stream()) == WHOLEMEMORY_SUCCESS) -} - -} // namespace wholegraph_torch - -static auto registry = torch::RegisterOperators() - .op("wholegraph::gather", &wholegraph_torch::gather) - .op("wholegraph::scatter", &wholegraph_torch::scatter); diff --git a/python/pylibwholegraph/wholegraph_torch/ops/raft_ramdom_generator_ops.cpp b/python/pylibwholegraph/wholegraph_torch/ops/raft_ramdom_generator_ops.cpp deleted file mode 100644 index aae454ebf..000000000 --- a/python/pylibwholegraph/wholegraph_torch/ops/raft_ramdom_generator_ops.cpp +++ /dev/null @@ -1,74 +0,0 @@ -#include - -#include - -#include "test_raft_random.cuh" - -namespace wholegraph_torch_test { - -torch::Tensor raft_pcg_generator_random(int64_t random_seed, - int64_t subsequence, - int64_t generated_random_number) -{ - auto to = torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64).requires_grad(false); - torch::Tensor output = torch::empty({(long)(generated_random_number)}, to); - - TestPCGenerator rng((unsigned long long)random_seed, subsequence, 0); - for (int64_t i = 0; i < generated_random_number; i++) { - uint32_t random_num; - rng.next(random_num); - output[i].data_ptr()[0] = (int64_t)random_num; - } - - return output; -} - -torch::Tensor raft_pcg_generator_random_from_weight(int64_t random_seed, - int64_t subsequence, - torch::Tensor edge_weight, - int64_t generated_random_number) -{ - auto to = - torch::TensorOptions().device(torch::kCPU).dtype(edge_weight.dtype()).requires_grad(false); - torch::Tensor output = torch::empty({(long)(generated_random_number)}, to); - TestPCGenerator rng((unsigned long long)random_seed, subsequence, 0); - for (int64_t i = 0; i < generated_random_number; i++) { - float u = -rng.next_float(1.0f, 0.5f); - int64_t random_num2 = 0; - int seed_count = -1; - do { - rng.next(random_num2); - seed_count++; - } while (!random_num2); - auto count_one = [](unsigned long long num) { - int c = 0; - while (num) { - num >>= 1; - c++; - } - return 64 - c; - }; - int one_bit = count_one(random_num2) + seed_count * 64; - u *= pow(2, -one_bit); - // float logk = (log1pf(u) / logf(2.0)) * (1.0f / (float)weight); - if (edge_weight.dtype() == torch::kFloat32) { - float weight = edge_weight[i].data_ptr()[0]; - float logk = (1 / weight) * (log1p(u) / log(2.0)); - output[i].data_ptr()[0] = logk; - } else if (edge_weight.dtype() == torch::kFloat64) { - double weight = edge_weight[i].data_ptr()[0]; - double logk = (1 / weight) * (log1p(u) / log(2.0)); - output[i].data_ptr()[0] = logk; - } - } - - return output; -} - -} // namespace wholegraph_torch_test - -static auto registry = torch::RegisterOperators() - .op("wholegraph_test::raft_pcg_generator_random", - &wholegraph_torch_test::raft_pcg_generator_random) - .op("wholegraph_test::raft_pcg_generator_random_from_weight", - &wholegraph_torch_test::raft_pcg_generator_random_from_weight); diff --git a/python/pylibwholegraph/wholegraph_torch/ops/test_raft_random.cuh b/python/pylibwholegraph/wholegraph_torch/ops/test_raft_random.cuh deleted file mode 100644 index 6b3036012..000000000 --- a/python/pylibwholegraph/wholegraph_torch/ops/test_raft_random.cuh +++ /dev/null @@ -1,128 +0,0 @@ -#pragma once - -#include - -/** PCG random number generator from raft */ -struct TestPCGenerator { - /** - * @brief ctor. Initializes the state for RNG. This code is derived from PCG basic code - * @param seed the seed (can be same across all threads). Same as PCG's initstate - * @param subsequence is same as PCG's initseq - * @param offset unused - */ - __host__ __device__ __forceinline__ TestPCGenerator(uint64_t seed, - uint64_t subsequence, - uint64_t offset) - { - pcg_state = uint64_t(0); - inc = (subsequence << 1u) | 1u; - uint32_t discard; - next(discard); - pcg_state += seed; - next(discard); - skipahead(offset); - } - - // Based on "Random Number Generation with Arbitrary Strides" F. B. Brown - // Link https://mcnp.lanl.gov/pdf_files/anl-rn-arb-stride.pdf - __host__ __device__ __forceinline__ void skipahead(uint64_t offset) - { - uint64_t G = 1; - uint64_t h = 6364136223846793005ULL; - uint64_t C = 0; - uint64_t f = inc; - while (offset) { - if (offset & 1) { - G = G * h; - C = C * h + f; - } - f = f * (h + 1); - h = h * h; - offset >>= 1; - } - pcg_state = pcg_state * G + C; - } - - /** - * @defgroup NextRand Generate the next random number - * @brief This code is derived from PCG basic code - * @{ - */ - __host__ __device__ __forceinline__ uint32_t next_u32() - { - uint32_t ret; - uint64_t oldstate = pcg_state; - pcg_state = oldstate * 6364136223846793005ULL + inc; - uint32_t xorshifted = ((oldstate >> 18u) ^ oldstate) >> 27u; - uint32_t rot = oldstate >> 59u; - ret = (xorshifted >> rot) | (xorshifted << ((-rot) & 31)); - return ret; - } - __host__ __device__ __forceinline__ uint64_t next_u64() - { - uint64_t ret; - uint32_t a, b; - a = next_u32(); - b = next_u32(); - ret = uint64_t(a) | (uint64_t(b) << 32); - return ret; - } - - __host__ __device__ __forceinline__ int32_t next_i32() - { - int32_t ret; - uint32_t val; - val = next_u32(); - ret = int32_t(val & 0x7fffffff); - return ret; - } - - __host__ __device__ __forceinline__ int64_t next_i64() - { - int64_t ret; - uint64_t val; - val = next_u64(); - ret = int64_t(val & 0x7fffffffffffffff); - return ret; - } - - __host__ __device__ __forceinline__ float next_float() - { - float ret; - uint32_t val = next_u32() >> 8; - ret = static_cast(val) / (1U << 24); - return ret; - } - - __host__ __device__ __forceinline__ float next_float(float max, float min) - { - float ret; - uint32_t val = next_u32() >> 8; - ret = static_cast(val) / (1U << 24); - ret *= (max - min); - ret += min; - return ret; - } - - __host__ __device__ __forceinline__ double next_double() - { - double ret; - uint64_t val = next_u64() >> 11; - ret = static_cast(val) / (1LU << 53); - return ret; - } - - __host__ __device__ __forceinline__ void next(uint32_t& ret) { ret = next_u32(); } - __host__ __device__ __forceinline__ void next(uint64_t& ret) { ret = next_u64(); } - __host__ __device__ __forceinline__ void next(int32_t& ret) { ret = next_i32(); } - __host__ __device__ __forceinline__ void next(int64_t& ret) { ret = next_i64(); } - - __host__ __device__ __forceinline__ void next(float& ret) { ret = next_float(); } - __host__ __device__ __forceinline__ void next(double& ret) { ret = next_double(); } - - /** @} */ - - private: - uint64_t pcg_state; - uint64_t inc; -}; diff --git a/python/pylibwholegraph/wholegraph_torch/ops/unweighted_sample_without_replacement_ops.cpp b/python/pylibwholegraph/wholegraph_torch/ops/unweighted_sample_without_replacement_ops.cpp deleted file mode 100644 index dc7d4200c..000000000 --- a/python/pylibwholegraph/wholegraph_torch/ops/unweighted_sample_without_replacement_ops.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include -#include - -#include -#include -#include - -#include "../torch_env_func_ptrs.h" -#include "../torch_utils.h" - -using torch::autograd::variable_list; - -namespace wholegraph_torch { - -variable_list unweighted_sample_without_replacement(int64_t csr_row_ptr_wholememory_tensor_handle, - int64_t csr_col_ptr_wholememory_tensor_handle, - torch::Tensor& input_nodes, - int64_t max_sample_count, - torch::optional random_seed) -{ - torch_tensor_check_dim(input_nodes, 1, "unweighted_sample_without_replacement, input_nodes"); - torch_tensor_check_dtype_is_index(input_nodes, - "unweighted_sample_without_replacement, input_nodes"); - wrapped_torch_tensor const wrapped_input_nodes_tensor(input_nodes); - - auto wm_csr_row_ptr_tensor = - reinterpret_cast(csr_row_ptr_wholememory_tensor_handle); - auto wm_csr_col_ptr_tensor = - reinterpret_cast(csr_col_ptr_wholememory_tensor_handle); - auto* p_wm_csr_row_ptr_tensor_desc = - wholememory_tensor_get_tensor_description(wm_csr_row_ptr_tensor); - auto* p_wm_csr_col_ptr_tensor_desc = - wholememory_tensor_get_tensor_description(wm_csr_col_ptr_tensor); - auto* p_input_nodes_desc = - wholememory_tensor_get_tensor_description(wrapped_input_nodes_tensor.get_wholememory_tensor()); - - TORCH_CHECK(p_wm_csr_row_ptr_tensor_desc->dim == 1, - "csr_row_ptr_wholememory_tensor_handle should be 1D WholeMemory Tensor.") - TORCH_CHECK(p_wm_csr_col_ptr_tensor_desc->dim == 1, - "csr_col_ptr_wholememory_tensor_handle should be 1D WholeMemory Tensor.") - - TORCH_CHECK(p_wm_csr_row_ptr_tensor_desc->dtype == WHOLEMEMORY_DT_INT64, - "csr_row_ptr_wholememory_tensor_handle should be int64 WholeMemory Tensor.") - TORCH_CHECK(p_wm_csr_col_ptr_tensor_desc->dtype == WHOLEMEMORY_DT_INT || WHOLEMEMORY_DT_INT64, - "csr_col_ptr_wholememory_tensor_handle should be int or int64 WholeMemory Tensor.") - - wholememory_tensor_description_t output_sample_offset_tensor_desc; - output_sample_offset_tensor_desc.dtype = WHOLEMEMORY_DT_INT; - output_sample_offset_tensor_desc.dim = 1; - output_sample_offset_tensor_desc.storage_offset = 0; - output_sample_offset_tensor_desc.sizes[0] = p_input_nodes_desc->sizes[0] + 1; - output_sample_offset_tensor_desc.strides[0] = 1; - - pytorch_memory_context output_sample_offset_context, output_dest_memory_context, - output_center_localid_memory_context, output_edge_gid_memory_context; - - torch_common_malloc_func(&output_sample_offset_tensor_desc, &output_sample_offset_context); - auto output_sample_offset_tensor = output_sample_offset_context.tensor; - wrapped_torch_tensor const wrapped_output_sample_offset_tensor(output_sample_offset_tensor); - - unsigned long long random_seed_value; - if (random_seed.has_value()) { - random_seed_value = random_seed.value(); - } else { - thread_local std::random_device rd; - thread_local std::mt19937 gen(rd()); - thread_local std::uniform_int_distribution distrib; - random_seed_value = distrib(gen); - } - - TORCH_CHECK(wholegraph_csr_unweighted_sample_without_replacement( - wm_csr_row_ptr_tensor, - wm_csr_col_ptr_tensor, - wrapped_input_nodes_tensor.get_wholememory_tensor(), - max_sample_count, - wrapped_output_sample_offset_tensor.get_wholememory_tensor(), - &output_dest_memory_context, - &output_center_localid_memory_context, - &output_edge_gid_memory_context, - random_seed_value, - wholegraph_torch::get_pytorch_env_func(), - wholegraph_torch::get_current_stream()) == WHOLEMEMORY_SUCCESS) - - auto output_dest_tensor = output_dest_memory_context.tensor; - auto output_center_localid_tensor = output_center_localid_memory_context.tensor; - auto output_edge_gid_tensor = output_edge_gid_memory_context.tensor; - - return {output_sample_offset_tensor, - output_dest_tensor, - output_center_localid_tensor, - output_edge_gid_tensor}; -} - -} // namespace wholegraph_torch - -static auto registry = - torch::RegisterOperators().op("wholegraph::unweighted_sample_without_replacement", - &wholegraph_torch::unweighted_sample_without_replacement); diff --git a/python/pylibwholegraph/wholegraph_torch/ops/weighted_sample_without_replacement_ops.cpp b/python/pylibwholegraph/wholegraph_torch/ops/weighted_sample_without_replacement_ops.cpp deleted file mode 100644 index c2c508e51..000000000 --- a/python/pylibwholegraph/wholegraph_torch/ops/weighted_sample_without_replacement_ops.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include -#include -#include -#include -#include - -#include "../torch_env_func_ptrs.h" -#include "../torch_utils.h" - -using torch::autograd::variable_list; - -namespace wholegraph_torch { - -variable_list weighted_sample_without_replacement(int64_t csr_row_ptr_wholememory_tensor_handle, - int64_t csr_col_ptr_wholememory_tensor_handle, - int64_t csr_weight_ptr_wholememory_tensor_handle, - torch::Tensor& input_nodes, - int64_t max_sample_count, - torch::optional random_seed) -{ - torch_tensor_check_dim(input_nodes, 1, "weighted_sample_without_replacement, input_nodes"); - torch_tensor_check_dtype_is_index(input_nodes, - "weighted_sample_without_replacement, input_nodes"); - wrapped_torch_tensor const wrapped_input_nodes_tensor(input_nodes); - - auto wm_csr_row_ptr_tensor = - reinterpret_cast(csr_row_ptr_wholememory_tensor_handle); - auto wm_csr_col_ptr_tensor = - reinterpret_cast(csr_col_ptr_wholememory_tensor_handle); - auto wm_csr_weight_ptr_tensor = - reinterpret_cast(csr_weight_ptr_wholememory_tensor_handle); - auto* p_wm_csr_row_ptr_tensor_desc = - wholememory_tensor_get_tensor_description(wm_csr_row_ptr_tensor); - auto* p_wm_csr_col_ptr_tensor_desc = - wholememory_tensor_get_tensor_description(wm_csr_col_ptr_tensor); - auto* p_wm_csr_weight_ptr_tensor_desc = - wholememory_tensor_get_tensor_description(wm_csr_weight_ptr_tensor); - auto* p_input_nodes_desc = - wholememory_tensor_get_tensor_description(wrapped_input_nodes_tensor.get_wholememory_tensor()); - - TORCH_CHECK(p_wm_csr_row_ptr_tensor_desc->dim == 1, - "csr_row_ptr_wholememory_tensor_handle should be 1D WholeMemory Tensor.") - TORCH_CHECK(p_wm_csr_col_ptr_tensor_desc->dim == 1, - "csr_col_ptr_wholememory_tensor_handle should be 1D WholeMemory Tensor.") - TORCH_CHECK(p_wm_csr_weight_ptr_tensor_desc->dim == 1, - "csr_weight_ptr_wholememory_tensor_handle should be 1D WholeMemory Tensor.") - - TORCH_CHECK(p_wm_csr_row_ptr_tensor_desc->dtype == WHOLEMEMORY_DT_INT64, - "csr_row_ptr_wholememory_tensor_handle should be int64 WholeMemory Tensor.") - TORCH_CHECK(p_wm_csr_col_ptr_tensor_desc->dtype == WHOLEMEMORY_DT_INT || WHOLEMEMORY_DT_INT64, - "csr_col_ptr_wholememory_tensor_handle should be int or int64 WholeMemory Tensor.") - TORCH_CHECK( - p_wm_csr_weight_ptr_tensor_desc->dtype == WHOLEMEMORY_DT_FLOAT || WHOLEMEMORY_DT_DOUBLE, - "csr_weight_ptr_wholememory_tensor_handle should be 1D WholeMemory Tensor.") - - wholememory_tensor_description_t output_sample_offset_tensor_desc; - output_sample_offset_tensor_desc.dtype = WHOLEMEMORY_DT_INT; - output_sample_offset_tensor_desc.dim = 1; - output_sample_offset_tensor_desc.storage_offset = 0; - output_sample_offset_tensor_desc.sizes[0] = p_input_nodes_desc->sizes[0] + 1; - output_sample_offset_tensor_desc.strides[0] = 1; - - pytorch_memory_context output_sample_offset_context, output_dest_memory_context, - output_center_localid_memory_context, output_edge_gid_memory_context; - - torch_common_malloc_func(&output_sample_offset_tensor_desc, &output_sample_offset_context); - auto output_sample_offset_tensor = output_sample_offset_context.tensor; - wrapped_torch_tensor const wrapped_output_sample_offset_tensor(output_sample_offset_tensor); - - unsigned long long random_seed_value; - if (random_seed.has_value()) { - random_seed_value = random_seed.value(); - } else { - thread_local std::random_device rd; - thread_local std::mt19937 gen(rd()); - thread_local std::uniform_int_distribution distrib; - random_seed_value = distrib(gen); - } - - TORCH_CHECK(wholegraph_csr_weighted_sample_without_replacement( - wm_csr_row_ptr_tensor, - wm_csr_col_ptr_tensor, - wm_csr_weight_ptr_tensor, - wrapped_input_nodes_tensor.get_wholememory_tensor(), - max_sample_count, - wrapped_output_sample_offset_tensor.get_wholememory_tensor(), - &output_dest_memory_context, - &output_center_localid_memory_context, - &output_edge_gid_memory_context, - random_seed_value, - wholegraph_torch::get_pytorch_env_func(), - wholegraph_torch::get_current_stream()) == WHOLEMEMORY_SUCCESS) - - auto output_dest_tensor = output_dest_memory_context.tensor; - auto output_center_localid_tensor = output_center_localid_memory_context.tensor; - auto output_edge_gid_tensor = output_edge_gid_memory_context.tensor; - - return {output_sample_offset_tensor, - output_dest_tensor, - output_center_localid_tensor, - output_edge_gid_tensor}; -} - -} // namespace wholegraph_torch - -static auto registry = - torch::RegisterOperators().op("wholegraph::weighted_sample_without_replacement", - &wholegraph_torch::weighted_sample_without_replacement); diff --git a/python/pylibwholegraph/wholegraph_torch/torch_env_func_ptrs.cpp b/python/pylibwholegraph/wholegraph_torch/torch_env_func_ptrs.cpp deleted file mode 100644 index c3191bbdf..000000000 --- a/python/pylibwholegraph/wholegraph_torch/torch_env_func_ptrs.cpp +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "torch_env_func_ptrs.h" - -#include - -#include "torch_utils.h" - -namespace wholegraph_torch { - -void* torch_malloc_func(wholememory_tensor_description_t* tensor_description, - wholememory_memory_allocation_type_t memory_allocation_type, - void* memory_context, - void* /*global_context*/) -{ - bool gpu_memory = memory_allocation_type == WHOLEMEMORY_MA_DEVICE; - bool pinned_memory = memory_allocation_type == WHOLEMEMORY_MA_PINNED; - return torch_common_malloc_func(tensor_description, memory_context, gpu_memory, pinned_memory); -} - -static wholememory_env_func_t pytorch_env_func = { - .temporary_fns = - { - .create_memory_context_fn = create_torch_memory_context_func, - .destroy_memory_context_fn = destroy_torch_memory_context_func, - .malloc_fn = torch_malloc_func, - .free_fn = torch_common_free_func, - .global_context = nullptr, - }, - .output_fns = { - .malloc_fn = torch_malloc_func, - .free_fn = torch_common_free_func, - .global_context = nullptr, - }}; - -wholememory_env_func_t* get_pytorch_env_func() { return &pytorch_env_func; } - -cudaStream_t get_current_stream() { return at::cuda::getCurrentCUDAStream(); } - -} // namespace wholegraph_torch diff --git a/python/pylibwholegraph/wholegraph_torch/torch_env_func_ptrs.h b/python/pylibwholegraph/wholegraph_torch/torch_env_func_ptrs.h deleted file mode 100644 index 90bdcc4cb..000000000 --- a/python/pylibwholegraph/wholegraph_torch/torch_env_func_ptrs.h +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include - -namespace wholegraph_torch { - -/** - * @brief : PyTorch environment functions for memory allocation. - * - * @return : pointers to the functions of current CUDA device - */ -wholememory_env_func_t* get_pytorch_env_func(); - -cudaStream_t get_current_stream(); - -} // namespace wholegraph_torch diff --git a/python/pylibwholegraph/wholegraph_torch/torch_utils.cpp b/python/pylibwholegraph/wholegraph_torch/torch_utils.cpp deleted file mode 100644 index e31aeb00e..000000000 --- a/python/pylibwholegraph/wholegraph_torch/torch_utils.cpp +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "torch_utils.h" - -#include - -namespace wholegraph_torch { - -c10::ScalarType get_c10_scalar_type(wholememory_dtype_t wm_dtype) -{ - switch (wm_dtype) { - case WHOLEMEMORY_DT_FLOAT: return c10::ScalarType::Float; - case WHOLEMEMORY_DT_HALF: return c10::ScalarType::Half; - case WHOLEMEMORY_DT_DOUBLE: return c10::ScalarType::Double; - case WHOLEMEMORY_DT_BF16: return c10::ScalarType::BFloat16; - case WHOLEMEMORY_DT_INT: return c10::ScalarType::Int; - case WHOLEMEMORY_DT_INT64: return c10::ScalarType::Long; - case WHOLEMEMORY_DT_INT16: return c10::ScalarType::Short; - case WHOLEMEMORY_DT_INT8: return c10::ScalarType::Char; - default: return c10::ScalarType::Undefined; - } -} - -wholememory_dtype_t get_wholememory_dtype(torch::ScalarType ts_dtype) -{ - switch (ts_dtype) { - case c10::ScalarType::Float: return WHOLEMEMORY_DT_FLOAT; - case c10::ScalarType::Half: return WHOLEMEMORY_DT_HALF; - case c10::ScalarType::Double: return WHOLEMEMORY_DT_DOUBLE; - case c10::ScalarType::BFloat16: return WHOLEMEMORY_DT_BF16; - case c10::ScalarType::Int: return WHOLEMEMORY_DT_INT; - case c10::ScalarType::Long: return WHOLEMEMORY_DT_INT64; - case c10::ScalarType::Short: return WHOLEMEMORY_DT_INT16; - case c10::ScalarType::Char: return WHOLEMEMORY_DT_INT8; - default: return WHOLEMEMORY_DT_UNKNOWN; - } -} - -void set_need_grad(pytorch_memory_context* memory_context, bool require_grad) -{ - memory_context->options = memory_context->options.requires_grad(require_grad); -} - -void create_torch_memory_context_func(void** memory_context, void* /*global_context*/) -{ - *memory_context = new pytorch_memory_context(); -} - -void destroy_torch_memory_context_func(void* memory_context, void* /*global_context*/) -{ - if (memory_context != nullptr) { delete static_cast(memory_context); } -} - -void* torch_common_malloc_func(wholememory_tensor_description_t* tensor_description, - void* memory_context, - bool gpu_memory, - bool pinned) -{ - auto* pytorch_context = static_cast(memory_context); - pytorch_context->desc = *tensor_description; - std::vector shape(tensor_description->dim); - for (int i = 0; i < tensor_description->dim; i++) { - shape[i] = tensor_description->sizes[i]; - } - pytorch_context->options = - pytorch_context->options.dtype(get_c10_scalar_type(tensor_description->dtype)); - if (gpu_memory) { - pytorch_context->options = - pytorch_context->options.device(c10::Device(c10::kCUDA, c10::cuda::current_device())); - } else { - pytorch_context->options = pytorch_context->options.device(c10::Device(c10::kCPU)); - pytorch_context->options = pytorch_context->options.pinned_memory(pinned); - } - try { - pytorch_context->tensor = torch::empty(shape, pytorch_context->options); - } catch (c10::Error& err) { - fprintf(stderr, "torch_common_malloc_func allocation failed. Reasion=%s", err.what()); - throw err; - } - return pytorch_context->tensor.data_ptr(); -} - -void torch_common_free_func(void* memory_context, void* /*global_context*/) -{ - static_cast(memory_context)->tensor = torch::Tensor(); - static_cast(memory_context)->options = torch::TensorOptions(); - wholememory_initialize_tensor_desc(&static_cast(memory_context)->desc); -} - -void get_tensor_desc_from_torch_tensor(wholememory_tensor_description_t* tensor_desc, - const torch::Tensor& t) -{ - tensor_desc->dim = t.dim(); - tensor_desc->dtype = get_wholememory_dtype(t.dtype().toScalarType()); - TORCH_CHECK(tensor_desc->dtype != WHOLEMEMORY_DT_UNKNOWN); - tensor_desc->storage_offset = t.storage_offset(); - for (int i = 0; i < tensor_desc->dim; i++) { - tensor_desc->sizes[i] = t.size(i); - tensor_desc->strides[i] = t.stride(i); - } -} - -void get_array_desc_from_torch_tensor(wholememory_array_description_t* array_desc, - const torch::Tensor& t) -{ - TORCH_CHECK(t.dim() == 1, "get_array_desc_from_torch_tensor: should be 1-dim tensor"); - array_desc->dtype = get_wholememory_dtype(t.dtype().toScalarType()); - TORCH_CHECK(array_desc->dtype != WHOLEMEMORY_DT_UNKNOWN); - array_desc->size = t.size(0); - array_desc->storage_offset = t.storage_offset(); -} - -void get_matrix_desc_from_torch_tensor(wholememory_matrix_description_t* matrix_desc, - const torch::Tensor& t) -{ - TORCH_CHECK(t.dim() == 2, "get_matrix_desc_from_torch_tensor: should be 2-dim tensor"); - matrix_desc->dtype = get_wholememory_dtype(t.dtype().toScalarType()); - TORCH_CHECK(matrix_desc->dtype != WHOLEMEMORY_DT_UNKNOWN); - matrix_desc->sizes[0] = t.size(0); - matrix_desc->sizes[1] = t.size(1); - matrix_desc->stride = t.stride(0); - matrix_desc->storage_offset = t.storage_offset(); -} - -wrapped_torch_tensor::wrapped_torch_tensor(const torch::Tensor& torch_tensor) -{ - wholememory_tensor_description_t tensor_description; - get_tensor_desc_from_torch_tensor(&tensor_description, torch_tensor); - wholememory_make_tensor_from_pointer( - &wholememory_tensor_, torch_tensor.storage().data(), &tensor_description); -} - -wrapped_torch_tensor::~wrapped_torch_tensor() -{ - wholememory_destroy_tensor(wholememory_tensor_); - wholememory_tensor_ = nullptr; -} - -wholememory_tensor_t wrapped_torch_tensor::get_wholememory_tensor() const -{ - return wholememory_tensor_; -} - -void wrapped_torch_tensor::unsqueeze(int dim) -{ - auto* tensor_desc = wholememory_tensor_get_tensor_description(wholememory_tensor_); - TORCH_CHECK(dim >= -tensor_desc->dim - 1 && dim <= tensor_desc->dim, - "dim = ", - dim, - " but t.dim()=", - tensor_desc->dim, - ", should in range [", - -tensor_desc->dim - 1, - ", ", - tensor_desc->dim, - "]") - if (dim < 0) { dim += tensor_desc->dim + 1; } - TORCH_CHECK(wholememory_unsqueeze_tensor(tensor_desc, dim), "unsqueeze failed.") -} - -void wrapped_torch_tensor::squeeze(int dim) -{ - auto* tensor_desc = wholememory_tensor_get_tensor_description(wholememory_tensor_); - TORCH_CHECK(dim >= -tensor_desc->dim && dim < tensor_desc->dim, - "dim = ", - dim, - " but t.dim()=", - tensor_desc->dim, - ", should in range [", - -tensor_desc->dim, - ", ", - tensor_desc->dim, - ")") - if (dim < 0) { dim += tensor_desc->dim; } - TORCH_CHECK(tensor_desc->sizes[dim] == 1, "dim size should be 1") - TORCH_CHECK( - dim == tensor_desc->dim - 1 || tensor_desc->strides[dim] == tensor_desc->strides[dim + 1], - "stride should be same as next dim") - TORCH_CHECK(wholememory_squeeze_tensor(tensor_desc, dim)) -} - -void torch_tensor_check_dim_in_range(const torch::Tensor& t, - int min_dim, - int max_dim, - const char* info) -{ - TORCH_CHECK(t.dim() >= min_dim && t.dim() <= max_dim, - std::string(info), - " dim=", - t.dim(), - ", should in range [", - min_dim, - ", ", - max_dim, - "]") -} - -void torch_tensor_check_dtype(const torch::Tensor& t, torch::Dtype dtype, const char* info) -{ - TORCH_CHECK(t.dtype() == dtype, std::string(info), " should be ", dtype, " but got ", t.dtype()); -} - -void torch_tensor_check_dtype_is_int(const torch::Tensor& t, const char* info) -{ - TORCH_CHECK(t.dtype() == torch::kInt8 || t.dtype() == torch::kInt16 || - t.dtype() == torch::kInt32 || t.dtype() == torch::kInt64, - std::string(info), - " should be integer.") -} - -// int32 or int64 -void torch_tensor_check_dtype_is_index(const torch::Tensor& t, const char* info) -{ - TORCH_CHECK(t.dtype() == torch::kInt32 || t.dtype() == torch::kInt64, - std::string(info), - " should be int32 or int64.") -} - -void torch_tensor_check_dtype_is_float(const torch::Tensor& t, const char* info) -{ - TORCH_CHECK(t.dtype() == torch::kFloat16 || t.dtype() == torch::kBFloat16 || - t.dtype() == torch::kFloat32 || t.dtype() == torch::kFloat64, - std::string(info), - " should be float tensor.") -} - -} // namespace wholegraph_torch diff --git a/python/pylibwholegraph/wholegraph_torch/torch_utils.h b/python/pylibwholegraph/wholegraph_torch/torch_utils.h deleted file mode 100644 index 45774ac4a..000000000 --- a/python/pylibwholegraph/wholegraph_torch/torch_utils.h +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -namespace wholegraph_torch { - -c10::ScalarType get_c10_scalar_type(wholememory_dtype_t wm_dtype); - -wholememory_dtype_t get_wholememory_dtype(torch::ScalarType ts_dtype); - -struct pytorch_memory_context { - torch::Tensor tensor; - torch::TensorOptions options; - wholememory_tensor_description_t desc; -}; - -void set_need_grad(pytorch_memory_context* memory_context, bool require_grad); - -void create_torch_memory_context_func(void** memory_context, void* /*global_context*/); - -void destroy_torch_memory_context_func(void* memory_context, void* /*global_context*/); - -void* torch_common_malloc_func(wholememory_tensor_description_t* tensor_description, - void* memory_context, - bool gpu_memory = true, - bool pinned = false); - -void torch_common_free_func(void* memory_context, void* /*global_context*/); - -void get_tensor_desc_from_torch_tensor(wholememory_tensor_description_t* tensor_desc, - const torch::Tensor& t); - -void get_array_desc_from_torch_tensor(wholememory_array_description_t* array_desc, - const torch::Tensor& t); - -void get_matrix_desc_from_torch_tensor(wholememory_matrix_description_t* matrix_desc, - const torch::Tensor& t); - -class wrapped_torch_tensor { - public: - explicit wrapped_torch_tensor(const torch::Tensor& torch_tensor); - ~wrapped_torch_tensor(); - wholememory_tensor_t get_wholememory_tensor() const; - void unsqueeze(int dim = -1); - void squeeze(int dim = -1); - - private: - wholememory_tensor_t wholememory_tensor_ = nullptr; -}; - -void torch_tensor_check_dim_in_range(const torch::Tensor& t, - int min_dim, - int max_dim, - const char* info); - -inline void torch_tensor_check_dim(const torch::Tensor& t, int dim, const char* info) -{ - return torch_tensor_check_dim_in_range(t, dim, dim, info); -} - -void torch_tensor_check_dtype(const torch::Tensor& t, torch::Dtype dtype, const char* info); - -void torch_tensor_check_dtype_is_int(const torch::Tensor& t, const char* info); - -// int32 or int64 -void torch_tensor_check_dtype_is_index(const torch::Tensor& t, const char* info); - -void torch_tensor_check_dtype_is_float(const torch::Tensor& t, const char* info); - -} // namespace wholegraph_torch