diff --git a/torchy_baselines/td3/td3.py b/torchy_baselines/td3/td3.py index 90ab127..2dbdd5f 100644 --- a/torchy_baselines/td3/td3.py +++ b/torchy_baselines/td3/td3.py @@ -206,6 +206,7 @@ class TD3(BaseRLModel): # Normalize returns # returns = (returns - returns.mean()) / (returns.std() + 1e-8) + returns = (returns - returns.mean()) policy_loss = -(returns * log_prob).mean()