mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-22 22:10:16 +00:00
Add plotting script
This commit is contained in:
parent
9644ae89cf
commit
0e092f7c52
1 changed files with 33 additions and 4 deletions
|
|
@ -64,6 +64,8 @@ class A2C(PPO):
|
|||
self.normalize_advantage = normalize_advantage
|
||||
self.rms_prop_eps = rms_prop_eps
|
||||
self.use_rms_prop = use_rms_prop
|
||||
self.actions = []
|
||||
self.states = []
|
||||
|
||||
if _init_setup_model:
|
||||
self._setup_model()
|
||||
|
|
@ -125,12 +127,39 @@ class A2C(PPO):
|
|||
logger.logkv("value_loss", value_loss.item())
|
||||
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
|
||||
|
||||
if self.use_sde:
|
||||
pass
|
||||
# print(th.exp(self.policy.log_std).detach())
|
||||
self.states.append(self.rollout_buffer.observations.cpu().numpy())
|
||||
self.actions.append(self.rollout_buffer.actions.cpu().numpy())
|
||||
|
||||
# Plot for MountainCarContinuous-v0
|
||||
if True:
|
||||
if len(self.actions) > 10:
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
actions = np.concatenate(self.actions)
|
||||
x = np.arange(len(actions))
|
||||
plt.figure("actions")
|
||||
start = 0
|
||||
for i in range(len(self.actions)):
|
||||
end = start + len(self.actions[i])
|
||||
# plt.plot(x[start:end], self.actions[i])
|
||||
# Clipped actions: real behavior, note that it is between [-2, 2] for the Pendulum
|
||||
plt.scatter(x[start:end], np.clip(self.actions[i], -1, 1), s=1)
|
||||
# plt.scatter(x[start:end], self.actions[i], s=1)
|
||||
start = end
|
||||
|
||||
plt.figure("states")
|
||||
for i in range(len(self.states)):
|
||||
if len(self.states[i].shape) > 1:
|
||||
# plt.plot(self.states[i][:, 0], self.states[i][:, 1])
|
||||
plt.scatter(self.states[i][:, 0], self.states[i][:, 1], s=1)
|
||||
else:
|
||||
plt.scatter(x[start:end], self.states[i], s=1)
|
||||
|
||||
plt.show()
|
||||
import ipdb; ipdb.set_trace()
|
||||
|
||||
|
||||
def learn(self, total_timesteps, callback=None, log_interval=100,
|
||||
def learn(self, total_timesteps, callback=None, log_interval=5,
|
||||
eval_env=None, eval_freq=-1, n_eval_episodes=5, tb_log_name="A2C", reset_num_timesteps=True):
|
||||
|
||||
return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
|
||||
|
|
|
|||
Loading…
Reference in a new issue