diff --git a/torchy_baselines/a2c/a2c.py b/torchy_baselines/a2c/a2c.py index d355e09..fed4290 100644 --- a/torchy_baselines/a2c/a2c.py +++ b/torchy_baselines/a2c/a2c.py @@ -64,6 +64,8 @@ class A2C(PPO): self.normalize_advantage = normalize_advantage self.rms_prop_eps = rms_prop_eps self.use_rms_prop = use_rms_prop + self.actions = [] + self.states = [] if _init_setup_model: self._setup_model() @@ -125,12 +127,39 @@ class A2C(PPO): logger.logkv("value_loss", value_loss.item()) logger.logkv("std", th.exp(self.policy.log_std).mean().item()) - if self.use_sde: - pass - # print(th.exp(self.policy.log_std).detach()) + self.states.append(self.rollout_buffer.observations.cpu().numpy()) + self.actions.append(self.rollout_buffer.actions.cpu().numpy()) + + # Plot for MountainCarContinuous-v0 + if True: + if len(self.actions) > 10: + import matplotlib.pyplot as plt + import numpy as np + actions = np.concatenate(self.actions) + x = np.arange(len(actions)) + plt.figure("actions") + start = 0 + for i in range(len(self.actions)): + end = start + len(self.actions[i]) + # plt.plot(x[start:end], self.actions[i]) + # Clipped actions: real behavior, note that it is between [-2, 2] for the Pendulum + plt.scatter(x[start:end], np.clip(self.actions[i], -1, 1), s=1) + # plt.scatter(x[start:end], self.actions[i], s=1) + start = end + + plt.figure("states") + for i in range(len(self.states)): + if len(self.states[i].shape) > 1: + # plt.plot(self.states[i][:, 0], self.states[i][:, 1]) + plt.scatter(self.states[i][:, 0], self.states[i][:, 1], s=1) + else: + plt.scatter(x[start:end], self.states[i], s=1) + + plt.show() + import ipdb; ipdb.set_trace() - def learn(self, total_timesteps, callback=None, log_interval=100, + def learn(self, total_timesteps, callback=None, log_interval=5, eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="A2C", reset_num_timesteps=True): return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,