From adf98feb2c24eb266fa51f88939b192f6ed30028 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 10 Nov 2021 08:36:57 +0800 Subject: [PATCH] ATenOp Support for BCEWithLogitsLoss (#9670) --- .../python/tools/symbolic_shape_infer.py | 13 ++++++++ .../ortmodule/_custom_gradient_registry.py | 9 +++++- .../ortmodule/_custom_op_symbolic_registry.py | 12 +++++++ .../python/orttraining_test_ortmodule_api.py | 32 +++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index a016fb3d32..0d80efbc64 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -196,6 +196,7 @@ class SymbolicShapeInference: 'aten::argmax': self._infer_aten_argmax, 'aten::avg_pool2d': self._infer_aten_pool2d, 'aten::_adaptive_avg_pool2d': self._infer_aten_pool2d, + 'aten::binary_cross_entropy_with_logits': self._infer_aten_bce, } self.run_ = True self.suggested_merge_ = {} @@ -1174,6 +1175,18 @@ class SymbolicShapeInference: vi = self.known_vi_[node.output[0]] vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape)) + def _infer_aten_bce(self, node): + reduction = self._try_get_value(node, 4) + if reduction is None: + reduction = 1 + elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type + vi = self.known_vi_[node.output[0]] + if reduction == 0: + vi.type.tensor_type.elem_type = elem_type + vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto()) + else: + vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, self._get_shape(node, 0))) + def _infer_BatchNormalization(self, node): self._propagate_shape_and_type(node) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 251bbef8bd..896a3464a6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -133,4 +133,11 @@ def adaptive_avg_pool2d_gradient(): 'GI(0)'], {'name': {'value': 'aten::_adaptive_avg_pool2d_backward', 'dtype': 'string'}}), ] -CustomGradientRegistry.register_custom_stop_gradient_edges([0], 'com.microsoft', 'ATenOp', 'aten::multinomial', '') \ No newline at end of file +CustomGradientRegistry.register_custom_stop_gradient_edges([0], 'com.microsoft', 'ATenOp', 'aten::multinomial', '') + +@register_gradient('com.microsoft', 'ATenOp', 'aten::binary_cross_entropy_with_logits', '') +def binary_cross_entropy_with_logits_gradient(): + return [ + (('ATenOp', 'com.microsoft'), ['GO(0)', 'I(0)', 'I(1)', 'I(2)', 'I(3)', 'I(4)'], [ + 'GI(0)'], {'name': {'value': 'aten::binary_cross_entropy_with_logits_backward', 'dtype': 'string'}}), + ] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index b4e54de19f..54fb28d021 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -116,3 +116,15 @@ def avg_pool2d(g, self, kernel_size, stride, padding, ceil_mode, count_include_p @register_symbolic('adaptive_avg_pool2d') def adaptive_avg_pool2d(g, self, output_size): return g.op("com.microsoft::ATenOp", self, output_size, name_s='aten::_adaptive_avg_pool2d') + + +@register_symbolic('binary_cross_entropy_with_logits') +def binary_cross_entropy_with_logits(g, self, target, weight, pos_weight, reduction): + # If weight is not None, we need to check if it requires grad and add gradient graph accordingly. + # But current custom_gradient_registry doesn't support such None checking, + # So doesn't support non-None weight for now. + if weight is None or sym_help._is_none(weight): + return g.op("com.microsoft::ATenOp", self, target, weight, pos_weight, reduction, + name_s='aten::binary_cross_entropy_with_logits') + from torch.onnx.symbolic_opset12 import binary_cross_entropy_with_logits as bce + return bce(g, self, target, weight, pos_weight, reduction) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index da7e849eae..312be267b4 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1110,6 +1110,38 @@ def test_aten_multinomial(input_shape, num_samples, replacement): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) +def test_gradient_correctness_bce_with_logits(): + class NeuralNetBCEWithLogitsLoss(torch.nn.Module): + def __init__(self, input_size, hidden_size): + super(NeuralNetBCEWithLogitsLoss, self).__init__() + self.linear = torch.nn.Linear(input_size, hidden_size) + + def forward(self, input, target): + loss_fct = torch.nn.BCEWithLogitsLoss() + return loss_fct(self.linear(input), target) + + N, D, H = 16, 256, 128 + device = 'cuda' + pt_model = NeuralNetBCEWithLogitsLoss(D, H).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input, target): + prediction = model(input, target) + loss = prediction.sum() + loss.backward() + return prediction + + for _ in range(10): + pt_input = torch.rand((N, D), device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_target = torch.rand((N, H), device=device) + ort_target = copy.deepcopy(pt_target) + pt_prediction = run_step(pt_model, pt_input, pt_target) + ort_prediction = run_step(ort_model, ort_input, ort_target) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + def test_module_with_non_differential_output(): device = 'cuda' N, D_in, H, D_out = 32, 128, 64, 10