[oneDNN EP] Optimized DynamicQuantizeLinear operator (#12403)

* Removed unnecesary reorders
* Removed unnecesary element wise clip
This commit is contained in:
Erick Muñoz 2022-08-03 13:36:42 -06:00 committed by GitHub
parent 7f58bd7236
commit d1497bdf62
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -17,137 +17,123 @@ Y_ZeroPoint = np.clip(round((0 - x_min) / Y_Scale), 0, 255).astype(np.uint8)
Y = np.clip(np.round(X / Y_Scale) + Y_ZeroPoint, 0, 255).astype(np.uint8)
*/
void DnnlDynamicQuantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
// Get engine
auto eng = sp.GetEngine();
auto x_memory = sp.GetMemory(node.Input(IN_X).Name());
x_memory = sp.GetMemoryAndReshape(node.Input(IN_X), x_memory.get_desc(), eng);
auto x_memory_desc = x_memory.get_desc();
auto x_memory_dims = x_memory_desc.dims();
auto x_memory_dt = x_memory_desc.data_type();
//dims of all ones
dnnl::memory::dims min_max_dst_dims(x_memory_dims.size(), 1);
// Get src mem
auto x_mem = sp.GetMemory(node.Input(IN_X));
auto x_md = x_mem.get_desc();
auto x_size = x_md.dims().size();
auto x_format = sp.GetDnnlFormat(x_size);
auto min_max_dst_mem_desc = dnnl::memory::desc(min_max_dst_dims, x_memory_dt, sp.GetDnnlFormat(x_memory_dims.size()));
// Dims for one dimensional tensor
dnnl::memory::dims one_dim(x_size, 1);
//max_reduction responsible for producing scale
auto max_reduction_d = dnnl::reduction::desc(
dnnl::algorithm::reduction_max, x_memory_desc, min_max_dst_mem_desc, 0.f, 0.f);
auto min_reduction_d = dnnl::reduction::desc(
dnnl::algorithm::reduction_min, x_memory_desc, min_max_dst_mem_desc, 0.f, 0.f);
// Y_SCALE COMPUTATION
// Create descriptor for reduction max and min
auto y_scale_md = dnnl::memory::desc(one_dim, x_md.data_type(), x_format);
auto max_reduction_d = dnnl::reduction::desc(dnnl::algorithm::reduction_max, x_md, y_scale_md, 0.f, 0.f);
auto min_reduction_d = dnnl::reduction::desc(dnnl::algorithm::reduction_min, x_md, y_scale_md, 0.f, 0.f);
//prepare a zero memory, used for adding 0 to data range for min max operation
auto zero_mem = dnnl::memory(min_max_dst_mem_desc, eng);
// Fill memory with 0's, needed for min and max binary
auto zero_mem = dnnl::memory(y_scale_md, eng);
WriteZeroToMem(zero_mem);
//max(x) with 0 added to range -> sub min(x) -> div 255
// Generate post ops to calc y_scale
dnnl::primitive_attr max_reduction_attr;
{
dnnl::post_ops sub_min_div_255;
//max(0,reduce_max(x))
sub_min_div_255.append_binary(dnnl::algorithm::binary_max, zero_mem.get_desc());
//max - min
sub_min_div_255.append_binary(dnnl::algorithm::binary_sub, min_max_dst_mem_desc);
// /255
sub_min_div_255.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f / 255.0f, 0.0f);
max_reduction_attr.set_post_ops(sub_min_div_255);
// y_scale = ((x_max - x_min) / (255 - 0)).astype(np.float32) # uint8->[0, 255]
dnnl::post_ops calc_y_scale;
// x_max = max(0, reduce_max(x))
calc_y_scale.append_binary(dnnl::algorithm::binary_max, zero_mem.get_desc());
// y_scale = x_max - x_min
calc_y_scale.append_binary(dnnl::algorithm::binary_sub, y_scale_md);
// y_scale =/ 255
calc_y_scale.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f / 255.0f, 0.0f);
max_reduction_attr.set_post_ops(calc_y_scale);
}
//add 0 to reduce min range
// x_min = min(0, reduce_min(x))
dnnl::primitive_attr min_reduction_attr;
{
dnnl::post_ops add_0_to_range;
add_0_to_range.append_binary(dnnl::algorithm::binary_min, zero_mem.get_desc());
min_reduction_attr.set_post_ops(add_0_to_range);
dnnl::post_ops calc_min;
calc_min.append_binary(dnnl::algorithm::binary_min, zero_mem.get_desc());
min_reduction_attr.set_post_ops(calc_min);
}
auto max_reduction_pd = dnnl::reduction::primitive_desc(max_reduction_d, max_reduction_attr, eng);
auto min_reduction_pd = dnnl::reduction::primitive_desc(min_reduction_d, min_reduction_attr, eng);
// Create reduction primitive
auto max_reduction_prim = dnnl::reduction(dnnl::reduction::primitive_desc(max_reduction_d, max_reduction_attr, eng));
auto min_reduction_prim = dnnl::reduction(dnnl::reduction::primitive_desc(min_reduction_d, min_reduction_attr, eng));
auto max_reduction_prim = dnnl::reduction(max_reduction_pd);
auto min_reduction_prim = dnnl::reduction(min_reduction_pd);
// Create y_scale and min memory
auto y_scale_mem = dnnl::memory(y_scale_md, eng);
auto min_reduction_mem = dnnl::memory(y_scale_md, eng);
auto y_scale_mem = dnnl::memory(min_max_dst_mem_desc, eng);
auto min_reduction_dst_mem = dnnl::memory(min_max_dst_mem_desc, eng);
// Compute min first since max_reduction needs min as input
sp.AddPrimitive(min_reduction_prim, {{DNNL_ARG_SRC, x_mem},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, zero_mem},
{DNNL_ARG_DST, min_reduction_mem}});
std::unordered_map<int, dnnl::memory> min_reduction_args = {{DNNL_ARG_SRC, x_memory}, {DNNL_ARG_DST, min_reduction_dst_mem}};
min_reduction_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] = zero_mem;
std::unordered_map<int, dnnl::memory> max_reduction_args = {{DNNL_ARG_SRC, x_memory}, {DNNL_ARG_DST, y_scale_mem}};
max_reduction_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] = zero_mem;
max_reduction_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1] = min_reduction_dst_mem;
//compute min first since max_reduction needs min dst as post op arg
// compute x min
sp.AddPrimitive(min_reduction_prim, min_reduction_args);
// compute y scale f32
sp.AddPrimitive(max_reduction_prim, max_reduction_args);
// Compute y_scale in fp32
sp.AddPrimitive(max_reduction_prim, {{DNNL_ARG_SRC, x_mem},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, zero_mem},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, min_reduction_mem},
{DNNL_ARG_DST, y_scale_mem}});
//prepare y zero point kernel
auto y_zero_point_d = dnnl::binary::desc(dnnl::algorithm::binary_div, min_reduction_dst_mem.get_desc(), y_scale_mem.get_desc(), min_reduction_dst_mem.get_desc());
// Y_ZERO_POINT COMPUTATION
// Create memory and primitive descriptors
auto y_zp_md = dnnl::memory::desc(one_dim, dnnl::memory::data_type::u8, x_format);
auto zp_prim_d = dnnl::binary::desc(dnnl::algorithm::binary_div, y_scale_md, y_scale_md, y_zp_md);
dnnl::primitive_attr y_zero_point_attr;
// Add round and clip post ops
dnnl::primitive_attr zp_prim_attr;
{
y_zero_point_attr.set_scales(DNNL_ARG_SRC_0, 0, {-1.0f});
zp_prim_attr.set_scales(DNNL_ARG_SRC_0, 0, {-1.0f});
dnnl::post_ops div_saturate_round;
div_saturate_round.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
//clip might not be necessary as reorder cast will saturate on lower precision
//might still need it as compute y needs saturated zero point already
div_saturate_round.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip_v2, 0.0f, 255.0f);
y_zero_point_attr.set_post_ops(div_saturate_round);
zp_prim_attr.set_post_ops(div_saturate_round);
}
auto y_zero_point_pd = dnnl::binary::primitive_desc(y_zero_point_d, y_zero_point_attr, eng);
auto y_zero_point_prim = dnnl::binary(y_zero_point_pd);
auto y_zero_point_dst_mem = dnnl::memory(y_zero_point_pd.dst_desc(), eng);
std::unordered_map<int, dnnl::memory> y_zero_point_args = {{DNNL_ARG_SRC_0, min_reduction_dst_mem}, {DNNL_ARG_SRC_1, y_scale_mem}, {DNNL_ARG_DST, y_zero_point_dst_mem}};
// Create primitives
auto zp_prim_pd = dnnl::binary::primitive_desc(zp_prim_d, zp_prim_attr, eng);
auto zp_prim = dnnl::binary(zp_prim_pd);
//y zero point f32
//np.clip(round((0 - x_min) / Y_Scale), 0, 255)
sp.AddPrimitive(y_zero_point_prim, y_zero_point_args);
// Create zp memory dst
auto y_zp_mem = dnnl::memory(zp_prim_pd.dst_desc(), eng);
// Calc zp
sp.AddPrimitive(zp_prim,{{DNNL_ARG_SRC_0, min_reduction_mem},
{DNNL_ARG_SRC_1, y_scale_mem},
{DNNL_ARG_DST, y_zp_mem}});
//prepare y kernel
//x/y -> round() -> + y_zp -> clip 0,255
auto y_d = dnnl::binary::desc(dnnl::algorithm::binary_div, x_memory.get_desc(), y_scale_mem.get_desc(), x_memory.get_desc());
dnnl::primitive_attr y_attr;
// Y COMPUTATION
// Create y md and binary desc
auto y_md = dnnl::memory::desc(x_md.dims(), dnnl::memory::data_type::u8, x_format);
auto y_bin_d = dnnl::binary::desc(dnnl::algorithm::binary_div, x_mem.get_desc(), y_scale_mem.get_desc(), y_md);
// Add post ops
dnnl::primitive_attr y_bin_attr;
{
dnnl::post_ops round_zp_saturate;
round_zp_saturate.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
round_zp_saturate.append_binary(dnnl::algorithm::binary_add, y_zero_point_dst_mem.get_desc());
//clip might not be necessary as reorder cast will saturate on lower precision
round_zp_saturate.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip_v2, 0.0f, 255.0f);
y_attr.set_post_ops(round_zp_saturate);
dnnl::post_ops round_add;
round_add.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
round_add.append_binary(dnnl::algorithm::binary_add, y_zp_mem.get_desc());
y_bin_attr.set_post_ops(round_add);
}
auto y_pd = dnnl::binary::primitive_desc(y_d, y_attr, eng);
// Create binary primitive with post ops
auto y_pd = dnnl::binary::primitive_desc(y_bin_d, y_bin_attr, eng);
auto y_prim = dnnl::binary(y_pd);
// Create y_dst mem
auto y_mem = dnnl::memory(y_pd.dst_desc(), eng);
std::unordered_map<int, dnnl::memory> y_args = {{DNNL_ARG_SRC_0, x_memory}, {DNNL_ARG_SRC_1, y_scale_mem}, {DNNL_ARG_DST, y_mem}};
y_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1] = y_zero_point_dst_mem;
// Compute y
sp.AddPrimitive(y_prim, {{DNNL_ARG_SRC_0, x_mem},
{DNNL_ARG_SRC_1, y_scale_mem},
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, y_zp_mem},
{DNNL_ARG_DST, y_mem}});
// x/y -> round() -> + y_zp -> clip 0,255
// quantized output tensor f32
sp.AddPrimitive(y_prim, y_args);
//set output y scale
// Set outputs
sp.SetMemory(node.Output(OUT_Y), y_mem);
sp.SetMemory(node.Output(OUT_Y_SCALE), y_scale_mem, false, true);
//data type change for y_zp and set memory
//data type change is needed for onnxruntime spec
//zp for onednn is currently in s32, any downstream node might need to convert from u8 to s32
auto y_zero_point_dst_md_uint8 = ChangeMemoryDescDataType(y_zero_point_dst_mem.get_desc(), dnnl::memory::data_type::u8);
auto y_zero_point_dst_mem_uint8 = dnnl::memory(y_zero_point_dst_md_uint8, eng);
sp.AddPrimitive(dnnl::reorder(y_zero_point_dst_mem, y_zero_point_dst_mem_uint8), {{DNNL_ARG_FROM, y_zero_point_dst_mem}, {DNNL_ARG_TO, y_zero_point_dst_mem_uint8}});
sp.SetMemory(node.Output(OUT_Y_ZP), y_zero_point_dst_mem_uint8, false, true);
//data type change for y and set memory
auto y_md_uint8 = ChangeMemoryDescDataType(y_mem.get_desc(), dnnl::memory::data_type::u8);
auto y_mem_uint8 = dnnl::memory(y_md_uint8, eng);
sp.AddPrimitive(dnnl::reorder(y_mem, y_mem_uint8), {{DNNL_ARG_FROM, y_mem}, {DNNL_ARG_TO, y_mem_uint8}});
sp.SetMemory(node.Output(OUT_Y), y_mem_uint8);
sp.SetMemory(node.Output(OUT_Y_ZP), y_zp_mem, false, true);
}
//change md to targeted data type of cast op dst