stable-baselines3/torchy_baselines/common/evaluation.py
Antonin Raffin d4e2dc8a9c Add CEM-RL
2019-09-06 14:01:10 +02:00

22 lines
635 B
Python

import numpy as np
def evaluate_policy(model, env, n_eval_episodes=10, deterministic=True, render=False):
"""
Runs policy for n episodes and returns average reward
"""
mean_reward, n_steps = 0.0, 0
for _ in range(n_eval_episodes):
obs = env.reset()
done = False
while not done:
action = model.predict(np.array(obs), deterministic=deterministic)
obs, reward, done, _ = env.step(action)
mean_reward += reward
n_steps += 1
if render:
env.render()
mean_reward /= n_eval_episodes
return mean_reward, n_steps