fix f16 for attention, enable slice and flatten for more types (#19262)

This commit is contained in:
Guenther Schmuelling 2024-01-29 10:13:46 -08:00 committed by GitHub
parent e96a038f01
commit 9e69606360
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 9 additions and 13 deletions

View file

@ -297,7 +297,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
if (sum == 0) {
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')};
x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')};
}
} else {
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {

View file

@ -13,7 +13,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -23,7 +23,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -33,7 +33,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);
ONNX_OPERATOR_KERNEL_EX(
@ -43,7 +43,7 @@ ONNX_OPERATOR_KERNEL_EX(
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(0, 0)
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
.TypeConstraint("T", JsepSupportedFloatTypes()),
Flatten);
} // namespace js

View file

@ -12,8 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1, 9,
kJsExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Slice_1);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -26,8 +25,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3)
.InputMemoryType(OrtMemTypeCPU, 4)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Slice);
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@ -40,8 +38,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3)
.InputMemoryType(OrtMemTypeCPU, 4)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Slice);
ONNX_OPERATOR_KERNEL_EX(
@ -54,8 +51,7 @@ ONNX_OPERATOR_KERNEL_EX(
.InputMemoryType(OrtMemTypeCPU, 2)
.InputMemoryType(OrtMemTypeCPU, 3)
.InputMemoryType(OrtMemTypeCPU, 4)
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<int32_t>()}),
.TypeConstraint("T", JsepSupportedDataTypes()),
Slice);
} // namespace js