Fix predict method

This commit is contained in:
Antonin Raffin 2019-10-29 18:30:36 +01:00
parent 42d50ed09b
commit c0cb9fc9c5

View file

@ -127,14 +127,14 @@ class PPO(BaseRLModel):
if self.clip_range_vf is not None:
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
def select_action(self, observation):
def select_action(self, observation, deterministic=False):
# Normally not needed
observation = np.array(observation)
with th.no_grad():
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
return self.policy.actor_forward(observation, deterministic=False)
return self.policy.actor_forward(observation, deterministic=deterministic)
def predict(self, observation, state=None, mask=None, deterministic=True):
def predict(self, observation, state=None, mask=None, deterministic=False):
"""
Get the model's action from an observation
@ -144,7 +144,7 @@ class PPO(BaseRLModel):
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
"""
clipped_actions = self.select_action(observation)
clipped_actions = self.select_action(observation, deterministic=deterministic)
if isinstance(self.action_space, gym.spaces.Box):
clipped_actions = np.clip(clipped_actions, self.action_space.low, self.action_space.high)
return clipped_actions