mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Remove aten::binary_cross_entropy_with_logits from ATen Fallback (#12301)
This commit is contained in:
parent
3bf614fd47
commit
c40f73ae0c
4 changed files with 0 additions and 78 deletions
|
|
@ -203,7 +203,6 @@ class SymbolicShapeInference:
|
|||
"argmax": self._infer_aten_argmax,
|
||||
"avg_pool2d": self._infer_aten_pool2d,
|
||||
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
|
||||
"binary_cross_entropy_with_logits": self._infer_aten_bce,
|
||||
"numpy_T": self._infer_Transpose,
|
||||
}
|
||||
self.run_ = True
|
||||
|
|
@ -1345,18 +1344,6 @@ class SymbolicShapeInference:
|
|||
vi = self.known_vi_[node.output[0]]
|
||||
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
|
||||
|
||||
def _infer_aten_bce(self, node):
|
||||
reduction = self._try_get_value(node, 4)
|
||||
if reduction is None:
|
||||
reduction = 1
|
||||
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
||||
vi = self.known_vi_[node.output[0]]
|
||||
if reduction == 0:
|
||||
vi.type.tensor_type.elem_type = elem_type
|
||||
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
|
||||
else:
|
||||
vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, self._get_shape(node, 0)))
|
||||
|
||||
def _infer_BatchNormalization(self, node):
|
||||
self._propagate_shape_and_type(node)
|
||||
|
||||
|
|
|
|||
|
|
@ -212,18 +212,6 @@ CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.ate
|
|||
CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.aten", "ATen", "multinomial", "")
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "binary_cross_entropy_with_logits", "")
|
||||
def binary_cross_entropy_with_logits_gradient():
|
||||
return [
|
||||
(
|
||||
("ATen", "org.pytorch.aten"),
|
||||
["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)"],
|
||||
["GI(0)"],
|
||||
{"operator": {"value": "binary_cross_entropy_with_logits_backward", "dtype": "string"}},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "numpy_T", "")
|
||||
def numpy_T_gradient():
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -177,26 +177,6 @@ def adaptive_avg_pool2d(g, self, output_size):
|
|||
return g.op("org.pytorch.aten::ATen", self, output_size, operator_s="_adaptive_avg_pool2d")
|
||||
|
||||
|
||||
@register_symbolic("binary_cross_entropy_with_logits")
|
||||
def binary_cross_entropy_with_logits(g, self, target, weight, pos_weight, reduction):
|
||||
# If weight is not None, we need to check if it requires grad and add gradient graph accordingly.
|
||||
# But current custom_gradient_registry doesn't support such None checking,
|
||||
# So doesn't support non-None weight for now.
|
||||
if weight is None or sym_help._is_none(weight):
|
||||
return g.op(
|
||||
"org.pytorch.aten::ATen",
|
||||
self,
|
||||
target,
|
||||
weight,
|
||||
pos_weight,
|
||||
reduction,
|
||||
operator_s="binary_cross_entropy_with_logits",
|
||||
)
|
||||
from torch.onnx.symbolic_opset12 import binary_cross_entropy_with_logits as bce
|
||||
|
||||
return bce(g, self, target, weight, pos_weight, reduction)
|
||||
|
||||
|
||||
@register_symbolic("numpy_T")
|
||||
def numpy_T(g, self):
|
||||
# Numpy-style `a.T`: returns the tensor
|
||||
|
|
|
|||
|
|
@ -1630,39 +1630,6 @@ def test_numpy_T(input_shape):
|
|||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
|
||||
|
||||
def test_gradient_correctness_bce_with_logits():
|
||||
class NeuralNetBCEWithLogitsLoss(torch.nn.Module):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super(NeuralNetBCEWithLogitsLoss, self).__init__()
|
||||
self.linear = torch.nn.Linear(input_size, hidden_size)
|
||||
|
||||
def forward(self, input, target):
|
||||
loss_fct = torch.nn.BCEWithLogitsLoss()
|
||||
return loss_fct(self.linear(input), target)
|
||||
|
||||
N, D, H = 16, 256, 128
|
||||
device = "cuda"
|
||||
pt_model = NeuralNetBCEWithLogitsLoss(D, H).to(device)
|
||||
ort_model = ORTModule(copy.deepcopy(pt_model))
|
||||
|
||||
def run_step(model, input, target):
|
||||
prediction = model(input, target)
|
||||
loss = prediction.sum()
|
||||
loss.backward()
|
||||
return prediction
|
||||
|
||||
for _ in range(10):
|
||||
pt_input = torch.rand((N, D), device=device, requires_grad=True)
|
||||
ort_input = copy.deepcopy(pt_input)
|
||||
pt_target = torch.rand((N, H), device=device)
|
||||
ort_target = copy.deepcopy(pt_target)
|
||||
pt_prediction = run_step(pt_model, pt_input, pt_target)
|
||||
ort_prediction = run_step(ort_model, ort_input, ort_target)
|
||||
|
||||
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
|
||||
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
|
||||
|
||||
|
||||
def test_gradient_correctness_cast_chain():
|
||||
class NeuralNetCast(torch.nn.Module):
|
||||
def __init__(self, D):
|
||||
|
|
|
|||
Loading…
Reference in a new issue