[js/webgpu] fix heap access > 2GB (#19010)

This commit is contained in:
Guenther Schmuelling 2024-01-08 17:58:38 -08:00 committed by GitHub
parent 975a315cd7
commit a8bb1df331
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 47 additions and 46 deletions

View file

@ -67,6 +67,7 @@ namespace js {
float value; \
ORT_ENFORCE(info.GetAttr<float>(#attr_name, &value));, \
, ({#attr_name : $1}), static_cast<double>(value))
#define JSEP_HEAP_PTR(ptr) reinterpret_cast<uintptr_t>(ptr)
// TODO:
// class JsMultiProgramKernel : public OpKernel { /* TBD */ };

View file

@ -54,13 +54,13 @@ class ConvBase : public JsKernel {
static_cast<int32_t>(conv_attrs_.group),
static_cast<int32_t>(kernel_shape_0),
static_cast<int32_t>(local_pads.size()),
reinterpret_cast<int32_t>(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_),
JSEP_HEAP_PTR(&w_is_const_),
conv_attrs_.activation.c_str(),
activation_params.size(),
reinterpret_cast<int32_t>(activation_params_ptr) >> 2);
JSEP_HEAP_PTR(activation_params_ptr) >> 2);
} else {
JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({
"format" : $11 ? "NHWC" : "NCHW",
@ -81,14 +81,14 @@ class ConvBase : public JsKernel {
static_cast<int32_t>(kernel_shape_0),
static_cast<int32_t>(kernel_shape_1),
static_cast<int32_t>(local_pads.size()),
reinterpret_cast<int32_t>(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
JSEP_HEAP_PTR(local_pads.size() > 0 ? local_pads.data() : nullptr) >> 2,
static_cast<int32_t>(conv_attrs_.strides.size() > 0 ? conv_attrs_.strides[0] : 0),
static_cast<int32_t>(conv_attrs_.strides.size() > 1 ? conv_attrs_.strides[1] : 0),
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_),
JSEP_HEAP_PTR(&w_is_const_),
conv_attrs_.activation.c_str(),
activation_params.size(),
reinterpret_cast<int32_t>(activation_params_ptr) >> 2);
JSEP_HEAP_PTR(activation_params_ptr) >> 2);
}
}

View file

@ -64,11 +64,11 @@ class ConvTranspose : public JsKernel {
static_cast<int32_t>(pads_1),
static_cast<int32_t>(strides),
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_),
JSEP_HEAP_PTR(&w_is_const_),
gsl::narrow_cast<int32_t>(local_output_padding.size()),
reinterpret_cast<int32_t>(local_output_padding_ptr) >> 2,
JSEP_HEAP_PTR(local_output_padding_ptr) >> 2,
gsl::narrow_cast<int32_t>(local_output_shape.size()),
reinterpret_cast<int32_t>(local_output_shape_ptr) >> 2,
JSEP_HEAP_PTR(local_output_shape_ptr) >> 2,
conv_transpose_attrs_.activation.c_str());
} else {
constexpr size_t pads_vec_size = 4;
@ -114,17 +114,17 @@ class ConvTranspose : public JsKernel {
"activation" : UTF8ToString($13)
}),
static_cast<int32_t>(conv_transpose_attrs_.auto_pad),
reinterpret_cast<int32_t>(local_dilations.data()) >> 2,
JSEP_HEAP_PTR(local_dilations.data()) >> 2,
static_cast<int32_t>(conv_transpose_attrs_.group),
reinterpret_cast<int32_t>(local_kernel_shape.data()) >> 2,
reinterpret_cast<int32_t>(local_pads.data()) >> 2,
reinterpret_cast<int32_t>(local_strides.data()) >> 2,
JSEP_HEAP_PTR(local_kernel_shape.data()) >> 2,
JSEP_HEAP_PTR(local_pads.data()) >> 2,
JSEP_HEAP_PTR(local_strides.data()) >> 2,
static_cast<int32_t>(channels_last),
reinterpret_cast<int32_t>(&w_is_const_),
JSEP_HEAP_PTR(&w_is_const_),
gsl::narrow_cast<int32_t>(local_output_padding.size()),
reinterpret_cast<int32_t>(local_output_padding_ptr) >> 2,
JSEP_HEAP_PTR(local_output_padding_ptr) >> 2,
gsl::narrow_cast<int32_t>(local_output_shape.size()),
reinterpret_cast<int32_t>(local_output_shape_ptr) >> 2,
JSEP_HEAP_PTR(local_output_shape_ptr) >> 2,
conv_transpose_attrs_.activation.c_str());
}
}

View file

@ -26,7 +26,7 @@ class Pad : public JsKernel, public PadBase {
static_cast<int32_t>(mode_),
static_cast<double>(value_),
gsl::narrow_cast<int32_t>(pads.size()),
reinterpret_cast<int32_t>((pads.size() > 0) ? pads.data() : nullptr) >> 2);
JSEP_HEAP_PTR((pads.size() > 0) ? pads.data() : nullptr) >> 2);
}
};

View file

@ -8,29 +8,29 @@
namespace onnxruntime {
namespace js {
#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \
template <bool allow_multi_axes = true> \
class ReduceKernel : public JsKernel, public ReduceKernelBase<allow_multi_axes> { \
public: \
using ReduceKernelBase<allow_multi_axes>::axes_; \
using ReduceKernelBase<allow_multi_axes>::noop_with_empty_axes_; \
using ReduceKernelBase<allow_multi_axes>::keepdims_; \
ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase<allow_multi_axes>(info) { \
std::vector<int32_t> axes(axes_.size()); \
if (axes_.size() > 0) { \
std::transform(axes_.begin(), axes_.end(), axes.begin(), \
[](int64_t axis) { return gsl::narrow_cast<int32_t>(axis); }); \
} \
JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \
"keepDims" : !!$1, \
"noopWithEmptyAxes" : !!$2, \
"axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \
}), \
static_cast<int32_t>(keepdims_), \
static_cast<int32_t>(noop_with_empty_axes_), \
gsl::narrow_cast<int32_t>(axes.size()), \
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2); \
} \
#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \
template <bool allow_multi_axes = true> \
class ReduceKernel : public JsKernel, public ReduceKernelBase<allow_multi_axes> { \
public: \
using ReduceKernelBase<allow_multi_axes>::axes_; \
using ReduceKernelBase<allow_multi_axes>::noop_with_empty_axes_; \
using ReduceKernelBase<allow_multi_axes>::keepdims_; \
ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase<allow_multi_axes>(info) { \
std::vector<int32_t> axes(axes_.size()); \
if (axes_.size() > 0) { \
std::transform(axes_.begin(), axes_.end(), axes.begin(), \
[](int64_t axis) { return gsl::narrow_cast<int32_t>(axis); }); \
} \
JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \
"keepDims" : !!$1, \
"noopWithEmptyAxes" : !!$2, \
"axes" : $3 ? (Array.from(HEAP32.subarray($4, $4 + $3))) : [], \
}), \
static_cast<int32_t>(keepdims_), \
static_cast<int32_t>(noop_with_empty_axes_), \
gsl::narrow_cast<int32_t>(axes.size()), \
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2); \
} \
};
JSEP_DEFINE_REDUCE_KERNEL(ReduceMax);

View file

@ -34,7 +34,7 @@ class Resize : public JsKernel, public UpsampleBase {
}),
static_cast<int32_t>(antialias_),
gsl::narrow_cast<int32_t>(axes.size()),
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2,
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2,
resize_coordinate_transformation_mode.c_str(),
static_cast<double>(cubic_coeff_a_),
static_cast<int32_t>(exclude_outside_),

View file

@ -24,11 +24,11 @@ class Slice : public JsKernel, public SliceBase {
"ends" : $3 ? Array.from(HEAP32.subarray($4, $4 + $3)) : [],
"axes" : $5 ? Array.from(HEAP32.subarray($6, $6 + $5)) : []}),
gsl::narrow_cast<int32_t>(starts.size()),
reinterpret_cast<int32_t>((starts.size() > 0) ? starts.data() : nullptr) >> 2,
JSEP_HEAP_PTR((starts.size() > 0) ? starts.data() : nullptr) >> 2,
gsl::narrow_cast<int32_t>(ends.size()),
reinterpret_cast<int32_t>((ends.size() > 0) ? ends.data() : nullptr) >> 2,
JSEP_HEAP_PTR((ends.size() > 0) ? ends.data() : nullptr) >> 2,
gsl::narrow_cast<int32_t>(axes.size()),
reinterpret_cast<int32_t>((axes.size() > 0) ? axes.data() : nullptr) >> 2);
JSEP_HEAP_PTR((axes.size() > 0) ? axes.data() : nullptr) >> 2);
}
};

View file

@ -53,7 +53,7 @@ class Split : public JsKernel, public SplitBase {
static_cast<int32_t>(axis_),
static_cast<int32_t>(num_outputs_),
gsl::narrow_cast<int32_t>(split_sizes.size()),
reinterpret_cast<int32_t>((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2);
JSEP_HEAP_PTR((split_sizes.size() > 0) ? split_sizes.data() : nullptr) >> 2);
}
};

View file

@ -27,7 +27,7 @@ class Transpose final : public JsKernel, public TransposeBase {
gsl::narrow_cast<int32_t>(perm_specified_ ? perm_.size() : 0),
// $2: index to HEAP32 of the first int32 element. calculated from right shift memory
// address by 2
reinterpret_cast<int32_t>(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2);
JSEP_HEAP_PTR(perm_specified_ && !perm.empty() ? perm.data() : nullptr) >> 2);
}
};