mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-19 02:03:52 +00:00
Flush and trim resources in DML EP in new OnSessionInitializationEnd method
This commit is contained in:
parent
b63349c8d6
commit
e89dd92387
15 changed files with 41 additions and 103 deletions
|
|
@ -136,6 +136,13 @@ class IExecutionProvider {
|
|||
*/
|
||||
virtual common::Status OnRunEnd();
|
||||
|
||||
/**
|
||||
Called when session creation is complete
|
||||
This provides an opportunity for execution providers to optionally synchronize and
|
||||
clean up its temporary resources to reduce memory and ensure the first run is fast.
|
||||
*/
|
||||
virtual common::Status OnSessionInitializationEnd();
|
||||
|
||||
void InsertAllocator(AllocatorPtr allocator);
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -49,6 +49,8 @@ common::Status IExecutionProvider::OnRunStart() { return Status::OK(); }
|
|||
|
||||
common::Status IExecutionProvider::OnRunEnd() { return Status::OK(); }
|
||||
|
||||
common::Status IExecutionProvider::OnSessionInitializationEnd() { return Status::OK(); }
|
||||
|
||||
void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) {
|
||||
const OrtMemoryInfo& info = allocator->Info();
|
||||
const int key = MakeKey(info.id, info.mem_type);
|
||||
|
|
|
|||
|
|
@ -34,8 +34,6 @@ namespace Dml
|
|||
void FlushContext(onnxruntime::IExecutionProvider* provider);
|
||||
void SetDefaultRoundingMode(onnxruntime::IExecutionProvider* provider, AllocatorRoundingMode roundingMode);
|
||||
void ReleaseCompletedReferences(onnxruntime::IExecutionProvider* provider);
|
||||
void TrimUploadHeap(onnxruntime::IExecutionProvider * provider);
|
||||
void WaitForGpuCompletion(onnxruntime::IExecutionProvider * provider);
|
||||
|
||||
onnxruntime::common::Status CopyTensor(
|
||||
onnxruntime::IExecutionProvider* provider,
|
||||
|
|
|
|||
|
|
@ -532,17 +532,6 @@ namespace Dml
|
|||
return onnxruntime::common::Status::OK();
|
||||
}
|
||||
|
||||
Status ExecutionProviderImpl::WaitForGpuCompletion()
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
||||
Flush();
|
||||
m_context->GetCurrentCompletionEvent().WaitForSignal();
|
||||
m_context->ReleaseCompletedReferences();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void __stdcall ExecutionProviderImpl::Flush() const
|
||||
{
|
||||
assert(!m_closed);
|
||||
|
|
@ -558,11 +547,6 @@ namespace Dml
|
|||
{
|
||||
m_context->ReleaseCompletedReferences();
|
||||
}
|
||||
|
||||
void ExecutionProviderImpl::TrimUploadHeap()
|
||||
{
|
||||
m_uploadHeap->Trim();
|
||||
}
|
||||
|
||||
void ExecutionProviderImpl::QueueReference(IUnknown* object)
|
||||
{
|
||||
|
|
@ -701,6 +685,20 @@ namespace Dml
|
|||
return m_cpuOutputAllocator;
|
||||
}
|
||||
|
||||
|
||||
onnxruntime::common::Status ExecutionProviderImpl::OnSessionInitializationEnd()
|
||||
{
|
||||
// Flush and trim resources, including staging memory used to upload weights.
|
||||
// This reduces memory usage immediately after session creation, and avoids
|
||||
// performance impact of deallocation during first evaluation.
|
||||
Flush();
|
||||
m_context->GetCurrentCompletionEvent().WaitForSignal();
|
||||
m_context->ReleaseCompletedReferences();
|
||||
m_uploadHeap->Trim();
|
||||
|
||||
return onnxruntime::common::Status::OK();
|
||||
}
|
||||
|
||||
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
|
||||
IDMLDevice* dmlDevice,
|
||||
ID3D12CommandQueue* commandQueue,
|
||||
|
|
@ -733,18 +731,6 @@ namespace Dml
|
|||
dmlexecutionprovider->ReleaseCompletedReferences();
|
||||
}
|
||||
|
||||
void TrimUploadHeap(onnxruntime::IExecutionProvider * provider)
|
||||
{
|
||||
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
|
||||
dmlexecutionprovider->TrimUploadHeap();
|
||||
}
|
||||
|
||||
void WaitForGpuCompletion(onnxruntime::IExecutionProvider * provider)
|
||||
{
|
||||
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
|
||||
dmlexecutionprovider->WaitForGpuCompletion();
|
||||
}
|
||||
|
||||
onnxruntime::common::Status CopyTensor(
|
||||
onnxruntime::IExecutionProvider* provider,
|
||||
const onnxruntime::Tensor& src,
|
||||
|
|
|
|||
|
|
@ -39,8 +39,6 @@ namespace Dml
|
|||
|
||||
void ReleaseCompletedReferences();
|
||||
|
||||
void TrimUploadHeap();
|
||||
|
||||
public: // implements Dml::IExecutionProvider
|
||||
STDMETHOD(GetD3DDevice)(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept final;
|
||||
|
||||
|
|
@ -92,7 +90,6 @@ namespace Dml
|
|||
uint32_t GetSuppportedDeviceDataTypeMask() const;
|
||||
|
||||
onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const;
|
||||
onnxruntime::common::Status WaitForGpuCompletion();
|
||||
|
||||
// IWinmlExecutionProvider methods
|
||||
void QueueReference(IUnknown* object) override;
|
||||
|
|
@ -157,7 +154,9 @@ namespace Dml
|
|||
std::shared_ptr<onnxruntime::IAllocator> GetCpuOutputAllocator();
|
||||
|
||||
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap>
|
||||
GetInternalRegistrationInfoMap() const;
|
||||
GetInternalRegistrationInfoMap() const;
|
||||
|
||||
onnxruntime::common::Status OnSessionInitializationEnd();
|
||||
|
||||
private:
|
||||
void Initialize(ID3D12CommandQueue* queue, ExecutionProvider& executionProvider);
|
||||
|
|
@ -221,31 +220,28 @@ namespace Dml
|
|||
bool enableMetacommands = true
|
||||
);
|
||||
|
||||
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const final
|
||||
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const final override
|
||||
{
|
||||
return std::make_unique<DataTransfer>(m_impl.Get());
|
||||
}
|
||||
|
||||
const void* GetExecutionHandle() const noexcept final
|
||||
const void* GetExecutionHandle() const noexcept final override
|
||||
{
|
||||
return m_impl.Get();
|
||||
}
|
||||
|
||||
std::shared_ptr<onnxruntime::KernelRegistry> GetKernelRegistry() const final
|
||||
std::shared_ptr<onnxruntime::KernelRegistry> GetKernelRegistry() const final override
|
||||
{
|
||||
return m_impl->GetKernelRegistry();
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
|
||||
GetCapability(const onnxruntime::GraphViewer& graph,
|
||||
const std::vector<const onnxruntime::KernelRegistry*>& kernel_registries) const final;
|
||||
const std::vector<const onnxruntime::KernelRegistry*>& kernel_registries) const final override;
|
||||
|
||||
// Not to be confused with IExecutionProvider::Sync() const. The DML provider handles
|
||||
// synchronization when copying inputs and outputs, therefore doesn't override the
|
||||
// default ORT method, which does nothin.
|
||||
onnxruntime::common::Status WaitForGpuCompletion()
|
||||
{
|
||||
return m_impl->WaitForGpuCompletion();
|
||||
onnxruntime::common::Status OnSessionInitializationEnd() override
|
||||
{
|
||||
return m_impl->OnSessionInitializationEnd();
|
||||
}
|
||||
|
||||
void Flush()
|
||||
|
|
@ -262,11 +258,6 @@ namespace Dml
|
|||
{
|
||||
return m_impl->ReleaseCompletedReferences();
|
||||
}
|
||||
|
||||
void TrimUploadHeap()
|
||||
{
|
||||
m_impl->TrimUploadHeap();
|
||||
}
|
||||
|
||||
ExecutionProviderImpl* GetImpl()
|
||||
{
|
||||
|
|
|
|||
|
|
@ -929,6 +929,14 @@ common::Status InferenceSession::Initialize() {
|
|||
if (session_profiler_.IsEnabled()) {
|
||||
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "session_initialization", tp);
|
||||
}
|
||||
|
||||
if (status.IsOK()) {
|
||||
auto retval = status;
|
||||
for (auto& xp : execution_providers_) {
|
||||
auto status = xp->OnSessionInitializationEnd();
|
||||
ORT_CHECK_AND_SET_RETVAL(status);
|
||||
}
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -58,7 +58,6 @@ ORT_API_STATUS(SessionCopyOneInputAcrossDevices, _In_ OrtSession* session, _In_
|
|||
// Dml methods (TODO need to figure out how these need to move to session somehow...)
|
||||
ORT_API_STATUS(DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled);
|
||||
ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider);
|
||||
ORT_API_STATUS(DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider);
|
||||
ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider);
|
||||
ORT_API_STATUS(DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource);
|
||||
ORT_API_STATUS(DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource);
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
|
|||
// Dml methods (TODO need to figure out how these need to move to session somehow...)
|
||||
&winmla::DmlExecutionProviderSetDefaultRoundingMode,
|
||||
&winmla::DmlExecutionProviderFlushContext,
|
||||
&winmla::DmlExecutionProviderTrimUploadHeap,
|
||||
&winmla::DmlExecutionProviderReleaseCompletedReferences,
|
||||
&winmla::DmlCreateGPUAllocationFromD3DResource,
|
||||
&winmla::DmlFreeGPUAllocation,
|
||||
|
|
|
|||
|
|
@ -303,14 +303,6 @@ struct WinmlAdapterApi {
|
|||
*/
|
||||
OrtStatus*(ORT_API_CALL* DmlExecutionProviderFlushContext)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* DmlExecutionProviderTrimUploadHeap
|
||||
* This api is used to trim the upload heap in the DML EP.
|
||||
*
|
||||
* WinML communicates directly with DML to perform this as an optimization.
|
||||
*/
|
||||
OrtStatus*(ORT_API_CALL* DmlExecutionProviderTrimUploadHeap)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION;
|
||||
|
||||
/**
|
||||
* DmlExecutionProviderReleaseCompletedReferences
|
||||
* This api is used to release completed references after first run the DML EP.
|
||||
|
|
|
|||
|
|
@ -93,16 +93,6 @@ ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderFlushContext, _In_ OrtExecutionP
|
|||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider);
|
||||
Dml::TrimUploadHeap(dml_provider_internal);
|
||||
#endif // USE_DML
|
||||
return nullptr;
|
||||
API_IMPL_END
|
||||
}
|
||||
|
||||
ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider) {
|
||||
API_IMPL_BEGIN
|
||||
#ifdef USE_DML
|
||||
|
|
|
|||
|
|
@ -493,19 +493,6 @@ HRESULT OnnxruntimeEngine::FlushContext() {
|
|||
return S_OK;
|
||||
}
|
||||
|
||||
HRESULT OnnxruntimeEngine::TrimUploadHeap() {
|
||||
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
||||
|
||||
OrtExecutionProvider* ort_provider;
|
||||
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
|
||||
engine_factory_->UseOrtApi());
|
||||
|
||||
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderTrimUploadHeap(ort_provider),
|
||||
engine_factory_->UseOrtApi());
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
HRESULT OnnxruntimeEngine::ReleaseCompletedReferences() {
|
||||
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
|
||||
|
||||
|
|
|
|||
|
|
@ -76,8 +76,6 @@ class OnnxruntimeEngine : public Microsoft::WRL::RuntimeClass<
|
|||
() override;
|
||||
STDMETHOD(FlushContext)
|
||||
() override;
|
||||
STDMETHOD(TrimUploadHeap)
|
||||
() override;
|
||||
STDMETHOD(ReleaseCompletedReferences)
|
||||
() override;
|
||||
STDMETHOD(Sync)
|
||||
|
|
|
|||
|
|
@ -277,16 +277,6 @@ LearningModelSession::GetResults(
|
|||
// Update output providers
|
||||
auto outputs = binding_impl->UpdateProviders();
|
||||
|
||||
// Once the first evaluation following initialization is complete, and therefore the
|
||||
// initialization work is also complete, trim the upload heap. This is only done once
|
||||
// to avoid requiring the extra allocation during each evaluation.
|
||||
if (is_first_evaluate_) {
|
||||
if (is_gpu_evaluation) {
|
||||
engine_->TrimUploadHeap();
|
||||
}
|
||||
is_first_evaluate_ = false;
|
||||
}
|
||||
|
||||
// Create the return status object
|
||||
auto result = winrt::make<LearningModelEvaluationResult>();
|
||||
auto result_impl = result.as<winmlp::LearningModelEvaluationResult>();
|
||||
|
|
|
|||
|
|
@ -121,12 +121,6 @@ struct LearningModelSession : LearningModelSessionT<LearningModelSession> {
|
|||
// Synchronization
|
||||
CWinMLLock session_creation_lock_;
|
||||
CWinMLLock dml_ep_lock_;
|
||||
|
||||
// is_first_evaluate_ is used as a heuristic to determine
|
||||
// when the dml upload heap can be trimmed.
|
||||
bool is_first_evaluate_ = true;
|
||||
|
||||
|
||||
};
|
||||
|
||||
} // namespace winrt::Windows::AI::MachineLearning::implementation
|
||||
|
|
|
|||
|
|
@ -99,9 +99,6 @@ IEngine : IUnknown {
|
|||
STDMETHOD(FlushContext)
|
||||
() PURE;
|
||||
|
||||
STDMETHOD(TrimUploadHeap)
|
||||
() PURE;
|
||||
|
||||
STDMETHOD(ReleaseCompletedReferences)
|
||||
() PURE;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue