diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 09c8fde..a44a03f 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -13,6 +13,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Add support for Discrete observation spaces Bug Fixes: ^^^^^^^^^^ diff --git a/tests/test_identity.py b/tests/test_identity.py new file mode 100644 index 0000000..64f6780 --- /dev/null +++ b/tests/test_identity.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest + +from torchy_baselines import A2C, PPO, SAC, TD3 +from torchy_baselines.common.identity_env import IdentityEnvBox, IdentityEnv +from torchy_baselines.common.evaluation import evaluate_policy +from torchy_baselines.common.noise import NormalActionNoise + + +@pytest.mark.parametrize("model_class", [A2C, PPO]) +def test_discrete(model_class): + env = IdentityEnv(10) + model = model_class('MlpPolicy', env, gamma=0.5, seed=0).learn(3000) + + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) + + +@pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3]) +def test_continuous(model_class): + env = IdentityEnvBox(eps=0.5) + + n_steps = { + A2C: 3000, + PPO: 3000, + SAC: 500, + TD3: 500 + }[model_class] + + kwargs = dict( + policy_kwargs=dict(net_arch=[64, 64]), + seed=0, + gamma=0.95 + ) + if model_class in [TD3]: + n_actions = 1 + action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) + kwargs['action_noise'] = action_noise + + model = model_class('MlpPolicy', env, **kwargs).learn(n_steps) + + evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90) diff --git a/tests/test_logger.py b/tests/test_logger.py index 5ca0437..df6d059 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -5,8 +5,8 @@ import pytest import numpy as np from torchy_baselines.common.logger import (make_output_format, read_csv, read_json, DEBUG, ScopedConfigure, - info, debug, set_level, configure, logkv, logkvs, dumpkvs, logkv_mean, warn, - error, reset) + info, debug, set_level, configure, logkv, logkvs, + dumpkvs, logkv_mean, warn, error, reset) KEY_VALUES = { "test": 1, diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 9c528f4..076ec92 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -1,9 +1,10 @@ -import numpy as np import os -import pytest -import torch as th from copy import deepcopy +import pytest +import numpy as np +import torch as th + from torchy_baselines import A2C, PPO, SAC, TD3 from torchy_baselines.common.identity_env import IdentityEnvBox from torchy_baselines.common.vec_env import DummyVecEnv diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index cda1866..c171a4c 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -338,15 +338,15 @@ class BaseRLModel(ABC): "Box environment, please use {} ".format(observation_space.shape) + "or (n_env, {}) for the observation shape." .format(", ".join(map(str, observation_space.shape)))) - # TODO: add support for Discrete, MultiDiscrete and MultiBinary observation spaces - # elif isinstance(observation_space, gym.spaces.Discrete): - # if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' - # return False - # elif len(observation.shape) == 1: - # return True - # else: - # raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + - # "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") + elif isinstance(observation_space, gym.spaces.Discrete): + if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' + return False + elif len(observation.shape) == 1: + return True + else: + raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) + + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape.") + # TODO: add support for MultiDiscrete and MultiBinary observation spaces # elif isinstance(observation_space, gym.spaces.MultiDiscrete): # if observation.shape == (len(observation_space.nvec),): # return False diff --git a/torchy_baselines/common/evaluation.py b/torchy_baselines/common/evaluation.py index 704057c..b5017c9 100644 --- a/torchy_baselines/common/evaluation.py +++ b/torchy_baselines/common/evaluation.py @@ -49,8 +49,8 @@ def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True, mean_reward = np.mean(episode_rewards) std_reward = np.std(episode_rewards) if reward_threshold is not None: - assert mean_reward > reward_threshold, (f'Mean reward below threshold: ' - '{mean_reward:.2f} < {reward_threshold:.2f}') + assert mean_reward > reward_threshold, ('Mean reward below threshold: ' + f'{mean_reward:.2f} < {reward_threshold:.2f}') if return_episode_rewards: return episode_rewards, episode_lengths return mean_reward, std_reward diff --git a/torchy_baselines/common/preprocessing.py b/torchy_baselines/common/preprocessing.py index fa312bf..ebef581 100644 --- a/torchy_baselines/common/preprocessing.py +++ b/torchy_baselines/common/preprocessing.py @@ -52,7 +52,7 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, return obs.float() elif isinstance(observation_space, spaces.Discrete): # One hot encoding and convert to float to avoid errors - return F.one_hot(obs, num_classes=observation_space.n).float() + return F.one_hot(obs.long(), num_classes=observation_space.n).float() else: # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict raise NotImplementedError() @@ -88,8 +88,8 @@ def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]: # raise NotImplementedError() return np.prod(observation_space.shape) elif isinstance(observation_space, spaces.Discrete): - # Observation is an int - return 1 + # Observation is a one hot vector + return observation_space.n else: # TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict raise NotImplementedError()