Skip to content

Commit

Permalink
cpu only build fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Roger Waleffe authored and Roger Waleffe committed Oct 15, 2023
1 parent 51aca8b commit ea44305
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 3 deletions.
21 changes: 21 additions & 0 deletions scripts/compile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
mkdir build
cd build

#cmake ../ -DUSE_CUDA=1
#cmake ../ -DUSE_OMP=1
cmake ../ -DUSE_CUDA=1 -DUSE_OMP=1 #-DCMAKE_C_COMPILER=clang
#cmake ../ -DUSE_CUDA=1 -DUSE_OMP=1 -DCMAKE_BUILD_TYPE=Debug -DMARIUS_USE_ASAN=1

make marius_train -j
#make marius_train marius_eval -j
#make _pymarius -j
#make unit -j

#make bindings -j
#mkdir marius
#cp *.so marius
#cp -r ../src/python/* marius/

#cp *.so /usr/local/lib/python3.6/dist-packages/marius/

cd ..
33 changes: 33 additions & 0 deletions scripts/preprocess.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
export MARIUS_ONLY_PYTHON=1
export MARIUS_NO_BINDINGS=1

pip3 uninstall marius
pip3 install .

#mkdir ../datasets

#marius_preprocess --dataset fb15k_237 ../datasets/fb15k237/
#marius_preprocess --dataset fb15k_237 --num_partitions 16 ../datasets/fb15k237_partitioned/
#marius_preprocess --dataset fb15k_237 --num_partitions 128 ../datasets/fb15k237_partitioned_128/
#marius_preprocess --dataset fb15k_237 --num_partitions 256 ../datasets/fb15k237_partitioned_256/

#marius_preprocess --dataset live_journal ../datasets/livejournal/
#marius_preprocess --dataset live_journal --num_partitions 16 ../datasets/livejournal_16/

#marius_preprocess --dataset freebase86m ../datasets/freebase86m/
#marius_preprocess --dataset freebase86m --num_partitions 8 ../datasets/freebase86m_8/
#marius_preprocess --dataset freebase86m --num_partitions 16 ../datasets/freebase86m_16/
#marius_preprocess --dataset freebase86m --num_partitions 32 ../datasets/freebase86m_32/
#marius_preprocess --dataset freebase86m --num_partitions 64 ../datasets/freebase86m_64/
#marius_preprocess --dataset freebase86m --num_partitions 256 ../datasets/freebase86m_256/
#marius_preprocess --dataset freebase86m --num_partitions 1024 ../datasets/freebase86m_1024/
#marius_preprocess --dataset freebase86m --num_partitions 4096 ../datasets/freebase86m_4096/



#marius_preprocess --dataset cora ../datasets/cora/

#marius_preprocess --dataset OGBN_ARXIV ../datasets/arxiv/


#marius_preprocess --dataset fb15k_237 --output_dir datasets/fb15k_237_example/
9 changes: 7 additions & 2 deletions src/cpp/include/common/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ class DummyCudaStreamGuard {
class DummyCudaMultiStreamGuard {
public:
DummyCudaMultiStreamGuard(DummyCudaStream []) {}

DummyCudaMultiStreamGuard(std::initializer_list<DummyCudaStream>) {}

DummyCudaMultiStreamGuard(std::vector<DummyCudaStream *>) {}
};

#ifdef MARIUS_CUDA
Expand All @@ -75,7 +79,8 @@ typedef at::cuda::CUDAStreamGuard CudaStreamGuard;
typedef at::cuda::CUDAMultiStreamGuard CudaMultiStreamGuard;

using at::cuda::getStreamFromPool;
using at::cuda::getCurrentCUDAStream;
//using at::cuda::getCurrentCUDAStream;
inline CudaStream getCurrentCudaStream(int device_index = 0) { return at::cuda::getCurrentCUDAStream(device_index); }

#else
typedef DummyCudaEvent CudaEvent;
Expand All @@ -84,7 +89,7 @@ typedef DummyCudaStreamGuard CudaStreamGuard;
typedef DummyCudaMultiStreamGuard CudaMultiStreamGuard;

inline CudaStream getStreamFromPool(bool = false, int = 0) { return CudaStream(); }
inline CudaStream getCurrentCUDAStream(int = 0) { return CudaStream(); }
inline CudaStream getCurrentCudaStream(int = 0) { return CudaStream(); }
#endif

#ifndef IO_FLAGS
Expand Down
2 changes: 1 addition & 1 deletion src/cpp/src/data/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ void Batch::accumulateGradients(float learning_rate) {

status_ = BatchStatus::AccumulatedGradients;

getCurrentCUDAStream(node_gradients_.device().index()).synchronize();
getCurrentCudaStream(node_gradients_.device().index()).synchronize();
}

void Batch::embeddingsToHost() {
Expand Down

0 comments on commit ea44305

Please sign in to comment.