Add SetLanguageProjection C Api and use it in four projections (#5023)

* Add SetLanguageProjection C Api and use it in four projections

* static cast enum languageprojection to uint32_t

* resolve comments

* fix typo and line added unintentionally

* revert unecessary change

* reorder c# api

* add TensorAt and CreateAndRegisterAllocator in Csharp to keep the same order as C apis
This commit is contained in:
Xiang Zhang 2020-09-04 14:26:39 -07:00 committed by GitHub
parent 6dd4af3936
commit 0dad79b495
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 105 additions and 3 deletions

View file

@ -174,6 +174,9 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr GetBoundOutputValues;
public IntPtr ClearBoundInputs;
public IntPtr ClearBoundOutputs;
public IntPtr TensorAt;
public IntPtr CreateAndRegisterAllocator;
public IntPtr SetLanguageProjection;
}
internal static class NativeMethods
@ -235,7 +238,7 @@ namespace Microsoft.ML.OnnxRuntime
OrtSetSessionGraphOptimizationLevel = (DOrtSetSessionGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionGraphOptimizationLevel, typeof(DOrtSetSessionGraphOptimizationLevel));
OrtRegisterCustomOpsLibrary = (DOrtRegisterCustomOpsLibrary)Marshal.GetDelegateForFunctionPointer(api_.RegisterCustomOpsLibrary, typeof(DOrtRegisterCustomOpsLibrary));
OrtAddSessionConfigEntry = (DOrtAddSessionConfigEntry)Marshal.GetDelegateForFunctionPointer(api_.AddSessionConfigEntry, typeof(DOrtAddSessionConfigEntry));
OrtCreateRunOptions = (DOrtCreateRunOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateRunOptions, typeof(DOrtCreateRunOptions));
OrtReleaseRunOptions = (DOrtReleaseRunOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseRunOptions, typeof(DOrtReleaseRunOptions));
OrtRunOptionsSetRunLogVerbosityLevel = (DOrtRunOptionsSetRunLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.RunOptionsSetRunLogVerbosityLevel, typeof(DOrtRunOptionsSetRunLogVerbosityLevel));
@ -272,6 +275,9 @@ namespace Microsoft.ML.OnnxRuntime
OrtGetBoundOutputValues = (DOrtGetBoundOutputValues)Marshal.GetDelegateForFunctionPointer(api_.GetBoundOutputValues, typeof(DOrtGetBoundOutputValues));
OrtClearBoundInputs = (DOrtClearBoundInputs)Marshal.GetDelegateForFunctionPointer(api_.ClearBoundInputs, typeof(DOrtClearBoundInputs));
OrtClearBoundOutputs = (DOrtClearBoundOutputs)Marshal.GetDelegateForFunctionPointer(api_.ClearBoundOutputs, typeof(DOrtClearBoundOutputs));
OrtTensorAt = (DOrtTensorAt)Marshal.GetDelegateForFunctionPointer(api_.TensorAt, typeof(DOrtTensorAt));
OrtCreateAndRegisterAllocator = (DOrtCreateAndRegisterAllocator)Marshal.GetDelegateForFunctionPointer(api_.CreateAndRegisterAllocator, typeof(DOrtCreateAndRegisterAllocator));
OrtSetLanguageProjection = (DOrtSetLanguageProjection)Marshal.GetDelegateForFunctionPointer(api_.SetLanguageProjection, typeof(DOrtSetLanguageProjection));
OrtGetValue = (DOrtGetValue)Marshal.GetDelegateForFunctionPointer(api_.GetValue, typeof(DOrtGetValue));
OrtGetValueType = (DOrtGetValueType)Marshal.GetDelegateForFunctionPointer(api_.GetValueType, typeof(DOrtGetValueType));
@ -776,6 +782,33 @@ namespace Microsoft.ML.OnnxRuntime
public delegate void DOrtClearBoundOutputs(IntPtr /*(OrtIoBinding)*/ io_binding);
public static DOrtClearBoundOutputs OrtClearBoundOutputs;
/// <summary>
/// Provides element-level access into a tensor.
/// </summary>
/// <param name="location_values">a pointer to an array of index values that specify an element's location in the tensor data blob</param>
/// <param name="location_values_count">length of location_values</param>
/// <param name="out">a pointer to the element specified by location_values</param>
public delegate void DOrtTensorAt(IntPtr /*(OrtIoBinding)*/ io_binding);
public static DOrtTensorAt OrtTensorAt;
/// <summary>
/// Creates an allocator instance and registers it with the env to enable
///sharing between multiple sessions that use the same env instance.
///Lifetime of the created allocator will be valid for the duration of the environment.
///Returns an error if an allocator with the same OrtMemoryInfo is already registered.
/// </summary>
/// <param name="mem_info">must be non-null</param>
/// <param name="arena_cfg">if nullptr defaults will be used</param>
public delegate void DOrtCreateAndRegisterAllocator(IntPtr /*(OrtIoBinding)*/ io_binding);
public static DOrtCreateAndRegisterAllocator OrtCreateAndRegisterAllocator;
/// <summary>
/// Set the language projection for collecting telemetry data when Env is created
/// </summary>
/// <param name="projection">the source projected language</param>
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetLanguageProjection(IntPtr /* (OrtEnv*) */ environment, OrtLanguageProjection projection);
public static DOrtSetLanguageProjection OrtSetLanguageProjection;
#endregion IoBinding API
#region ModelMetadata API

View file

@ -24,6 +24,19 @@ namespace Microsoft.ML.OnnxRuntime
Fatal = 4
}
/// <summary>
/// Language projection property for telemetry event for tracking the source usage of ONNXRUNTIME
/// </summary>
public enum OrtLanguageProjection
{
ORT_PROJECTION_C = 0,
ORT_PROJECTION_CPLUSPLUS = 1 ,
ORT_PROJECTION_CSHARP = 2,
ORT_PROJECTION_PYTHON = 3,
ORT_PROJECTION_JAVA = 4,
ORT_PROJECTION_WINML = 5,
}
/// <summary>
/// This class intializes the process-global ONNX runtime
/// C# API users do not need to access this, thus kept as internal
@ -52,6 +65,15 @@ namespace Microsoft.ML.OnnxRuntime
:base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateEnv(LogLevel.Warning, @"CSharpOnnxRuntime", out handle));
try
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetLanguageProjection(handle, OrtLanguageProjection.ORT_PROJECTION_CSHARP));
}
catch (OnnxRuntimeException e)
{
ReleaseHandle();
throw e;
}
}
protected override bool ReleaseHandle()

View file

@ -225,6 +225,16 @@ typedef enum ExecutionMode {
ORT_PARALLEL = 1,
} ExecutionMode;
// Set the language projection, default is C, which means it will classify the language not in the list to C also.
typedef enum OrtLanguageProjection {
ORT_PROJECTION_C = 0, // default
ORT_PROJECTION_CPLUSPLUS = 1,
ORT_PROJECTION_CSHARP = 2,
ORT_PROJECTION_PYTHON = 3,
ORT_PROJECTION_JAVA = 4,
ORT_PROJECTION_WINML = 5,
} OrtLanguageProjection;
struct OrtKernelInfo;
typedef struct OrtKernelInfo OrtKernelInfo;
struct OrtKernelContext;
@ -1012,6 +1022,13 @@ struct OrtApi {
*/
ORT_API2_STATUS(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info,
_In_ const OrtArenaCfg* arena_cfg);
/**
* Set the language projection for collecting telemetry data when Env is created
* \param projection the source projected language.
*/
ORT_API2_STATUS(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection);
};
/*

View file

@ -288,14 +288,17 @@ inline void IoBinding::ClearBoundOutputs() {
inline Env::Env(OrtLoggingLevel default_warning_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnv(default_warning_level, logid, &p_));
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
}
inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param) {
ThrowOnError(GetApi().CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_));
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
}
inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel default_warning_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnvWithGlobalThreadPools(default_warning_level, logid, tp_options, &p_));
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
}
inline Env& Env::EnableTelemetryEvents() {

View file

@ -20,6 +20,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEnvironment_createHandle(JNIEnv *
const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, &copy);
checkOrtStatus(jniEnv,api,api->CreateEnv(convertLoggingLevel(loggingLevel), cName, &env));
(*jniEnv)->ReleaseStringUTFChars(jniEnv,name,cName);
checkOrtStatus(jniEnv, api, api->SetLanguageProjection(env, ORT_PROJECTION_JAVA));
return (jlong) env;
}

View file

@ -19,6 +19,10 @@ void Telemetry::EnableTelemetryEvents() const {
void Telemetry::DisableTelemetryEvents() const {
}
void Telemetry::SetLanguageProjection(uint32_t projection) const {
ORT_UNUSED_PARAMETER(projection);
}
void Telemetry::LogProcessInfo() const {
}

View file

@ -36,6 +36,7 @@ class Telemetry {
virtual void EnableTelemetryEvents() const;
virtual void DisableTelemetryEvents() const;
virtual void SetLanguageProjection(uint32_t projection) const;
virtual void LogProcessInfo() const;

View file

@ -57,6 +57,7 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim
OrtMutex WindowsTelemetry::mutex_;
uint32_t WindowsTelemetry::global_register_count_ = 0;
bool WindowsTelemetry::enabled_ = true;
uint32_t WindowsTelemetry::projection_ = 0;
WindowsTelemetry::WindowsTelemetry() {
@ -88,6 +89,10 @@ void WindowsTelemetry::DisableTelemetryEvents() const {
enabled_ = false;
}
void WindowsTelemetry::SetLanguageProjection(uint32_t projection) const {
projection_ = projection;
}
void WindowsTelemetry::LogProcessInfo() const {
if (global_register_count_ == 0 || enabled_ == false)
return;
@ -196,6 +201,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingInt64(ir_version, "irVersion"),
TraceLoggingUInt32(projection_, "OrtProgrammingProjection"),
TraceLoggingString(model_producer_name.c_str(), "modelProducerName"),
TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"),
TraceLoggingString(model_domain.c_str(), "modelDomain"),

View file

@ -20,6 +20,7 @@ class WindowsTelemetry : public Telemetry {
void EnableTelemetryEvents() const override;
void DisableTelemetryEvents() const override;
void SetLanguageProjection(uint32_t projection) const override;
void LogProcessInfo() const override;
@ -48,6 +49,7 @@ class WindowsTelemetry : public Telemetry {
static OrtMutex mutex_;
static uint32_t global_register_count_;
static bool enabled_;
static uint32_t projection_;
};
} // namespace onnxruntime

View file

@ -1752,6 +1752,16 @@ ORT_API_STATUS_IMPL(OrtApis::TensorAt, _Inout_ OrtValue* value, size_t* location
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection) {
API_IMPL_BEGIN
ORT_UNUSED_PARAMETER(ort_env);
// note telemetry is controlled via the platform Env object, not the OrtEnv object instance
const Env& env = Env::Default();
env.GetTelemetryProvider().SetLanguageProjection(static_cast<uint32_t>(projection));
return nullptr;
API_IMPL_END
}
// End support for non-tensor types
static constexpr OrtApiBase ort_api_base = {
@ -1972,6 +1982,7 @@ static constexpr OrtApi ort_api_1_to_5 = {
&OrtApis::ClearBoundOutputs,
&OrtApis::TensorAt,
&OrtApis::CreateAndRegisterAllocator,
&OrtApis::SetLanguageProjection,
};
// Assert to do a limited check to ensure Version 1 of OrtApi never changes (will detect an addition or deletion but not if they cancel out each other)

View file

@ -231,4 +231,6 @@ ORT_API_STATUS_IMPL(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options,
ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, size_t* location_values, size_t location_values_count, _Outptr_ void** out);
ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg);
ORT_API_STATUS_IMPL(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection);
} // namespace OrtApis

View file

@ -1462,7 +1462,7 @@ void InitializeEnv() {
// import_array1() forces a void return value.
import_array1();
})();
Env::Default().GetTelemetryProvider().SetLanguageProjection(OrtLanguageProjection::ORT_PROJECTION_PYTHON);
OrtPybindThrowIfError(Environment::Create(onnxruntime::make_unique<LoggingManager>(
std::unique_ptr<ISink>{new CLogSink{}},
Severity::kWARNING, false, LoggingManager::InstanceType::Default,

View file

@ -190,8 +190,8 @@ OnnxruntimeEnvironment::OnnxruntimeEnvironment(const OrtApi* ort_api) : ort_env_
OrtEnv* ort_env = nullptr;
THROW_IF_NOT_OK_MSG(ort_api->CreateEnv(OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, "Default", &ort_env),
ort_api);
THROW_IF_NOT_OK_MSG(ort_api->SetLanguageProjection(ort_env, OrtLanguageProjection::ORT_PROJECTION_WINML), ort_api);
ort_env_ = UniqueOrtEnv(ort_env, ort_api->ReleaseEnv);
// Configure the environment with the winml logger
auto winml_adapter_api = GetVersionedWinmlAdapterApi(ort_api);
THROW_IF_NOT_OK_MSG(winml_adapter_api->EnvConfigureCustomLoggerAndProfiler(ort_env_.get(),