mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-12 00:59:23 +00:00
fix f16 for attention, enable slice and flatten for more types (#19262)
This commit is contained in:
parent
e96a038f01
commit
9e69606360
3 changed files with 9 additions and 13 deletions
|
|
@ -297,7 +297,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
|
|||
|
||||
if (sum == 0) {
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
|
||||
x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')};
|
||||
x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')};
|
||||
}
|
||||
} else {
|
||||
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
|
@ -23,7 +23,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
|
@ -33,7 +33,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Flatten);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
|
|
@ -43,7 +43,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.Alias(0, 0)
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
|
||||
.TypeConstraint("T", JsepSupportedFloatTypes()),
|
||||
Flatten);
|
||||
|
||||
} // namespace js
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
1, 9,
|
||||
kJsExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()}),
|
||||
.TypeConstraint("T", JsepSupportedDataTypes()),
|
||||
Slice_1);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
|
@ -26,8 +25,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3)
|
||||
.InputMemoryType(OrtMemTypeCPU, 4)
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()}),
|
||||
.TypeConstraint("T", JsepSupportedDataTypes()),
|
||||
Slice);
|
||||
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
||||
|
|
@ -40,8 +38,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
|
|||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3)
|
||||
.InputMemoryType(OrtMemTypeCPU, 4)
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()}),
|
||||
.TypeConstraint("T", JsepSupportedDataTypes()),
|
||||
Slice);
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
|
|
@ -54,8 +51,7 @@ ONNX_OPERATOR_KERNEL_EX(
|
|||
.InputMemoryType(OrtMemTypeCPU, 2)
|
||||
.InputMemoryType(OrtMemTypeCPU, 3)
|
||||
.InputMemoryType(OrtMemTypeCPU, 4)
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<int32_t>()}),
|
||||
.TypeConstraint("T", JsepSupportedDataTypes()),
|
||||
Slice);
|
||||
|
||||
} // namespace js
|
||||
|
|
|
|||
Loading…
Reference in a new issue