diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp index 64ff21a1dd..e809a20cc0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/FusedGraphKernel.cpp @@ -26,7 +26,7 @@ namespace Dml std::vector initInputBindings, std::vector& isInputsUploadedByDmlEP, std::vector& inputsUsed) : - OpKernel(kernelInfo), + OpKernel(kernelInfo), m_compiledExecutionPlanOperator(compiledExecutionPlanOperator), m_inputsUsed(inputsUsed), m_outputShapes(outputShapes), @@ -40,7 +40,7 @@ namespace Dml // We assume the execution object inherits IUnknown as its first base ComPtr providerExecutionObject = const_cast(static_cast(m_executionHandle)); - // Get the WinML-specific execution provider interface from the execution object. + // Get the WinML-specific execution provider interface from the execution object. ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider)); ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider)); } @@ -82,10 +82,10 @@ namespace Dml m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); std::for_each( - initializeResourceRefs.begin(), - initializeResourceRefs.end(), + initializeResourceRefs.begin(), + initializeResourceRefs.end(), [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } - ); + ); if (reuseCommandList) { @@ -97,7 +97,7 @@ namespace Dml { // Only re-use the cached command list if its prior execution is complete on the GPU. // This requirement can be avoided by mantaining ring buffers. - if (!m_graphicsCommandList || + if (!m_graphicsCommandList || (m_fence != nullptr && m_fence->GetCompletedValue() < m_completionValue)) { // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator @@ -109,7 +109,7 @@ namespace Dml ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); - // Get input resources for execution, excluding those which were specified as owned by DML and provided + // Get input resources for execution, excluding those which were specified as owned by DML and provided // at initialization instead. std::vector> inputTensors(kernelContext->InputCount()); std::vector inputPtrs(kernelContext->InputCount()); @@ -140,7 +140,7 @@ namespace Dml aux); ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); - + // Queue references to objects which must be kept alive until resulting GPU work completes m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); m_winmlProvider->QueueReference(m_persistentResourceAllocatorUnk.Get()); @@ -157,7 +157,7 @@ namespace Dml IDMLCompiledOperator* op, _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, gsl::span inputTensors, - gsl::span outputTensors) const + gsl::span outputTensors) const { auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span& tensors) { @@ -210,7 +210,7 @@ namespace Dml FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors); ORT_THROW_IF_FAILED(m_provider->ExecuteOperator( - op, + op, persistentResourceBinding, inputBindings, outputBindings)); @@ -228,7 +228,7 @@ namespace Dml desc.Flags = D3D12_DESCRIPTOR_HEAP_FLAG_SHADER_VISIBLE; desc.NumDescriptors = execBindingProps.RequiredDescriptorCount; desc.Type = D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV; - + ComPtr d3dDevice; ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(d3dDevice.GetAddressOf())); @@ -253,7 +253,7 @@ namespace Dml m_commandAllocator.Get(), nullptr, IID_GRAPHICS_PPV_ARGS(m_graphicsCommandList.ReleaseAndGetAddressOf()))); - + if (m_persistentResource) { DML_BINDING_DESC persistentResourceBindingDesc = @@ -275,7 +275,7 @@ namespace Dml void ExecuteReusableCommandList(onnxruntime::OpKernelContext* kernelContext) const { DML_BINDING_PROPERTIES execBindingProps = m_compiledExecutionPlanOperator->GetBindingProperties(); - + std::vector inputBindings(kernelContext->InputCount()); std::vector inputBindingDescs(kernelContext->InputCount()); @@ -285,7 +285,7 @@ namespace Dml true, nullptr); - // Populate input bindings, excluding those which were specified as owned by DML and provided + // Populate input bindings, excluding those which were specified as owned by DML and provided // at initialization instead. m_inputBindingAllocIds.resize(inputBindings.size()); bool inputBindingsChanged = false; @@ -314,7 +314,7 @@ namespace Dml } } } - + if (inputBindingsChanged) { m_bindingTable->BindInputs(gsl::narrow_cast(inputBindingDescs.size()), inputBindingDescs.data()); @@ -326,7 +326,7 @@ namespace Dml m_outputBindingAllocIds.resize(outputBindings.size()); bool outputBindingsChanged = false; - + for (uint32_t i = 0; i < outputBindings.size(); ++i) { std::vector outputDims; @@ -337,7 +337,7 @@ namespace Dml } onnxruntime::Tensor* tensor = kernelContext->Output( - static_cast(i), + static_cast(i), onnxruntime::TensorShape::FromExistingBuffer(outputDims) ); @@ -357,7 +357,7 @@ namespace Dml if (execBindingProps.TemporaryResourceSize > 0) { - // Allocate temporary data which will automatically be freed when the GPU work + // Allocate temporary data which will automatically be freed when the GPU work // which is scheduled up to the point that this method returns has completed. ComPtr tempAlloc; uint64_t tempAllocId = 0; @@ -365,7 +365,7 @@ namespace Dml ComPtr tempResourceUnk; m_winmlProvider->GetABIDataInterface(false, tempAlloc.Get(), &tempResourceUnk); - + // Bind the temporary resource. ComPtr tempResource; ORT_THROW_IF_FAILED(tempResourceUnk->QueryInterface(tempResource.GetAddressOf())); @@ -376,7 +376,7 @@ namespace Dml { m_bindingTable->BindTemporaryResource(&tempBindingDesc); } - + m_tempBindingAllocId = tempAllocId; } @@ -384,7 +384,16 @@ namespace Dml // re-used. ComPtr fence; uint64_t completionValue; - ORT_THROW_IF_FAILED(m_provider->ExecuteCommandList(m_graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue)); + HRESULT hr = m_provider->ExecuteCommandList(m_graphicsCommandList.Get(), fence.GetAddressOf(), &completionValue); + + if (hr == DXGI_ERROR_DEVICE_REMOVED) + { + ComPtr device; + ORT_THROW_IF_FAILED(m_provider->GetD3DDevice(&device)); + ORT_THROW_IF_FAILED(device->GetDeviceRemovedReason()); + } + + ORT_THROW_IF_FAILED(hr); m_fence = fence; m_completionValue = completionValue; @@ -410,13 +419,13 @@ namespace Dml std::optional m_persistentResourceBinding; ComPtr m_persistentResource; ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator - + // Bindings from previous executions of a re-used command list mutable std::vector m_inputBindingAllocIds; mutable std::vector m_outputBindingAllocIds; mutable uint64_t m_tempBindingAllocId = 0; - // Fence tracking the status of the command list's last execution, and whether its descriptor heap + // Fence tracking the status of the command list's last execution, and whether its descriptor heap // can safely be updated. mutable ComPtr m_fence; mutable uint64_t m_completionValue = 0; @@ -438,7 +447,7 @@ namespace Dml ) { return new FusedGraphKernel( - info, + info, compiledExecutionPlanOperator, outputShapes, reuseCommandList,