mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-28 03:20:58 +00:00
[js/web] Refine conv attributes (#20684)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
This commit is contained in:
parent
e81c8676e3
commit
6b58fcc00b
1 changed files with 46 additions and 59 deletions
|
|
@ -17,78 +17,65 @@ class ConvBase : public JsKernel {
|
|||
ConvBase(const OpKernelInfo& info, bool is_channels_last, bool is_fused_conv) : JsKernel(info),
|
||||
conv_attrs_(info),
|
||||
w_is_const_(false) {
|
||||
TensorShapeVector kernel_shape;
|
||||
const size_t pads_vec_size = conv_attrs_.pads.size() == 0 ? 4 : conv_attrs_.pads.size();
|
||||
std::vector<int32_t> local_pads(pads_vec_size, 0);
|
||||
for (size_t i = 0; i < conv_attrs_.pads.size() && i < pads_vec_size; ++i) {
|
||||
local_pads[i] = gsl::narrow_cast<int32_t>(conv_attrs_.pads[i]);
|
||||
}
|
||||
|
||||
TensorShapeVector kernel_shape;
|
||||
if (conv_attrs_.kernel_shape_specified) {
|
||||
ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape).IsOK());
|
||||
}
|
||||
std::vector<int32_t> kernel_shapes(kernel_shape.size(), 0);
|
||||
if (conv_attrs_.kernel_shape_specified) {
|
||||
for (size_t i = 0; i < kernel_shape.size(); ++i) {
|
||||
kernel_shapes[i] = gsl::narrow_cast<int32_t>(kernel_shape[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int32_t> strides(conv_attrs_.strides.size(), 0);
|
||||
for (size_t i = 0; i < conv_attrs_.strides.size(); ++i) {
|
||||
strides[i] = gsl::narrow_cast<int32_t>(conv_attrs_.strides[i]);
|
||||
}
|
||||
|
||||
std::vector<int32_t> dilations(conv_attrs_.dilations.size(), 0);
|
||||
for (size_t i = 0; i < conv_attrs_.dilations.size(); ++i) {
|
||||
dilations[i] = gsl::narrow_cast<int32_t>(conv_attrs_.dilations[i]);
|
||||
}
|
||||
|
||||
conv_attrs_.activation = info.GetAttrOrDefault<std::string>("activation", "");
|
||||
std::vector<float> activation_params = info.GetAttrsOrDefault<float>("activation_params");
|
||||
int64_t channels_last = is_channels_last ? 1 : info.GetAttrOrDefault<int64_t>("channels_last", 0);
|
||||
auto kernel_shape_0 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 0 ? kernel_shape[0] : 0;
|
||||
auto kernel_shape_1 = conv_attrs_.kernel_shape_specified && kernel_shape.size() > 1 ? kernel_shape[1] : 0;
|
||||
|
||||
// currently only support Conv 1D/2D. TODO: support Conv3D and other
|
||||
if (conv_attrs_.dilations.size() == 1 ||
|
||||
(conv_attrs_.kernel_shape_specified && kernel_shape.size() == 1) ||
|
||||
conv_attrs_.strides.size() == 1) {
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
|
||||
"format" : $8 ? "NHWC" : "NCHW",
|
||||
"auto_pad" : $1,
|
||||
"dilations" : [$2],
|
||||
"group" : $3,
|
||||
"kernel_shape" : [$4],
|
||||
"pads" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [],
|
||||
"strides" : [$7],
|
||||
"w_is_const" : () JS_ARROW(!!HEAP8[$9]),
|
||||
"activation" : UTF8ToString($10),
|
||||
"activation_params" : $11 ? Array.from(HEAPF32.subarray($11, $12)) : []
|
||||
}),
|
||||
static_cast<int32_t>(conv_attrs_.auto_pad),
|
||||
static_cast<int32_t>(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0),
|
||||
static_cast<int32_t>(conv_attrs_.group),
|
||||
static_cast<int32_t>(kernel_shape_0),
|
||||
JSEP_HEAP32_INDEX_START(local_pads),
|
||||
JSEP_HEAP32_INDEX_END(local_pads),
|
||||
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
|
||||
static_cast<int32_t>(channels_last),
|
||||
JSEP_HEAP8_INDEX(&w_is_const_),
|
||||
conv_attrs_.activation.c_str(),
|
||||
JSEP_HEAP32_INDEX_START(activation_params),
|
||||
JSEP_HEAP32_INDEX_END(activation_params));
|
||||
} else {
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
|
||||
"format" : $11 ? "NHWC" : "NCHW",
|
||||
"auto_pad" : $1,
|
||||
"dilations" : [ $2, $3 ],
|
||||
"group" : $4,
|
||||
"kernel_shape" : [ $5, $6 ],
|
||||
"pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [],
|
||||
"strides" : [ $9, $10 ],
|
||||
"w_is_const" : () JS_ARROW(!!HEAP8[$12]),
|
||||
"activation" : UTF8ToString($13),
|
||||
"activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : []
|
||||
}),
|
||||
static_cast<int32_t>(conv_attrs_.auto_pad),
|
||||
static_cast<int32_t>(conv_attrs_.dilations.size() > 0 ? conv_attrs_.dilations[0] : 0),
|
||||
static_cast<int32_t>(conv_attrs_.dilations.size() > 1 ? conv_attrs_.dilations[1] : 0),
|
||||
static_cast<int32_t>(conv_attrs_.group),
|
||||
static_cast<int32_t>(kernel_shape_0),
|
||||
static_cast<int32_t>(kernel_shape_1),
|
||||
JSEP_HEAP32_INDEX_START(local_pads),
|
||||
JSEP_HEAP32_INDEX_END(local_pads),
|
||||
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
|
||||
static_cast<int32_t>(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0),
|
||||
static_cast<int32_t>(channels_last),
|
||||
JSEP_HEAP8_INDEX(&w_is_const_),
|
||||
conv_attrs_.activation.c_str(),
|
||||
JSEP_HEAP32_INDEX_START(activation_params),
|
||||
JSEP_HEAP32_INDEX_END(activation_params));
|
||||
}
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
|
||||
"format" : $11 ? "NHWC" : "NCHW",
|
||||
"auto_pad" : $1,
|
||||
"dilations" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [],
|
||||
"group" : $4,
|
||||
"kernel_shape" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [],
|
||||
"pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [],
|
||||
"strides" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [],
|
||||
"w_is_const" : () JS_ARROW(!!HEAP8[$12]),
|
||||
"activation" : UTF8ToString($13),
|
||||
"activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : []
|
||||
}),
|
||||
static_cast<int32_t>(conv_attrs_.auto_pad),
|
||||
JSEP_HEAP32_INDEX_START(dilations),
|
||||
JSEP_HEAP32_INDEX_END(dilations),
|
||||
static_cast<int32_t>(conv_attrs_.group),
|
||||
JSEP_HEAP32_INDEX_START(kernel_shape),
|
||||
JSEP_HEAP32_INDEX_END(kernel_shape),
|
||||
JSEP_HEAP32_INDEX_START(local_pads),
|
||||
JSEP_HEAP32_INDEX_END(local_pads),
|
||||
JSEP_HEAP32_INDEX_START(strides),
|
||||
JSEP_HEAP32_INDEX_END(strides),
|
||||
static_cast<int32_t>(channels_last),
|
||||
JSEP_HEAP8_INDEX(&w_is_const_),
|
||||
conv_attrs_.activation.c_str(),
|
||||
JSEP_HEAP32_INDEX_START(activation_params),
|
||||
JSEP_HEAP32_INDEX_END(activation_params));
|
||||
}
|
||||
|
||||
Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
|
||||
|
|
|
|||
Loading…
Reference in a new issue