mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-06-05 00:00:04 +00:00
Fixes for python 2
This commit is contained in:
parent
68028c71a1
commit
904742714d
4 changed files with 16 additions and 5 deletions
2
setup.py
2
setup.py
|
|
@ -9,7 +9,7 @@ setup(name='torchy_baselines',
|
|||
install_requires=[
|
||||
'gym[classic_control]>=0.10.9',
|
||||
'numpy',
|
||||
'torch>=1.2.0+cpu' # torch>=1.2.0
|
||||
'torch>=1.2.0' # torch>=1.2.0+cpu
|
||||
],
|
||||
extras_require={
|
||||
'tests': [
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -7,7 +7,7 @@ import gym
|
|||
from torchy_baselines.common.policies import get_policy_from_name
|
||||
|
||||
|
||||
class BaseRLModel(ABC):
|
||||
class BaseRLModel(object):
|
||||
"""
|
||||
The base RL model
|
||||
|
||||
|
|
@ -17,6 +17,7 @@ class BaseRLModel(ABC):
|
|||
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug
|
||||
:param policy_base: (BasePolicy) the base policy used by this method
|
||||
"""
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
def __init__(self, policy, env, policy_base, policy_kwargs=None, verbose=0):
|
||||
if isinstance(policy, str) and policy_base is not None:
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ class Actor(nn.Module):
|
|||
if net_arch is None:
|
||||
net_arch = [400, 300]
|
||||
|
||||
# TODO: orthogonal initialization?
|
||||
|
||||
self.actor_net = nn.Sequential(
|
||||
nn.Linear(state_dim, net_arch[0]),
|
||||
nn.ReLU(),
|
||||
|
|
@ -52,7 +54,7 @@ class Critic(nn.Module):
|
|||
return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
|
||||
|
||||
def q1_forward(self, obs, action):
|
||||
return self.q1_net( th.cat([obs, action], dim=1))
|
||||
return self.q1_net(th.cat([obs, action], dim=1))
|
||||
|
||||
|
||||
class TD3Policy(BasePolicy):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import sys
|
||||
import time
|
||||
|
||||
import torch as th
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -125,6 +126,7 @@ class TD3(BaseRLModel):
|
|||
episode_num = 0
|
||||
done = True
|
||||
evaluations = []
|
||||
start_time = time.time()
|
||||
|
||||
while self.num_timesteps < total_timesteps:
|
||||
|
||||
|
|
@ -146,6 +148,7 @@ class TD3(BaseRLModel):
|
|||
evaluations.append(evaluate_policy(self, self.env, n_eval_episodes))
|
||||
if self.verbose > 0:
|
||||
print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1]))
|
||||
print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - start_time)))
|
||||
sys.stdout.flush()
|
||||
|
||||
# Reset environment
|
||||
|
|
@ -167,7 +170,12 @@ class TD3(BaseRLModel):
|
|||
|
||||
# Rescale and perform action
|
||||
new_obs, reward, done, _ = self.env.step(self.max_action * action)
|
||||
done_bool = 0 if episode_timesteps + 1 == self.env._max_episode_steps else float(done)
|
||||
|
||||
if hasattr(self.env, '_max_episode_steps'):
|
||||
done_bool = 0 if episode_timesteps + 1 == self.env._max_episode_steps else float(done)
|
||||
else:
|
||||
done_bool = float(done)
|
||||
|
||||
episode_reward += reward
|
||||
|
||||
# Store data in replay buffer
|
||||
|
|
|
|||
Loading…
Reference in a new issue