diff --git a/onnxruntime/core/providers/dnnl/subgraph/dnnl_dynamicquantizelinear.cc b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dynamicquantizelinear.cc index 6527b0b36c..96061c201a 100644 --- a/onnxruntime/core/providers/dnnl/subgraph/dnnl_dynamicquantizelinear.cc +++ b/onnxruntime/core/providers/dnnl/subgraph/dnnl_dynamicquantizelinear.cc @@ -17,137 +17,123 @@ Y_ZeroPoint = np.clip(round((0 - x_min) / Y_Scale), 0, 255).astype(np.uint8) Y = np.clip(np.round(X / Y_Scale) + Y_ZeroPoint, 0, 255).astype(np.uint8) */ void DnnlDynamicQuantizeLinear::CreatePrimitive(DnnlSubgraphPrimitive& sp, DnnlNode& node) { + // Get engine auto eng = sp.GetEngine(); - auto x_memory = sp.GetMemory(node.Input(IN_X).Name()); - x_memory = sp.GetMemoryAndReshape(node.Input(IN_X), x_memory.get_desc(), eng); - auto x_memory_desc = x_memory.get_desc(); - auto x_memory_dims = x_memory_desc.dims(); - auto x_memory_dt = x_memory_desc.data_type(); - //dims of all ones - dnnl::memory::dims min_max_dst_dims(x_memory_dims.size(), 1); + // Get src mem + auto x_mem = sp.GetMemory(node.Input(IN_X)); + auto x_md = x_mem.get_desc(); + auto x_size = x_md.dims().size(); + auto x_format = sp.GetDnnlFormat(x_size); - auto min_max_dst_mem_desc = dnnl::memory::desc(min_max_dst_dims, x_memory_dt, sp.GetDnnlFormat(x_memory_dims.size())); + // Dims for one dimensional tensor + dnnl::memory::dims one_dim(x_size, 1); - //max_reduction responsible for producing scale - auto max_reduction_d = dnnl::reduction::desc( - dnnl::algorithm::reduction_max, x_memory_desc, min_max_dst_mem_desc, 0.f, 0.f); - auto min_reduction_d = dnnl::reduction::desc( - dnnl::algorithm::reduction_min, x_memory_desc, min_max_dst_mem_desc, 0.f, 0.f); + // Y_SCALE COMPUTATION + // Create descriptor for reduction max and min + auto y_scale_md = dnnl::memory::desc(one_dim, x_md.data_type(), x_format); + auto max_reduction_d = dnnl::reduction::desc(dnnl::algorithm::reduction_max, x_md, y_scale_md, 0.f, 0.f); + auto min_reduction_d = dnnl::reduction::desc(dnnl::algorithm::reduction_min, x_md, y_scale_md, 0.f, 0.f); - //prepare a zero memory, used for adding 0 to data range for min max operation - auto zero_mem = dnnl::memory(min_max_dst_mem_desc, eng); + // Fill memory with 0's, needed for min and max binary + auto zero_mem = dnnl::memory(y_scale_md, eng); WriteZeroToMem(zero_mem); - //max(x) with 0 added to range -> sub min(x) -> div 255 + // Generate post ops to calc y_scale dnnl::primitive_attr max_reduction_attr; { - dnnl::post_ops sub_min_div_255; - //max(0,reduce_max(x)) - sub_min_div_255.append_binary(dnnl::algorithm::binary_max, zero_mem.get_desc()); - //max - min - sub_min_div_255.append_binary(dnnl::algorithm::binary_sub, min_max_dst_mem_desc); - // /255 - sub_min_div_255.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f / 255.0f, 0.0f); - max_reduction_attr.set_post_ops(sub_min_div_255); + // y_scale = ((x_max - x_min) / (255 - 0)).astype(np.float32) # uint8->[0, 255] + dnnl::post_ops calc_y_scale; + // x_max = max(0, reduce_max(x)) + calc_y_scale.append_binary(dnnl::algorithm::binary_max, zero_mem.get_desc()); + // y_scale = x_max - x_min + calc_y_scale.append_binary(dnnl::algorithm::binary_sub, y_scale_md); + // y_scale =/ 255 + calc_y_scale.append_eltwise(1.0f, dnnl::algorithm::eltwise_linear, 1.0f / 255.0f, 0.0f); + max_reduction_attr.set_post_ops(calc_y_scale); } - //add 0 to reduce min range + // x_min = min(0, reduce_min(x)) dnnl::primitive_attr min_reduction_attr; { - dnnl::post_ops add_0_to_range; - add_0_to_range.append_binary(dnnl::algorithm::binary_min, zero_mem.get_desc()); - min_reduction_attr.set_post_ops(add_0_to_range); + dnnl::post_ops calc_min; + calc_min.append_binary(dnnl::algorithm::binary_min, zero_mem.get_desc()); + min_reduction_attr.set_post_ops(calc_min); } - auto max_reduction_pd = dnnl::reduction::primitive_desc(max_reduction_d, max_reduction_attr, eng); - auto min_reduction_pd = dnnl::reduction::primitive_desc(min_reduction_d, min_reduction_attr, eng); + // Create reduction primitive + auto max_reduction_prim = dnnl::reduction(dnnl::reduction::primitive_desc(max_reduction_d, max_reduction_attr, eng)); + auto min_reduction_prim = dnnl::reduction(dnnl::reduction::primitive_desc(min_reduction_d, min_reduction_attr, eng)); - auto max_reduction_prim = dnnl::reduction(max_reduction_pd); - auto min_reduction_prim = dnnl::reduction(min_reduction_pd); + // Create y_scale and min memory + auto y_scale_mem = dnnl::memory(y_scale_md, eng); + auto min_reduction_mem = dnnl::memory(y_scale_md, eng); - auto y_scale_mem = dnnl::memory(min_max_dst_mem_desc, eng); - auto min_reduction_dst_mem = dnnl::memory(min_max_dst_mem_desc, eng); + // Compute min first since max_reduction needs min as input + sp.AddPrimitive(min_reduction_prim, {{DNNL_ARG_SRC, x_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, zero_mem}, + {DNNL_ARG_DST, min_reduction_mem}}); - std::unordered_map min_reduction_args = {{DNNL_ARG_SRC, x_memory}, {DNNL_ARG_DST, min_reduction_dst_mem}}; - min_reduction_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] = zero_mem; - - std::unordered_map max_reduction_args = {{DNNL_ARG_SRC, x_memory}, {DNNL_ARG_DST, y_scale_mem}}; - max_reduction_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] = zero_mem; - max_reduction_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1] = min_reduction_dst_mem; - - //compute min first since max_reduction needs min dst as post op arg - // compute x min - sp.AddPrimitive(min_reduction_prim, min_reduction_args); - // compute y scale f32 - sp.AddPrimitive(max_reduction_prim, max_reduction_args); + // Compute y_scale in fp32 + sp.AddPrimitive(max_reduction_prim, {{DNNL_ARG_SRC, x_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1, zero_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, min_reduction_mem}, + {DNNL_ARG_DST, y_scale_mem}}); - //prepare y zero point kernel - auto y_zero_point_d = dnnl::binary::desc(dnnl::algorithm::binary_div, min_reduction_dst_mem.get_desc(), y_scale_mem.get_desc(), min_reduction_dst_mem.get_desc()); + // Y_ZERO_POINT COMPUTATION + // Create memory and primitive descriptors + auto y_zp_md = dnnl::memory::desc(one_dim, dnnl::memory::data_type::u8, x_format); + auto zp_prim_d = dnnl::binary::desc(dnnl::algorithm::binary_div, y_scale_md, y_scale_md, y_zp_md); - dnnl::primitive_attr y_zero_point_attr; + // Add round and clip post ops + dnnl::primitive_attr zp_prim_attr; { - y_zero_point_attr.set_scales(DNNL_ARG_SRC_0, 0, {-1.0f}); + zp_prim_attr.set_scales(DNNL_ARG_SRC_0, 0, {-1.0f}); dnnl::post_ops div_saturate_round; div_saturate_round.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f); - //clip might not be necessary as reorder cast will saturate on lower precision - //might still need it as compute y needs saturated zero point already - div_saturate_round.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip_v2, 0.0f, 255.0f); - y_zero_point_attr.set_post_ops(div_saturate_round); + zp_prim_attr.set_post_ops(div_saturate_round); } - auto y_zero_point_pd = dnnl::binary::primitive_desc(y_zero_point_d, y_zero_point_attr, eng); - auto y_zero_point_prim = dnnl::binary(y_zero_point_pd); - auto y_zero_point_dst_mem = dnnl::memory(y_zero_point_pd.dst_desc(), eng); - std::unordered_map y_zero_point_args = {{DNNL_ARG_SRC_0, min_reduction_dst_mem}, {DNNL_ARG_SRC_1, y_scale_mem}, {DNNL_ARG_DST, y_zero_point_dst_mem}}; + // Create primitives + auto zp_prim_pd = dnnl::binary::primitive_desc(zp_prim_d, zp_prim_attr, eng); + auto zp_prim = dnnl::binary(zp_prim_pd); - //y zero point f32 - //np.clip(round((0 - x_min) / Y_Scale), 0, 255) - sp.AddPrimitive(y_zero_point_prim, y_zero_point_args); + // Create zp memory dst + auto y_zp_mem = dnnl::memory(zp_prim_pd.dst_desc(), eng); + // Calc zp + sp.AddPrimitive(zp_prim,{{DNNL_ARG_SRC_0, min_reduction_mem}, + {DNNL_ARG_SRC_1, y_scale_mem}, + {DNNL_ARG_DST, y_zp_mem}}); - //prepare y kernel - //x/y -> round() -> + y_zp -> clip 0,255 - auto y_d = dnnl::binary::desc(dnnl::algorithm::binary_div, x_memory.get_desc(), y_scale_mem.get_desc(), x_memory.get_desc()); - dnnl::primitive_attr y_attr; + // Y COMPUTATION + // Create y md and binary desc + auto y_md = dnnl::memory::desc(x_md.dims(), dnnl::memory::data_type::u8, x_format); + auto y_bin_d = dnnl::binary::desc(dnnl::algorithm::binary_div, x_mem.get_desc(), y_scale_mem.get_desc(), y_md); + // Add post ops + dnnl::primitive_attr y_bin_attr; { - dnnl::post_ops round_zp_saturate; - round_zp_saturate.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f); - round_zp_saturate.append_binary(dnnl::algorithm::binary_add, y_zero_point_dst_mem.get_desc()); - //clip might not be necessary as reorder cast will saturate on lower precision - round_zp_saturate.append_eltwise(1.0f, dnnl::algorithm::eltwise_clip_v2, 0.0f, 255.0f); - y_attr.set_post_ops(round_zp_saturate); + dnnl::post_ops round_add; + round_add.append_eltwise(1.0f, dnnl::algorithm::eltwise_round, 0.0f, 0.0f); + round_add.append_binary(dnnl::algorithm::binary_add, y_zp_mem.get_desc()); + y_bin_attr.set_post_ops(round_add); } - auto y_pd = dnnl::binary::primitive_desc(y_d, y_attr, eng); + // Create binary primitive with post ops + auto y_pd = dnnl::binary::primitive_desc(y_bin_d, y_bin_attr, eng); auto y_prim = dnnl::binary(y_pd); - + // Create y_dst mem auto y_mem = dnnl::memory(y_pd.dst_desc(), eng); - std::unordered_map y_args = {{DNNL_ARG_SRC_0, x_memory}, {DNNL_ARG_SRC_1, y_scale_mem}, {DNNL_ARG_DST, y_mem}}; - y_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1] = y_zero_point_dst_mem; + // Compute y + sp.AddPrimitive(y_prim, {{DNNL_ARG_SRC_0, x_mem}, + {DNNL_ARG_SRC_1, y_scale_mem}, + {DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1, y_zp_mem}, + {DNNL_ARG_DST, y_mem}}); - // x/y -> round() -> + y_zp -> clip 0,255 - // quantized output tensor f32 - sp.AddPrimitive(y_prim, y_args); - - - //set output y scale + // Set outputs + sp.SetMemory(node.Output(OUT_Y), y_mem); sp.SetMemory(node.Output(OUT_Y_SCALE), y_scale_mem, false, true); - - //data type change for y_zp and set memory - //data type change is needed for onnxruntime spec - //zp for onednn is currently in s32, any downstream node might need to convert from u8 to s32 - auto y_zero_point_dst_md_uint8 = ChangeMemoryDescDataType(y_zero_point_dst_mem.get_desc(), dnnl::memory::data_type::u8); - auto y_zero_point_dst_mem_uint8 = dnnl::memory(y_zero_point_dst_md_uint8, eng); - sp.AddPrimitive(dnnl::reorder(y_zero_point_dst_mem, y_zero_point_dst_mem_uint8), {{DNNL_ARG_FROM, y_zero_point_dst_mem}, {DNNL_ARG_TO, y_zero_point_dst_mem_uint8}}); - sp.SetMemory(node.Output(OUT_Y_ZP), y_zero_point_dst_mem_uint8, false, true); - - //data type change for y and set memory - auto y_md_uint8 = ChangeMemoryDescDataType(y_mem.get_desc(), dnnl::memory::data_type::u8); - auto y_mem_uint8 = dnnl::memory(y_md_uint8, eng); - sp.AddPrimitive(dnnl::reorder(y_mem, y_mem_uint8), {{DNNL_ARG_FROM, y_mem}, {DNNL_ARG_TO, y_mem_uint8}}); - sp.SetMemory(node.Output(OUT_Y), y_mem_uint8); - + sp.SetMemory(node.Output(OUT_Y_ZP), y_zp_mem, false, true); } //change md to targeted data type of cast op dst