diff --git a/torchy_baselines/ppo/ppo.py b/torchy_baselines/ppo/ppo.py index e9d8bdf..d9496f0 100644 --- a/torchy_baselines/ppo/ppo.py +++ b/torchy_baselines/ppo/ppo.py @@ -135,6 +135,10 @@ class PPO(BaseRLModel): self.clip_range = get_schedule_fn(self.clip_range) if self.clip_range_vf is not None: + if isinstance(self.clip_range_vf, (float, int)): + assert self.clip_range_vf > 0, ('`clip_range_vf` must be positive, ' + 'pass `None` to deactivate vf clipping') + self.clip_range_vf = get_schedule_fn(self.clip_range_vf) def collect_rollouts(self,