Filter out features extractor weights

This commit is contained in:
Antonin RAFFIN 2020-04-22 10:56:17 +02:00
parent b289aca5fe
commit 4b826f2e2d

View file

@ -362,9 +362,25 @@ class BaseRLModel(ABC):
raise NotImplementedError(f"{cls} has no `_setup_model()` method")
model._setup_model()
# put state_dicts back in place
# Retrieve features extractor parameters
features_extractor_params = {name:param for name, param in params['policy'].items()
if name.startswith('features_extractor')}
# Put state_dicts back in place
for name in params:
attr = recursive_getattr(model, name)
# Complete the missing keys that normally correspond to the features extractor
# parameters that were not duplicated at save time
missing_keys = set(attr.state_dict().keys()) - set(params[name].keys())
for missing_key in missing_keys:
if 'features_extractor' not in missing_key:
continue
# Match with feature extractor
# Remove top-level, e.g. actor.features_extractor.cnn -> features_extractor.cnn
feature_key = '.'.join(missing_key.split('.')[1:])
params[name][missing_key] = features_extractor_params[feature_key]
attr.load_state_dict(params[name])
# put tensors back in place
@ -649,8 +665,11 @@ class BaseRLModel(ABC):
params_to_save = {}
for name in state_dicts_names:
attr = recursive_getattr(self, name)
# Filter out features_extractor weights that are not top level to save space
state_dict = {key:val for key, val in attr.state_dict().items()
if not ('features_extractor' in key and not key.startswith('features_extractor'))}
# Retrieve state dict
params_to_save[name] = attr.state_dict()
params_to_save[name] = state_dict
self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)