diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp b/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp index 65c4d4c645c..0289a5166cf 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/detail/QMatmul.cpp @@ -9,6 +9,82 @@ namespace at::native::onednn { +at::Tensor broadcast_bias2D( + at::Tensor& dst, + at::Tensor& bias, + int64_t m, + int64_t n) { + switch (bias.dim()) { + case 1: + TORCH_CHECK( + bias.size(0) == n || bias.size(0) == 1, + "matmul supports [n] or [1] when bias dim is 1, but b.size() is:", + bias.size(0)); + break; + case 2: + if ((bias.size(0) == m && bias.size(1) == n) || + (bias.size(0) == m && bias.size(1) == 1) || + (bias.size(0) == m && bias.size(1) == 1)) + return bias; // No need to broadcast + TORCH_CHECK( + bias.size(0) == 1 && bias.size(1) == 1, + "matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...") + break; + case 0: + TORCH_CHECK( + bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ..."); + default: + TORCH_CHECK(0, "unsupported bias dim in matmul ..."); + } + bias = bias.expand({1, n}).contiguous(); + return bias; +} + +at::Tensor broadcast_bias3D( + at::Tensor& dst, + at::Tensor bias, + int64_t mb, + int64_t m, + int64_t n) { + switch (bias.dim()) { + case 1: + TORCH_CHECK( + bias.size(0) == n || bias.size(0) == 1, + "matmul supports [n] or [1] when bias dim is 1, but b.size() is:", + bias.size(0)); + break; + case 3: + TORCH_CHECK( + are_expandable({mb, m, n}, bias.sizes()), + "matmul bias must be expandable to:", + dst.sizes(), + " but got:", + bias.sizes()); + break; + case 0: + TORCH_CHECK( + bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ..."); + break; + default: + TORCH_CHECK(0, "unsupported bias dim in matmul ..."); + } + bias = bias.expand({mb, m, n}).contiguous(); + return bias; +} + +at::Tensor broadcast_bias( + at::Tensor& dst, + at::Tensor bias, + int64_t mb, + int64_t m, + int64_t n) { + if (dst.dim() == 2) { + return broadcast_bias2D(dst, bias, m, n); + } else { + return broadcast_bias3D(dst, bias, mb, m, n); + } +} + void quantized_matmul( at::Tensor mat1, // act double input_scale, @@ -16,7 +92,7 @@ void quantized_matmul( at::Tensor mat2, // weight at::Tensor& weight_scales, at::Tensor& weight_zero_points, - at::Tensor& b_raw, + at::Tensor& bias, at::Tensor result, // output double output_scale, int64_t output_zero_point, @@ -28,13 +104,13 @@ void quantized_matmul( double binary_alpha, const c10::string_view& unary_post_op, torch::List>& unary_post_op_args, - c10::string_view unary_post_op_algorithm) { + c10::string_view unary_post_op_algorithm, + bool m2_trans) { // [Note] Quantized Matrix Multiplication at XPU // The following code integrates oneDNN quantized gemm. The quantization // config we support: // activation: s8&u8; per tensor calibrated; symmetric&asymmetric // weight: s8; per_tensor/per_channel calibrated; symmetric - bool m2_trans = true; auto attr = Attr(1.0 / output_scale, output_zero_point); construct_attr_by_post_op( binary_post_op, @@ -75,51 +151,11 @@ void quantized_matmul( } bool with_bias = false; - at::Tensor b = b_raw; + at::Tensor b = bias; if (b.defined()) { with_bias = true; - if (b.dim() == 1) { - TORCH_CHECK( - b.size(0) == n || b.size(0) == 1, - "matmul supports [n] or [1] when bias dim is 1, but b.size() is:", - b.size(0)); - if (b.size(0) == 0) { - with_bias = false; - } else if (m1.dim() == 3) { - b = b.expand({mb, m, n}).contiguous(); - } else if (m1.dim() == 2) { - b = b.expand({1, n}).contiguous(); - } - } else if (b.dim() == 2) { - TORCH_CHECK( - (b.size(0) == m && b.size(1) == n) || - (b.size(0) == 1 && b.size(1) == n) || - (b.size(0) == m && b.size(1) == 1) || - (b.size(0) == 1 && b.size(1) == 1), - "matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ..."); - if (b.size(0) == 1 && b.size(1) == 1) - b = b.expand({1, n}).contiguous(); - } else if (b.dim() == 3) { - TORCH_CHECK( - are_expandable({mb, m, n}, b.sizes()), - "matmul bias must be expandable to:", - dst.sizes(), - " but got:", - b.sizes()); - b = b.expand({mb, m, n}).contiguous(); - } else if (b.dim() == 0) { - TORCH_CHECK( - b.numel() == 1, "matmul supports 1 numel when bias dim is [] ..."); - if (m1.dim() == 3) { - b = b.expand({mb, m, n}).contiguous(); - } else { - b = b.expand({1, n}).contiguous(); - } - } else { - TORCH_CHECK(0, "unsupported bias dim in matmul ..."); - } + b = broadcast_bias(dst, b, mb, m, n); } - // bias is fused in post-op for quantized path b = b.contiguous(); // avoid reorder 2 times diff --git a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h index 4987a40ffa6..d7012b0f2a9 100644 --- a/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h +++ b/aten/src/ATen/native/mkldnn/xpu/detail/oneDNN.h @@ -152,6 +152,7 @@ void quantized_matmul( double binary_alpha, const c10::string_view& unary_post_op, torch::List>& unary_post_op_args, - c10::string_view unary_post_op_algorithm); + c10::string_view unary_post_op_algorithm, + bool m2_trnas); } // namespace at::native::onednn diff --git a/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp b/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp index 18e11eb87a6..0b6dc8bbbfc 100644 --- a/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/qlinear.cpp @@ -25,6 +25,11 @@ Tensor q_linear_pointwise( const int64_t dim = act.dim(); TORCH_CHECK(dim == 2, "qliner XPU: input dim should be 2, but got", dim); + TORCH_CHECK( + act.device() == weight.device() && + act.device() == weight_scales.device() && + act.device() == weight_zero_points.device(), + "qlinear xpu: input tensors(act, weight, weight scale, weight zero-points) should be on the same device"); int64_t K = act.size(dim - 1); int64_t M = act.numel() / K; // [M, K] x [K, N] @@ -55,7 +60,8 @@ Tensor q_linear_pointwise( /*binary alpha*/ 1.0, post_op_name, post_op_args, - post_op_algorithm); + post_op_algorithm, + /*m2_trans*/ true); return qout; } @@ -78,6 +84,11 @@ Tensor q_linear_pointwise_tensor( const int64_t dim = act.dim(); TORCH_CHECK(dim == 2, "qliner XPU: input dim should be 2, but got", dim); + TORCH_CHECK( + act.device() == weight.device() && + act.device() == weight_scales.device() && + act.device() == weight_zero_points.device(), + "qlinear xpu: input tensors(act, weight, weight scale, weight zero-points) should be on the same device"); int64_t K = act.size(dim - 1); int64_t M = act.numel() / K; // [M, K] x [K, N] @@ -108,7 +119,8 @@ Tensor q_linear_pointwise_tensor( /*binary alpha*/ 1.0, post_op_name, post_op_args, - post_op_algorithm); + post_op_algorithm, + /*m2_trans*/ true); return qout; } diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index 86ab1354089..8bce4254b69 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -2060,6 +2060,12 @@ class TestPatternMatcher(TestPatternMatcherBase): return self.linear2(self.linear(x)) mod = M(bias, do_permute=do_permute).eval().to(device=device) + assert isinstance(inputs, tuple) + + def __convert_tensor_to_device(input, device): + return input.to(device=device) if isinstance(input, torch.Tensor) else input + + inputs = tuple(__convert_tensor_to_device(input, device) for input in inputs) def _default_matcher_check_fn(): self.assertEqual(