mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Require at least one tensor to be marked dynamic with --dynamic-batch-only (#99620)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/99620 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
abdd1f4a38
commit
fc8fa6c356
2 changed files with 13 additions and 0 deletions
|
|
@ -2412,14 +2412,19 @@ def run(runner, args, original_dir=None):
|
|||
# NB: This must be done late enough so that we don't do more
|
||||
# conversions on the inputs
|
||||
# NB: Assumes only the first batch-y like dimension is the batch
|
||||
marked = False
|
||||
|
||||
def detect_and_mark_batch(t):
|
||||
nonlocal marked
|
||||
for i, s in enumerate(t.size()):
|
||||
if s == batch_size:
|
||||
torch._dynamo.mark_dynamic(t, i)
|
||||
marked = True
|
||||
break
|
||||
|
||||
if args.dynamic_batch_only:
|
||||
tree_map_only(torch.Tensor, detect_and_mark_batch, example_inputs)
|
||||
assert marked, f"nothing in example_inputs had a dim with {batch_size}"
|
||||
|
||||
if args.log_operator_inputs:
|
||||
log_operator_inputs(
|
||||
|
|
|
|||
|
|
@ -282,6 +282,10 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
|||
if self.args.accuracy and model_name in MAX_BATCH_SIZE_FOR_ACCURACY_CHECK:
|
||||
batch_size = min(batch_size, MAX_BATCH_SIZE_FOR_ACCURACY_CHECK[model_name])
|
||||
|
||||
# See https://github.com/pytorch/benchmark/issues/1560
|
||||
if model_name == "speech_transformer":
|
||||
batch_size = 10
|
||||
|
||||
# workaround "RuntimeError: not allowed to set torch.backends.cudnn flags"
|
||||
torch.backends.__allow_nonbracketed_mutation_flag = True
|
||||
extra_args = []
|
||||
|
|
@ -317,6 +321,10 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
|||
# the right example_inputs
|
||||
if model_name == "yolov3":
|
||||
example_inputs = (torch.rand(batch_size, 3, 384, 512).to(device),)
|
||||
# See https://github.com/pytorch/benchmark/issues/1561
|
||||
if model_name == "maml_omniglot":
|
||||
batch_size = 5
|
||||
assert example_inputs[0].shape[0] == batch_size
|
||||
# global current_name, current_device
|
||||
# current_device = device
|
||||
# current_name = benchmark.name
|
||||
|
|
|
|||
Loading…
Reference in a new issue