mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-26 22:45:15 +00:00
Fix predict method
This commit is contained in:
parent
42d50ed09b
commit
c0cb9fc9c5
1 changed files with 4 additions and 4 deletions
|
|
@ -127,14 +127,14 @@ class PPO(BaseRLModel):
|
|||
if self.clip_range_vf is not None:
|
||||
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
|
||||
|
||||
def select_action(self, observation):
|
||||
def select_action(self, observation, deterministic=False):
|
||||
# Normally not needed
|
||||
observation = np.array(observation)
|
||||
with th.no_grad():
|
||||
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
|
||||
return self.policy.actor_forward(observation, deterministic=False)
|
||||
return self.policy.actor_forward(observation, deterministic=deterministic)
|
||||
|
||||
def predict(self, observation, state=None, mask=None, deterministic=True):
|
||||
def predict(self, observation, state=None, mask=None, deterministic=False):
|
||||
"""
|
||||
Get the model's action from an observation
|
||||
|
||||
|
|
@ -144,7 +144,7 @@ class PPO(BaseRLModel):
|
|||
:param deterministic: (bool) Whether or not to return deterministic actions.
|
||||
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
|
||||
"""
|
||||
clipped_actions = self.select_action(observation)
|
||||
clipped_actions = self.select_action(observation, deterministic=deterministic)
|
||||
if isinstance(self.action_space, gym.spaces.Box):
|
||||
clipped_actions = np.clip(clipped_actions, self.action_space.low, self.action_space.high)
|
||||
return clipped_actions
|
||||
|
|
|
|||
Loading…
Reference in a new issue