Refactor handling of obs and action space

+ remove duplicated code
This commit is contained in:
Antonin RAFFIN 2020-03-20 10:09:09 +01:00
parent 7251b9d2c2
commit 57b37513b6
12 changed files with 152 additions and 76 deletions

View file

@ -3,6 +3,29 @@
Changelog
==========
Pre-Release 0.4.0a0 (WIP)
------------------------------
Breaking Changes:
^^^^^^^^^^^^^^^^^
New Features:
^^^^^^^^^^^^^
Bug Fixes:
^^^^^^^^^^
Deprecations:
^^^^^^^^^^^^^
Others:
^^^^^^^
- Refactor handling of observation and action spaces
Documentation:
^^^^^^^^^^^^^^
Pre-Release 0.3.0 (2020-02-14)
------------------------------

View file

@ -31,6 +31,8 @@ class BaseRLModel(ABC):
:param env: (Union[GymEnv, str]) The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: (Type[BasePolicy]) The base policy used by this method
:param learning_rate: (float or callable) learning rate for the optimizer,
it can be a function of the current progress (from 1 to 0)
:param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation
:param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug
:param device: (Union[th.device, str]) Device on which the code should run.
@ -53,6 +55,7 @@ class BaseRLModel(ABC):
policy: Type[BasePolicy],
env: Union[GymEnv, str],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Callable],
policy_kwargs: Dict[str, Any] = None,
verbose: int = 0,
device: Union[th.device, str] = 'auto',
@ -89,7 +92,7 @@ class BaseRLModel(ABC):
self.action_noise = None # type: Optional[ActionNoise]
self.start_time = None
self.policy = None
self.learning_rate = None # type: Optional[float]
self.learning_rate = learning_rate
self.lr_schedule = None # type: Optional[Callable]
# Used for SDE only
self.use_sde = use_sde
@ -135,7 +138,7 @@ class BaseRLModel(ABC):
@abstractmethod
def _setup_model(self) -> None:
"""
Create networks and optimizers
Create networks, buffer and optimizers
"""
raise NotImplementedError()
@ -761,6 +764,11 @@ class OffPolicyRLModel(BaseRLModel):
:param env: The environment to learn from
(if registered in Gym, can be str. Can be None for loading trained models)
:param policy_base: The base policy used by this method
:param learning_rate: (float or callable) learning rate for the optimizer,
it can be a function of the current progress (from 1 to 0)
:param buffer_size: (int) size of the replay buffer
:param learning_starts: (int) how many steps of the model to collect transitions for before learning starts
:param batch_size: (int) Minibatch size for each gradient update
:param policy_kwargs: Additional arguments to be passed to the policy on creation
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
:param device: Device on which the code should run.
@ -785,6 +793,10 @@ class OffPolicyRLModel(BaseRLModel):
policy: Type[BasePolicy],
env: Union[GymEnv, str],
policy_base: Type[BasePolicy],
learning_rate: Union[float, Callable],
buffer_size: int = int(1e6),
learning_starts: int = 100,
batch_size: int = 256,
policy_kwargs: Dict[str, Any] = None,
verbose: int = 0,
device: Union[th.device, str] = 'auto',
@ -796,16 +808,32 @@ class OffPolicyRLModel(BaseRLModel):
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False):
super(OffPolicyRLModel, self).__init__(policy, env, policy_base, policy_kwargs, verbose,
super(OffPolicyRLModel, self).__init__(policy, env, policy_base, learning_rate,
policy_kwargs, verbose,
device, support_multi_env, create_eval_env, monitor_wrapper,
seed, use_sde, sde_sample_freq)
self.buffer_size = buffer_size
self.batch_size = batch_size
self.learning_starts = learning_starts
self.actor = None
self.replay_buffer = None # type: Optional[ReplayBuffer]
# Update policy keyword arguments
self.policy_kwargs['use_sde'] = self.use_sde
self.policy_kwargs['device'] = self.device
# For SDE only
self.rollout_data = None
self.on_policy_exploration = False
self.actor = None
self.replay_buffer = None # type: Optional[ReplayBuffer]
self.use_sde_at_warmup = use_sde_at_warmup
def _setup_model(self):
self._setup_lr_schedule()
self.set_random_seed(self.seed)
self.replay_buffer = ReplayBuffer(self.buffer_size, self.observation_space,
self.action_space, self.device)
self.policy = self.policy_class(self.observation_space, self.action_space,
self.lr_schedule, **self.policy_kwargs)
self.policy = self.policy.to(self.device)
def save_replay_buffer(self, path: str):
"""
Save the replay buffer as a pickle file.

View file

@ -2,9 +2,11 @@ from typing import Union, Optional, Generator
import numpy as np
import torch as th
from gym import spaces
from torchy_baselines.common.vec_env import VecNormalize
from torchy_baselines.common.type_aliases import RolloutBufferSamples, ReplayBufferSamples
from torchy_baselines.common.preprocessing import get_obs_dim, get_action_dim
class BaseBuffer(object):
@ -12,23 +14,24 @@ class BaseBuffer(object):
Base class that represent a buffer (rollout or replay)
:param buffer_size: (int) Max number of element in the buffer
:param obs_dim: (int) Dimension of the observation
:param action_dim: (int) Dimension of the action space
:param observation_space: (spaces.Space) Observation space
:param action_space: (spaces.Space) Action space
:param device: (Union[th.device, str]) PyTorch device
to which the values will be converted
:param n_envs: (int) Number of parallel environments
"""
def __init__(self,
buffer_size: int,
obs_dim: int,
action_dim: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = 'cpu',
n_envs: int = 1):
super(BaseBuffer, self).__init__()
self.buffer_size = buffer_size
self.obs_dim = obs_dim
self.action_dim = action_dim
self.observation_space = observation_space
self.action_space = action_space
self.obs_dim = get_obs_dim(observation_space)
self.action_dim = get_action_dim(action_space)
self.pos = 0
self.full = False
self.device = device
@ -137,19 +140,20 @@ class ReplayBuffer(BaseBuffer):
Replay buffer used in off-policy algorithms like SAC/TD3.
:param buffer_size: (int) Max number of element in the buffer
:param obs_dim: (int) Dimension of the observation
:param action_dim: (int) Dimension of the action space
:param observation_space: (spaces.Space) Observation space
:param action_space: (spaces.Space) Action space
:param device: (th.device)
:param n_envs: (int) Number of parallel environments
"""
def __init__(self,
buffer_size: int,
obs_dim: int,
action_dim: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = 'cpu',
n_envs: int = 1):
super(ReplayBuffer, self).__init__(buffer_size, obs_dim, action_dim, device, n_envs=n_envs)
super(ReplayBuffer, self).__init__(buffer_size, observation_space,
action_space, device, n_envs=n_envs)
assert n_envs == 1, "Replay buffer only support single environment for now"
@ -194,8 +198,8 @@ class RolloutBuffer(BaseBuffer):
Rollout buffer used in on-policy algorithms like A2C/PPO.
:param buffer_size: (int) Max number of element in the buffer
:param obs_dim: (int) Dimension of the observation
:param action_dim: (int) Dimension of the action space
:param observation_space: (spaces.Space) Observation space
:param action_space: (spaces.Space) Action space
:param device: (th.device)
:param gae_lambda: (float) Factor for trade-off of bias vs variance for Generalized Advantage Estimator
Equivalent to classic advantage when set to 1.
@ -205,14 +209,15 @@ class RolloutBuffer(BaseBuffer):
def __init__(self,
buffer_size: int,
obs_dim: int,
action_dim: int,
observation_space: spaces.Space,
action_space: spaces.Space,
device: Union[th.device, str] = 'cpu',
gae_lambda: float = 1,
gamma: float = 0.99,
n_envs: int = 1):
super(RolloutBuffer, self).__init__(buffer_size, obs_dim, action_dim, device, n_envs=n_envs)
super(RolloutBuffer, self).__init__(buffer_size, observation_space,
action_space, device, n_envs=n_envs)
self.gae_lambda = gae_lambda
self.gamma = gamma
self.observations, self.actions, self.rewards, self.advantages = None, None, None, None

View file

@ -6,6 +6,8 @@ import torch.nn as nn
from torch.distributions import Normal, Categorical
from gym import spaces
from torchy_baselines.common.preprocessing import get_action_dim
class Distribution(object):
def __init__(self):
@ -59,7 +61,7 @@ class DiagGaussianDistribution(Distribution):
Gaussian distribution with diagonal covariance matrix,
for continuous actions.
:param action_dim: (int) Number of continuous actions
:param action_dim: (int) Dimension of the action space.
"""
def __init__(self, action_dim: int):
@ -144,7 +146,7 @@ class SquashedDiagGaussianDistribution(DiagGaussianDistribution):
Gaussian distribution with diagonal covariance matrix,
followed by a squashing function (tanh) to ensure bounds.
:param action_dim: (int) Number of continuous actions
:param action_dim: (int) Dimension of the action space.
:param epsilon: (float) small value to avoid NaN due to numerical imprecision.
"""
@ -252,7 +254,7 @@ class StateDependentNoiseDistribution(Distribution):
It is used to create the noise exploration matrix and
compute the log probabilty of an action with that noise.
:param action_dim: (int) Number of continuous actions
:param action_dim: (int) Dimension of the action space.
:param full_std: (bool) Whether to use (n_features x n_actions) parameters
for the std instead of only (n_features,)
:param use_expln: (bool) Use ``expln()`` function instead of ``exp()`` to ensure
@ -495,8 +497,8 @@ def make_proba_distribution(action_space: gym.spaces.Space,
if isinstance(action_space, spaces.Box):
assert len(action_space.shape) == 1, "Error: the action space must be a vector"
if use_sde:
return StateDependentNoiseDistribution(action_space.shape[0], **dist_kwargs)
return DiagGaussianDistribution(action_space.shape[0], **dist_kwargs)
return StateDependentNoiseDistribution(get_action_dim(action_space), **dist_kwargs)
return DiagGaussianDistribution(get_action_dim(action_space), **dist_kwargs)
elif isinstance(action_space, spaces.Discrete):
return CategoricalDistribution(action_space.n, **dist_kwargs)
# elif isinstance(action_space, spaces.MultiDiscrete):

View file

@ -0,0 +1,44 @@
from typing import Tuple, Union
import numpy as np
import torch as th
from gym import spaces
def is_image(observation_space):
return False
def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space) -> th.Tensor:
if isinstance(observation_space, spaces.Box):
if is_image(observation_space):
return obs / 255.0
return obs
elif isinstance(observation_space, spaces.Discrete):
# TODO: one hot encoding
raise NotImplementedError()
else:
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
raise NotImplementedError()
def get_obs_dim(observation_space: spaces.Space) -> Union[int, Tuple[int, ...]]:
if isinstance(observation_space, spaces.Box):
if is_image(observation_space):
return observation_space.shape
return np.prod(observation_space.shape)
elif isinstance(observation_space, spaces.Discrete):
return 1
else:
# TODO: Multidiscrete, Binary, MultiBinary, Tuple, Dict
raise NotImplementedError()
def get_action_dim(action_space: spaces.Space) -> int:
if isinstance(action_space, spaces.Box):
return int(np.prod(action_space.shape))
elif isinstance(action_space, spaces.Discrete):
# Action is an int
return 1
else:
raise NotImplementedError()

View file

@ -6,6 +6,7 @@ import torch as th
import torch.nn as nn
import numpy as np
from torchy_baselines.common.preprocessing import get_obs_dim
from torchy_baselines.common.policies import (BasePolicy, register_policy, MlpExtractor,
create_sde_feature_extractor)
from torchy_baselines.common.distributions import (make_proba_distribution, Distribution,
@ -54,7 +55,7 @@ class PPOPolicy(BasePolicy):
use_expln: bool = False,
squash_output: bool = False):
super(PPOPolicy, self).__init__(observation_space, action_space, device, squash_output=squash_output)
self.obs_dim = self.observation_space.shape[0]
self.obs_dim = get_obs_dim(self.observation_space)
# Default network architecture, from stable-baselines
if net_arch is None:

View file

@ -98,11 +98,10 @@ class PPO(BaseRLModel):
device: Union[th.device, str] = 'auto',
_init_setup_model: bool = True):
super(PPO, self).__init__(policy, env, PPOPolicy, policy_kwargs=policy_kwargs,
super(PPO, self).__init__(policy, env, PPOPolicy, learning_rate, policy_kwargs=policy_kwargs,
verbose=verbose, device=device, use_sde=use_sde, sde_sample_freq=sde_sample_freq,
create_eval_env=create_eval_env, support_multi_env=True, seed=seed)
self.learning_rate = learning_rate
self.batch_size = batch_size
self.n_epochs = n_epochs
self.n_steps = n_steps
@ -123,19 +122,12 @@ class PPO(BaseRLModel):
def _setup_model(self) -> None:
self._setup_lr_schedule()
# TODO: preprocessing: one hot vector for obs discrete
state_dim = self.observation_space.shape[0]
if isinstance(self.action_space, spaces.Box):
# Action is a 1D vector
action_dim = self.action_space.shape[0]
elif isinstance(self.action_space, spaces.Discrete):
# Action is a scalar
action_dim = 1
self.set_random_seed(self.seed)
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
self.rollout_buffer = RolloutBuffer(self.n_steps, self.observation_space,
self.action_space, self.device,
gamma=self.gamma, gae_lambda=self.gae_lambda,
n_envs=self.n_envs)
self.policy = self.policy_class(self.observation_space, self.action_space,
self.lr_schedule, use_sde=self.use_sde, device=self.device,
**self.policy_kwargs)

View file

@ -4,6 +4,7 @@ import gym
import torch as th
import torch.nn as nn
from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
create_sde_feature_extractor)
from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
@ -191,8 +192,8 @@ class SACPolicy(BasePolicy):
if net_arch is None:
net_arch = [256, 256]
self.obs_dim = self.observation_space.shape[0]
self.action_dim = self.action_space.shape[0]
self.obs_dim = get_obs_dim(self.observation_space)
self.action_dim = get_action_dim(self.action_space)
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {

View file

@ -89,7 +89,9 @@ class SAC(OffPolicyRLModel):
device: Union[th.device, str] = 'auto',
_init_setup_model: bool = True):
super(SAC, self).__init__(policy, env, SACPolicy, policy_kwargs, verbose, device,
super(SAC, self).__init__(policy, env, SACPolicy, learning_rate,
buffer_size, learning_starts, batch_size,
policy_kwargs, verbose, device,
create_eval_env=create_eval_env, seed=seed,
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup)
@ -97,11 +99,6 @@ class SAC(OffPolicyRLModel):
self.target_entropy = target_entropy
self.log_ent_coef = None # type: Optional[th.Tensor]
self.target_update_interval = target_update_interval
self.buffer_size = buffer_size
# In the original paper, same learning rate is used for all networks
self.learning_rate = learning_rate
self.learning_starts = learning_starts
self.batch_size = batch_size
self.tau = tau
# Entropy coefficient / Entropy temperature
# Inverse of the reward scale
@ -118,10 +115,8 @@ class SAC(OffPolicyRLModel):
self._setup_model()
def _setup_model(self) -> None:
self._setup_lr_schedule()
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
if self.seed is not None:
self.set_random_seed(self.seed)
super(SAC, self)._setup_model()
self._create_aliases()
# Target entropy is used when learning the entropy coefficient
if self.target_entropy == 'auto':
@ -152,13 +147,6 @@ class SAC(OffPolicyRLModel):
# is passed
self.ent_coef_tensor = th.tensor(float(self.ent_coef)).to(self.device)
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
self.policy = self.policy_class(self.observation_space, self.action_space,
self.lr_schedule, use_sde=self.use_sde,
device=self.device, **self.policy_kwargs)
self.policy = self.policy.to(self.device)
self._create_aliases()
def _create_aliases(self) -> None:
self.actor = self.policy.actor
self.critic = self.policy.critic

View file

@ -4,6 +4,7 @@ import gym
import torch as th
import torch.nn as nn
from torchy_baselines.common.preprocessing import get_action_dim, get_obs_dim
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
create_sde_feature_extractor)
from torchy_baselines.common.distributions import StateDependentNoiseDistribution
@ -230,8 +231,8 @@ class TD3Policy(BasePolicy):
if net_arch is None:
net_arch = [400, 300]
self.obs_dim = self.observation_space.shape[0]
self.action_dim = self.action_space.shape[0]
self.obs_dim = get_obs_dim(self.observation_space)
self.action_dim = get_action_dim(self.action_space)
self.net_arch = net_arch
self.activation_fn = activation_fn
self.net_args = {

View file

@ -88,18 +88,16 @@ class TD3(OffPolicyRLModel):
device: Union[th.device, str] = 'auto',
_init_setup_model: bool = True):
super(TD3, self).__init__(policy, env, TD3Policy, policy_kwargs, verbose, device,
super(TD3, self).__init__(policy, env, TD3Policy, learning_rate,
buffer_size, learning_starts, batch_size,
policy_kwargs, verbose, device,
create_eval_env=create_eval_env, seed=seed,
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup)
self.buffer_size = buffer_size
self.learning_rate = learning_rate
self.learning_starts = learning_starts
self.train_freq = train_freq
self.gradient_steps = gradient_steps
self.n_episodes_rollout = n_episodes_rollout
self.batch_size = batch_size
self.tau = tau
self.gamma = gamma
self.action_noise = action_noise
@ -118,14 +116,7 @@ class TD3(OffPolicyRLModel):
self._setup_model()
def _setup_model(self) -> None:
self._setup_lr_schedule()
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
self.set_random_seed(self.seed)
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
self.policy = self.policy_class(self.observation_space, self.action_space,
self.lr_schedule, use_sde=self.use_sde,
device=self.device, **self.policy_kwargs)
self.policy = self.policy.to(self.device)
super(TD3, self)._setup_model()
self._create_aliases()
def _create_aliases(self) -> None:

View file

@ -1 +1 @@
0.3.0
0.4.0a0