onnxruntime/js/web/lib/wasm/proxy-messages.ts
Caroline Zhu 64de71c5e2
[js/web/training] Add CreateTrainingSession (#17891)
### Description
* Adds TrainingSession.create() functionality following the web bindings
for training design doc
* Added 2 new training APIs to wasm/api.h:
   * OrtTrainingGetInputOutputName
   * OrtTrainingGetInputOutputCount
* Moved isOrtEnvInitialized boolean to the wasm-core-impl and added a
method that references it

### Motivation and Context
* Adding web bindings for training

#### Related work
* #16521 allowed for training artifacts to be built
* #17333 added interfaces for training
* #17474 allows for training package to be built + adds training backend
to web package **[MUST BE MERGED IN BEFORE THIS ONE]**

---------

Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Co-authored-by: Ashwini Khade <askhade@microsoft.com>
2023-10-26 09:22:10 -07:00

82 lines
2.5 KiB
TypeScript

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import type {Env, InferenceSession, Tensor} from 'onnxruntime-common';
export type SerializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu'];
export type GpuBufferMetadata = {
gpuBuffer: Tensor.GpuBufferType;
download?: () => Promise<Tensor.DataTypeMap[Tensor.GpuBufferDataTypes]>;
dispose?: () => void;
};
export type UnserializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']|
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned'];
export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata;
export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]];
export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number];
interface MessageError {
err?: string;
}
interface MessageInitWasm extends MessageError {
type: 'init-wasm';
in ?: Env.WebAssemblyFlags;
}
interface MessageInitOrt extends MessageError {
type: 'init-ort';
in ?: Env;
}
interface MessageCreateSessionAllocate extends MessageError {
type: 'create_allocate';
in ?: {model: Uint8Array};
out?: SerializableModeldata;
}
interface MessageCreateSessionFinalize extends MessageError {
type: 'create_finalize';
in ?: {modeldata: SerializableModeldata; options?: InferenceSession.SessionOptions};
out?: SerializableSessionMetadata;
}
interface MessageCreateSession extends MessageError {
type: 'create';
in ?: {model: Uint8Array; options?: InferenceSession.SessionOptions};
out?: SerializableSessionMetadata;
}
interface MessageReleaseSession extends MessageError {
type: 'release';
in ?: number;
}
interface MessageRun extends MessageError {
type: 'run';
in ?: {
sessionId: number; inputIndices: number[]; inputs: SerializableTensorMetadata[]; outputIndices: number[];
options: InferenceSession.RunOptions;
};
out?: SerializableTensorMetadata[];
}
interface MesssageEndProfiling extends MessageError {
type: 'end-profiling';
in ?: number;
}
interface MessageIsOrtEnvInitialized extends MessageError {
type: 'is-ort-env-initialized';
out?: boolean;
}
export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;