mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
parent
221a8bff42
commit
9c1fe1c398
4 changed files with 103 additions and 48 deletions
|
|
@ -9,6 +9,82 @@
|
|||
|
||||
namespace at::native::onednn {
|
||||
|
||||
at::Tensor broadcast_bias2D(
|
||||
at::Tensor& dst,
|
||||
at::Tensor& bias,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
switch (bias.dim()) {
|
||||
case 1:
|
||||
TORCH_CHECK(
|
||||
bias.size(0) == n || bias.size(0) == 1,
|
||||
"matmul supports [n] or [1] when bias dim is 1, but b.size() is:",
|
||||
bias.size(0));
|
||||
break;
|
||||
case 2:
|
||||
if ((bias.size(0) == m && bias.size(1) == n) ||
|
||||
(bias.size(0) == m && bias.size(1) == 1) ||
|
||||
(bias.size(0) == m && bias.size(1) == 1))
|
||||
return bias; // No need to broadcast
|
||||
TORCH_CHECK(
|
||||
bias.size(0) == 1 && bias.size(1) == 1,
|
||||
"matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...")
|
||||
break;
|
||||
case 0:
|
||||
TORCH_CHECK(
|
||||
bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
|
||||
default:
|
||||
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
|
||||
}
|
||||
bias = bias.expand({1, n}).contiguous();
|
||||
return bias;
|
||||
}
|
||||
|
||||
at::Tensor broadcast_bias3D(
|
||||
at::Tensor& dst,
|
||||
at::Tensor bias,
|
||||
int64_t mb,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
switch (bias.dim()) {
|
||||
case 1:
|
||||
TORCH_CHECK(
|
||||
bias.size(0) == n || bias.size(0) == 1,
|
||||
"matmul supports [n] or [1] when bias dim is 1, but b.size() is:",
|
||||
bias.size(0));
|
||||
break;
|
||||
case 3:
|
||||
TORCH_CHECK(
|
||||
are_expandable({mb, m, n}, bias.sizes()),
|
||||
"matmul bias must be expandable to:",
|
||||
dst.sizes(),
|
||||
" but got:",
|
||||
bias.sizes());
|
||||
break;
|
||||
case 0:
|
||||
TORCH_CHECK(
|
||||
bias.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
|
||||
}
|
||||
bias = bias.expand({mb, m, n}).contiguous();
|
||||
return bias;
|
||||
}
|
||||
|
||||
at::Tensor broadcast_bias(
|
||||
at::Tensor& dst,
|
||||
at::Tensor bias,
|
||||
int64_t mb,
|
||||
int64_t m,
|
||||
int64_t n) {
|
||||
if (dst.dim() == 2) {
|
||||
return broadcast_bias2D(dst, bias, m, n);
|
||||
} else {
|
||||
return broadcast_bias3D(dst, bias, mb, m, n);
|
||||
}
|
||||
}
|
||||
|
||||
void quantized_matmul(
|
||||
at::Tensor mat1, // act
|
||||
double input_scale,
|
||||
|
|
@ -16,7 +92,7 @@ void quantized_matmul(
|
|||
at::Tensor mat2, // weight
|
||||
at::Tensor& weight_scales,
|
||||
at::Tensor& weight_zero_points,
|
||||
at::Tensor& b_raw,
|
||||
at::Tensor& bias,
|
||||
at::Tensor result, // output
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
|
|
@ -28,13 +104,13 @@ void quantized_matmul(
|
|||
double binary_alpha,
|
||||
const c10::string_view& unary_post_op,
|
||||
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm) {
|
||||
c10::string_view unary_post_op_algorithm,
|
||||
bool m2_trans) {
|
||||
// [Note] Quantized Matrix Multiplication at XPU
|
||||
// The following code integrates oneDNN quantized gemm. The quantization
|
||||
// config we support:
|
||||
// activation: s8&u8; per tensor calibrated; symmetric&asymmetric
|
||||
// weight: s8; per_tensor/per_channel calibrated; symmetric
|
||||
bool m2_trans = true;
|
||||
auto attr = Attr(1.0 / output_scale, output_zero_point);
|
||||
construct_attr_by_post_op(
|
||||
binary_post_op,
|
||||
|
|
@ -75,51 +151,11 @@ void quantized_matmul(
|
|||
}
|
||||
|
||||
bool with_bias = false;
|
||||
at::Tensor b = b_raw;
|
||||
at::Tensor b = bias;
|
||||
if (b.defined()) {
|
||||
with_bias = true;
|
||||
if (b.dim() == 1) {
|
||||
TORCH_CHECK(
|
||||
b.size(0) == n || b.size(0) == 1,
|
||||
"matmul supports [n] or [1] when bias dim is 1, but b.size() is:",
|
||||
b.size(0));
|
||||
if (b.size(0) == 0) {
|
||||
with_bias = false;
|
||||
} else if (m1.dim() == 3) {
|
||||
b = b.expand({mb, m, n}).contiguous();
|
||||
} else if (m1.dim() == 2) {
|
||||
b = b.expand({1, n}).contiguous();
|
||||
}
|
||||
} else if (b.dim() == 2) {
|
||||
TORCH_CHECK(
|
||||
(b.size(0) == m && b.size(1) == n) ||
|
||||
(b.size(0) == 1 && b.size(1) == n) ||
|
||||
(b.size(0) == m && b.size(1) == 1) ||
|
||||
(b.size(0) == 1 && b.size(1) == 1),
|
||||
"matmul supports [m, n] or [1, n] or [m, 1] or [1, 1] when bias dim is 2 ...");
|
||||
if (b.size(0) == 1 && b.size(1) == 1)
|
||||
b = b.expand({1, n}).contiguous();
|
||||
} else if (b.dim() == 3) {
|
||||
TORCH_CHECK(
|
||||
are_expandable({mb, m, n}, b.sizes()),
|
||||
"matmul bias must be expandable to:",
|
||||
dst.sizes(),
|
||||
" but got:",
|
||||
b.sizes());
|
||||
b = b.expand({mb, m, n}).contiguous();
|
||||
} else if (b.dim() == 0) {
|
||||
TORCH_CHECK(
|
||||
b.numel() == 1, "matmul supports 1 numel when bias dim is [] ...");
|
||||
if (m1.dim() == 3) {
|
||||
b = b.expand({mb, m, n}).contiguous();
|
||||
} else {
|
||||
b = b.expand({1, n}).contiguous();
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(0, "unsupported bias dim in matmul ...");
|
||||
}
|
||||
b = broadcast_bias(dst, b, mb, m, n);
|
||||
}
|
||||
|
||||
// bias is fused in post-op for quantized path
|
||||
b = b.contiguous(); // avoid reorder 2 times
|
||||
|
||||
|
|
|
|||
|
|
@ -152,6 +152,7 @@ void quantized_matmul(
|
|||
double binary_alpha,
|
||||
const c10::string_view& unary_post_op,
|
||||
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
c10::string_view unary_post_op_algorithm);
|
||||
c10::string_view unary_post_op_algorithm,
|
||||
bool m2_trnas);
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
|
|
|||
|
|
@ -25,6 +25,11 @@ Tensor q_linear_pointwise(
|
|||
|
||||
const int64_t dim = act.dim();
|
||||
TORCH_CHECK(dim == 2, "qliner XPU: input dim should be 2, but got", dim);
|
||||
TORCH_CHECK(
|
||||
act.device() == weight.device() &&
|
||||
act.device() == weight_scales.device() &&
|
||||
act.device() == weight_zero_points.device(),
|
||||
"qlinear xpu: input tensors(act, weight, weight scale, weight zero-points) should be on the same device");
|
||||
int64_t K = act.size(dim - 1);
|
||||
int64_t M = act.numel() / K;
|
||||
// [M, K] x [K, N]
|
||||
|
|
@ -55,7 +60,8 @@ Tensor q_linear_pointwise(
|
|||
/*binary alpha*/ 1.0,
|
||||
post_op_name,
|
||||
post_op_args,
|
||||
post_op_algorithm);
|
||||
post_op_algorithm,
|
||||
/*m2_trans*/ true);
|
||||
|
||||
return qout;
|
||||
}
|
||||
|
|
@ -78,6 +84,11 @@ Tensor q_linear_pointwise_tensor(
|
|||
|
||||
const int64_t dim = act.dim();
|
||||
TORCH_CHECK(dim == 2, "qliner XPU: input dim should be 2, but got", dim);
|
||||
TORCH_CHECK(
|
||||
act.device() == weight.device() &&
|
||||
act.device() == weight_scales.device() &&
|
||||
act.device() == weight_zero_points.device(),
|
||||
"qlinear xpu: input tensors(act, weight, weight scale, weight zero-points) should be on the same device");
|
||||
int64_t K = act.size(dim - 1);
|
||||
int64_t M = act.numel() / K;
|
||||
// [M, K] x [K, N]
|
||||
|
|
@ -108,7 +119,8 @@ Tensor q_linear_pointwise_tensor(
|
|||
/*binary alpha*/ 1.0,
|
||||
post_op_name,
|
||||
post_op_args,
|
||||
post_op_algorithm);
|
||||
post_op_algorithm,
|
||||
/*m2_trans*/ true);
|
||||
|
||||
return qout;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2152,6 +2152,12 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
return self.linear2(self.linear(x))
|
||||
|
||||
mod = M(bias, do_permute=do_permute).eval().to(device=device)
|
||||
assert isinstance(inputs, tuple)
|
||||
|
||||
def __convert_tensor_to_device(input, device):
|
||||
return input.to(device=device) if isinstance(input, torch.Tensor) else input
|
||||
|
||||
inputs = tuple(__convert_tensor_to_device(input, device) for input in inputs)
|
||||
|
||||
def _default_matcher_check_fn():
|
||||
self.assertEqual(
|
||||
|
|
|
|||
Loading…
Reference in a new issue