mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Add SetLanguageProjection C Api and use it in four projections (#5023)
* Add SetLanguageProjection C Api and use it in four projections * static cast enum languageprojection to uint32_t * resolve comments * fix typo and line added unintentionally * revert unecessary change * reorder c# api * add TensorAt and CreateAndRegisterAllocator in Csharp to keep the same order as C apis
This commit is contained in:
parent
6dd4af3936
commit
0dad79b495
13 changed files with 105 additions and 3 deletions
|
|
@ -174,6 +174,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public IntPtr GetBoundOutputValues;
|
||||
public IntPtr ClearBoundInputs;
|
||||
public IntPtr ClearBoundOutputs;
|
||||
public IntPtr TensorAt;
|
||||
public IntPtr CreateAndRegisterAllocator;
|
||||
public IntPtr SetLanguageProjection;
|
||||
}
|
||||
|
||||
internal static class NativeMethods
|
||||
|
|
@ -235,7 +238,7 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtSetSessionGraphOptimizationLevel = (DOrtSetSessionGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionGraphOptimizationLevel, typeof(DOrtSetSessionGraphOptimizationLevel));
|
||||
OrtRegisterCustomOpsLibrary = (DOrtRegisterCustomOpsLibrary)Marshal.GetDelegateForFunctionPointer(api_.RegisterCustomOpsLibrary, typeof(DOrtRegisterCustomOpsLibrary));
|
||||
OrtAddSessionConfigEntry = (DOrtAddSessionConfigEntry)Marshal.GetDelegateForFunctionPointer(api_.AddSessionConfigEntry, typeof(DOrtAddSessionConfigEntry));
|
||||
|
||||
|
||||
OrtCreateRunOptions = (DOrtCreateRunOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateRunOptions, typeof(DOrtCreateRunOptions));
|
||||
OrtReleaseRunOptions = (DOrtReleaseRunOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseRunOptions, typeof(DOrtReleaseRunOptions));
|
||||
OrtRunOptionsSetRunLogVerbosityLevel = (DOrtRunOptionsSetRunLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.RunOptionsSetRunLogVerbosityLevel, typeof(DOrtRunOptionsSetRunLogVerbosityLevel));
|
||||
|
|
@ -272,6 +275,9 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
OrtGetBoundOutputValues = (DOrtGetBoundOutputValues)Marshal.GetDelegateForFunctionPointer(api_.GetBoundOutputValues, typeof(DOrtGetBoundOutputValues));
|
||||
OrtClearBoundInputs = (DOrtClearBoundInputs)Marshal.GetDelegateForFunctionPointer(api_.ClearBoundInputs, typeof(DOrtClearBoundInputs));
|
||||
OrtClearBoundOutputs = (DOrtClearBoundOutputs)Marshal.GetDelegateForFunctionPointer(api_.ClearBoundOutputs, typeof(DOrtClearBoundOutputs));
|
||||
OrtTensorAt = (DOrtTensorAt)Marshal.GetDelegateForFunctionPointer(api_.TensorAt, typeof(DOrtTensorAt));
|
||||
OrtCreateAndRegisterAllocator = (DOrtCreateAndRegisterAllocator)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocator, typeof(DOrtCreateAndRegisterAllocator));
|
||||
OrtSetLanguageProjection = (DOrtSetLanguageProjection)Marshal.GetDelegateForFunctionPointer(api_.SetLanguageProjection, typeof(DOrtSetLanguageProjection));
|
||||
|
||||
OrtGetValue = (DOrtGetValue)Marshal.GetDelegateForFunctionPointer(api_.GetValue, typeof(DOrtGetValue));
|
||||
OrtGetValueType = (DOrtGetValueType)Marshal.GetDelegateForFunctionPointer(api_.GetValueType, typeof(DOrtGetValueType));
|
||||
|
|
@ -776,6 +782,33 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
public delegate void DOrtClearBoundOutputs(IntPtr /*(OrtIoBinding)*/ io_binding);
|
||||
public static DOrtClearBoundOutputs OrtClearBoundOutputs;
|
||||
|
||||
/// <summary>
|
||||
/// Provides element-level access into a tensor.
|
||||
/// </summary>
|
||||
/// <param name="location_values">a pointer to an array of index values that specify an element's location in the tensor data blob</param>
|
||||
/// <param name="location_values_count">length of location_values</param>
|
||||
/// <param name="out">a pointer to the element specified by location_values</param>
|
||||
public delegate void DOrtTensorAt(IntPtr /*(OrtIoBinding)*/ io_binding);
|
||||
public static DOrtTensorAt OrtTensorAt;
|
||||
|
||||
/// <summary>
|
||||
/// Creates an allocator instance and registers it with the env to enable
|
||||
///sharing between multiple sessions that use the same env instance.
|
||||
///Lifetime of the created allocator will be valid for the duration of the environment.
|
||||
///Returns an error if an allocator with the same OrtMemoryInfo is already registered.
|
||||
/// </summary>
|
||||
/// <param name="mem_info">must be non-null</param>
|
||||
/// <param name="arena_cfg">if nullptr defaults will be used</param>
|
||||
public delegate void DOrtCreateAndRegisterAllocator(IntPtr /*(OrtIoBinding)*/ io_binding);
|
||||
public static DOrtCreateAndRegisterAllocator OrtCreateAndRegisterAllocator;
|
||||
|
||||
/// <summary>
|
||||
/// Set the language projection for collecting telemetry data when Env is created
|
||||
/// </summary>
|
||||
/// <param name="projection">the source projected language</param>
|
||||
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetLanguageProjection(IntPtr /* (OrtEnv*) */ environment, OrtLanguageProjection projection);
|
||||
public static DOrtSetLanguageProjection OrtSetLanguageProjection;
|
||||
|
||||
#endregion IoBinding API
|
||||
|
||||
#region ModelMetadata API
|
||||
|
|
|
|||
|
|
@ -24,6 +24,19 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
Fatal = 4
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Language projection property for telemetry event for tracking the source usage of ONNXRUNTIME
|
||||
/// </summary>
|
||||
public enum OrtLanguageProjection
|
||||
{
|
||||
ORT_PROJECTION_C = 0,
|
||||
ORT_PROJECTION_CPLUSPLUS = 1 ,
|
||||
ORT_PROJECTION_CSHARP = 2,
|
||||
ORT_PROJECTION_PYTHON = 3,
|
||||
ORT_PROJECTION_JAVA = 4,
|
||||
ORT_PROJECTION_WINML = 5,
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This class intializes the process-global ONNX runtime
|
||||
/// C# API users do not need to access this, thus kept as internal
|
||||
|
|
@ -52,6 +65,15 @@ namespace Microsoft.ML.OnnxRuntime
|
|||
:base(IntPtr.Zero, true)
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateEnv(LogLevel.Warning, @"CSharpOnnxRuntime", out handle));
|
||||
try
|
||||
{
|
||||
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetLanguageProjection(handle, OrtLanguageProjection.ORT_PROJECTION_CSHARP));
|
||||
}
|
||||
catch (OnnxRuntimeException e)
|
||||
{
|
||||
ReleaseHandle();
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
protected override bool ReleaseHandle()
|
||||
|
|
|
|||
|
|
@ -225,6 +225,16 @@ typedef enum ExecutionMode {
|
|||
ORT_PARALLEL = 1,
|
||||
} ExecutionMode;
|
||||
|
||||
// Set the language projection, default is C, which means it will classify the language not in the list to C also.
|
||||
typedef enum OrtLanguageProjection {
|
||||
ORT_PROJECTION_C = 0, // default
|
||||
ORT_PROJECTION_CPLUSPLUS = 1,
|
||||
ORT_PROJECTION_CSHARP = 2,
|
||||
ORT_PROJECTION_PYTHON = 3,
|
||||
ORT_PROJECTION_JAVA = 4,
|
||||
ORT_PROJECTION_WINML = 5,
|
||||
} OrtLanguageProjection;
|
||||
|
||||
struct OrtKernelInfo;
|
||||
typedef struct OrtKernelInfo OrtKernelInfo;
|
||||
struct OrtKernelContext;
|
||||
|
|
@ -1012,6 +1022,13 @@ struct OrtApi {
|
|||
*/
|
||||
ORT_API2_STATUS(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info,
|
||||
_In_ const OrtArenaCfg* arena_cfg);
|
||||
|
||||
/**
|
||||
* Set the language projection for collecting telemetry data when Env is created
|
||||
* \param projection the source projected language.
|
||||
*/
|
||||
ORT_API2_STATUS(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection);
|
||||
|
||||
};
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -288,14 +288,17 @@ inline void IoBinding::ClearBoundOutputs() {
|
|||
|
||||
inline Env::Env(OrtLoggingLevel default_warning_level, _In_ const char* logid) {
|
||||
ThrowOnError(GetApi().CreateEnv(default_warning_level, logid, &p_));
|
||||
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
||||
}
|
||||
|
||||
inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
|
||||
ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_));
|
||||
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
||||
}
|
||||
|
||||
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel default_warning_level, _In_ const char* logid) {
|
||||
ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_));
|
||||
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
|
||||
}
|
||||
|
||||
inline Env& Env::EnableTelemetryEvents() {
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEnvironment_createHandle(JNIEnv *
|
|||
const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, ©);
|
||||
checkOrtStatus(jniEnv,api,api->CreateEnv(convertLoggingLevel(loggingLevel), cName, &env));
|
||||
(*jniEnv)->ReleaseStringUTFChars(jniEnv,name,cName);
|
||||
checkOrtStatus(jniEnv, api, api->SetLanguageProjection(env, ORT_PROJECTION_JAVA));
|
||||
return (jlong) env;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,10 @@ void Telemetry::EnableTelemetryEvents() const {
|
|||
void Telemetry::DisableTelemetryEvents() const {
|
||||
}
|
||||
|
||||
void Telemetry::SetLanguageProjection(uint32_t projection) const {
|
||||
ORT_UNUSED_PARAMETER(projection);
|
||||
}
|
||||
|
||||
void Telemetry::LogProcessInfo() const {
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ class Telemetry {
|
|||
|
||||
virtual void EnableTelemetryEvents() const;
|
||||
virtual void DisableTelemetryEvents() const;
|
||||
virtual void SetLanguageProjection(uint32_t projection) const;
|
||||
|
||||
virtual void LogProcessInfo() const;
|
||||
|
||||
|
|
|
|||
|
|
@ -57,6 +57,7 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim
|
|||
OrtMutex WindowsTelemetry::mutex_;
|
||||
uint32_t WindowsTelemetry::global_register_count_ = 0;
|
||||
bool WindowsTelemetry::enabled_ = true;
|
||||
uint32_t WindowsTelemetry::projection_ = 0;
|
||||
|
||||
|
||||
WindowsTelemetry::WindowsTelemetry() {
|
||||
|
|
@ -88,6 +89,10 @@ void WindowsTelemetry::DisableTelemetryEvents() const {
|
|||
enabled_ = false;
|
||||
}
|
||||
|
||||
void WindowsTelemetry::SetLanguageProjection(uint32_t projection) const {
|
||||
projection_ = projection;
|
||||
}
|
||||
|
||||
void WindowsTelemetry::LogProcessInfo() const {
|
||||
if (global_register_count_ == 0 || enabled_ == false)
|
||||
return;
|
||||
|
|
@ -196,6 +201,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
|
|||
TraceLoggingUInt8(0, "schemaVersion"),
|
||||
TraceLoggingUInt32(session_id, "sessionId"),
|
||||
TraceLoggingInt64(ir_version, "irVersion"),
|
||||
TraceLoggingUInt32(projection_, "OrtProgrammingProjection"),
|
||||
TraceLoggingString(model_producer_name.c_str(), "modelProducerName"),
|
||||
TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"),
|
||||
TraceLoggingString(model_domain.c_str(), "modelDomain"),
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ class WindowsTelemetry : public Telemetry {
|
|||
|
||||
void EnableTelemetryEvents() const override;
|
||||
void DisableTelemetryEvents() const override;
|
||||
void SetLanguageProjection(uint32_t projection) const override;
|
||||
|
||||
void LogProcessInfo() const override;
|
||||
|
||||
|
|
@ -48,6 +49,7 @@ class WindowsTelemetry : public Telemetry {
|
|||
static OrtMutex mutex_;
|
||||
static uint32_t global_register_count_;
|
||||
static bool enabled_;
|
||||
static uint32_t projection_;
|
||||
};
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1752,6 +1752,16 @@ ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, size_t* location
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(OrtApis::SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection) {
|
||||
API_IMPL_BEGIN
|
||||
ORT_UNUSED_PARAMETER(ort_env);
|
||||
// note telemetry is controlled via the platform Env object, not the OrtEnv object instance
|
||||
const Env& env = Env::Default();
|
||||
env.GetTelemetryProvider().SetLanguageProjection(static_cast<uint32_t>(projection));
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
// End support for non-tensor types
|
||||
|
||||
static constexpr OrtApiBase ort_api_base = {
|
||||
|
|
@ -1972,6 +1982,7 @@ static constexpr OrtApi ort_api_1_to_5 = {
|
|||
&OrtApis::ClearBoundOutputs,
|
||||
&OrtApis::TensorAt,
|
||||
&OrtApis::CreateAndRegisterAllocator,
|
||||
&OrtApis::SetLanguageProjection,
|
||||
};
|
||||
|
||||
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)
|
||||
|
|
|
|||
|
|
@ -231,4 +231,6 @@ ORT_API_STATUS_IMPL(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options,
|
|||
ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
|
||||
|
||||
ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg);
|
||||
|
||||
ORT_API_STATUS_IMPL(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection);
|
||||
} // namespace OrtApis
|
||||
|
|
|
|||
|
|
@ -1462,7 +1462,7 @@ void InitializeEnv() {
|
|||
// import_array1() forces a void return value.
|
||||
import_array1();
|
||||
})();
|
||||
|
||||
Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON);
|
||||
OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique<LoggingManager>(
|
||||
std::unique_ptr<ISink>{new CLogSink{}},
|
||||
Severity::kWARNING, false, LoggingManager::InstanceType::Default,
|
||||
|
|
|
|||
|
|
@ -190,8 +190,8 @@ OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_
|
|||
OrtEnv* ort_env = nullptr;
|
||||
THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env),
|
||||
ort_api);
|
||||
THROW_IF_NOT_OK_MSG(ort_api->SetLanguageProjection(ort_env, OrtLanguageProjection::ORT_PROJECTION_WINML), ort_api);
|
||||
ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv);
|
||||
|
||||
// Configure the environment with the winml logger
|
||||
auto winml_adapter_api = GetVersionedWinmlAdapterApi(ort_api);
|
||||
THROW_IF_NOT_OK_MSG(winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(ort_env_.get(),
|
||||
|
|
|
|||
Loading…
Reference in a new issue