Remove aten::binary_cross_entropy_with_logits from ATen Fallback (#12301)

This commit is contained in:
Vincent Wang 2022-07-26 07:29:56 +08:00 committed by GitHub
parent 3bf614fd47
commit c40f73ae0c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 0 additions and 78 deletions

View file

@ -203,7 +203,6 @@ class SymbolicShapeInference:
"argmax": self._infer_aten_argmax,
"avg_pool2d": self._infer_aten_pool2d,
"_adaptive_avg_pool2d": self._infer_aten_pool2d,
"binary_cross_entropy_with_logits": self._infer_aten_bce,
"numpy_T": self._infer_Transpose,
}
self.run_ = True
@ -1345,18 +1344,6 @@ class SymbolicShapeInference:
vi = self.known_vi_[node.output[0]]
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
def _infer_aten_bce(self, node):
reduction = self._try_get_value(node, 4)
if reduction is None:
reduction = 1
elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
vi = self.known_vi_[node.output[0]]
if reduction == 0:
vi.type.tensor_type.elem_type = elem_type
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
else:
vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, self._get_shape(node, 0)))
def _infer_BatchNormalization(self, node):
self._propagate_shape_and_type(node)

View file

@ -212,18 +212,6 @@ CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.ate
CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.aten", "ATen", "multinomial", "")
@register_gradient("org.pytorch.aten", "ATen", "binary_cross_entropy_with_logits", "")
def binary_cross_entropy_with_logits_gradient():
return [
(
("ATen", "org.pytorch.aten"),
["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)"],
["GI(0)"],
{"operator": {"value": "binary_cross_entropy_with_logits_backward", "dtype": "string"}},
),
]
@register_gradient("org.pytorch.aten", "ATen", "numpy_T", "")
def numpy_T_gradient():
return [

View file

@ -177,26 +177,6 @@ def adaptive_avg_pool2d(g, self, output_size):
return g.op("org.pytorch.aten::ATen", self, output_size, operator_s="_adaptive_avg_pool2d")
@register_symbolic("binary_cross_entropy_with_logits")
def binary_cross_entropy_with_logits(g, self, target, weight, pos_weight, reduction):
# If weight is not None, we need to check if it requires grad and add gradient graph accordingly.
# But current custom_gradient_registry doesn't support such None checking,
# So doesn't support non-None weight for now.
if weight is None or sym_help._is_none(weight):
return g.op(
"org.pytorch.aten::ATen",
self,
target,
weight,
pos_weight,
reduction,
operator_s="binary_cross_entropy_with_logits",
)
from torch.onnx.symbolic_opset12 import binary_cross_entropy_with_logits as bce
return bce(g, self, target, weight, pos_weight, reduction)
@register_symbolic("numpy_T")
def numpy_T(g, self):
# Numpy-style `a.T`: returns the tensor

View file

@ -1630,39 +1630,6 @@ def test_numpy_T(input_shape):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
def test_gradient_correctness_bce_with_logits():
class NeuralNetBCEWithLogitsLoss(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super(NeuralNetBCEWithLogitsLoss, self).__init__()
self.linear = torch.nn.Linear(input_size, hidden_size)
def forward(self, input, target):
loss_fct = torch.nn.BCEWithLogitsLoss()
return loss_fct(self.linear(input), target)
N, D, H = 16, 256, 128
device = "cuda"
pt_model = NeuralNetBCEWithLogitsLoss(D, H).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, input, target):
prediction = model(input, target)
loss = prediction.sum()
loss.backward()
return prediction
for _ in range(10):
pt_input = torch.rand((N, D), device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_target = torch.rand((N, H), device=device)
ort_target = copy.deepcopy(pt_target)
pt_prediction = run_step(pt_model, pt_input, pt_target)
ort_prediction = run_step(ort_model, ort_input, ort_target)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
def test_gradient_correctness_cast_chain():
class NeuralNetCast(torch.nn.Module):
def __init__(self, D):