Rename seed

This commit is contained in:
Antonin Raffin 2019-09-24 16:59:47 +02:00
parent 32648d9029
commit 6bfbb7198a
4 changed files with 8 additions and 8 deletions

View file

@ -207,7 +207,7 @@ class BaseRLModel(object):
"""
raise NotImplementedError()
def seed(self, seed=0):
def set_random_seed(self, seed=0):
set_random_seed(seed, using_cuda=self.device == th.device('cuda'))
self.action_space.seed(seed)
if self.env is not None:

View file

@ -40,7 +40,7 @@ class PPO(BaseRLModel):
verbose, device, create_eval_env=create_eval_env, support_multi_env=True)
self.learning_rate = learning_rate
self._seed = seed
self.seed = seed
self.batch_size = batch_size
self.n_optim = n_optim
self.n_steps = n_steps
@ -63,7 +63,7 @@ class PPO(BaseRLModel):
state_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
# TODO: different seed for each env when n_envs > 1
if self.n_envs == 1:
self.seed(self._seed)
self.set_random_seed(self.seed)
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)

View file

@ -55,7 +55,7 @@ class SAC(BaseRLModel):
tau=0.005, ent_coef='auto', target_update_interval=1,
gradient_steps=1, target_entropy='auto', action_noise=None,
gamma=0.99, action_noise_std=0.0, create_eval_env=False,
policy_kwargs=None, verbose=0, seed=0,
policy_kwargs=None, verbose=0, seed=0, device='auto',
_init_setup_model=True):
super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device,
@ -64,7 +64,7 @@ class SAC(BaseRLModel):
self.max_action = np.abs(self.action_space.high)
self.action_noise_std = action_noise_std
self.learning_rate = learning_rate
self._seed = seed
self.seed = seed
self.target_entropy = target_entropy
self.log_ent_coef = None
# self.target_update_interval = target_update_interval
@ -89,7 +89,7 @@ class SAC(BaseRLModel):
def _setup_model(self):
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
self.seed(self._seed)
self.set_random_seed(self.seed)
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == 'auto':

View file

@ -57,7 +57,7 @@ class TD3(BaseRLModel):
self.max_action = np.abs(self.action_space.high)
self.action_noise_std = action_noise_std
self.buffer_size = buffer_size
self._seed = seed
self.seed = seed
self.buffer_size = buffer_size
# TODO: accept callables
@ -78,7 +78,7 @@ class TD3(BaseRLModel):
def _setup_model(self):
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
self.seed(self._seed)
self.set_random_seed(self.seed)
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
self.policy = self.policy(self.observation_space, self.action_space,
self.learning_rate, device=self.device, **self.policy_kwargs)