[js/webgpu] Fix NAN caused by un-initialized buffer in instance-norm (#19387)

The added case will be NAN because of the un-initialized buffer.
This commit is contained in:
Xu Xing 2024-03-19 13:59:32 +08:00 committed by GitHub
parent 6bb64683f8
commit 4c6a6a37f7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 81 additions and 1 deletions

View file

@ -207,7 +207,7 @@ const computeMean =
let offset = currentImageNumber * uniforms.image_size;
var sum = ${fillVector('f32', components)};
var squaredSum = ${fillVector('f32', components)};
for (var i: u32 = 0; i < ${WG}; i++) {
for (var i: u32 = 0; i < min(${WG}, uniforms.H); i++) {
let value = input[offset + i + currentChannelNumber * ${WG}];
sum += value[0];
squaredSum += value[1];

View file

@ -224,5 +224,85 @@
]
}
]
},
{
"name": "Simple test with NHWC, components 1, buffer reuse",
"operator": "InstanceNormalization",
"inputShapeDefinitions": "rankOnly",
"opset": {
"domain": "",
"version": 17
},
"cases": [
{
"name": "Simple test",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [2, 3, 1, 1],
"type": "float32"
},
{
"data": [1, 2, 3],
"dims": [3],
"type": "float32"
},
{
"data": [4, 5, 6],
"dims": [3],
"type": "float32"
}
],
"outputs": [
{
"data": [4, 5, 6, 4, 5, 6],
"dims": [2, 3, 1, 1],
"type": "float32"
}
]
}
]
},
{
"name": "Simple test with NHWC, components 2, buffer reuse",
"operator": "InstanceNormalization",
"inputShapeDefinitions": "rankOnly",
"opset": {
"domain": "",
"version": 17
},
"cases": [
{
"name": "Simple test",
"inputs": [
{
"data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 8, 7, 6, 5, 4, 3, 2],
"dims": [1, 6, 1, 3],
"type": "float32"
},
{
"data": [1, 2, 3, 4, 5, 6],
"dims": [6],
"type": "float32"
},
{
"data": [4, 5, 6, 7, 8, 9],
"dims": [6],
"type": "float32"
}
],
"outputs": [
{
"data": [
2.775264263153076, 4, 5.224735260009766, 2.5505285263061523, 5, 7.449470520019531, 2.325794219970703, 6,
9.674205780029297, 11.898944854736328, 7, 2.1010589599609375, 14.123676300048828, 8, 1.876321792602539,
16.348413467407227, 9, 1.6515865325927734
],
"dims": [1, 6, 1, 3],
"type": "float32"
}
]
}
]
}
]