Fixes for python 2

This commit is contained in:
Antonin Raffin 2019-09-06 11:43:02 +02:00
parent 68028c71a1
commit 904742714d
4 changed files with 16 additions and 5 deletions

View file

@ -9,7 +9,7 @@ setup(name='torchy_baselines',
install_requires=[
'gym[classic_control]>=0.10.9',
'numpy',
'torch>=1.2.0+cpu' # torch>=1.2.0
'torch>=1.2.0' # torch>=1.2.0+cpu
],
extras_require={
'tests': [

View file

@ -1,4 +1,4 @@
from abc import ABC, abstractmethod
from abc import ABCMeta, abstractmethod
import numpy as np
@ -7,7 +7,7 @@ import gym
from torchy_baselines.common.policies import get_policy_from_name
class BaseRLModel(ABC):
class BaseRLModel(object):
"""
The base RL model
@ -17,6 +17,7 @@ class BaseRLModel(ABC):
:param verbose: (int) the verbosity level: 0 none, 1 training information, 2 debug
:param policy_base: (BasePolicy) the base policy used by this method
"""
__metaclass__ = ABCMeta
def __init__(self, policy, env, policy_base, policy_kwargs=None, verbose=0):
if isinstance(policy, str) and policy_base is not None:

View file

@ -11,6 +11,8 @@ class Actor(nn.Module):
if net_arch is None:
net_arch = [400, 300]
# TODO: orthogonal initialization?
self.actor_net = nn.Sequential(
nn.Linear(state_dim, net_arch[0]),
nn.ReLU(),
@ -52,7 +54,7 @@ class Critic(nn.Module):
return self.q1_net(qvalue_input), self.q2_net(qvalue_input)
def q1_forward(self, obs, action):
return self.q1_net( th.cat([obs, action], dim=1))
return self.q1_net(th.cat([obs, action], dim=1))
class TD3Policy(BasePolicy):

View file

@ -1,4 +1,5 @@
import sys
import time
import torch as th
import torch.nn.functional as F
@ -125,6 +126,7 @@ class TD3(BaseRLModel):
episode_num = 0
done = True
evaluations = []
start_time = time.time()
while self.num_timesteps < total_timesteps:
@ -146,6 +148,7 @@ class TD3(BaseRLModel):
evaluations.append(evaluate_policy(self, self.env, n_eval_episodes))
if self.verbose > 0:
print("Eval num_timesteps={}, mean_reward={:.2f}".format(self.num_timesteps, evaluations[-1]))
print("FPS: {:.2f}".format(self.num_timesteps / (time.time() - start_time)))
sys.stdout.flush()
# Reset environment
@ -167,7 +170,12 @@ class TD3(BaseRLModel):
# Rescale and perform action
new_obs, reward, done, _ = self.env.step(self.max_action * action)
done_bool = 0 if episode_timesteps + 1 == self.env._max_episode_steps else float(done)
if hasattr(self.env, '_max_episode_steps'):
done_bool = 0 if episode_timesteps + 1 == self.env._max_episode_steps else float(done)
else:
done_bool = float(done)
episode_reward += reward
# Store data in replay buffer