diff --git a/include/cddp-cpp/cddp_core/cddp_core.hpp b/include/cddp-cpp/cddp_core/cddp_core.hpp index 0f170b4..733d2da 100644 --- a/include/cddp-cpp/cddp_core/cddp_core.hpp +++ b/include/cddp-cpp/cddp_core/cddp_core.hpp @@ -162,7 +162,17 @@ class CDDP { * @brief Set the Reference state * @param reference_state Reference state */ - void setReferenceState(const Eigen::VectorXd& reference_state) { reference_state_ = reference_state; } + void setReferenceState(const Eigen::VectorXd& reference_state) { + reference_state_ = reference_state; + // Update the objective reference state + objective_->setReferenceState(reference_state); + } + + void setReferenceStates(const std::vector& reference_states) { + reference_states_ = reference_states; + // Update the objective reference states + objective_->setReferenceStates(reference_states); + } /** * @brief Set the time horizon for the problem @@ -282,6 +292,7 @@ class CDDP { std::unique_ptr log_barrier_; Eigen::VectorXd initial_state_; Eigen::VectorXd reference_state_; // Desired reference state + std::vector reference_states_; // Desired reference states (trajectory) int horizon_; // Time horizon for the problem double timestep_; // Time step for the problem CDDPOptions options_; // Options for the solver diff --git a/tests/cddp_core/test_logcddp_core.cpp b/tests/cddp_core/test_logcddp_core.cpp index 881ae0b..b2a203c 100644 --- a/tests/cddp_core/test_logcddp_core.cpp +++ b/tests/cddp_core/test_logcddp_core.cpp @@ -46,7 +46,7 @@ TEST(CDDPTest, SolveLogCDDP) { 0.0, 0.0, 10.0; Qf = 0.5 * Qf; Eigen::VectorXd goal_state(state_dim); - goal_state << 2.0, 2.0, M_PI/2.0; + goal_state << 3.0, 2.0, M_PI/2.0; // Create an empty vector of Eigen::VectorXd std::vector empty_reference_states; @@ -82,6 +82,10 @@ TEST(CDDPTest, SolveLogCDDP) { cddp_solver.setDynamicalSystem(std::move(system)); cddp_solver.setObjective(std::move(objective)); + // Update goal state (for test) + goal_state << 2.0, 2.0, M_PI/2.0; + cddp_solver.setReferenceState(goal_state); + // Define control box constraints Eigen::VectorXd control_lower_bound(control_dim); control_lower_bound << -1.0, -M_PI;