Improve predict method

This commit is contained in:
Antonin Raffin 2020-02-12 15:25:05 +01:00
parent 9caea35a11
commit f1a4fa2d3f
12 changed files with 134 additions and 90 deletions

View file

@ -3,7 +3,7 @@
Changelog
==========
Pre-Release 0.2.0a1 (WIP)
Pre-Release 0.2.0a2 (WIP)
------------------------------
Breaking Changes:
@ -20,7 +20,8 @@ New Features:
- Add methods for saving and loading replay buffer
- Add `extend()` method to the buffers
- Add `get_vec_normalize_env()` to `BaseRLModel` to retrieve `VecNormalize` wrapper when it exists
- Add `¶results_plotter` from Stable Baselines
- Add `results_plotter` from Stable Baselines
- Improve `predict()` method to handle different type of observations (single, vectorized, ...)
Bug Fixes:
^^^^^^^^^^

View file

@ -12,7 +12,9 @@ setup(name='torchy_baselines',
'torch>=1.2.0',
'cloudpickle',
# For reading logs
'pandas'
'pandas',
# Plotting learning curves
'matplotlib'
],
extras_require={
'tests': [
@ -45,7 +47,7 @@ setup(name='torchy_baselines',
license="MIT",
long_description="",
long_description_content_type='text/markdown',
version="0.2.0a1",
version="0.2.0a2",
)
# python setup.py sdist

View file

@ -34,7 +34,7 @@ def test_save_load(model_class):
env.reset()
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
observations = np.squeeze(observations)
observations = observations.reshape(10, -1)
# Get dictionary of current parameters
params = deepcopy(model.policy.state_dict())
@ -53,7 +53,7 @@ def test_save_load(model_class):
params = new_params
# get selected actions
selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
selected_actions = model.predict(observations, deterministic=True)
# Check
model.save("test_save.zip")
@ -68,7 +68,7 @@ def test_save_load(model_class):
assert th.allclose(params[key], new_params[key]), "Model parameters not the same after save and load."
# check if model still selects the same actions
new_selected_actions = [model.predict(observation, deterministic=True) for observation in observations]
new_selected_actions = model.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)
# check if learn still works

View file

@ -4,4 +4,4 @@ from torchy_baselines.ppo import PPO
from torchy_baselines.sac import SAC
from torchy_baselines.td3 import TD3
__version__ = "0.2.0a1"
__version__ = "0.2.0a2"

View file

@ -306,21 +306,104 @@ class BaseRLModel(ABC):
"""
raise NotImplementedError()
@abstractmethod
@staticmethod
def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: (np.ndarray) the input observation to validate
:param observation_space: (gym.spaces) the observation space
:return: (bool) whether the given observation is vectorized or not
"""
if isinstance(observation_space, gym.spaces.Box):
if observation.shape == observation_space.shape:
return False
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Box environment, please use {} ".format(observation_space.shape) +
"or (n_env, {}) for the observation shape."
.format(", ".join(map(str, observation_space.shape))))
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
elif isinstance(observation_space, gym.spaces.MultiDiscrete):
if observation.shape == (len(observation_space.nvec),):
return False
elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
return True
else:
raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
"environment, please use ({},) or ".format(len(observation_space.nvec)) +
"(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
elif isinstance(observation_space, gym.spaces.MultiBinary):
if observation.shape == (observation_space.n,):
return False
elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
"environment, please use ({},) or ".format(observation_space.n) +
"(n_env, {}) for the observation shape.".format(observation_space.n))
else:
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
.format(observation_space))
def predict(self, observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
deterministic: bool = False) -> np.ndarray:
"""
Get the model's action from an observation
Get the model's action(s) from an observation
:param observation: the input observation
:param state: The last states (can be None, used in recurrent policies)
:param mask: The last masks (can be None, used in recurrent policies)
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next state (used in recurrent policies)
:param observation: (np.ndarray) the input observation
:param state: (Optional[np.ndarray]) The last states (can be None, used in recurrent policies)
:param mask: (Optional[np.ndarray]) The last masks (can be None, used in recurrent policies)
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: (np.ndarray) the model's action and the next state (used in recurrent policies)
"""
raise NotImplementedError()
# if state is None:
# state = self.initial_state
# if mask is None:
# mask = [False for _ in range(self.n_envs)]
observation = np.array(observation)
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)
observation = observation.reshape((-1,) + self.observation_space.shape)
# Convert to float pytorch
# TODO: replace with preprocessing
observation = th.as_tensor(observation).float().to(self.device)
with th.no_grad():
actions = self.policy.predict(observation, deterministic=deterministic)
# Convert to numpy
actions = actions.cpu().numpy()
# Rescale to proper domain when using squashing
# TODO: should not be used for a Gaussian distribution?
if isinstance(self.action_space, gym.spaces.Box):
actions = self.unscale_action(actions)
clipped_actions = actions
# Clip the actions to avoid out of bound error when using gaussian distribution
if isinstance(self.action_space, gym.spaces.Box):
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
clipped_actions = clipped_actions[0]
# TODO: switch to stable baselines API
# return clipped_actions, state
return clipped_actions
@classmethod
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs):
@ -806,7 +889,9 @@ class OffPolicyRLModel(BaseRLModel):
# Warmup phase
unscaled_action = np.array([self.action_space.sample()])
else:
unscaled_action = self.predict(obs, deterministic=not self.use_sde)
# Note: we assume that the policy uses tanh to scale the action
# We use non-deterministic action in the case of SAC, for TD3, it does not matter
unscaled_action = self.predict(obs, deterministic=False)
# Rescale the action from [low, high] to [-1, 1]
scaled_action = self.scale_action(unscaled_action)

View file

@ -1,25 +1,30 @@
from typing import Union
from itertools import zip_longest
import gym
import torch as th
import torch.nn as nn
import numpy as np
class BasePolicy(nn.Module):
"""
The base policy object
:param observation_space: (Gym Space) The observation space of the environment
:param action_space: (Gym Space) The action space of the environment
:param observation_space: (gym.spaces.Space) The observation space of the environment
:param action_space: (gym.spaces.Space) The action space of the environment
"""
def __init__(self, observation_space, action_space, device='cpu'):
def __init__(self, observation_space: gym.spaces.Space,
action_space: gym.spaces.Space, device: Union[th.device, str] = 'cpu'):
super(BasePolicy, self).__init__()
self.observation_space = observation_space
self.action_space = action_space
self.device = device
@staticmethod
def init_weights(module, gain=1):
def init_weights(module: nn.Module, gain: float = 1):
if type(module) == nn.Linear:
nn.init.orthogonal_(module.weight, gain=gain)
module.bias.data.fill_(0.0)
@ -27,7 +32,13 @@ class BasePolicy(nn.Module):
def forward(self, *_args, **kwargs):
raise NotImplementedError()
def save(self, path):
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
"""
Get the action according to the policy for a given observation.
"""
raise NotImplementedError()
def save(self, path: str) -> None:
"""
Save model to a given location.
@ -35,7 +46,7 @@ class BasePolicy(nn.Module):
"""
th.save(self.state_dict(), path)
def load(self, path):
def load(self, path: str) -> None:
"""
Load saved model from path.
@ -43,7 +54,7 @@ class BasePolicy(nn.Module):
"""
self.load_state_dict(th.load(path))
def load_from_vector(self, vector):
def load_from_vector(self, vector: np.ndarray):
"""
Load parameters from a 1D vector.
@ -51,7 +62,7 @@ class BasePolicy(nn.Module):
"""
th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters())
def parameters_to_vector(self):
def parameters_to_vector(self) -> np.ndarray:
"""
Convert the parameters to a 1D vector.

View file

@ -162,10 +162,10 @@ class PPOPolicy(BasePolicy):
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde,
deterministic=deterministic)
def actor_forward(self, obs, deterministic=False):
latent_pi, _, latent_sde = self._get_latent(obs)
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
latent_pi, _, latent_sde = self._get_latent(observation)
action, _ = self._get_action_dist_from_latent(latent_pi, latent_sde, deterministic=deterministic)
return action.detach().cpu().numpy()
return action
def evaluate_actions(self, obs, action, deterministic=False):
"""

View file

@ -129,28 +129,6 @@ class PPO(BaseRLModel):
if self.clip_range_vf is not None:
self.clip_range_vf = get_schedule_fn(self.clip_range_vf)
def select_action(self, observation, deterministic=False):
# Normally not needed
observation = np.array(observation)
with th.no_grad():
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
return self.policy.actor_forward(observation, deterministic=deterministic)
def predict(self, observation, state=None, mask=None, deterministic=False):
"""
Get the model's action from an observation
:param observation: (np.ndarray) the input observation
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
"""
clipped_actions = self.select_action(observation, deterministic=deterministic)
if isinstance(self.action_space, gym.spaces.Box):
clipped_actions = np.clip(clipped_actions, self.action_space.low, self.action_space.high)
return clipped_actions
def collect_rollouts(self,
env: VecEnv,
callback: BaseCallback,

View file

@ -129,11 +129,11 @@ class Actor(BaseNetwork):
def forward(self, obs, deterministic=False):
mean_actions, log_std, latent_sde = self.get_action_dist_params(obs)
if self.use_sde:
# Note the action is squashed
# Note: the action is squashed
action, _ = self.action_dist.proba_distribution(mean_actions, log_std, latent_sde,
deterministic=deterministic)
else:
# Note the action is squashed
# Note: the action is squashed
action, _ = self.action_dist.proba_distribution(mean_actions, log_std,
deterministic=deterministic)
return action
@ -246,6 +246,8 @@ class SACPolicy(BasePolicy):
def forward(self, obs):
return self.actor(obs)
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.actor.forward(observation, deterministic)
MlpPolicy = SACPolicy

View file

@ -148,25 +148,6 @@ class SAC(OffPolicyRLModel):
self.critic = self.policy.critic
self.critic_target = self.policy.critic_target
def select_action(self, observation):
# Normally not needed
observation = np.array(observation)
with th.no_grad():
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
return self.actor(observation).cpu().data.numpy()
def predict(self, observation, state=None, mask=None, deterministic=True):
"""
Get the model's action from an observation
:param observation: (np.ndarray) the input observation
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
"""
return self.unscale_action(self.select_action(observation))
def train(self, gradient_steps: int, batch_size: int = 64):
# Update optimizers learning rate
optimizers = [self.actor.optimizer, self.critic.optimizer]

View file

@ -277,6 +277,9 @@ class TD3Policy(BasePolicy):
def forward(self, obs, deterministic=True):
return self.actor(obs, deterministic=deterministic)
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
return self.forward(observation, deterministic)
MlpPolicy = TD3Policy

View file

@ -114,25 +114,6 @@ class TD3(OffPolicyRLModel):
self.critic_target = self.policy.critic_target
self.vf_net = self.policy.vf_net
def select_action(self, observation, deterministic=True):
# Normally not needed
observation = np.array(observation)
with th.no_grad():
observation = th.FloatTensor(observation.reshape(1, -1)).to(self.device)
return self.actor(observation, deterministic=deterministic).cpu().numpy()
def predict(self, observation, state=None, mask=None, deterministic=True):
"""
Get the model's action from an observation
:param observation: (np.ndarray) the input observation
:param state: (np.ndarray) The last states (can be None, used in recurrent policies)
:param mask: (np.ndarray) The last masks (can be None, used in recurrent policies)
:param deterministic: (bool) Whether or not to return deterministic actions.
:return: (np.ndarray, np.ndarray) the model's action and the next state (used in recurrent policies)
"""
return self.unscale_action(self.select_action(observation, deterministic=deterministic))
def train_critic(self, gradient_steps: int = 1,
batch_size: int = 100,
replay_data: Optional[ReplayBufferSamples] = None,