mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-31 23:28:05 +00:00
Seed env + fix max action
This commit is contained in:
parent
9cf289b997
commit
68028c71a1
2 changed files with 9 additions and 3 deletions
|
|
@ -91,7 +91,7 @@ class BaseRLModel(ABC):
|
|||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run",
|
||||
def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run",
|
||||
reset_num_timesteps=True):
|
||||
"""
|
||||
Return a trained model.
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
import sys
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
|
@ -22,7 +24,7 @@ class TD3(BaseRLModel):
|
|||
|
||||
super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose)
|
||||
|
||||
self.max_action = float(self.action_space.high)
|
||||
self.max_action = np.abs(self.action_space.high)
|
||||
self.replay_buffer = None
|
||||
self.device = device
|
||||
self.action_noise_std = action_noise_std
|
||||
|
|
@ -38,6 +40,9 @@ class TD3(BaseRLModel):
|
|||
state_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
set_random_seed(self.seed, using_cuda=self.device != 'cpu')
|
||||
|
||||
if self.env is not None:
|
||||
self.env.seed(self.seed)
|
||||
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, state_dim, action_dim, self.device)
|
||||
self.policy = self.policy(self.observation_space, self.action_space,
|
||||
self.learning_rate, device=self.device, **self.policy_kwargs)
|
||||
|
|
@ -113,7 +118,7 @@ class TD3(BaseRLModel):
|
|||
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
|
||||
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
|
||||
|
||||
def learn(self, total_timesteps, callback=None, seed=None, log_interval=100,
|
||||
def learn(self, total_timesteps, callback=None, log_interval=100,
|
||||
eval_freq=-1, n_eval_episodes=5, tb_log_name="TD3", reset_num_timesteps=True):
|
||||
|
||||
timesteps_since_eval = 0
|
||||
|
|
@ -141,6 +146,7 @@ class TD3(BaseRLModel):
|
|||
evaluations.append(evaluate_policy(self, self.env, n_eval_episodes))
|
||||
if self.verbose > 0:
|
||||
print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1]))
|
||||
sys.stdout.flush()
|
||||
|
||||
# Reset environment
|
||||
obs = self.env.reset()
|
||||
|
|
|
|||
Loading…
Reference in a new issue