From 4b826f2e2d2bcc1f5244aa89d2c05571ef66ae0f Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 22 Apr 2020 10:56:17 +0200 Subject: [PATCH] Filter out features extractor weights --- torchy_baselines/common/base_class.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index ac49453..bce9cad 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -362,9 +362,25 @@ class BaseRLModel(ABC): raise NotImplementedError(f"{cls} has no `_setup_model()` method") model._setup_model() - # put state_dicts back in place + # Retrieve features extractor parameters + features_extractor_params = {name:param for name, param in params['policy'].items() + if name.startswith('features_extractor')} + + # Put state_dicts back in place for name in params: attr = recursive_getattr(model, name) + + # Complete the missing keys that normally correspond to the features extractor + # parameters that were not duplicated at save time + missing_keys = set(attr.state_dict().keys()) - set(params[name].keys()) + for missing_key in missing_keys: + if 'features_extractor' not in missing_key: + continue + # Match with feature extractor + # Remove top-level, e.g. actor.features_extractor.cnn -> features_extractor.cnn + feature_key = '.'.join(missing_key.split('.')[1:]) + params[name][missing_key] = features_extractor_params[feature_key] + attr.load_state_dict(params[name]) # put tensors back in place @@ -649,8 +665,11 @@ class BaseRLModel(ABC): params_to_save = {} for name in state_dicts_names: attr = recursive_getattr(self, name) + # Filter out features_extractor weights that are not top level to save space + state_dict = {key:val for key, val in attr.state_dict().items() + if not ('features_extractor' in key and not key.startswith('features_extractor'))} # Retrieve state dict - params_to_save[name] = attr.state_dict() + params_to_save[name] = state_dict self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)