mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-27 03:11:28 +00:00
[js/web] unify resolve rules for "Clip" (#18527)
### Description It was a mistake to use 2 different names for Clip operator in op-resolve-rules.ts for different opset. An optimized implementation can handle both cases (opset < 11 and opset >=11). Remove "ClipV10" as an entry from the table.
This commit is contained in:
parent
abdf8b7c3f
commit
c7fd930330
3 changed files with 9 additions and 13 deletions
|
|
@ -55,7 +55,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
|
|||
['BiasSplitGelu', [biasSplitGelu]],
|
||||
['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
|
||||
['Ceil', [unaryOps.ceil]],
|
||||
['ClipV10', [unaryOps.clipV10]],
|
||||
['Clip', [unaryOps.clip]],
|
||||
['Concat', [concat, parseConcatAttributes]],
|
||||
['Conv', [conv, parseConvAttributes]],
|
||||
|
|
|
|||
|
|
@ -124,7 +124,14 @@ export interface ClipAttributes extends AttributeWithCacheKey {
|
|||
readonly max: number;
|
||||
}
|
||||
|
||||
export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
|
||||
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
|
||||
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
|
||||
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
|
||||
return createAttributeWithCacheKey({min, max});
|
||||
};
|
||||
|
||||
export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
|
||||
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
|
||||
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
|
||||
context.compute(
|
||||
createElementwiseProgramInfo(
|
||||
|
|
@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo
|
|||
attributes.cacheKey),
|
||||
{inputs: [0]});
|
||||
};
|
||||
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
|
||||
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
|
||||
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
|
||||
return createAttributeWithCacheKey({min, max});
|
||||
};
|
||||
|
||||
export const clip = (context: ComputeContext): void => {
|
||||
const attributes = generateClipAttributesFromInputs(context.inputs);
|
||||
clipV10(context, attributes);
|
||||
};
|
||||
|
||||
export const ceil = (context: ComputeContext): void => {
|
||||
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil'));
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not)
|
|||
|
||||
// activation
|
||||
|
||||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f)
|
||||
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f)
|
||||
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10)
|
||||
JSEP_KERNEL_IMPL(Clip, Clip)
|
||||
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider,
|
||||
|
|
|
|||
Loading…
Reference in a new issue