mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Intel GPU] qconv_pointwise.binary XPU support
ghstack-source-id: cbaa17fd11
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135189
This commit is contained in:
parent
63ffab0c73
commit
b91481ba82
7 changed files with 250 additions and 71 deletions
|
|
@ -177,7 +177,7 @@ class Attr {
|
|||
float sum_q_scale = 1.f,
|
||||
int64_t zp = 0) {
|
||||
ops_params_.push_back(
|
||||
PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, kind_t::sum));
|
||||
PostOpParam(/*scale_sum*/ sum_scale * sum_q_scale, zp, kind_t::sum));
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
|
@ -261,10 +261,7 @@ class Attr {
|
|||
return *this;
|
||||
}
|
||||
|
||||
dnnl::post_ops extract_post_ops(
|
||||
const at::Tensor& dst,
|
||||
bool is_quantized = false,
|
||||
bool int8_output = false) {
|
||||
dnnl::post_ops extract_post_ops(const at::Tensor& dst) {
|
||||
// this function is used to extract post ops params from the ops_params_
|
||||
// and put them into onednn post ops
|
||||
for (size_t i = 0; i < ops_params_.size(); ++i) {
|
||||
|
|
@ -303,11 +300,6 @@ class Attr {
|
|||
}
|
||||
}
|
||||
|
||||
// if output is quantized, then append the eltwise linear to adjust the
|
||||
// output scale/zero_point
|
||||
if (is_quantized && int8_output) {
|
||||
dnnl_post_ops_.append_eltwise(kind_with_linear, q_scale_, q_zero_point_);
|
||||
}
|
||||
return dnnl_post_ops_;
|
||||
}
|
||||
|
||||
|
|
@ -410,6 +402,7 @@ static inline void construct_attr_by_post_op(
|
|||
double binary_alpha,
|
||||
double input1_scale,
|
||||
int64_t input1_zero_point,
|
||||
std::optional<at::Tensor> accum,
|
||||
const std::string_view& unary_post_op,
|
||||
const torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
const std::string_view& unary_post_op_algorithm,
|
||||
|
|
@ -418,11 +411,46 @@ static inline void construct_attr_by_post_op(
|
|||
(binary_post_op == "none" && unary_post_op == "none"); // not post-ops
|
||||
bool is_unary_post_op_only =
|
||||
(binary_post_op == "none" && unary_post_op != "none"); // ex., conv + relu
|
||||
bool is_valid_binary_combination =
|
||||
(binary_post_op == "add" || binary_post_op == "sum") &&
|
||||
(unary_post_op == "none" || unary_post_op == "relu");
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
is_unary_post_op_only || is_none_post_op,
|
||||
"Currently, quantization backend for Intel GPU only supports convolution or convolution with unary post operation like ReLU");
|
||||
construct_attr_for_unary(
|
||||
unary_post_op, unary_post_op_args, unary_post_op_algorithm, attr);
|
||||
is_unary_post_op_only || is_none_post_op || is_valid_binary_combination,
|
||||
"Please provide valid combination of unary post operators and binary post operators");
|
||||
|
||||
if (binary_post_op == "none") {
|
||||
construct_attr_for_unary(
|
||||
unary_post_op, unary_post_op_args, unary_post_op_algorithm, attr);
|
||||
} else if (binary_post_op == "sum") {
|
||||
if (unary_post_op == "none") {
|
||||
if (input1_zero_point != 0)
|
||||
attr = attr.append_post_eltwise(
|
||||
/*scale*/ 1,
|
||||
/*alpha*/ 1,
|
||||
-input1_zero_point * input1_scale,
|
||||
attr.kind_with_linear);
|
||||
attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0);
|
||||
} else if (unary_post_op == "relu") {
|
||||
if (input1_zero_point != 0)
|
||||
attr = attr.append_post_eltwise(
|
||||
/*scale*/ 1,
|
||||
/*alpha*/ 1,
|
||||
-input1_zero_point * input1_scale,
|
||||
attr.kind_with_linear);
|
||||
attr = attr.append_post_sum(1, input1_scale, /*input1_zero_point*/ 0);
|
||||
attr = attr.append_post_eltwise(
|
||||
/* eltwise_scale */ 1.f,
|
||||
/* alpha */ 0.f,
|
||||
/* beta */ 0.f,
|
||||
attr.kind_with_relu);
|
||||
}
|
||||
} else if (binary_post_op == "add") {
|
||||
TORCH_CHECK(accum.has_value());
|
||||
attr = attr.append_post_binary(attr.kind_with_binary_add, accum.value());
|
||||
if (unary_post_op == "relu") {
|
||||
attr = attr.append_post_eltwise(1.f, 0.f, 0.f, attr.kind_with_relu);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at::native::onednn
|
||||
|
|
|
|||
|
|
@ -11,14 +11,19 @@
|
|||
|
||||
namespace at::native::onednn {
|
||||
|
||||
static std::tuple<dnnl::memory::desc, dnnl::memory::desc, dnnl::memory::desc>
|
||||
static std::tuple<
|
||||
dnnl::memory::desc,
|
||||
dnnl::memory::desc,
|
||||
dnnl::memory::desc,
|
||||
dnnl::memory::desc>
|
||||
qconv_get_md(
|
||||
const at::Tensor& src,
|
||||
const at::Tensor& wgh,
|
||||
std::optional<at::Tensor> bias,
|
||||
const at::Tensor& dst,
|
||||
int64_t groups) {
|
||||
// create dnnl::memory desc from the src/wgh/dst tensors
|
||||
dnnl::memory::desc src_usr_md, wgh_usr_md, dst_usr_md;
|
||||
dnnl::memory::desc src_usr_md, wgh_usr_md, dst_usr_md, bias_usr_md;
|
||||
auto ndim = src.ndimension();
|
||||
bool src_is_cl =
|
||||
(src.suggest_memory_format() == at::MemoryFormat::ChannelsLast) ||
|
||||
|
|
@ -44,7 +49,14 @@ qconv_get_md(
|
|||
auto fmt_wgh = conv_weight_fmt(ndim, groups != 1, wgh_is_cl);
|
||||
wgh_usr_md = dnnl::memory::desc(wgh_tz, wei_data_t, fmt_wgh);
|
||||
|
||||
return {src_usr_md, wgh_usr_md, dst_usr_md};
|
||||
if (bias.has_value()) {
|
||||
bias_usr_md = dnnl::memory::desc(
|
||||
bias.value().sizes().vec(),
|
||||
dnnl::memory::data_type::f32,
|
||||
dnnl::memory::format_tag::x);
|
||||
}
|
||||
|
||||
return {src_usr_md, wgh_usr_md, bias_usr_md, dst_usr_md};
|
||||
}
|
||||
|
||||
at::Tensor quantized_convolution(
|
||||
|
|
@ -76,14 +88,12 @@ at::Tensor quantized_convolution(
|
|||
Attr(/*q_scale=*/1.0 / inv_output_scale, /*zp=*/output_zero_point);
|
||||
|
||||
auto ndim = act.ndimension();
|
||||
if (bias.has_value()) {
|
||||
attr = attr.append_bias(bias.value(), ndim - 2);
|
||||
}
|
||||
construct_attr_by_post_op(
|
||||
binary_attr.has_value() ? binary_attr.value() : "none",
|
||||
binary_alpha.has_value() ? binary_alpha.value().to<double>() : 1.0,
|
||||
accum_scale,
|
||||
accum_zero_point,
|
||||
accum,
|
||||
unary_attr.has_value() ? unary_attr.value() : "none",
|
||||
unary_scalars,
|
||||
unary_algorithm.has_value() ? unary_algorithm.value() : "",
|
||||
|
|
@ -110,10 +120,7 @@ at::Tensor quantized_convolution(
|
|||
dnnl::memory::dims _dilation = compatible_dilation(dilation);
|
||||
dnnl::post_ops po;
|
||||
// extract post ops
|
||||
po = attr.extract_post_ops(
|
||||
output,
|
||||
/*is_quantized*/ true,
|
||||
output.scalar_type() == at::kByte || output.scalar_type() == at::kChar);
|
||||
po = attr.extract_post_ops(output);
|
||||
int mask_ac = 0, mask_weight;
|
||||
// [Note: Per-channel quantization mask setting]
|
||||
// Per-channel quantization is on weight output channel mostly, mask_weight=
|
||||
|
|
@ -127,10 +134,11 @@ at::Tensor quantized_convolution(
|
|||
dnnl::primitive_attr pattr;
|
||||
|
||||
bool src_need_zp = (act_scale != 0);
|
||||
bool dst_need_zp = (output_zero_point != 0);
|
||||
|
||||
// create usr_md for tensors, and md for conv primitive
|
||||
auto [src_md, weight_md, output_md] =
|
||||
qconv_get_md(act, weight, output, groups);
|
||||
auto [src_md, weight_md, bias_md, output_md] =
|
||||
qconv_get_md(act, weight, bias, output, groups);
|
||||
|
||||
// get tensor md
|
||||
auto ic = act.size(1);
|
||||
|
|
@ -139,11 +147,14 @@ at::Tensor quantized_convolution(
|
|||
compatible_weight_dims(ndim, groups, oc, ic, weight.sizes());
|
||||
|
||||
pattr.set_scales_mask(DNNL_ARG_SRC, mask_ac);
|
||||
pattr.set_scales_mask(DNNL_ARG_DST, mask_ac);
|
||||
pattr.set_scales_mask(DNNL_ARG_WEIGHTS, mask_weight);
|
||||
pattr.set_post_ops(po);
|
||||
|
||||
if (src_need_zp)
|
||||
pattr.set_zero_points_mask(DNNL_ARG_SRC, mask_ac);
|
||||
if (dst_need_zp)
|
||||
pattr.set_zero_points_mask(DNNL_ARG_DST, mask_ac);
|
||||
pattr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
// create primitive
|
||||
|
|
@ -153,7 +164,7 @@ at::Tensor quantized_convolution(
|
|||
dnnl::algorithm::convolution_direct,
|
||||
src_md,
|
||||
weight_md,
|
||||
dnnl::memory::desc(),
|
||||
bias.has_value() ? bias_md : dnnl::memory::desc(),
|
||||
output_md,
|
||||
_stride,
|
||||
_dilation,
|
||||
|
|
@ -164,11 +175,14 @@ at::Tensor quantized_convolution(
|
|||
dnnl::convolution_forward conv_forward =
|
||||
dnnl::convolution_forward(conv_fwd_pd);
|
||||
|
||||
dnnl::memory src_m, weight_m, output_m;
|
||||
dnnl::memory src_m, weight_m, output_m, bias_m;
|
||||
|
||||
src_m = make_onednn_memory(src_md, engine, act.data_ptr());
|
||||
output_m = make_onednn_memory(output_md, engine, output.data_ptr());
|
||||
weight_m = make_onednn_memory(weight_md, engine, weight.data_ptr());
|
||||
if (bias.has_value()) {
|
||||
bias_m = make_onednn_memory(bias_md, engine, bias.value().data_ptr());
|
||||
}
|
||||
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
if (attr.with_binary())
|
||||
|
|
@ -176,6 +190,9 @@ at::Tensor quantized_convolution(
|
|||
args.insert({DNNL_ARG_SRC, src_m});
|
||||
args.insert({DNNL_ARG_WEIGHTS, weight_m});
|
||||
args.insert({DNNL_ARG_DST, output_m});
|
||||
if (bias.has_value()) {
|
||||
args.insert({DNNL_ARG_BIAS, bias_m});
|
||||
}
|
||||
|
||||
dnnl::memory src_sc_m, src_zp_m;
|
||||
Tensor src_sc_tensor, src_zp_tensor;
|
||||
|
|
@ -188,7 +205,17 @@ at::Tensor quantized_convolution(
|
|||
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, src_zp_m});
|
||||
}
|
||||
|
||||
// dst scale is no need for setting, since it is fused in postop via linear
|
||||
dnnl::memory dst_sc_m, dst_zp_m;
|
||||
Tensor dst_sc_tensor, dst_zp_tensor;
|
||||
dst_sc_m = dnnl_memory_from_host_scalar(
|
||||
static_cast<float>(inv_output_scale), dst_sc_tensor, engine);
|
||||
dst_zp_m = dnnl_memory_from_host_scalar(
|
||||
static_cast<int32_t>(output_zero_point), dst_zp_tensor, engine);
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_m});
|
||||
if (dst_need_zp) {
|
||||
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_m});
|
||||
}
|
||||
|
||||
size_t scratchpad_size = conv_fwd_pd.scratchpad_desc().get_size();
|
||||
Tensor scratchpad_tensor = at::empty(
|
||||
{static_cast<int64_t>(scratchpad_size)},
|
||||
|
|
|
|||
|
|
@ -117,6 +117,7 @@ void quantized_matmul(
|
|||
binary_alpha,
|
||||
input_scale,
|
||||
input_zero_point,
|
||||
other,
|
||||
unary_post_op,
|
||||
unary_post_op_args,
|
||||
unary_post_op_algorithm,
|
||||
|
|
@ -209,11 +210,9 @@ void quantized_matmul(
|
|||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
dnnl::post_ops po;
|
||||
po = attr.extract_post_ops(
|
||||
dst,
|
||||
true,
|
||||
dst.scalar_type() == at::kByte || dst.scalar_type() == at::kChar);
|
||||
po = attr.extract_post_ops(dst);
|
||||
bool m1_need_zp = (input_zero_point != 0);
|
||||
bool dst_need_zp = (output_zero_point != 0);
|
||||
bool wgh_is_per_channel = weight_scales.numel() > 1;
|
||||
|
||||
dnnl::matmul matmul_p;
|
||||
|
|
@ -241,6 +240,10 @@ void quantized_matmul(
|
|||
if (m1_need_zp) {
|
||||
pattr.set_zero_points_mask(DNNL_ARG_SRC, mask_ac);
|
||||
}
|
||||
pattr.set_scales_mask(DNNL_ARG_DST, mask_ac);
|
||||
if (dst_need_zp) {
|
||||
pattr.set_zero_points_mask(DNNL_ARG_DST, mask_ac);
|
||||
}
|
||||
|
||||
if (with_bias) {
|
||||
b_md = dnnl::memory::desc(bias_dims, bias_dt, bias_strides);
|
||||
|
|
@ -308,6 +311,17 @@ void quantized_matmul(
|
|||
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, m1_zp_m});
|
||||
}
|
||||
|
||||
dnnl::memory dst_sc_m, dst_zp_m;
|
||||
Tensor dst_sc_tensor, dst_zp_tensor;
|
||||
dst_sc_m = dnnl_memory_from_host_scalar(
|
||||
static_cast<float>(output_scale), dst_sc_tensor, engine);
|
||||
dst_zp_m = dnnl_memory_from_host_scalar(
|
||||
static_cast<int32_t>(output_zero_point), dst_zp_tensor, engine);
|
||||
args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, dst_sc_m});
|
||||
if (dst_need_zp) {
|
||||
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_m});
|
||||
}
|
||||
|
||||
auto qmatmul_event = dnnl::sycl_interop::execute(matmul_p, stream, args);
|
||||
|
||||
if (!dst.is_same(result))
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class QConvoneDNNXPU final {
|
|||
TORCH_CHECK(
|
||||
attr == "none" || attr == "relu" || attr == "hardtanh" ||
|
||||
attr == "hardswish" || attr == "swish",
|
||||
"We support quantized convolution without any post-ops or combinations for Quantized Conv + ReLU, Hardtanh, and Hardswish are supported. However, encountered unsupported post operation:",
|
||||
"We support quantized convolution without any post-ops or combinations for Quantized Conv + ReLU, Hardtanh, GELU, Swish, and Hardswish are supported. However, encountered unsupported post operation:",
|
||||
attr,
|
||||
".");
|
||||
}
|
||||
|
|
@ -104,6 +104,96 @@ class QConvoneDNNXPU final {
|
|||
/*unary_scalars*/ scalars,
|
||||
/*unary_algorithm*/ algorithm);
|
||||
}
|
||||
|
||||
static at::Tensor run_pointwise_binary(
|
||||
at::Tensor act,
|
||||
double act_scale,
|
||||
int64_t act_zero_point,
|
||||
at::Tensor weight,
|
||||
at::Tensor weight_scales,
|
||||
at::Tensor weight_zero_points,
|
||||
at::Tensor accum,
|
||||
std::optional<at::Tensor> bias,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
double accum_scale,
|
||||
int64_t accum_zero_point,
|
||||
std::string_view binary_attr,
|
||||
std::optional<at::Scalar> alpha,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
TORCH_CHECK(
|
||||
act.dim() == 4 && binary_attr == "sum" &&
|
||||
(!unary_attr.has_value() ||
|
||||
(unary_attr.has_value() &&
|
||||
(unary_attr.value() == "none" || unary_attr.value() == "relu"))),
|
||||
"post_op sum or post_op sum_relu is supported for quantized pointwise conv2d. Got binary_post_op: ",
|
||||
binary_attr,
|
||||
" unary_post_op: ",
|
||||
unary_attr.has_value() ? unary_attr.value() : "none",
|
||||
".")
|
||||
|
||||
bool is_channels_last_suggested = use_channels_last_for_conv(act, weight);
|
||||
auto mfmt = is_channels_last_suggested
|
||||
? get_cl_tag_by_ndim(act.ndimension())
|
||||
: at::MemoryFormat::Contiguous;
|
||||
Tensor input_ = act.contiguous(mfmt);
|
||||
Tensor weight_ = weight.contiguous(mfmt);
|
||||
|
||||
auto dst_tz = conv_dst_size(
|
||||
input_.ndimension(),
|
||||
input_.sizes(),
|
||||
weight_.sizes(),
|
||||
padding.vec(),
|
||||
padding.vec(),
|
||||
stride.vec(),
|
||||
dilation.vec());
|
||||
|
||||
bool has_accum_postop_sum = binary_attr == "sum";
|
||||
Tensor output = has_accum_postop_sum
|
||||
? accum
|
||||
: at::empty(
|
||||
dst_tz,
|
||||
device(c10::kXPU).dtype(output_dtype).memory_format(mfmt));
|
||||
|
||||
output = quantized_convolution(
|
||||
act,
|
||||
act_scale,
|
||||
act_zero_point,
|
||||
weight,
|
||||
weight_scales,
|
||||
weight_zero_points,
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
/*transposed*/ false,
|
||||
groups,
|
||||
output,
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
/*accum*/ accum,
|
||||
/*accum_scale*/ accum_scale,
|
||||
/*accum_zero_point*/ accum_zero_point,
|
||||
/*output_dtype*/ output_dtype,
|
||||
/*binary_attr*/ binary_attr,
|
||||
/*binary_alpha*/ alpha,
|
||||
/*unary_attr*/ unary_attr,
|
||||
/*unary_scalars*/ unary_scalars,
|
||||
/*unary_algorithm*/ unary_algorithm);
|
||||
|
||||
if (!has_accum_postop_sum) {
|
||||
return output;
|
||||
} else {
|
||||
return accum;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_LIBRARY_IMPL(onednn, XPU, m) {
|
||||
|
|
@ -119,6 +209,9 @@ TORCH_LIBRARY_IMPL(onednn, XPU, m) {
|
|||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("onednn::qconv3d_pointwise"),
|
||||
QConvoneDNNXPU::run_pointwise);
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary"),
|
||||
QConvoneDNNXPU::run_pointwise_binary);
|
||||
}
|
||||
|
||||
} // namespace at::native::xpu
|
||||
|
|
|
|||
|
|
@ -1246,7 +1246,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
qconv2d_unary_matcher_nodes=11,
|
||||
)
|
||||
|
||||
def _qconv2d_add_cpu_test_helper(self, use_relu=False, int8_mixed_bf16=False):
|
||||
def _qconv2d_add_test_helper(
|
||||
self, device="cpu", use_relu=False, int8_mixed_bf16=False
|
||||
):
|
||||
r"""
|
||||
This testcase will quantize a Conv2d->Add pattern as:
|
||||
X
|
||||
|
|
@ -1292,9 +1294,11 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
return res
|
||||
|
||||
for add_fn in quantization_add_fn_list + quantization_inplace_add_fn_list:
|
||||
mod = M(add_fn, use_relu).eval()
|
||||
v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False).add(
|
||||
1
|
||||
mod = M(add_fn, use_relu).eval().to(device=device)
|
||||
v = (
|
||||
torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False)
|
||||
.add(1)
|
||||
.to(device=device)
|
||||
)
|
||||
|
||||
def matcher_check_fn():
|
||||
|
|
@ -1320,7 +1324,9 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float,
|
||||
)
|
||||
|
||||
def _qconv2d_add_cpu_test_helper2(self, use_relu=False, int8_mixed_bf16=False):
|
||||
def _qconv2d_add_test_helper2(
|
||||
self, device="cpu", use_relu=False, int8_mixed_bf16=False
|
||||
):
|
||||
r"""
|
||||
This testcase will quantize two Conv2d->Add patterns as:
|
||||
|
||||
|
|
@ -1381,10 +1387,16 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
for add_fn, swap_inputs in itertools.product(
|
||||
quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True]
|
||||
):
|
||||
mod = M(add_fn, use_relu, swap_inputs).eval()
|
||||
x = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False)
|
||||
x2 = torch.randn((1, 6, 6, 6), dtype=torch.float32, requires_grad=False)
|
||||
x3 = torch.randn((1, 6, 4, 4), dtype=torch.float32, requires_grad=False)
|
||||
mod = M(add_fn, use_relu, swap_inputs).eval().to(device=device)
|
||||
x = torch.randn(
|
||||
(1, 3, 8, 8), dtype=torch.float32, requires_grad=False, device=device
|
||||
)
|
||||
x2 = torch.randn(
|
||||
(1, 6, 6, 6), dtype=torch.float32, requires_grad=False, device=device
|
||||
)
|
||||
x3 = torch.randn(
|
||||
(1, 6, 4, 4), dtype=torch.float32, requires_grad=False, device=device
|
||||
)
|
||||
|
||||
def matcher_check_fn():
|
||||
# 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2
|
||||
|
|
@ -1412,28 +1424,42 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_add_cpu(self):
|
||||
self._qconv2d_add_cpu_test_helper()
|
||||
self._qconv2d_add_cpu_test_helper2()
|
||||
self._qconv2d_add_test_helper()
|
||||
self._qconv2d_add_test_helper2()
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfNoXPU
|
||||
def test_qconv2d_add_xpu(self):
|
||||
self._qconv2d_add_test_helper(device="xpu")
|
||||
self._qconv2d_add_test_helper2(device="xpu")
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNNBF16
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_add_int8_mixed_bf16(self):
|
||||
self._qconv2d_add_cpu_test_helper(int8_mixed_bf16=True)
|
||||
self._qconv2d_add_cpu_test_helper2(int8_mixed_bf16=True)
|
||||
self._qconv2d_add_test_helper(int8_mixed_bf16=True)
|
||||
self._qconv2d_add_test_helper2(int8_mixed_bf16=True)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_add_relu_cpu(self):
|
||||
self._qconv2d_add_cpu_test_helper(use_relu=True)
|
||||
self._qconv2d_add_cpu_test_helper2(use_relu=True)
|
||||
self._qconv2d_add_test_helper(use_relu=True)
|
||||
self._qconv2d_add_test_helper2(use_relu=True)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
@skipIfNoXPU
|
||||
def test_qconv2d_add_relu_xpu(self):
|
||||
self._qconv2d_add_test_helper(device="xpu", use_relu=True)
|
||||
self._qconv2d_add_test_helper2(device="xpu", use_relu=True)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNNBF16
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_add_relu_int8_mixed_bf16(self):
|
||||
self._qconv2d_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True)
|
||||
self._qconv2d_add_cpu_test_helper2(use_relu=True, int8_mixed_bf16=True)
|
||||
self._qconv2d_add_test_helper(use_relu=True, int8_mixed_bf16=True)
|
||||
self._qconv2d_add_test_helper2(use_relu=True, int8_mixed_bf16=True)
|
||||
|
||||
@skipIfNoDynamoSupport
|
||||
@skipIfNoONEDNN
|
||||
|
|
|
|||
|
|
@ -710,7 +710,8 @@ def _is_valid_quantized_op_binary_optimization_pattern(
|
|||
if "other" in match.kwargs
|
||||
else (
|
||||
match.kwargs["accum"]
|
||||
if output_dtype == torch.uint8 or (not extra_input_from_dequant)
|
||||
if (output_dtype in OrderedSet([torch.uint8, torch.int8]))
|
||||
or (not extra_input_from_dequant)
|
||||
else match.kwargs["accum_after_dequant"]
|
||||
)
|
||||
)
|
||||
|
|
@ -2758,13 +2759,19 @@ def _register_qconv_post_op_fusion_pass(
|
|||
else:
|
||||
accum = (
|
||||
kwargs["accum"]
|
||||
if output_dtype == torch.uint8
|
||||
if output_dtype in OrderedSet([torch.uint8, torch.int8])
|
||||
else kwargs["accum_after_dequant"]
|
||||
)
|
||||
accum_scale = (
|
||||
kwargs["accum_scale"] if output_dtype == torch.uint8 else 1.0
|
||||
kwargs["accum_scale"]
|
||||
if output_dtype in OrderedSet([torch.uint8, torch.int8])
|
||||
else 1.0
|
||||
)
|
||||
accum_zp = (
|
||||
kwargs["accum_zp"]
|
||||
if output_dtype in OrderedSet([torch.uint8, torch.int8])
|
||||
else 0
|
||||
)
|
||||
accum_zp = kwargs["accum_zp"] if output_dtype == torch.uint8 else 0
|
||||
computation_args = (
|
||||
x,
|
||||
x_scale,
|
||||
|
|
|
|||
|
|
@ -96,22 +96,6 @@ class XPUInductorQuantizer(X86InductorQuantizer):
|
|||
):
|
||||
pass
|
||||
|
||||
def _annotate_conv2d_binary(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[FilterFn] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def _annotate_conv2d_binary_unary(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
quantization_config: Optional[QuantizationConfig],
|
||||
filter_fn: Optional[FilterFn] = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def _annotate_linear_fusion_pattern(
|
||||
self,
|
||||
model: torch.fx.GraphModule,
|
||||
|
|
|
|||
Loading…
Reference in a new issue