diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 597072d..00d8093 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -207,7 +207,7 @@ class BaseRLModel(object): """ raise NotImplementedError() - def seed(self, seed=0): + def set_random_seed(self, seed=0): set_random_seed(seed, using_cuda=self.device == th.device('cuda')) self.action_space.seed(seed) if self.env is not None: diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index 76a6a73..5ef0793 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -40,7 +40,7 @@ class PPO(BaseRLModel): verbose, device, create_eval_env=create_eval_env, support_multi_env=True) self.learning_rate = learning_rate - self._seed = seed + self.seed = seed self.batch_size = batch_size self.n_optim = n_optim self.n_steps = n_steps @@ -63,7 +63,7 @@ class PPO(BaseRLModel): state_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] # TODO: different seed for each env when n_envs > 1 if self.n_envs == 1: - self.seed(self._seed) + self.set_random_seed(self.seed) self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device, gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs) diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index 8fb3708..1465e3a 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -55,7 +55,7 @@ class SAC(BaseRLModel): tau=0.005, ent_coef='auto', target_update_interval=1, gradient_steps=1, target_entropy='auto', action_noise=None, gamma=0.99, action_noise_std=0.0, create_eval_env=False, - policy_kwargs=None, verbose=0, seed=0, + policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device, @@ -64,7 +64,7 @@ class SAC(BaseRLModel): self.max_action = np.abs(self.action_space.high) self.action_noise_std = action_noise_std self.learning_rate = learning_rate - self._seed = seed + self.seed = seed self.target_entropy = target_entropy self.log_ent_coef = None # self.target_update_interval = target_update_interval @@ -89,7 +89,7 @@ class SAC(BaseRLModel): def _setup_model(self): obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] - self.seed(self._seed) + self.set_random_seed(self.seed) # Target entropy is used when learning the entropy coefficient if self.target_entropy == 'auto': diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 7ede8e3..34f1051 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -57,7 +57,7 @@ class TD3(BaseRLModel): self.max_action = np.abs(self.action_space.high) self.action_noise_std = action_noise_std self.buffer_size = buffer_size - self._seed = seed + self.seed = seed self.buffer_size = buffer_size # TODO: accept callables @@ -78,7 +78,7 @@ class TD3(BaseRLModel): def _setup_model(self): obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] - self.seed(self._seed) + self.set_random_seed(self.seed) self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device) self.policy = self.policy(self.observation_space, self.action_space, self.learning_rate, device=self.device, **self.policy_kwargs)