mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-01 23:30:35 +00:00
[js/webgpu] Fix max pool shape end with 0 (#21698)
Bug: https://github.com/microsoft/onnxruntime/issues/21386 ### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
e32e3575d8
commit
7172aff1cf
4 changed files with 106 additions and 28 deletions
67
js/web/test/data/ops/max-pool.jsonc
Normal file
67
js/web/test/data/ops/max-pool.jsonc
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
[
|
||||
{
|
||||
"name": "MaxPool",
|
||||
"operator": "MaxPool",
|
||||
"attributes": [
|
||||
{ "name": "kernel_shape", "data": [3], "type": "ints" },
|
||||
{ "name": "dilations", "data": [1], "type": "ints" }
|
||||
],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[3,5,5] T[3,5,3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [
|
||||
1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238,
|
||||
-0.9772778749465942, 0.9500884413719177, -0.15135720372200012, -0.10321885347366333, 0.4105985164642334,
|
||||
0.14404356479644775, 1.4542734622955322, 0.7610377073287964, 0.12167501449584961, 0.44386324286460876,
|
||||
0.3336743414402008, 1.4940791130065918, -0.2051582634449005, 0.3130677044391632, -0.8540957570075989,
|
||||
-2.5529897212982178, 0.653618574142456, 0.8644362092018127, -0.7421650290489197, 2.269754648208618, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100,
|
||||
100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100
|
||||
],
|
||||
"dims": [3, 5, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [
|
||||
1.764052391052246, 2.2408931255340576, 2.2408931255340576, 0.9500884413719177, 0.9500884413719177,
|
||||
0.4105985164642334, 1.4542734622955322, 1.4542734622955322, 0.7610377073287964, 1.4940791130065918,
|
||||
1.4940791130065918, 0.3130677044391632, 0.8644362092018127, 0.8644362092018127, 2.269754648208618, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100,
|
||||
100, 100
|
||||
],
|
||||
"dims": [3, 5, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "MaxPool",
|
||||
"operator": "MaxPool",
|
||||
"attributes": [{ "name": "kernel_shape", "data": [3], "type": "ints" }],
|
||||
"cases": [
|
||||
{
|
||||
"name": "T[1,1,5] T[1,1,3]",
|
||||
"inputs": [
|
||||
{
|
||||
"data": [1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238],
|
||||
"dims": [1, 1, 5],
|
||||
"type": "float32"
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"data": [1.764052391052246, 2.2408931255340576, 2.2408931255340576],
|
||||
"dims": [1, 1, 3],
|
||||
"type": "float32"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
|
@ -1371,6 +1371,7 @@
|
|||
"matmul.jsonc",
|
||||
"matmulnbits.jsonc",
|
||||
"matmul-broadcast.jsonc",
|
||||
"max-pool.jsonc",
|
||||
"mul.jsonc",
|
||||
"mul_int32.jsonc",
|
||||
"multihead-attention.jsonc",
|
||||
|
|
|
|||
|
|
@ -48,7 +48,6 @@ class ConvBase : public JsKernel {
|
|||
std::vector<float> activation_params = info.GetAttrsOrDefault<float>("activation_params");
|
||||
int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault<int64_t>("channels_last", 0);
|
||||
|
||||
// currently only support Conv 1D/2D. TODO: support Conv3D and other
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
|
||||
"format" : $11 ? "NHWC" : "NCHW",
|
||||
"auto_pad" : $1,
|
||||
|
|
@ -65,8 +64,8 @@ class ConvBase : public JsKernel {
|
|||
JSEP_HEAP32_INDEX_START(dilations),
|
||||
JSEP_HEAP32_INDEX_END(dilations),
|
||||
static_cast<int32_t>(conv_attrs_.group),
|
||||
JSEP_HEAP32_INDEX_START(kernel_shape),
|
||||
JSEP_HEAP32_INDEX_END(kernel_shape),
|
||||
JSEP_HEAP32_INDEX_START(kernel_shapes),
|
||||
JSEP_HEAP32_INDEX_END(kernel_shapes),
|
||||
JSEP_HEAP32_INDEX_START(local_pads),
|
||||
JSEP_HEAP32_INDEX_END(local_pads),
|
||||
JSEP_HEAP32_INDEX_START(strides),
|
||||
|
|
|
|||
|
|
@ -9,38 +9,45 @@
|
|||
namespace onnxruntime {
|
||||
namespace js {
|
||||
|
||||
#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \
|
||||
"format" : $15 ? "NHWC" : "NCHW", \
|
||||
"auto_pad" : $1, \
|
||||
"ceil_mode" : $2, \
|
||||
"count_include_pad" : $3, \
|
||||
"storage_order" : $4, \
|
||||
"dilations" : [ $5, $6 ], \
|
||||
"kernel_shape" : [ $7, $8 ], \
|
||||
"pads" : [ $9, $10, $11, $12 ], \
|
||||
"strides" : [ $13, $14 ] \
|
||||
#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \
|
||||
"format" : $13 ? "NHWC" : "NCHW", \
|
||||
"auto_pad" : $1, \
|
||||
"ceil_mode" : $2, \
|
||||
"count_include_pad" : $3, \
|
||||
"storage_order" : $4, \
|
||||
"dilations" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], \
|
||||
"kernel_shape" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], \
|
||||
"pads" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], \
|
||||
"strides" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [] \
|
||||
})
|
||||
|
||||
#define POOL_ATTRIBUTES_PARAM_LIST \
|
||||
static_cast<int32_t>(pool_attrs_.auto_pad), \
|
||||
static_cast<int32_t>(pool_attrs_.ceil_mode), \
|
||||
static_cast<int32_t>(pool_attrs_.count_include_pad), \
|
||||
static_cast<int32_t>(pool_attrs_.storage_order), \
|
||||
static_cast<int32_t>(pool_attrs_.dilations.size() > 0 ? pool_attrs_.dilations[0] : 0), \
|
||||
static_cast<int32_t>(pool_attrs_.dilations.size() > 1 ? pool_attrs_.dilations[1] : 0), \
|
||||
static_cast<int32_t>(pool_attrs_.kernel_shape.size() > 0 ? pool_attrs_.kernel_shape[0] : 0), \
|
||||
static_cast<int32_t>(pool_attrs_.kernel_shape.size() > 1 ? pool_attrs_.kernel_shape[1] : 0), \
|
||||
static_cast<int32_t>(pool_attrs_.pads.size() > 0 ? pool_attrs_.pads[0] : 0), \
|
||||
static_cast<int32_t>(pool_attrs_.pads.size() > 1 ? pool_attrs_.pads[1] : 0), \
|
||||
static_cast<int32_t>(pool_attrs_.pads.size() > 2 ? pool_attrs_.pads[2] : 0), \
|
||||
static_cast<int32_t>(pool_attrs_.pads.size() > 3 ? pool_attrs_.pads[3] : 0), \
|
||||
static_cast<int32_t>(pool_attrs_.strides.size() > 0 ? pool_attrs_.strides[0] : 0), \
|
||||
static_cast<int32_t>(pool_attrs_.strides.size() > 1 ? pool_attrs_.strides[1] : 0), \
|
||||
#define POOL_ATTRIBUTES_PARAM_LIST \
|
||||
static_cast<int32_t>(pool_attrs_.auto_pad), \
|
||||
static_cast<int32_t>(pool_attrs_.ceil_mode), \
|
||||
static_cast<int32_t>(pool_attrs_.count_include_pad), \
|
||||
static_cast<int32_t>(pool_attrs_.storage_order), \
|
||||
JSEP_HEAP32_INDEX_START(dilations), \
|
||||
JSEP_HEAP32_INDEX_END(dilations), \
|
||||
JSEP_HEAP32_INDEX_START(kernel_shapes), \
|
||||
JSEP_HEAP32_INDEX_END(kernel_shapes), \
|
||||
JSEP_HEAP32_INDEX_START(pads), \
|
||||
JSEP_HEAP32_INDEX_END(pads), \
|
||||
JSEP_HEAP32_INDEX_START(strides), \
|
||||
JSEP_HEAP32_INDEX_END(strides), \
|
||||
static_cast<int32_t>(is_channels_last)
|
||||
|
||||
#define GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING ({"format" : $1 ? "NHWC" : "NCHW"})
|
||||
#define GLOBAL_POOL_ATTRIBUTES_PARAM_LIST static_cast<int32_t>(is_channels_last)
|
||||
|
||||
template <typename Type>
|
||||
inline const std::vector<Type> CastTensorShapeVector(const TensorShapeVector& shape) {
|
||||
std::vector<Type> castedShapes(shape.size(), 0);
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
castedShapes[i] = gsl::narrow_cast<Type>(shape[i]);
|
||||
}
|
||||
return castedShapes;
|
||||
}
|
||||
|
||||
template <typename PoolType, bool is_channels_last>
|
||||
class Pool : public JsKernel, public PoolBase {
|
||||
public:
|
||||
|
|
@ -54,6 +61,10 @@ class Pool : public JsKernel, public PoolBase {
|
|||
// TODO: GlobalLpPool
|
||||
}
|
||||
} else {
|
||||
auto kernel_shapes{CastTensorShapeVector<int32_t>(pool_attrs_.kernel_shape)};
|
||||
auto strides{CastTensorShapeVector<int32_t>(pool_attrs_.strides)};
|
||||
auto dilations{CastTensorShapeVector<int32_t>(pool_attrs_.dilations)};
|
||||
auto pads{CastTensorShapeVector<int32_t>(pool_attrs_.pads)};
|
||||
if constexpr (PoolType::type == onnxruntime::PoolType::kAveragePool) {
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(AveragePool, POOL_ATTRIBUTES_JS_OBJ_MAPPING, POOL_ATTRIBUTES_PARAM_LIST);
|
||||
} else if constexpr (PoolType::type == onnxruntime::PoolType::kMaxPool) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue