Flush and trim resources in DML EP in new OnSessionInitializationEnd method

This commit is contained in:
Jeff 2020-04-14 21:34:49 -07:00
parent b63349c8d6
commit e89dd92387
15 changed files with 41 additions and 103 deletions

View file

@ -136,6 +136,13 @@ class IExecutionProvider {
*/
virtual common::Status OnRunEnd();
/**
Called when session creation is complete
This provides an opportunity for execution providers to optionally synchronize and
clean up its temporary resources to reduce memory and ensure the first run is fast.
*/
virtual common::Status OnSessionInitializationEnd();
void InsertAllocator(AllocatorPtr allocator);
/**

View file

@ -49,6 +49,8 @@ common::Status IExecutionProvider::OnRunStart() { return Status::OK(); }
common::Status IExecutionProvider::OnRunEnd() { return Status::OK(); }
common::Status IExecutionProvider::OnSessionInitializationEnd() { return Status::OK(); }
void IExecutionProvider::InsertAllocator(AllocatorPtr allocator) {
const OrtMemoryInfo& info = allocator->Info();
const int key = MakeKey(info.id, info.mem_type);

View file

@ -34,8 +34,6 @@ namespace Dml
void FlushContext(onnxruntime::IExecutionProvider* provider);
void SetDefaultRoundingMode(onnxruntime::IExecutionProvider* provider, AllocatorRoundingMode roundingMode);
void ReleaseCompletedReferences(onnxruntime::IExecutionProvider* provider);
void TrimUploadHeap(onnxruntime::IExecutionProvider * provider);
void WaitForGpuCompletion(onnxruntime::IExecutionProvider * provider);
onnxruntime::common::Status CopyTensor(
onnxruntime::IExecutionProvider* provider,

View file

@ -532,17 +532,6 @@ namespace Dml
return onnxruntime::common::Status::OK();
}
Status ExecutionProviderImpl::WaitForGpuCompletion()
{
assert(!m_closed);
Flush();
m_context->GetCurrentCompletionEvent().WaitForSignal();
m_context->ReleaseCompletedReferences();
return Status::OK();
}
void __stdcall ExecutionProviderImpl::Flush() const
{
assert(!m_closed);
@ -558,11 +547,6 @@ namespace Dml
{
m_context->ReleaseCompletedReferences();
}
void ExecutionProviderImpl::TrimUploadHeap()
{
m_uploadHeap->Trim();
}
void ExecutionProviderImpl::QueueReference(IUnknown* object)
{
@ -701,6 +685,20 @@ namespace Dml
return m_cpuOutputAllocator;
}
onnxruntime::common::Status ExecutionProviderImpl::OnSessionInitializationEnd()
{
// Flush and trim resources, including staging memory used to upload weights.
// This reduces memory usage immediately after session creation, and avoids
// performance impact of deallocation during first evaluation.
Flush();
m_context->GetCurrentCompletionEvent().WaitForSignal();
m_context->ReleaseCompletedReferences();
m_uploadHeap->Trim();
return onnxruntime::common::Status::OK();
}
std::unique_ptr<onnxruntime::IExecutionProvider> CreateExecutionProvider(
IDMLDevice* dmlDevice,
ID3D12CommandQueue* commandQueue,
@ -733,18 +731,6 @@ namespace Dml
dmlexecutionprovider->ReleaseCompletedReferences();
}
void TrimUploadHeap(onnxruntime::IExecutionProvider * provider)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
dmlexecutionprovider->TrimUploadHeap();
}
void WaitForGpuCompletion(onnxruntime::IExecutionProvider * provider)
{
ExecutionProvider* dmlexecutionprovider = static_cast<Dml::ExecutionProvider*>(provider);
dmlexecutionprovider->WaitForGpuCompletion();
}
onnxruntime::common::Status CopyTensor(
onnxruntime::IExecutionProvider* provider,
const onnxruntime::Tensor& src,

View file

@ -39,8 +39,6 @@ namespace Dml
void ReleaseCompletedReferences();
void TrimUploadHeap();
public: // implements Dml::IExecutionProvider
STDMETHOD(GetD3DDevice)(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept final;
@ -92,7 +90,6 @@ namespace Dml
uint32_t GetSuppportedDeviceDataTypeMask() const;
onnxruntime::common::Status CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const;
onnxruntime::common::Status WaitForGpuCompletion();
// IWinmlExecutionProvider methods
void QueueReference(IUnknown* object) override;
@ -157,7 +154,9 @@ namespace Dml
std::shared_ptr<onnxruntime::IAllocator> GetCpuOutputAllocator();
std::shared_ptr<const Windows::AI::MachineLearning::Adapter::InternalRegistrationInfoMap>
GetInternalRegistrationInfoMap() const;
GetInternalRegistrationInfoMap() const;
onnxruntime::common::Status OnSessionInitializationEnd();
private:
void Initialize(ID3D12CommandQueue* queue, ExecutionProvider& executionProvider);
@ -221,31 +220,28 @@ namespace Dml
bool enableMetacommands = true
);
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const final
std::unique_ptr<onnxruntime::IDataTransfer> GetDataTransfer() const final override
{
return std::make_unique<DataTransfer>(m_impl.Get());
}
const void* GetExecutionHandle() const noexcept final
const void* GetExecutionHandle() const noexcept final override
{
return m_impl.Get();
}
std::shared_ptr<onnxruntime::KernelRegistry> GetKernelRegistry() const final
std::shared_ptr<onnxruntime::KernelRegistry> GetKernelRegistry() const final override
{
return m_impl->GetKernelRegistry();
}
std::vector<std::unique_ptr<onnxruntime::ComputeCapability>>
GetCapability(const onnxruntime::GraphViewer& graph,
const std::vector<const onnxruntime::KernelRegistry*>& kernel_registries) const final;
const std::vector<const onnxruntime::KernelRegistry*>& kernel_registries) const final override;
// Not to be confused with IExecutionProvider::Sync() const. The DML provider handles
// synchronization when copying inputs and outputs, therefore doesn't override the
// default ORT method, which does nothin.
onnxruntime::common::Status WaitForGpuCompletion()
{
return m_impl->WaitForGpuCompletion();
onnxruntime::common::Status OnSessionInitializationEnd() override
{
return m_impl->OnSessionInitializationEnd();
}
void Flush()
@ -262,11 +258,6 @@ namespace Dml
{
return m_impl->ReleaseCompletedReferences();
}
void TrimUploadHeap()
{
m_impl->TrimUploadHeap();
}
ExecutionProviderImpl* GetImpl()
{

View file

@ -929,6 +929,14 @@ common::Status InferenceSession::Initialize() {
if (session_profiler_.IsEnabled()) {
session_profiler_.EndTimeAndRecordEvent(profiling::SESSION_EVENT, "session_initialization", tp);
}
if (status.IsOK()) {
auto retval = status;
for (auto& xp : execution_providers_) {
auto status = xp->OnSessionInitializationEnd();
ORT_CHECK_AND_SET_RETVAL(status);
}
}
return status;
}

View file

@ -58,7 +58,6 @@ ORT_API_STATUS(SessionCopyOneInputAcrossDevices, _In_ OrtSession* session, _In_
// Dml methods (TODO need to figure out how these need to move to session somehow...)
ORT_API_STATUS(DmlExecutionProviderSetDefaultRoundingMode, _In_ OrtExecutionProvider* dml_provider, _In_ bool is_enabled);
ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider);
ORT_API_STATUS(DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider);
ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider);
ORT_API_STATUS(DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void** dml_resource);
ORT_API_STATUS(DmlGetD3D12ResourceFromAllocation, _In_ OrtExecutionProvider* provider, _In_ void* allocation, _Out_ ID3D12Resource** resource);

View file

@ -59,7 +59,6 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
// Dml methods (TODO need to figure out how these need to move to session somehow...)
&winmla::DmlExecutionProviderSetDefaultRoundingMode,
&winmla::DmlExecutionProviderFlushContext,
&winmla::DmlExecutionProviderTrimUploadHeap,
&winmla::DmlExecutionProviderReleaseCompletedReferences,
&winmla::DmlCreateGPUAllocationFromD3DResource,
&winmla::DmlFreeGPUAllocation,

View file

@ -303,14 +303,6 @@ struct WinmlAdapterApi {
*/
OrtStatus*(ORT_API_CALL* DmlExecutionProviderFlushContext)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION;
/**
* DmlExecutionProviderTrimUploadHeap
* This api is used to trim the upload heap in the DML EP.
*
* WinML communicates directly with DML to perform this as an optimization.
*/
OrtStatus*(ORT_API_CALL* DmlExecutionProviderTrimUploadHeap)(_In_ OrtExecutionProvider* dml_provider)NO_EXCEPTION;
/**
* DmlExecutionProviderReleaseCompletedReferences
* This api is used to release completed references after first run the DML EP.

View file

@ -93,16 +93,6 @@ ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderFlushContext, _In_ OrtExecutionP
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider) {
API_IMPL_BEGIN
#ifdef USE_DML
auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider);
Dml::TrimUploadHeap(dml_provider_internal);
#endif // USE_DML
return nullptr;
API_IMPL_END
}
ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider) {
API_IMPL_BEGIN
#ifdef USE_DML

View file

@ -493,19 +493,6 @@ HRESULT OnnxruntimeEngine::FlushContext() {
return S_OK;
}
HRESULT OnnxruntimeEngine::TrimUploadHeap() {
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();
OrtExecutionProvider* ort_provider;
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider),
engine_factory_->UseOrtApi());
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlExecutionProviderTrimUploadHeap(ort_provider),
engine_factory_->UseOrtApi());
return S_OK;
}
HRESULT OnnxruntimeEngine::ReleaseCompletedReferences() {
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();

View file

@ -76,8 +76,6 @@ class OnnxruntimeEngine : public Microsoft::WRL::RuntimeClass<
() override;
STDMETHOD(FlushContext)
() override;
STDMETHOD(TrimUploadHeap)
() override;
STDMETHOD(ReleaseCompletedReferences)
() override;
STDMETHOD(Sync)

View file

@ -277,16 +277,6 @@ LearningModelSession::GetResults(
// Update output providers
auto outputs = binding_impl->UpdateProviders();
// Once the first evaluation following initialization is complete, and therefore the
// initialization work is also complete, trim the upload heap. This is only done once
// to avoid requiring the extra allocation during each evaluation.
if (is_first_evaluate_) {
if (is_gpu_evaluation) {
engine_->TrimUploadHeap();
}
is_first_evaluate_ = false;
}
// Create the return status object
auto result = winrt::make<LearningModelEvaluationResult>();
auto result_impl = result.as<winmlp::LearningModelEvaluationResult>();

View file

@ -121,12 +121,6 @@ struct LearningModelSession : LearningModelSessionT<LearningModelSession> {
// Synchronization
CWinMLLock session_creation_lock_;
CWinMLLock dml_ep_lock_;
// is_first_evaluate_ is used as a heuristic to determine
// when the dml upload heap can be trimmed.
bool is_first_evaluate_ = true;
};
} // namespace winrt::Windows::AI::MachineLearning::implementation

View file

@ -99,9 +99,6 @@ IEngine : IUnknown {
STDMETHOD(FlushContext)
() PURE;
STDMETHOD(TrimUploadHeap)
() PURE;
STDMETHOD(ReleaseCompletedReferences)
() PURE;