mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-05 04:17:53 +00:00
[js/webgpu] fix heap access > 2GB (#19010)
This commit is contained in:
parent
975a315cd7
commit
a8bb1df331
9 changed files with 47 additions and 46 deletions
|
|
@ -67,6 +67,7 @@ namespace js {
|
|||
float value; \
|
||||
ORT_ENFORCE(info.GetAttr<float>(#attr_name, &value));, \
|
||||
, ({#attr_name : $1}), static_cast<double>(value))
|
||||
#define JSEP_HEAP_PTR(ptr) reinterpret_cast<uintptr_t>(ptr)
|
||||
|
||||
// TODO:
|
||||
// class JsMultiProgramKernel : public OpKernel { /* TBD */ };
|
||||
|
|
|
|||
|
|
@ -54,13 +54,13 @@ class ConvBase : public JsKernel {
|
|||
static_cast<int32_t>(conv_attrs_.group),
|
||||
static_cast<int32_t>(kernel_shape_0),
|
||||
static_cast<int32_t>(local_pads.size()),
|
||||
reinterpret_cast<int32_t>(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
|
||||
JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
|
||||
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
|
||||
static_cast<int32_t>(channels_last),
|
||||
reinterpret_cast<int32_t>(&w_is_const_),
|
||||
JSEP_HEAP_PTR(&w_is_const_),
|
||||
conv_attrs_.activation.c_str(),
|
||||
activation_params.size(),
|
||||
reinterpret_cast<int32_t>(activation_params_ptr) >> 2);
|
||||
JSEP_HEAP_PTR(activation_params_ptr) >> 2);
|
||||
} else {
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
|
||||
"format" : $11 ? "NHWC" : "NCHW",
|
||||
|
|
@ -81,14 +81,14 @@ class ConvBase : public JsKernel {
|
|||
static_cast<int32_t>(kernel_shape_0),
|
||||
static_cast<int32_t>(kernel_shape_1),
|
||||
static_cast<int32_t>(local_pads.size()),
|
||||
reinterpret_cast<int32_t>(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
|
||||
JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
|
||||
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
|
||||
static_cast<int32_t>(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0),
|
||||
static_cast<int32_t>(channels_last),
|
||||
reinterpret_cast<int32_t>(&w_is_const_),
|
||||
JSEP_HEAP_PTR(&w_is_const_),
|
||||
conv_attrs_.activation.c_str(),
|
||||
activation_params.size(),
|
||||
reinterpret_cast<int32_t>(activation_params_ptr) >> 2);
|
||||
JSEP_HEAP_PTR(activation_params_ptr) >> 2);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -64,11 +64,11 @@ class ConvTranspose : public JsKernel {
|
|||
static_cast<int32_t>(pads_1),
|
||||
static_cast<int32_t>(strides),
|
||||
static_cast<int32_t>(channels_last),
|
||||
reinterpret_cast<int32_t>(&w_is_const_),
|
||||
JSEP_HEAP_PTR(&w_is_const_),
|
||||
gsl::narrow_cast<int32_t>(local_output_padding.size()),
|
||||
reinterpret_cast<int32_t>(local_output_padding_ptr) >> 2,
|
||||
JSEP_HEAP_PTR(local_output_padding_ptr) >> 2,
|
||||
gsl::narrow_cast<int32_t>(local_output_shape.size()),
|
||||
reinterpret_cast<int32_t>(local_output_shape_ptr) >> 2,
|
||||
JSEP_HEAP_PTR(local_output_shape_ptr) >> 2,
|
||||
conv_transpose_attrs_.activation.c_str());
|
||||
} else {
|
||||
constexpr size_t pads_vec_size = 4;
|
||||
|
|
@ -114,17 +114,17 @@ class ConvTranspose : public JsKernel {
|
|||
"activation" : UTF8ToString($13)
|
||||
}),
|
||||
static_cast<int32_t>(conv_transpose_attrs_.auto_pad),
|
||||
reinterpret_cast<int32_t>(local_dilations.data()) >> 2,
|
||||
JSEP_HEAP_PTR(local_dilations.data()) >> 2,
|
||||
static_cast<int32_t>(conv_transpose_attrs_.group),
|
||||
reinterpret_cast<int32_t>(local_kernel_shape.data()) >> 2,
|
||||
reinterpret_cast<int32_t>(local_pads.data()) >> 2,
|
||||
reinterpret_cast<int32_t>(local_strides.data()) >> 2,
|
||||
JSEP_HEAP_PTR(local_kernel_shape.data()) >> 2,
|
||||
JSEP_HEAP_PTR(local_pads.data()) >> 2,
|
||||
JSEP_HEAP_PTR(local_strides.data()) >> 2,
|
||||
static_cast<int32_t>(channels_last),
|
||||
reinterpret_cast<int32_t>(&w_is_const_),
|
||||
JSEP_HEAP_PTR(&w_is_const_),
|
||||
gsl::narrow_cast<int32_t>(local_output_padding.size()),
|
||||
reinterpret_cast<int32_t>(local_output_padding_ptr) >> 2,
|
||||
JSEP_HEAP_PTR(local_output_padding_ptr) >> 2,
|
||||
gsl::narrow_cast<int32_t>(local_output_shape.size()),
|
||||
reinterpret_cast<int32_t>(local_output_shape_ptr) >> 2,
|
||||
JSEP_HEAP_PTR(local_output_shape_ptr) >> 2,
|
||||
conv_transpose_attrs_.activation.c_str());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class Pad : public JsKernel, public PadBase {
|
|||
static_cast<int32_t>(mode_),
|
||||
static_cast<double>(value_),
|
||||
gsl::narrow_cast<int32_t>(pads.size()),
|
||||
reinterpret_cast<int32_t>((pads.size() > 0) ? pads.data() : nullptr) >> 2);
|
||||
JSEP_HEAP_PTR((pads.size() > 0) ? pads.data() : nullptr) >> 2);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -8,29 +8,29 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
namespace js {
|
||||
#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \
|
||||
template <bool allow_multi_axes = true> \
|
||||
class ReduceKernel : public JsKernel, public ReduceKernelBase<allow_multi_axes> { \
|
||||
public: \
|
||||
using ReduceKernelBase<allow_multi_axes>::axes_; \
|
||||
using ReduceKernelBase<allow_multi_axes>::noop_with_empty_axes_; \
|
||||
using ReduceKernelBase<allow_multi_axes>::keepdims_; \
|
||||
ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase<allow_multi_axes>(info) { \
|
||||
std::vector<int32_t> axes(axes_.size()); \
|
||||
if (axes_.size() > 0) { \
|
||||
std::transform(axes_.begin(), axes_.end(), axes.begin(), \
|
||||
[](int64_t axis) { return gsl::narrow_cast<int32_t>(axis); }); \
|
||||
} \
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \
|
||||
"keepDims" : !!$1, \
|
||||
"noopWithEmptyAxes" : !!$2, \
|
||||
"axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \
|
||||
}), \
|
||||
static_cast<int32_t>(keepdims_), \
|
||||
static_cast<int32_t>(noop_with_empty_axes_), \
|
||||
gsl::narrow_cast<int32_t>(axes.size()), \
|
||||
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2); \
|
||||
} \
|
||||
#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \
|
||||
template <bool allow_multi_axes = true> \
|
||||
class ReduceKernel : public JsKernel, public ReduceKernelBase<allow_multi_axes> { \
|
||||
public: \
|
||||
using ReduceKernelBase<allow_multi_axes>::axes_; \
|
||||
using ReduceKernelBase<allow_multi_axes>::noop_with_empty_axes_; \
|
||||
using ReduceKernelBase<allow_multi_axes>::keepdims_; \
|
||||
ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase<allow_multi_axes>(info) { \
|
||||
std::vector<int32_t> axes(axes_.size()); \
|
||||
if (axes_.size() > 0) { \
|
||||
std::transform(axes_.begin(), axes_.end(), axes.begin(), \
|
||||
[](int64_t axis) { return gsl::narrow_cast<int32_t>(axis); }); \
|
||||
} \
|
||||
JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \
|
||||
"keepDims" : !!$1, \
|
||||
"noopWithEmptyAxes" : !!$2, \
|
||||
"axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \
|
||||
}), \
|
||||
static_cast<int32_t>(keepdims_), \
|
||||
static_cast<int32_t>(noop_with_empty_axes_), \
|
||||
gsl::narrow_cast<int32_t>(axes.size()), \
|
||||
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2); \
|
||||
} \
|
||||
};
|
||||
|
||||
JSEP_DEFINE_REDUCE_KERNEL(ReduceMax);
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class Resize : public JsKernel, public UpsampleBase {
|
|||
}),
|
||||
static_cast<int32_t>(antialias_),
|
||||
gsl::narrow_cast<int32_t>(axes.size()),
|
||||
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2,
|
||||
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2,
|
||||
resize_coordinate_transformation_mode.c_str(),
|
||||
static_cast<double>(cubic_coeff_a_),
|
||||
static_cast<int32_t>(exclude_outside_),
|
||||
|
|
|
|||
|
|
@ -24,11 +24,11 @@ class Slice : public JsKernel, public SliceBase {
|
|||
"ends" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : [],
|
||||
"axes" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : []}),
|
||||
gsl::narrow_cast<int32_t>(starts.size()),
|
||||
reinterpret_cast<int32_t>((starts.size() > 0) ? starts.data() : nullptr) >> 2,
|
||||
JSEP_HEAP_PTR((starts.size() > 0) ? starts.data() : nullptr) >> 2,
|
||||
gsl::narrow_cast<int32_t>(ends.size()),
|
||||
reinterpret_cast<int32_t>((ends.size() > 0) ? ends.data() : nullptr) >> 2,
|
||||
JSEP_HEAP_PTR((ends.size() > 0) ? ends.data() : nullptr) >> 2,
|
||||
gsl::narrow_cast<int32_t>(axes.size()),
|
||||
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2);
|
||||
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class Split : public JsKernel, public SplitBase {
|
|||
static_cast<int32_t>(axis_),
|
||||
static_cast<int32_t>(num_outputs_),
|
||||
gsl::narrow_cast<int32_t>(split_sizes.size()),
|
||||
reinterpret_cast<int32_t>((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2);
|
||||
JSEP_HEAP_PTR((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class Transpose final : public JsKernel, public TransposeBase {
|
|||
gsl::narrow_cast<int32_t>(perm_specified_ ? perm_.size() : 0),
|
||||
// $2: index to HEAP32 of the first int32 element. calculated from right shift memory
|
||||
// address by 2
|
||||
reinterpret_cast<int32_t>(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2);
|
||||
JSEP_HEAP_PTR(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue