mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
[oneDNN EP] Optimized DynamicQuantizeLinear operator (#12403)
* Removed unnecesary reorders * Removed unnecesary element wise clip
This commit is contained in:
parent
7f58bd7236
commit
d1497bdf62
1 changed files with 82 additions and 96 deletions
|
|
@ -17,137 +17,123 @@ Y_ZeroPoint = np.clip(round((0 - x_min) / Y_Scale), 0, 255).astype(np.uint8)
|
|||
Y = np.clip(np.round(X / Y_Scale) + Y_ZeroPoint, 0, 255).astype(np.uint8)
|
||||
*/
|
||||
void DnnlDynamicQuantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) {
|
||||
// Get engine
|
||||
auto eng = sp.GetEngine();
|
||||
auto x_memory = sp.GetMemory(node.Input(IN_X).Name());
|
||||
x_memory = sp.GetMemoryAndReshape(node.Input(IN_X), x_memory.get_desc(), eng);
|
||||
auto x_memory_desc = x_memory.get_desc();
|
||||
auto x_memory_dims = x_memory_desc.dims();
|
||||
auto x_memory_dt = x_memory_desc.data_type();
|
||||
|
||||
//dims of all ones
|
||||
dnnl::memory::dims min_max_dst_dims(x_memory_dims.size(), 1);
|
||||
// Get src mem
|
||||
auto x_mem = sp.GetMemory(node.Input(IN_X));
|
||||
auto x_md = x_mem.get_desc();
|
||||
auto x_size = x_md.dims().size();
|
||||
auto x_format = sp.GetDnnlFormat(x_size);
|
||||
|
||||
auto min_max_dst_mem_desc = dnnl::memory::desc(min_max_dst_dims, x_memory_dt, sp.GetDnnlFormat(x_memory_dims.size()));
|
||||
// Dims for one dimensional tensor
|
||||
dnnl::memory::dims one_dim(x_size, 1);
|
||||
|
||||
//max_reduction responsible for producing scale
|
||||
auto max_reduction_d = dnnl::reduction::desc(
|
||||
dnnl::algorithm::reduction_max, x_memory_desc, min_max_dst_mem_desc, 0.f, 0.f);
|
||||
auto min_reduction_d = dnnl::reduction::desc(
|
||||
dnnl::algorithm::reduction_min, x_memory_desc, min_max_dst_mem_desc, 0.f, 0.f);
|
||||
// Y_SCALE COMPUTATION
|
||||
// Create descriptor for reduction max and min
|
||||
auto y_scale_md = dnnl::memory::desc(one_dim, x_md.data_type(), x_format);
|
||||
auto max_reduction_d = dnnl::reduction::desc(dnnl::algorithm::reduction_max, x_md, y_scale_md, 0.f, 0.f);
|
||||
auto min_reduction_d = dnnl::reduction::desc(dnnl::algorithm::reduction_min, x_md, y_scale_md, 0.f, 0.f);
|
||||
|
||||
//prepare a zero memory, used for adding 0 to data range for min max operation
|
||||
auto zero_mem = dnnl::memory(min_max_dst_mem_desc, eng);
|
||||
// Fill memory with 0's, needed for min and max binary
|
||||
auto zero_mem = dnnl::memory(y_scale_md, eng);
|
||||
WriteZeroToMem(zero_mem);
|
||||
|
||||
//max(x) with 0 added to range -> sub min(x) -> div 255
|
||||
// Generate post ops to calc y_scale
|
||||
dnnl::primitive_attr max_reduction_attr;
|
||||
{
|
||||
dnnl::post_ops sub_min_div_255;
|
||||
//max(0,reduce_max(x))
|
||||
sub_min_div_255.append_binary(dnnl::algorithm::binary_max, zero_mem.get_desc());
|
||||
//max - min
|
||||
sub_min_div_255.append_binary(dnnl::algorithm::binary_sub, min_max_dst_mem_desc);
|
||||
// /255
|
||||
sub_min_div_255.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f / 255.0f, 0.0f);
|
||||
max_reduction_attr.set_post_ops(sub_min_div_255);
|
||||
// y_scale = ((x_max - x_min) / (255 - 0)).astype(np.float32) # uint8->[0, 255]
|
||||
dnnl::post_ops calc_y_scale;
|
||||
// x_max = max(0, reduce_max(x))
|
||||
calc_y_scale.append_binary(dnnl::algorithm::binary_max, zero_mem.get_desc());
|
||||
// y_scale = x_max - x_min
|
||||
calc_y_scale.append_binary(dnnl::algorithm::binary_sub, y_scale_md);
|
||||
// y_scale =/ 255
|
||||
calc_y_scale.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f / 255.0f, 0.0f);
|
||||
max_reduction_attr.set_post_ops(calc_y_scale);
|
||||
}
|
||||
|
||||
//add 0 to reduce min range
|
||||
// x_min = min(0, reduce_min(x))
|
||||
dnnl::primitive_attr min_reduction_attr;
|
||||
{
|
||||
dnnl::post_ops add_0_to_range;
|
||||
add_0_to_range.append_binary(dnnl::algorithm::binary_min, zero_mem.get_desc());
|
||||
min_reduction_attr.set_post_ops(add_0_to_range);
|
||||
dnnl::post_ops calc_min;
|
||||
calc_min.append_binary(dnnl::algorithm::binary_min, zero_mem.get_desc());
|
||||
min_reduction_attr.set_post_ops(calc_min);
|
||||
}
|
||||
|
||||
auto max_reduction_pd = dnnl::reduction::primitive_desc(max_reduction_d, max_reduction_attr, eng);
|
||||
auto min_reduction_pd = dnnl::reduction::primitive_desc(min_reduction_d, min_reduction_attr, eng);
|
||||
// Create reduction primitive
|
||||
auto max_reduction_prim = dnnl::reduction(dnnl::reduction::primitive_desc(max_reduction_d, max_reduction_attr, eng));
|
||||
auto min_reduction_prim = dnnl::reduction(dnnl::reduction::primitive_desc(min_reduction_d, min_reduction_attr, eng));
|
||||
|
||||
auto max_reduction_prim = dnnl::reduction(max_reduction_pd);
|
||||
auto min_reduction_prim = dnnl::reduction(min_reduction_pd);
|
||||
// Create y_scale and min memory
|
||||
auto y_scale_mem = dnnl::memory(y_scale_md, eng);
|
||||
auto min_reduction_mem = dnnl::memory(y_scale_md, eng);
|
||||
|
||||
auto y_scale_mem = dnnl::memory(min_max_dst_mem_desc, eng);
|
||||
auto min_reduction_dst_mem = dnnl::memory(min_max_dst_mem_desc, eng);
|
||||
// Compute min first since max_reduction needs min as input
|
||||
sp.AddPrimitive(min_reduction_prim, {{DNNL_ARG_SRC, x_mem},
|
||||
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, zero_mem},
|
||||
{DNNL_ARG_DST, min_reduction_mem}});
|
||||
|
||||
std::unordered_map<int, dnnl::memory> min_reduction_args = {{DNNL_ARG_SRC, x_memory}, {DNNL_ARG_DST, min_reduction_dst_mem}};
|
||||
min_reduction_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] = zero_mem;
|
||||
|
||||
std::unordered_map<int, dnnl::memory> max_reduction_args = {{DNNL_ARG_SRC, x_memory}, {DNNL_ARG_DST, y_scale_mem}};
|
||||
max_reduction_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] = zero_mem;
|
||||
max_reduction_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1] = min_reduction_dst_mem;
|
||||
|
||||
//compute min first since max_reduction needs min dst as post op arg
|
||||
// compute x min
|
||||
sp.AddPrimitive(min_reduction_prim, min_reduction_args);
|
||||
// compute y scale f32
|
||||
sp.AddPrimitive(max_reduction_prim, max_reduction_args);
|
||||
// Compute y_scale in fp32
|
||||
sp.AddPrimitive(max_reduction_prim, {{DNNL_ARG_SRC, x_mem},
|
||||
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, zero_mem},
|
||||
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, min_reduction_mem},
|
||||
{DNNL_ARG_DST, y_scale_mem}});
|
||||
|
||||
|
||||
//prepare y zero point kernel
|
||||
auto y_zero_point_d = dnnl::binary::desc(dnnl::algorithm::binary_div, min_reduction_dst_mem.get_desc(), y_scale_mem.get_desc(), min_reduction_dst_mem.get_desc());
|
||||
// Y_ZERO_POINT COMPUTATION
|
||||
// Create memory and primitive descriptors
|
||||
auto y_zp_md = dnnl::memory::desc(one_dim, dnnl::memory::data_type::u8, x_format);
|
||||
auto zp_prim_d = dnnl::binary::desc(dnnl::algorithm::binary_div, y_scale_md, y_scale_md, y_zp_md);
|
||||
|
||||
dnnl::primitive_attr y_zero_point_attr;
|
||||
// Add round and clip post ops
|
||||
dnnl::primitive_attr zp_prim_attr;
|
||||
{
|
||||
y_zero_point_attr.set_scales(DNNL_ARG_SRC_0, 0, {-1.0f});
|
||||
zp_prim_attr.set_scales(DNNL_ARG_SRC_0, 0, {-1.0f});
|
||||
dnnl::post_ops div_saturate_round;
|
||||
div_saturate_round.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
|
||||
//clip might not be necessary as reorder cast will saturate on lower precision
|
||||
//might still need it as compute y needs saturated zero point already
|
||||
div_saturate_round.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip_v2, 0.0f, 255.0f);
|
||||
y_zero_point_attr.set_post_ops(div_saturate_round);
|
||||
zp_prim_attr.set_post_ops(div_saturate_round);
|
||||
}
|
||||
auto y_zero_point_pd = dnnl::binary::primitive_desc(y_zero_point_d, y_zero_point_attr, eng);
|
||||
auto y_zero_point_prim = dnnl::binary(y_zero_point_pd);
|
||||
|
||||
auto y_zero_point_dst_mem = dnnl::memory(y_zero_point_pd.dst_desc(), eng);
|
||||
std::unordered_map<int, dnnl::memory> y_zero_point_args = {{DNNL_ARG_SRC_0, min_reduction_dst_mem}, {DNNL_ARG_SRC_1, y_scale_mem}, {DNNL_ARG_DST, y_zero_point_dst_mem}};
|
||||
// Create primitives
|
||||
auto zp_prim_pd = dnnl::binary::primitive_desc(zp_prim_d, zp_prim_attr, eng);
|
||||
auto zp_prim = dnnl::binary(zp_prim_pd);
|
||||
|
||||
//y zero point f32
|
||||
//np.clip(round((0 - x_min) / Y_Scale), 0, 255)
|
||||
sp.AddPrimitive(y_zero_point_prim, y_zero_point_args);
|
||||
// Create zp memory dst
|
||||
auto y_zp_mem = dnnl::memory(zp_prim_pd.dst_desc(), eng);
|
||||
|
||||
// Calc zp
|
||||
sp.AddPrimitive(zp_prim,{{DNNL_ARG_SRC_0, min_reduction_mem},
|
||||
{DNNL_ARG_SRC_1, y_scale_mem},
|
||||
{DNNL_ARG_DST, y_zp_mem}});
|
||||
|
||||
//prepare y kernel
|
||||
//x/y -> round() -> + y_zp -> clip 0,255
|
||||
auto y_d = dnnl::binary::desc(dnnl::algorithm::binary_div, x_memory.get_desc(), y_scale_mem.get_desc(), x_memory.get_desc());
|
||||
dnnl::primitive_attr y_attr;
|
||||
// Y COMPUTATION
|
||||
// Create y md and binary desc
|
||||
auto y_md = dnnl::memory::desc(x_md.dims(), dnnl::memory::data_type::u8, x_format);
|
||||
auto y_bin_d = dnnl::binary::desc(dnnl::algorithm::binary_div, x_mem.get_desc(), y_scale_mem.get_desc(), y_md);
|
||||
// Add post ops
|
||||
dnnl::primitive_attr y_bin_attr;
|
||||
{
|
||||
dnnl::post_ops round_zp_saturate;
|
||||
round_zp_saturate.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
|
||||
round_zp_saturate.append_binary(dnnl::algorithm::binary_add, y_zero_point_dst_mem.get_desc());
|
||||
//clip might not be necessary as reorder cast will saturate on lower precision
|
||||
round_zp_saturate.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip_v2, 0.0f, 255.0f);
|
||||
y_attr.set_post_ops(round_zp_saturate);
|
||||
dnnl::post_ops round_add;
|
||||
round_add.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f);
|
||||
round_add.append_binary(dnnl::algorithm::binary_add, y_zp_mem.get_desc());
|
||||
y_bin_attr.set_post_ops(round_add);
|
||||
}
|
||||
auto y_pd = dnnl::binary::primitive_desc(y_d, y_attr, eng);
|
||||
// Create binary primitive with post ops
|
||||
auto y_pd = dnnl::binary::primitive_desc(y_bin_d, y_bin_attr, eng);
|
||||
auto y_prim = dnnl::binary(y_pd);
|
||||
|
||||
// Create y_dst mem
|
||||
auto y_mem = dnnl::memory(y_pd.dst_desc(), eng);
|
||||
std::unordered_map<int, dnnl::memory> y_args = {{DNNL_ARG_SRC_0, x_memory}, {DNNL_ARG_SRC_1, y_scale_mem}, {DNNL_ARG_DST, y_mem}};
|
||||
y_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1] = y_zero_point_dst_mem;
|
||||
// Compute y
|
||||
sp.AddPrimitive(y_prim, {{DNNL_ARG_SRC_0, x_mem},
|
||||
{DNNL_ARG_SRC_1, y_scale_mem},
|
||||
{DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, y_zp_mem},
|
||||
{DNNL_ARG_DST, y_mem}});
|
||||
|
||||
// x/y -> round() -> + y_zp -> clip 0,255
|
||||
// quantized output tensor f32
|
||||
sp.AddPrimitive(y_prim, y_args);
|
||||
|
||||
|
||||
//set output y scale
|
||||
// Set outputs
|
||||
sp.SetMemory(node.Output(OUT_Y), y_mem);
|
||||
sp.SetMemory(node.Output(OUT_Y_SCALE), y_scale_mem, false, true);
|
||||
|
||||
//data type change for y_zp and set memory
|
||||
//data type change is needed for onnxruntime spec
|
||||
//zp for onednn is currently in s32, any downstream node might need to convert from u8 to s32
|
||||
auto y_zero_point_dst_md_uint8 = ChangeMemoryDescDataType(y_zero_point_dst_mem.get_desc(), dnnl::memory::data_type::u8);
|
||||
auto y_zero_point_dst_mem_uint8 = dnnl::memory(y_zero_point_dst_md_uint8, eng);
|
||||
sp.AddPrimitive(dnnl::reorder(y_zero_point_dst_mem, y_zero_point_dst_mem_uint8), {{DNNL_ARG_FROM, y_zero_point_dst_mem}, {DNNL_ARG_TO, y_zero_point_dst_mem_uint8}});
|
||||
sp.SetMemory(node.Output(OUT_Y_ZP), y_zero_point_dst_mem_uint8, false, true);
|
||||
|
||||
//data type change for y and set memory
|
||||
auto y_md_uint8 = ChangeMemoryDescDataType(y_mem.get_desc(), dnnl::memory::data_type::u8);
|
||||
auto y_mem_uint8 = dnnl::memory(y_md_uint8, eng);
|
||||
sp.AddPrimitive(dnnl::reorder(y_mem, y_mem_uint8), {{DNNL_ARG_FROM, y_mem}, {DNNL_ARG_TO, y_mem_uint8}});
|
||||
sp.SetMemory(node.Output(OUT_Y), y_mem_uint8);
|
||||
|
||||
sp.SetMemory(node.Output(OUT_Y_ZP), y_zp_mem, false, true);
|
||||
}
|
||||
|
||||
//change md to targeted data type of cast op dst
|
||||
|
|
|
|||
Loading…
Reference in a new issue