diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index 75efd71f7f..58bb63fb97 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -159,7 +159,9 @@ hand_implemented = { "aten::eq.Scalar_out": Cast(Equal("self", "other"), to="GetONNXTensorProtoDataType(out.scalar_type())"), "aten::bitwise_and.Tensor_out": MakeTorchFallback(), "aten::masked_select": GatherND("self", Transpose(NonZero(Expand("mask", Shape("self"))))), - "aten::_local_scalar_dense": MakeTorchFallback(), + "aten::_local_scalar_dense": MakeTorchFallback(), # This function extracts a scalar value from + # a tensor with exactly one value; there's no need to try to do this on an ORT device. + # See CPU impl at pytorch/blob/master/aten/src/ATen/native/Scalar.cpp "aten::gt.Scalar_out": MakeTorchFallback(), "aten::lt.Scalar_out": MakeTorchFallback(), "aten::equal": SignatureOnly(),