ATenOp Support for BCEWithLogitsLoss (#9670)

This commit is contained in:
Vincent Wang 2021-11-10 08:36:57 +08:00 committed by GitHub
parent 1b70a14c51
commit adf98feb2c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 1 deletions

View file

@ -196,6 +196,7 @@ class SymbolicShapeInference:
'aten::argmax': self._infer_aten_argmax,
'aten::avg_pool2d': self._infer_aten_pool2d,
'aten::_adaptive_avg_pool2d': self._infer_aten_pool2d,
'aten::binary_cross_entropy_with_logits': self._infer_aten_bce,
}
self.run_ = True
self.suggested_merge_ = {}
@ -1174,6 +1175,18 @@ class SymbolicShapeInference:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
def _infer_aten_bce(self, node):
reduction = self._try_get_value(node, 4)
if reduction is None:
reduction = 1
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
if reduction == 0:
vi.type.tensor_type.elem_type = elem_type
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
else:
vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, self._get_shape(node, 0)))
def _infer_BatchNormalization(self, node):
self._propagate_shape_and_type(node)

View file

@ -133,4 +133,11 @@ def adaptive_avg_pool2d_gradient():
'GI(0)'], {'name': {'value': 'aten::_adaptive_avg_pool2d_backward', 'dtype': 'string'}}),
]
CustomGradientRegistry.register_custom_stop_gradient_edges([0], 'com.microsoft', 'ATenOp', 'aten::multinomial', '')
CustomGradientRegistry.register_custom_stop_gradient_edges([0], 'com.microsoft', 'ATenOp', 'aten::multinomial', '')
@register_gradient('com.microsoft', 'ATenOp', 'aten::binary_cross_entropy_with_logits', '')
def binary_cross_entropy_with_logits_gradient():
return [
(('ATenOp', 'com.microsoft'), ['GO(0)', 'I(0)', 'I(1)', 'I(2)', 'I(3)', 'I(4)'], [
'GI(0)'], {'name': {'value': 'aten::binary_cross_entropy_with_logits_backward', 'dtype': 'string'}}),
]

View file

@ -116,3 +116,15 @@ def avg_pool2d(g, self, kernel_size, stride, padding, ceil_mode, count_include_p
@register_symbolic('adaptive_avg_pool2d')
def adaptive_avg_pool2d(g, self, output_size):
return g.op("com.microsoft::ATenOp", self, output_size, name_s='aten::_adaptive_avg_pool2d')
@register_symbolic('binary_cross_entropy_with_logits')
def binary_cross_entropy_with_logits(g, self, target, weight, pos_weight, reduction):
# If weight is not None, we need to check if it requires grad and add gradient graph accordingly.
# But current custom_gradient_registry doesn't support such None checking,
# So doesn't support non-None weight for now.
if weight is None or sym_help._is_none(weight):
return g.op("com.microsoft::ATenOp", self, target, weight, pos_weight, reduction,
name_s='aten::binary_cross_entropy_with_logits')
from torch.onnx.symbolic_opset12 import binary_cross_entropy_with_logits as bce
return bce(g, self, target, weight, pos_weight, reduction)

View file

@ -1110,6 +1110,38 @@ def test_aten_multinomial(input_shape, num_samples, replacement):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
def test_gradient_correctness_bce_with_logits():
class NeuralNetBCEWithLogitsLoss(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super(NeuralNetBCEWithLogitsLoss, self).__init__()
self.linear = torch.nn.Linear(input_size, hidden_size)
def forward(self, input, target):
loss_fct = torch.nn.BCEWithLogitsLoss()
return loss_fct(self.linear(input), target)
N, D, H = 16, 256, 128
device = 'cuda'
pt_model = NeuralNetBCEWithLogitsLoss(D, H).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, input, target):
prediction = model(input, target)
loss = prediction.sum()
loss.backward()
return prediction
for _ in range(10):
pt_input = torch.rand((N, D), device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_target = torch.rand((N, H), device=device)
ort_target = copy.deepcopy(pt_target)
pt_prediction = run_step(pt_model, pt_input, pt_target)
ort_prediction = run_step(ort_model, ort_input, ort_target)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
def test_module_with_non_differential_output():
device = 'cuda'
N, D_in, H, D_out = 32, 128, 64, 10