diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 04d91c90d32..fb11ac61ad6 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -43,6 +43,11 @@ try: except ImportError: from microbenchmarks.operator_inp_utils import OperatorInputsMode +try: + import torch_xla.core.xla_model as xm +except ImportError: + # ignore the error if torch_xla is not installed + pass log = logging.getLogger(__name__) @@ -285,30 +290,43 @@ def tensor_is_on_xla(tensors): return any(map(lambda x: x.device.type == "xla", tensors)) -def timed(model, model_iter_fn, example_inputs, times=1, return_result=False): +def timed( + model, + model_iter_fn, + example_inputs, + times=1, + return_result=False, + collect_outputs=False, +): + use_xla = tensor_is_on_xla(example_inputs) synchronize() - if tensor_is_on_xla(example_inputs): - import torch_xla.core.xla_model as xm + if use_xla: xm.mark_step() + xm.wait_device_ops() - reset_rng_state() t0 = time.perf_counter() # Dont collect outputs to correctly measure timing for _ in range(times): - result = model_iter_fn(model, example_inputs, collect_outputs=False) - if tensor_is_on_xla(result): - # If the model is on XLA device, it's possible that after running - # the model, the computation is accumulated but not performed yet. - # Flush all the accumulated computations to make the time measurement - # accurate. - import torch_xla + # Put this call inside the loop to reset the seed for each iteration. + reset_rng_state(use_xla) + result = model_iter_fn(model, example_inputs, collect_outputs=collect_outputs) - result_list = result - if not isinstance(result, (tuple, list)): - result_list = [result] - torch_xla._XLAC._xla_sync_multi(result_list, []) - synchronize() + # instead of calling sync on result_list, we should call mark_step. + # In training case, result_list may be empty, but we want to + # send all the pending graphs for compilation. + if use_xla: + # For the model running on regular torchxla (baseline), we need the + # mark step to send the accumulated graph for compilation. + # + # For the model running with dynamo/torchxla bridge, in training case, + # we need the mark step to send the optimizer graph out for + # compilation. + xm.mark_step() + + if use_xla: + xm.wait_device_ops() + synchronize() t1 = time.perf_counter() return (t1 - t0, result) if return_result else t1 - t0 @@ -421,8 +439,6 @@ def randomize_input(inputs): def maybe_mark_step(args): if args.trace_on_xla: - import torch_xla.core.xla_model as xm - xm.mark_step() @@ -460,6 +476,12 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): else: yield + times = args.iterations_per_run + + # Use higher tolerance for XLA since XLA cause numerical unstability when + # graph size changes + tolerance = args.xla_tolerance if args.trace_on_xla else 1e-4 + with maybe_profile(enabled=args.export_profiler_trace) as p: frozen_model_iter_fn = torch._dynamo.run(model_iter_fn) for rep in range(args.repeat): @@ -476,7 +498,12 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): # interleave the runs to handle frequency scaling and load changes with maybe_mark_profile(p=p, mark="expected"): timings[rep, 0], expected_output = timed( - model, model_iter_fn, inputs, return_result=True + model, + model_iter_fn, + inputs, + return_result=True, + times=times, + collect_outputs=args.collect_outputs, ) # call mark_step between the 2 calls to make the comparison fair. @@ -484,11 +511,18 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs): with maybe_mark_profile(p=p, mark="actual"): timings[rep, 1], actual_output = timed( - model, frozen_model_iter_fn, inputs, return_result=True + model, + frozen_model_iter_fn, + inputs, + return_result=True, + times=times, + collect_outputs=args.collect_outputs, ) if should_check_result: - is_correct = is_correct and same(expected_output, actual_output) + is_correct = is_correct and same( + expected_output, actual_output, tol=tolerance + ) if args.export_profiler_trace: name = args.profiler_trace_name + "_" + model.name + ".json" @@ -848,10 +882,12 @@ def cast_to_fp32(model, inputs): return cast_to(torch.float32, model, inputs) -def reset_rng_state(): +def reset_rng_state(use_xla=False): torch.manual_seed(1337) random.seed(1337) np.random.seed(1337) + if use_xla: + xm.set_rng_state(1337, str(xm.xla_device())) class DummyGradScaler: @@ -1420,6 +1456,15 @@ def parse_args(args=None): parser.add_argument( "--repeat", "-n", type=int, default=30, help="number of timing runs" ) + iterations_per_run_help = """ + Run this may iterations for each time measurement. This is mainly used for + XLA training. We want to run multiple iterations per measurement so the + tracing and computation for different iteartions can overlap with each + other. This makes sure we have an accurate xla baseline. + """ + parser.add_argument( + "--iterations-per-run", type=int, default=1, help=iterations_per_run_help + ) parser.add_argument( "--randomize-input", action="store_true", @@ -1601,6 +1646,19 @@ def parse_args(args=None): action="store_true", help="Whether to trace the model on XLA or on eager device", ) + parser.add_argument( + "--xla-tolerance", + type=float, + default=1e-2, + help="XLA needs a loose tolerance to pass the correctness check", + ) + parser.add_argument( + "--collect-outputs", + action="store_true", + help="""Whether to collect outputs for training. Set this to true if we + want to verify the numerical correctness of graidents. But that may + cause time measurement not accurate""", + ) group_fuser = parser.add_mutually_exclusive_group() # --nvfuser is now the default, keep the option to not break scripts @@ -2063,8 +2121,6 @@ def run(runner, args, original_dir=None): continue # bad benchmark implementation if args.trace_on_xla: - import torch_xla.core.xla_model as xm - xla_dev = xm.xla_device() model = model.to(device=xla_dev) example_inputs = tree_map( diff --git a/benchmarks/dynamo/torchbench.py b/benchmarks/dynamo/torchbench.py index cbc29a535cc..a9b13571bbb 100755 --- a/benchmarks/dynamo/torchbench.py +++ b/benchmarks/dynamo/torchbench.py @@ -192,6 +192,7 @@ class TorchBenchmarkRunner(BenchmarkRunner): def __init__(self): super(TorchBenchmarkRunner, self).__init__() self.suite_name = "torchbench" + self.optimizer = None @property def skip_models(self): @@ -297,6 +298,10 @@ class TorchBenchmarkRunner(BenchmarkRunner): # global current_name, current_device # current_device = device # current_name = benchmark.name + + if self.args.trace_on_xla: + # work around for: https://github.com/pytorch/xla/issues/4174 + import torch_xla # noqa: F401 self.validate_model(model, example_inputs) return device, benchmark.name, model, example_inputs, batch_size diff --git a/test/dynamo/test_torchxla_integration.py b/test/dynamo/test_torchxla_integration.py index a1cd3da4718..831a5818c0b 100644 --- a/test/dynamo/test_torchxla_integration.py +++ b/test/dynamo/test_torchxla_integration.py @@ -5,6 +5,8 @@ import torch import torch._dynamo.test_case import torch._dynamo.testing +from functorch.compile import aot_module_simplified, make_boxed_compiler +from torch._dynamo import disable try: from .test_torchxla_util import maybe_skip_torchxla_test @@ -54,7 +56,21 @@ class LinearModule(nn.Module): return self.linear(x) def get_random_inputs(self): - return (torch.randn(10),) + return (torch.randn(2, 10),) + + +class MaxPoolModule(nn.Module): + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 6, kernel_size=3, stride=2) + self.pool = nn.MaxPool2d(3, stride=2) + + def forward(self, x): + x = self.conv(x) + return self.pool(x) + + def get_random_inputs(self): + return (torch.randn(2, 3, 10, 10),) class ModuleInplaceUpdate(nn.Module): @@ -128,12 +144,73 @@ def make_reuse_graph_test(module_class, niter=100): return test_wrapper +def training_compiler(gm, example_inputs): + @make_boxed_compiler + @disable + def fw_compiler(graph, inputs, *args, **kwargs): + # tracing time inputs are FakeTensors, we can not pass them + # to extract_compiled_graph directly since we can not extract + # xla tensor id from fake tensors. Call extract_compiled_graph + # lazily and trigger that for the first call with non-fake tensors. + compiled_graph = None + + def optimized_mod(*args): + nonlocal compiled_graph + if compiled_graph is None: + compiled_graph = integration.extract_compiled_graph(graph, args) + return compiled_graph(*args) + + return optimized_mod + + return aot_module_simplified(gm, example_inputs, fw_compiler=fw_compiler) + + +def model_iter_fn_train(mod, inputs): + outputs = mod(*inputs) + loss = outputs.mean() + loss.backward() + + param_list = list(mod.parameters()) + return [param.grad for param in param_list] + + +def make_training_test(model_cls): + @maybe_skip_torchxla_test + def test_wrapper(self): + import torch_xla.core.xla_model as xm + + xla_dev = xm.xla_device() + model = model_cls() + inputs = model.get_random_inputs() + + model = model.to(device=xla_dev) + inputs = tuple(inp.to(device=xla_dev) for inp in inputs) + + # do baseline + baseline_model = copy.deepcopy(model) + baseline_inputs = copy.deepcopy(inputs) + expected_output = model_iter_fn_train(baseline_model, baseline_inputs) + + compiler = training_compiler + optimize_ctx = torch._dynamo.optimize(compiler, nopython=False) + optimized_model_iter_fn = optimize_ctx(model_iter_fn_train) + + actual_output = optimized_model_iter_fn(model, inputs) + print(f"expected_output:\n{expected_output}\nactual_output:\n{actual_output}") + assert allclose(expected_output, actual_output) + + return test_wrapper + + class TorchXLAReuseGraphTest(torch._dynamo.test_case.TestCase): test_basic = make_reuse_graph_test(BasicModule) test_matmul = make_reuse_graph_test(MatmulModule) test_linear = make_reuse_graph_test(LinearModule) test_inplace_update = make_reuse_graph_test(ModuleInplaceUpdate) + test_training_linear = make_training_test(LinearModule) + test_training_maxpool = make_training_test(MaxPoolModule) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/optimizations/torchxla_integration.py b/torch/_dynamo/optimizations/torchxla_integration.py index f93e4d385ad..9db5351b70d 100644 --- a/torch/_dynamo/optimizations/torchxla_integration.py +++ b/torch/_dynamo/optimizations/torchxla_integration.py @@ -8,7 +8,7 @@ from typing import Any, Dict, List import torch -debug = os.environ.get("debug_extract_compiled_graph") == "1" +debug = os.environ.get("TORCH_XLA_DEBUG") == "1" @dataclasses.dataclass @@ -39,7 +39,13 @@ class GraphInputMatcher: self.graph_input_tensor_ids, self.graph_input_xla_values ): arg_idx = self.tensor_id_to_arg_idx.get(tensor_id, None) - if arg_idx is None: + # Instead of use trace time base seed, use the runtime + # base seed here. + if tensor_id == torch_xla._XLAC._get_seed_info_id(): + inp = torch_xla._XLAC._get_base_seed_as_tensor( + str(traced_xla_value.device) + ) + elif arg_idx is None: inp = traced_xla_value else: inp = args[arg_idx] @@ -73,6 +79,125 @@ def import_torchxla(): import torch_xla.debug.metrics as metrics +class Deduper: + def __init__(self): + # origlist index to dedupedlist index + self.permute_for_orig = None + + def dedup(self, origlist): + self.permute_for_orig = [] + deduped_ids = dict() + deduped_list = [] + for item in origlist: + item_id = id(item) + if item_id not in deduped_ids: + deduped_ids[item_id] = len(deduped_ids) + deduped_list.append(item) + self.permute_for_orig.append(deduped_ids[item_id]) + + return deduped_list + + def recover(self, deduped_list): + assert len(self.permute_for_orig) >= len(deduped_list) + return [deduped_list[i] for i in self.permute_for_orig] + + +class DumbReturnHandler: + """ + Define dumb return as an output that is also an input. + Torch xla does not return such tensors as its graph output. That breaks the + API contract with the caller of the graph. Also AOTAutograd + may generate such a graph quite often. + + To avoid break the contract with the user of the GraphModule, we need + add those outputs manually. + + Check https://github.com/pytorch/pytorch/pull/89536 for details. + + AOTAutograd may also generate graph with duplicated return item. + E.g. https://gist.github.com/shunting314/e60df8ac21fbe2494337c10d02bd78dc + (this is a graph generated for a model with a single BatchNorm2d) + XLA will dedup those duplicate items, but we need recover the duplications to maintain + the contract with the caller. + """ + + def __init__(self, trace_inputs, trace_outputs, trace_inputs_inplace_update_bool): + self.trace_inputs = trace_inputs + self.trace_outputs = trace_outputs + + # dedup the traced outputs first + self.deduper = Deduper() + self.deduped_trace_outputs = self.deduper.dedup(self.trace_outputs) + + if debug: + print( + f"Number of duplicated outputs {len(self.trace_outputs) - len(self.deduped_trace_outputs)})" + ) + + # record the output that is also a input + trace_inputs_id2pos = {id(x): pos for pos, x in enumerate(self.trace_inputs)} + self.trace_outputs_pos_to_inputs_pos = [] + for out_pos, out in enumerate(self.deduped_trace_outputs): + in_pos = trace_inputs_id2pos.get(id(out), None) + if in_pos is not None and not trace_inputs_inplace_update_bool[in_pos]: + self.trace_outputs_pos_to_inputs_pos.append((out_pos, in_pos)) + + if debug: + print( + f"Number trace input {len(trace_inputs)}, number trace output {len(trace_outputs)}" + ) + print( + f"Found {len(self.trace_outputs_pos_to_inputs_pos)} dumb returns: {self.trace_outputs_pos_to_inputs_pos}" + ) + + def addDumbReturn(self, real_inputs, real_outputs): + for out_pos, in_pos in self.trace_outputs_pos_to_inputs_pos: + assert in_pos < len(real_inputs) + # equals is fine since we can append an item at the end + assert out_pos <= len(real_outputs) + + real_outputs.insert(out_pos, real_inputs[in_pos]) + + ret = self.deduper.recover(real_outputs) + return ret + + +class NoneRemover: + """ + torchxla pybind APIs that accepts a Tensor list does not expect None value on + the list. But some graph (e.g. backward graph generated by aot autograd) may + return a None value. We need strip those None value before sending the list to + those torchxla APIs. We need add None value back later after running the + compiled graph from torchxla. + """ + + def __init__(self): + self.none_poslist = [] + + def remove_nones(self, value_list): + """ + Remove none from value_list. value_list will be inplace updated. + The original position of None values are recorded. + """ + num = len(value_list) + + # work in reverse order + for i in reversed(range(num)): + if value_list[i] is None: + self.none_poslist.append(i) + del value_list[i] + + self.none_poslist.reverse() + + def add_nones(self, value_list): + """ + Add nones to value_list according to self.none_poslist. value_list + is inplace updated. + """ + for pos in self.none_poslist: + value_list.insert(pos, None) + + def is_xla_tensor(tensor: torch.Tensor) -> bool: return tensor.device.type == "xla" @@ -97,9 +222,15 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): ] if debug: + print(f"Graph module:\n{xla_model.code}") print(f"args_tensor_ids {args_tensor_ids}") tensor_id_to_arg_idx = {tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)} + + # get_fallback_ops below uses counters to detect torch_xla fallbacks. + # Clear the counters here so we ignore pre-existing fallbacks and + # only detect fallbacks happening when running the xla_model below. + metrics.clear_counters() xla_out = xla_model(*xla_args) fallback_ops = get_fallback_ops() @@ -111,6 +242,11 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): if not isinstance(xla_out, (tuple, list)): xla_out = (xla_out,) + none_remover = NoneRemover() + none_remover.remove_nones(xla_out) + + xla_out_ids = {id(x) for x in xla_out} + # If a arg is being in place updated by model, we need to include arg as part of the graph result. xla_args_need_update_bool = torch_xla._XLAC._check_tensor_need_materialization( xla_args @@ -118,17 +254,22 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): xla_args_need_update = [] arg_index_to_need_update_index = {} for i, need_update in enumerate(xla_args_need_update_bool): - if need_update: + # Don't add inplace updated argument to the list if it's already + # being returned + if need_update and id(xla_args[i]) not in xla_out_ids: arg_index_to_need_update_index[i] = len(xla_args_need_update) xla_args_need_update.append(xla_args[i]) args_and_out = tuple(xla_args_need_update) + tuple(xla_out) if debug: + print(f"#inplace update: {len(xla_args_need_update)}") print(f"XLA IR Text: {torch_xla._XLAC._get_xla_tensors_text(args_and_out)}") - print(f"XLA IR HLO: {torch_xla._XLAC._get_xla_tensors_hlo(args_and_out)}") # calculate graph hash + dumb_return_handler = DumbReturnHandler( + xla_args, args_and_out, xla_args_need_update_bool + ) graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out) if debug: print("graph_hash", graph_hash) @@ -150,7 +291,6 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): torch_xla._XLAC._xla_sync_multi(args_and_out, []) torch_xla._XLAC._clear_pending_irs(str(xm.xla_device())) - # input all cpu tensors def optimized_mod(*args): torch_xla._XLAC._xla_sync_multi(args, []) enter_ts = time.time() @@ -161,13 +301,14 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): graph_input = graph_input_matcher(args) start_ts = time.time() res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input) + res = dumb_return_handler.addDumbReturn(args, res) if debug: print( f"torchxla reuse compiled graph run_cached_graph takes {time.time() - start_ts} seconds" ) args_inplace_update_ts = time.time() - assert len(res) == len(args_and_out) + assert len(res) == len(args_and_out), f"{len(res)} v.s. {len(args_and_out)}" ncopy = 0 for arg_index, res_index in arg_index_to_need_update_index.items(): @@ -184,6 +325,7 @@ def extract_compiled_graph(xla_model: torch.fx.GraphModule, xla_args): print(f"optimized_mod takes {time.time() - enter_ts} seconds overall") xm.mark_step() + none_remover.add_nones(result) return result return optimized_mod diff --git a/torch/_dynamo/optimizations/training.py b/torch/_dynamo/optimizations/training.py index a1854037610..20cfb3759ad 100644 --- a/torch/_dynamo/optimizations/training.py +++ b/torch/_dynamo/optimizations/training.py @@ -363,6 +363,14 @@ def cudagraphs(model, inputs): aot_cudagraphs = aot_autograd(fw_compiler=cudagraphs, bw_compiler=cudagraphs) +aot_torchxla_trivial = aot_autograd( + fw_compiler=BACKENDS["torchxla_trivial"], +) + +aot_torchxla_trace_once = aot_autograd( + fw_compiler=BACKENDS["torchxla_trace_once"], +) + def create_aot_backends(): """ @@ -399,3 +407,6 @@ def create_aot_backends(): # aot_inductor_debug just replaces the inductor compiler with nop to help # isolate inductor vs aot_eager errors BACKENDS["aot_inductor_debug"] = aot_inductor_debug + + BACKENDS["aot_torchxla_trivial"] = aot_torchxla_trivial + BACKENDS["aot_torchxla_trace_once"] = aot_torchxla_trace_once