Fix com ptr refcount (#5404)

This commit is contained in:
Tiago Koji Castro Shibata 2020-10-08 10:18:38 -07:00 committed by GitHub
parent b04cf2d229
commit 83ead3e2eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 5 deletions

View file

@ -146,6 +146,7 @@ ORT_API_STATUS_IMPL(winmla::DmlGetD3D12ResourceFromAllocation, _In_ OrtExecution
Dml::GetD3D12ResourceFromAllocation(
dml_provider_internal->GetAllocator(0, ::OrtMemType::OrtMemTypeDefault).get(),
allocation);
(*d3d_resource)->AddRef();
#endif // USE_DML USE_DML
return nullptr;
API_IMPL_END
@ -173,4 +174,4 @@ ORT_API_STATUS_IMPL(winmla::DmlCopyTensor, _In_ OrtExecutionProvider* dml_provid
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Out of memory");
#endif // USE_DML USE_DML
API_IMPL_END
}
}

View file

@ -179,11 +179,11 @@ HRESULT OnnxruntimeValue::GetResource(_winml::Resource& out) {
bool is_cpu = false;
if (SUCCEEDED(IsCpu(&is_cpu)) && !is_cpu) {
void* resource;
winrt::com_ptr<ID3D12Resource> resource;
RETURN_HR_IF_NOT_OK_MSG(winml_adapter_api->DmlGetD3D12ResourceFromAllocation(ort_provider, mutable_data,
reinterpret_cast<ID3D12Resource**>(&resource)),
resource.put()),
ort_api);
out = _winml::Resource(resource, [](void*) { /*do nothing, as this pointer is actually a com pointer! */ });
out = _winml::Resource(resource.get(), [](void*) { /*do nothing, as this pointer is actually a com pointer! */ });
} else {
int is_tensor;
RETURN_HR_IF_NOT_OK_MSG(ort_api->IsTensor(value_.get(), &is_tensor),
@ -1360,4 +1360,4 @@ STDAPI CreateOnnxruntimeEngineFactory(_Out_ _winml::IEngineFactory** engine_fact
RETURN_IF_FAILED(Microsoft::WRL::MakeAndInitialize<OnnxruntimeEngineFactory>(&onnxruntime_engine_factory));
RETURN_IF_FAILED(onnxruntime_engine_factory.CopyTo(engine_factory));
return S_OK;
}
}

View file

@ -135,6 +135,7 @@ winrt::com_ptr<ID3D12Resource> CreateD3D12Resource(ID3D12Device& device) {
return d3d12_resource;
}
void DmlCreateAndFreeGPUAllocationFromD3DResource() {
GPUTEST;
winrt::com_ptr<ID3D12Device> device;