mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-06 00:03:28 +00:00
Fix learning starts
This commit is contained in:
parent
440166fe26
commit
12f854e1aa
2 changed files with 3 additions and 3 deletions
|
|
@ -251,7 +251,7 @@ class SAC(BaseRLModel):
|
|||
episode_num += n_episodes
|
||||
timesteps_since_eval += episode_timesteps
|
||||
|
||||
if self.num_timesteps > 0:
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
if self.verbose > 1:
|
||||
print("Total T: {} Episode Num: {} Episode T: {} Reward: {}".format(
|
||||
self.num_timesteps, episode_num, episode_timesteps, episode_reward))
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class TD3(BaseRLModel):
|
|||
observation = np.array(observation)
|
||||
with th.no_grad():
|
||||
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
|
||||
return self.actor(observation).cpu().data.numpy()
|
||||
return self.actor(observation).cpu().numpy()
|
||||
|
||||
def predict(self, observation, state=None, mask=None, deterministic=True):
|
||||
"""
|
||||
|
|
@ -222,7 +222,7 @@ class TD3(BaseRLModel):
|
|||
self.num_timesteps += episode_timesteps
|
||||
timesteps_since_eval += episode_timesteps
|
||||
|
||||
if self.num_timesteps > 0:
|
||||
if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
|
||||
if self.verbose > 1:
|
||||
print("Total T: {} Episode Num: {} Episode T: {} Reward: {}".format(
|
||||
self.num_timesteps, episode_num, episode_timesteps, episode_reward))
|
||||
|
|
|
|||
Loading…
Reference in a new issue