mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Some modifications to improve readability (#31352)
Summary: In the long string, formalstring thinks it is good to have a name. When using dict, literal is better for readability and faster than dict constructor. I always appreciate your efforts in creating the world's best frameworks. Pull Request resolved: https://github.com/pytorch/pytorch/pull/31352 Differential Revision: D19191967 Pulled By: ngimel fbshipit-source-id: 21f063b163b67de8cf9761a4db5991f74318e991
This commit is contained in:
parent
7078f4b27d
commit
d770fbc1d2
3 changed files with 23 additions and 19 deletions
|
|
@ -85,9 +85,12 @@ def trainbench(name, rnn_creator, nloops=100, warmup=10,
|
|||
return fwd_time, bwd_time
|
||||
|
||||
assert device == 'cuda'
|
||||
creator_args = dict(seqLength=seqLength, numLayers=numLayers,
|
||||
inputSize=inputSize, hiddenSize=hiddenSize,
|
||||
miniBatch=miniBatch, device=device, seed=seed)
|
||||
creator_args = creator_args = {
|
||||
'seqLength': seqLength, 'numLayers': numLayers,
|
||||
'inputSize': inputSize, 'hiddenSize': hiddenSize,
|
||||
'miniBatch': miniBatch, 'device': device, 'seed': seed
|
||||
}
|
||||
|
||||
modeldef = rnn_creator(**creator_args)
|
||||
|
||||
[train_batch(modeldef) for _ in range(warmup)]
|
||||
|
|
@ -217,7 +220,7 @@ if __name__ == '__main__':
|
|||
del bench_args['cnns']
|
||||
del bench_args['variable_lstms']
|
||||
|
||||
results = dict()
|
||||
results = {}
|
||||
if should_bench_varlen_lstms:
|
||||
if args.nloops + args.warmup > 30:
|
||||
print_stderr(
|
||||
|
|
|
|||
|
|
@ -161,10 +161,10 @@ def validate_cuda_device(location):
|
|||
'to map your storages to the CPU.')
|
||||
if device >= torch.cuda.device_count():
|
||||
raise RuntimeError('Attempting to deserialize object on CUDA device '
|
||||
'{} but torch.cuda.device_count() is {}. Please use '
|
||||
'{device} but torch.cuda.device_count() is {device_count}. Please use '
|
||||
'torch.load with map_location to map your storages '
|
||||
'to an existing device.'.format(
|
||||
device, torch.cuda.device_count()))
|
||||
device=device, device_count=torch.cuda.device_count()))
|
||||
return device
|
||||
|
||||
|
||||
|
|
@ -188,8 +188,8 @@ def location_tag(storage):
|
|||
location = tagger(storage)
|
||||
if location:
|
||||
return location
|
||||
raise RuntimeError("don't know how to determine data location of " +
|
||||
torch.typename(storage))
|
||||
raise RuntimeError("don't know how to determine data location of "
|
||||
+ torch.typename(storage))
|
||||
|
||||
|
||||
def default_restore_location(storage, location):
|
||||
|
|
@ -197,9 +197,9 @@ def default_restore_location(storage, location):
|
|||
result = fn(storage, location)
|
||||
if result is not None:
|
||||
return result
|
||||
raise RuntimeError("don't know how to restore data location of " +
|
||||
torch.typename(storage) + " (tagged with " +
|
||||
location + ")")
|
||||
raise RuntimeError("don't know how to restore data location of "
|
||||
+ torch.typename(storage) + " (tagged with "
|
||||
+ location + ")")
|
||||
|
||||
|
||||
def normalize_storage_type(storage_type):
|
||||
|
|
@ -320,9 +320,9 @@ def _check_seekable(f):
|
|||
def raise_err_msg(patterns, e):
|
||||
for p in patterns:
|
||||
if p in str(e):
|
||||
msg = (str(e) + ". You can only torch.load from a file that is seekable." +
|
||||
" Please pre-load the data into a buffer like io.BytesIO and" +
|
||||
" try to load from it instead.")
|
||||
msg = (str(e) + ". You can only torch.load from a file that is seekable."
|
||||
+ " Please pre-load the data into a buffer like io.BytesIO and"
|
||||
+ " try to load from it instead.")
|
||||
raise type(e)(msg)
|
||||
raise e
|
||||
|
||||
|
|
@ -656,8 +656,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
|||
"accessing the object's source attribute or set "
|
||||
"`torch.nn.Module.dump_patches = True` and use the "
|
||||
"patch tool to revert the changes.")
|
||||
msg = ("source code of class '{}' has changed. {}"
|
||||
.format(torch.typename(container_type), msg))
|
||||
msg = ("source code of class '{container_type}' has changed. {msg}"
|
||||
.format(container_type=torch.typename(container_type), msg=msg))
|
||||
warnings.warn(msg, SourceChangeWarning)
|
||||
|
||||
def legacy_load(f):
|
||||
|
|
@ -753,7 +753,8 @@ def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
|
|||
except tarfile.TarError:
|
||||
if _is_zipfile(f):
|
||||
# .zip is used for torch.jit.save and will throw an un-pickling error here
|
||||
raise RuntimeError("{} is a zip archive (did you mean to use torch.jit.load()?)".format(f.name))
|
||||
raise RuntimeError(
|
||||
"{filename} is a zip archive (did you mean to use torch.jit.load()?)".format(filename=f.name))
|
||||
# if not a tarfile, reset file offset and proceed
|
||||
f.seek(0)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
import collections
|
||||
from collections import OrderedDict
|
||||
import weakref
|
||||
import warnings
|
||||
|
||||
|
|
@ -25,7 +25,7 @@ class RemovableHandle(object):
|
|||
def __setstate__(self, state):
|
||||
if state[0] is None:
|
||||
# create a dead reference
|
||||
self.hooks_dict_ref = weakref.ref(collections.OrderedDict())
|
||||
self.hooks_dict_ref = weakref.ref(OrderedDict())
|
||||
else:
|
||||
self.hooks_dict_ref = weakref.ref(state[0])
|
||||
self.id = state[1]
|
||||
|
|
|
|||
Loading…
Reference in a new issue