mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-16 21:10:08 +00:00
commit
7f7b288dce
17 changed files with 208 additions and 152 deletions
|
|
@ -50,7 +50,7 @@ import torchy_baselines
|
|||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'Torchy Baselines'
|
||||
copyright = '2019, Torchy Baselines'
|
||||
copyright = '2020, Torchy Baselines'
|
||||
author = 'Torchy Baselines Contributors'
|
||||
|
||||
# The short X.Y version
|
||||
|
|
@ -70,7 +70,7 @@ release = torchy_baselines.__version__
|
|||
# ones.
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx_autodoc_typehints',
|
||||
# 'sphinx_autodoc_typehints',
|
||||
'sphinx.ext.autosummary',
|
||||
'sphinx.ext.mathjax',
|
||||
'sphinx.ext.ifconfig',
|
||||
|
|
|
|||
|
|
@ -14,10 +14,12 @@ Breaking Changes:
|
|||
|
||||
New Features:
|
||||
^^^^^^^^^^^^^
|
||||
- Better logging for ``SAC`` and ``PPO``
|
||||
|
||||
Bug Fixes:
|
||||
^^^^^^^^^^
|
||||
- Synced callbacks with Stable-Baselines
|
||||
- Fixed colors in `results_plotter`
|
||||
|
||||
Deprecations:
|
||||
^^^^^^^^^^^^^
|
||||
|
|
@ -29,9 +31,11 @@ Others:
|
|||
- Buffers now return ``NamedTuple``
|
||||
- More typing
|
||||
- Add test for ``expln``
|
||||
- Renamed ``learning_rate`` to ``lr_schedule``
|
||||
|
||||
Documentation:
|
||||
^^^^^^^^^^^^^^
|
||||
- Deactivated ``sphinx_autodoc_typehints`` extension
|
||||
|
||||
|
||||
Pre-Release 0.2.0 (2020-02-14)
|
||||
|
|
|
|||
5
setup.py
5
setup.py
|
|
@ -10,6 +10,7 @@ setup(name='torchy_baselines',
|
|||
'gym[classic_control]>=0.11',
|
||||
'numpy',
|
||||
'torch>=1.4.0',
|
||||
# For saving models
|
||||
'cloudpickle',
|
||||
# For reading logs
|
||||
'pandas',
|
||||
|
|
@ -31,7 +32,7 @@ setup(name='torchy_baselines',
|
|||
# For spelling
|
||||
'sphinxcontrib.spelling',
|
||||
# Type hints support
|
||||
'sphinx-autodoc-typehints'
|
||||
# 'sphinx-autodoc-typehints'
|
||||
],
|
||||
'extra': [
|
||||
# For render
|
||||
|
|
@ -47,7 +48,7 @@ setup(name='torchy_baselines',
|
|||
license="MIT",
|
||||
long_description="",
|
||||
long_description_content_type='text/markdown',
|
||||
version="0.2.3",
|
||||
version="0.2.4",
|
||||
)
|
||||
|
||||
# python setup.py sdist
|
||||
|
|
|
|||
|
|
@ -4,4 +4,4 @@ from torchy_baselines.ppo import PPO
|
|||
from torchy_baselines.sac import SAC
|
||||
from torchy_baselines.td3 import TD3
|
||||
|
||||
__version__ = "0.2.3"
|
||||
__version__ = "0.2.4"
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class A2C(PPO):
|
|||
super(A2C, self)._setup_model()
|
||||
if self.use_rms_prop:
|
||||
self.policy.optimizer = th.optim.RMSprop(self.policy.parameters(),
|
||||
lr=self.learning_rate(1), alpha=0.99,
|
||||
lr=self.lr_schedule(1), alpha=0.99,
|
||||
eps=self.rms_prop_eps, weight_decay=0)
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: Optional[int] = None) -> None:
|
||||
|
|
@ -144,6 +144,8 @@ class A2C(PPO):
|
|||
explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
|
||||
self.rollout_buffer.values.flatten())
|
||||
|
||||
self._n_updates += 1
|
||||
logger.logkv("n_updates", self._n_updates)
|
||||
logger.logkv("explained_variance", explained_var)
|
||||
logger.logkv("entropy_loss", entropy_loss.item())
|
||||
logger.logkv("policy_loss", policy_loss.item())
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class CEMRL(TD3):
|
|||
self.actor.load_from_vector(self.es_params[i])
|
||||
self.actor_target.load_from_vector(self.es_params[i])
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(),
|
||||
lr=self.learning_rate(self._current_progress))
|
||||
lr=self.lr_schedule(self._current_progress))
|
||||
|
||||
# In the paper: 2 * actor_steps // self.n_grad
|
||||
# In the original implementation: actor_steps // self.n_grad
|
||||
|
|
|
|||
|
|
@ -27,25 +27,25 @@ class BaseRLModel(ABC):
|
|||
"""
|
||||
The base RL model
|
||||
|
||||
:param policy: Policy object
|
||||
:param env: The environment to learn from
|
||||
:param policy: (Type[BasePolicy]) Policy object
|
||||
:param env: (Union[GymEnv, str]) The environment to learn from
|
||||
(if registered in Gym, can be str. Can be None for loading trained models)
|
||||
:param policy_base: The base policy used by this method
|
||||
:param policy_kwargs: Additional arguments to be passed to the policy on creation
|
||||
:param verbose: The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: Device on which the code should run.
|
||||
:param policy_base: (Type[BasePolicy]) The base policy used by this method
|
||||
:param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation
|
||||
:param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param device: (Union[th.device, str]) Device on which the code should run.
|
||||
By default, it will try to use a Cuda compatible device and fallback to cpu
|
||||
if it is not possible.
|
||||
:param support_multi_env: Whether the algorithm supports training
|
||||
:param support_multi_env: (bool) Whether the algorithm supports training
|
||||
with multiple environments (as in A2C)
|
||||
:param create_eval_env: Whether to create a second environment that will be
|
||||
:param create_eval_env: (bool) Whether to create a second environment that will be
|
||||
used for evaluating the agent periodically. (Only available when passing string for the environment)
|
||||
:param monitor_wrapper: When creating an environment, whether to wrap it
|
||||
:param monitor_wrapper: (bool) When creating an environment, whether to wrap it
|
||||
or not in a Monitor wrapper.
|
||||
:param seed: Seed for the pseudo random generators
|
||||
:param use_sde: Whether to use State Dependent Exploration (SDE)
|
||||
:param seed: (Optional[int]) Seed for the pseudo random generators
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration (SDE)
|
||||
instead of action noise exploration (default: False)
|
||||
:param sde_sample_freq: Sample a new noise matrix every n steps when using SDE
|
||||
:param sde_sample_freq: (int) Sample a new noise matrix every n steps when using SDE
|
||||
Default: -1 (only sample at the beginning of the rollout)
|
||||
"""
|
||||
|
||||
|
|
@ -80,8 +80,8 @@ class BaseRLModel(ABC):
|
|||
self._vec_normalize_env = unwrap_vec_normalize(env)
|
||||
self.verbose = verbose
|
||||
self.policy_kwargs = {} if policy_kwargs is None else policy_kwargs
|
||||
self.observation_space = None
|
||||
self.action_space = None
|
||||
self.observation_space = None # type: Optional[gym.spaces.Space]
|
||||
self.action_space = None # type: Optional[gym.spaces.Space]
|
||||
self.n_envs = None
|
||||
self.num_timesteps = 0
|
||||
self.eval_env = None
|
||||
|
|
@ -89,7 +89,8 @@ class BaseRLModel(ABC):
|
|||
self.action_noise = None # type: Optional[ActionNoise]
|
||||
self.start_time = None
|
||||
self.policy = None
|
||||
self.learning_rate = None
|
||||
self.learning_rate = None # type: Optional[float]
|
||||
self.lr_schedule = None # type: Optional[Callable]
|
||||
# Used for SDE only
|
||||
self.use_sde = use_sde
|
||||
self.sde_sample_freq = sde_sample_freq
|
||||
|
|
@ -99,6 +100,8 @@ class BaseRLModel(ABC):
|
|||
# Buffers for logging
|
||||
self.ep_info_buffer = None # type: Optional[deque]
|
||||
self.ep_success_buffer = None # type: Optional[deque]
|
||||
# For logging
|
||||
self._n_updates = 0 # type: int
|
||||
|
||||
# Create and wrap the env if needed
|
||||
if env is not None:
|
||||
|
|
@ -132,13 +135,16 @@ class BaseRLModel(ABC):
|
|||
@abstractmethod
|
||||
def _setup_model(self) -> None:
|
||||
"""
|
||||
Setup model so state_dict can be loaded
|
||||
Create networks and optimizers
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
|
||||
"""
|
||||
Return the environment that will be used for evaluation.
|
||||
|
||||
:param eval_env: (Optional[GymEnv]))
|
||||
:return: (Optional[GymEnv])
|
||||
"""
|
||||
if eval_env is None:
|
||||
eval_env = self.eval_env
|
||||
|
|
@ -154,7 +160,8 @@ class BaseRLModel(ABC):
|
|||
Rescale the action from [low, high] to [-1, 1]
|
||||
(no need for symmetric action space)
|
||||
|
||||
:param action: Action to scale
|
||||
:param action: (np.ndarray) Action to scale
|
||||
:return: (np.ndarray) Scaled action
|
||||
"""
|
||||
low, high = self.action_space.low, self.action_space.high
|
||||
return 2.0 * ((action - low) / (high - low)) - 1.0
|
||||
|
|
@ -169,9 +176,9 @@ class BaseRLModel(ABC):
|
|||
low, high = self.action_space.low, self.action_space.high
|
||||
return low + (0.5 * (scaled_action + 1.0) * (high - low))
|
||||
|
||||
def _setup_learning_rate(self) -> None:
|
||||
def _setup_lr_schedule(self) -> None:
|
||||
"""Transform to callable if needed."""
|
||||
self.learning_rate = get_schedule_fn(self.learning_rate)
|
||||
self.lr_schedule = get_schedule_fn(self.learning_rate)
|
||||
|
||||
def _update_current_progress(self, num_timesteps: int, total_timesteps: int) -> None:
|
||||
"""
|
||||
|
|
@ -187,15 +194,16 @@ class BaseRLModel(ABC):
|
|||
Update the optimizers learning rate using the current learning rate schedule
|
||||
and the current progress (from 1 to 0).
|
||||
|
||||
:param optimizers: An optimizer or a list of optimizer.
|
||||
:param optimizers: (Union[List[th.optim.Optimizer], th.optim.Optimizer])
|
||||
An optimizer or a list of optimizers.
|
||||
"""
|
||||
# Log the current learning rate
|
||||
logger.logkv("learning_rate", self.learning_rate(self._current_progress))
|
||||
logger.logkv("learning_rate", self.lr_schedule(self._current_progress))
|
||||
|
||||
if not isinstance(optimizers, list):
|
||||
optimizers = [optimizers]
|
||||
for optimizer in optimizers:
|
||||
update_learning_rate(optimizer, self.learning_rate(self._current_progress))
|
||||
update_learning_rate(optimizer, self.lr_schedule(self._current_progress))
|
||||
|
||||
@staticmethod
|
||||
def safe_mean(arr: Union[np.ndarray, list, deque]) -> np.ndarray:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""
|
||||
Taken from stable-baselines
|
||||
"""
|
||||
from typing import Optional
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
|
@ -13,34 +12,34 @@ class ActionNoise(ABC):
|
|||
def __init__(self):
|
||||
super(ActionNoise, self).__init__()
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
call end of episode reset for the noise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self):
|
||||
pass
|
||||
def __call__(self) -> np.ndarray:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class NormalActionNoise(ActionNoise):
|
||||
"""
|
||||
A Gaussian action noise
|
||||
|
||||
:param mean: (float) the mean value of the noise
|
||||
:param sigma: (float) the scale of the noise (std here)
|
||||
:param mean: (np.ndarray) the mean value of the noise
|
||||
:param sigma: (np.ndarray) the scale of the noise (std here)
|
||||
"""
|
||||
|
||||
def __init__(self, mean, sigma):
|
||||
def __init__(self, mean: np.ndarray, sigma: np.ndarray):
|
||||
self._mu = mean
|
||||
self._sigma = sigma
|
||||
super(NormalActionNoise, self).__init__()
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> np.ndarray:
|
||||
return np.random.normal(self._mu, self._sigma)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'NormalActionNoise(mu={self._mu}, sigma={self._sigma})'
|
||||
|
||||
|
||||
|
|
@ -50,34 +49,38 @@ class OrnsteinUhlenbeckActionNoise(ActionNoise):
|
|||
|
||||
Based on http://math.stackexchange.com/questions/1287634/implementing-ornstein-uhlenbeck-in-matlab
|
||||
|
||||
:param mean: (float) the mean of the noise
|
||||
:param sigma: (float) the scale of the noise
|
||||
:param mean: (np.ndarray) the mean of the noise
|
||||
:param sigma: (np.ndarray) the scale of the noise
|
||||
:param theta: (float) the rate of mean reversion
|
||||
:param dt: (float) the timestep for the noise
|
||||
:param initial_noise: ([float]) the initial value for the noise output, (if None: 0)
|
||||
:param initial_noise: (Optional[np.ndarray]) the initial value for the noise output, (if None: 0)
|
||||
"""
|
||||
|
||||
def __init__(self, mean, sigma, theta=.15, dt=1e-2, initial_noise=None):
|
||||
def __init__(self, mean: np.ndarray,
|
||||
sigma: np.ndarray,
|
||||
theta: float = .15,
|
||||
dt: float = 1e-2,
|
||||
initial_noise: Optional[np.ndarray] = None):
|
||||
self._theta = theta
|
||||
self._mu = mean
|
||||
self._sigma = sigma
|
||||
self._dt = dt
|
||||
self.initial_noise = initial_noise
|
||||
self.noise_prev = None
|
||||
self.noise_prev = np.zeros_like(self._mu)
|
||||
self.reset()
|
||||
super(OrnsteinUhlenbeckActionNoise, self).__init__()
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> np.ndarray:
|
||||
noise = self.noise_prev + self._theta * (self._mu - self.noise_prev) * self._dt + \
|
||||
self._sigma * np.sqrt(self._dt) * np.random.normal(size=self._mu.shape)
|
||||
self.noise_prev = noise
|
||||
return noise
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
reset the Ornstein Uhlenbeck noise, to the initial position
|
||||
"""
|
||||
self.noise_prev = self.initial_noise if self.initial_noise is not None else np.zeros_like(self._mu)
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f'OrnsteinUhlenbeckActionNoise(mu={self._mu}, sigma={self._sigma})'
|
||||
|
|
|
|||
|
|
@ -8,16 +8,12 @@ import matplotlib.pyplot as plt
|
|||
|
||||
from torchy_baselines.common.monitor import load_results
|
||||
|
||||
plt.rcParams['svg.fonttype'] = 'none'
|
||||
|
||||
X_TIMESTEPS = 'timesteps'
|
||||
X_EPISODES = 'episodes'
|
||||
X_WALLTIME = 'walltime_hrs'
|
||||
POSSIBLE_X_AXES = [X_TIMESTEPS, X_EPISODES, X_WALLTIME]
|
||||
EPISODES_WINDOW = 100
|
||||
COLORS = ['blue', 'green', 'red', 'cyan', 'magenta', 'yellow', 'black', 'purple', 'pink',
|
||||
'brown', 'orange', 'teal', 'coral', 'lightblue', 'lime', 'lavender', 'turquoise',
|
||||
'darkgreen', 'tan', 'salmon', 'gold', 'lightpurple', 'darkred', 'darkblue']
|
||||
|
||||
|
||||
def rolling_window(array: np.ndarray, window: int) -> np.ndarray:
|
||||
|
|
@ -49,25 +45,25 @@ def window_func(var_1: np.ndarray, var_2: np.ndarray,
|
|||
return var_1[window - 1:], function_on_var2
|
||||
|
||||
|
||||
def ts2xy(timesteps: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
def ts2xy(data_frame: pd.DataFrame, x_axis: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Decompose a timesteps variable to x ans ys
|
||||
Decompose a data frame variable to x ans ys
|
||||
|
||||
:param timesteps: (pd.DataFrame) the input data
|
||||
:param data_frame: (pd.DataFrame) the input data
|
||||
:param x_axis: (str) the axis for the x and y output
|
||||
(can be X_TIMESTEPS='timesteps', X_EPISODES='episodes' or X_WALLTIME='walltime_hrs')
|
||||
:return: (Tuple[np.ndarray, np.ndarray]) the x and y output
|
||||
"""
|
||||
if x_axis == X_TIMESTEPS:
|
||||
x_var = np.cumsum(timesteps.l.values)
|
||||
y_var = timesteps.r.values
|
||||
x_var = np.cumsum(data_frame.l.values)
|
||||
y_var = data_frame.r.values
|
||||
elif x_axis == X_EPISODES:
|
||||
x_var = np.arange(len(timesteps))
|
||||
y_var = timesteps.r.values
|
||||
x_var = np.arange(len(data_frame))
|
||||
y_var = data_frame.r.values
|
||||
elif x_axis == X_WALLTIME:
|
||||
# Convert to hours
|
||||
x_var = timesteps.t.values / 3600.
|
||||
y_var = timesteps.r.values
|
||||
x_var = data_frame.t.values / 3600.
|
||||
y_var = data_frame.r.values
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x_var, y_var
|
||||
|
|
@ -85,17 +81,16 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]],
|
|||
:param figsize: (Tuple[int, int]) Size of the figure (width, height)
|
||||
"""
|
||||
|
||||
plt.figure(figsize=figsize)
|
||||
plt.figure(title, figsize=figsize)
|
||||
max_x = max(xy[0][-1] for xy in xy_list)
|
||||
min_x = 0
|
||||
for (i, (x, y)) in enumerate(xy_list):
|
||||
color = COLORS[i]
|
||||
plt.scatter(x, y, s=2)
|
||||
# Do not plot the smoothed curve at all if the timeseries is shorter than window size.
|
||||
if x.shape[0] >= EPISODES_WINDOW:
|
||||
# Compute and plot rolling mean with window of size EPISODE_WINDOW
|
||||
x, y_mean = window_func(x, y, EPISODES_WINDOW, np.mean)
|
||||
plt.plot(x, y_mean, color=color)
|
||||
plt.plot(x, y_mean)
|
||||
plt.xlim(min_x, max_x)
|
||||
plt.title(title)
|
||||
plt.xlabel(x_axis)
|
||||
|
|
@ -106,7 +101,7 @@ def plot_curves(xy_list: List[Tuple[np.ndarray, np.ndarray]],
|
|||
def plot_results(dirs: List[str], num_timesteps: Optional[int],
|
||||
x_axis: str, task_name: str, figsize: Tuple[int, int] = (8, 2)) -> None:
|
||||
"""
|
||||
plot the results
|
||||
Plot the results using csv files from ``Monitor`` wrapper.
|
||||
|
||||
:param dirs: ([str]) the save location of the results to plot
|
||||
:param num_timesteps: (int or None) only plot the points below this value
|
||||
|
|
@ -116,11 +111,11 @@ def plot_results(dirs: List[str], num_timesteps: Optional[int],
|
|||
:param figsize: (Tuple[int, int]) Size of the figure (width, height)
|
||||
"""
|
||||
|
||||
timesteps_list = []
|
||||
data_frames = []
|
||||
for folder in dirs:
|
||||
timesteps = load_results(folder)
|
||||
data_frame = load_results(folder)
|
||||
if num_timesteps is not None:
|
||||
timesteps = timesteps[timesteps.l.cumsum() <= num_timesteps]
|
||||
timesteps_list.append(timesteps)
|
||||
xy_list = [ts2xy(timesteps_item, x_axis) for timesteps_item in timesteps_list]
|
||||
data_frame = data_frame[data_frame.l.cumsum() <= num_timesteps]
|
||||
data_frames.append(data_frame)
|
||||
xy_list = [ts2xy(data_frame, x_axis) for data_frame in data_frames]
|
||||
plot_curves(xy_list, x_axis, task_name, figsize)
|
||||
|
|
|
|||
|
|
@ -1,26 +1,30 @@
|
|||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class RunningMeanStd(object):
|
||||
def __init__(self, epsilon=1e-4, shape=()):
|
||||
def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = ()):
|
||||
"""
|
||||
calulates the running mean and std of a data stream
|
||||
Calulates the running mean and std of a data stream
|
||||
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
|
||||
|
||||
:param epsilon: (float) helps with arithmetic issues
|
||||
:param shape: (tuple) the shape of the data stream's output
|
||||
"""
|
||||
self.mean = np.zeros(shape, 'float64')
|
||||
self.var = np.ones(shape, 'float64')
|
||||
self.mean = np.zeros(shape, np.float64)
|
||||
self.var = np.ones(shape, np.float64)
|
||||
self.count = epsilon
|
||||
|
||||
def update(self, arr):
|
||||
def update(self, arr: np.ndarray) -> None:
|
||||
batch_mean = np.mean(arr, axis=0)
|
||||
batch_var = np.var(arr, axis=0)
|
||||
batch_count = arr.shape[0]
|
||||
self.update_from_moments(batch_mean, batch_var, batch_count)
|
||||
|
||||
def update_from_moments(self, batch_mean, batch_var, batch_count):
|
||||
def update_from_moments(self, batch_mean: np.ndarray,
|
||||
batch_var: np.ndarray,
|
||||
batch_count: int) -> None:
|
||||
delta = batch_mean - self.mean
|
||||
tot_count = self.count + batch_count
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Callable, Union
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
|
||||
|
||||
def set_random_seed(seed, using_cuda=False):
|
||||
def set_random_seed(seed: int, using_cuda: bool = False) -> None:
|
||||
"""
|
||||
Seed the different random generators
|
||||
:param seed: (int)
|
||||
|
|
@ -21,7 +22,7 @@ def set_random_seed(seed, using_cuda=False):
|
|||
|
||||
|
||||
# From stable baselines
|
||||
def explained_variance(y_pred, y_true):
|
||||
def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Computes fraction of variance that ypred explains about y.
|
||||
Returns 1 - Var[y-ypred] / Var[y]
|
||||
|
|
@ -40,7 +41,7 @@ def explained_variance(y_pred, y_true):
|
|||
return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
|
||||
|
||||
|
||||
def update_learning_rate(optimizer, learning_rate):
|
||||
def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None:
|
||||
"""
|
||||
Update the learning rate for a given optimizer.
|
||||
Useful when doing linear schedule.
|
||||
|
|
@ -52,7 +53,7 @@ def update_learning_rate(optimizer, learning_rate):
|
|||
param_group['lr'] = learning_rate
|
||||
|
||||
|
||||
def get_schedule_fn(value_schedule):
|
||||
def get_schedule_fn(value_schedule: Union[Callable, float]) -> Callable:
|
||||
"""
|
||||
Transform (if needed) learning rate and clip range (for PPO)
|
||||
to callable.
|
||||
|
|
@ -70,13 +71,13 @@ def get_schedule_fn(value_schedule):
|
|||
return value_schedule
|
||||
|
||||
|
||||
def constant_fn(val):
|
||||
def constant_fn(val: float) -> Callable:
|
||||
"""
|
||||
Create a function that returns a constant
|
||||
It is useful for learning rate schedule (to avoid code duplication)
|
||||
|
||||
:param val: (float)
|
||||
:return: (function)
|
||||
:return: (Callable)
|
||||
"""
|
||||
|
||||
def func(_):
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ class PPOPolicy(BasePolicy):
|
|||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param learning_rate: (callable) Learning rate schedule (could be constant)
|
||||
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param activation_fn: (nn.Module) Activation function
|
||||
|
|
@ -41,7 +41,7 @@ class PPOPolicy(BasePolicy):
|
|||
def __init__(self,
|
||||
observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
learning_rate: Callable,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
activation_fn: nn.Module = nn.Tanh,
|
||||
|
|
@ -93,7 +93,7 @@ class PPOPolicy(BasePolicy):
|
|||
# Action distribution
|
||||
self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs)
|
||||
|
||||
self._build(learning_rate)
|
||||
self._build(lr_schedule)
|
||||
|
||||
def reset_noise(self, n_envs: int = 1) -> None:
|
||||
"""
|
||||
|
|
@ -104,7 +104,7 @@ class PPOPolicy(BasePolicy):
|
|||
assert isinstance(self.action_dist, StateDependentNoiseDistribution), 'reset_noise() is only available when using SDE'
|
||||
self.action_dist.sample_weights(self.log_std, batch_size=n_envs)
|
||||
|
||||
def _build(self, learning_rate: Callable) -> None:
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
self.mlp_extractor = MlpExtractor(self.features_dim, net_arch=self.net_arch,
|
||||
activation_fn=self.activation_fn, device=self.device)
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ class PPOPolicy(BasePolicy):
|
|||
self.value_net: 1
|
||||
}[module]
|
||||
module.apply(partial(self.init_weights, gain=gain))
|
||||
self.optimizer = th.optim.Adam(self.parameters(), lr=learning_rate(1), eps=self.adam_epsilon)
|
||||
self.optimizer = th.optim.Adam(self.parameters(), lr=lr_schedule(1), eps=self.adam_epsilon)
|
||||
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
|
||||
if not isinstance(obs, th.Tensor):
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ class PPO(BaseRLModel):
|
|||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_learning_rate()
|
||||
self._setup_lr_schedule()
|
||||
# TODO: preprocessing: one hot vector for obs discrete
|
||||
state_dim = self.observation_space.shape[0]
|
||||
if isinstance(self.action_space, spaces.Box):
|
||||
|
|
@ -137,7 +137,7 @@ class PPO(BaseRLModel):
|
|||
self.rollout_buffer = RolloutBuffer(self.n_steps, state_dim, action_dim, self.device,
|
||||
gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.n_envs)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, use_sde=self.use_sde, device=self.device,
|
||||
self.lr_schedule, use_sde=self.use_sde, device=self.device,
|
||||
**self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
|
||||
|
|
@ -198,7 +198,7 @@ class PPO(BaseRLModel):
|
|||
|
||||
return obs, continue_training
|
||||
|
||||
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
|
||||
def train(self, n_epochs: int, batch_size: int = 64) -> None:
|
||||
# Update optimizer learning rate
|
||||
self._update_learning_rate(self.policy.optimizer)
|
||||
# Compute current clip range
|
||||
|
|
@ -207,9 +207,14 @@ class PPO(BaseRLModel):
|
|||
if self.clip_range_vf is not None:
|
||||
clip_range_vf = self.clip_range_vf(self._current_progress)
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
entropy_losses, all_kl_divs = [], []
|
||||
pg_losses, value_losses = [], []
|
||||
clip_fractions = []
|
||||
|
||||
# train for gradient_steps epochs
|
||||
for epoch in range(n_epochs):
|
||||
approx_kl_divs = []
|
||||
# Sample replay buffer
|
||||
# Do a complete pass on the rollout buffer
|
||||
for rollout_data in self.rollout_buffer.get(batch_size):
|
||||
|
||||
actions = rollout_data.actions
|
||||
|
|
@ -236,6 +241,11 @@ class PPO(BaseRLModel):
|
|||
policy_loss_2 = advantages * th.clamp(ratio, 1 - clip_range, 1 + clip_range)
|
||||
policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
|
||||
|
||||
# Logging
|
||||
pg_losses.append(policy_loss.item())
|
||||
clip_fraction = th.mean((th.abs(ratio - 1) > clip_range).float()).item()
|
||||
clip_fractions.append(clip_fraction)
|
||||
|
||||
if self.clip_range_vf is None:
|
||||
# No clipping
|
||||
values_pred = values
|
||||
|
|
@ -246,6 +256,7 @@ class PPO(BaseRLModel):
|
|||
clip_range_vf)
|
||||
# Value loss using the TD(gae_lambda) target
|
||||
value_loss = F.mse_loss(rollout_data.returns, values_pred)
|
||||
value_losses.append(value_loss.item())
|
||||
|
||||
# Entropy loss favor exploration
|
||||
if entropy is None:
|
||||
|
|
@ -254,6 +265,8 @@ class PPO(BaseRLModel):
|
|||
else:
|
||||
entropy_loss = -th.mean(entropy)
|
||||
|
||||
entropy_losses.append(entropy_loss.item())
|
||||
|
||||
loss = policy_loss + self.ent_coef * entropy_loss + self.vf_coef * value_loss
|
||||
|
||||
# Optimization step
|
||||
|
|
@ -264,23 +277,27 @@ class PPO(BaseRLModel):
|
|||
self.policy.optimizer.step()
|
||||
approx_kl_divs.append(th.mean(rollout_data.old_log_prob - log_prob).detach().cpu().numpy())
|
||||
|
||||
all_kl_divs.append(np.mean(approx_kl_divs))
|
||||
|
||||
if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl:
|
||||
print("Early stopping at step {} due to reaching max kl: {:.2f}".format(gradient_step,
|
||||
np.mean(approx_kl_divs)))
|
||||
print(f"Early stopping at step {epoch} due to reaching max kl: {np.mean(approx_kl_divs):.2f}")
|
||||
break
|
||||
|
||||
self._n_updates += n_epochs
|
||||
explained_var = explained_variance(self.rollout_buffer.returns.flatten(),
|
||||
self.rollout_buffer.values.flatten())
|
||||
|
||||
logger.logkv("n_updates", self._n_updates)
|
||||
logger.logkv("clip_fraction", np.mean(clip_fraction))
|
||||
logger.logkv("clip_range", clip_range)
|
||||
if self.clip_range_vf is not None:
|
||||
logger.logkv("clip_range_vf", clip_range_vf)
|
||||
|
||||
logger.logkv("approx_kl", np.mean(approx_kl_divs))
|
||||
logger.logkv("explained_variance", explained_var)
|
||||
# TODO: gather stats for the entropy and other losses?
|
||||
logger.logkv("entropy_loss", entropy_loss.item())
|
||||
logger.logkv("policy_loss", policy_loss.item())
|
||||
logger.logkv("value_loss", value_loss.item())
|
||||
logger.logkv("entropy_loss", np.mean(entropy_losses))
|
||||
logger.logkv("policy_gradient_loss", np.mean(pg_losses))
|
||||
logger.logkv("value_loss", np.mean(value_losses))
|
||||
if hasattr(self.policy, 'log_std'):
|
||||
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import gym
|
|||
import torch as th
|
||||
import torch.nn as nn
|
||||
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, \
|
||||
create_sde_feature_extractor
|
||||
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
|
||||
create_sde_feature_extractor)
|
||||
from torchy_baselines.common.distributions import SquashedDiagGaussianDistribution, StateDependentNoiseDistribution
|
||||
|
||||
# CAP the standard deviation of the actor
|
||||
|
|
@ -161,8 +161,8 @@ class SACPolicy(BasePolicy):
|
|||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param learning_rate: (callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
:param lr_schedule: (callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param activation_fn: (nn.Module) Activation function
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
|
|
@ -177,7 +177,7 @@ class SACPolicy(BasePolicy):
|
|||
"""
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
learning_rate: Callable,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
activation_fn: nn.Module = nn.ReLU,
|
||||
|
|
@ -213,16 +213,16 @@ class SACPolicy(BasePolicy):
|
|||
self.actor, self.actor_target = None, None
|
||||
self.critic, self.critic_target = None, None
|
||||
|
||||
self._build(learning_rate)
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, learning_rate: Callable) -> None:
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
self.actor = self.make_actor()
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=learning_rate(1))
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=lr_schedule(1))
|
||||
|
||||
self.critic = self.make_critic()
|
||||
self.critic_target = self.make_critic()
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1))
|
||||
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1))
|
||||
|
||||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
|
|
@ -234,7 +234,7 @@ class SACPolicy(BasePolicy):
|
|||
return self.predict(obs, deterministic=False)
|
||||
|
||||
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.actor.forward(observation, deterministic)
|
||||
return self.actor(observation, deterministic)
|
||||
|
||||
|
||||
MlpPolicy = SACPolicy
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ class SAC(OffPolicyRLModel):
|
|||
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
|
||||
use_sde_at_warmup=use_sde_at_warmup)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.target_entropy = target_entropy
|
||||
self.log_ent_coef = None # type: Optional[th.Tensor]
|
||||
self.target_update_interval = target_update_interval
|
||||
|
|
@ -119,7 +118,7 @@ class SAC(OffPolicyRLModel):
|
|||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_learning_rate()
|
||||
self._setup_lr_schedule()
|
||||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
if self.seed is not None:
|
||||
self.set_random_seed(self.seed)
|
||||
|
|
@ -146,7 +145,7 @@ class SAC(OffPolicyRLModel):
|
|||
# Note: we optimize the log of the entropy coeff which is slightly different from the paper
|
||||
# as discussed in https://github.com/rail-berkeley/softlearning/issues/37
|
||||
self.log_ent_coef = th.log(th.ones(1, device=self.device) * init_value).requires_grad_(True)
|
||||
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.learning_rate(1))
|
||||
self.ent_coef_optimizer = th.optim.Adam([self.log_ent_coef], lr=self.lr_schedule(1))
|
||||
else:
|
||||
# Force conversion to float
|
||||
# this will throw an error if a malformed string (different from 'auto')
|
||||
|
|
@ -155,7 +154,7 @@ class SAC(OffPolicyRLModel):
|
|||
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, use_sde=self.use_sde,
|
||||
self.lr_schedule, use_sde=self.use_sde,
|
||||
device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
|
@ -173,8 +172,8 @@ class SAC(OffPolicyRLModel):
|
|||
|
||||
self._update_learning_rate(optimizers)
|
||||
|
||||
ent_coef_loss, ent_coef = th.zeros(1), th.zeros(1)
|
||||
actor_loss, critic_loss = th.zeros(1), th.zeros(1)
|
||||
ent_coef_losses, ent_coefs = [], []
|
||||
actor_losses, critic_losses = [], []
|
||||
|
||||
for gradient_step in range(gradient_steps):
|
||||
# Sample replay buffer
|
||||
|
|
@ -195,9 +194,12 @@ class SAC(OffPolicyRLModel):
|
|||
# see https://github.com/rail-berkeley/softlearning/issues/60
|
||||
ent_coef = th.exp(self.log_ent_coef.detach())
|
||||
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
|
||||
ent_coef_losses.append(ent_coef_loss.item())
|
||||
else:
|
||||
ent_coef = self.ent_coef_tensor
|
||||
|
||||
ent_coefs.append(ent_coef.item())
|
||||
|
||||
# Optimize entropy coefficient, also called
|
||||
# entropy temperature or alpha in the paper
|
||||
if ent_coef_loss is not None:
|
||||
|
|
@ -221,6 +223,7 @@ class SAC(OffPolicyRLModel):
|
|||
|
||||
# Compute critic loss
|
||||
critic_loss = 0.5 * (F.mse_loss(current_q1, q_backup) + F.mse_loss(current_q2, q_backup))
|
||||
critic_losses.append(critic_loss.item())
|
||||
|
||||
# Optimize the critic
|
||||
self.critic.optimizer.zero_grad()
|
||||
|
|
@ -232,6 +235,7 @@ class SAC(OffPolicyRLModel):
|
|||
qf1_pi, qf2_pi = self.critic.forward(replay_data.observations, actions_pi)
|
||||
min_qf_pi = th.min(qf1_pi, qf2_pi)
|
||||
actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
|
||||
actor_losses.append(actor_loss.item())
|
||||
|
||||
# Optimize the actor
|
||||
self.actor.optimizer.zero_grad()
|
||||
|
|
@ -243,12 +247,14 @@ class SAC(OffPolicyRLModel):
|
|||
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
|
||||
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
|
||||
|
||||
# TODO: average
|
||||
logger.logkv("ent_coef", ent_coef.item())
|
||||
logger.logkv("actor_loss", actor_loss.item())
|
||||
logger.logkv("critic_loss", critic_loss.item())
|
||||
if ent_coef_loss is not None:
|
||||
logger.logkv("ent_coef_loss", ent_coef_loss.item())
|
||||
self._n_updates += gradient_steps
|
||||
|
||||
logger.logkv("n_updates", self._n_updates)
|
||||
logger.logkv("ent_coef", np.mean(ent_coefs))
|
||||
logger.logkv("actor_loss", np.mean(actor_losses))
|
||||
logger.logkv("critic_loss", np.mean(critic_losses))
|
||||
if len(ent_coef_losses) > 0:
|
||||
logger.logkv("ent_coef_loss", np.mean(ent_coef_losses))
|
||||
|
||||
def learn(self,
|
||||
total_timesteps: int,
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import torch
|
||||
from typing import Optional, List, Tuple, Callable, Union
|
||||
|
||||
import gym
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from torchy_baselines.common.policies import (BasePolicy, register_policy, create_mlp, BaseNetwork,
|
||||
create_sde_feature_extractor)
|
||||
from torchy_baselines.common.distributions import StateDependentNoiseDistribution
|
||||
from torchy_baselines.common.policies import BasePolicy, register_policy, create_mlp, BaseNetwork, \
|
||||
create_sde_feature_extractor
|
||||
|
||||
|
||||
class Actor(BaseNetwork):
|
||||
|
|
@ -76,7 +77,7 @@ class Actor(BaseNetwork):
|
|||
actor_net = create_mlp(obs_dim, action_dim, net_arch, activation_fn, squash_output=True)
|
||||
self.mu = nn.Sequential(*actor_net)
|
||||
|
||||
def get_std(self) -> torch.Tensor:
|
||||
def get_std(self) -> th.Tensor:
|
||||
"""
|
||||
Retrieve the standard deviation of the action distribution.
|
||||
Only useful when using SDE.
|
||||
|
|
@ -92,7 +93,7 @@ class Actor(BaseNetwork):
|
|||
mean_actions = self.mu(latent_pi)
|
||||
return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde)
|
||||
|
||||
def _get_latent(self, obs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def _get_latent(self, obs) -> Tuple[th.Tensor, th.Tensor]:
|
||||
latent_pi = self.latent_pi(obs)
|
||||
|
||||
if self.sde_feature_extractor is not None:
|
||||
|
|
@ -101,7 +102,7 @@ class Actor(BaseNetwork):
|
|||
latent_sde = latent_pi
|
||||
return latent_pi, latent_sde
|
||||
|
||||
def evaluate_actions(self, obs: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def evaluate_actions(self, obs: th.Tensor, action: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
"""
|
||||
Evaluate actions according to the current policy,
|
||||
given the observations. Only useful when using SDE.
|
||||
|
|
@ -123,7 +124,7 @@ class Actor(BaseNetwork):
|
|||
"""
|
||||
self.action_dist.sample_weights(self.log_std)
|
||||
|
||||
def forward(self, obs: torch.Tensor, deterministic: bool = True) -> torch.Tensor:
|
||||
def forward(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
|
||||
if self.use_sde:
|
||||
latent_pi, latent_sde = self._get_latent(obs)
|
||||
if deterministic:
|
||||
|
|
@ -162,11 +163,11 @@ class Critic(BaseNetwork):
|
|||
q2_net = create_mlp(obs_dim + action_dim, 1, net_arch, activation_fn)
|
||||
self.q2_net = nn.Sequential(*q2_net)
|
||||
|
||||
def forward(self, obs: torch.Tensor, action: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, obs: th.Tensor, action: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
|
||||
qvalue_input = th.cat([obs, action], dim=1)
|
||||
return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
|
||||
|
||||
def q1_forward(self, obs: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
|
||||
def q1_forward(self, obs: th.Tensor, action: th.Tensor) -> th.Tensor:
|
||||
return self.q1_net(th.cat([obs, action], dim=1))
|
||||
|
||||
|
||||
|
|
@ -175,10 +176,11 @@ class ValueFunction(BaseNetwork):
|
|||
Value function for TD3 when doing on-policy exploration with SDE.
|
||||
|
||||
:param obs_dim: (int) Dimension of the observation
|
||||
:param net_arch: ([int]) Network architecture
|
||||
:param net_arch: (Optional[List[int]]) Network architecture
|
||||
:param activation_fn: (nn.Module) Activation function
|
||||
"""
|
||||
def __init__(self, obs_dim, net_arch=None, activation_fn=nn.Tanh):
|
||||
def __init__(self, obs_dim: int, net_arch: Optional[List[int]] = None,
|
||||
activation_fn: nn.Module = nn.Tanh):
|
||||
super(ValueFunction, self).__init__()
|
||||
|
||||
if net_arch is None:
|
||||
|
|
@ -187,7 +189,7 @@ class ValueFunction(BaseNetwork):
|
|||
vf_net = create_mlp(obs_dim, 1, net_arch, activation_fn)
|
||||
self.vf_net = nn.Sequential(*vf_net)
|
||||
|
||||
def forward(self, obs):
|
||||
def forward(self, obs: th.Tensor) -> th.Tensor:
|
||||
return self.vf_net(obs)
|
||||
|
||||
|
||||
|
|
@ -197,8 +199,8 @@ class TD3Policy(BasePolicy):
|
|||
|
||||
:param observation_space: (gym.spaces.Space) Observation space
|
||||
:param action_space: (gym.spaces.Space) Action space
|
||||
:param learning_rate: (callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: ([int or dict]) The specification of the policy and value networks.
|
||||
:param lr_schedule: (Callable) Learning rate schedule (could be constant)
|
||||
:param net_arch: (Optional[List[int]]) The specification of the policy and value networks.
|
||||
:param device: (str or th.device) Device on which the code should run.
|
||||
:param activation_fn: (nn.Module) Activation function
|
||||
:param use_sde: (bool) Whether to use State Dependent Exploration or not
|
||||
|
|
@ -210,10 +212,18 @@ class TD3Policy(BasePolicy):
|
|||
a positive standard deviation (cf paper). It allows to keep variance
|
||||
above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough.
|
||||
"""
|
||||
def __init__(self, observation_space, action_space,
|
||||
learning_rate, net_arch=None, device='cpu',
|
||||
activation_fn=nn.ReLU, use_sde=False, log_std_init=-3,
|
||||
clip_noise=None, lr_sde=3e-4, sde_net_arch=None, use_expln=False):
|
||||
def __init__(self, observation_space: gym.spaces.Space,
|
||||
action_space: gym.spaces.Space,
|
||||
lr_schedule: Callable,
|
||||
net_arch: Optional[List[int]] = None,
|
||||
device: Union[th.device, str] = 'cpu',
|
||||
activation_fn: nn.Module = nn.ReLU,
|
||||
use_sde: bool = False,
|
||||
log_std_init: float = -3,
|
||||
clip_noise: Optional[float] = None,
|
||||
lr_sde: float = 3e-4,
|
||||
sde_net_arch: Optional[List[int]] = None,
|
||||
use_expln: bool = False):
|
||||
super(TD3Policy, self).__init__(observation_space, action_space, device, squash_output=True)
|
||||
|
||||
# Default network architecture, from the original paper
|
||||
|
|
@ -247,37 +257,37 @@ class TD3Policy(BasePolicy):
|
|||
self.use_sde = use_sde
|
||||
self.vf_net = None
|
||||
self.log_std_init = log_std_init
|
||||
self._build(learning_rate)
|
||||
self._build(lr_schedule)
|
||||
|
||||
def _build(self, learning_rate):
|
||||
def _build(self, lr_schedule: Callable) -> None:
|
||||
self.actor = self.make_actor()
|
||||
self.actor_target = self.make_actor()
|
||||
self.actor_target.load_state_dict(self.actor.state_dict())
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=learning_rate(1))
|
||||
self.actor.optimizer = th.optim.Adam(self.actor.parameters(), lr=lr_schedule(1))
|
||||
|
||||
self.critic = self.make_critic()
|
||||
self.critic_target = self.make_critic()
|
||||
self.critic_target.load_state_dict(self.critic.state_dict())
|
||||
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=learning_rate(1))
|
||||
self.critic.optimizer = th.optim.Adam(self.critic.parameters(), lr=lr_schedule(1))
|
||||
|
||||
if self.use_sde:
|
||||
self.vf_net = ValueFunction(self.obs_dim)
|
||||
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()})
|
||||
self.actor.sde_optimizer.add_param_group({'params': self.vf_net.parameters()}) # pytype: disable=attribute-error
|
||||
|
||||
def reset_noise(self):
|
||||
def reset_noise(self) -> None:
|
||||
return self.actor.reset_noise()
|
||||
|
||||
def make_actor(self):
|
||||
def make_actor(self) -> Actor:
|
||||
return Actor(**self.actor_kwargs).to(self.device)
|
||||
|
||||
def make_critic(self):
|
||||
def make_critic(self) -> Critic:
|
||||
return Critic(**self.net_args).to(self.device)
|
||||
|
||||
def forward(self, obs, deterministic=True):
|
||||
return self.actor(obs, deterministic=deterministic)
|
||||
def forward(self, observation: th.Tensor, deterministic: bool = False):
|
||||
return self.predict(observation, deterministic=deterministic)
|
||||
|
||||
def predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor:
|
||||
return self.forward(observation, deterministic)
|
||||
return self.actor(observation, deterministic=deterministic)
|
||||
|
||||
|
||||
MlpPolicy = TD3Policy
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import torch as th
|
|||
import torch.nn.functional as F
|
||||
from typing import List, Tuple, Type, Union, Callable, Optional, Dict, Any
|
||||
|
||||
from torchy_baselines.common import logger
|
||||
from torchy_baselines.common.base_class import OffPolicyRLModel
|
||||
from torchy_baselines.common.buffers import ReplayBuffer
|
||||
from torchy_baselines.common.noise import ActionNoise
|
||||
|
|
@ -117,12 +118,12 @@ class TD3(OffPolicyRLModel):
|
|||
self._setup_model()
|
||||
|
||||
def _setup_model(self) -> None:
|
||||
self._setup_learning_rate()
|
||||
self._setup_lr_schedule()
|
||||
obs_dim, action_dim = self.observation_space.shape[0], self.action_space.shape[0]
|
||||
self.set_random_seed(self.seed)
|
||||
self.replay_buffer = ReplayBuffer(self.buffer_size, obs_dim, action_dim, self.device)
|
||||
self.policy = self.policy_class(self.observation_space, self.action_space,
|
||||
self.learning_rate, use_sde=self.use_sde,
|
||||
self.lr_schedule, use_sde=self.use_sde,
|
||||
device=self.device, **self.policy_kwargs)
|
||||
self.policy = self.policy.to(self.device)
|
||||
self._create_aliases()
|
||||
|
|
@ -215,6 +216,10 @@ class TD3(OffPolicyRLModel):
|
|||
if gradient_step % policy_delay == 0:
|
||||
self.train_actor(replay_data=replay_data, tau_actor=self.tau, tau_critic=self.tau)
|
||||
|
||||
self._n_updates += gradient_steps
|
||||
logger.logkv("n_updates", self._n_updates)
|
||||
|
||||
|
||||
def train_sde(self) -> None:
|
||||
# Update optimizer learning rate
|
||||
# self._update_learning_rate(self.policy.optimizer)
|
||||
|
|
|
|||
Loading…
Reference in a new issue