mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-23 22:20:18 +00:00
Filter out features extractor weights
This commit is contained in:
parent
b289aca5fe
commit
4b826f2e2d
1 changed files with 21 additions and 2 deletions
|
|
@ -362,9 +362,25 @@ class BaseRLModel(ABC):
|
|||
raise NotImplementedError(f"{cls} has no `_setup_model()` method")
|
||||
model._setup_model()
|
||||
|
||||
# put state_dicts back in place
|
||||
# Retrieve features extractor parameters
|
||||
features_extractor_params = {name:param for name, param in params['policy'].items()
|
||||
if name.startswith('features_extractor')}
|
||||
|
||||
# Put state_dicts back in place
|
||||
for name in params:
|
||||
attr = recursive_getattr(model, name)
|
||||
|
||||
# Complete the missing keys that normally correspond to the features extractor
|
||||
# parameters that were not duplicated at save time
|
||||
missing_keys = set(attr.state_dict().keys()) - set(params[name].keys())
|
||||
for missing_key in missing_keys:
|
||||
if 'features_extractor' not in missing_key:
|
||||
continue
|
||||
# Match with feature extractor
|
||||
# Remove top-level, e.g. actor.features_extractor.cnn -> features_extractor.cnn
|
||||
feature_key = '.'.join(missing_key.split('.')[1:])
|
||||
params[name][missing_key] = features_extractor_params[feature_key]
|
||||
|
||||
attr.load_state_dict(params[name])
|
||||
|
||||
# put tensors back in place
|
||||
|
|
@ -649,8 +665,11 @@ class BaseRLModel(ABC):
|
|||
params_to_save = {}
|
||||
for name in state_dicts_names:
|
||||
attr = recursive_getattr(self, name)
|
||||
# Filter out features_extractor weights that are not top level to save space
|
||||
state_dict = {key:val for key, val in attr.state_dict().items()
|
||||
if not ('features_extractor' in key and not key.startswith('features_extractor'))}
|
||||
# Retrieve state dict
|
||||
params_to_save[name] = attr.state_dict()
|
||||
params_to_save[name] = state_dict
|
||||
|
||||
self._save_to_file_zip(path, data=data, params=params_to_save, tensors=tensors)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue