diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index d92e3e9bb3a..5cee9303f18 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -545,8 +545,8 @@ Tensor mkldnn_convolution_pointwise_binary( // op, such as "hardtanh" has scalar parameters "gelu" has algorithm parameters. Tensor& mkldnn_convolution_pointwise_binary_( - const Tensor& input_t, Tensor& other_t, + const Tensor& input_t, const Tensor& weight_t, const c10::optional& bias_opt, IntArrayRef padding, diff --git a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp index dbff070fc6f..9b0b8ce23d1 100644 --- a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp +++ b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp @@ -50,7 +50,7 @@ TORCH_LIBRARY(mkldnn, m) { m.def(TORCH_SELECTIVE_SCHEMA( "mkldnn::_convolution_pointwise.binary(Tensor X, Tensor other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA( - "mkldnn::_convolution_pointwise_.binary(Tensor X, Tensor(a!) other, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y")); + "mkldnn::_convolution_pointwise_.binary(Tensor(a!) other, Tensor X, Tensor W, Tensor? B, int[] padding, int[] stride, int[] dilation, int groups, str binary_attr, Scalar? alpha, str? unary_attr, Scalar?[] unary_scalars, str? unary_algorithm) -> Tensor(a!) Y")); m.def(TORCH_SELECTIVE_SCHEMA( "mkldnn::_convolution_transpose_pointwise(Tensor X, Tensor W, Tensor? B, int[] padding, int[] output_padding, int[] stride, int[] dilation, int groups, str attr, Scalar?[] scalars, str? algorithm) -> Tensor Y")); m.def(TORCH_SELECTIVE_SCHEMA( diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 7d088c2ae5c..ed320c82f77 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -362,6 +362,7 @@ ALLOW_LIST = [ ("aten::_nested_view_from_buffer_copy", datetime.date(2023, 5, 1)), ("aten::_nested_view_from_buffer", datetime.date(2023, 5, 1)), ("aten::_scaled_dot_product_flash_attention_backward", datetime.date(2023, 6, 1)), + ("mkldnn::_convolution_pointwise_.binary", datetime.date(2023, 7, 1)), # These ops were moved to python under the c10d_functional namespace ("aten::wait_tensor", datetime.date(9999, 1, 30)), ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)), diff --git a/test/inductor/test_cpp_wrapper.py b/test/inductor/test_cpp_wrapper.py index 42f58350c97..d5897568d13 100644 --- a/test/inductor/test_cpp_wrapper.py +++ b/test/inductor/test_cpp_wrapper.py @@ -72,6 +72,9 @@ test_failures_cpp_wrapper = { "test_conv2d_binary_inplace_fusion_failed_cpu_dynamic_shapes": test_torchinductor.TestFailure( ("cpp_wrapper",), is_skip=True ), + "test_conv2d_binary_inplace_fusion_pass_cpu_dynamic_shapes": test_torchinductor.TestFailure( + ("cpp_wrapper",), is_skip=True + ), } @@ -129,6 +132,16 @@ if RUN_CPU: ["op_convolution_pointwise_binary_.call"], ], ), + BaseTest( + "test_conv2d_binary_inplace_fusion_pass", + "cpu", + test_mkldnn_pattern_matcher.TestPaternMatcher(), + condition=torch._C.has_mkldnn, + func_inputs=[ + ["op_convolution_pointwise_binary_.call"], + ["op_convolution_pointwise_binary.call"], + ], + ), BaseTest( "test_conv2d_unary", "cpu", diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 82bf2927eca..b334c1d9e12 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -345,7 +345,9 @@ class TestPaternMatcher(TestCase): v = torch.randn(1, 3, 28, 28) self._test_common(mod, (v,), 0, 0) - def test_conv2d_binary_inplace_fusion_pass(self): + def test_conv2d_binary_inplace_fusion_pass_cpu( + self, include_ops=None, exclude_ops=None + ): class Model(torch.nn.Module): def __init__(self): super().__init__() @@ -362,8 +364,12 @@ class TestPaternMatcher(TestCase): torch.randn(1, 32, 28, 28).to(memory_format=torch.channels_last), ] mod = Model().to(memory_format=torch.channels_last).eval() - include_ops = ["mkldnn._convolution_pointwise_.binary"] - exclude_ops = ["mkldnn._convolution_pointwise.binary"] + + if include_ops is None: + include_ops = ["mkldnn._convolution_pointwise_.binary"] + if exclude_ops is None: + exclude_ops = ["mkldnn._convolution_pointwise.binary"] + self._test_code_common(mod, inputs, include_ops, exclude_ops) def test_conv2d_binary_inplace_fusion_failed_cpu( diff --git a/test/test_mkldnn_fusion.py b/test/test_mkldnn_fusion.py index fad3e77dcca..ac5521815b0 100644 --- a/test/test_mkldnn_fusion.py +++ b/test/test_mkldnn_fusion.py @@ -295,7 +295,7 @@ class TestMkldnnFusion(JitTestCase): # for binary add, we support inplace version. if attr == "add": fused_inplace = torch.ops.mkldnn._convolution_pointwise_( - x, other, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, + other, x, mod.conv.weight, mod.conv.bias, mod.conv.padding, mod.conv.stride, mod.conv.dilation, mod.conv.groups, attr, None, unary_attr, [], None ) self.assertEqual(ref, other) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 19800656180..49e8ce67009 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3689,21 +3689,50 @@ class ConvolutionBinary(ExternKernelAlloc): class ConvolutionBinaryInplace(ExternKernelAlloc): - kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" - def __init__( self, kernel_layout, inputs, constant_args=(), - kernel="torch.ops.mkldnn._convolution_pointwise_.binary", ): - super().__init__(kernel_layout, inputs, constant_args) - self.kernel = kernel + # Due to constrain of op.call, other (Tensor&) should be at input[0] + reordered_inputs = [inputs[1], inputs[0]] + inputs[2:] + + super().__init__( + kernel_layout, + reordered_inputs, + constant_args, + None, + kernel="torch.ops.mkldnn._convolution_pointwise_.binary", + cpp_kernel="mkldnn::_convolution_pointwise_", + ) + self.cpp_kernel_overlad_name = "binary" + self.cpp_kernel_key = "convolution_pointwise_binary_" + # TODO: op.call: input[0] should be at::Tensor& + self.cpp_op_schema = """ + at::Tensor&( + at::Tensor& other_t, + const at::Tensor& input_t, + const at::Tensor& weight_t, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups, + c10::string_view binary_attr, + c10::optional alpha, + c10::optional unary_attr, + torch::List> unary_scalars, + c10::optional unary_algorithm)""" def codegen(self, wrapper): - wrapper.writeline( - f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" + wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( + self.get_name(), + self.kernel, + self.codegen_args(), + self.cpp_op_schema, + self.cpp_kernel_key, + self.cpp_kernel_overlad_name, ) def get_mutation_names(self): @@ -3727,7 +3756,6 @@ class ConvolutionBinaryInplace(ExternKernelAlloc): unary_scalars: Optional[List[Any]], unary_algorithm: Optional[str], ): - kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" ( inputs, constant_args, @@ -3738,18 +3766,20 @@ class ConvolutionBinaryInplace(ExternKernelAlloc): ) other = cls.require_stride_order(other, req_stride_order) inputs.insert(1, other) + optional_scalar = OptionalScalar() + optional_string = OptionalString() + optional_list = OptionalList() constant_args = constant_args + [ binary_attr, - binary_alpha, - unary_attr, - unary_scalars, - unary_algorithm, + may_convert_to_optional(optional_scalar, binary_alpha), + may_convert_to_optional(optional_string, unary_attr), + may_convert_to_optional(optional_list, unary_scalars), + may_convert_to_optional(optional_string, unary_algorithm), ] return ConvolutionBinaryInplace( kernel_layout=MutationLayout(inputs[1]), inputs=inputs, constant_args=constant_args, - kernel=kernel, )