Skip to content

Commit

Permalink
Add predict_trajectory to Vehicle
Browse files Browse the repository at this point in the history
This should enable predict_trajectory=True with continuous action spaces

Fix #239
  • Loading branch information
eleurent committed Sep 23, 2023
1 parent e942d0f commit f933a57
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
17 changes: 10 additions & 7 deletions highway_env/envs/common/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,26 +126,29 @@ def space(self) -> spaces.Box:
def vehicle_class(self) -> Callable:
return Vehicle if not self.dynamical else BicycleVehicle

def act(self, action: np.ndarray) -> None:
def get_action(self, action: np.ndarray):
if self.clip:
action = np.clip(action, -1, 1)
if self.speed_range:
self.controlled_vehicle.MIN_SPEED, self.controlled_vehicle.MAX_SPEED = self.speed_range
if self.longitudinal and self.lateral:
self.controlled_vehicle.act({
return {
"acceleration": utils.lmap(action[0], [-1, 1], self.acceleration_range),
"steering": utils.lmap(action[1], [-1, 1], self.steering_range),
})
}
elif self.longitudinal:
self.controlled_vehicle.act({
return {
"acceleration": utils.lmap(action[0], [-1, 1], self.acceleration_range),
"steering": 0,
})
}
elif self.lateral:
self.controlled_vehicle.act({
return {
"acceleration": 0,
"steering": utils.lmap(action[0], [-1, 1], self.steering_range)
})
}

def act(self, action: np.ndarray) -> None:
self.controlled_vehicle.act(self.get_action(action))
self.last_action = action


Expand Down
2 changes: 2 additions & 0 deletions highway_env/envs/common/graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def set_agent_action_sequence(self, actions: List['Action']) -> None:
"""
if isinstance(self.env.action_type, DiscreteMetaAction):
actions = [self.env.action_type.actions[a] for a in actions]
elif isinstance(self.env.action_type, ContinuousAction):
actions = [self.env.action_type.get_action(a) for a in actions]
if len(actions) > 1:
self.vehicle_trajectory = self.env.vehicle.predict_trajectory(actions,
1 / self.env.config["policy_frequency"],
Expand Down
23 changes: 23 additions & 0 deletions highway_env/vehicle/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,26 @@ def __str__(self):

def __repr__(self):
return self.__str__()

def predict_trajectory(self, actions: List, action_duration: float, trajectory_timestep: float, dt: float) \
-> List['Vehicle']:
"""
Predict the future trajectory of the vehicle given a sequence of actions.
:param actions: a sequence of future actions.
:param action_duration: the duration of each action.
:param trajectory_timestep: the duration between each save of the vehicle state.
:param dt: the timestep of the simulation
:return: the sequence of future states
"""
states = []
v = copy.deepcopy(self)
t = 0
for action in actions:
v.act(action) # Low-level control action
for _ in range(int(action_duration / dt)):
t += 1
v.step(dt)
if (t % int(trajectory_timestep / dt)) == 0:
states.append(copy.deepcopy(v))
return states

0 comments on commit f933a57

Please sign in to comment.