From 8f4155180ec01e7c231189e7cb5ef9c2a4aa7525 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Wed, 22 Apr 2020 10:56:33 +0200 Subject: [PATCH] Revert "Filter out features extractor weights" This reverts commit 93f9de799add0874878a91fe9eaf8162321066b5. --- torchy_baselines/common/base_class.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index bce9cad..ac49453 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -362,25 +362,9 @@ class BaseRLModel(ABC): raise NotImplementedError(f"{cls} has no `_setup_model()` method") model._setup_model() - # Retrieve features extractor parameters - features_extractor_params = {name:param for name, param in params['policy'].items() - if name.startswith('features_extractor')} - - # Put state_dicts back in place + # put state_dicts back in place for name in params: attr = recursive_getattr(model, name) - - # Complete the missing keys that normally correspond to the features extractor - # parameters that were not duplicated at save time - missing_keys = set(attr.state_dict().keys()) - set(params[name].keys()) - for missing_key in missing_keys: - if 'features_extractor' not in missing_key: - continue - # Match with feature extractor - # Remove top-level, e.g. actor.features_extractor.cnn -> features_extractor.cnn - feature_key = '.'.join(missing_key.split('.')[1:]) - params[name][missing_key] = features_extractor_params[feature_key] - attr.load_state_dict(params[name]) # put tensors back in place @@ -665,11 +649,8 @@ class BaseRLModel(ABC): params_to_save = {} for name in state_dicts_names: attr = recursive_getattr(self, name) - # Filter out features_extractor weights that are not top level to save space - state_dict = {key:val for key, val in attr.state_dict().items() - if not ('features_extractor' in key and not key.startswith('features_extractor'))} # Retrieve state dict - params_to_save[name] = state_dict + params_to_save[name] = attr.state_dict() self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)