onnxruntime/js/node/lib/binding.ts
Arthur Islamov c262879214
Added DML and CUDA provider support in onnxruntime-node (#16050)
### Description
I've added changes to support CUDA and DML (only on Windows, on other
platforms it will throw an error)



### Motivation and Context
It fixes this feature request
https://github.com/microsoft/onnxruntime/issues/14127 which is tracked
here https://github.com/microsoft/onnxruntime/issues/14529

I was working on StableDiffusion implementation for node.js and it is
very slow on CPU, so GPU support is essential.

Here is a working demo with a patched and precompiled version
https://github.com/dakenf/stable-diffusion-nodejs

---------
2023-08-25 16:57:06 -07:00

50 lines
1.5 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {InferenceSession, OnnxValue} from 'onnxruntime-common';
type SessionOptions = InferenceSession.SessionOptions;
type FeedsType = {
[name: string]: OnnxValue;
};
type FetchesType = {
[name: string]: OnnxValue|null;
};
type ReturnType = {
[name: string]: OnnxValue;
};
type RunOptions = InferenceSession.RunOptions;
/**
* Binding exports a simple synchronized inference session object wrap.
*/
export declare namespace Binding {
export interface InferenceSession {
loadModel(modelPath: string, options: SessionOptions): void;
loadModel(buffer: ArrayBuffer, byteOffset: number, byteLength: number, options: SessionOptions): void;
readonly inputNames: string[];
readonly outputNames: string[];
run(feeds: FeedsType, fetches: FetchesType, options: RunOptions): ReturnType;
}
export interface InferenceSessionConstructor {
new(): InferenceSession;
}
export interface SupportedBackend {
name: string;
bundled: boolean;
}
}
// export native binding
export const binding =
// eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires
require(`../bin/napi-v3/${process.platform}/${process.arch}/onnxruntime_binding.node`) as {
// eslint-disable-next-line @typescript-eslint/naming-convention
InferenceSession: Binding.InferenceSessionConstructor;
listSupportedBackends: () => Binding.SupportedBackend[];
};