[js/web] unify resolve rules for "Clip" (#18527)

### Description
It was a mistake to use 2 different names for Clip operator in
op-resolve-rules.ts for different opset. An optimized implementation can
handle both cases (opset < 11 and opset >=11). Remove "ClipV10" as an
entry from the table.
This commit is contained in:
Yulong Wang 2023-11-20 23:18:06 -08:00 committed by GitHub
parent abdf8b7c3f
commit c7fd930330
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 13 deletions

View file

@ -55,7 +55,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['BiasSplitGelu', [biasSplitGelu]],
['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]],
['Ceil', [unaryOps.ceil]],
['ClipV10', [unaryOps.clipV10]],
['Clip', [unaryOps.clip]],
['Concat', [concat, parseConcatAttributes]],
['Conv', [conv, parseConvAttributes]],

View file

@ -124,7 +124,14 @@ export interface ClipAttributes extends AttributeWithCacheKey {
readonly max: number;
}
export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): void => {
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};
export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
@ -135,16 +142,6 @@ export const clipV10 = (context: ComputeContext, attributes: ClipAttributes): vo
attributes.cacheKey),
{inputs: [0]});
};
const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAttributes => {
const min = (inputs.length >= 2) ? inputs[1].getFloat32Array()[0] : MIN_CLIP;
const max = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : MAX_CLIP;
return createAttributeWithCacheKey({min, max});
};
export const clip = (context: ComputeContext): void => {
const attributes = generateClipAttributesFromInputs(context.inputs);
clipV10(context, attributes);
};
export const ceil = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Ceil', 'ceil'));

View file

@ -123,7 +123,7 @@ JSEP_ELEMENTWISE_TYPED_KERNEL(Not, 1, bool, Not)
// activation
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, ClipV10, min, 3.402823e+38f, max, -3.402823e+38f)
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(ClipV10, Clip, min, 3.402823e+38f, max, -3.402823e+38f)
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Clip, 6, 10, ClipV10)
JSEP_KERNEL_IMPL(Clip, Clip)
ONNX_OPERATOR_VERSIONED_KERNEL_EX(Clip, kOnnxDomain, 11, 11, kJsExecutionProvider,