diff --git a/README.md b/README.md index b8fac6a..5f70f36 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,6 @@ PyTorch version of [Stable Baselines](https://github.com/hill-a/stable-baselines TODO: - save/load - predict -- better rescale (min + action * range) - flexible mlp - logger - better monitor wrapper? diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index a4b846b..310d2c8 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -251,9 +251,12 @@ class BaseRLModel(object): while not done: # Select action randomly or according to policy if num_timesteps < learning_starts: - action = [self.action_space.sample()] + action = np.array([self.action_space.sample()]) else: - action = self.scale_action(self.predict(obs, deterministic=deterministic)) + action = self.predict(obs, deterministic=deterministic) + + # Rescale the action from [low, high] to [-1, 1] + action = self.scale_action(action) # Add noise to the action (improve exploration) if action_noise is not None: