[1/N] [Caffe2] Remove caffe2_aten_fallback code (#128675)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128675
Approved by: https://github.com/r-barnes
This commit is contained in:
cyy 2024-06-17 21:25:55 +00:00 committed by PyTorch MergeBot
parent 8953725e6d
commit 163847b1bb
19 changed files with 33 additions and 529 deletions

View file

@ -86,26 +86,6 @@ class TestExportModes(pytorch_test_common.ExportTestCase):
x = torch.ones(3)
torch.onnx.export(foo, (x,), f)
@common_utils.skipIfNoCaffe2
@common_utils.skipIfNoLapack
def test_caffe2_aten_fallback(self):
class ModelWithAtenNotONNXOp(nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg
x = torch.rand(3, 4)
y = torch.rand(3, 4)
torch.onnx.export_to_pretty_string(
ModelWithAtenNotONNXOp(),
(x, y),
add_node_names=False,
do_constant_folding=False,
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
@common_utils.skipIfCaffe2
@common_utils.skipIfNoLapack
def test_aten_fallback(self):
class ModelWithAtenNotONNXOp(nn.Module):

View file

@ -39,7 +39,7 @@ from torch.onnx.symbolic_helper import (
parse_args,
)
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfCaffe2, skipIfNoLapack
from torch.testing._internal.common_utils import skipIfNoLapack
unittest.TestCase.maxDiff = None
@ -414,7 +414,6 @@ class TestOperators(common_utils.TestCase):
x = torch.randn(20, 16, 50)
self.assertONNX(nn.MaxPool1d(3, stride=2, return_indices=True), x)
@skipIfCaffe2
def test_at_op(self):
x = torch.randn(3, 4)
@ -694,7 +693,6 @@ class TestOperators(common_utils.TestCase):
keep_initializers_as_inputs=True,
)
@skipIfCaffe2
def test_embedding_bags(self):
emb_bag = nn.EmbeddingBag(10, 8)
input = torch.tensor([1, 2, 3, 4]).long()
@ -949,7 +947,6 @@ class TestOperators(common_utils.TestCase):
other = torch.randint(-50, 50, (2, 3, 4), dtype=torch.int8)
self.assertONNX(BiwiseAndModel(), (input, other), opset_version=18)
@skipIfCaffe2
def test_layer_norm_aten(self):
model = torch.nn.LayerNorm([10, 10])
x = torch.randn(20, 5, 10, 10)
@ -1203,7 +1200,6 @@ class TestOperators(common_utils.TestCase):
torch.onnx.unregister_custom_op_symbolic("::embedding", _onnx_opset_version)
# This is test_aten_embedding_1 with shape inference on custom symbolic aten::embedding.
@skipIfCaffe2
def test_aten_embedding_2(self):
_onnx_opset_version = 12

View file

@ -20,7 +20,7 @@ import pytorch_test_common
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.onnx import OperatorExportTypes, symbolic_helper, utils
from torch.onnx import symbolic_helper, utils
from torch.onnx._internal import registration
from torch.testing._internal import common_quantization, common_utils, jit_utils
@ -394,7 +394,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
for node in graph.nodes():
self.assertTrue(node.sourceRange())
@common_utils.skipIfCaffe2
def test_clip_aten_fallback_due_exception(self):
def bad_clamp(g, self, min, max):
return symbolic_helper._onnx_unsupported("Bad boy!")
@ -411,7 +410,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
)
self.assertAtenOp(onnx_model, "clamp", "Tensor")
@common_utils.skipIfCaffe2
def test_clip_aten_fallback_explicit_request(self):
class MyClip(torch.nn.Module):
def forward(self, x):
@ -961,60 +959,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5)))
@common_utils.skipIfNoCaffe2
def test_caffe2_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
abcd = x + y
defg = torch.linalg.qr(abcd)
return defg
# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
for operator_export_type in (
OperatorExportTypes.ONNX_ATEN,
OperatorExportTypes.ONNX_ATEN_FALLBACK,
):
x = torch.rand(3, 4)
y = torch.rand(3, 4)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenNotONNXOp(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=operator_export_type,
# support for linalg.qr was added in later op set versions.
opset_version=9,
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")
@common_utils.skipIfNoCaffe2
def test_caffe2_onnx_aten_must_not_fallback(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
return torch.fmod(x, y)
# TODO: Refactor common_utils._decide_skip_caffe2 to support parametrize
for operator_export_type in (
OperatorExportTypes.ONNX_ATEN_FALLBACK,
OperatorExportTypes.ONNX_ATEN,
):
x = torch.randn(3, 4, dtype=torch.float32)
y = torch.randn(3, 4, dtype=torch.float32)
f = io.BytesIO()
torch.onnx.export(
ModelWithAtenFmod(),
(x, y),
f,
do_constant_folding=False,
operator_export_type=operator_export_type,
opset_version=10, # or higher
)
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
assert onnx_model.graph.node[0].op_type == "Mod"
@common_utils.skipIfCaffe2
def test_aten_fallback_must_fallback(self):
class ModelWithAtenNotONNXOp(torch.nn.Module):
def forward(self, x, y):
@ -1037,7 +981,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "linalg_qr")
@common_utils.skipIfCaffe2
def test_onnx_aten(self):
class ModelWithAtenFmod(torch.nn.Module):
def forward(self, x, y):
@ -1056,7 +999,6 @@ class TestONNXExport(pytorch_test_common.ExportTestCase):
onnx_model = onnx.load(io.BytesIO(f.getvalue()))
self.assertAtenOp(onnx_model, "fmod", "Tensor")
@common_utils.skipIfCaffe2
def test_onnx_aten_fallback_must_not_fallback(self):
# For BUILD_CAFFE2=0, aten fallback only when not exportable
class ONNXExportable(torch.nn.Module):
@ -1233,26 +1175,6 @@ class TestQuantizeEagerONNXExport(common_utils.TestCase):
_export_to_onnx(model, data, input_names)
@common_quantization.skipIfNoFBGEMM
@common_utils.skipIfNoCaffe2
def test_lower_graph_linear(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Linear(5, 10, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 2, 5).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)
@common_quantization.skipIfNoFBGEMM
@common_utils.skipIfNoCaffe2
def test_lower_graph_conv2d(self):
model = torch.ao.quantization.QuantWrapper(
torch.nn.Conv2d(3, 5, 2, bias=True)
).to(dtype=torch.float)
data_numpy = np.random.rand(1, 3, 6, 6).astype(np.float32)
data = torch.from_numpy(data_numpy).to(dtype=torch.float)
self._test_lower_graph_impl(model, data)
@common_quantization.skipIfNoFBGEMM
@unittest.skip(
"onnx opset9 does not support quantize_per_tensor and caffe2 \

View file

@ -17,7 +17,6 @@ from pytorch_test_common import (
skipIfUnsupportedMaxOpsetVersion,
skipIfUnsupportedMinOpsetVersion,
)
from verify import verify
import torch
import torch.onnx
@ -26,7 +25,7 @@ from torch.onnx import _constants, OperatorExportTypes, TrainingMode, utils
from torch.onnx._globals import GLOBALS
from torch.onnx.symbolic_helper import _unpack_list, parse_args
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import skipIfNoCaffe2, skipIfNoLapack
from torch.testing._internal.common_utils import skipIfNoLapack
def _remove_test_environment_prefix_from_scope_name(scope_name: str) -> str:
@ -1623,25 +1622,6 @@ class TestUtilityFuns(_BaseTestCase):
"Graph parameter names does not match model parameters.",
)
@skipIfNoCaffe2
def test_modifying_params(self):
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.tensor([2.0]))
def forward(self, x):
y = x * x
self.param.data.add_(1.0)
return y
x = torch.tensor([1, 2])
# Move import to local as caffe2 backend requires additional build flag,
# and is only used in this test case.
import caffe2.python.onnx.backend as backend
verify(MyModel(), x, backend, do_constant_folding=False)
def test_fuse_conv_bn(self):
class Fuse(torch.nn.Module):
def __init__(self):

View file

@ -23,7 +23,7 @@ hu.assert_deadline_disabled()
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, BUILD_WITH_CAFFE2, IS_SANDCASTLE
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
override_quantized_engine, supported_qengines, override_qengines, _snr
@ -4524,47 +4524,6 @@ class TestQuantizedEmbeddingOps(TestCase):
self._test_embedding_bag_unpack_impl(pack_fn, unpack_fn, bit_rate, optimized_qparams, weight)
""" Tests the correctness of the embedding_bag_8bit pack/unpack op against C2 """
@unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2")
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
num_batches=st.integers(1, 5),
data_type=st.sampled_from([np.float32, np.float16]),)
def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim, num_batches, data_type):
pack_fn = torch.ops.quantized.embedding_bag_byte_prepack
unpack_fn = torch.ops.quantized.embedding_bag_byte_unpack
self._test_embedding_bag_unpack_fn(
pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False, num_batches, data_type=data_type)
""" Tests the correctness of the embedding_bag_4bit pack/unpack op against C2 """
@unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2")
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
optimized_qparams=st.booleans(),
data_type=st.sampled_from([np.float32, np.float16]),)
def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type):
pack_fn = torch.ops.quantized.embedding_bag_4bit_prepack
unpack_fn = torch.ops.quantized.embedding_bag_4bit_unpack
# 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1
self._test_embedding_bag_unpack_fn(
pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1, data_type=data_type)
""" Tests the correctness of the embedding_bag_2bit pack/unpack op against C2 """
@unittest.skipIf(not BUILD_WITH_CAFFE2, "Test needs Caffe2")
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 8 == 0),
optimized_qparams=st.booleans(),
data_type=st.sampled_from([np.float32, np.float16]),)
def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimized_qparams, data_type):
pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack
unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack
# 4bit and 2bit quantization right now only works for 2D Tensor so we set the num_batches to 1
self._test_embedding_bag_unpack_fn(
pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1, data_type=data_type)
def embedding_bag_rowwise_offsets_run(
self, bit_rate, num_embeddings,

View file

@ -96,7 +96,7 @@ import torch.nn.functional as F
from torch.testing._internal import jit_utils
from torch.testing._internal.common_jit import check_against_reference
from torch.testing._internal.common_utils import run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
suppress_warnings, BUILD_WITH_CAFFE2, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \
suppress_warnings, IS_SANDCASTLE, GRAPH_EXECUTOR, ProfilingMode, TestCase, \
freeze_rng_state, slowTest, TemporaryFileName, \
enable_profiling_mode_for_profiling_tests, TEST_MKL, set_default_dtype, num_profiled_runs, \
skipIfCrossRef, skipIfTorchDynamo
@ -15299,56 +15299,6 @@ dedent """
continue
self.assertEqual(value, getattr(loaded, "_" + name))
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
@unittest.skipIf(not BUILD_WITH_CAFFE2, "PyTorch is build without Caffe2 support")
def test_old_models_bc(self):
model = {
'archive/version': b'1',
'archive/code/archive.py':
b'''
op_version_set = 0
def forward(self,
_0: Tensor) -> Tensor:
_1 = torch.zeros([10], dtype=6, layout=0, device=torch.device("cpu"))
result = torch.to(torch.fill_(_1, 5), dtype=6, layout=0, device=torch.device("cpu"),
non_blocking=False, copy=False)
result2 = torch.rand([10], dtype=6, layout=0, device=torch.device("cpu"))
result3 = torch.rand_like(result2, dtype=6, layout=0, device=torch.device("cpu"))
_2 = torch.add(torch.add(result, result2, alpha=1), result3, alpha=1)
return _2
''',
'archive/attributes.pkl': b'\x80\x02](e.',
'archive/libs.py': b'op_version_set = 0\n',
'archive/model.json':
b'''
{
"protoVersion":"2",
"mainModule":{
"torchscriptArena":{
"key":"code/archive.py"
},
"name":"archive",
"optimize":true
},
"producerName":"pytorch",
"producerVersion":"1.0",
"libs":{
"torchscriptArena":{
"key":"libs.py"
}
}
}'''}
with TemporaryFileName() as fname:
archive_name = os.path.basename(os.path.normpath(fname))
with zipfile.ZipFile(fname, 'w') as archive:
for k, v in model.items():
archive.writestr(k, v)
with open(fname, "rb") as f:
fn = torch.jit.load(f)
x = torch.zeros(10)
fn(x)
def test_submodule_attribute_serialization(self):
class S(torch.jit.ScriptModule):

View file

@ -2,7 +2,6 @@
from enum import Enum
_CAFFE2_ATEN_FALLBACK: bool
PRODUCER_VERSION: str
class TensorProtoDataType(Enum):

View file

@ -292,7 +292,5 @@ void initONNXBindings(PyObject* module) {
.value("TRAINING", TrainingMode::TRAINING);
onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION);
onnx.attr("_CAFFE2_ATEN_FALLBACK") = false;
}
} // namespace torch::onnx

View file

@ -1,12 +1,7 @@
# mypy: allow-untyped-defs
from torch import _C
from torch._C import _onnx as _C_onnx
from torch._C._onnx import (
_CAFFE2_ATEN_FALLBACK,
OperatorExportTypes,
TensorProtoDataType,
TrainingMode,
)
from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode
from . import ( # usort:skip. Keep the order instead of sorting lexicographically
_deprecation,

View file

@ -12,7 +12,6 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Un
import torch
from torch import _C
from torch._C import _onnx as _C_onnx
from torch.onnx._globals import GLOBALS
from torch.onnx._internal import _beartype, registration
@ -329,14 +328,6 @@ def _scalar(x: torch.Tensor):
return x[0]
@_beartype.beartype
def _is_caffe2_aten_fallback() -> bool:
return (
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
and _C_onnx._CAFFE2_ATEN_FALLBACK
)
@_beartype.beartype
def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
r"""Initializes the right attribute based on type of value."""
@ -350,16 +341,6 @@ def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
if _is_onnx_list(value):
kind += "s"
if aten and _is_caffe2_aten_fallback():
if isinstance(value, torch.Tensor):
# Caffe2 proto does not support tensor attribute.
if value.numel() > 1:
raise ValueError("Should not pass tensor attribute")
value = _scalar(value)
if isinstance(value, float):
kind = "f"
else:
kind = "i"
return getattr(node, f"{kind}_")(name, value)

View file

@ -537,10 +537,7 @@ def is_complex_value(x: _C.Value) -> bool:
@_beartype.beartype
def is_caffe2_aten_fallback() -> bool:
return (
GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
and _C_onnx._CAFFE2_ATEN_FALLBACK
)
return False
@_beartype.beartype
@ -592,9 +589,7 @@ def _get_dim_for_cross(x: _C.Value, dim: Optional[int]):
@_beartype.beartype
def _unimplemented(op: str, msg: str, value: Optional[_C.Value] = None) -> None:
# For BC reasons, the behavior for Caffe2 does not raise exception for unimplemented operators
if _C_onnx._CAFFE2_ATEN_FALLBACK:
warnings.warn(f"ONNX export failed on {op} because {msg} not supported")
elif GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
if GLOBALS.operator_export_type == _C_onnx.OperatorExportTypes.ONNX:
_onnx_unsupported(f"{op}, {msg}", value)

View file

@ -211,10 +211,6 @@ def index_put(
indices_list = symbolic_helper._unpack_list(indices_list_value)
else:
indices_list = [indices_list_value]
if symbolic_helper.is_caffe2_aten_fallback():
args = [self] + indices_list + [values, accumulate]
return g.at("index_put", *args)
accumulate = symbolic_helper._parse_arg(accumulate, "b")
if len(indices_list) == 0:
@ -398,8 +394,6 @@ def __interpolate(
def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
if symbolic_helper._maybe_get_const(sparse_grad, "i"):
return symbolic_helper._unimplemented("gather", "sparse_grad == True")
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("gather", self, dim, index, sparse_grad)
return g.op("GatherElements", self, index, axis_i=dim)
@ -407,8 +401,6 @@ def gather(g: jit_utils.GraphContext, self, dim, index, sparse_grad=False):
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter(g: jit_utils.GraphContext, self, dim, index, src):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("scatter", self, dim, index, src, overload_name="src")
src_type = _type_utils.JitScalarType.from_value(src)
src = symbolic_helper._maybe_get_scalar(src)
if symbolic_helper._is_value(src):
@ -898,8 +890,6 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim):
stop = g.op(
"Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
)
if symbolic_helper.is_caffe2_aten_fallback():
return g.op("_caffe2::Range", stop)
return arange(g, stop, 4, None, None, None)
@ -982,9 +972,6 @@ def mm(g: jit_utils.GraphContext, self, other):
@_onnx_symbolic("aten::index")
@_beartype.beartype
def index(g: jit_utils.GraphContext, self, index):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("index", self, index, overload_name="Tensor")
if symbolic_helper._is_packed_list(index):
indices = symbolic_helper._unpack_list(index)
else:
@ -1007,16 +994,6 @@ def index(g: jit_utils.GraphContext, self, index):
@_beartype.beartype
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
dim_value = symbolic_helper._parse_arg(dim, "i")
if symbolic_helper.is_caffe2_aten_fallback():
return g.at(
"index_fill",
self,
index,
value,
overload_name="int_Scalar",
dim_i=dim_value,
)
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
@ -1030,8 +1007,6 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
@_beartype.beartype
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
dim_value = symbolic_helper._parse_arg(dim, "i")
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("index_copy", self, index, source, dim_i=dim_value)
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)

View file

@ -330,8 +330,6 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
const_step
):
return opset9.unfold(g, input, dimension, const_size, const_step)
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
sizedim = symbolic_helper._get_tensor_dim_size(input, dimension)
if sizedim is not None:

View file

@ -71,9 +71,6 @@ def grid_sampler(
@symbolic_helper.parse_args("v", "i", "v", "v")
@_beartype.beartype
def scatter_add(g: jit_utils.GraphContext, self, dim, index, src):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("scatter", self, dim, index, src, overload_name="src")
src_type = _type_utils.JitScalarType.from_value(
src, _type_utils.JitScalarType.UNDEFINED
)

View file

@ -841,36 +841,18 @@ def _reduce_with_dtype(onnx_op: str, name: str, allow_multi_dim_support: bool =
@symbolic_helper.parse_args("v", "i", "none")
@_beartype.beartype
def cumsum(g: jit_utils.GraphContext, input, dim, dtype):
if symbolic_helper.is_caffe2_aten_fallback():
if dtype.node().kind() != "prim::Constant":
return symbolic_helper._unimplemented("cumsum", "dtype", dtype)
return g.at("cumsum", input, dim_i=dim)
symbolic_helper._onnx_opset_unsupported("cumsum", 9, 11, input)
@_onnx_symbolic("aten::_sample_dirichlet")
@_beartype.beartype
def _sample_dirichlet(g: jit_utils.GraphContext, self, generator):
if symbolic_helper.is_caffe2_aten_fallback():
if not symbolic_helper._is_none(generator):
return symbolic_helper._unimplemented(
"_sample_dirichlet", "We are not able to export generator", self
)
return g.at("_sample_dirichlet", self)
return symbolic_helper._onnx_unsupported("_sample_dirichlet", self)
@_onnx_symbolic("aten::_standard_gamma")
@_beartype.beartype
def _standard_gamma(g: jit_utils.GraphContext, self, generator):
if symbolic_helper.is_caffe2_aten_fallback():
if not symbolic_helper._is_none(generator):
return symbolic_helper._unimplemented(
"_standard_gamma", "not able to export generator", self
)
return g.at("_standard_gamma", self)
return symbolic_helper._onnx_unsupported("_standard_gamma", self)
@ -1007,19 +989,6 @@ def embedding_bag(
return symbolic_helper._onnx_unsupported(
"embedding_bag with per_sample_weights"
)
if symbolic_helper.is_caffe2_aten_fallback():
return g.at(
"embedding_bag",
embedding_matrix,
indices,
offsets,
outputs=4,
scale_grad_by_freq_i=scale_grad_by_freq,
mode_i=mode,
sparse_i=sparse,
include_last_offset_i=include_last_offset,
padding_idx_i=padding_idx,
)
return symbolic_helper._onnx_unsupported("embedding_bag", embedding_matrix)
@ -1052,10 +1021,6 @@ def transpose(g: jit_utils.GraphContext, self, dim0, dim1):
axes = list(range(rank))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
return g.op("Transpose", self, perm_i=axes)
elif symbolic_helper.is_caffe2_aten_fallback():
# if we don't have dim information we cannot
# output a permute so use ATen instead
return g.at("transpose", self, overload_name="int", dim0_i=dim0, dim1_i=dim1)
else:
raise errors.SymbolicValueError(
"Unsupported: ONNX export of transpose for tensor of unknown rank.",
@ -2927,16 +2892,6 @@ def layer_norm(
eps: float,
cudnn_enable: bool,
) -> _C.Value:
if symbolic_helper.is_caffe2_aten_fallback():
return g.at(
"layer_norm",
input,
weight,
bias,
normalized_shape_i=normalized_shape,
eps_f=eps,
cudnn_enable_i=cudnn_enable,
)
normalized, _, _ = native_layer_norm(g, input, normalized_shape, weight, bias, eps)
return normalized
@ -3043,8 +2998,6 @@ def instance_norm(
@symbolic_helper.parse_args("v", "i", "i", "i")
@_beartype.beartype
def unfold(g: jit_utils.GraphContext, input, dimension, size, step):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("unfold", input, dimension_i=dimension, size_i=size, step_i=step)
sizes = symbolic_helper._get_tensor_sizes(input)
# FIXME(justinchuby): Get rid of the try catch here to improve readability
try:
@ -3119,9 +3072,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum
indices_list = symbolic_helper._unpack_list(indices_list_value)
else:
indices_list = [indices_list_value]
if symbolic_helper.is_caffe2_aten_fallback():
args = [self] + indices_list + [values, accumulate]
return g.at("index_put", *args)
accumulate = symbolic_helper._parse_arg(accumulate, "b")
@ -3136,16 +3086,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum
@_beartype.beartype
def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
dim_value = symbolic_helper._parse_arg(dim, "i")
if symbolic_helper.is_caffe2_aten_fallback():
return g.at(
"index_fill",
self,
index,
value,
overload_name="int_Scalar",
dim_i=dim_value,
)
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
@ -3160,8 +3100,6 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value):
@_beartype.beartype
def index_copy(g: jit_utils.GraphContext, self, dim, index, source):
dim_value = symbolic_helper._parse_arg(dim, "i")
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("index_copy", self, index, source, dim_i=dim_value)
expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper(
g, self, dim, index
)
@ -3220,10 +3158,6 @@ def type_as(g: jit_utils.GraphContext, self, other):
to_i=other_dtype.onnx_type(),
)
if symbolic_helper.is_caffe2_aten_fallback():
# We don't know the type of other, bail by emitting ATen
return g.at("type_as", self, other)
raise errors.SymbolicValueError(
"Unsupported: ONNX export of type_as for tensor "
"of unknown dtype. Please check if the dtype of the "
@ -3236,8 +3170,6 @@ def type_as(g: jit_utils.GraphContext, self, other):
@symbolic_helper.parse_args("v", "v", "i", "f")
@_beartype.beartype
def cosine_similarity(g: jit_utils.GraphContext, x1, x2, dim, eps):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("cosine_similarity", x1, x2, dim_i=dim, eps_f=eps)
cross = symbolic_helper._reducesum_helper(
g, mul(g, x1, x2), axes_i=[dim], keepdims_i=0
)
@ -3516,50 +3448,28 @@ def norm(g: jit_utils.GraphContext, self, p, dim, keepdim, dtype=None):
@symbolic_helper.parse_args("v", "v", "v", "i")
@_beartype.beartype
def conv_tbc(g: jit_utils.GraphContext, input, weight, bias, pad):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("conv_tbc", input, weight, bias, pad_i=pad)
else:
# input must have 3 dimensions, see:
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
# input = (time, batch, in_channels)
# weight = (kernel_width, in_channels, out_channels)
# bias = (out_channels,)
input = g.op("Transpose", input, perm_i=[1, 2, 0])
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
return g.op("Transpose", conv, perm_i=[2, 0, 1])
# input must have 3 dimensions, see:
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/ConvolutionTBC.cpp#L8-L10
# input = (time, batch, in_channels)
# weight = (kernel_width, in_channels, out_channels)
# bias = (out_channels,)
input = g.op("Transpose", input, perm_i=[1, 2, 0])
weight = g.op("Transpose", weight, perm_i=[2, 1, 0])
conv = conv1d(g, input, weight, bias, [1], [pad], [1], 1)
return g.op("Transpose", conv, perm_i=[2, 0, 1])
@_onnx_symbolic("aten::_unique")
@symbolic_helper.parse_args("v", "i", "i")
@_beartype.beartype
def _unique(g: jit_utils.GraphContext, input, sorted, return_inverse):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at(
"_unique",
input,
sorted_i=sorted,
return_inverse_i=return_inverse,
outputs=2,
)
else:
return symbolic_helper._onnx_unsupported("_unique", input)
return symbolic_helper._onnx_unsupported("_unique", input)
@_onnx_symbolic("aten::_unique2")
@symbolic_helper.parse_args("v", "i", "i", "i")
@_beartype.beartype
def _unique2(g: jit_utils.GraphContext, input, sorted, return_inverse, return_counts):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at(
"_unique2",
input,
sorted_i=sorted,
return_inverse_i=return_inverse,
return_counts_i=return_counts,
outputs=3,
)
symbolic_helper._onnx_opset_unsupported("_unique2", 9, 11, input)
@ -4973,11 +4883,8 @@ def _dim_arange(g: jit_utils.GraphContext, like, dim):
stop = g.op(
"Gather", like_shape, g.op("Constant", value_t=torch.tensor(dim)), axis_i=0
)
if symbolic_helper.is_caffe2_aten_fallback():
return g.op("_caffe2::Range", stop)
else:
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
return arange(g, stop, 4, None, None, None)
# aten::arange(Scalar end, ScalarType dtype, Layout, Device, bool pin_memory)
return arange(g, stop, 4, None, None, None)
@_onnx_symbolic("aten::detach")
@ -5543,9 +5450,6 @@ def logsumexp(g: jit_utils.GraphContext, input, dim, keepdim):
@_onnx_symbolic("aten::arange")
@_beartype.beartype
def arange(g: jit_utils.GraphContext, *args):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("arange", *args)
@_beartype.beartype
def _get_arange_dtype(dtype):
dtype = symbolic_helper._maybe_get_const(dtype, "i")
@ -5665,9 +5569,6 @@ def masked_fill_(g: jit_utils.GraphContext, self, mask, value):
@_onnx_symbolic("aten::index")
@_beartype.beartype
def index(g: jit_utils.GraphContext, self, index):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("index", self, index, overload_name="Tensor")
if symbolic_helper._is_packed_list(index):
indices = symbolic_helper._unpack_list(index)
else:
@ -6083,17 +5984,6 @@ def gelu(g: jit_utils.GraphContext, self: torch._C.Value, approximate: str = "no
def group_norm(
g: jit_utils.GraphContext, input, num_groups, weight, bias, eps, cudnn_enabled
):
if symbolic_helper.is_caffe2_aten_fallback():
return g.at(
"group_norm",
input,
weight,
bias,
num_groups_i=num_groups,
eps_f=eps,
cudnn_enabled_i=cudnn_enabled,
)
channel_size = symbolic_helper._get_tensor_dim_size(input, 1)
if channel_size is not None:
assert channel_size % num_groups == 0
@ -6169,9 +6059,6 @@ def _weight_norm(g: jit_utils.GraphContext, weight_v, weight_g, dim):
norm_v = norm(g, weight_v, 2, axes, 1)
div = g.op("Div", weight_v, norm_v)
return g.op("Mul", div, weight_g)
if symbolic_helper.is_caffe2_aten_fallback():
return g.at("_weight_norm", weight_v, weight_g, dim_i=dim)
raise errors.SymbolicValueError(
"Unsupported: ONNX export of _weight_norm for tensor of unknown rank.",
weight_v,

View file

@ -11,7 +11,6 @@ import copy
import inspect
import io
import re
import textwrap
import typing
import warnings
from typing import (
@ -681,27 +680,6 @@ def _optimize_graph(
_C._jit_pass_onnx_unpack_quantized_weights(
graph, params_dict, symbolic_helper.is_caffe2_aten_fallback()
)
if symbolic_helper.is_caffe2_aten_fallback():
# Insert permutes before and after each conv op to ensure correct order.
_C._jit_pass_onnx_quantization_insert_permutes(graph, params_dict)
# Find consecutive permutes that are no-ops and remove them.
_C._jit_pass_custom_pattern_based_rewrite_graph(
textwrap.dedent(
"""\
graph(%Pi):
%Pq = quantized::nhwc2nchw(%Pi)
%Pr = quantized::nchw2nhwc(%Pq)
return (%Pr)"""
),
textwrap.dedent(
"""\
graph(%Ri):
return (%Ri)"""
),
graph,
)
# onnx only supports tensors, so we turn all out number types into tensors
_C._jit_pass_erase_number_types(graph)
if GLOBALS.onnx_shape_inference:
@ -734,18 +712,9 @@ def _optimize_graph(
graph = _C._jit_pass_canonicalize(graph)
_C._jit_pass_lint(graph)
if GLOBALS.onnx_shape_inference:
try:
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
except RuntimeError as exc:
if (
_C_onnx._CAFFE2_ATEN_FALLBACK
and exc.args[0]
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
):
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
pass
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
return graph
@ -783,17 +752,6 @@ def warn_on_static_input_change(input_states):
@_beartype.beartype
def _resolve_args_by_export_type(arg_name, arg_value, operator_export_type):
"""Resolves the arguments that are ignored when export_type != operator_export_type.ONNX."""
if (
operator_export_type is not operator_export_type.ONNX
and _C_onnx._CAFFE2_ATEN_FALLBACK
):
if arg_value is True:
warnings.warn(
f"'{arg_name}' can be set to True only when 'operator_export_type' is "
"`ONNX`. Since 'operator_export_type' is not set to 'ONNX', "
f"'{arg_name}' argument will be ignored."
)
arg_value = False
return arg_value
@ -1298,18 +1256,9 @@ def _model_to_graph(
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
if GLOBALS.onnx_shape_inference:
try:
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
except RuntimeError as exc:
if (
_C_onnx._CAFFE2_ATEN_FALLBACK
and exc.args[0]
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
):
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
pass
_C._jit_pass_onnx_graph_shape_type_inference(
graph, params_dict, GLOBALS.export_onnx_opset_version
)
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
@ -1612,15 +1561,6 @@ def _export(
if export_type is None:
export_type = _exporter_states.ExportTypes.PROTOBUF_FILE
# Discussed deprecation with Nikita Shulga and Sergii Dymchenko from Meta
if _C_onnx._CAFFE2_ATEN_FALLBACK:
warnings.warn(
"Caffe2 ONNX exporter is deprecated in version 2.0 and will be "
"removed in 2.2. Please use PyTorch 2.1 or older for this capability.",
category=FutureWarning,
stacklevel=2,
)
if isinstance(model, torch.nn.DataParallel):
raise ValueError(
"torch.nn.DataParallel is not supported by ONNX "
@ -1655,10 +1595,7 @@ def _export(
"no local function support. "
)
if not operator_export_type:
if _C_onnx._CAFFE2_ATEN_FALLBACK:
operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
# By default, training=TrainingMode.EVAL,
# which is good because running a model in training mode could result in
@ -1904,21 +1841,12 @@ def _should_aten_fallback(
is_aten_fallback_export = (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
)
is_caffe2_build = _C_onnx._CAFFE2_ATEN_FALLBACK
if not name.startswith("aten::"):
return False
if is_caffe2_build:
if (
is_onnx_aten_export or is_aten_fallback_export
) and not is_exportable_aten_op:
return True
else:
if is_onnx_aten_export or (
is_aten_fallback_export and not is_exportable_aten_op
):
return True
if is_onnx_aten_export or (is_aten_fallback_export and not is_exportable_aten_op):
return True
return False
@ -1968,7 +1896,7 @@ def _symbolic_context_handler(symbolic_fn: Callable) -> Callable:
def _get_aten_op_overload_name(n: _C.Node) -> str:
# Returns `overload_name` attribute to ATen ops on non-Caffe2 builds
schema = n.schema()
if not schema.startswith("aten::") or symbolic_helper.is_caffe2_aten_fallback():
if not schema.startswith("aten::"):
return ""
return _C.parse_schema(schema).overload_name
@ -2032,14 +1960,7 @@ def _run_symbolic_function(
)
try:
# Caffe2-specific: Quantized op symbolics are registered for opset 9 only.
if symbolic_helper.is_caffe2_aten_fallback() and opset_version == 9:
symbolic_caffe2.register_quantized_ops("caffe2", opset_version)
if namespace == "quantized" and symbolic_helper.is_caffe2_aten_fallback():
domain = "caffe2"
else:
domain = namespace
domain = namespace
symbolic_function_name = f"{domain}::{op_name}"
symbolic_function_group = registration.registry.get_function_group(
@ -2073,10 +1994,7 @@ def _run_symbolic_function(
except RuntimeError:
if operator_export_type == _C_onnx.OperatorExportTypes.ONNX_FALLTHROUGH:
return None
elif (
operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
and not symbolic_helper.is_caffe2_aten_fallback()
):
elif operator_export_type == _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
# Emit ATen op for non-Caffe2 builds when `operator_export_type==ONNX_ATEN_FALLBACK`
attrs = {
k + "_" + node.kindOf(k)[0]: symbolic_helper._node_get(node, k)

View file

@ -633,10 +633,7 @@ def _onnx_graph_from_model(
utils._setup_trace_module_map(model, export_modules_as_functions)
if not operator_export_type:
if _C_onnx._CAFFE2_ATEN_FALLBACK:
operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
else:
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
operator_export_type = _C_onnx.OperatorExportTypes.ONNX
GLOBALS.export_onnx_opset_version = opset_version
GLOBALS.operator_export_type = operator_export_type

View file

@ -329,14 +329,6 @@ def skipIfNoQNNPACK(fn):
fn(*args, **kwargs)
return wrapper
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if not torch.onnx._CAFFE2_ATEN_FALLBACK:
raise unittest.SkipTest(reason)
else:
fn(*args, **kwargs)
return wrapper
def withQNNPACKBackend(fn):
# TODO(future PR): consider combining with skipIfNoQNNPACK,
# will require testing of existing callsites

View file

@ -1252,8 +1252,6 @@ TEST_OPT_EINSUM = _check_module_exists('opt_einsum')
TEST_Z3 = _check_module_exists('z3')
BUILD_WITH_CAFFE2 = torch.onnx._CAFFE2_ATEN_FALLBACK
def split_if_not_empty(x: str):
return x.split(",") if len(x) != 0 else []
@ -1886,19 +1884,6 @@ def skipIfNotRegistered(op_name, message):
"""
return unittest.skip("Pytorch is compiled without Caffe2")
def _decide_skip_caffe2(expect_caffe2, reason):
def skip_dec(func):
@wraps(func)
def wrapper(self):
if torch.onnx._CAFFE2_ATEN_FALLBACK != expect_caffe2:
raise unittest.SkipTest(reason)
return func(self)
return wrapper
return skip_dec
skipIfCaffe2 = _decide_skip_caffe2(False, "Not compatible with Caffe2")
skipIfNoCaffe2 = _decide_skip_caffe2(True, "Caffe2 is not available")
def skipIfNoSciPy(fn):
@wraps(fn)
def wrapper(*args, **kwargs):