From c0cb9fc9c57e0a3fd14c5675c2bc206d64f562f4 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 29 Oct 2019 18:30:36 +0100 Subject: [PATCH] Fix predict method --- torchy_baselines/ppo/ppo.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 8171eb5..8127307 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -127,14 +127,14 @@ class PPO(BaseRLModel): if self.clip_range_vf is not None: self.clip_range_vf = get_schedule_fn(self.clip_range_vf) - def select_action(self, observation): + def select_action(self, observation, deterministic=False): # Normally not needed observation = np.array(observation) with th.no_grad(): observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device) - return self.policy.actor_forward(observation, deterministic=False) + return self.policy.actor_forward(observation, deterministic=deterministic) - def predict(self, observation, state=None, mask=None, deterministic=True): + def predict(self, observation, state=None, mask=None, deterministic=False): """ Get the model's action from an observation @@ -144,7 +144,7 @@ class PPO(BaseRLModel): :param deterministic: (bool) Whether or not to return deterministic actions. :return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies) """ - clipped_actions = self.select_action(observation) + clipped_actions = self.select_action(observation, deterministic=deterministic) if isinstance(self.action_space, gym.spaces.Box): clipped_actions = np.clip(clipped_actions, self.action_space.low, self.action_space.high) return clipped_actions