mirror of
https://github.com/saymrwulf/stable-baselines3.git
synced 2026-05-30 23:18:47 +00:00
added some comments to _load_from_file
This commit is contained in:
parent
6560ae9952
commit
8460bfe397
1 changed files with 4 additions and 0 deletions
|
|
@ -329,9 +329,12 @@ class BaseRLModel(object):
|
|||
# go to start of file
|
||||
file_content.seek(0)
|
||||
params = th.load(file_content)
|
||||
|
||||
# check for all other .pth files
|
||||
other_file = [file_name for file_name in namelist if
|
||||
os.path.splitext(file_name)[1] == ".pth" and file_name != "params.pth"]
|
||||
# if there are any other files which end with .pth and aren't "params.pth"
|
||||
# assume that they each are optimizer parameters
|
||||
if len(other_file) > 0:
|
||||
opt_params = dict()
|
||||
for file in other_file:
|
||||
|
|
@ -341,6 +344,7 @@ class BaseRLModel(object):
|
|||
file_content.write(opt_param_file.read())
|
||||
# go to start of file
|
||||
file_content.seek(0)
|
||||
# save the parameters in dict with file name but trim file ending
|
||||
opt_params[os.path.splitext(file)[0]] = th.load(file_content)
|
||||
except zipfile.BadZipFile:
|
||||
# load_path wasn't a zip file
|
||||
|
|
|
|||
Loading…
Reference in a new issue