From e8a9d4f04d8dc49fba87d21d181d8e8c12dea5b7 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Thu, 10 Aug 2023 09:14:43 -0700 Subject: [PATCH] [JS/Web] Fix Resize kMSInternalNHWCDomain (#17023) ### Description Fix some Resize failing tests. ### Motivation and Context --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- js/web/docs/webgpu-operators.md | 2 +- js/web/test/suite-test-list.jsonc | 42 +++++++++---------- .../onnx_transpose_optimization.cc | 2 +- .../providers/js/js_execution_provider.cc | 2 - onnxruntime/core/providers/js/js_kernel.h | 34 +++++++++++---- .../core/providers/js/operators/resize.cc | 2 +- 6 files changed, 49 insertions(+), 35 deletions(-) diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 6df4cb61c6..84bf69b51f 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -63,7 +63,7 @@ Do not modify directly.* | ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | | | Relu | ai.onnx(6-12,13,14+) | | | Reshape | ai.onnx(5-12,13,14+) | no GPU kernel | -| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | +| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling | | Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix | | Sigmoid | ai.onnx(6-12,13+) | | | Sin | ai.onnx(7+) | | diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 0eb402019d..70e110d498 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -979,32 +979,32 @@ "test_reshape_zero_dim", "test_resize_downsample_linear", "test_resize_downsample_nearest", - // "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", + "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside", // "test_resize_downsample_scales_cubic_align_corners", - // "test_resize_downsample_scales_cubic", + "test_resize_downsample_scales_cubic", // "test_resize_downsample_scales_linear_align_corners", - // "test_resize_downsample_scales_linear", - // "test_resize_downsample_scales_nearest", - // "test_resize_downsample_sizes_cubic", - // "test_resize_downsample_sizes_linear_pytorch_half_pixel", - // "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", - // "test_resize_downsample_sizes_nearest", + "test_resize_downsample_scales_linear", + "test_resize_downsample_scales_nearest", + "test_resize_downsample_sizes_cubic", + "test_resize_downsample_sizes_linear_pytorch_half_pixel", + "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn", + "test_resize_downsample_sizes_nearest", "test_resize_nearest", - // "test_resize_tf_crop_and_resize", + "test_resize_tf_crop_and_resize", "test_resize_upsample_linear", "test_resize_upsample_nearest", - // "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", - // "test_resize_upsample_scales_cubic_align_corners", - // "test_resize_upsample_scales_cubic_asymmetric", - // "test_resize_upsample_scales_cubic", - // "test_resize_upsample_scales_linear_align_corners", - // "test_resize_upsample_scales_linear", - // "test_resize_upsample_scales_nearest", - // "test_resize_upsample_sizes_cubic", - // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel", - // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners", - // "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", - // "test_resize_upsample_sizes_nearest", + "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside", + "test_resize_upsample_scales_cubic_align_corners", + "test_resize_upsample_scales_cubic_asymmetric", + "test_resize_upsample_scales_cubic", + "test_resize_upsample_scales_linear_align_corners", + "test_resize_upsample_scales_linear", + "test_resize_upsample_scales_nearest", + "test_resize_upsample_sizes_cubic", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners", + "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric", + "test_resize_upsample_sizes_nearest", // // "test_reversesequence_batch", // // "test_reversesequence_time", // // "test_rnn_seq_length", diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index a54903a036..af4859fdbb 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -927,7 +927,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con } static bool HandleResize([[maybe_unused]] HandlerArgs& args) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) +#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) // The CUDA Resize kernel requires that the input is NCHW, so we can't push a Transpose through a Resize // in ORT builds with CUDA enabled. // The ROCm EP is generated from the CUDA EP kernel so the same applies to builds with ROCm enabled. diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index b242c988d5..102b2f6cc0 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -270,7 +270,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather); -class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize); @@ -495,7 +494,6 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 98d26e15b1..b8fab3bbc5 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -79,16 +79,31 @@ class JsKernel : public OpKernel { } Status SerializeKernelContext(OpKernelContext* context, AllocatorPtr alloc, void* custom_data_ptr, size_t custom_data_size, void** ptr) const { + // + // An optional input may be a placeholder, which is nullptr. In this case, we still need to + // add the placeholder to the serialized data, with type, data_ptr and dim_size all zeros, + // so that the JS kernel can know the input count. // // temp_data_format (every item is (u)int32_t): - // context_ptr | input_count | custom_data_ptr | custom_data_size | [input_data_0] ... [input_data_N-1] + // context_ptr | input_count | custom_data_ptr | custom_data_size | [input_data_or_placeholder_0] ... [input_data_or_placeholder_N-1] + // + // input_data_or_placeholder_format: + // input_data OR placeholder // // input_data_format: // type | data_ptr | dim_size | dim[0] ... dim[N-1] // + // placeholder_format: + // 0 | 0 | 0 + // size_t temp_data_size = sizeof(size_t) * 5; for (int i = 0; i < context->InputCount(); i++) { - temp_data_size += sizeof(size_t) * (3 + context->Input(i)->Shape().NumDimensions()); + const auto* input_ptr = context->Input(i); + if (nullptr != input_ptr) { + temp_data_size += sizeof(size_t) * (3 + input_ptr->Shape().NumDimensions()); + } else { + temp_data_size += sizeof(size_t) * 3; + } } uint32_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); if (p_serialized_kernel_context == nullptr) { @@ -102,18 +117,19 @@ class JsKernel : public OpKernel { p_serialized_kernel_context[4] = static_cast(custom_data_size); size_t index = 5; for (int i = 0; i < context->InputCount(); i++) { - p_serialized_kernel_context[index++] = static_cast(context->Input(i)->GetElementType()); - const auto* ptr = context->Input(i); + const auto* input_ptr = context->Input(i); // Skip if the input is only a placeholder. - if (ptr == nullptr) { + if (input_ptr == nullptr) { + p_serialized_kernel_context[index++] = 0; p_serialized_kernel_context[index++] = 0; p_serialized_kernel_context[index++] = 0; continue; } - p_serialized_kernel_context[index++] = reinterpret_cast(ptr->DataRaw()); - p_serialized_kernel_context[index++] = static_cast(context->Input(i)->Shape().NumDimensions()); - for (size_t d = 0; d < context->Input(i)->Shape().NumDimensions(); d++) { - p_serialized_kernel_context[index++] = static_cast(context->Input(i)->Shape()[d]); + p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); + p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); + for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) { + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); } } diff --git a/onnxruntime/core/providers/js/operators/resize.cc b/onnxruntime/core/providers/js/operators/resize.cc index e6d6d0a065..7619c33a47 100644 --- a/onnxruntime/core/providers/js/operators/resize.cc +++ b/onnxruntime/core/providers/js/operators/resize.cc @@ -45,12 +45,12 @@ namespace js { Resize); #define REGISTER_RESIZE_KERNEL_DOMAIN(domain) \ - REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain); \ REGISTER_RESIZE_VERSIONED_KERNEL(domain, 11, 12); \ REGISTER_RESIZE_VERSIONED_KERNEL(domain, 13, 17); \ REGISTER_RESIZE_VERSIONED_KERNEL(domain, 18, 18); \ REGISTER_RESIZE_KERNEL(domain, 19); +REGISTER_RESIZE_VERSIONED_10_10_KERNEL(kOnnxDomain); REGISTER_RESIZE_KERNEL_DOMAIN(kOnnxDomain); REGISTER_RESIZE_KERNEL_DOMAIN(kMSInternalNHWCDomain);