From 4d0c033bf2ff53a1510b31e7b1446dda0d6d27fa Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 7 Oct 2019 16:36:48 +0200 Subject: [PATCH] Bug fix when randomly sampling actions --- README.md | 1 - torchy_baselines/common/base_class.py | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b8fac6a..5f70f36 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,6 @@ PyTorch version of [Stable Baselines](https://github.com/hill-a/stable-baselines TODO: - save/load - predict -- better rescale (min + action * range) - flexible mlp - logger - better monitor wrapper? diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index a4b846b..310d2c8 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -251,9 +251,12 @@ class BaseRLModel(object): while not done: # Select action randomly or according to policy if num_timesteps < learning_starts: - action = [self.action_space.sample()] + action = np.array([self.action_space.sample()]) else: - action = self.scale_action(self.predict(obs, deterministic=deterministic)) + action = self.predict(obs, deterministic=deterministic) + + # Rescale the action from [low, high] to [-1, 1] + action = self.scale_action(action) # Add noise to the action (improve exploration) if action_noise is not None: