From ee0f409f13e03b96adc98500431b4cbf00243ad8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 20 Jun 2024 19:44:07 -0400 Subject: [PATCH] enable pytorch ops in cuda test --- .github/workflows/test_cuda.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index d97b1f9431..81ec974e33 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -49,11 +49,13 @@ jobs: - run: python -m pip install -U uv - run: python -m uv pip install --system "tensorflow>=2.15.0rc0" "torch>=2.2.0" - run: | + export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') python -m uv pip install --system -v -e .[gpu,test,lmp,cu12,torch] mpi4py env: DP_VARIANT: cuda DP_ENABLE_NATIVE_OPTIMIZATION: 1 + DP_ENABLE_PYTORCH: 1 - run: dp --version - run: python -m pytest source/tests --durations=0 env: