mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
training support for dynamo+torchxla integration (#88449)
We've already shown some promising perf result by integrating dynamo with torchxla for inference. To provide consistent UX for training and for inference, in this PR we try to enable training for dynamo/torchxla.
Training is trickier than inference and we may not expect much perf gains since
1. in training case, torchxla only generate a single combined graph for fwd/bwd/optimizer while in `torchxla_trace_once` bridge we added in dynamo, due to how AOT_Autograd works, we will generate 3 graphs: one for forward, one for backward and one for the optimizer. XLA favors larger graph to do more optimizations.
2. in training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all training cares more about throughput while inference cares more about latency.
3. in training case, people can increase batch size to 'mitigate' the tracing overhead. Increase batch size does not change tracing overhead, thus it shows like the tracing overhead 'per example' reduces.
But we still want to add training support to dynamo/torchxla to make the work complete.
We added '--iterations-per-run' argument to control how may iterations we do per measure/device sync. This is to understand the impact of item 2 above.
Results:
With '--iterations-per-run' equals to 1, here are the perf numbers:
```
+-------------------------+--------------------+-------------------------+
| Model | XLA (trace once) | XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18 | 0.91 | 0.959 |
+-------------------------+--------------------+-------------------------+
| resnet50 | 0.917 | 0.932 |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d | 0.912 | 0.905 |
+-------------------------+--------------------+-------------------------+
| alexnet | 1.038 | 0.974 |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2 | 0.881 | 0.835 |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0 | 0.903 | 0.931 |
+-------------------------+--------------------+-------------------------+
| vgg16 | 0.914 | 0.967 |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch | 1.359 | 0.84 |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer | 1.288 | 0.893 |
+-------------------------+--------------------+-------------------------+
| geomean | 1.0006 | 0.913794 |
+-------------------------+--------------------+-------------------------+
```
Overall it looks like graph break indeed cause perf loss. But for BERT_pytorch and timm_vision_transformer we still see perf gain. We need do more experiments with larger '--iterations-per-run'
NOTE:
In torchbench.py I added the following code to do a few workaround:
```
from myscripts import workaround # TODO will remove this line before landing
```
Here are the content of workaround.py:
```
import torch
from torch import nn
import os
# override max_pool2d with avg_pool2d
if os.environ.get("REPLACE_MAXPOOL", "0") == "1":
torch.nn.MaxPool2d = torch.nn.AvgPool2d
```
It work around a few issues we found
1. MaxPool2d does not work for training in dynamo/torchxla: https://github.com/pytorch/torchdynamo/issues/1837 . WIP fix from Brian in https://github.com/pytorch/pytorch/pull/90226 , https://github.com/pytorch/xla/pull/4276/files (WIP)
2. recent change ( this PR https://github.com/pytorch/pytorch/pull/88697 ) in op decomposition cause batch_norm ops to fallback in torchxla. Fix from jack in https://github.com/pytorch/xla/pull/4282#event-7969608134 . (confirmed the fix after adding Deduper to handle duplicated return from fx graph generated by AOTAutograd)
3. we have issue to handle dropout because of random seed out of sync issue. Here is the fix: https://github.com/pytorch/xla/pull/4293 (confirmed the fix)
Example command:
```
REPLACE_MAXPOOL=1 USE_FAKE_TENSOR=0 GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only vgg16
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88449
Approved by: https://github.com/wconstab, https://github.com/qihqi, https://github.com/malfet
This commit is contained in:
parent
df4b3b13bc
commit
a5f32f8978
5 changed files with 322 additions and 31 deletions
|
|
@ -43,6 +43,11 @@ try:
|
|||
except ImportError:
|
||||
from microbenchmarks.operator_inp_utils import OperatorInputsMode
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
except ImportError:
|
||||
# ignore the error if torch_xla is not installed
|
||||
pass
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -285,30 +290,43 @@ def tensor_is_on_xla(tensors):
|
|||
return any(map(lambda x: x.device.type == "xla", tensors))
|
||||
|
||||
|
||||
def timed(model, model_iter_fn, example_inputs, times=1, return_result=False):
|
||||
def timed(
|
||||
model,
|
||||
model_iter_fn,
|
||||
example_inputs,
|
||||
times=1,
|
||||
return_result=False,
|
||||
collect_outputs=False,
|
||||
):
|
||||
use_xla = tensor_is_on_xla(example_inputs)
|
||||
synchronize()
|
||||
if tensor_is_on_xla(example_inputs):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if use_xla:
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
reset_rng_state()
|
||||
t0 = time.perf_counter()
|
||||
# Dont collect outputs to correctly measure timing
|
||||
for _ in range(times):
|
||||
result = model_iter_fn(model, example_inputs, collect_outputs=False)
|
||||
if tensor_is_on_xla(result):
|
||||
# If the model is on XLA device, it's possible that after running
|
||||
# the model, the computation is accumulated but not performed yet.
|
||||
# Flush all the accumulated computations to make the time measurement
|
||||
# accurate.
|
||||
import torch_xla
|
||||
# Put this call inside the loop to reset the seed for each iteration.
|
||||
reset_rng_state(use_xla)
|
||||
result = model_iter_fn(model, example_inputs, collect_outputs=collect_outputs)
|
||||
|
||||
result_list = result
|
||||
if not isinstance(result, (tuple, list)):
|
||||
result_list = [result]
|
||||
torch_xla._XLAC._xla_sync_multi(result_list, [])
|
||||
synchronize()
|
||||
# instead of calling sync on result_list, we should call mark_step.
|
||||
# In training case, result_list may be empty, but we want to
|
||||
# send all the pending graphs for compilation.
|
||||
if use_xla:
|
||||
# For the model running on regular torchxla (baseline), we need the
|
||||
# mark step to send the accumulated graph for compilation.
|
||||
#
|
||||
# For the model running with dynamo/torchxla bridge, in training case,
|
||||
# we need the mark step to send the optimizer graph out for
|
||||
# compilation.
|
||||
xm.mark_step()
|
||||
|
||||
if use_xla:
|
||||
xm.wait_device_ops()
|
||||
synchronize()
|
||||
t1 = time.perf_counter()
|
||||
return (t1 - t0, result) if return_result else t1 - t0
|
||||
|
||||
|
|
@ -421,8 +439,6 @@ def randomize_input(inputs):
|
|||
|
||||
def maybe_mark_step(args):
|
||||
if args.trace_on_xla:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
xm.mark_step()
|
||||
|
||||
|
||||
|
|
@ -460,6 +476,12 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
|||
else:
|
||||
yield
|
||||
|
||||
times = args.iterations_per_run
|
||||
|
||||
# Use higher tolerance for XLA since XLA cause numerical unstability when
|
||||
# graph size changes
|
||||
tolerance = args.xla_tolerance if args.trace_on_xla else 1e-4
|
||||
|
||||
with maybe_profile(enabled=args.export_profiler_trace) as p:
|
||||
frozen_model_iter_fn = torch._dynamo.run(model_iter_fn)
|
||||
for rep in range(args.repeat):
|
||||
|
|
@ -476,7 +498,12 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
|||
# interleave the runs to handle frequency scaling and load changes
|
||||
with maybe_mark_profile(p=p, mark="expected"):
|
||||
timings[rep, 0], expected_output = timed(
|
||||
model, model_iter_fn, inputs, return_result=True
|
||||
model,
|
||||
model_iter_fn,
|
||||
inputs,
|
||||
return_result=True,
|
||||
times=times,
|
||||
collect_outputs=args.collect_outputs,
|
||||
)
|
||||
|
||||
# call mark_step between the 2 calls to make the comparison fair.
|
||||
|
|
@ -484,11 +511,18 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
|
|||
|
||||
with maybe_mark_profile(p=p, mark="actual"):
|
||||
timings[rep, 1], actual_output = timed(
|
||||
model, frozen_model_iter_fn, inputs, return_result=True
|
||||
model,
|
||||
frozen_model_iter_fn,
|
||||
inputs,
|
||||
return_result=True,
|
||||
times=times,
|
||||
collect_outputs=args.collect_outputs,
|
||||
)
|
||||
|
||||
if should_check_result:
|
||||
is_correct = is_correct and same(expected_output, actual_output)
|
||||
is_correct = is_correct and same(
|
||||
expected_output, actual_output, tol=tolerance
|
||||
)
|
||||
|
||||
if args.export_profiler_trace:
|
||||
name = args.profiler_trace_name + "_" + model.name + ".json"
|
||||
|
|
@ -848,10 +882,12 @@ def cast_to_fp32(model, inputs):
|
|||
return cast_to(torch.float32, model, inputs)
|
||||
|
||||
|
||||
def reset_rng_state():
|
||||
def reset_rng_state(use_xla=False):
|
||||
torch.manual_seed(1337)
|
||||
random.seed(1337)
|
||||
np.random.seed(1337)
|
||||
if use_xla:
|
||||
xm.set_rng_state(1337, str(xm.xla_device()))
|
||||
|
||||
|
||||
class DummyGradScaler:
|
||||
|
|
@ -1420,6 +1456,15 @@ def parse_args(args=None):
|
|||
parser.add_argument(
|
||||
"--repeat", "-n", type=int, default=30, help="number of timing runs"
|
||||
)
|
||||
iterations_per_run_help = """
|
||||
Run this may iterations for each time measurement. This is mainly used for
|
||||
XLA training. We want to run multiple iterations per measurement so the
|
||||
tracing and computation for different iteartions can overlap with each
|
||||
other. This makes sure we have an accurate xla baseline.
|
||||
"""
|
||||
parser.add_argument(
|
||||
"--iterations-per-run", type=int, default=1, help=iterations_per_run_help
|
||||
)
|
||||
parser.add_argument(
|
||||
"--randomize-input",
|
||||
action="store_true",
|
||||
|
|
@ -1601,6 +1646,19 @@ def parse_args(args=None):
|
|||
action="store_true",
|
||||
help="Whether to trace the model on XLA or on eager device",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--xla-tolerance",
|
||||
type=float,
|
||||
default=1e-2,
|
||||
help="XLA needs a loose tolerance to pass the correctness check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--collect-outputs",
|
||||
action="store_true",
|
||||
help="""Whether to collect outputs for training. Set this to true if we
|
||||
want to verify the numerical correctness of graidents. But that may
|
||||
cause time measurement not accurate""",
|
||||
)
|
||||
|
||||
group_fuser = parser.add_mutually_exclusive_group()
|
||||
# --nvfuser is now the default, keep the option to not break scripts
|
||||
|
|
@ -2063,8 +2121,6 @@ def run(runner, args, original_dir=None):
|
|||
continue # bad benchmark implementation
|
||||
|
||||
if args.trace_on_xla:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
xla_dev = xm.xla_device()
|
||||
model = model.to(device=xla_dev)
|
||||
example_inputs = tree_map(
|
||||
|
|
|
|||
|
|
@ -192,6 +192,7 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
|||
def __init__(self):
|
||||
super(TorchBenchmarkRunner, self).__init__()
|
||||
self.suite_name = "torchbench"
|
||||
self.optimizer = None
|
||||
|
||||
@property
|
||||
def skip_models(self):
|
||||
|
|
@ -297,6 +298,10 @@ class TorchBenchmarkRunner(BenchmarkRunner):
|
|||
# global current_name, current_device
|
||||
# current_device = device
|
||||
# current_name = benchmark.name
|
||||
|
||||
if self.args.trace_on_xla:
|
||||
# work around for: https://github.com/pytorch/xla/issues/4174
|
||||
import torch_xla # noqa: F401
|
||||
self.validate_model(model, example_inputs)
|
||||
return device, benchmark.name, model, example_inputs, batch_size
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import torch
|
|||
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from functorch.compile import aot_module_simplified, make_boxed_compiler
|
||||
from torch._dynamo import disable
|
||||
|
||||
try:
|
||||
from .test_torchxla_util import maybe_skip_torchxla_test
|
||||
|
|
@ -54,7 +56,21 @@ class LinearModule(nn.Module):
|
|||
return self.linear(x)
|
||||
|
||||
def get_random_inputs(self):
|
||||
return (torch.randn(10),)
|
||||
return (torch.randn(2, 10),)
|
||||
|
||||
|
||||
class MaxPoolModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=2)
|
||||
self.pool = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return self.pool(x)
|
||||
|
||||
def get_random_inputs(self):
|
||||
return (torch.randn(2, 3, 10, 10),)
|
||||
|
||||
|
||||
class ModuleInplaceUpdate(nn.Module):
|
||||
|
|
@ -128,12 +144,73 @@ def make_reuse_graph_test(module_class, niter=100):
|
|||
return test_wrapper
|
||||
|
||||
|
||||
def training_compiler(gm, example_inputs):
|
||||
@make_boxed_compiler
|
||||
@disable
|
||||
def fw_compiler(graph, inputs, *args, **kwargs):
|
||||
# tracing time inputs are FakeTensors, we can not pass them
|
||||
# to extract_compiled_graph directly since we can not extract
|
||||
# xla tensor id from fake tensors. Call extract_compiled_graph
|
||||
# lazily and trigger that for the first call with non-fake tensors.
|
||||
compiled_graph = None
|
||||
|
||||
def optimized_mod(*args):
|
||||
nonlocal compiled_graph
|
||||
if compiled_graph is None:
|
||||
compiled_graph = integration.extract_compiled_graph(graph, args)
|
||||
return compiled_graph(*args)
|
||||
|
||||
return optimized_mod
|
||||
|
||||
return aot_module_simplified(gm, example_inputs, fw_compiler=fw_compiler)
|
||||
|
||||
|
||||
def model_iter_fn_train(mod, inputs):
|
||||
outputs = mod(*inputs)
|
||||
loss = outputs.mean()
|
||||
loss.backward()
|
||||
|
||||
param_list = list(mod.parameters())
|
||||
return [param.grad for param in param_list]
|
||||
|
||||
|
||||
def make_training_test(model_cls):
|
||||
@maybe_skip_torchxla_test
|
||||
def test_wrapper(self):
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
xla_dev = xm.xla_device()
|
||||
model = model_cls()
|
||||
inputs = model.get_random_inputs()
|
||||
|
||||
model = model.to(device=xla_dev)
|
||||
inputs = tuple(inp.to(device=xla_dev) for inp in inputs)
|
||||
|
||||
# do baseline
|
||||
baseline_model = copy.deepcopy(model)
|
||||
baseline_inputs = copy.deepcopy(inputs)
|
||||
expected_output = model_iter_fn_train(baseline_model, baseline_inputs)
|
||||
|
||||
compiler = training_compiler
|
||||
optimize_ctx = torch._dynamo.optimize(compiler, nopython=False)
|
||||
optimized_model_iter_fn = optimize_ctx(model_iter_fn_train)
|
||||
|
||||
actual_output = optimized_model_iter_fn(model, inputs)
|
||||
print(f"expected_output:\n{expected_output}\nactual_output:\n{actual_output}")
|
||||
assert allclose(expected_output, actual_output)
|
||||
|
||||
return test_wrapper
|
||||
|
||||
|
||||
class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase):
|
||||
test_basic = make_reuse_graph_test(BasicModule)
|
||||
test_matmul = make_reuse_graph_test(MatmulModule)
|
||||
test_linear = make_reuse_graph_test(LinearModule)
|
||||
test_inplace_update = make_reuse_graph_test(ModuleInplaceUpdate)
|
||||
|
||||
test_training_linear = make_training_test(LinearModule)
|
||||
test_training_maxpool = make_training_test(MaxPoolModule)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, List
|
|||
|
||||
import torch
|
||||
|
||||
debug = os.environ.get("debug_extract_compiled_graph") == "1"
|
||||
debug = os.environ.get("TORCH_XLA_DEBUG") == "1"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -39,7 +39,13 @@ class GraphInputMatcher:
|
|||
self.graph_input_tensor_ids, self.graph_input_xla_values
|
||||
):
|
||||
arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None)
|
||||
if arg_idx is None:
|
||||
# Instead of use trace time base seed, use the runtime
|
||||
# base seed here.
|
||||
if tensor_id == torch_xla._XLAC._get_seed_info_id():
|
||||
inp = torch_xla._XLAC._get_base_seed_as_tensor(
|
||||
str(traced_xla_value.device)
|
||||
)
|
||||
elif arg_idx is None:
|
||||
inp = traced_xla_value
|
||||
else:
|
||||
inp = args[arg_idx]
|
||||
|
|
@ -73,6 +79,125 @@ def import_torchxla():
|
|||
import torch_xla.debug.metrics as metrics
|
||||
|
||||
|
||||
class Deduper:
|
||||
def __init__(self):
|
||||
# origlist index to dedupedlist index
|
||||
self.permute_for_orig = None
|
||||
|
||||
def dedup(self, origlist):
|
||||
self.permute_for_orig = []
|
||||
deduped_ids = dict()
|
||||
deduped_list = []
|
||||
for item in origlist:
|
||||
item_id = id(item)
|
||||
if item_id not in deduped_ids:
|
||||
deduped_ids[item_id] = len(deduped_ids)
|
||||
deduped_list.append(item)
|
||||
self.permute_for_orig.append(deduped_ids[item_id])
|
||||
|
||||
return deduped_list
|
||||
|
||||
def recover(self, deduped_list):
|
||||
assert len(self.permute_for_orig) >= len(deduped_list)
|
||||
return [deduped_list[i] for i in self.permute_for_orig]
|
||||
|
||||
|
||||
class DumbReturnHandler:
|
||||
"""
|
||||
Define dumb return as an output that is also an input.
|
||||
Torch xla does not return such tensors as its graph output. That breaks the
|
||||
API contract with the caller of the graph. Also AOTAutograd
|
||||
may generate such a graph quite often.
|
||||
|
||||
To avoid break the contract with the user of the GraphModule, we need
|
||||
add those outputs manually.
|
||||
|
||||
Check https://github.com/pytorch/pytorch/pull/89536 for details.
|
||||
|
||||
AOTAutograd may also generate graph with duplicated return item.
|
||||
E.g. https://gist.github.com/shunting314/e60df8ac21fbe2494337c10d02bd78dc
|
||||
(this is a graph generated for a model with a single BatchNorm2d)
|
||||
XLA will dedup those duplicate items, but we need recover the duplications to maintain
|
||||
the contract with the caller.
|
||||
"""
|
||||
|
||||
def __init__(self, trace_inputs, trace_outputs, trace_inputs_inplace_update_bool):
|
||||
self.trace_inputs = trace_inputs
|
||||
self.trace_outputs = trace_outputs
|
||||
|
||||
# dedup the traced outputs first
|
||||
self.deduper = Deduper()
|
||||
self.deduped_trace_outputs = self.deduper.dedup(self.trace_outputs)
|
||||
|
||||
if debug:
|
||||
print(
|
||||
f"Number of duplicated outputs {len(self.trace_outputs) - len(self.deduped_trace_outputs)})"
|
||||
)
|
||||
|
||||
# record the output that is also a input
|
||||
trace_inputs_id2pos = {id(x): pos for pos, x in enumerate(self.trace_inputs)}
|
||||
self.trace_outputs_pos_to_inputs_pos = []
|
||||
for out_pos, out in enumerate(self.deduped_trace_outputs):
|
||||
in_pos = trace_inputs_id2pos.get(id(out), None)
|
||||
if in_pos is not None and not trace_inputs_inplace_update_bool[in_pos]:
|
||||
self.trace_outputs_pos_to_inputs_pos.append((out_pos, in_pos))
|
||||
|
||||
if debug:
|
||||
print(
|
||||
f"Number trace input {len(trace_inputs)}, number trace output {len(trace_outputs)}"
|
||||
)
|
||||
print(
|
||||
f"Found {len(self.trace_outputs_pos_to_inputs_pos)} dumb returns: {self.trace_outputs_pos_to_inputs_pos}"
|
||||
)
|
||||
|
||||
def addDumbReturn(self, real_inputs, real_outputs):
|
||||
for out_pos, in_pos in self.trace_outputs_pos_to_inputs_pos:
|
||||
assert in_pos < len(real_inputs)
|
||||
# equals is fine since we can append an item at the end
|
||||
assert out_pos <= len(real_outputs)
|
||||
|
||||
real_outputs.insert(out_pos, real_inputs[in_pos])
|
||||
|
||||
ret = self.deduper.recover(real_outputs)
|
||||
return ret
|
||||
|
||||
|
||||
class NoneRemover:
|
||||
"""
|
||||
torchxla pybind APIs that accepts a Tensor list does not expect None value on
|
||||
the list. But some graph (e.g. backward graph generated by aot autograd) may
|
||||
return a None value. We need strip those None value before sending the list to
|
||||
those torchxla APIs. We need add None value back later after running the
|
||||
compiled graph from torchxla.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.none_poslist = []
|
||||
|
||||
def remove_nones(self, value_list):
|
||||
"""
|
||||
Remove none from value_list. value_list will be inplace updated.
|
||||
The original position of None values are recorded.
|
||||
"""
|
||||
num = len(value_list)
|
||||
|
||||
# work in reverse order
|
||||
for i in reversed(range(num)):
|
||||
if value_list[i] is None:
|
||||
self.none_poslist.append(i)
|
||||
del value_list[i]
|
||||
|
||||
self.none_poslist.reverse()
|
||||
|
||||
def add_nones(self, value_list):
|
||||
"""
|
||||
Add nones to value_list according to self.none_poslist. value_list
|
||||
is inplace updated.
|
||||
"""
|
||||
for pos in self.none_poslist:
|
||||
value_list.insert(pos, None)
|
||||
|
||||
|
||||
def is_xla_tensor(tensor: torch.Tensor) -> bool:
|
||||
return tensor.device.type == "xla"
|
||||
|
||||
|
|
@ -97,9 +222,15 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
|
|||
]
|
||||
|
||||
if debug:
|
||||
print(f"Graph module:\n{xla_model.code}")
|
||||
print(f"args_tensor_ids {args_tensor_ids}")
|
||||
|
||||
tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)}
|
||||
|
||||
# get_fallback_ops below uses counters to detect torch_xla fallbacks.
|
||||
# Clear the counters here so we ignore pre-existing fallbacks and
|
||||
# only detect fallbacks happening when running the xla_model below.
|
||||
metrics.clear_counters()
|
||||
xla_out = xla_model(*xla_args)
|
||||
|
||||
fallback_ops = get_fallback_ops()
|
||||
|
|
@ -111,6 +242,11 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
|
|||
if not isinstance(xla_out, (tuple, list)):
|
||||
xla_out = (xla_out,)
|
||||
|
||||
none_remover = NoneRemover()
|
||||
none_remover.remove_nones(xla_out)
|
||||
|
||||
xla_out_ids = {id(x) for x in xla_out}
|
||||
|
||||
# If a arg is being in place updated by model, we need to include arg as part of the graph result.
|
||||
xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization(
|
||||
xla_args
|
||||
|
|
@ -118,17 +254,22 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
|
|||
xla_args_need_update = []
|
||||
arg_index_to_need_update_index = {}
|
||||
for i, need_update in enumerate(xla_args_need_update_bool):
|
||||
if need_update:
|
||||
# Don't add inplace updated argument to the list if it's already
|
||||
# being returned
|
||||
if need_update and id(xla_args[i]) not in xla_out_ids:
|
||||
arg_index_to_need_update_index[i] = len(xla_args_need_update)
|
||||
xla_args_need_update.append(xla_args[i])
|
||||
|
||||
args_and_out = tuple(xla_args_need_update) + tuple(xla_out)
|
||||
|
||||
if debug:
|
||||
print(f"#inplace update: {len(xla_args_need_update)}")
|
||||
print(f"XLA IR Text: {torch_xla._XLAC._get_xla_tensors_text(args_and_out)}")
|
||||
print(f"XLA IR HLO: {torch_xla._XLAC._get_xla_tensors_hlo(args_and_out)}")
|
||||
|
||||
# calculate graph hash
|
||||
dumb_return_handler = DumbReturnHandler(
|
||||
xla_args, args_and_out, xla_args_need_update_bool
|
||||
)
|
||||
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
|
||||
if debug:
|
||||
print("graph_hash", graph_hash)
|
||||
|
|
@ -150,7 +291,6 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
|
|||
torch_xla._XLAC._xla_sync_multi(args_and_out, [])
|
||||
torch_xla._XLAC._clear_pending_irs(str(xm.xla_device()))
|
||||
|
||||
# input all cpu tensors
|
||||
def optimized_mod(*args):
|
||||
torch_xla._XLAC._xla_sync_multi(args, [])
|
||||
enter_ts = time.time()
|
||||
|
|
@ -161,13 +301,14 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
|
|||
graph_input = graph_input_matcher(args)
|
||||
start_ts = time.time()
|
||||
res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)
|
||||
res = dumb_return_handler.addDumbReturn(args, res)
|
||||
if debug:
|
||||
print(
|
||||
f"torchxla reuse compiled graph run_cached_graph takes {time.time() - start_ts} seconds"
|
||||
)
|
||||
|
||||
args_inplace_update_ts = time.time()
|
||||
assert len(res) == len(args_and_out)
|
||||
assert len(res) == len(args_and_out), f"{len(res)} v.s. {len(args_and_out)}"
|
||||
ncopy = 0
|
||||
|
||||
for arg_index, res_index in arg_index_to_need_update_index.items():
|
||||
|
|
@ -184,6 +325,7 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args):
|
|||
print(f"optimized_mod takes {time.time() - enter_ts} seconds overall")
|
||||
|
||||
xm.mark_step()
|
||||
none_remover.add_nones(result)
|
||||
return result
|
||||
|
||||
return optimized_mod
|
||||
|
|
|
|||
|
|
@ -363,6 +363,14 @@ def cudagraphs(model, inputs):
|
|||
|
||||
aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs)
|
||||
|
||||
aot_torchxla_trivial = aot_autograd(
|
||||
fw_compiler=BACKENDS["torchxla_trivial"],
|
||||
)
|
||||
|
||||
aot_torchxla_trace_once = aot_autograd(
|
||||
fw_compiler=BACKENDS["torchxla_trace_once"],
|
||||
)
|
||||
|
||||
|
||||
def create_aot_backends():
|
||||
"""
|
||||
|
|
@ -399,3 +407,6 @@ def create_aot_backends():
|
|||
# aot_inductor_debug just replaces the inductor compiler with nop to help
|
||||
# isolate inductor vs aot_eager errors
|
||||
BACKENDS["aot_inductor_debug"] = aot_inductor_debug
|
||||
|
||||
BACKENDS["aot_torchxla_trivial"] = aot_torchxla_trivial
|
||||
BACKENDS["aot_torchxla_trace_once"] = aot_torchxla_trace_once
|
||||
|
|
|
|||
Loading…
Reference in a new issue