Skip to content

Commit

Permalink
remove ipex cpu module's python dependency. (#2911)
Browse files Browse the repository at this point in the history
* correct all_reduce schema

* remove ipex cpu module's python dependency.

---------

Co-authored-by: blzheng <beilei.zheng@intel.com>
  • Loading branch information
xuhancn and blzheng authored May 21, 2024
1 parent 9a192ef commit 38573f2
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 38 deletions.
8 changes: 0 additions & 8 deletions csrc/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -246,14 +246,6 @@ if(BUILD_STRIPPED_BIN)
set_target_properties(${PLUGIN_NAME_CPU} PROPERTIES LINK_FLAGS_RELEASE -s)
endif()

find_package(PythonLibs)
if(${PYTHONLIBS_FOUND})
target_link_libraries(${PLUGIN_NAME_CPU} PUBLIC ${PYTHON_LIBRARIES})
endif()

find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${PLUGIN_NAME_CPU} PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})

install(TARGETS ${PLUGIN_NAME_CPU}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
Expand Down
34 changes: 12 additions & 22 deletions csrc/cpu/aten/kernels/MoEKrnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <c10/util/Exception.h>
#include <immintrin.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/utils/python_symnode.h>
#include <algorithm>
#include "tpp/kernels/TPPGEMMKrnl.h"

Expand All @@ -16,6 +15,15 @@ namespace cpu {

namespace {

at::Tensor call_AllReduce(const at::Tensor& self) {
static auto op_allreduce =
c10::Dispatcher::singleton()
.findSchemaOrThrow("deepspeed_comm::all_reduce", "")
.typed<at::Tensor(const at::Tensor& self)>();
auto ret = op_allreduce.call(self);
return ret;
}

at::Tensor mixtral_moe_tpp_kernl_impl(
const at::Tensor& hidden_states,
const at::Tensor& top_x,
Expand Down Expand Up @@ -46,13 +54,7 @@ at::Tensor mixtral_moe_tpp_kernl_impl(
tpp_linear_nobias_forward_cpu(curr_state, down_wei, c10::nullopt);
}
if (is_distributed) {
py::gil_scoped_acquire acquire;
py::function allreduce = py::module_::import("torch")
.attr("ops")
.attr("deepspeed_comm")
.attr("all_reduce");
allreduce(curr_state);
py::gil_scoped_release release;
call_AllReduce(curr_state);
}
curr_state = curr_state * routing_w;
output.index_add_(0, top_x, curr_state.squeeze(0).to(hidden_states.dtype()));
Expand Down Expand Up @@ -98,13 +100,7 @@ at::Tensor mixtral_moe_kernl_impl(
c10::nullopt);
}
if (is_distributed) {
py::gil_scoped_acquire acquire;
py::function allreduce = py::module_::import("torch")
.attr("ops")
.attr("deepspeed_comm")
.attr("all_reduce");
allreduce(curr_state);
py::gil_scoped_release release;
call_AllReduce(curr_state);
}
curr_state = curr_state * routing_w;
output.index_add_(0, top_x, curr_state.squeeze(0).to(hidden_states.dtype()));
Expand All @@ -130,13 +126,7 @@ at::Tensor mixtral_moe_woq_kernl_impl(
down_wei);

if (is_distributed) {
py::gil_scoped_acquire acquire;
py::function allreduce = py::module_::import("torch")
.attr("ops")
.attr("deepspeed_comm")
.attr("all_reduce");
allreduce(curr_state);
py::gil_scoped_release release;
call_AllReduce(curr_state);
}
curr_state = curr_state * routing_w;
output.index_add_(0, top_x, curr_state.squeeze(0).to(hidden_states.dtype()));
Expand Down
8 changes: 0 additions & 8 deletions tests/cpu/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,5 @@ target_link_libraries(${CPU_CPP_TEST_NAME} PUBLIC c10)
# Link IPEX
target_link_libraries(${CPU_CPP_TEST_NAME} PUBLIC intel-ext-pt-cpu)

find_package(PythonLibs)
if(${PYTHONLIBS_FOUND})
target_link_libraries(${CPU_CPP_TEST_NAME} PUBLIC ${PYTHON_LIBRARIES})
endif()

find_library(TORCH_PYTHON_LIBRARY torch_python PATH "${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(${CPU_CPP_TEST_NAME} PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY})

install(TARGETS ${CPU_CPP_TEST_NAME}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})

0 comments on commit 38573f2

Please sign in to comment.