mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-22 22:10:16 +00:00
Rename seed
This commit is contained in:
parent
32648d9029
commit
6bfbb7198a
4 changed files with 8 additions and 8 deletions
|
|
@ -207,7 +207,7 @@ class BaseRLModel(object):
|
|||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def seed(self, seed=0):
|
||||
def set_random_seed(self, seed=0):
|
||||
set_random_seed(seed, using_cuda=self.device == th.device('cuda'))
|
||||
self.action_space.seed(seed)
|
||||
if self.env is not None:
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class PPO(BaseRLModel):
|
|||
verbose, device, create_eval_env=create_eval_env, support_multi_env=True)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self._seed = seed
|
||||
self.seed = seed
|
||||
self.batch_size = batch_size
|
||||
self.n_optim = n_optim
|
||||
self.n_steps = n_steps
|
||||
|
|
@ -63,7 +63,7 @@ class PPO(BaseRLModel):
|
|||
state_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
# TODO: different seed for each env when n_envs > 1
|
||||
if self.n_envs == 1:
|
||||
self.seed(self._seed)
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ class SAC(BaseRLModel):
|
|||
tau=0.005, ent_coef='auto', target_update_interval=1,
|
||||
gradient_steps=1, target_entropy='auto', action_noise=None,
|
||||
gamma=0.99, action_noise_std=0.0, create_eval_env=False,
|
||||
policy_kwargs=None, verbose=0, seed=0,
|
||||
policy_kwargs=None, verbose=0, seed=0, device='auto',
|
||||
_init_setup_model=True):
|
||||
|
||||
super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device,
|
||||
|
|
@ -64,7 +64,7 @@ class SAC(BaseRLModel):
|
|||
self.max_action = np.abs(self.action_space.high)
|
||||
self.action_noise_std = action_noise_std
|
||||
self.learning_rate = learning_rate
|
||||
self._seed = seed
|
||||
self.seed = seed
|
||||
self.target_entropy = target_entropy
|
||||
self.log_ent_coef = None
|
||||
# self.target_update_interval = target_update_interval
|
||||
|
|
@ -89,7 +89,7 @@ class SAC(BaseRLModel):
|
|||
|
||||
def _setup_model(self):
|
||||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
self.seed(self._seed)
|
||||
self.set_random_seed(self.seed)
|
||||
|
||||
# Target entropy is used when learning the entropy coefficient
|
||||
if self.target_entropy == 'auto':
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class TD3(BaseRLModel):
|
|||
self.max_action = np.abs(self.action_space.high)
|
||||
self.action_noise_std = action_noise_std
|
||||
self.buffer_size = buffer_size
|
||||
self._seed = seed
|
||||
self.seed = seed
|
||||
|
||||
self.buffer_size = buffer_size
|
||||
# TODO: accept callables
|
||||
|
|
@ -78,7 +78,7 @@ class TD3(BaseRLModel):
|
|||
|
||||
def _setup_model(self):
|
||||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
self.seed(self._seed)
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy(self.observation_space, self.action_space,
|
||||
self.learning_rate, device=self.device, **self.policy_kwargs)
|
||||
|
|
|
|||
Loading…
Reference in a new issue