mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
[js/webnn] Enable user-supplied MLContext (#20600)
### Description This PR enables the API added in #20816 as well as moving context creation to JS. ### Motivation and Context In order to enable I/O Binding with the upcoming [MLBuffer](https://github.com/webmachinelearning/webnn/issues/542) API in the WebNN specification, we need to share the same `MLContext` across multiple sessions. This is because `MLBuffer`s are restricted to the `MLContext` where they were created. This PR enables developers to use the same `MLContext` across multiple sessions.
This commit is contained in:
parent
cd516a1677
commit
4c3c809bdb
8 changed files with 458 additions and 55 deletions
401
js/web/lib/wasm/jsep/webnn/webnn.d.ts
vendored
Normal file
401
js/web/lib/wasm/jsep/webnn/webnn.d.ts
vendored
Normal file
|
|
@ -0,0 +1,401 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
interface NavigatorML {
|
||||
readonly ml: ML;
|
||||
}
|
||||
interface Navigator extends NavigatorML {}
|
||||
interface WorkerNavigator extends NavigatorML {}
|
||||
type MLDeviceType = 'cpu'|'gpu'|'npu';
|
||||
type MLPowerPreference = 'default'|'high-performance'|'low-power';
|
||||
interface MLContextOptions {
|
||||
deviceType?: MLDeviceType;
|
||||
powerPreference?: MLPowerPreference;
|
||||
numThreads?: number;
|
||||
}
|
||||
interface ML {
|
||||
createContext(options?: MLContextOptions): Promise<MLContext>;
|
||||
createContext(gpuDevice: GPUDevice): Promise<MLContext>;
|
||||
}
|
||||
type MLNamedArrayBufferViews = Record<string, ArrayBufferView>;
|
||||
interface MLComputeResult {
|
||||
inputs?: MLNamedArrayBufferViews;
|
||||
outputs?: MLNamedArrayBufferViews;
|
||||
}
|
||||
interface MLContext {
|
||||
compute(graph: MLGraph, inputs: MLNamedArrayBufferViews, outputs: MLNamedArrayBufferViews): Promise<MLComputeResult>;
|
||||
}
|
||||
interface MLGraph {}
|
||||
type MLInputOperandLayout = 'nchw'|'nhwc';
|
||||
type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8';
|
||||
interface MLOperandDescriptor {
|
||||
dataType: MLOperandDataType;
|
||||
dimensions?: number[];
|
||||
}
|
||||
interface MLOperand {
|
||||
dataType(): MLOperandDataType;
|
||||
shape(): number[];
|
||||
}
|
||||
interface MLActivation {}
|
||||
type MLNamedOperands = Record<string, MLOperand>;
|
||||
interface MLGraphBuilder {
|
||||
// eslint-disable-next-line @typescript-eslint/no-misused-new
|
||||
new(context: MLContext): MLGraphBuilder;
|
||||
input(name: string, descriptor: MLOperandDescriptor): MLOperand;
|
||||
constant(descriptor: MLOperandDescriptor, bufferView: ArrayBufferView): MLOperand;
|
||||
constant(type: MLOperandDataType, value: number): MLOperand;
|
||||
build(outputs: MLNamedOperands): Promise<MLGraph>;
|
||||
}
|
||||
interface MLArgMinMaxOptions {
|
||||
axes?: number[];
|
||||
keepDimensions?: boolean;
|
||||
selectLastIndex?: boolean;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
argMin(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand;
|
||||
argMax(input: MLOperand, options?: MLArgMinMaxOptions): MLOperand;
|
||||
}
|
||||
interface MLBatchNormalizationOptions {
|
||||
scale?: MLOperand;
|
||||
bias?: MLOperand;
|
||||
axis?: number;
|
||||
epsilon?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
batchNormalization(input: MLOperand, mean: MLOperand, variance: MLOperand, options?: MLBatchNormalizationOptions):
|
||||
MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
cast(input: MLOperand, type: MLOperandDataType): MLOperand;
|
||||
}
|
||||
interface MLClampOptions {
|
||||
minValue?: number;
|
||||
maxValue?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
clamp(input: MLOperand, options?: MLClampOptions): MLOperand;
|
||||
clamp(options?: MLClampOptions): MLActivation;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
concat(inputs: MLOperand[], axis: number): MLOperand;
|
||||
}
|
||||
type MLConv2dFilterOperandLayout = 'oihw'|'hwio'|'ohwi'|'ihwo';
|
||||
interface MLConv2dOptions {
|
||||
padding?: number[];
|
||||
strides?: number[];
|
||||
dilations?: number[];
|
||||
groups?: number;
|
||||
inputLayout?: MLInputOperandLayout;
|
||||
filterLayout?: MLConv2dFilterOperandLayout;
|
||||
bias?: MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
conv2d(input: MLOperand, filter: MLOperand, options?: MLConv2dOptions): MLOperand;
|
||||
}
|
||||
type MLConvTranspose2dFilterOperandLayout = 'iohw'|'hwoi'|'ohwi';
|
||||
interface MLConvTranspose2dOptions {
|
||||
padding?: number[];
|
||||
strides?: number[];
|
||||
dilations?: number[];
|
||||
outputPadding?: number[];
|
||||
outputSizes?: number[];
|
||||
groups?: number;
|
||||
inputLayout?: MLInputOperandLayout;
|
||||
filterLayout?: MLConvTranspose2dFilterOperandLayout;
|
||||
bias?: MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
convTranspose2d(input: MLOperand, filter: MLOperand, options?: MLConvTranspose2dOptions): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
add(a: MLOperand, b: MLOperand): MLOperand;
|
||||
sub(a: MLOperand, b: MLOperand): MLOperand;
|
||||
mul(a: MLOperand, b: MLOperand): MLOperand;
|
||||
div(a: MLOperand, b: MLOperand): MLOperand;
|
||||
max(a: MLOperand, b: MLOperand): MLOperand;
|
||||
min(a: MLOperand, b: MLOperand): MLOperand;
|
||||
pow(a: MLOperand, b: MLOperand): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
equal(a: MLOperand, b: MLOperand): MLOperand;
|
||||
greater(a: MLOperand, b: MLOperand): MLOperand;
|
||||
greaterOrEqual(a: MLOperand, b: MLOperand): MLOperand;
|
||||
lesser(a: MLOperand, b: MLOperand): MLOperand;
|
||||
lesserOrEqual(a: MLOperand, b: MLOperand): MLOperand;
|
||||
logicalNot(a: MLOperand): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
abs(input: MLOperand): MLOperand;
|
||||
ceil(input: MLOperand): MLOperand;
|
||||
cos(input: MLOperand): MLOperand;
|
||||
erf(input: MLOperand): MLOperand;
|
||||
exp(input: MLOperand): MLOperand;
|
||||
floor(input: MLOperand): MLOperand;
|
||||
identity(input: MLOperand): MLOperand;
|
||||
log(input: MLOperand): MLOperand;
|
||||
neg(input: MLOperand): MLOperand;
|
||||
reciprocal(input: MLOperand): MLOperand;
|
||||
sin(input: MLOperand): MLOperand;
|
||||
sqrt(input: MLOperand): MLOperand;
|
||||
tan(input: MLOperand): MLOperand;
|
||||
}
|
||||
interface MLEluOptions {
|
||||
alpha?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
elu(input: MLOperand, options?: MLEluOptions): MLOperand;
|
||||
elu(options?: MLEluOptions): MLActivation;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
expand(input: MLOperand, newShape: number[]): MLOperand;
|
||||
}
|
||||
interface MLGatherOptions {
|
||||
axis?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
gather(input: MLOperand, indices: MLOperand, options?: MLGatherOptions): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
gelu(input: MLOperand): MLOperand;
|
||||
gelu(): MLActivation;
|
||||
}
|
||||
interface MLGemmOptions {
|
||||
c?: MLOperand;
|
||||
alpha?: number;
|
||||
beta?: number;
|
||||
aTranspose?: boolean;
|
||||
bTranspose?: boolean;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
gemm(a: MLOperand, b: MLOperand, options?: MLGemmOptions): MLOperand;
|
||||
}
|
||||
type MLGruWeightLayout = 'zrn'|'rzn';
|
||||
type MLRecurrentNetworkDirection = 'forward'|'backward'|'both';
|
||||
interface MLGruOptions {
|
||||
bias?: MLOperand;
|
||||
recurrentBias?: MLOperand;
|
||||
initialHiddenState?: MLOperand;
|
||||
resetAfter?: boolean;
|
||||
returnSequence?: boolean;
|
||||
direction?: MLRecurrentNetworkDirection;
|
||||
layout?: MLGruWeightLayout;
|
||||
activations?: MLActivation[];
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
gru(input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number,
|
||||
options?: MLGruOptions): MLOperand[];
|
||||
}
|
||||
interface MLGruCellOptions {
|
||||
bias?: MLOperand;
|
||||
recurrentBias?: MLOperand;
|
||||
resetAfter?: boolean;
|
||||
layout?: MLGruWeightLayout;
|
||||
activations?: MLActivation[];
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
gruCell(
|
||||
input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, hiddenSize: number,
|
||||
options?: MLGruCellOptions): MLOperand;
|
||||
}
|
||||
interface MLHardSigmoidOptions {
|
||||
alpha?: number;
|
||||
beta?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
hardSigmoid(input: MLOperand, options?: MLHardSigmoidOptions): MLOperand;
|
||||
hardSigmoid(options?: MLHardSigmoidOptions): MLActivation;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
hardSwish(input: MLOperand): MLOperand;
|
||||
hardSwish(): MLActivation;
|
||||
}
|
||||
interface MLInstanceNormalizationOptions {
|
||||
scale?: MLOperand;
|
||||
bias?: MLOperand;
|
||||
epsilon?: number;
|
||||
layout?: MLInputOperandLayout;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
instanceNormalization(input: MLOperand, options?: MLInstanceNormalizationOptions): MLOperand;
|
||||
}
|
||||
interface MLLayerNormalizationOptions {
|
||||
scale?: MLOperand;
|
||||
bias?: MLOperand;
|
||||
axes?: number[];
|
||||
epsilon?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
layerNormalization(input: MLOperand, options?: MLLayerNormalizationOptions): MLOperand;
|
||||
}
|
||||
interface MLLeakyReluOptions {
|
||||
alpha?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
leakyRelu(input: MLOperand, options?: MLLeakyReluOptions): MLOperand;
|
||||
leakyRelu(options?: MLLeakyReluOptions): MLActivation;
|
||||
}
|
||||
interface MLLinearOptions {
|
||||
alpha?: number;
|
||||
beta?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
linear(input: MLOperand, options?: MLLinearOptions): MLOperand;
|
||||
linear(options?: MLLinearOptions): MLActivation;
|
||||
}
|
||||
type MLLstmWeightLayout = 'iofg'|'ifgo';
|
||||
interface MLLstmOptions {
|
||||
bias?: MLOperand;
|
||||
recurrentBias?: MLOperand;
|
||||
peepholeWeight?: MLOperand;
|
||||
initialHiddenState?: MLOperand;
|
||||
initialCellState?: MLOperand;
|
||||
returnSequence?: boolean;
|
||||
direction?: MLRecurrentNetworkDirection;
|
||||
layout?: MLLstmWeightLayout;
|
||||
activations?: MLActivation[];
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
lstm(
|
||||
input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, steps: number, hiddenSize: number,
|
||||
options?: MLLstmOptions): MLOperand[];
|
||||
}
|
||||
interface MLLstmCellOptions {
|
||||
bias?: MLOperand;
|
||||
recurrentBias?: MLOperand;
|
||||
peepholeWeight?: MLOperand;
|
||||
layout?: MLLstmWeightLayout;
|
||||
activations?: MLActivation[];
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
lstmCell(
|
||||
input: MLOperand, weight: MLOperand, recurrentWeight: MLOperand, hiddenState: MLOperand, cellState: MLOperand,
|
||||
hiddenSize: number, options?: MLLstmCellOptions): MLOperand[];
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
matmul(a: MLOperand, b: MLOperand): MLOperand;
|
||||
}
|
||||
type MLPaddingMode = 'constant'|'edge'|'reflection'|'symmetric';
|
||||
interface MLPadOptions {
|
||||
mode?: MLPaddingMode;
|
||||
value?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
pad(input: MLOperand, beginningPadding: number[], endingPadding: number[], options?: MLPadOptions): MLOperand;
|
||||
}
|
||||
type MLRoundingType = 'floor'|'ceil';
|
||||
interface MLPool2dOptions {
|
||||
windowDimensions?: number[];
|
||||
padding?: number[];
|
||||
strides?: number[];
|
||||
dilations?: number[];
|
||||
layout?: MLInputOperandLayout;
|
||||
roundingType?: MLRoundingType;
|
||||
outputSizes?: number[];
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
averagePool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand;
|
||||
l2Pool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand;
|
||||
maxPool2d(input: MLOperand, options?: MLPool2dOptions): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
prelu(input: MLOperand, slope: MLOperand): MLOperand;
|
||||
}
|
||||
interface MLReduceOptions {
|
||||
axes?: number[];
|
||||
keepDimensions?: boolean;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
reduceL1(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
reduceL2(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
reduceLogSum(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
reduceLogSumExp(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
reduceMax(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
reduceMean(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
reduceMin(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
reduceProduct(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
reduceSum(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
reduceSumSquare(input: MLOperand, options?: MLReduceOptions): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
relu(input: MLOperand): MLOperand;
|
||||
relu(): MLActivation;
|
||||
}
|
||||
type MLInterpolationMode = 'nearest-neighbor'|'linear';
|
||||
interface MLResample2dOptions {
|
||||
mode?: MLInterpolationMode;
|
||||
scales?: number[];
|
||||
sizes?: number[];
|
||||
axes?: number[];
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
resample2d(input: MLOperand, options?: MLResample2dOptions): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
reshape(input: MLOperand, newShape: number[]): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
sigmoid(input: MLOperand): MLOperand;
|
||||
sigmoid(): MLActivation;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
slice(input: MLOperand, starts: number[], sizes: number[]): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
softmax(input: MLOperand, axis: number): MLOperand;
|
||||
softmax(axis: number): MLActivation;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
softplus(input: MLOperand): MLOperand;
|
||||
softplus(): MLActivation;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
softsign(input: MLOperand): MLOperand;
|
||||
softsign(): MLActivation;
|
||||
}
|
||||
interface MLSplitOptions {
|
||||
axis?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
split(input: MLOperand, splits: number|number[], options?: MLSplitOptions): MLOperand[];
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
tanh(input: MLOperand): MLOperand;
|
||||
tanh(): MLActivation;
|
||||
}
|
||||
interface MLTransposeOptions {
|
||||
permutation?: number[];
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
transpose(input: MLOperand, options?: MLTransposeOptions): MLOperand;
|
||||
}
|
||||
interface MLTriangularOptions {
|
||||
upper?: boolean;
|
||||
diagonal?: number;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
triangular(input: MLOperand, options?: MLTriangularOptions): MLOperand;
|
||||
}
|
||||
interface MLGraphBuilder {
|
||||
where(condition: MLOperand, input: MLOperand, other: MLOperand): MLOperand;
|
||||
}
|
||||
|
||||
// Experimental MLBuffer interface
|
||||
|
||||
type MLSize64Out = number;
|
||||
interface MLBuffer {
|
||||
readonly size: MLSize64Out;
|
||||
destroy(): void;
|
||||
}
|
||||
type MLSize64 = number;
|
||||
interface MLBufferDescriptor {
|
||||
size: MLSize64;
|
||||
}
|
||||
type MLNamedBuffers = Record<string, MLBuffer>;
|
||||
interface MLContext {
|
||||
createBuffer(descriptor: MLBufferDescriptor): MLBuffer;
|
||||
writeBuffer(
|
||||
dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: MLSize64,
|
||||
srcElementSize?: MLSize64): void;
|
||||
readBuffer(srcBuffer: MLBuffer): Promise<ArrayBuffer>;
|
||||
dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void;
|
||||
}
|
||||
|
|
@ -66,8 +66,6 @@ const setExecutionProviders =
|
|||
const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption;
|
||||
// const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
|
||||
const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
|
||||
const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;
|
||||
const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
|
||||
if (deviceType) {
|
||||
const keyDataOffset = allocWasmString('deviceType', allocs);
|
||||
const valueDataOffset = allocWasmString(deviceType, allocs);
|
||||
|
|
@ -76,26 +74,6 @@ const setExecutionProviders =
|
|||
checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`);
|
||||
}
|
||||
}
|
||||
if (numThreads !== undefined) {
|
||||
// Just ignore invalid webnnOptions.numThreads.
|
||||
const validatedNumThreads =
|
||||
(typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 :
|
||||
numThreads;
|
||||
const keyDataOffset = allocWasmString('numThreads', allocs);
|
||||
const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs);
|
||||
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
|
||||
0) {
|
||||
checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`);
|
||||
}
|
||||
}
|
||||
if (powerPreference) {
|
||||
const keyDataOffset = allocWasmString('powerPreference', allocs);
|
||||
const valueDataOffset = allocWasmString(powerPreference, allocs);
|
||||
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
|
||||
0) {
|
||||
checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 'webgpu':
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
|
||||
// WebNN API specification.
|
||||
// https://github.com/webmachinelearning/webnn/issues/677
|
||||
/// <reference path="jsep/webnn/webnn.d.ts" />
|
||||
|
||||
import {Env, InferenceSession, Tensor} from 'onnxruntime-common';
|
||||
|
||||
import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
|
||||
|
|
@ -253,11 +258,43 @@ export const createSession = async(
|
|||
await Promise.all(loadingPromises);
|
||||
}
|
||||
|
||||
for (const provider of options?.executionProviders ?? []) {
|
||||
const providerName = typeof provider === 'string' ? provider : provider.name;
|
||||
if (providerName === 'webnn') {
|
||||
if (wasm.currentContext) {
|
||||
throw new Error('WebNN execution provider is already set.');
|
||||
}
|
||||
if (typeof provider !== 'string') {
|
||||
const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption;
|
||||
const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
|
||||
const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice;
|
||||
const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
|
||||
const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;
|
||||
const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
|
||||
if (context) {
|
||||
wasm.currentContext = context as MLContext;
|
||||
} else if (gpuDevice) {
|
||||
wasm.currentContext = await navigator.ml.createContext(gpuDevice);
|
||||
} else {
|
||||
wasm.currentContext = await navigator.ml.createContext({deviceType, numThreads, powerPreference});
|
||||
}
|
||||
} else {
|
||||
wasm.currentContext = await navigator.ml.createContext();
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
|
||||
if (sessionHandle === 0) {
|
||||
checkLastError('Can\'t create a session.');
|
||||
}
|
||||
|
||||
// clear current MLContext after session creation
|
||||
if (wasm.currentContext) {
|
||||
wasm.currentContext = undefined;
|
||||
}
|
||||
|
||||
const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
|
||||
|
||||
const enableGraphCapture = !!options?.enableGraphCapture;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
|
||||
// WebNN API specification.
|
||||
// https://github.com/webmachinelearning/webnn/issues/677
|
||||
/// <reference path="jsep/webnn/webnn.d.ts" />
|
||||
|
||||
import type {Tensor} from 'onnxruntime-common';
|
||||
|
||||
/* eslint-disable @typescript-eslint/naming-convention */
|
||||
|
|
@ -19,7 +24,7 @@ export declare namespace JSEP {
|
|||
type CaptureEndFunction = () => void;
|
||||
type ReplayFunction = () => void;
|
||||
|
||||
export interface Module extends WebGpuModule {
|
||||
export interface Module extends WebGpuModule, WebNnModule {
|
||||
/**
|
||||
* Mount the external data file to an internal map, which will be used during session initialization.
|
||||
*
|
||||
|
|
@ -106,6 +111,13 @@ export declare namespace JSEP {
|
|||
*/
|
||||
jsepOnReleaseSession: (sessionId: number) => void;
|
||||
}
|
||||
|
||||
export interface WebNnModule {
|
||||
/**
|
||||
* Active MLContext used to create WebNN EP.
|
||||
*/
|
||||
currentContext: MLContext;
|
||||
}
|
||||
}
|
||||
|
||||
export interface OrtInferenceAPIs {
|
||||
|
|
|
|||
|
|
@ -17,24 +17,12 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags,
|
||||
const std::string& webnn_threads_number, const std::string& webnn_power_flags)
|
||||
WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags)
|
||||
: IExecutionProvider{onnxruntime::kWebNNExecutionProvider} {
|
||||
// Create WebNN context and graph builder.
|
||||
const emscripten::val ml = emscripten::val::global("navigator")["ml"];
|
||||
if (!ml.as<bool>()) {
|
||||
ORT_THROW("Failed to get ml from navigator.");
|
||||
}
|
||||
emscripten::val context_options = emscripten::val::object();
|
||||
context_options.set("deviceType", emscripten::val(webnn_device_flags));
|
||||
// WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend.
|
||||
if (webnn_device_flags.compare("cpu") == 0) {
|
||||
preferred_layout_ = DataLayout::NHWC;
|
||||
wnn_device_type_ = webnn::WebnnDeviceType::CPU;
|
||||
// Set "numThreads" if it's not default 0.
|
||||
if (webnn_threads_number.compare("0") != 0) {
|
||||
context_options.set("numThreads", stoi(webnn_threads_number));
|
||||
}
|
||||
} else {
|
||||
preferred_layout_ = DataLayout::NCHW;
|
||||
if (webnn_device_flags.compare("gpu") == 0) {
|
||||
|
|
@ -45,11 +33,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
|
|||
ORT_THROW("Unknown WebNN deviceType.");
|
||||
}
|
||||
}
|
||||
if (webnn_power_flags.compare("default") != 0) {
|
||||
context_options.set("powerPreference", emscripten::val(webnn_power_flags));
|
||||
}
|
||||
|
||||
wnn_context_ = ml.call<emscripten::val>("createContext", context_options).await();
|
||||
wnn_context_ = emscripten::val::module_property("currentContext");
|
||||
if (!wnn_context_.as<bool>()) {
|
||||
ORT_THROW("Failed to create WebNN context.");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ class Model;
|
|||
|
||||
class WebNNExecutionProvider : public IExecutionProvider {
|
||||
public:
|
||||
WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
|
||||
const std::string& webnn_power_flags);
|
||||
explicit WebNNExecutionProvider(const std::string& webnn_device_flags);
|
||||
virtual ~WebNNExecutionProvider();
|
||||
|
||||
std::vector<std::unique_ptr<ComputeCapability>>
|
||||
|
|
|
|||
|
|
@ -10,27 +10,22 @@ using namespace onnxruntime;
|
|||
|
||||
namespace onnxruntime {
|
||||
struct WebNNProviderFactory : IExecutionProviderFactory {
|
||||
WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
|
||||
const std::string& webnn_power_flags)
|
||||
: webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {}
|
||||
explicit WebNNProviderFactory(const std::string& webnn_device_flags)
|
||||
: webnn_device_flags_(webnn_device_flags) {}
|
||||
~WebNNProviderFactory() override {}
|
||||
|
||||
std::unique_ptr<IExecutionProvider> CreateProvider() override;
|
||||
|
||||
std::string webnn_device_flags_;
|
||||
std::string webnn_threads_number_;
|
||||
std::string webnn_power_flags_;
|
||||
};
|
||||
|
||||
std::unique_ptr<IExecutionProvider> WebNNProviderFactory::CreateProvider() {
|
||||
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_);
|
||||
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_);
|
||||
}
|
||||
|
||||
std::shared_ptr<IExecutionProviderFactory> WebNNProviderFactoryCreator::Create(
|
||||
const ProviderOptions& provider_options) {
|
||||
return std::make_shared<onnxruntime::WebNNProviderFactory>(provider_options.at("deviceType"),
|
||||
provider_options.at("numThreads"),
|
||||
provider_options.at("powerPreference"));
|
||||
return std::make_shared<onnxruntime::WebNNProviderFactory>(provider_options.at("deviceType"));
|
||||
}
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -127,11 +127,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
|
|||
} else if (strcmp(provider_name, "WEBNN") == 0) {
|
||||
#if defined(USE_WEBNN)
|
||||
std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu");
|
||||
std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0");
|
||||
std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default");
|
||||
provider_options["deviceType"] = deviceType;
|
||||
provider_options["numThreads"] = numThreads;
|
||||
provider_options["powerPreference"] = powerPreference;
|
||||
options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options));
|
||||
#else
|
||||
status = create_not_supported_status();
|
||||
|
|
|
|||
Loading…
Reference in a new issue