training support for dynamo+torchxla integration (#88449)

We've already shown some promising perf result by integrating dynamo with torchxla for inference. To provide consistent UX for training and for inference, in this PR we try to enable training for dynamo/torchxla.

Training is trickier than inference and we may not expect much perf gains since
1. in training case, torchxla only generate a single combined graph for fwd/bwd/optimizer while in `torchxla_trace_once` bridge we added in dynamo, due to how AOT_Autograd works, we will generate 3 graphs: one for forward, one for backward and one for the optimizer. XLA favors larger graph to do more optimizations.
2. in training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all training cares more about throughput while inference cares more about latency.
3. in training case, people can increase batch size to 'mitigate' the tracing overhead. Increase batch size does not change tracing overhead, thus it shows like the tracing overhead 'per example' reduces.

But we still want to add training support to dynamo/torchxla to make the work complete.

We added '--iterations-per-run' argument to control how may iterations we do per measure/device sync. This is to understand the impact of item 2 above.

Results:

With '--iterations-per-run' equals to 1, here are the perf numbers:
```
+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |             0.91   |                0.959    |
+-------------------------+--------------------+-------------------------+
| resnet50                |             0.917  |                0.932    |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |             0.912  |                0.905    |
+-------------------------+--------------------+-------------------------+
| alexnet                 |             1.038  |                0.974    |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |             0.881  |                0.835    |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |             0.903  |                0.931    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |             0.914  |                0.967    |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |             1.359  |                0.84     |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |             1.288  |                0.893    |
+-------------------------+--------------------+-------------------------+
| geomean                 |             1.0006 |                0.913794 |
+-------------------------+--------------------+-------------------------+
```

Overall it looks like graph break indeed cause perf loss. But for BERT_pytorch and timm_vision_transformer we still see perf gain. We need do more experiments with larger '--iterations-per-run'

NOTE:
In torchbench.py I added the following code to do a few workaround:
```
from myscripts import workaround # TODO will remove this line before landing
```

Here are the content of workaround.py:
```
import torch
from torch import nn
import os

# override max_pool2d with avg_pool2d
if os.environ.get("REPLACE_MAXPOOL", "0") == "1":
    torch.nn.MaxPool2d = torch.nn.AvgPool2d

```

It work around a few issues we found
1. MaxPool2d does not work for training in dynamo/torchxla: https://github.com/pytorch/torchdynamo/issues/1837 . WIP fix from Brian in https://github.com/pytorch/pytorch/pull/90226 , https://github.com/pytorch/xla/pull/4276/files (WIP)
2. recent change ( this PR https://github.com/pytorch/pytorch/pull/88697 ) in op decomposition cause batch_norm ops to fallback in torchxla. Fix from jack in https://github.com/pytorch/xla/pull/4282#event-7969608134 . (confirmed the fix after adding Deduper to handle duplicated return from fx graph generated by AOTAutograd)
3. we have issue to handle dropout because of random seed out of sync issue. Here is the fix: https://github.com/pytorch/xla/pull/4293 (confirmed the fix)

Example command:
```
REPLACE_MAXPOOL=1 USE_FAKE_TENSOR=0 GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only vgg16
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88449
Approved by: https://github.com/wconstab, https://github.com/qihqi, https://github.com/malfet
This commit is contained in:
Shunting Zhang 2023-01-05 19:59:34 +00:00 committed by PyTorch MergeBot
parent df4b3b13bc
commit a5f32f8978
5 changed files with 322 additions and 31 deletions

View file

@ -43,6 +43,11 @@ try:
except ImportError:
from microbenchmarks.operator_inp_utils import OperatorInputsMode
try:
import torch_xla.core.xla_model as xm
except ImportError:
# ignore the error if torch_xla is not installed
pass
log = logging.getLogger(__name__)
@ -285,30 +290,43 @@ def tensor_is_on_xla(tensors):
return any(map(lambda x: x.device.type == "xla", tensors))
def timed(model, model_iter_fn, example_inputs, times=1, return_result=False):
def timed(
model,
model_iter_fn,
example_inputs,
times=1,
return_result=False,
collect_outputs=False,
):
use_xla = tensor_is_on_xla(example_inputs)
synchronize()
if tensor_is_on_xla(example_inputs):
import torch_xla.core.xla_model as xm
if use_xla:
xm.mark_step()
xm.wait_device_ops()
reset_rng_state()
t0 = time.perf_counter()
# Dont collect outputs to correctly measure timing
for _ in range(times):
result = model_iter_fn(model, example_inputs, collect_outputs=False)
if tensor_is_on_xla(result):
# If the model is on XLA device, it's possible that after running
# the model, the computation is accumulated but not performed yet.
# Flush all the accumulated computations to make the time measurement
# accurate.
import torch_xla
# Put this call inside the loop to reset the seed for each iteration.
reset_rng_state(use_xla)
result = model_iter_fn(model, example_inputs, collect_outputs=collect_outputs)
result_list = result
if not isinstance(result, (tuple, list)):
result_list = [result]
torch_xla._XLAC._xla_sync_multi(result_list, [])
synchronize()
# instead of calling sync on result_list, we should call mark_step.
# In training case, result_list may be empty, but we want to
# send all the pending graphs for compilation.
if use_xla:
# For the model running on regular torchxla (baseline), we need the
# mark step to send the accumulated graph for compilation.
#
# For the model running with dynamo/torchxla bridge, in training case,
# we need the mark step to send the optimizer graph out for
# compilation.
xm.mark_step()
if use_xla:
xm.wait_device_ops()
synchronize()
t1 = time.perf_counter()
return (t1 - t0, result) if return_result else t1 - t0
@ -421,8 +439,6 @@ def randomize_input(inputs):
def maybe_mark_step(args):
if args.trace_on_xla:
import torch_xla.core.xla_model as xm
xm.mark_step()
@ -460,6 +476,12 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
else:
yield
times = args.iterations_per_run
# Use higher tolerance for XLA since XLA cause numerical unstability when
# graph size changes
tolerance = args.xla_tolerance if args.trace_on_xla else 1e-4
with maybe_profile(enabled=args.export_profiler_trace) as p:
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
for rep in range(args.repeat):
@ -476,7 +498,12 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
# interleave the runs to handle frequency scaling and load changes
with maybe_mark_profile(p=p, mark="expected"):
timings[rep, 0], expected_output = timed(
model, model_iter_fn, inputs, return_result=True
model,
model_iter_fn,
inputs,
return_result=True,
times=times,
collect_outputs=args.collect_outputs,
)
# call mark_step between the 2 calls to make the comparison fair.
@ -484,11 +511,18 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
with maybe_mark_profile(p=p, mark="actual"):
timings[rep, 1], actual_output = timed(
model, frozen_model_iter_fn, inputs, return_result=True
model,
frozen_model_iter_fn,
inputs,
return_result=True,
times=times,
collect_outputs=args.collect_outputs,
)
if should_check_result:
is_correct = is_correct and same(expected_output, actual_output)
is_correct = is_correct and same(
expected_output, actual_output, tol=tolerance
)
if args.export_profiler_trace:
name = args.profiler_trace_name + "_" + model.name + ".json"
@ -848,10 +882,12 @@ def cast_to_fp32(model, inputs):
return cast_to(torch.float32, model, inputs)
def reset_rng_state():
def reset_rng_state(use_xla=False):
torch.manual_seed(1337)
random.seed(1337)
np.random.seed(1337)
if use_xla:
xm.set_rng_state(1337, str(xm.xla_device()))
class DummyGradScaler:
@ -1420,6 +1456,15 @@ def parse_args(args=None):
parser.add_argument(
"--repeat", "-n", type=int, default=30, help="number of timing runs"
)
iterations_per_run_help = """
Run this may iterations for each time measurement. This is mainly used for
XLA training. We want to run multiple iterations per measurement so the
tracing and computation for different iteartions can overlap with each
other. This makes sure we have an accurate xla baseline.
"""
parser.add_argument(
"--iterations-per-run", type=int, default=1, help=iterations_per_run_help
)
parser.add_argument(
"--randomize-input",
action="store_true",
@ -1601,6 +1646,19 @@ def parse_args(args=None):
action="store_true",
help="Whether to trace the model on XLA or on eager device",
)
parser.add_argument(
"--xla-tolerance",
type=float,
default=1e-2,
help="XLA needs a loose tolerance to pass the correctness check",
)
parser.add_argument(
"--collect-outputs",
action="store_true",
help="""Whether to collect outputs for training. Set this to true if we
want to verify the numerical correctness of graidents. But that may
cause time measurement not accurate""",
)
group_fuser = parser.add_mutually_exclusive_group()
# --nvfuser is now the default, keep the option to not break scripts
@ -2063,8 +2121,6 @@ def run(runner, args, original_dir=None):
continue # bad benchmark implementation
if args.trace_on_xla:
import torch_xla.core.xla_model as xm
xla_dev = xm.xla_device()
model = model.to(device=xla_dev)
example_inputs = tree_map(

View file

@ -192,6 +192,7 @@ class TorchBenchmarkRunner(BenchmarkRunner):
def __init__(self):
super(TorchBenchmarkRunner, self).__init__()
self.suite_name = "torchbench"
self.optimizer = None
@property
def skip_models(self):
@ -297,6 +298,10 @@ class TorchBenchmarkRunner(BenchmarkRunner):
# global current_name, current_device
# current_device = device
# current_name = benchmark.name
if self.args.trace_on_xla:
# work around for: https://github.com/pytorch/xla/issues/4174
import torch_xla # noqa: F401
self.validate_model(model, example_inputs)
return device, benchmark.name, model, example_inputs, batch_size

View file

@ -5,6 +5,8 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from functorch.compile import aot_module_simplified, make_boxed_compiler
from torch._dynamo import disable
try:
from .test_torchxla_util import maybe_skip_torchxla_test
@ -54,7 +56,21 @@ class LinearModule(nn.Module):
return self.linear(x)
def get_random_inputs(self):
return (torch.randn(10),)
return (torch.randn(2, 10),)
class MaxPoolModule(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=2)
self.pool = nn.MaxPool2d(3, stride=2)
def forward(self, x):
x = self.conv(x)
return self.pool(x)
def get_random_inputs(self):
return (torch.randn(2, 3, 10, 10),)
class ModuleInplaceUpdate(nn.Module):
@ -128,12 +144,73 @@ def make_reuse_graph_test(module_class, niter=100):
return test_wrapper
def training_compiler(gm, example_inputs):
@make_boxed_compiler
@disable
def fw_compiler(graph, inputs, *args, **kwargs):
# tracing time inputs are FakeTensors, we can not pass them
# to extract_compiled_graph directly since we can not extract
# xla tensor id from fake tensors. Call extract_compiled_graph
# lazily and trigger that for the first call with non-fake tensors.
compiled_graph = None
def optimized_mod(*args):
nonlocal compiled_graph
if compiled_graph is None:
compiled_graph = integration.extract_compiled_graph(graph, args)
return compiled_graph(*args)
return optimized_mod
return aot_module_simplified(gm, example_inputs, fw_compiler=fw_compiler)
def model_iter_fn_train(mod, inputs):
outputs = mod(*inputs)
loss = outputs.mean()
loss.backward()
param_list = list(mod.parameters())
return [param.grad for param in param_list]
def make_training_test(model_cls):
@maybe_skip_torchxla_test
def test_wrapper(self):
import torch_xla.core.xla_model as xm
xla_dev = xm.xla_device()
model = model_cls()
inputs = model.get_random_inputs()
model = model.to(device=xla_dev)
inputs = tuple(inp.to(device=xla_dev) for inp in inputs)
# do baseline
baseline_model = copy.deepcopy(model)
baseline_inputs = copy.deepcopy(inputs)
expected_output = model_iter_fn_train(baseline_model, baseline_inputs)
compiler = training_compiler
optimize_ctx = torch._dynamo.optimize(compiler, nopython=False)
optimized_model_iter_fn = optimize_ctx(model_iter_fn_train)
actual_output = optimized_model_iter_fn(model, inputs)
print(f"expected_output:\n{expected_output}\nactual_output:\n{actual_output}")
assert allclose(expected_output, actual_output)
return test_wrapper
class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase):
test_basic = make_reuse_graph_test(BasicModule)
test_matmul = make_reuse_graph_test(MatmulModule)
test_linear = make_reuse_graph_test(LinearModule)
test_inplace_update = make_reuse_graph_test(ModuleInplaceUpdate)
test_training_linear = make_training_test(LinearModule)
test_training_maxpool = make_training_test(MaxPoolModule)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View file

@ -8,7 +8,7 @@ from typing import Any, Dict, List
import torch
debug = os.environ.get("debug_extract_compiled_graph") == "1"
debug = os.environ.get("TORCH_XLA_DEBUG") == "1"
@dataclasses.dataclass
@ -39,7 +39,13 @@ class GraphInputMatcher:
self.graph_input_tensor_ids, self.graph_input_xla_values
):
arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None)
if arg_idx is None:
# Instead of use trace time base seed, use the runtime
# base seed here.
if tensor_id == torch_xla._XLAC._get_seed_info_id():
inp = torch_xla._XLAC._get_base_seed_as_tensor(
str(traced_xla_value.device)
)
elif arg_idx is None:
inp = traced_xla_value
else:
inp = args[arg_idx]
@ -73,6 +79,125 @@ def import_torchxla():
import torch_xla.debug.metrics as metrics
class Deduper:
def __init__(self):
# origlist index to dedupedlist index
self.permute_for_orig = None
def dedup(self, origlist):
self.permute_for_orig = []
deduped_ids = dict()
deduped_list = []
for item in origlist:
item_id = id(item)
if item_id not in deduped_ids:
deduped_ids[item_id] = len(deduped_ids)
deduped_list.append(item)
self.permute_for_orig.append(deduped_ids[item_id])
return deduped_list
def recover(self, deduped_list):
assert len(self.permute_for_orig) >= len(deduped_list)
return [deduped_list[i] for i in self.permute_for_orig]
class DumbReturnHandler:
"""
Define dumb return as an output that is also an input.
Torch xla does not return such tensors as its graph output. That breaks the
API contract with the caller of the graph. Also AOTAutograd
may generate such a graph quite often.
To avoid break the contract with the user of the GraphModule, we need
add those outputs manually.
Check https://github.com/pytorch/pytorch/pull/89536 for details.
AOTAutograd may also generate graph with duplicated return item.
E.g. https://gist.github.com/shunting314/e60df8ac21fbe2494337c10d02bd78dc
(this is a graph generated for a model with a single BatchNorm2d)
XLA will dedup those duplicate items, but we need recover the duplications to maintain
the contract with the caller.
"""
def __init__(self, trace_inputs, trace_outputs, trace_inputs_inplace_update_bool):
self.trace_inputs = trace_inputs
self.trace_outputs = trace_outputs
# dedup the traced outputs first
self.deduper = Deduper()
self.deduped_trace_outputs = self.deduper.dedup(self.trace_outputs)
if debug:
print(
f"Number of duplicated outputs {len(self.trace_outputs) - len(self.deduped_trace_outputs)})"
)
# record the output that is also a input
trace_inputs_id2pos = {id(x): pos for pos, x in enumerate(self.trace_inputs)}
self.trace_outputs_pos_to_inputs_pos = []
for out_pos, out in enumerate(self.deduped_trace_outputs):
in_pos = trace_inputs_id2pos.get(id(out), None)
if in_pos is not None and not trace_inputs_inplace_update_bool[in_pos]:
self.trace_outputs_pos_to_inputs_pos.append((out_pos, in_pos))
if debug:
print(
f"Number trace input {len(trace_inputs)}, number trace output {len(trace_outputs)}"
)
print(
f"Found {len(self.trace_outputs_pos_to_inputs_pos)} dumb returns: {self.trace_outputs_pos_to_inputs_pos}"
)
def addDumbReturn(self, real_inputs, real_outputs):
for out_pos, in_pos in self.trace_outputs_pos_to_inputs_pos:
assert in_pos < len(real_inputs)
# equals is fine since we can append an item at the end
assert out_pos <= len(real_outputs)
real_outputs.insert(out_pos, real_inputs[in_pos])
ret = self.deduper.recover(real_outputs)
return ret
class NoneRemover:
"""
torchxla pybind APIs that accepts a Tensor list does not expect None value on
the list. But some graph (e.g. backward graph generated by aot autograd) may
return a None value. We need strip those None value before sending the list to
those torchxla APIs. We need add None value back later after running the
compiled graph from torchxla.
"""
def __init__(self):
self.none_poslist = []
def remove_nones(self, value_list):
"""
Remove none from value_list. value_list will be inplace updated.
The original position of None values are recorded.
"""
num = len(value_list)
# work in reverse order
for i in reversed(range(num)):
if value_list[i] is None:
self.none_poslist.append(i)
del value_list[i]
self.none_poslist.reverse()
def add_nones(self, value_list):
"""
Add nones to value_list according to self.none_poslist. value_list
is inplace updated.
"""
for pos in self.none_poslist:
value_list.insert(pos, None)
def is_xla_tensor(tensor: torch.Tensor) -> bool:
return tensor.device.type == "xla"
@ -97,9 +222,15 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
]
if debug:
print(f"Graph module:\n{xla_model.code}")
print(f"args_tensor_ids {args_tensor_ids}")
tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)}
# get_fallback_ops below uses counters to detect torch_xla fallbacks.
# Clear the counters here so we ignore pre-existing fallbacks and
# only detect fallbacks happening when running the xla_model below.
metrics.clear_counters()
xla_out = xla_model(*xla_args)
fallback_ops = get_fallback_ops()
@ -111,6 +242,11 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
if not isinstance(xla_out, (tuple, list)):
xla_out = (xla_out,)
none_remover = NoneRemover()
none_remover.remove_nones(xla_out)
xla_out_ids = {id(x) for x in xla_out}
# If a arg is being in place updated by model, we need to include arg as part of the graph result.
xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
xla_args
@ -118,17 +254,22 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
xla_args_need_update = []
arg_index_to_need_update_index = {}
for i, need_update in enumerate(xla_args_need_update_bool):
if need_update:
# Don't add inplace updated argument to the list if it's already
# being returned
if need_update and id(xla_args[i]) not in xla_out_ids:
arg_index_to_need_update_index[i] = len(xla_args_need_update)
xla_args_need_update.append(xla_args[i])
args_and_out = tuple(xla_args_need_update) + tuple(xla_out)
if debug:
print(f"#inplace update: {len(xla_args_need_update)}")
print(f"XLA IR Text: {torch_xla._XLAC._get_xla_tensors_text(args_and_out)}")
print(f"XLA IR HLO: {torch_xla._XLAC._get_xla_tensors_hlo(args_and_out)}")
# calculate graph hash
dumb_return_handler = DumbReturnHandler(
xla_args, args_and_out, xla_args_need_update_bool
)
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
if debug:
print("graph_hash", graph_hash)
@ -150,7 +291,6 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
torch_xla._XLAC._xla_sync_multi(args_and_out, [])
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
# input all cpu tensors
def optimized_mod(*args):
torch_xla._XLAC._xla_sync_multi(args, [])
enter_ts = time.time()
@ -161,13 +301,14 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
graph_input = graph_input_matcher(args)
start_ts = time.time()
res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)
res = dumb_return_handler.addDumbReturn(args, res)
if debug:
print(
f"torchxla reuse compiled graph run_cached_graph takes {time.time() - start_ts} seconds"
)
args_inplace_update_ts = time.time()
assert len(res) == len(args_and_out)
assert len(res) == len(args_and_out), f"{len(res)} v.s. {len(args_and_out)}"
ncopy = 0
for arg_index, res_index in arg_index_to_need_update_index.items():
@ -184,6 +325,7 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
print(f"optimized_mod takes {time.time() - enter_ts} seconds overall")
xm.mark_step()
none_remover.add_nones(result)
return result
return optimized_mod

View file

@ -363,6 +363,14 @@ def cudagraphs(model, inputs):
aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
aot_torchxla_trivial = aot_autograd(
fw_compiler=BACKENDS["torchxla_trivial"],
)
aot_torchxla_trace_once = aot_autograd(
fw_compiler=BACKENDS["torchxla_trace_once"],
)
def create_aot_backends():
"""
@ -399,3 +407,6 @@ def create_aot_backends():
# aot_inductor_debug just replaces the inductor compiler with nop to help
# isolate inductor vs aot_eager errors
BACKENDS["aot_inductor_debug"] = aot_inductor_debug
BACKENDS["aot_torchxla_trivial"] = aot_torchxla_trivial
BACKENDS["aot_torchxla_trace_once"] = aot_torchxla_trace_once