diff --git a/CMakeLists.txt b/CMakeLists.txt index 60c40d7..0e5d076 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,8 +40,8 @@ option(GUROBI_ROOT "Path to Gurobi installation" "") set(GUROBI_ROOT $ENV{HOME}/.local/lib/gurobi1103/linux64) # LibTorch Configuration -set(CDDP_CPP_TORCH "Whether to use LibTorch" ON) -option(CDDP_CPP_TORCH_GPU "Whether to use GPU support in LibTorch" ON) +option(CDDP_CPP_TORCH "Whether to use LibTorch" OFF) +option(CDDP_CPP_TORCH_GPU "Whether to use GPU support in LibTorch" OFF) set(LIBTORCH_DIR $ENV{HOME}/.local/lib/libtorch CACHE PATH "Path to local LibTorch installation") # FIXME: Change this to your local LibTorch installation directory # Python Configuration @@ -177,6 +177,8 @@ if (CDDP_CPP_TORCH) # Export LibTorch variables for other parts of the build set(TORCH_INSTALL_PREFIX ${Torch_DIR}/../../../ CACHE PATH "LibTorch installation directory") + + target_compile_definitions(${PROJECT_NAME} PRIVATE CDDP_CPP_TORCH_ENABLED=1) endif() @@ -200,9 +202,12 @@ set(cddp_core_srcs src/cddp_core/clddp_core.cpp src/cddp_core/asddp_core.cpp src/cddp_core/logddp_core.cpp - src/cddp_core/neural_dynamical_system.cpp ) +if (CDDP_CPP_TORCH) + list(APPEND cddp_core_srcs src/cddp_core/torch_helper.cpp) +endif() + set(dynamics_model_srcs src/dynamics_model/pendulum.cpp src/dynamics_model/dubins_car.cpp @@ -228,9 +233,12 @@ target_link_libraries(${PROJECT_NAME} Python3::Python Python3::Module Python3::NumPy - ${TORCH_LIBRARIES} ) +if (CDDP_CPP_TORCH) + target_link_libraries(${PROJECT_NAME} ${TORCH_LIBRARIES}) +endif() + target_include_directories(${PROJECT_NAME} PUBLIC $ $ @@ -238,7 +246,7 @@ target_include_directories(${PROJECT_NAME} PUBLIC ) # Ensure proper CUDA support if enabled -if(CDDP_CPP_TORCH_GPU) +if(TORCH_FOUND AND CDDP_CPP_TORCH_GPU) set_property(TARGET ${PROJECT_NAME} PROPERTY CUDA_ARCHITECTURES native) endif() diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 082328b..b63e8f7 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -44,23 +44,25 @@ if (CDDP_CPP_CASADI) endif() # Neural dynamics examples -add_executable(prepare_pendulum neural_dynamics/prepare_pendulum.cpp) -target_link_libraries(prepare_pendulum cddp) +if (CDDP_CPP_TORCH) + add_executable(prepare_pendulum neural_dynamics/prepare_pendulum.cpp) + target_link_libraries(prepare_pendulum cddp) -# add_executable(prepare_cartpole neural_dynamics/prepare_cartpole.cpp) -# target_link_libraries(prepare_cartpole cddp) + # add_executable(prepare_cartpole neural_dynamics/prepare_cartpole.cpp) + # target_link_libraries(prepare_cartpole cddp) -add_executable(train_pendulum neural_dynamics/train_pendulum.cpp) -target_link_libraries(train_pendulum cddp) + add_executable(train_pendulum neural_dynamics/train_pendulum.cpp) + target_link_libraries(train_pendulum cddp) -# add_executable(train_cartpole neural_dynamics/train_cartpole.cpp) -# target_link_libraries(train_cartpole cddp) + # add_executable(train_cartpole neural_dynamics/train_cartpole.cpp) + # target_link_libraries(train_cartpole cddp) -add_executable(run_pendulum neural_dynamics/run_pendulum.cpp) -target_link_libraries(run_pendulum cddp) + add_executable(run_pendulum neural_dynamics/run_pendulum.cpp) + target_link_libraries(run_pendulum cddp) -# add_executable(run_cartpole neural_dynamics/run_cartpole.cpp) -# target_link_libraries(run_cartpole cddp) + # add_executable(run_cartpole neural_dynamics/run_cartpole.cpp) + # target_link_libraries(run_cartpole cddp) -# add_executable(cddp_pendulum_neural _cddp_pendulum_neural.cpp) -# target_link_libraries(cddp_pendulum_neural cddp) + # add_executable(cddp_pendulum_neural _cddp_pendulum_neural.cpp) + # target_link_libraries(cddp_pendulum_neural cddp) +endif() diff --git a/include/cddp-cpp/cddp.hpp b/include/cddp-cpp/cddp.hpp index 0f4520a..73caa41 100644 --- a/include/cddp-cpp/cddp.hpp +++ b/include/cddp-cpp/cddp.hpp @@ -30,7 +30,10 @@ #include "cddp_core/helper.hpp" #include "cddp_core/boxqp.hpp" #include "cddp_core/qp_solver.hpp" + +#ifdef CDDP_CPP_TORCH_ENABLED #include "cddp_core/neural_dynamical_system.hpp" +#endif // Models #include "dynamics_model/pendulum.hpp"