[ghstack-poisoned]
This commit is contained in:
zhiweiya 2025-02-09 23:37:19 -08:00
commit f22cfc1281
4 changed files with 103 additions and 48 deletions

View file

@ -9,6 +9,82 @@
namespace at::native::onednn {
at::Tensor broadcast_bias2D(
at::Tensor& dst,
at::Tensor& bias,
int64_t m,
int64_t n) {
switch (bias.dim()) {
case 1:
TORCH_CHECK(
bias.size(0) == n || bias.size(0) == 1,
"matmul supports [n] or [1] when bias dim is 1, but b.size() is:",
bias.size(0));
break;
case 2:
if ((bias.size(0) == m && bias.size(1) == n) ||
(bias.size(0) == m && bias.size(1) == 1) ||
(bias.size(0) == m && bias.size(1) == 1))
return bias; // No need to broadcast
TORCH_CHECK(
bias.size(0) == 1 && bias.size(1) == 1,
"matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...")
break;
case 0:
TORCH_CHECK(
bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
default:
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
}
bias = bias.expand({1, n}).contiguous();
return bias;
}
at::Tensor broadcast_bias3D(
at::Tensor& dst,
at::Tensor bias,
int64_t mb,
int64_t m,
int64_t n) {
switch (bias.dim()) {
case 1:
TORCH_CHECK(
bias.size(0) == n || bias.size(0) == 1,
"matmul supports [n] or [1] when bias dim is 1, but b.size() is:",
bias.size(0));
break;
case 3:
TORCH_CHECK(
are_expandable({mb, m, n}, bias.sizes()),
"matmul bias must be expandable to:",
dst.sizes(),
" but got:",
bias.sizes());
break;
case 0:
TORCH_CHECK(
bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
break;
default:
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
}
bias = bias.expand({mb, m, n}).contiguous();
return bias;
}
at::Tensor broadcast_bias(
at::Tensor& dst,
at::Tensor bias,
int64_t mb,
int64_t m,
int64_t n) {
if (dst.dim() == 2) {
return broadcast_bias2D(dst, bias, m, n);
} else {
return broadcast_bias3D(dst, bias, mb, m, n);
}
}
void quantized_matmul(
at::Tensor mat1, // act
double input_scale,
@ -16,7 +92,7 @@ void quantized_matmul(
at::Tensor mat2, // weight
at::Tensor& weight_scales,
at::Tensor& weight_zero_points,
at::Tensor& b_raw,
at::Tensor& bias,
at::Tensor result, // output
double output_scale,
int64_t output_zero_point,
@ -28,13 +104,13 @@ void quantized_matmul(
double binary_alpha,
const c10::string_view& unary_post_op,
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
c10::string_view unary_post_op_algorithm) {
c10::string_view unary_post_op_algorithm,
bool m2_trans) {
// [Note] Quantized Matrix Multiplication at XPU
// The following code integrates oneDNN quantized gemm. The quantization
// config we support:
// activation: s8&u8; per tensor calibrated; symmetric&asymmetric
// weight: s8; per_tensor/per_channel calibrated; symmetric
bool m2_trans = true;
auto attr = Attr(1.0 / output_scale, output_zero_point);
construct_attr_by_post_op(
binary_post_op,
@ -75,51 +151,11 @@ void quantized_matmul(
}
bool with_bias = false;
at::Tensor b = b_raw;
at::Tensor b = bias;
if (b.defined()) {
with_bias = true;
if (b.dim() == 1) {
TORCH_CHECK(
b.size(0) == n || b.size(0) == 1,
"matmul supports [n] or [1] when bias dim is 1, but b.size() is:",
b.size(0));
if (b.size(0) == 0) {
with_bias = false;
} else if (m1.dim() == 3) {
b = b.expand({mb, m, n}).contiguous();
} else if (m1.dim() == 2) {
b = b.expand({1, n}).contiguous();
}
} else if (b.dim() == 2) {
TORCH_CHECK(
(b.size(0) == m && b.size(1) == n) ||
(b.size(0) == 1 && b.size(1) == n) ||
(b.size(0) == m && b.size(1) == 1) ||
(b.size(0) == 1 && b.size(1) == 1),
"matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...");
if (b.size(0) == 1 && b.size(1) == 1)
b = b.expand({1, n}).contiguous();
} else if (b.dim() == 3) {
TORCH_CHECK(
are_expandable({mb, m, n}, b.sizes()),
"matmul bias must be expandable to:",
dst.sizes(),
" but got:",
b.sizes());
b = b.expand({mb, m, n}).contiguous();
} else if (b.dim() == 0) {
TORCH_CHECK(
b.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
if (m1.dim() == 3) {
b = b.expand({mb, m, n}).contiguous();
} else {
b = b.expand({1, n}).contiguous();
}
} else {
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
}
b = broadcast_bias(dst, b, mb, m, n);
}
// bias is fused in post-op for quantized path
b = b.contiguous(); // avoid reorder 2 times

View file

@ -152,6 +152,7 @@ void quantized_matmul(
double binary_alpha,
const c10::string_view& unary_post_op,
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
c10::string_view unary_post_op_algorithm);
c10::string_view unary_post_op_algorithm,
bool m2_trnas);
} // namespace at::native::onednn

View file

@ -25,6 +25,11 @@ Tensor q_linear_pointwise(
const int64_t dim = act.dim();
TORCH_CHECK(dim == 2, "qliner XPU: input dim should be 2, but got", dim);
TORCH_CHECK(
act.device() == weight.device() &&
act.device() == weight_scales.device() &&
act.device() == weight_zero_points.device(),
"qlinear xpu: input tensors(act, weight, weight scale, weight zero-points) should be on the same device");
int64_t K = act.size(dim - 1);
int64_t M = act.numel() / K;
// [M, K] x [K, N]
@ -55,7 +60,8 @@ Tensor q_linear_pointwise(
/*binary alpha*/ 1.0,
post_op_name,
post_op_args,
post_op_algorithm);
post_op_algorithm,
/*m2_trans*/ true);
return qout;
}
@ -78,6 +84,11 @@ Tensor q_linear_pointwise_tensor(
const int64_t dim = act.dim();
TORCH_CHECK(dim == 2, "qliner XPU: input dim should be 2, but got", dim);
TORCH_CHECK(
act.device() == weight.device() &&
act.device() == weight_scales.device() &&
act.device() == weight_zero_points.device(),
"qlinear xpu: input tensors(act, weight, weight scale, weight zero-points) should be on the same device");
int64_t K = act.size(dim - 1);
int64_t M = act.numel() / K;
// [M, K] x [K, N]
@ -108,7 +119,8 @@ Tensor q_linear_pointwise_tensor(
/*binary alpha*/ 1.0,
post_op_name,
post_op_args,
post_op_algorithm);
post_op_algorithm,
/*m2_trans*/ true);
return qout;
}

View file

@ -2060,6 +2060,12 @@ class TestPatternMatcher(TestPatternMatcherBase):
return self.linear2(self.linear(x))
mod = M(bias, do_permute=do_permute).eval().to(device=device)
assert isinstance(inputs, tuple)
def __convert_tensor_to_device(input, device):
return input.to(device=device) if isinstance(input, torch.Tensor) else input
inputs = tuple(__convert_tensor_to_device(input, device) for input in inputs)
def _default_matcher_check_fn():
self.assertEqual(