Add plotting script

This commit is contained in:
Antonin Raffin 2019-10-31 16:59:35 +01:00
parent 9644ae89cf
commit 0e092f7c52

View file

@ -64,6 +64,8 @@ class A2C(PPO):
self.normalize_advantage = normalize_advantage
self.rms_prop_eps = rms_prop_eps
self.use_rms_prop = use_rms_prop
self.actions = []
self.states = []
if _init_setup_model:
self._setup_model()
@ -125,12 +127,39 @@ class A2C(PPO):
logger.logkv("value_loss", value_loss.item())
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
if self.use_sde:
pass
# print(th.exp(self.policy.log_std).detach())
self.states.append(self.rollout_buffer.observations.cpu().numpy())
self.actions.append(self.rollout_buffer.actions.cpu().numpy())
# Plot for MountainCarContinuous-v0
if True:
if len(self.actions) > 10:
import matplotlib.pyplot as plt
import numpy as np
actions = np.concatenate(self.actions)
x = np.arange(len(actions))
plt.figure("actions")
start = 0
for i in range(len(self.actions)):
end = start + len(self.actions[i])
# plt.plot(x[start:end], self.actions[i])
# Clipped actions: real behavior, note that it is between [-2, 2] for the Pendulum
plt.scatter(x[start:end], np.clip(self.actions[i], -1, 1), s=1)
# plt.scatter(x[start:end], self.actions[i], s=1)
start = end
plt.figure("states")
for i in range(len(self.states)):
if len(self.states[i].shape) > 1:
# plt.plot(self.states[i][:, 0], self.states[i][:, 1])
plt.scatter(self.states[i][:, 0], self.states[i][:, 1], s=1)
else:
plt.scatter(x[start:end], self.states[i], s=1)
plt.show()
import ipdb; ipdb.set_trace()
def learn(self, total_timesteps, callback=None, log_interval=100,
def learn(self, total_timesteps, callback=None, log_interval=5,
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="A2C", reset_num_timesteps=True):
return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,