From ab64ff464e5c75c995530b68276c2697744717c8 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 14 Oct 2019 11:09:22 +0200 Subject: [PATCH] Add tensorboard_log dummy arg --- torchy_baselines/cem_rl/cem_rl.py | 11 ++++++----- torchy_baselines/sac/sac.py | 2 +- torchy_baselines/td3/td3.py | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/torchy_baselines/cem_rl/cem_rl.py b/torchy_baselines/cem_rl/cem_rl.py index 90119c2..2116d3f 100644 --- a/torchy_baselines/cem_rl/cem_rl.py +++ b/torchy_baselines/cem_rl/cem_rl.py @@ -15,13 +15,14 @@ class CEMRL(TD3): Code: https://github.com/apourchot/CEM-RL """ - def __init__(self, policy, env, policy_kwargs=None, verbose=0, - sigma_init=1e-3, pop_size=10, damp=1e-3, damp_limit=1e-5, - elitism=False, n_grad=5, policy_delay=2, batch_size=100, - buffer_size=int(1e6), learning_rate=1e-3, seed=0, device='auto', + def __init__(self, policy, env, sigma_init=1e-3, pop_size=10, + damp=1e-3, damp_limit=1e-5, elitism=False, n_grad=5, + policy_delay=2, batch_size=100, + buffer_size=int(1e6), learning_rate=1e-3, action_noise=None, learning_starts=100, tau=0.005, n_episodes_rollout=1, update_style='original', - create_eval_env=False, + tensorboard_log=None, create_eval_env=False, + policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): super(CEMRL, self).__init__(policy, env, diff --git a/torchy_baselines/sac/sac.py b/torchy_baselines/sac/sac.py index c8a74ef..dca15f8 100644 --- a/torchy_baselines/sac/sac.py +++ b/torchy_baselines/sac/sac.py @@ -57,7 +57,7 @@ class SAC(BaseRLModel): tau=0.005, ent_coef='auto', target_update_interval=1, train_freq=1, gradient_steps=1, n_episodes_rollout=-1, target_entropy='auto', action_noise=None, - gamma=0.99, create_eval_env=False, + gamma=0.99, tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 3fdd6da..7a13c0a 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -50,7 +50,7 @@ class TD3(BaseRLModel): policy_delay=2, learning_starts=100, gamma=0.99, batch_size=100, train_freq=-1, gradient_steps=-1, n_episodes_rollout=1, tau=0.005, action_noise=None, target_policy_noise=0.2, target_noise_clip=0.5, - create_eval_env=False, policy_kwargs=None, verbose=0, + tensorboard_log=None, create_eval_env=False, policy_kwargs=None, verbose=0, seed=0, device='auto', _init_setup_model=True): super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose, device,