onnxruntime/winml/adapter/winml_adapter_environment.cpp
Abhishek Udupa 83c59d2594
Session-aware and thread-safe CUDA profiler (#13706)
### Description
The existing CUDA profiler is neither session-aware, nor thread-safe.
This PR ensures both.

### Motivation and Context
[PR 13549](https://github.com/microsoft/onnxruntime/pull/13549) brought
thread-safety and session-awareness to the ROCm profiler. This PR brings
the same goodness to the CUDA profiler as well.

Sample outputs of a profiling run from the StableDiffusion model (this
model was chosen because it requires orchestration of multiple sessions,
and verifies that the profilers are now indeed session-aware) on both
CUDA and ROCm EPs are attached, along with a script that checks that the
trace files generated by the profile are well-formed.

Update 11/29: Updated the profile outputs. The older profile outputs
exhibited an issue where some timestamps were wildly out of range,
leading to problems visualizing the traces. The bug has been fixed and
the profile outputs have been updated, along with an update to the check
script to ensure that timestamps are monotonically increasing.


[sd_profile_outputs_cuda.tar.gz](https://github.com/microsoft/onnxruntime/files/10118088/sd_profile_outputs_cuda.tar.gz)

[sd_profile_outputs_rocm.tar.gz](https://github.com/microsoft/onnxruntime/files/10118089/sd_profile_outputs_rocm.tar.gz)

[check_profile_output_well_formedness.zip](https://github.com/microsoft/onnxruntime/files/10118090/check_profile_output_well_formedness.zip)

Co-authored-by: Abhishek Udupa <abhishek.udupa@microsoft.com>
2022-12-09 13:22:12 -08:00

85 lines
4.2 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "adapter/pch.h"
#include "winml_adapter_c_api.h"
#include "core/session/ort_apis.h"
#include "winml_adapter_apis.h"
#include "core/framework/error_code_helper.h"
#include "core/session/ort_env.h"
#ifdef USE_DML
#include "abi_custom_registry_impl.h"
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"
#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h"
#endif USE_DML
namespace winmla = Windows::AI::MachineLearning::Adapter;
class WinmlAdapterLoggingWrapper : public LoggingWrapper {
public:
WinmlAdapterLoggingWrapper(OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, void* logger_param) : LoggingWrapper(logging_function, logger_param),
profiling_function_(profiling_function) {
}
void SendProfileEvent(onnxruntime::profiling::EventRecord& event_record) const override {
if (profiling_function_) {
OrtProfilerEventRecord ort_event_record = {};
ort_event_record.category_ = static_cast<OrtProfilerEventCategory>(event_record.cat);
ort_event_record.category_name_ = onnxruntime::profiling::event_category_names_[event_record.cat];
ort_event_record.duration_ = event_record.dur;
ort_event_record.event_name_ = event_record.name.c_str();
ort_event_record.execution_provider_ = (event_record.cat == onnxruntime::profiling::EventCategory::NODE_EVENT) ? event_record.args["provider"].c_str() : nullptr;
ort_event_record.op_name_ = (event_record.cat == onnxruntime::profiling::EventCategory::NODE_EVENT) ? event_record.args["op_name"].c_str() : nullptr;
ort_event_record.process_id_ = event_record.pid;
ort_event_record.thread_id_ = event_record.tid;
ort_event_record.time_span_ = event_record.ts;
profiling_function_(&ort_event_record);
}
}
private:
OrtProfilingFunction profiling_function_{};
};
ORT_API_STATUS_IMPL(winmla::EnvConfigureCustomLoggerAndProfiler, _In_ OrtEnv* env, OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function,
_In_opt_ void* logger_param, OrtLoggingLevel default_warning_level,
_In_ const char* logid, _Outptr_ OrtEnv** out) {
API_IMPL_BEGIN
std::string name = logid;
std::unique_ptr<onnxruntime::logging::ISink> logger = std::make_unique<WinmlAdapterLoggingWrapper>(logging_function, profiling_function, logger_param);
// Clear the logging manager, since only one default instance of logging manager can exist at a time.
env->SetLoggingManager(nullptr);
auto winml_logging_manager = std::make_unique<onnxruntime::logging::LoggingManager>(std::move(logger),
static_cast<onnxruntime::logging::Severity>(default_warning_level),
false,
onnxruntime::logging::LoggingManager::InstanceType::Default,
&name);
// Set a new default logging manager
env->SetLoggingManager(std::move(winml_logging_manager));
return nullptr;
API_IMPL_END
}
// Override select shape inference functions which are incomplete in ONNX with versions that are complete,
// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being
// deferred until first evaluation. It also prevents a situation where inference functions in externally
// registered schema are reachable only after upstream schema have been revised in a later OS release,
// which would be a compatibility risk.
ORT_API_STATUS_IMPL(winmla::OverrideSchema) {
API_IMPL_BEGIN
#ifdef USE_DML
static std::once_flag schema_override_once_flag;
std::call_once(schema_override_once_flag, []() {
SchemaInferenceOverrider::OverrideSchemaInferenceFunctions();
});
#endif USE_DML.
return nullptr;
API_IMPL_END
}