diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 557e901..645ecf2 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -149,9 +149,9 @@ class TD3(BaseRLModel): for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()): target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) - def train_actor(self, gradient_steps: object = 1, batch_size: object = 100, tau_actor: object = 0.005, - tau_critic: object = 0.005, - replay_data: object = None) -> object: + def train_actor(self, gradient_steps=1, batch_size=100, tau_actor=0.005, + tau_critic=0.005, + replay_data=None): # Update optimizer learning rate self._update_learning_rate(self.actor.optimizer)