From c40f73ae0cb986fd5e57c54764643adffbe2e37c Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 26 Jul 2022 07:29:56 +0800 Subject: [PATCH] Remove aten::binary_cross_entropy_with_logits from ATen Fallback (#12301) --- .../python/tools/symbolic_shape_infer.py | 13 -------- .../ortmodule/_custom_gradient_registry.py | 12 ------- .../ortmodule/_custom_op_symbolic_registry.py | 20 ----------- .../python/orttraining_test_ortmodule_api.py | 33 ------------------- 4 files changed, 78 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 4a553325f9..e5157f90ee 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -203,7 +203,6 @@ class SymbolicShapeInference: "argmax": self._infer_aten_argmax, "avg_pool2d": self._infer_aten_pool2d, "_adaptive_avg_pool2d": self._infer_aten_pool2d, - "binary_cross_entropy_with_logits": self._infer_aten_bce, "numpy_T": self._infer_Transpose, } self.run_ = True @@ -1345,18 +1344,6 @@ class SymbolicShapeInference: vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape)) - def _infer_aten_bce(self, node): - reduction = self._try_get_value(node, 4) - if reduction is None: - reduction = 1 - elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type - vi = self.known_vi_[node.output[0]] - if reduction == 0: - vi.type.tensor_type.elem_type = elem_type - vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) - else: - vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, self._get_shape(node, 0))) - def _infer_BatchNormalization(self, node): self._propagate_shape_and_type(node) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index c0d1e2cb40..42de356ae2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -212,18 +212,6 @@ CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.ate CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.aten", "ATen", "multinomial", "") -@register_gradient("org.pytorch.aten", "ATen", "binary_cross_entropy_with_logits", "") -def binary_cross_entropy_with_logits_gradient(): - return [ - ( - ("ATen", "org.pytorch.aten"), - ["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)"], - ["GI(0)"], - {"operator": {"value": "binary_cross_entropy_with_logits_backward", "dtype": "string"}}, - ), - ] - - @register_gradient("org.pytorch.aten", "ATen", "numpy_T", "") def numpy_T_gradient(): return [ diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 0b1c4ec151..634922379f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -177,26 +177,6 @@ def adaptive_avg_pool2d(g, self, output_size): return g.op("org.pytorch.aten::ATen", self, output_size, operator_s="_adaptive_avg_pool2d") -@register_symbolic("binary_cross_entropy_with_logits") -def binary_cross_entropy_with_logits(g, self, target, weight, pos_weight, reduction): - # If weight is not None, we need to check if it requires grad and add gradient graph accordingly. - # But current custom_gradient_registry doesn't support such None checking, - # So doesn't support non-None weight for now. - if weight is None or sym_help._is_none(weight): - return g.op( - "org.pytorch.aten::ATen", - self, - target, - weight, - pos_weight, - reduction, - operator_s="binary_cross_entropy_with_logits", - ) - from torch.onnx.symbolic_opset12 import binary_cross_entropy_with_logits as bce - - return bce(g, self, target, weight, pos_weight, reduction) - - @register_symbolic("numpy_T") def numpy_T(g, self): # Numpy-style `a.T`: returns the tensor diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 26cf8f6df6..32543b7226 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1630,39 +1630,6 @@ def test_numpy_T(input_shape): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) -def test_gradient_correctness_bce_with_logits(): - class NeuralNetBCEWithLogitsLoss(torch.nn.Module): - def __init__(self, input_size, hidden_size): - super(NeuralNetBCEWithLogitsLoss, self).__init__() - self.linear = torch.nn.Linear(input_size, hidden_size) - - def forward(self, input, target): - loss_fct = torch.nn.BCEWithLogitsLoss() - return loss_fct(self.linear(input), target) - - N, D, H = 16, 256, 128 - device = "cuda" - pt_model = NeuralNetBCEWithLogitsLoss(D, H).to(device) - ort_model = ORTModule(copy.deepcopy(pt_model)) - - def run_step(model, input, target): - prediction = model(input, target) - loss = prediction.sum() - loss.backward() - return prediction - - for _ in range(10): - pt_input = torch.rand((N, D), device=device, requires_grad=True) - ort_input = copy.deepcopy(pt_input) - pt_target = torch.rand((N, H), device=device) - ort_target = copy.deepcopy(pt_target) - pt_prediction = run_step(pt_model, pt_input, pt_target) - ort_prediction = run_step(ort_model, ort_input, ort_target) - - _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) - _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) - - def test_gradient_correctness_cast_chain(): class NeuralNetCast(torch.nn.Module): def __init__(self, D):