[Intel GPU] qconv_pointwise.binary XPU support

ghstack-source-id: cbaa17fd11
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135189
This commit is contained in:
Yan Zhiwei 2025-02-10 04:25:50 -08:00 committed by guangyey
parent 63ffab0c73
commit b91481ba82
7 changed files with 250 additions and 71 deletions

View file

@ -177,7 +177,7 @@ class Attr {
float sum_q_scale = 1.f,
int64_t zp = 0) {
ops_params_.push_back(
PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, kind_t::sum));
PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, zp, kind_t::sum));
return *this;
}
@ -261,10 +261,7 @@ class Attr {
return *this;
}
dnnl::post_ops extract_post_ops(
const at::Tensor& dst,
bool is_quantized = false,
bool int8_output = false) {
dnnl::post_ops extract_post_ops(const at::Tensor& dst) {
// this function is used to extract post ops params from the ops_params_
// and put them into onednn post ops
for (size_t i = 0; i < ops_params_.size(); ++i) {
@ -303,11 +300,6 @@ class Attr {
}
}
// if output is quantized, then append the eltwise linear to adjust the
// output scale/zero_point
if (is_quantized && int8_output) {
dnnl_post_ops_.append_eltwise(kind_with_linear, q_scale_, q_zero_point_);
}
return dnnl_post_ops_;
}
@ -410,6 +402,7 @@ static inline void construct_attr_by_post_op(
double binary_alpha,
double input1_scale,
int64_t input1_zero_point,
std::optional<at::Tensor> accum,
const std::string_view& unary_post_op,
const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
const std::string_view& unary_post_op_algorithm,
@ -418,11 +411,46 @@ static inline void construct_attr_by_post_op(
(binary_post_op == "none" && unary_post_op == "none"); // not post-ops
bool is_unary_post_op_only =
(binary_post_op == "none" && unary_post_op != "none"); // ex., conv + relu
bool is_valid_binary_combination =
(binary_post_op == "add" || binary_post_op == "sum") &&
(unary_post_op == "none" || unary_post_op == "relu");
TORCH_INTERNAL_ASSERT(
is_unary_post_op_only || is_none_post_op,
"Currently, quantization backend for Intel GPU only supports convolution or convolution with unary post operation like ReLU");
construct_attr_for_unary(
unary_post_op, unary_post_op_args, unary_post_op_algorithm, attr);
is_unary_post_op_only || is_none_post_op || is_valid_binary_combination,
"Please provide valid combination of unary post operators and binary post operators");
if (binary_post_op == "none") {
construct_attr_for_unary(
unary_post_op, unary_post_op_args, unary_post_op_algorithm, attr);
} else if (binary_post_op == "sum") {
if (unary_post_op == "none") {
if (input1_zero_point != 0)
attr = attr.append_post_eltwise(
/*scale*/ 1,
/*alpha*/ 1,
-input1_zero_point * input1_scale,
attr.kind_with_linear);
attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0);
} else if (unary_post_op == "relu") {
if (input1_zero_point != 0)
attr = attr.append_post_eltwise(
/*scale*/ 1,
/*alpha*/ 1,
-input1_zero_point * input1_scale,
attr.kind_with_linear);
attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0);
attr = attr.append_post_eltwise(
/* eltwise_scale */ 1.f,
/* alpha */ 0.f,
/* beta */ 0.f,
attr.kind_with_relu);
}
} else if (binary_post_op == "add") {
TORCH_CHECK(accum.has_value());
attr = attr.append_post_binary(attr.kind_with_binary_add, accum.value());
if (unary_post_op == "relu") {
attr = attr.append_post_eltwise(1.f, 0.f, 0.f, attr.kind_with_relu);
}
}
}
} // namespace at::native::onednn

View file

@ -11,14 +11,19 @@
namespace at::native::onednn {
static std::tuple<dnnl::memory::desc, dnnl::memory::desc, dnnl::memory::desc>
static std::tuple<
dnnl::memory::desc,
dnnl::memory::desc,
dnnl::memory::desc,
dnnl::memory::desc>
qconv_get_md(
const at::Tensor& src,
const at::Tensor& wgh,
std::optional<at::Tensor> bias,
const at::Tensor& dst,
int64_t groups) {
// create dnnl::memory desc from the src/wgh/dst tensors
dnnl::memory::desc src_usr_md, wgh_usr_md, dst_usr_md;
dnnl::memory::desc src_usr_md, wgh_usr_md, dst_usr_md, bias_usr_md;
auto ndim = src.ndimension();
bool src_is_cl =
(src.suggest_memory_format() == at::MemoryFormat::ChannelsLast) ||
@ -44,7 +49,14 @@ qconv_get_md(
auto fmt_wgh = conv_weight_fmt(ndim, groups != 1, wgh_is_cl);
wgh_usr_md = dnnl::memory::desc(wgh_tz, wei_data_t, fmt_wgh);
return {src_usr_md, wgh_usr_md, dst_usr_md};
if (bias.has_value()) {
bias_usr_md = dnnl::memory::desc(
bias.value().sizes().vec(),
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::x);
}
return {src_usr_md, wgh_usr_md, bias_usr_md, dst_usr_md};
}
at::Tensor quantized_convolution(
@ -76,14 +88,12 @@ at::Tensor quantized_convolution(
Attr(/*q_scale=*/1.0 / inv_output_scale, /*zp=*/output_zero_point);
auto ndim = act.ndimension();
if (bias.has_value()) {
attr = attr.append_bias(bias.value(), ndim - 2);
}
construct_attr_by_post_op(
binary_attr.has_value() ? binary_attr.value() : "none",
binary_alpha.has_value() ? binary_alpha.value().to<double>() : 1.0,
accum_scale,
accum_zero_point,
accum,
unary_attr.has_value() ? unary_attr.value() : "none",
unary_scalars,
unary_algorithm.has_value() ? unary_algorithm.value() : "",
@ -110,10 +120,7 @@ at::Tensor quantized_convolution(
dnnl::memory::dims _dilation = compatible_dilation(dilation);
dnnl::post_ops po;
// extract post ops
po = attr.extract_post_ops(
output,
/*is_quantized*/ true,
output.scalar_type() == at::kByte || output.scalar_type() == at::kChar);
po = attr.extract_post_ops(output);
int mask_ac = 0, mask_weight;
// [Note: Per-channel quantization mask setting]
// Per-channel quantization is on weight output channel mostly, mask_weight=
@ -127,10 +134,11 @@ at::Tensor quantized_convolution(
dnnl::primitive_attr pattr;
bool src_need_zp = (act_scale != 0);
bool dst_need_zp = (output_zero_point != 0);
// create usr_md for tensors, and md for conv primitive
auto [src_md, weight_md, output_md] =
qconv_get_md(act, weight, output, groups);
auto [src_md, weight_md, bias_md, output_md] =
qconv_get_md(act, weight, bias, output, groups);
// get tensor md
auto ic = act.size(1);
@ -139,11 +147,14 @@ at::Tensor quantized_convolution(
compatible_weight_dims(ndim, groups, oc, ic, weight.sizes());
pattr.set_scales_mask(DNNL_ARG_SRC, mask_ac);
pattr.set_scales_mask(DNNL_ARG_DST, mask_ac);
pattr.set_scales_mask(DNNL_ARG_WEIGHTS, mask_weight);
pattr.set_post_ops(po);
if (src_need_zp)
pattr.set_zero_points_mask(DNNL_ARG_SRC, mask_ac);
if (dst_need_zp)
pattr.set_zero_points_mask(DNNL_ARG_DST, mask_ac);
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
// create primitive
@ -153,7 +164,7 @@ at::Tensor quantized_convolution(
dnnl::algorithm::convolution_direct,
src_md,
weight_md,
dnnl::memory::desc(),
bias.has_value() ? bias_md : dnnl::memory::desc(),
output_md,
_stride,
_dilation,
@ -164,11 +175,14 @@ at::Tensor quantized_convolution(
dnnl::convolution_forward conv_forward =
dnnl::convolution_forward(conv_fwd_pd);
dnnl::memory src_m, weight_m, output_m;
dnnl::memory src_m, weight_m, output_m, bias_m;
src_m = make_onednn_memory(src_md, engine, act.data_ptr());
output_m = make_onednn_memory(output_md, engine, output.data_ptr());
weight_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
if (bias.has_value()) {
bias_m = make_onednn_memory(bias_md, engine, bias.value().data_ptr());
}
std::unordered_map<int, dnnl::memory> args;
if (attr.with_binary())
@ -176,6 +190,9 @@ at::Tensor quantized_convolution(
args.insert({DNNL_ARG_SRC, src_m});
args.insert({DNNL_ARG_WEIGHTS, weight_m});
args.insert({DNNL_ARG_DST, output_m});
if (bias.has_value()) {
args.insert({DNNL_ARG_BIAS, bias_m});
}
dnnl::memory src_sc_m, src_zp_m;
Tensor src_sc_tensor, src_zp_tensor;
@ -188,7 +205,17 @@ at::Tensor quantized_convolution(
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_m});
}
// dst scale is no need for setting, since it is fused in postop via linear
dnnl::memory dst_sc_m, dst_zp_m;
Tensor dst_sc_tensor, dst_zp_tensor;
dst_sc_m = dnnl_memory_from_host_scalar(
static_cast<float>(inv_output_scale), dst_sc_tensor, engine);
dst_zp_m = dnnl_memory_from_host_scalar(
static_cast<int32_t>(output_zero_point), dst_zp_tensor, engine);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_m});
if (dst_need_zp) {
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_m});
}
size_t scratchpad_size = conv_fwd_pd.scratchpad_desc().get_size();
Tensor scratchpad_tensor = at::empty(
{static_cast<int64_t>(scratchpad_size)},

View file

@ -117,6 +117,7 @@ void quantized_matmul(
binary_alpha,
input_scale,
input_zero_point,
other,
unary_post_op,
unary_post_op_args,
unary_post_op_algorithm,
@ -209,11 +210,9 @@ void quantized_matmul(
std::unordered_map<int, dnnl::memory> args;
dnnl::post_ops po;
po = attr.extract_post_ops(
dst,
true,
dst.scalar_type() == at::kByte || dst.scalar_type() == at::kChar);
po = attr.extract_post_ops(dst);
bool m1_need_zp = (input_zero_point != 0);
bool dst_need_zp = (output_zero_point != 0);
bool wgh_is_per_channel = weight_scales.numel() > 1;
dnnl::matmul matmul_p;
@ -241,6 +240,10 @@ void quantized_matmul(
if (m1_need_zp) {
pattr.set_zero_points_mask(DNNL_ARG_SRC, mask_ac);
}
pattr.set_scales_mask(DNNL_ARG_DST, mask_ac);
if (dst_need_zp) {
pattr.set_zero_points_mask(DNNL_ARG_DST, mask_ac);
}
if (with_bias) {
b_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides);
@ -308,6 +311,17 @@ void quantized_matmul(
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, m1_zp_m});
}
dnnl::memory dst_sc_m, dst_zp_m;
Tensor dst_sc_tensor, dst_zp_tensor;
dst_sc_m = dnnl_memory_from_host_scalar(
static_cast<float>(output_scale), dst_sc_tensor, engine);
dst_zp_m = dnnl_memory_from_host_scalar(
static_cast<int32_t>(output_zero_point), dst_zp_tensor, engine);
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_m});
if (dst_need_zp) {
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_m});
}
auto qmatmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args);
if (!dst.is_same(result))

View file

@ -54,7 +54,7 @@ class QConvoneDNNXPU final {
TORCH_CHECK(
attr == "none" || attr == "relu" || attr == "hardtanh" ||
attr == "hardswish" || attr == "swish",
"We support quantized convolution without any post-ops or combinations for Quantized Conv + ReLU, Hardtanh, and Hardswish are supported. However, encountered unsupported post operation:",
"We support quantized convolution without any post-ops or combinations for Quantized Conv + ReLU, Hardtanh, GELU, Swish, and Hardswish are supported. However, encountered unsupported post operation:",
attr,
".");
}
@ -104,6 +104,96 @@ class QConvoneDNNXPU final {
/*unary_scalars*/ scalars,
/*unary_algorithm*/ algorithm);
}
static at::Tensor run_pointwise_binary(
at::Tensor act,
double act_scale,
int64_t act_zero_point,
at::Tensor weight,
at::Tensor weight_scales,
at::Tensor weight_zero_points,
at::Tensor accum,
std::optional<at::Tensor> bias,
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> dilation,
int64_t groups,
double output_scale,
int64_t output_zero_point,
std::optional<c10::ScalarType> output_dtype,
double accum_scale,
int64_t accum_zero_point,
std::string_view binary_attr,
std::optional<at::Scalar> alpha,
std::optional<std::string_view> unary_attr,
torch::List<std::optional<at::Scalar>> unary_scalars,
std::optional<std::string_view> unary_algorithm) {
TORCH_CHECK(
act.dim() == 4 && binary_attr == "sum" &&
(!unary_attr.has_value() ||
(unary_attr.has_value() &&
(unary_attr.value() == "none" || unary_attr.value() == "relu"))),
"post_op sum or post_op sum_relu is supported for quantized pointwise conv2d. Got binary_post_op: ",
binary_attr,
" unary_post_op: ",
unary_attr.has_value() ? unary_attr.value() : "none",
".")
bool is_channels_last_suggested = use_channels_last_for_conv(act, weight);
auto mfmt = is_channels_last_suggested
? get_cl_tag_by_ndim(act.ndimension())
: at::MemoryFormat::Contiguous;
Tensor input_ = act.contiguous(mfmt);
Tensor weight_ = weight.contiguous(mfmt);
auto dst_tz = conv_dst_size(
input_.ndimension(),
input_.sizes(),
weight_.sizes(),
padding.vec(),
padding.vec(),
stride.vec(),
dilation.vec());
bool has_accum_postop_sum = binary_attr == "sum";
Tensor output = has_accum_postop_sum
? accum
: at::empty(
dst_tz,
device(c10::kXPU).dtype(output_dtype).memory_format(mfmt));
output = quantized_convolution(
act,
act_scale,
act_zero_point,
weight,
weight_scales,
weight_zero_points,
bias,
stride,
padding,
dilation,
/*transposed*/ false,
groups,
output,
output_scale,
output_zero_point,
/*accum*/ accum,
/*accum_scale*/ accum_scale,
/*accum_zero_point*/ accum_zero_point,
/*output_dtype*/ output_dtype,
/*binary_attr*/ binary_attr,
/*binary_alpha*/ alpha,
/*unary_attr*/ unary_attr,
/*unary_scalars*/ unary_scalars,
/*unary_algorithm*/ unary_algorithm);
if (!has_accum_postop_sum) {
return output;
} else {
return accum;
}
}
};
TORCH_LIBRARY_IMPL(onednn, XPU, m) {
@ -119,6 +209,9 @@ TORCH_LIBRARY_IMPL(onednn, XPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("onednn::qconv3d_pointwise"),
QConvoneDNNXPU::run_pointwise);
m.impl(
TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary"),
QConvoneDNNXPU::run_pointwise_binary);
}
} // namespace at::native::xpu

View file

@ -1246,7 +1246,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
qconv2d_unary_matcher_nodes=11,
)
def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False):
def _qconv2d_add_test_helper(
self, device="cpu", use_relu=False, int8_mixed_bf16=False
):
r"""
This testcase will quantize a Conv2d->Add pattern as:
X
@ -1292,9 +1294,11 @@ class TestPatternMatcher(TestPatternMatcherBase):
return res
for add_fn in quantization_add_fn_list + quantization_inplace_add_fn_list:
mod = M(add_fn, use_relu).eval()
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
1
mod = M(add_fn, use_relu).eval().to(device=device)
v = (
torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False)
.add(1)
.to(device=device)
)
def matcher_check_fn():
@ -1320,7 +1324,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
)
def _qconv2d_add_cpu_test_helper2(self, use_relu=False, int8_mixed_bf16=False):
def _qconv2d_add_test_helper2(
self, device="cpu", use_relu=False, int8_mixed_bf16=False
):
r"""
This testcase will quantize two Conv2d->Add patterns as:
@ -1381,10 +1387,16 @@ class TestPatternMatcher(TestPatternMatcherBase):
for add_fn, swap_inputs in itertools.product(
quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True]
):
mod = M(add_fn, use_relu, swap_inputs).eval()
x = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False)
x2 = torch.randn((1, 6, 6, 6), dtype=torch.float32, requires_grad=False)
x3 = torch.randn((1, 6, 4, 4), dtype=torch.float32, requires_grad=False)
mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device)
x = torch.randn(
(1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device
)
x2 = torch.randn(
(1, 6, 6, 6), dtype=torch.float32, requires_grad=False, device=device
)
x3 = torch.randn(
(1, 6, 4, 4), dtype=torch.float32, requires_grad=False, device=device
)
def matcher_check_fn():
# 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2
@ -1412,28 +1424,42 @@ class TestPatternMatcher(TestPatternMatcherBase):
@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_add_cpu(self):
self._qconv2d_add_cpu_test_helper()
self._qconv2d_add_cpu_test_helper2()
self._qconv2d_add_test_helper()
self._qconv2d_add_test_helper2()
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoXPU
def test_qconv2d_add_xpu(self):
self._qconv2d_add_test_helper(device="xpu")
self._qconv2d_add_test_helper2(device="xpu")
@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
def test_qconv2d_add_int8_mixed_bf16(self):
self._qconv2d_add_cpu_test_helper(int8_mixed_bf16=True)
self._qconv2d_add_cpu_test_helper2(int8_mixed_bf16=True)
self._qconv2d_add_test_helper(int8_mixed_bf16=True)
self._qconv2d_add_test_helper2(int8_mixed_bf16=True)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
def test_qconv2d_add_relu_cpu(self):
self._qconv2d_add_cpu_test_helper(use_relu=True)
self._qconv2d_add_cpu_test_helper2(use_relu=True)
self._qconv2d_add_test_helper(use_relu=True)
self._qconv2d_add_test_helper2(use_relu=True)
@skipIfNoDynamoSupport
@skipIfNoONEDNN
@skipIfNoXPU
def test_qconv2d_add_relu_xpu(self):
self._qconv2d_add_test_helper(device="xpu", use_relu=True)
self._qconv2d_add_test_helper2(device="xpu", use_relu=True)
@skipIfNoDynamoSupport
@skipIfNoONEDNNBF16
@skipIfNoONEDNN
def test_qconv2d_add_relu_int8_mixed_bf16(self):
self._qconv2d_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True)
self._qconv2d_add_cpu_test_helper2(use_relu=True, int8_mixed_bf16=True)
self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True)
self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True)
@skipIfNoDynamoSupport
@skipIfNoONEDNN

View file

@ -710,7 +710,8 @@ def _is_valid_quantized_op_binary_optimization_pattern(
if "other" in match.kwargs
else (
match.kwargs["accum"]
if output_dtype == torch.uint8 or (not extra_input_from_dequant)
if (output_dtype in OrderedSet([torch.uint8, torch.int8]))
or (not extra_input_from_dequant)
else match.kwargs["accum_after_dequant"]
)
)
@ -2758,13 +2759,19 @@ def _register_qconv_post_op_fusion_pass(
else:
accum = (
kwargs["accum"]
if output_dtype == torch.uint8
if output_dtype in OrderedSet([torch.uint8, torch.int8])
else kwargs["accum_after_dequant"]
)
accum_scale = (
kwargs["accum_scale"] if output_dtype == torch.uint8 else 1.0
kwargs["accum_scale"]
if output_dtype in OrderedSet([torch.uint8, torch.int8])
else 1.0
)
accum_zp = (
kwargs["accum_zp"]
if output_dtype in OrderedSet([torch.uint8, torch.int8])
else 0
)
accum_zp = kwargs["accum_zp"] if output_dtype == torch.uint8 else 0
computation_args = (
x,
x_scale,

View file

@ -96,22 +96,6 @@ class XPUInductorQuantizer(X86InductorQuantizer):
):
pass
def _annotate_conv2d_binary(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
) -> None:
pass
def _annotate_conv2d_binary_unary(
self,
gm: torch.fx.GraphModule,
quantization_config: Optional[QuantizationConfig],
filter_fn: Optional[FilterFn] = None,
) -> None:
pass
def _annotate_linear_fusion_pattern(
self,
model: torch.fx.GraphModule,