mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
ATenOp Support for BCEWithLogitsLoss (#9670)
This commit is contained in:
parent
1b70a14c51
commit
adf98feb2c
4 changed files with 65 additions and 1 deletions
|
|
@ -196,6 +196,7 @@ class SymbolicShapeInference:
|
|||
'aten::argmax': self._infer_aten_argmax,
|
||||
'aten::avg_pool2d': self._infer_aten_pool2d,
|
||||
'aten::_adaptive_avg_pool2d': self._infer_aten_pool2d,
|
||||
'aten::binary_cross_entropy_with_logits': self._infer_aten_bce,
|
||||
}
|
||||
self.run_ = True
|
||||
self.suggested_merge_ = {}
|
||||
|
|
@ -1174,6 +1175,18 @@ class SymbolicShapeInference:
|
|||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
|
||||
|
||||
def _infer_aten_bce(self, node):
|
||||
reduction = self._try_get_value(node, 4)
|
||||
if reduction is None:
|
||||
reduction = 1
|
||||
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
if reduction == 0:
|
||||
vi.type.tensor_type.elem_type = elem_type
|
||||
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
|
||||
else:
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, self._get_shape(node, 0)))
|
||||
|
||||
def _infer_BatchNormalization(self, node):
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
|
|
|
|||
|
|
@ -133,4 +133,11 @@ def adaptive_avg_pool2d_gradient():
|
|||
'GI(0)'], {'name': {'value': 'aten::_adaptive_avg_pool2d_backward', 'dtype': 'string'}}),
|
||||
]
|
||||
|
||||
CustomGradientRegistry.register_custom_stop_gradient_edges([0], 'com.microsoft', 'ATenOp', 'aten::multinomial', '')
|
||||
CustomGradientRegistry.register_custom_stop_gradient_edges([0], 'com.microsoft', 'ATenOp', 'aten::multinomial', '')
|
||||
|
||||
@register_gradient('com.microsoft', 'ATenOp', 'aten::binary_cross_entropy_with_logits', '')
|
||||
def binary_cross_entropy_with_logits_gradient():
|
||||
return [
|
||||
(('ATenOp', 'com.microsoft'), ['GO(0)', 'I(0)', 'I(1)', 'I(2)', 'I(3)', 'I(4)'], [
|
||||
'GI(0)'], {'name': {'value': 'aten::binary_cross_entropy_with_logits_backward', 'dtype': 'string'}}),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -116,3 +116,15 @@ def avg_pool2d(g, self, kernel_size, stride, padding, ceil_mode, count_include_p
|
|||
@register_symbolic('adaptive_avg_pool2d')
|
||||
def adaptive_avg_pool2d(g, self, output_size):
|
||||
return g.op("com.microsoft::ATenOp", self, output_size, name_s='aten::_adaptive_avg_pool2d')
|
||||
|
||||
|
||||
@register_symbolic('binary_cross_entropy_with_logits')
|
||||
def binary_cross_entropy_with_logits(g, self, target, weight, pos_weight, reduction):
|
||||
# If weight is not None, we need to check if it requires grad and add gradient graph accordingly.
|
||||
# But current custom_gradient_registry doesn't support such None checking,
|
||||
# So doesn't support non-None weight for now.
|
||||
if weight is None or sym_help._is_none(weight):
|
||||
return g.op("com.microsoft::ATenOp", self, target, weight, pos_weight, reduction,
|
||||
name_s='aten::binary_cross_entropy_with_logits')
|
||||
from torch.onnx.symbolic_opset12 import binary_cross_entropy_with_logits as bce
|
||||
return bce(g, self, target, weight, pos_weight, reduction)
|
||||
|
|
|
|||
|
|
@ -1110,6 +1110,38 @@ def test_aten_multinomial(input_shape, num_samples, replacement):
|
|||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
|
||||
def test_gradient_correctness_bce_with_logits():
|
||||
class NeuralNetBCEWithLogitsLoss(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super(NeuralNetBCEWithLogitsLoss, self).__init__()
|
||||
self.linear = torch.nn.Linear(input_size, hidden_size)
|
||||
|
||||
def forward(self, input, target):
|
||||
loss_fct = torch.nn.BCEWithLogitsLoss()
|
||||
return loss_fct(self.linear(input), target)
|
||||
|
||||
N, D, H = 16, 256, 128
|
||||
device = 'cuda'
|
||||
pt_model = NeuralNetBCEWithLogitsLoss(D, H).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, input, target):
|
||||
prediction = model(input, target)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return prediction
|
||||
|
||||
for _ in range(10):
|
||||
pt_input = torch.rand((N, D), device=device, requires_grad=True)
|
||||
ort_input = copy.deepcopy(pt_input)
|
||||
pt_target = torch.rand((N, H), device=device)
|
||||
ort_target = copy.deepcopy(pt_target)
|
||||
pt_prediction = run_step(pt_model, pt_input, pt_target)
|
||||
ort_prediction = run_step(ort_model, ort_input, ort_target)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
def test_module_with_non_differential_output():
|
||||
device = 'cuda'
|
||||
N, D_in, H, D_out = 32, 128, 64, 10
|
||||
|
|
|
|||
Loading…
Reference in a new issue