diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 1e78608..f29fc72 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -91,7 +91,7 @@ class BaseRLModel(ABC): raise NotImplementedError() @abstractmethod - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run", + def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run", reset_num_timesteps=True): """ Return a trained model. diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 5702e87..476fc67 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -1,3 +1,5 @@ +import sys + import torch as th import torch.nn.functional as F import numpy as np @@ -22,7 +24,7 @@ class TD3(BaseRLModel): super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose) - self.max_action = float(self.action_space.high) + self.max_action = np.abs(self.action_space.high) self.replay_buffer = None self.device = device self.action_noise_std = action_noise_std @@ -38,6 +40,9 @@ class TD3(BaseRLModel): state_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0] set_random_seed(self.seed, using_cuda=self.device != 'cpu') + if self.env is not None: + self.env.seed(self.seed) + self.replay_buffer = ReplayBuffer(self.buffer_size, state_dim, action_dim, self.device) self.policy = self.policy(self.observation_space, self.action_space, self.learning_rate, device=self.device, **self.policy_kwargs) @@ -113,7 +118,7 @@ class TD3(BaseRLModel): for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) - def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, + def learn(self, total_timesteps, callback=None, log_interval=100, eval_freq=-1, n_eval_episodes=5, tb_log_name="TD3", reset_num_timesteps=True): timesteps_since_eval = 0 @@ -141,6 +146,7 @@ class TD3(BaseRLModel): evaluations.append(evaluate_policy(self, self.env, n_eval_episodes)) if self.verbose > 0: print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1])) + sys.stdout.flush() # Reset environment obs = self.env.reset()