Seed env + fix max action

This commit is contained in:
Antonin Raffin 2019-09-06 11:09:56 +02:00
parent 9cf289b997
commit 68028c71a1
2 changed files with 9 additions and 3 deletions

View file

@ -91,7 +91,7 @@ class BaseRLModel(ABC):
raise NotImplementedError()
@abstractmethod
def learn(self, total_timesteps, callback=None, seed=None, log_interval=100, tb_log_name="run",
def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="run",
reset_num_timesteps=True):
"""
Return a trained model.

View file

@ -1,3 +1,5 @@
import sys
import torch as th
import torch.nn.functional as F
import numpy as np
@ -22,7 +24,7 @@ class TD3(BaseRLModel):
super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose)
self.max_action = float(self.action_space.high)
self.max_action = np.abs(self.action_space.high)
self.replay_buffer = None
self.device = device
self.action_noise_std = action_noise_std
@ -38,6 +40,9 @@ class TD3(BaseRLModel):
state_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
set_random_seed(self.seed, using_cuda=self.device != 'cpu')
if self.env is not None:
self.env.seed(self.seed)
self.replay_buffer = ReplayBuffer(self.buffer_size, state_dim, action_dim, self.device)
self.policy = self.policy(self.observation_space, self.action_space,
self.learning_rate, device=self.device, **self.policy_kwargs)
@ -113,7 +118,7 @@ class TD3(BaseRLModel):
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
def learn(self, total_timesteps, callback=None, seed=None, log_interval=100,
def learn(self, total_timesteps, callback=None, log_interval=100,
eval_freq=-1, n_eval_episodes=5, tb_log_name="TD3", reset_num_timesteps=True):
timesteps_since_eval = 0
@ -141,6 +146,7 @@ class TD3(BaseRLModel):
evaluations.append(evaluate_policy(self, self.env, n_eval_episodes))
if self.verbose > 0:
print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1]))
sys.stdout.flush()
# Reset environment
obs = self.env.reset()