Adding platform telemetry (#2109)

This commit is contained in:
Paul McDaniel 2019-10-19 18:25:57 -07:00 committed by Changming Sun
parent b1096424f0
commit d1159b7008
22 changed files with 551 additions and 35 deletions

View file

@ -30,9 +30,12 @@
* [Technical Design Details](#technical-design-details)
* [Extensibility Options](#extensibility-options)
**[Data/Telemetry](#Data/Telemetry)**
**[Contributions and Feedback](#contribute)**
**[License](#license)**
***
# Key Features
## Run any ONNX model
@ -160,6 +163,10 @@ To tune performance for ONNX models, the [ONNX Go Live tool "OLive"](https://git
transform](include/onnxruntime/core/optimizer/graph_transformer.h)
* [Add a new rewrite rule](include/onnxruntime/core/optimizer/rewrite_rule.h)
***
# Data/Telemetry
This project may collect usage data and send it to Microsoft to help improve our products and services. See the [privacy statement](docs/Privacy.md) for more details.
***
# Contribute
We welcome contributions! Please see the [contribution guidelines](CONTRIBUTING.md).
@ -171,6 +178,7 @@ For any feedback or to report a bug, please file a [GitHub Issue](https://github
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
***
# License
[MIT License](LICENSE)

View file

@ -16,6 +16,8 @@ set(onnxruntime_common_src_patterns
"${ONNXRUNTIME_ROOT}/core/platform/env.cc"
"${ONNXRUNTIME_ROOT}/core/platform/env_time.h"
"${ONNXRUNTIME_ROOT}/core/platform/env_time.cc"
"${ONNXRUNTIME_ROOT}/core/platform/telemetry.h"
"${ONNXRUNTIME_ROOT}/core/platform/telemetry.cc"
)
if(WIN32)

View file

@ -22,6 +22,8 @@ namespace Microsoft.ML.OnnxRuntime
public IntPtr GetErrorMessage;
public IntPtr CreateEnv;
public IntPtr CreateEnvWithCustomLogger;
public IntPtr EnableTelemetryEvents;
public IntPtr DisableTelemetryEvents;
public IntPtr CreateSession;
public IntPtr CreateSessionFromArray;
public IntPtr Run;

View file

@ -29,3 +29,6 @@
The example below shows a sample run using the SqueezeNet model from ONNX model zoo, including dynamically reading model inputs, outputs, shape and type information, as well as running a sample vector and fetching the resulting class probabilities for inspection.
* [../csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp](../csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests.Capi/C_Api_Sample.cpp)
## Telemetry
To turn on/off telemetry collection on official Windows builds, please use Enable/DisableTelemetryEvents() in the C API. See the [Privacy](./Privacy.md) page for more information on telemetry collection and Microsoft's privacy policy.

18
docs/Privacy.md Normal file
View file

@ -0,0 +1,18 @@
# Privacy
## Data Collection
The software may collect information about you and your use of the software and send it to Microsoft. Microsoft may use this information to provide services and improve our products and services. You may turn off the telemetry as described in the repository. There are also some features in the software that may enable you and Microsoft to collect data from users of your applications. If you use these features, you must comply with applicable law, including providing appropriate notices to users of your applications together with a copy of Microsoft's privacy statement. Our privacy statement is located at https://go.microsoft.com/fwlink/?LinkID=824704. You can learn more about data collection and use in the help documentation and our privacy statement. Your use of the software operates as your consent to these practices.
***
### Private Builds
No data collection is performed when using your private builds.
### Official Builds
Currently telemetry is only implemented for Windows builds, but may be expanded in the future to cover other platforms. Telemetry is turned OFF by default while this feature is in BETA. When the feature moves from BETA to RELEASE, developers should expect telemetry to be ON by default when using the Official Builds. This is implemented via 'Platform Telemetry' per vendor platform providers (see telemetry.h).
#### Technical Details
The Windows provider uses the [TraceLogging](https://docs.microsoft.com/en-us/windows/win32/tracelogging/trace-logging-about) API for its implementation.
For API usage details to turn this on/off, please check the API pages:
* [C API](./C_API.md#telemetry)

View file

@ -74,6 +74,9 @@ using common::Status;
static_cast<void>(fn)
std::vector<std::string> GetStackTrace();
// these is a helper function that gets defined by platform/Telemetry
void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line);
// __PRETTY_FUNCTION__ isn't a macro on gcc, so use a check for _MSC_VER
// so we only define it as one for MSVC
@ -137,16 +140,25 @@ std::vector<std::string> GetStackTrace();
ORT_DISALLOW_COPY_AND_ASSIGNMENT(TypeName); \
ORT_DISALLOW_MOVE(TypeName)
#define ORT_RETURN_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if ((!_status.IsOK())) return _status; \
#define ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id) \
do { \
auto _status = (expr); \
if ((!_status.IsOK())) { \
::onnxruntime::LogRuntimeError(session_id, _status, __FILE__, __FUNCTION__, __LINE__); \
return _status; \
} \
} while (0)
#define ORT_RETURN_IF_ERROR_SESSIONID_(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, session_id_)
#define ORT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR_SESSIONID(expr, 0)
#define ORT_THROW_IF_ERROR(expr) \
do { \
auto _status = (expr); \
if ((!_status.IsOK())) ORT_THROW(_status); \
if ((!_status.IsOK())) { \
::onnxruntime::LogRuntimeError(0, _status, __FILE__, __FUNCTION__, __LINE__); \
ORT_THROW(_status); \
} \
} while (0)
// use this macro when cannot early return

View file

@ -0,0 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#define ONNXRUNTIME_VERSION_STRING "1.0"

View file

@ -249,6 +249,10 @@ struct OrtApi {
_In_ const char* logid,
_Outptr_ OrtEnv** out)NO_EXCEPTION;
// Platform telemetry events are on by default since they are lightweight. You can manually turn them off.
OrtStatus*(ORT_API_CALL* EnableTelemetryEvents)(_In_ const OrtEnv* env)NO_EXCEPTION;
OrtStatus*(ORT_API_CALL* DisableTelemetryEvents)(_In_ const OrtEnv* env)NO_EXCEPTION;
// TODO: document the path separator convention? '/' vs '\'
// TODO: should specify the access characteristics of model_path. Is this read only during the
// execution of OrtCreateSession, or does the OrtSession retain a handle to the file/directory

View file

@ -107,6 +107,9 @@ struct Env : Base<OrtEnv> {
Env(OrtLoggingLevel default_logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
Env& EnableTelemetryEvents();
Env& DisableTelemetryEvents();
static const OrtApi* s_api;
};

View file

@ -93,6 +93,16 @@ inline Env::Env(OrtLoggingLevel default_warning_level, const char* logid, OrtLog
ThrowOnError(g_api->CreateEnvWithCustomLogger(logging_function, logger_param, default_warning_level, logid, &p_));
}
inline Env& Env::EnableTelemetryEvents() {
ThrowOnError(g_api->EnableTelemetryEvents(p_));
return *this;
}
inline Env& Env::DisableTelemetryEvents() {
ThrowOnError(g_api->DisableTelemetryEvents(p_));
return *this;
}
inline CustomOpDomain::CustomOpDomain(const char* domain) {
ThrowOnError(g_api->CreateCustomOpDomain(domain, &p_));
}

View file

@ -26,6 +26,7 @@ limitations under the License.
#include "core/common/common.h"
#include "core/framework/callback.h"
#include "core/platform/env_time.h"
#include "core/platform/telemetry.h"
#ifndef _WIN32
#include <sys/types.h>
@ -133,6 +134,9 @@ class Env {
// returns the name that LoadDynamicLibrary() can use
virtual std::string FormatLibraryFileName(const std::string& name, const std::string& version) const = 0;
// \brief returns a provider that will handle telemetry on the current platform
virtual const Telemetry& GetTelemetryProvider() const = 0;
protected:
Env();

View file

@ -266,8 +266,14 @@ class PosixEnv : public Env {
return filename;
}
// \brief returns a provider that will handle telemetry on the current platform
const Telemetry& GetTelemetryProvider() const override {
return telemetry_provider_;
}
private:
PosixEnv() = default;
Telemetry telemetry_provider_;
};
} // namespace

View file

@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/platform/telemetry.h"
#include "core/platform/env.h"
namespace onnxruntime {
void LogRuntimeError(uint32_t sessionId, const common::Status& status, const char* file,
const char* function, uint32_t line)
{
const Env& env = Env::Default();
env.GetTelemetryProvider().LogRuntimeError(sessionId, status, file, function, line);
}
void Telemetry::EnableTelemetryEvents() const {
}
void Telemetry::DisableTelemetryEvents() const {
}
void Telemetry::LogProcessInfo() const {
}
void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
const std::string& model_producer_version, const std::string& model_domain,
const std::unordered_map<std::string, int>& domain_to_version_map,
const std::string& model_graph_name,
const std::unordered_map<std::string, std::string>& model_metadata,
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids) const {
ORT_UNUSED_PARAMETER(session_id);
ORT_UNUSED_PARAMETER(ir_version);
ORT_UNUSED_PARAMETER(model_producer_name);
ORT_UNUSED_PARAMETER(model_producer_version);
ORT_UNUSED_PARAMETER(model_domain);
ORT_UNUSED_PARAMETER(domain_to_version_map);
ORT_UNUSED_PARAMETER(model_graph_name);
ORT_UNUSED_PARAMETER(model_metadata);
ORT_UNUSED_PARAMETER(loadedFrom);
ORT_UNUSED_PARAMETER(execution_provider_ids);
}
void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const {
ORT_UNUSED_PARAMETER(session_id);
ORT_UNUSED_PARAMETER(status);
ORT_UNUSED_PARAMETER(file);
ORT_UNUSED_PARAMETER(function);
ORT_UNUSED_PARAMETER(line);
}
void Telemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const {
ORT_UNUSED_PARAMETER(session_id);
ORT_UNUSED_PARAMETER(total_runs_since_last);
ORT_UNUSED_PARAMETER(total_run_duration_since_last);
}
} // namespace onnxruntime

View file

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <string>
#include <vector>
#include <unordered_map>
#include "core/common/status.h"
#include "core/common/common.h"
namespace onnxruntime {
/**
* Configuration information for a session.
* An interface used by the onnxruntime implementation to
* access operating system functionality for telemetry
*
* look at env.h and the Env objection which is the activation factory
* for telemetry instances
*
* All Telemetry implementations are safe for concurrent access from
* multiple threads without any external synchronization.
*/
class Telemetry {
public:
// don't create these, use Env::GetTelemetryProvider() instead
// this constructor is made public so that other platform Env providers can
// use this base class as a "stub" implementation
Telemetry() = default;
virtual ~Telemetry() = default;
virtual void EnableTelemetryEvents() const;
virtual void DisableTelemetryEvents() const;
virtual void LogProcessInfo() const;
virtual void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
const std::string& model_producer_version, const std::string& model_domain,
const std::unordered_map<std::string, int>& domain_to_version_map,
const std::string& model_graph_name,
const std::unordered_map<std::string, std::string>& model_metadata,
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids) const;
virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const;
virtual void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const;
private:
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Telemetry);
};
} // namespace onnxruntime

View file

@ -25,6 +25,7 @@ limitations under the License.
#include "core/common/logging/logging.h"
#include "core/platform/env.h"
#include "core/platform/windows/telemetry.h"
namespace onnxruntime {
@ -211,6 +212,11 @@ class WindowsEnv : public Env {
ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented");
}
// \brief returns a provider that will handle telemetry on the current platform
const Telemetry& GetTelemetryProvider() const override {
return telemetry_provider_;
}
private:
WindowsEnv()
: GetSystemTimePreciseAsFileTime_(nullptr) {
@ -228,8 +234,8 @@ class WindowsEnv : public Env {
typedef VOID(WINAPI* FnGetSystemTimePreciseAsFileTime)(LPFILETIME);
FnGetSystemTimePreciseAsFileTime GetSystemTimePreciseAsFileTime_;
WindowsTelemetry telemetry_provider_;
};
} // namespace
#if defined(PLATFORM_WINDOWS)

View file

@ -0,0 +1,192 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/platform/windows/telemetry.h"
#include "core/common/version.h"
// ETW includes
// need space after Windows.h to prevent clang-format re-ordering breaking the build.
// TraceLoggingProvider.h must follow Windows.h
#include <Windows.h>
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 26440) // Warning C26440 from TRACELOGGING_DEFINE_PROVIDER
#endif
#include <TraceLoggingProvider.h>
#include <evntrace.h>
namespace onnxruntime {
namespace {
TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntime",
// {3a26b1ff-7484-7484-7484-15261f42614d}
(0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d),
TraceLoggingOptionMicrosoftTelemetry());
} // namespace
#ifdef _MSC_VER
#pragma warning(pop)
#endif
OrtMutex WindowsTelemetry::mutex_;
uint32_t WindowsTelemetry::global_register_count_ = 0;
bool WindowsTelemetry::enabled_ = false;
WindowsTelemetry::WindowsTelemetry() {
std::lock_guard<OrtMutex> lock(mutex_);
if (global_register_count_ == 0) {
// TraceLoggingRegister is fancy in that you can only register once GLOBALLY for the whole process
HRESULT hr = TraceLoggingRegister(telemetry_provider_handle);
if (SUCCEEDED(hr)) {
global_register_count_ += 1;
}
}
}
WindowsTelemetry::~WindowsTelemetry() {
std::lock_guard<OrtMutex> lock(mutex_);
if (global_register_count_ > 0) {
global_register_count_ -= 1;
if (global_register_count_ == 0) {
TraceLoggingUnregister(telemetry_provider_handle);
}
}
}
void WindowsTelemetry::EnableTelemetryEvents() const {
enabled_ = true;
}
void WindowsTelemetry::DisableTelemetryEvents() const {
enabled_ = false;
}
void WindowsTelemetry::LogProcessInfo() const {
if (global_register_count_ == 0 || enabled_ == false)
return;
static std::atomic<bool> process_info_logged;
// did we already log the process info? we only need to log it once
if (process_info_logged.exchange(true))
return;
TraceLoggingWrite(telemetry_provider_handle,
"ProcessInfo",
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingString(ONNXRUNTIME_VERSION_STRING, "runtimeVersion"),
TraceLoggingBool(true, "isRedist"));
process_info_logged = true;
}
void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
const std::string& model_producer_version, const std::string& model_domain,
const std::unordered_map<std::string, int>& domain_to_version_map,
const std::string& model_graph_name,
const std::unordered_map<std::string, std::string>& model_metadata,
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids) const {
if (global_register_count_ == 0 || enabled_ == false)
return;
// build the strings we need
std::string domain_to_verison_string;
bool first = true;
for (auto& i : domain_to_version_map) {
if (first) {
first = false;
} else {
domain_to_verison_string += ',';
}
domain_to_verison_string += i.first;
domain_to_verison_string += '=';
domain_to_verison_string += std::to_string(i.second);
}
std::string model_metadata_string;
first = true;
for (auto& i : model_metadata) {
if (first) {
first = false;
} else {
model_metadata_string += ',';
}
model_metadata_string += i.first;
model_metadata_string += '=';
model_metadata_string += i.second;
}
std::string execution_provider_string;
first = true;
for (auto& i : execution_provider_ids) {
if (first) {
first = false;
} else {
execution_provider_string += ',';
}
execution_provider_string += i;
}
TraceLoggingWrite(telemetry_provider_handle,
"SessionCreation",
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingInt64(ir_version, "irVersion"),
TraceLoggingString(model_producer_name.c_str(), "modelProducerName"),
TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"),
TraceLoggingString(model_domain.c_str(), "modelDomain"),
TraceLoggingString(domain_to_verison_string.c_str(), "domainToVersionMap"),
TraceLoggingString(model_graph_name.c_str(), "modelGraphName"),
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
TraceLoggingString(loadedFrom.c_str(), "loadedFrom"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
}
void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const {
if (global_register_count_ == 0 || enabled_ == false)
return;
TraceLoggingWrite(telemetry_provider_handle,
"RuntimeError",
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingUInt32(status.Code(), "errorCode"),
TraceLoggingUInt32(status.Category(), "errorCategory"),
TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"),
TraceLoggingString(file, "file"),
TraceLoggingString(function, "function"),
TraceLoggingInt32(line, "line"));
}
void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const {
if (global_register_count_ == 0 || enabled_ == false)
return;
TraceLoggingWrite(telemetry_provider_handle,
"RuntimePerf",
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingUInt32(total_runs_since_last, "totalRuns"),
TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"));
}
} // namespace onnxruntime

View file

@ -0,0 +1,57 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/platform/telemetry.h"
#include "core/platform/ort_mutex.h"
#include <atomic>
// ***
// platform specific control bits
#ifndef TraceLoggingOptionMicrosoftTelemetry
#define TraceLoggingOptionMicrosoftTelemetry() \
TraceLoggingOptionGroup(0000000000, 00000, 00000, 0000, 0000, 0000, 0000, 0000, 000, 0000, 0000)
#endif
#define MICROSOFT_KEYWORD_MEASURES 0x0000400000000000 // Bit 46
#define TelemetryPrivacyDataTag(tag) TraceLoggingUInt64((tag), "PartA_PrivTags")
#define PDT_ProductAndServicePerformance 0x0000000001000000u
#define PDT_ProductAndServiceUsage 0x0000000002000000u
// ***
namespace onnxruntime {
/**
* derives and implments a Telemetry provider on Windows
*/
class WindowsTelemetry : public Telemetry {
public:
// these are allowed to be created, WindowsEnv will create one
WindowsTelemetry();
~WindowsTelemetry();
void EnableTelemetryEvents() const override;
void DisableTelemetryEvents() const override;
void LogProcessInfo() const override;
void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name,
const std::string& model_producer_version, const std::string& model_domain,
const std::unordered_map<std::string, int>& domain_to_version_map,
const std::string& model_graph_name,
const std::unordered_map<std::string, std::string>& model_metadata,
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids) const override;
void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const override;
void LogRuntimePerf(uint32_t session_id, uint32_t total_runs_since_last, int64_t total_run_duration_since_last) const override;
private:
static OrtMutex mutex_;
static uint32_t global_register_count_;
static bool enabled_;
};
} // namespace onnxruntime

View file

@ -17,6 +17,8 @@
#include "core/graph/dml_ops/dml_defs.h"
#endif
#include "core/platform/env.h"
namespace onnxruntime {
using namespace ::onnxruntime::common;
using namespace ONNX_NAMESPACE;
@ -85,6 +87,10 @@ Internal copy node
Internal copy node
)DOC");
// fire off startup telemetry (this call is idempotent)
const Env& env = Env::Default();
env.GetTelemetryProvider().LogProcessInfo();
is_initialized_ = true;
} catch (std::exception& ex) {
status = Status{ONNXRUNTIME, common::RUNTIME_EXCEPTION, std::string{"Exception caught: "} + ex.what()};

View file

@ -95,6 +95,8 @@ inline std::basic_string<T> GetCurrentTimeString() {
} // namespace
std::atomic<uint32_t> InferenceSession::global_session_id_{1};
InferenceSession::InferenceSession(const SessionOptions& session_options,
logging::LoggingManager* logging_manager)
: session_options_(session_options),
@ -122,6 +124,9 @@ InferenceSession::InferenceSession(const SessionOptions& session_options,
if (session_options.enable_profiling) {
StartProfiling(session_options.profile_file_prefix);
}
// a monotonically increasing session id for use in telemetry
session_id_ = global_session_id_.fetch_add(1);
}
InferenceSession::~InferenceSession() {
@ -188,7 +193,7 @@ common::Status InferenceSession::AddCustomTransformerList(const std::vector<std:
common::Status InferenceSession::AddCustomOpDomains(const std::vector<OrtCustomOpDomain*>& op_domains) {
std::shared_ptr<CustomRegistry> custom_registry;
ORT_RETURN_IF_ERROR(CreateCustomRegistry(op_domains, custom_registry));
ORT_RETURN_IF_ERROR_SESSIONID_(CreateCustomRegistry(op_domains, custom_registry));
RegisterCustomRegistry(custom_registry);
return Status::OK();
}
@ -221,15 +226,22 @@ common::Status InferenceSession::Load(std::function<common::Status(std::shared_p
std::shared_ptr<onnxruntime::Model> p_tmp_model;
status = loader(p_tmp_model);
ORT_RETURN_IF_ERROR(status);
ORT_RETURN_IF_ERROR_SESSIONID_(status);
model_ = p_tmp_model;
status = DoPostLoadProcessing(*model_);
ORT_RETURN_IF_ERROR(status);
ORT_RETURN_IF_ERROR_SESSIONID_(status);
// all steps complete, mark the model as loaded.
is_model_loaded_ = true;
// and log telemetry
const Env& env = Env::Default();
env.GetTelemetryProvider().LogSessionCreation(session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(),
model_->Domain(), model_->MainGraph().DomainToVersionMap(), model_->MainGraph().Name(),
model_->MetaData(), event_name, execution_providers_.GetIds());
} catch (const std::exception& ex) {
status = Status(common::ONNXRUNTIME, common::FAIL, "Exception during loading: " + std::string(ex.what()));
} catch (...) {
@ -363,7 +375,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
// 5. insert cast nodes.
// first apply global(execution provider independent), level 1(default/system/basic) graph to graph optimizations
ORT_RETURN_IF_ERROR(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1));
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, TransformerLevel::Level1));
#ifdef USE_DML
// TODO: this is a temporary workaround to apply the DML EP's custom graph transformer prior to partitioning. This
@ -384,17 +396,17 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
// Do partitioning based on execution providers' capability.
GraphPartitioner partitioner(kernel_registry_manager, providers);
ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr()));
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state.ExportDll(), session_state.GetMutableFuncMgr()));
// apply transformers except default transformers
// Default transformers are required for correctness and they are owned and run by inference session
for (int i = static_cast<int>(TransformerLevel::Level1); i < static_cast<int>(TransformerLevel::MaxTransformerLevel); i++) {
ORT_RETURN_IF_ERROR(graph_transformer_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i)));
ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr.ApplyTransformers(graph, static_cast<TransformerLevel>(i)));
}
bool modified = false;
// Insert cast node/s.
ORT_RETURN_IF_ERROR(insert_cast_transformer.Apply(graph, modified));
ORT_RETURN_IF_ERROR_SESSIONID_(insert_cast_transformer.Apply(graph, modified));
// Now every node should be already assigned to an execution provider
for (auto& node : graph.Nodes()) {
@ -417,7 +429,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph,
// Insert copy node/s.
MemcpyTransformer copy_transformer{provider_types, kernel_registry_manager};
ORT_RETURN_IF_ERROR(copy_transformer.Apply(graph, modified));
ORT_RETURN_IF_ERROR_SESSIONID_(copy_transformer.Apply(graph, modified));
return common::Status::OK();
}
@ -443,7 +455,7 @@ common::Status InferenceSession::CreateSubgraphSessionState(Graph& graph, Sessio
subgraph_session_state->GetMutableFuncMgr().SetFusedFuncs(session_state.GetFuncMgr());
// recurse
ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));
ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(*subgraph, *subgraph_session_state));
// add the subgraph SessionState instance to the parent graph SessionState so it can be retrieved
// by Compute() via OpKernelContextInternal.
@ -479,8 +491,8 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
*subgraph_session_state, execution_providers_, kernel_registry_manager_);
const auto implicit_inputs = node.ImplicitInputDefs();
ORT_RETURN_IF_ERROR(initializer.CreatePlan(&node, &implicit_inputs,
session_options_.execution_mode));
ORT_RETURN_IF_ERROR_SESSIONID_(initializer.CreatePlan(&node, &implicit_inputs,
session_options_.execution_mode));
// LOGS(*session_logger_, VERBOSE) << std::make_pair(subgraph_info.session_state->GetExecutionPlan(),
// &*subgraph_info.session_state);
@ -489,10 +501,10 @@ common::Status InferenceSession::InitializeSubgraphSessions(Graph& graph, Sessio
auto* p_op_kernel = session_state.GetMutableKernel(node.Index());
ORT_ENFORCE(p_op_kernel);
auto& control_flow_kernel = dynamic_cast<controlflow::IControlFlowKernel&>(*p_op_kernel);
ORT_RETURN_IF_ERROR(control_flow_kernel.SetupSubgraphExecutionInfo(session_state, name, *subgraph_session_state));
ORT_RETURN_IF_ERROR_SESSIONID_(control_flow_kernel.SetupSubgraphExecutionInfo(session_state, name, *subgraph_session_state));
// recurse
ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(subgraph, *subgraph_session_state));
ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(subgraph, *subgraph_session_state));
}
}
@ -521,7 +533,7 @@ common::Status InferenceSession::Initialize() {
LOGS(*session_logger_, INFO) << "Adding default CPU execution provider.";
CPUExecutionProviderInfo epi{session_options_.enable_cpu_mem_arena};
auto p_cpu_exec_provider = onnxruntime::make_unique<CPUExecutionProvider>(epi);
ORT_RETURN_IF_ERROR(RegisterExecutionProvider(std::move(p_cpu_exec_provider)));
ORT_RETURN_IF_ERROR_SESSIONID_(RegisterExecutionProvider(std::move(p_cpu_exec_provider)));
}
if (session_options_.execution_mode == ExecutionMode::ORT_PARALLEL &&
@ -546,37 +558,37 @@ common::Status InferenceSession::Initialize() {
// The 1st ones should have already been registered via session-level API into KernelRegistryManager.
//
// Register 2nd registries into KernelRegistryManager.
ORT_RETURN_IF_ERROR(kernel_registry_manager_.RegisterKernels(execution_providers_));
ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_));
SessionStateInitializer session_initializer(session_options_.enable_mem_pattern, model_location_, graph,
session_state_, execution_providers_, kernel_registry_manager_);
// create SessionState for subgraphs as it's needed by the transformers
ORT_RETURN_IF_ERROR(CreateSubgraphSessionState(graph, session_state_));
ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(graph, session_state_));
// apply any transformations to the main graph and any subgraphs
ORT_RETURN_IF_ERROR(TransformGraph(graph, graph_transformation_mgr_,
execution_providers_, kernel_registry_manager_,
insert_cast_transformer_,
session_state_));
ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, graph_transformation_mgr_,
execution_providers_, kernel_registry_manager_,
insert_cast_transformer_,
session_state_));
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR(graph.Resolve());
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
if (!session_options_.optimized_model_filepath.empty()) {
if (session_options_.graph_optimization_level < TransformerLevel::Level3) {
// Serialize optimized ONNX model.
ORT_RETURN_IF_ERROR(Model::Save(*model_, session_options_.optimized_model_filepath));
ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath));
} else {
LOGS(*session_logger_, WARNING) << "Serializing Optimized ONNX model with Graph Optimization"
" level greater than 2 is not supported.";
}
}
ORT_RETURN_IF_ERROR(session_initializer.CreatePlan(nullptr, nullptr, session_options_.execution_mode));
ORT_RETURN_IF_ERROR_SESSIONID_(session_initializer.CreatePlan(nullptr, nullptr, session_options_.execution_mode));
// handle any subgraphs
ORT_RETURN_IF_ERROR(InitializeSubgraphSessions(graph, session_state_));
ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(graph, session_state_));
is_inited_ = true;
LOGS(*session_logger_, INFO) << "Session successfully initialized.";
@ -685,17 +697,17 @@ common::Status InferenceSession::ValidateInputs(const std::vector<std::string>&
auto expected_element_type = expected_type->AsTensorType()->GetElementType();
auto input_element_type = input_ml_value.Get<Tensor>().DataType();
ORT_RETURN_IF_ERROR(CheckTypes(input_element_type, expected_element_type));
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_element_type, expected_element_type));
// check for shape
const auto& expected_shape = iter->second.tensor_shape;
if (expected_shape.NumDimensions() > 0) {
const auto& input_shape = input_ml_value.Get<Tensor>().Shape();
ORT_RETURN_IF_ERROR(CheckShapes(feed_name, input_shape, expected_shape));
ORT_RETURN_IF_ERROR_SESSIONID_(CheckShapes(feed_name, input_shape, expected_shape));
}
} else {
auto input_type = input_ml_value.Type();
ORT_RETURN_IF_ERROR(CheckTypes(input_type, expected_type));
ORT_RETURN_IF_ERROR_SESSIONID_(CheckTypes(input_type, expected_type));
}
}
@ -746,8 +758,8 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
return Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized.");
}
ORT_RETURN_IF_ERROR(ValidateInputs(feed_names, feeds));
ORT_RETURN_IF_ERROR(ValidateOutputs(output_names, p_fetches));
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds));
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches));
FeedsFetchesInfo info(feed_names, output_names, session_state_.GetOrtValueNameIdxMap());
FeedsFetchesManager feeds_fetches_manager{std::move(info)};
@ -789,6 +801,23 @@ Status InferenceSession::Run(const RunOptions& run_options, const std::vector<st
}
--current_num_runs_;
// keep track of telemetry
++total_runs_since_last_;
total_run_duration_since_last_ += TimeDiffMicroSeconds(tp);
// time to send telemetry?
if (TimeDiffMicroSeconds(time_sent_last_) > kDurationBetweenSending) {
// send the telemetry
const Env& env = Env::Default();
env.GetTelemetryProvider().LogRuntimePerf(session_id_, total_runs_since_last_, total_run_duration_since_last_);
// reset counters
time_sent_last_ = std::chrono::high_resolution_clock::now();
total_runs_since_last_ = 0;
total_run_duration_since_last_ = 0;
}
// send out profiling events (optional)
if (session_profiler_.IsEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "model_run", tp);
}

View file

@ -434,5 +434,13 @@ class InferenceSession {
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
InterOpDomains interop_domains_;
#endif
// used to support platform telemetry
static std::atomic<uint32_t> global_session_id_; // a monotonically increasing session id
uint32_t session_id_; // the current session's id
uint32_t total_runs_since_last_; // the total number of Run() calls since the last report
long long total_run_duration_since_last_; // the total duration (us) of Run() calls since the last report
TimePoint time_sent_last_; // the TimePoint of the last report
const long long kDurationBetweenSending = 1000* 1000 * 60 * 10; // duration in (us). send a report every 10 mins
};
} // namespace onnxruntime

View file

@ -128,6 +128,27 @@ ORT_API_STATUS_IMPL(OrtApis::CreateEnv, OrtLoggingLevel default_warning_level,
API_IMPL_END
}
// enable platform telemetry
ORT_API_STATUS_IMPL(OrtApis::EnableTelemetryEvents, _In_ const OrtEnv* ort_env) {
API_IMPL_BEGIN
ORT_UNUSED_PARAMETER(ort_env);
// note telemetry is controlled via the platform Env object, not the OrtEnv object instance
const Env& env = Env::Default();
env.GetTelemetryProvider().EnableTelemetryEvents();
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(OrtApis::DisableTelemetryEvents, _In_ const OrtEnv* ort_env) {
API_IMPL_BEGIN
ORT_UNUSED_PARAMETER(ort_env);
// note telemetry is controlled via the platform Env object, not the OrtEnv object instance
const Env& env = Env::Default();
env.GetTelemetryProvider().DisableTelemetryEvents();
return nullptr;
API_IMPL_END
}
template <typename T>
OrtStatus* CreateTensorImpl(const int64_t* shape, size_t shape_len, OrtAllocator* allocator,
std::unique_ptr<Tensor>* out) {
@ -1266,6 +1287,9 @@ static constexpr OrtApi ort_api_1 = {
&OrtApis::CreateEnv,
&OrtApis::CreateEnvWithCustomLogger,
&OrtApis::EnableTelemetryEvents,
&OrtApis::DisableTelemetryEvents,
&OrtApis::CreateSession,
&OrtApis::CreateSessionFromArray,
&OrtApis::Run,

View file

@ -24,6 +24,8 @@ const char* ORT_API_CALL GetErrorMessage(_In_ const OrtStatus* status) NO_EXCEPT
ORT_API_STATUS_IMPL(CreateEnv, OrtLoggingLevel default_logging_level, _In_ const char* logid, _Outptr_ OrtEnv** out)
ORT_ALL_ARGS_NONNULL;
ORT_API_STATUS_IMPL(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel default_warning_level, _In_ const char* logid, _Outptr_ OrtEnv** out);
ORT_API_STATUS_IMPL(EnableTelemetryEvents, _In_ const OrtEnv* env);
ORT_API_STATUS_IMPL(DisableTelemetryEvents, _In_ const OrtEnv* env);
ORT_API_STATUS_IMPL(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path,
_In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out);