// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "onnxruntime_cxx_api.h" #include #include #include #include #include "common.h" #include "session_options_helper.h" #include "tensor_helper.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" #endif #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" #endif #ifdef USE_WEBGPU #include "core/providers/webgpu/webgpu_provider_factory.h" #endif #ifdef USE_TENSORRT #include "core/providers/tensorrt/tensorrt_provider_options.h" #endif #ifdef USE_COREML #include "core/providers/coreml/coreml_provider_factory.h" #endif const std::unordered_map GRAPH_OPT_LEVEL_NAME_TO_ID_MAP = { {"disabled", ORT_DISABLE_ALL}, {"basic", ORT_ENABLE_BASIC}, {"extended", ORT_ENABLE_EXTENDED}, {"all", ORT_ENABLE_ALL}}; const std::unordered_map EXECUTION_MODE_NAME_TO_ID_MAP = {{"sequential", ORT_SEQUENTIAL}, {"parallel", ORT_PARALLEL}}; void ParseExecutionProviders(const Napi::Array epList, Ort::SessionOptions& sessionOptions) { for (uint32_t i = 0; i < epList.Length(); i++) { Napi::Value epValue = epList[i]; std::string name; int deviceId = 0; #ifdef USE_COREML int coreMlFlags = 0; #endif #ifdef USE_WEBGPU std::unordered_map webgpu_options; #endif if (epValue.IsString()) { name = epValue.As().Utf8Value(); } else if (!epValue.IsObject() || epValue.IsNull() || !epValue.As().Has("name") || !epValue.As().Get("name").IsString()) { ORT_NAPI_THROW_TYPEERROR(epList.Env(), "Invalid argument: sessionOptions.executionProviders[", i, "] must be either a string or an object with property 'name'."); } else { auto obj = epValue.As(); name = obj.Get("name").As().Utf8Value(); if (obj.Has("deviceId")) { deviceId = obj.Get("deviceId").As(); } #ifdef USE_COREML if (obj.Has("coreMlFlags")) { coreMlFlags = obj.Get("coreMlFlags").As(); } #endif #ifdef USE_WEBGPU for (const auto& nameIter : obj.GetPropertyNames()) { Napi::Value nameVar = nameIter.second; std::string name = nameVar.As().Utf8Value(); if (name != "name") { Napi::Value valueVar = obj.Get(nameVar); ORT_NAPI_THROW_TYPEERROR_IF(!valueVar.IsString(), epList.Env(), "Invalid argument: sessionOptions.executionProviders must be a string or an object with property 'name'."); std::string value = valueVar.As().Utf8Value(); webgpu_options[name] = value; } } #endif } // CPU execution provider if (name == "cpu") { // TODO: handling CPU EP options #ifdef USE_CUDA } else if (name == "cuda") { OrtCUDAProviderOptionsV2* options; Ort::GetApi().CreateCUDAProviderOptions(&options); options->device_id = deviceId; sessionOptions.AppendExecutionProvider_CUDA_V2(*options); Ort::GetApi().ReleaseCUDAProviderOptions(options); #endif #ifdef USE_TENSORRT } else if (name == "tensorrt") { OrtTensorRTProviderOptionsV2* options; Ort::GetApi().CreateTensorRTProviderOptions(&options); options->device_id = deviceId; sessionOptions.AppendExecutionProvider_TensorRT_V2(*options); Ort::GetApi().ReleaseTensorRTProviderOptions(options); #endif #ifdef USE_DML } else if (name == "dml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_DML(sessionOptions, deviceId)); #endif #ifdef USE_WEBGPU } else if (name == "webgpu") { sessionOptions.AppendExecutionProvider("WebGPU", webgpu_options); #endif #ifdef USE_COREML } else if (name == "coreml") { Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, coreMlFlags)); #endif #ifdef USE_QNN } else if (name == "qnn") { std::unordered_map qnn_options; qnn_options["backend_path"] = "QnnHtp.dll"; qnn_options["enable_htp_fp16_precision"] = "1"; sessionOptions.AppendExecutionProvider("QNN", qnn_options); #endif } else { ORT_NAPI_THROW_ERROR(epList.Env(), "Invalid argument: sessionOptions.executionProviders[", i, "] is unsupported: '", name, "'."); } } } void IterateExtraOptions(const std::string& prefix, const Napi::Object& obj, Ort::SessionOptions& sessionOptions) { for (const auto& kvp : obj) { auto key = kvp.first.As().Utf8Value(); Napi::Value value = kvp.second; if (value.IsObject()) { IterateExtraOptions(prefix + key + ".", value.As(), sessionOptions); } else { ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString(), obj.Env(), "Invalid argument: sessionOptions.extra value must be a string in Node.js binding."); std::string entry = prefix + key; auto val = value.As().Utf8Value(); sessionOptions.AddConfigEntry(entry.c_str(), val.c_str()); } } } void ParseSessionOptions(const Napi::Object options, Ort::SessionOptions& sessionOptions) { // Execution provider if (options.Has("executionProviders")) { auto epsValue = options.Get("executionProviders"); ORT_NAPI_THROW_TYPEERROR_IF(!epsValue.IsArray(), options.Env(), "Invalid argument: sessionOptions.executionProviders must be an array."); ParseExecutionProviders(epsValue.As(), sessionOptions); } // Intra threads number if (options.Has("intraOpNumThreads")) { auto numValue = options.Get("intraOpNumThreads"); ORT_NAPI_THROW_TYPEERROR_IF(!numValue.IsNumber(), options.Env(), "Invalid argument: sessionOptions.intraOpNumThreads must be a number."); double num = numValue.As().DoubleValue(); ORT_NAPI_THROW_RANGEERROR_IF(std::floor(num) != num || num < 0 || num > 4294967295, options.Env(), "'intraOpNumThreads' is invalid: ", num); sessionOptions.SetIntraOpNumThreads(static_cast(num)); } // Inter threads number if (options.Has("interOpNumThreads")) { auto numValue = options.Get("interOpNumThreads"); ORT_NAPI_THROW_TYPEERROR_IF(!numValue.IsNumber(), options.Env(), "Invalid argument: sessionOptions.interOpNumThreads must be a number."); double num = numValue.As().DoubleValue(); ORT_NAPI_THROW_RANGEERROR_IF(std::floor(num) != num || num < 0 || num > 4294967295, options.Env(), "'interOpNumThreads' is invalid: ", num); sessionOptions.SetInterOpNumThreads(static_cast(num)); } // Optimization level if (options.Has("graphOptimizationLevel")) { auto optLevelValue = options.Get("graphOptimizationLevel"); ORT_NAPI_THROW_TYPEERROR_IF(!optLevelValue.IsString(), options.Env(), "Invalid argument: sessionOptions.graphOptimizationLevel must be a string."); auto optLevelString = optLevelValue.As().Utf8Value(); auto v = GRAPH_OPT_LEVEL_NAME_TO_ID_MAP.find(optLevelString); ORT_NAPI_THROW_TYPEERROR_IF(v == GRAPH_OPT_LEVEL_NAME_TO_ID_MAP.end(), options.Env(), "'graphOptimizationLevel' is not supported: ", optLevelString); sessionOptions.SetGraphOptimizationLevel(v->second); } // CPU memory arena if (options.Has("enableCpuMemArena")) { auto enableCpuMemArenaValue = options.Get("enableCpuMemArena"); ORT_NAPI_THROW_TYPEERROR_IF(!enableCpuMemArenaValue.IsBoolean(), options.Env(), "Invalid argument: sessionOptions.enableCpuMemArena must be a boolean value."); if (enableCpuMemArenaValue.As().Value()) { sessionOptions.EnableCpuMemArena(); } else { sessionOptions.DisableCpuMemArena(); } } // memory pattern if (options.Has("enableMemPattern")) { auto enableMemPatternValue = options.Get("enableMemPattern"); ORT_NAPI_THROW_TYPEERROR_IF(!enableMemPatternValue.IsBoolean(), options.Env(), "Invalid argument: sessionOptions.enableMemPattern must be a boolean value."); if (enableMemPatternValue.As().Value()) { sessionOptions.EnableMemPattern(); } else { sessionOptions.DisableMemPattern(); } } // optimizedModelFilePath if (options.Has("optimizedModelFilePath")) { auto optimizedModelFilePathValue = options.Get("optimizedModelFilePath"); ORT_NAPI_THROW_TYPEERROR_IF(!optimizedModelFilePathValue.IsString(), options.Env(), "Invalid argument: sessionOptions.optimizedModelFilePath must be a string."); #ifdef _WIN32 auto str = optimizedModelFilePathValue.As().Utf16Value(); std::filesystem::path optimizedModelFilePath{std::wstring{str.begin(), str.end()}}; #else std::filesystem::path optimizedModelFilePath{optimizedModelFilePathValue.As().Utf8Value()}; #endif sessionOptions.SetOptimizedModelFilePath(optimizedModelFilePath.c_str()); } // extra if (options.Has("extra")) { auto extraValue = options.Get("extra"); ORT_NAPI_THROW_TYPEERROR_IF(!extraValue.IsObject(), options.Env(), "Invalid argument: sessionOptions.extra must be an object."); IterateExtraOptions("", extraValue.As(), sessionOptions); } // execution mode if (options.Has("executionMode")) { auto executionModeValue = options.Get("executionMode"); ORT_NAPI_THROW_TYPEERROR_IF(!executionModeValue.IsString(), options.Env(), "Invalid argument: sessionOptions.executionMode must be a string."); auto executionModeString = executionModeValue.As().Utf8Value(); auto v = EXECUTION_MODE_NAME_TO_ID_MAP.find(executionModeString); ORT_NAPI_THROW_TYPEERROR_IF(v == EXECUTION_MODE_NAME_TO_ID_MAP.end(), options.Env(), "'executionMode' is not supported: ", executionModeString); sessionOptions.SetExecutionMode(v->second); } // log ID if (options.Has("logId")) { auto logIdValue = options.Get("logId"); ORT_NAPI_THROW_TYPEERROR_IF(!logIdValue.IsString(), options.Env(), "Invalid argument: sessionOptions.logId must be a string."); auto logIdString = logIdValue.As().Utf8Value(); sessionOptions.SetLogId(logIdString.c_str()); } // Log severity level if (options.Has("logSeverityLevel")) { auto logLevelValue = options.Get("logSeverityLevel"); ORT_NAPI_THROW_TYPEERROR_IF(!logLevelValue.IsNumber(), options.Env(), "Invalid argument: sessionOptions.logSeverityLevel must be a number."); double logLevelNumber = logLevelValue.As().DoubleValue(); ORT_NAPI_THROW_RANGEERROR_IF( std::floor(logLevelNumber) != logLevelNumber || logLevelNumber < 0 || logLevelNumber > 4, options.Env(), "Invalid argument: sessionOptions.logSeverityLevel must be one of the following: 0, 1, 2, 3, 4."); sessionOptions.SetLogSeverityLevel(static_cast(logLevelNumber)); } // Profiling if (options.Has("enableProfiling")) { auto enableProfilingValue = options.Get("enableProfiling"); ORT_NAPI_THROW_TYPEERROR_IF(!enableProfilingValue.IsBoolean(), options.Env(), "Invalid argument: sessionOptions.enableProfiling must be a boolean value."); if (enableProfilingValue.As().Value()) { ORT_NAPI_THROW_TYPEERROR_IF(!options.Has("profileFilePrefix"), options.Env(), "Invalid argument: sessionOptions.profileFilePrefix is required" " when sessionOptions.enableProfiling is set to true."); auto profileFilePrefixValue = options.Get("profileFilePrefix"); ORT_NAPI_THROW_TYPEERROR_IF(!profileFilePrefixValue.IsString(), options.Env(), "Invalid argument: sessionOptions.profileFilePrefix must be a string." " when sessionOptions.enableProfiling is set to true."); #ifdef _WIN32 auto str = profileFilePrefixValue.As().Utf16Value(); std::basic_string profileFilePrefix = std::wstring{str.begin(), str.end()}; #else std::basic_string profileFilePrefix = profileFilePrefixValue.As().Utf8Value(); #endif sessionOptions.EnableProfiling(profileFilePrefix.c_str()); } else { sessionOptions.DisableProfiling(); } } // external data if (options.Has("externalData")) { auto externalDataValue = options.Get("externalData"); ORT_NAPI_THROW_TYPEERROR_IF(!externalDataValue.IsArray(), options.Env(), "Invalid argument: sessionOptions.externalData must be an array."); auto externalData = externalDataValue.As(); std::vector> paths; std::vector buffs; std::vector sizes; for (const auto& kvp : externalData) { Napi::Value value = kvp.second; ORT_NAPI_THROW_TYPEERROR_IF(!value.IsObject(), options.Env(), "Invalid argument: sessionOptions.externalData value must be an object in Node.js binding."); Napi::Object obj = value.As(); ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("path") || !obj.Get("path").IsString(), options.Env(), "Invalid argument: sessionOptions.externalData value must have a 'path' property of type string in Node.js binding."); #ifdef _WIN32 auto path = obj.Get("path").As().Utf16Value(); paths.push_back(std::wstring{path.begin(), path.end()}); #else auto path = obj.Get("path").As().Utf8Value(); paths.push_back(path); #endif ORT_NAPI_THROW_TYPEERROR_IF(!obj.Has("data") || !obj.Get("data").IsBuffer() || !(obj.Get("data").IsTypedArray() && obj.Get("data").As().TypedArrayType() == napi_uint8_array), options.Env(), "Invalid argument: sessionOptions.externalData value must have an 'data' property of type buffer or typed array in Node.js binding."); auto data = obj.Get("data"); if (data.IsBuffer()) { buffs.push_back(data.As>().Data()); sizes.push_back(data.As>().Length()); } else { auto typedArray = data.As(); buffs.push_back(reinterpret_cast(typedArray.ArrayBuffer().Data()) + typedArray.ByteOffset()); sizes.push_back(typedArray.ByteLength()); } } sessionOptions.AddExternalInitializersFromFilesInMemory(paths, buffs, sizes); } } void ParsePreferredOutputLocations(const Napi::Object options, const std::vector& outputNames, std::vector& preferredOutputLocations) { if (options.Has("preferredOutputLocation")) { auto polValue = options.Get("preferredOutputLocation"); if (polValue.IsNull() || polValue.IsUndefined()) { return; } if (polValue.IsString()) { DataLocation location = ParseDataLocation(polValue.As().Utf8Value()); ORT_NAPI_THROW_TYPEERROR_IF(location == DATA_LOCATION_NONE, options.Env(), "Invalid argument: preferredOutputLocation must be an array or a valid string."); if (location == DATA_LOCATION_GPU_BUFFER || location == DATA_LOCATION_ML_TENSOR) { preferredOutputLocations.resize(outputNames.size(), location); } } else if (polValue.IsObject()) { preferredOutputLocations.resize(outputNames.size(), DATA_LOCATION_CPU); auto pol = polValue.As(); for (const auto& nameIter : pol.GetPropertyNames()) { Napi::Value nameVar = nameIter.second; std::string name = nameVar.As().Utf8Value(); // find the name in outputNames auto it = std::find(outputNames.begin(), outputNames.end(), name); ORT_NAPI_THROW_TYPEERROR_IF(it == outputNames.end(), options.Env(), "Invalid argument: \"", name, "\" is not a valid output name."); Napi::Value value = pol.Get(nameVar); DataLocation location = DATA_LOCATION_NONE; ORT_NAPI_THROW_TYPEERROR_IF(!value.IsString() || (location = ParseDataLocation(value.As().Utf8Value())) == DATA_LOCATION_NONE, options.Env(), "Invalid argument: preferredOutputLocation[\"", name, "\"] must be a valid string."); size_t index = it - outputNames.begin(); preferredOutputLocations[index] = location; } if (std::all_of(preferredOutputLocations.begin(), preferredOutputLocations.end(), [](int loc) { return loc == DATA_LOCATION_CPU; })) { preferredOutputLocations.clear(); } } else { ORT_NAPI_THROW_TYPEERROR(options.Env(), "Invalid argument: preferredOutputLocation must be an array or a valid string."); } } }