mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-04 23:59:56 +00:00
[JS/Web] Fix Resize kMSInternalNHWCDomain (#17023)
### Description Fix some Resize failing tests. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
This commit is contained in:
parent
ef6f4a4aa1
commit
e8a9d4f04d
6 changed files with 49 additions and 35 deletions
|
|
@ -63,7 +63,7 @@ Do not modify directly.*
|
|||
| ReduceSumSquare | ai.onnx(1-10,11-12,13-17,18+) | |
|
||||
| Relu | ai.onnx(6-12,13,14+) | |
|
||||
| Reshape | ai.onnx(5-12,13,14+) | no GPU kernel |
|
||||
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
|
||||
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
|
||||
| Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix |
|
||||
| Sigmoid | ai.onnx(6-12,13+) | |
|
||||
| Sin | ai.onnx(7+) | |
|
||||
|
|
|
|||
|
|
@ -979,32 +979,32 @@
|
|||
"test_reshape_zero_dim",
|
||||
"test_resize_downsample_linear",
|
||||
"test_resize_downsample_nearest",
|
||||
// "test_resize_downsample_scales_cubic_A_n0p5_exclude_outside",
|
||||
"test_resize_downsample_scales_cubic_A_n0p5_exclude_outside",
|
||||
// "test_resize_downsample_scales_cubic_align_corners",
|
||||
// "test_resize_downsample_scales_cubic",
|
||||
"test_resize_downsample_scales_cubic",
|
||||
// "test_resize_downsample_scales_linear_align_corners",
|
||||
// "test_resize_downsample_scales_linear",
|
||||
// "test_resize_downsample_scales_nearest",
|
||||
// "test_resize_downsample_sizes_cubic",
|
||||
// "test_resize_downsample_sizes_linear_pytorch_half_pixel",
|
||||
// "test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn",
|
||||
// "test_resize_downsample_sizes_nearest",
|
||||
"test_resize_downsample_scales_linear",
|
||||
"test_resize_downsample_scales_nearest",
|
||||
"test_resize_downsample_sizes_cubic",
|
||||
"test_resize_downsample_sizes_linear_pytorch_half_pixel",
|
||||
"test_resize_downsample_sizes_nearest_tf_half_pixel_for_nn",
|
||||
"test_resize_downsample_sizes_nearest",
|
||||
"test_resize_nearest",
|
||||
// "test_resize_tf_crop_and_resize",
|
||||
"test_resize_tf_crop_and_resize",
|
||||
"test_resize_upsample_linear",
|
||||
"test_resize_upsample_nearest",
|
||||
// "test_resize_upsample_scales_cubic_A_n0p5_exclude_outside",
|
||||
// "test_resize_upsample_scales_cubic_align_corners",
|
||||
// "test_resize_upsample_scales_cubic_asymmetric",
|
||||
// "test_resize_upsample_scales_cubic",
|
||||
// "test_resize_upsample_scales_linear_align_corners",
|
||||
// "test_resize_upsample_scales_linear",
|
||||
// "test_resize_upsample_scales_nearest",
|
||||
// "test_resize_upsample_sizes_cubic",
|
||||
// "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel",
|
||||
// "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners",
|
||||
// "opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
|
||||
// "test_resize_upsample_sizes_nearest",
|
||||
"test_resize_upsample_scales_cubic_A_n0p5_exclude_outside",
|
||||
"test_resize_upsample_scales_cubic_align_corners",
|
||||
"test_resize_upsample_scales_cubic_asymmetric",
|
||||
"test_resize_upsample_scales_cubic",
|
||||
"test_resize_upsample_scales_linear_align_corners",
|
||||
"test_resize_upsample_scales_linear",
|
||||
"test_resize_upsample_scales_nearest",
|
||||
"test_resize_upsample_sizes_cubic",
|
||||
"opset{12,13,17,18}/test_resize_upsample_sizes_nearest_ceil_half_pixel",
|
||||
"opset{12,13,17,18}/test_resize_upsample_sizes_nearest_floor_align_corners",
|
||||
"opset{12,13,17,18}/test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric",
|
||||
"test_resize_upsample_sizes_nearest",
|
||||
// // "test_reversesequence_batch",
|
||||
// // "test_reversesequence_time",
|
||||
// // "test_rnn_seq_length",
|
||||
|
|
|
|||
|
|
@ -927,7 +927,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con
|
|||
}
|
||||
|
||||
static bool HandleResize([[maybe_unused]] HandlerArgs& args) {
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN)
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN)
|
||||
// The CUDA Resize kernel requires that the input is NCHW, so we can't push a Transpose through a Resize
|
||||
// in ORT builds with CUDA enabled.
|
||||
// The ROCm EP is generated from the CUDA EP kernel so the same applies to builds with ROCm enabled.
|
||||
|
|
|
|||
|
|
@ -270,7 +270,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
|
|||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gather);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gather);
|
||||
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize);
|
||||
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize);
|
||||
|
|
@ -495,7 +494,6 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, 18, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 19, Resize)>,
|
||||
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 10, 10, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 12, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 13, 17, Resize)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 18, 18, Resize)>,
|
||||
|
|
|
|||
|
|
@ -79,16 +79,31 @@ class JsKernel : public OpKernel {
|
|||
}
|
||||
|
||||
Status SerializeKernelContext(OpKernelContext* context, AllocatorPtr alloc, void* custom_data_ptr, size_t custom_data_size, void** ptr) const {
|
||||
//
|
||||
// An optional input may be a placeholder, which is nullptr. In this case, we still need to
|
||||
// add the placeholder to the serialized data, with type, data_ptr and dim_size all zeros,
|
||||
// so that the JS kernel can know the input count.
|
||||
//
|
||||
// temp_data_format (every item is (u)int32_t):
|
||||
// context_ptr | input_count | custom_data_ptr | custom_data_size | [input_data_0] ... [input_data_N-1]
|
||||
// context_ptr | input_count | custom_data_ptr | custom_data_size | [input_data_or_placeholder_0] ... [input_data_or_placeholder_N-1]
|
||||
//
|
||||
// input_data_or_placeholder_format:
|
||||
// input_data OR placeholder
|
||||
//
|
||||
// input_data_format:
|
||||
// type | data_ptr | dim_size | dim[0] ... dim[N-1]
|
||||
//
|
||||
// placeholder_format:
|
||||
// 0 | 0 | 0
|
||||
//
|
||||
size_t temp_data_size = sizeof(size_t) * 5;
|
||||
for (int i = 0; i < context->InputCount(); i++) {
|
||||
temp_data_size += sizeof(size_t) * (3 + context->Input<Tensor>(i)->Shape().NumDimensions());
|
||||
const auto* input_ptr = context->Input<Tensor>(i);
|
||||
if (nullptr != input_ptr) {
|
||||
temp_data_size += sizeof(size_t) * (3 + input_ptr->Shape().NumDimensions());
|
||||
} else {
|
||||
temp_data_size += sizeof(size_t) * 3;
|
||||
}
|
||||
}
|
||||
uint32_t* p_serialized_kernel_context = reinterpret_cast<uint32_t*>(alloc->Alloc(temp_data_size));
|
||||
if (p_serialized_kernel_context == nullptr) {
|
||||
|
|
@ -102,18 +117,19 @@ class JsKernel : public OpKernel {
|
|||
p_serialized_kernel_context[4] = static_cast<uint32_t>(custom_data_size);
|
||||
size_t index = 5;
|
||||
for (int i = 0; i < context->InputCount(); i++) {
|
||||
p_serialized_kernel_context[index++] = static_cast<uint32_t>(context->Input<Tensor>(i)->GetElementType());
|
||||
const auto* ptr = context->Input<Tensor>(i);
|
||||
const auto* input_ptr = context->Input<Tensor>(i);
|
||||
// Skip if the input is only a placeholder.
|
||||
if (ptr == nullptr) {
|
||||
if (input_ptr == nullptr) {
|
||||
p_serialized_kernel_context[index++] = 0;
|
||||
p_serialized_kernel_context[index++] = 0;
|
||||
p_serialized_kernel_context[index++] = 0;
|
||||
continue;
|
||||
}
|
||||
p_serialized_kernel_context[index++] = reinterpret_cast<uint32_t>(ptr->DataRaw());
|
||||
p_serialized_kernel_context[index++] = static_cast<uint32_t>(context->Input<Tensor>(i)->Shape().NumDimensions());
|
||||
for (size_t d = 0; d < context->Input<Tensor>(i)->Shape().NumDimensions(); d++) {
|
||||
p_serialized_kernel_context[index++] = static_cast<uint32_t>(context->Input<Tensor>(i)->Shape()[d]);
|
||||
p_serialized_kernel_context[index++] = static_cast<uint32_t>(input_ptr->GetElementType());
|
||||
p_serialized_kernel_context[index++] = reinterpret_cast<uint32_t>(input_ptr->DataRaw());
|
||||
p_serialized_kernel_context[index++] = static_cast<uint32_t>(input_ptr->Shape().NumDimensions());
|
||||
for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) {
|
||||
p_serialized_kernel_context[index++] = static_cast<uint32_t>(input_ptr->Shape()[d]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -45,12 +45,12 @@ namespace js {
|
|||
Resize);
|
||||
|
||||
#define REGISTER_RESIZE_KERNEL_DOMAIN(domain) \
|
||||
REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain); \
|
||||
REGISTER_RESIZE_VERSIONED_KERNEL(domain, 11, 12); \
|
||||
REGISTER_RESIZE_VERSIONED_KERNEL(domain, 13, 17); \
|
||||
REGISTER_RESIZE_VERSIONED_KERNEL(domain, 18, 18); \
|
||||
REGISTER_RESIZE_KERNEL(domain, 19);
|
||||
|
||||
REGISTER_RESIZE_VERSIONED_10_10_KERNEL(kOnnxDomain);
|
||||
REGISTER_RESIZE_KERNEL_DOMAIN(kOnnxDomain);
|
||||
REGISTER_RESIZE_KERNEL_DOMAIN(kMSInternalNHWCDomain);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue