From 3cdc6d7775d2df2b836c5d39c2582daa2151b324 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 20 Jul 2022 17:00:41 +0800 Subject: [PATCH] [ORTModule] Bugfix of torch.chunk's Custom Symbolic when chunks==1 (#12249) handle custom chunk with chunks==1 --- .../ortmodule/_custom_op_symbolic_registry.py | 2 + .../python/orttraining_test_ortmodule_api.py | 61 ++++++++++--------- 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index a3192b5404..0b1c4ec151 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -228,6 +228,8 @@ def squeeze(g, self, dim=None): # exporting to Split with SplitGrad as gradient graph. @register_symbolic("ConstantChunk", "prim") def prim_ConstantChunk(g, self, chunks, dim): + if chunks == 1: + return self input_shape_dim = g.op( "Gather", g.op("Shape", self), g.op("Constant", value_t=torch.tensor([dim], dtype=torch.long)), axis_i=0 ) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 01a47459e3..26cf8f6df6 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1326,51 +1326,56 @@ def test_gradient_correctness_reducesum(dim, keepdim): # Before PyTorch 1.11.0, the exporter will fail to register symbolic with non-empty domain. @pytest.mark.skipif(Version(torch.__version__) < Version("1.11.0"), reason="PyTorch 1.10 incompatible") @pytest.mark.parametrize("dim", [0, 1, -1]) -def test_gradient_correctness_chunk(dim): +@pytest.mark.parametrize("chunks", [1, 3]) +def test_gradient_correctness_chunk(dim, chunks): class NeuralNetChunk(torch.nn.Module): def __init__(self, dim): super(NeuralNetChunk, self).__init__() self.dim = dim def forward(self, input): - return input.chunk(3, dim=self.dim) + return input.chunk(chunks, dim=self.dim) device = "cuda" pt_model = NeuralNetChunk(dim).to(device) - ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="chunk_model")) + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=(chunks > 1), onnx_prefix="chunk_model")) def run_step(model, input): - y1, y2, y3 = model(input) - loss = y1.sum() + y2.sum() + y3.sum() + results = model(input) + loss = results[0].sum() + for i in range(1, len(results)): + loss = loss + results[i].sum() loss.backward() - return y1, y2, y3 + return results N, D, H = 16, 17, 18 for _ in range(10): - input = torch.rand((N, D, H), device=device, requires_grad=True) - pt_y1, pt_y2, pt_y3 = run_step(pt_model, input) - ort_y1, ort_y2, ort_y3 = run_step(ort_model, input) + pt_input = torch.rand((N, D, H), device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_results = run_step(pt_model, pt_input) + ort_results = run_step(ort_model, ort_input) - _test_helpers.assert_values_are_close(ort_y1, pt_y1) - _test_helpers.assert_values_are_close(ort_y2, pt_y2) - _test_helpers.assert_values_are_close(ort_y3, pt_y3) - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + assert len(ort_results) == len(pt_results) + for i in range(len(ort_results)): + _test_helpers.assert_values_are_close(ort_results[i], pt_results[i]) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) - assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx")) - assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx")) - assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx")) - assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx")) - model = onnx.load(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx")) - has_split = False - for node in model.graph.node: - if node.op_type == "Split": - has_split = True - break - assert has_split - os.remove(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx")) - os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx")) - os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx")) - os.remove(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx")) + if chunks > 1: + assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx")) + assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx")) + assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx")) + assert os.path.exists(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx")) + model = onnx.load(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx")) + has_split = False + for node in model.graph.node: + if node.op_type == "Split": + has_split = True + break + assert has_split + os.remove(os.path.join(os.getcwd(), "chunk_model_torch_exported_training.onnx")) + os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_training.onnx")) + os.remove(os.path.join(os.getcwd(), "chunk_model_optimized_pre_grad_training.onnx")) + os.remove(os.path.join(os.getcwd(), "chunk_model_execution_model_training.onnx")) # In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that