mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-17 21:20:11 +00:00
Buf fixes for python 2
This commit is contained in:
parent
ab64ff464e
commit
64de9923d6
4 changed files with 20 additions and 9 deletions
|
|
@ -282,7 +282,7 @@ def getkvs():
|
|||
return Logger.CURRENT.name2val
|
||||
|
||||
|
||||
def log(*args, level=INFO):
|
||||
def log(*args, **kwargs):
|
||||
"""
|
||||
Write the sequence of args, with no separators,
|
||||
to the console and output files (if you've configured an output file).
|
||||
|
|
@ -293,6 +293,7 @@ def log(*args, level=INFO):
|
|||
:param args: (list) log the arguments
|
||||
:param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
|
||||
"""
|
||||
level = kwargs.get('level', INFO)
|
||||
Logger.CURRENT.log(*args, level=level)
|
||||
|
||||
|
||||
|
|
@ -433,7 +434,7 @@ class Logger(object):
|
|||
self.name2val.clear()
|
||||
self.name2cnt.clear()
|
||||
|
||||
def log(self, *args, level=INFO):
|
||||
def log(self, *args, **kwargs):
|
||||
"""
|
||||
Write the sequence of args, with no separators,
|
||||
to the console and output files (if you've configured an output file).
|
||||
|
|
@ -444,6 +445,7 @@ class Logger(object):
|
|||
:param args: (list) log the arguments
|
||||
:param level: (int) the logging level (can be DEBUG=10, INFO=20, WARN=30, ERROR=40, DISABLED=50)
|
||||
"""
|
||||
level = kwargs.get('level', INFO)
|
||||
if self.level <= level:
|
||||
self._do_log(args)
|
||||
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class VecEnv(object):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
def env_method(self, method_name, *method_args, **method_kwargs):
|
||||
"""
|
||||
Call instance methods of vectorized environments.
|
||||
|
||||
|
|
@ -222,8 +222,8 @@ class VecEnvWrapper(VecEnv):
|
|||
def set_attr(self, attr_name, value, indices=None):
|
||||
return self.venv.set_attr(attr_name, value, indices)
|
||||
|
||||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
return self.venv.env_method(method_name, *method_args, indices=indices, **method_kwargs)
|
||||
def env_method(self, method_name, *method_args, **method_kwargs):
|
||||
return self.venv.env_method(method_name, *method_args, **method_kwargs)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Find attribute from wrapped venv(s) if this wrapper does not have it.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from torchy_baselines.common.vec_env import VecEnv
|
||||
|
|
@ -44,7 +46,7 @@ class DummyVecEnv(VecEnv):
|
|||
obs = self.envs[env_idx].reset()
|
||||
self._save_obs(env_idx, obs)
|
||||
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
|
||||
self.buf_infos.copy())
|
||||
deepcopy(self.buf_infos))
|
||||
|
||||
def reset(self):
|
||||
for env_idx in range(self.num_envs):
|
||||
|
|
@ -97,8 +99,11 @@ class DummyVecEnv(VecEnv):
|
|||
for env_i in target_envs:
|
||||
setattr(env_i, attr_name, value)
|
||||
|
||||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
def env_method(self, method_name, *method_args, **method_kwargs):
|
||||
"""Call instance methods of vectorized environments."""
|
||||
indices = method_kwargs.get('indices')
|
||||
if 'indices' in method_kwargs:
|
||||
del method_kwargs['indices']
|
||||
target_envs = self._get_target_envs(indices)
|
||||
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
|
||||
|
||||
|
|
|
|||
|
|
@ -153,7 +153,8 @@ class SubprocVecEnv(VecEnv):
|
|||
for pipe in self.remotes:
|
||||
# gather images from subprocesses
|
||||
# `mode` will be taken into account later
|
||||
pipe.send(('render', (args, {'mode': 'rgb_array', **kwargs})))
|
||||
kwargs.update({'mode': 'rgb_array'})
|
||||
pipe.send(('render', (args, kwargs)))
|
||||
imgs = [pipe.recv() for pipe in self.remotes]
|
||||
# Create a big image by tiling images from subprocesses
|
||||
bigimg = tile_images(imgs)
|
||||
|
|
@ -187,8 +188,11 @@ class SubprocVecEnv(VecEnv):
|
|||
for remote in target_remotes:
|
||||
remote.recv()
|
||||
|
||||
def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
|
||||
def env_method(self, method_name, *method_args, **method_kwargs):
|
||||
"""Call instance methods of vectorized environments."""
|
||||
indices = method_kwargs.get('indices')
|
||||
if 'indices' in method_kwargs:
|
||||
del method_kwargs['indices']
|
||||
target_remotes = self._get_target_remotes(indices)
|
||||
for remote in target_remotes:
|
||||
remote.send(('env_method', (method_name, method_args, method_kwargs)))
|
||||
|
|
|
|||
Loading…
Reference in a new issue