diff --git a/torchy_baselines/common/base_class.py b/torchy_baselines/common/base_class.py index 8f0acca..09fe798 100644 --- a/torchy_baselines/common/base_class.py +++ b/torchy_baselines/common/base_class.py @@ -329,9 +329,12 @@ class BaseRLModel(object): # go to start of file file_content.seek(0) params = th.load(file_content) + # check for all other .pth files other_file = [file_name for file_name in namelist if os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"] + # if there are any other files which end with .pth and aren't "params.pth" + # assume that they each are optimizer parameters if len(other_file) > 0: opt_params = dict() for file in other_file: @@ -341,6 +344,7 @@ class BaseRLModel(object): file_content.write(opt_param_file.read()) # go to start of file file_content.seek(0) + # save the parameters in dict with file name but trim file ending opt_params[os.path.splitext(file)[0]] = th.load(file_content) except zipfile.BadZipFile: # load_path wasn't a zip file