mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-30 03:37:44 +00:00
Fix LoadFromStream to not use wss::Buffer internally (#9918)
Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
This commit is contained in:
parent
8db49e3d0f
commit
5edaa75ef6
5 changed files with 514 additions and 77 deletions
|
|
@ -15,6 +15,50 @@
|
|||
#include <robuffer.h>
|
||||
|
||||
namespace WINMLP {
|
||||
|
||||
// IBuffer implementation to avoid calling into WinTypes.dll to create wss::Buffer.
|
||||
// This will enable model creation on VTL1 without pulling in additional binaries on load.
|
||||
template <typename T>
|
||||
class STLVectorBackedBuffer : public winrt::implements<
|
||||
STLVectorBackedBuffer<T>,
|
||||
wss::IBuffer,
|
||||
::Windows::Storage::Streams::IBufferByteAccess> {
|
||||
private:
|
||||
std::vector<T> data_;
|
||||
size_t length_ = 0;
|
||||
|
||||
public:
|
||||
STLVectorBackedBuffer(size_t num_elements) : data_(num_elements) {}
|
||||
|
||||
uint32_t Capacity() const try {
|
||||
// Return the size of the backing vector in bytes
|
||||
return static_cast<uint32_t>(data_.size() * sizeof(T));
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
uint32_t Length() const try {
|
||||
// Return the used buffer in bytes
|
||||
return static_cast<uint32_t>(length_);
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
void Length(uint32_t value) try {
|
||||
// Set the use buffer length in bytes
|
||||
WINML_THROW_HR_IF_TRUE_MSG(E_INVALIDARG, value > Capacity(), "Parameter 'value' cannot be greater than the buffer's capacity.");
|
||||
length_ = value;
|
||||
}
|
||||
WINML_CATCH_ALL
|
||||
|
||||
STDMETHOD(Buffer)
|
||||
(_Outptr_ BYTE** value) {
|
||||
// Return the buffer
|
||||
RETURN_HR_IF_NULL(E_POINTER, value);
|
||||
*value = reinterpret_cast<BYTE*>(data_.data());
|
||||
return S_OK;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
LearningModel::LearningModel(
|
||||
const hstring& path,
|
||||
const winml::ILearningModelOperatorProvider op_provider) try : operator_provider_(op_provider) {
|
||||
|
|
@ -90,12 +134,11 @@ static HRESULT CreateModelFromStream(
|
|||
_winml::IModel** model) {
|
||||
auto content = stream.OpenReadAsync().get();
|
||||
|
||||
wss::Buffer buffer(static_cast<uint32_t>(content.Size()));
|
||||
auto buffer = winrt::make<STLVectorBackedBuffer<BYTE>>(static_cast<size_t>(content.Size()));
|
||||
auto result = content.ReadAsync(
|
||||
buffer,
|
||||
buffer.Capacity(),
|
||||
wss::InputStreamOptions::None)
|
||||
.get();
|
||||
wss::InputStreamOptions::None).get();
|
||||
|
||||
auto bytes = buffer.try_as<::Windows::Storage::Streams::IBufferByteAccess>();
|
||||
WINML_THROW_HR_IF_NULL_MSG(E_UNEXPECTED, bytes, "Model stream is invalid.");
|
||||
|
|
|
|||
|
|
@ -0,0 +1,363 @@
|
|||
// Copyright 2019 Microsoft Corporation. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style license that can be
|
||||
// found in the LICENSE file.
|
||||
#ifndef RANDOM_ACCESS_STREAM_H
|
||||
#define RANDOM_ACCESS_STREAM_H
|
||||
|
||||
#include <wrl.h>
|
||||
#include <wrl/client.h>
|
||||
|
||||
#include <windows.storage.streams.h>
|
||||
#include <robuffer.h>
|
||||
|
||||
#include <istream>
|
||||
|
||||
namespace WinMLTest {
|
||||
|
||||
struct BufferBackedRandomAccessStreamReadAsync
|
||||
: public Microsoft::WRL::RuntimeClass<
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
|
||||
__FIAsyncOperationWithProgress_2_Windows__CStorage__CStreams__CIBuffer_UINT32,
|
||||
ABI::Windows::Foundation::IAsyncInfo> {
|
||||
|
||||
InspectableClass(L"WinMLTest.BufferBackedRandomAccessStreamReadAsync", BaseTrust)
|
||||
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer> buffer_;
|
||||
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>> completed_handler_;
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Foundation::IAsyncOperationProgressHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>> progress_handler_;
|
||||
|
||||
AsyncStatus status_ = AsyncStatus::Started;
|
||||
|
||||
public:
|
||||
|
||||
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Id(
|
||||
/* [retval][out] */ __RPC__out unsigned __int32* id) override {
|
||||
*id = 0; // Do we need to implement this?
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Status(
|
||||
/* [retval][out] */ __RPC__out AsyncStatus* status) override {
|
||||
*status = status_;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_ErrorCode(
|
||||
/* [retval][out] */ __RPC__out HRESULT* /*errorCode*/) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE Cancel(void) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE Close(void) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
|
||||
|
||||
HRESULT SetBuffer(ABI::Windows::Storage::Streams::IBuffer* buffer) {
|
||||
buffer_ = buffer;
|
||||
status_ = AsyncStatus::Completed;
|
||||
if (buffer_ != nullptr) {
|
||||
if (completed_handler_ != nullptr) {
|
||||
completed_handler_->Invoke(this, ABI::Windows::Foundation::AsyncStatus::Completed);
|
||||
}
|
||||
}
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE put_Progress(
|
||||
ABI::Windows::Foundation::IAsyncOperationProgressHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>* handler) override {
|
||||
progress_handler_ = handler;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE get_Progress(ABI::Windows::Foundation::IAsyncOperationProgressHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>** handler)override {
|
||||
progress_handler_.CopyTo(handler);
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE put_Completed(ABI::Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>* handler) override {
|
||||
completed_handler_ = handler;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE get_Completed(ABI::Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>** handler) override {
|
||||
completed_handler_.CopyTo(handler);
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE GetResults(ABI::Windows::Storage::Streams::IBuffer** results) override {
|
||||
if (buffer_ == nullptr) {
|
||||
return E_FAIL;
|
||||
}
|
||||
|
||||
buffer_.CopyTo(results);
|
||||
return S_OK;
|
||||
}
|
||||
};
|
||||
|
||||
struct RandomAccessStream
|
||||
: public Microsoft::WRL::RuntimeClass<
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
|
||||
ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType,
|
||||
ABI::Windows::Storage::Streams::IContentTypeProvider,
|
||||
ABI::Windows::Storage::Streams::IRandomAccessStream,
|
||||
ABI::Windows::Storage::Streams::IInputStream,
|
||||
ABI::Windows::Storage::Streams::IOutputStream,
|
||||
ABI::Windows::Foundation::IClosable> {
|
||||
InspectableClass(L"WinMLTest.RandomAccessStream", BaseTrust)
|
||||
|
||||
private:
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer> buffer_ = nullptr;
|
||||
UINT64 position_ = 0;
|
||||
|
||||
public:
|
||||
HRESULT RuntimeClassInitialize(ABI::Windows::Storage::Streams::IBuffer* buffer) {
|
||||
buffer_ = buffer;
|
||||
position_ = 0;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
HRESULT RuntimeClassInitialize(ABI::Windows::Storage::Streams::IBuffer* buffer, UINT64 position) {
|
||||
buffer_ = buffer;
|
||||
position_ = position;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
// Content Provider
|
||||
|
||||
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_ContentType(
|
||||
/* [retval, out] */__RPC__deref_out_opt HSTRING* value
|
||||
) override {
|
||||
return WindowsCreateString(nullptr, 0, value);
|
||||
}
|
||||
|
||||
// IRandomAccessStream
|
||||
|
||||
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_Size(
|
||||
/* [retval, out] */__RPC__out UINT64* value
|
||||
) override {
|
||||
*value = 0;
|
||||
uint32_t length;
|
||||
buffer_->get_Length(&length);
|
||||
*value = static_cast<uint64_t>(length);
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
/* [propput] */virtual HRESULT STDMETHODCALLTYPE put_Size(
|
||||
/* [in] */UINT64 /*value*/
|
||||
) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE GetInputStreamAt(
|
||||
/* [in] */UINT64 position,
|
||||
/* [retval, out] */__RPC__deref_out_opt ABI::Windows::Storage::Streams::IInputStream** stream
|
||||
) override {
|
||||
return Microsoft::WRL::MakeAndInitialize<RandomAccessStream>(stream, buffer_.Get(), position);
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE GetOutputStreamAt(
|
||||
/* [in] */UINT64 /*position*/,
|
||||
/* [retval, out] */__RPC__deref_out_opt ABI::Windows::Storage::Streams::IOutputStream** /*stream*/
|
||||
) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_Position(
|
||||
/* [retval, out] */__RPC__out UINT64* value
|
||||
) override {
|
||||
*value = position_;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE Seek(
|
||||
/* [in] */UINT64 position
|
||||
) override {
|
||||
position_ = position;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE CloneStream(
|
||||
/* [retval, out] */__RPC__deref_out_opt ABI::Windows::Storage::Streams::IRandomAccessStream** stream
|
||||
) override {
|
||||
return Microsoft::WRL::MakeAndInitialize<RandomAccessStream>(stream, buffer_.Get(), 0);
|
||||
}
|
||||
|
||||
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_CanRead(
|
||||
/* [retval, out] */__RPC__out::boolean* value
|
||||
) override {
|
||||
UINT32 length;
|
||||
buffer_->get_Length(&length);
|
||||
*value = buffer_ != nullptr && position_ < static_cast<UINT64>(length);
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_CanWrite(
|
||||
/* [retval, out] */__RPC__out::boolean* value
|
||||
) override {
|
||||
*value = false;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
// IInputStream
|
||||
virtual HRESULT STDMETHODCALLTYPE ReadAsync(
|
||||
/* [in] */__RPC__in_opt ABI::Windows::Storage::Streams::IBuffer* buffer,
|
||||
/* [in] */UINT32 count,
|
||||
/* [in] */ABI::Windows::Storage::Streams::InputStreamOptions /*options*/,
|
||||
/* [retval, out] */__RPC__deref_out_opt __FIAsyncOperationWithProgress_2_Windows__CStorage__CStreams__CIBuffer_UINT32** operation
|
||||
) override {
|
||||
auto read_async = Microsoft::WRL::Make<BufferBackedRandomAccessStreamReadAsync>();
|
||||
read_async.CopyTo(operation);
|
||||
|
||||
// perform the "async work" which is actually synchronous atm
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer> spBuffer = buffer;
|
||||
Microsoft::WRL::ComPtr<Windows::Storage::Streams::IBufferByteAccess> out_buffer_byte_access;
|
||||
spBuffer.As<Windows::Storage::Streams::IBufferByteAccess>(&out_buffer_byte_access);
|
||||
byte* out_bytes = nullptr;
|
||||
out_buffer_byte_access->Buffer(&out_bytes);
|
||||
|
||||
Microsoft::WRL::ComPtr<Windows::Storage::Streams::IBufferByteAccess> in_buffer_byte_access;
|
||||
buffer_.As<Windows::Storage::Streams::IBufferByteAccess>(&in_buffer_byte_access);
|
||||
byte* in_bytes = nullptr;
|
||||
in_buffer_byte_access->Buffer(&in_bytes);
|
||||
|
||||
memcpy(out_bytes, in_bytes + static_cast<uint32_t>(position_), count);
|
||||
|
||||
read_async->SetBuffer(buffer);
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
// IOutputStream
|
||||
virtual HRESULT STDMETHODCALLTYPE WriteAsync(
|
||||
/* [in] */__RPC__in_opt ABI::Windows::Storage::Streams::IBuffer* /*buffer*/,
|
||||
/* [retval, out] */__RPC__deref_out_opt __FIAsyncOperationWithProgress_2_UINT32_UINT32** /*operation*/
|
||||
) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE FlushAsync(
|
||||
/* [retval, out] */__RPC__deref_out_opt __FIAsyncOperation_1_boolean** /*operation*/
|
||||
) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
// IClosable
|
||||
virtual HRESULT STDMETHODCALLTYPE Close(void) override {
|
||||
buffer_ = nullptr;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct BufferBackedRandomAccessStreamReferenceOpenReadAsync
|
||||
: public Microsoft::WRL::RuntimeClass<
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
|
||||
__FIAsyncOperation_1_Windows__CStorage__CStreams__CIRandomAccessStreamWithContentType,
|
||||
ABI::Windows::Foundation::IAsyncInfo> {
|
||||
|
||||
InspectableClass(L"WinMLTest.BufferBackedRandomAccessStreamReferenceOpenReadAsync", BaseTrust)
|
||||
public:
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType> ras_;
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Foundation::IAsyncOperationCompletedHandler<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType*>> completed_handler_;
|
||||
AsyncStatus status_ = AsyncStatus::Started;
|
||||
|
||||
HRESULT SetRandomAccessStream(ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType* ras) {
|
||||
ras_ = ras;
|
||||
status_ = AsyncStatus::Completed;
|
||||
if (ras_ != nullptr) {
|
||||
if (completed_handler_ != nullptr) {
|
||||
completed_handler_->Invoke(this, status_);
|
||||
}
|
||||
}
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Id(
|
||||
/* [retval][out] */ __RPC__out unsigned __int32* id) override {
|
||||
*id = 0; // Do we need to implement this?
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Status(
|
||||
/* [retval][out] */ __RPC__out AsyncStatus* status) override {
|
||||
*status = status_;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_ErrorCode(
|
||||
/* [retval][out] */ __RPC__out HRESULT* /*errorCode*/) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE Cancel(void) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE Close(void) override {
|
||||
return E_NOTIMPL;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE put_Completed(
|
||||
ABI::Windows::Foundation::IAsyncOperationCompletedHandler<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType*>* handler) override
|
||||
{
|
||||
completed_handler_ = handler;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE get_Completed(
|
||||
ABI::Windows::Foundation::IAsyncOperationCompletedHandler<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType*>** handler) override {
|
||||
completed_handler_.CopyTo(handler);
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE GetResults(
|
||||
ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType** results) override {
|
||||
if (ras_ == nullptr) {
|
||||
return E_FAIL;
|
||||
}
|
||||
ras_.CopyTo(results);
|
||||
return S_OK;
|
||||
}
|
||||
};
|
||||
|
||||
struct BufferBackedRandomAccessStreamReference
|
||||
: public Microsoft::WRL::RuntimeClass<
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
|
||||
ABI::Windows::Storage::Streams::IRandomAccessStreamReference> {
|
||||
InspectableClass(L"WinMLTest.BufferBackedRandomAccessStreamReference", BaseTrust)
|
||||
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer> buffer_ = nullptr;
|
||||
|
||||
public:
|
||||
HRESULT RuntimeClassInitialize(ABI::Windows::Storage::Streams::IBuffer* buffer) {
|
||||
buffer_ = buffer;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE OpenReadAsync(
|
||||
/* [retval, out] */__RPC__deref_out_opt __FIAsyncOperation_1_Windows__CStorage__CStreams__CIRandomAccessStreamWithContentType** operation
|
||||
) override {
|
||||
auto open_read_async = Microsoft::WRL::Make<BufferBackedRandomAccessStreamReferenceOpenReadAsync>();
|
||||
open_read_async.CopyTo(operation);
|
||||
|
||||
Microsoft::WRL::ComPtr<RandomAccessStream> ras;
|
||||
Microsoft::WRL::MakeAndInitialize<RandomAccessStream>(&ras, buffer_.Get());
|
||||
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType> ras_interface = nullptr;
|
||||
ras.As<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType>(&ras_interface);
|
||||
|
||||
open_read_async.Get()->SetRandomAccessStream(ras_interface.Get());
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace WinMLTest
|
||||
|
||||
#endif // RANDOM_ACCESS_STREAM_H
|
||||
|
|
@ -10,14 +10,16 @@
|
|||
#include <windows.storage.streams.h>
|
||||
#include <robuffer.h>
|
||||
|
||||
namespace Microsoft { namespace AI { namespace MachineLearning { namespace Details {
|
||||
namespace WinMLTest {
|
||||
|
||||
template <typename T>
|
||||
struct weak_buffer
|
||||
struct WeakBuffer
|
||||
: public Microsoft::WRL::RuntimeClass<
|
||||
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
|
||||
ABI::Windows::Storage::Streams::IBuffer,
|
||||
Windows::Storage::Streams::IBufferByteAccess> {
|
||||
InspectableClass(L"WinMLTest.WeakBuffer", BaseTrust)
|
||||
|
||||
private:
|
||||
const T* m_p_begin;
|
||||
const T* m_p_end;
|
||||
|
|
@ -42,9 +44,13 @@ public:
|
|||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE get_Length(
|
||||
UINT32 * /*value*/)
|
||||
UINT32 * value)
|
||||
{
|
||||
return E_NOTIMPL;
|
||||
if (value == nullptr) {
|
||||
return E_POINTER;
|
||||
}
|
||||
*value = static_cast<uint32_t>(m_p_end - m_p_begin) * sizeof(T);
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
virtual HRESULT STDMETHODCALLTYPE put_Length(
|
||||
|
|
@ -64,6 +70,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
}}}} // namespace Microsoft::AI::MachineLearning::Details
|
||||
} // namespace WinMLTest
|
||||
|
||||
#endif // WEAK_BUFFER_H
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#define WINML_H_
|
||||
|
||||
#include "weak_buffer.h"
|
||||
#include "buffer_backed_random_access_stream_reference.h"
|
||||
#include "weak_single_threaded_iterable.h"
|
||||
|
||||
#define RETURN_HR_IF_FAILED(expression) \
|
||||
|
|
@ -213,7 +214,7 @@ public:
|
|||
|
||||
WinMLLearningModel(const char* bytes, size_t size)
|
||||
{
|
||||
ML_FAIL_FAST_IF(0 != Initialize(bytes, size));
|
||||
ML_FAIL_FAST_IF(0 != Initialize(bytes, size, false /*dont copy*/));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -264,7 +265,7 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
int32_t Initialize(const char* bytes, size_t size)
|
||||
int32_t Initialize(const char* bytes, size_t size, bool with_copy = false)
|
||||
{
|
||||
auto hr = RoInitialize(RO_INIT_TYPE::RO_INIT_SINGLETHREADED);
|
||||
// https://docs.microsoft.com/en-us/windows/win32/api/roapi/nf-roapi-roinitialize#return-value
|
||||
|
|
@ -273,15 +274,17 @@ private:
|
|||
return static_cast<int32_t>(hr);
|
||||
}
|
||||
|
||||
// Create in memory stream
|
||||
Microsoft::WRL::ComPtr<IInspectable> in_memory_random_access_stream_insp;
|
||||
RETURN_HR_IF_FAILED(RoActivateInstance(
|
||||
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(),
|
||||
in_memory_random_access_stream_insp.GetAddressOf()));
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReference> random_access_stream_ref;
|
||||
if (with_copy) {
|
||||
// Create in memory stream
|
||||
Microsoft::WRL::ComPtr<IInspectable> in_memory_random_access_stream_insp;
|
||||
RETURN_HR_IF_FAILED(RoActivateInstance(
|
||||
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(),
|
||||
in_memory_random_access_stream_insp.GetAddressOf()));
|
||||
|
||||
// QI memory stream to output stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IOutputStream> output_stream;
|
||||
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream));
|
||||
// QI memory stream to output stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IOutputStream> output_stream;
|
||||
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream));
|
||||
|
||||
// Create data writer factory
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriterFactory> activation_factory;
|
||||
|
|
@ -297,7 +300,7 @@ private:
|
|||
|
||||
// Write the model to the data writer and thus to the stream
|
||||
RETURN_HR_IF_FAILED(
|
||||
data_writer->WriteBytes(static_cast<uint32_t>(size), reinterpret_cast<BYTE*>(const_cast<char *>(bytes))));
|
||||
data_writer->WriteBytes(static_cast<uint32_t>(size), reinterpret_cast<BYTE*>(const_cast<char*>(bytes))));
|
||||
|
||||
// QI the in memory stream to a random access stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStream> random_access_stream;
|
||||
|
|
@ -311,25 +314,35 @@ private:
|
|||
|
||||
// Create a random access stream reference from the random access stream view on top of
|
||||
// the in memory stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReference> random_access_stream_ref;
|
||||
RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream(
|
||||
random_access_stream.Get(),
|
||||
random_access_stream_ref.GetAddressOf()));
|
||||
|
||||
// Create a learning model factory
|
||||
Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelStatics> learning_model;
|
||||
RETURN_HR_IF_FAILED(
|
||||
GetActivationFactory(
|
||||
RuntimeClass_Microsoft_AI_MachineLearning_LearningModel,
|
||||
ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics,
|
||||
&learning_model));
|
||||
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Foundation::IAsyncOperation<uint32_t>> async_operation;
|
||||
RETURN_HR_IF_FAILED(data_writer->StoreAsync(&async_operation));
|
||||
auto store_completed_handler = Microsoft::WRL::Make<StoreCompleted>();
|
||||
RETURN_HR_IF_FAILED(async_operation->put_Completed(store_completed_handler.Get()));
|
||||
RETURN_HR_IF_FAILED(store_completed_handler->Wait());
|
||||
|
||||
} else {
|
||||
Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<char>> buffer;
|
||||
RETURN_HR_IF_FAILED(
|
||||
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<char>>(
|
||||
&buffer, bytes, bytes + size));
|
||||
|
||||
RETURN_HR_IF_FAILED(
|
||||
Microsoft::WRL::MakeAndInitialize<WinMLTest::BufferBackedRandomAccessStreamReference>(
|
||||
&random_access_stream_ref, buffer.Get()));
|
||||
}
|
||||
|
||||
// Create a learning model factory
|
||||
Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelStatics> learning_model;
|
||||
RETURN_HR_IF_FAILED(
|
||||
GetActivationFactory(
|
||||
RuntimeClass_Microsoft_AI_MachineLearning_LearningModel,
|
||||
ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics,
|
||||
&learning_model));
|
||||
|
||||
// Create a learning model from the factory with the random access stream reference that points
|
||||
// to the random access stream view on top of the in memory stream copy of the model
|
||||
RETURN_HR_IF_FAILED(
|
||||
|
|
@ -459,9 +472,9 @@ public:
|
|||
TensorFactory2IID<T>::IID,
|
||||
&tensor_factory));
|
||||
|
||||
Microsoft::WRL::ComPtr<weak_buffer<T>> buffer;
|
||||
Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<T>> buffer;
|
||||
RETURN_HR_IF_FAILED(
|
||||
Microsoft::WRL::MakeAndInitialize<weak_buffer<T>>(
|
||||
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(
|
||||
&buffer, p_data, p_data + data_size));
|
||||
|
||||
Microsoft::WRL::ComPtr<ITensor> tensor;
|
||||
|
|
@ -493,7 +506,7 @@ public:
|
|||
std::vector<Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer>> vec_buffers(num_buffers);
|
||||
for (size_t i = 0; i < num_buffers; i++) {
|
||||
RETURN_HR_IF_FAILED(
|
||||
Microsoft::WRL::MakeAndInitialize<weak_buffer<T>>(
|
||||
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(
|
||||
&vec_buffers.at(i), p_data[i], p_data[i] + data_sizes[i]));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#define WINML_H_
|
||||
|
||||
#include "weak_buffer.h"
|
||||
#include "buffer_backed_random_access_stream_reference.h"
|
||||
#include "weak_single_threaded_iterable.h"
|
||||
|
||||
#define RETURN_HR_IF_FAILED(expression) \
|
||||
|
|
@ -209,7 +210,7 @@ public:
|
|||
|
||||
WinMLLearningModel(const char* bytes, size_t size)
|
||||
{
|
||||
ML_FAIL_FAST_IF(0 != Initialize(bytes, size));
|
||||
ML_FAIL_FAST_IF(0 != Initialize(bytes, size, false /*with_copy*/));
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -260,52 +261,63 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
int32_t Initialize(const char* bytes, size_t size)
|
||||
int32_t Initialize(const char* bytes, size_t size, bool with_copy = false)
|
||||
{
|
||||
RoInitialize(RO_INIT_TYPE::RO_INIT_SINGLETHREADED);
|
||||
|
||||
// Create in memory stream
|
||||
Microsoft::WRL::ComPtr<IInspectable> in_memory_random_access_stream_insp;
|
||||
RETURN_HR_IF_FAILED(RoActivateInstance(
|
||||
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(),
|
||||
in_memory_random_access_stream_insp.GetAddressOf()));
|
||||
|
||||
// QI memory stream to output stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IOutputStream> output_stream;
|
||||
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream));
|
||||
|
||||
// Create data writer factory
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriterFactory> activation_factory;
|
||||
RETURN_HR_IF_FAILED(RoGetActivationFactory(
|
||||
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_DataWriter).Get(),
|
||||
IID_PPV_ARGS(activation_factory.GetAddressOf())));
|
||||
|
||||
// Create data writer object based on the in memory stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriter> data_writer;
|
||||
RETURN_HR_IF_FAILED(activation_factory->CreateDataWriter(
|
||||
output_stream.Get(),
|
||||
data_writer.GetAddressOf()));
|
||||
|
||||
// Write the model to the data writer and thus to the stream
|
||||
RETURN_HR_IF_FAILED(
|
||||
data_writer->WriteBytes(static_cast<uint32_t>(size), reinterpret_cast<BYTE*>(const_cast<char *>(bytes))));
|
||||
|
||||
// QI the in memory stream to a random access stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStream> random_access_stream;
|
||||
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&random_access_stream));
|
||||
|
||||
// Create a random access stream reference factory
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReferenceStatics> random_access_stream_ref_statics;
|
||||
RETURN_HR_IF_FAILED(RoGetActivationFactory(
|
||||
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_RandomAccessStreamReference).Get(),
|
||||
IID_PPV_ARGS(random_access_stream_ref_statics.GetAddressOf())));
|
||||
|
||||
// Create a random access stream reference from the random access stream view on top of
|
||||
// the in memory stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReference> random_access_stream_ref;
|
||||
RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream(
|
||||
random_access_stream.Get(),
|
||||
random_access_stream_ref.GetAddressOf()));
|
||||
if (with_copy) {
|
||||
// Create in memory stream
|
||||
Microsoft::WRL::ComPtr<IInspectable> in_memory_random_access_stream_insp;
|
||||
RETURN_HR_IF_FAILED(RoActivateInstance(
|
||||
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(),
|
||||
in_memory_random_access_stream_insp.GetAddressOf()));
|
||||
|
||||
// QI memory stream to output stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IOutputStream> output_stream;
|
||||
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream));
|
||||
|
||||
// Create data writer factory
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriterFactory> activation_factory;
|
||||
RETURN_HR_IF_FAILED(RoGetActivationFactory(
|
||||
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_DataWriter).Get(),
|
||||
IID_PPV_ARGS(activation_factory.GetAddressOf())));
|
||||
|
||||
// Create data writer object based on the in memory stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriter> data_writer;
|
||||
RETURN_HR_IF_FAILED(activation_factory->CreateDataWriter(
|
||||
output_stream.Get(),
|
||||
data_writer.GetAddressOf()));
|
||||
|
||||
// Write the model to the data writer and thus to the stream
|
||||
RETURN_HR_IF_FAILED(
|
||||
data_writer->WriteBytes(static_cast<uint32_t>(size), reinterpret_cast<BYTE*>(const_cast<char*>(bytes))));
|
||||
|
||||
// QI the in memory stream to a random access stream
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStream> random_access_stream;
|
||||
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&random_access_stream));
|
||||
|
||||
// Create a random access stream reference factory
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReferenceStatics> random_access_stream_ref_statics;
|
||||
RETURN_HR_IF_FAILED(RoGetActivationFactory(
|
||||
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_RandomAccessStreamReference).Get(),
|
||||
IID_PPV_ARGS(random_access_stream_ref_statics.GetAddressOf())));
|
||||
|
||||
// Create a random access stream reference from the random access stream view on top of
|
||||
// the in memory stream
|
||||
RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream(
|
||||
random_access_stream.Get(),
|
||||
random_access_stream_ref.GetAddressOf()));
|
||||
} else {
|
||||
Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<BYTE>> buffer;
|
||||
RETURN_HR_IF_FAILED(
|
||||
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<BYTE>>(
|
||||
&buffer, bytes, bytes + size));
|
||||
|
||||
RETURN_HR_IF_FAILED(
|
||||
Microsoft::WRL::MakeAndInitialize<WinMLTest::BufferBackedRandomAccessStreamReference>(
|
||||
&random_access_stream_ref, buffer.Get()));
|
||||
}
|
||||
|
||||
// Create a learning model factory
|
||||
Microsoft::WRL::ComPtr<ABI::Windows::AI::MachineLearning::ILearningModelStatics> learning_model;
|
||||
|
|
@ -450,9 +462,9 @@ public:
|
|||
TensorFactory2IID<T>::IID,
|
||||
&tensor_factory));
|
||||
|
||||
Microsoft::WRL::ComPtr<weak_buffer<T>> buffer;
|
||||
Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<T>> buffer;
|
||||
RETURN_HR_IF_FAILED(
|
||||
Microsoft::WRL::MakeAndInitialize<weak_buffer<T>>(
|
||||
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(
|
||||
&buffer, p_data, p_data + data_size));
|
||||
|
||||
Microsoft::WRL::ComPtr<ITensor> tensor;
|
||||
|
|
@ -484,7 +496,7 @@ public:
|
|||
std::vector<Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer>> vec_buffers(num_buffers);
|
||||
for (size_t i = 0; i < num_buffers; i++) {
|
||||
RETURN_HR_IF_FAILED(
|
||||
Microsoft::WRL::MakeAndInitialize<weak_buffer<T>>(
|
||||
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(
|
||||
&vec_buffers.at(i), p_data[i], p_data[i] + data_sizes[i]));
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue