Fix LoadFromStream to not use wss::Buffer internally (#9918)

Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
This commit is contained in:
Sheil Kumar 2021-12-02 21:29:06 -08:00 committed by GitHub
parent 8db49e3d0f
commit 5edaa75ef6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 514 additions and 77 deletions

View file

@ -15,6 +15,50 @@
#include <robuffer.h>
namespace WINMLP {
// IBuffer implementation to avoid calling into WinTypes.dll to create wss::Buffer.
// This will enable model creation on VTL1 without pulling in additional binaries on load.
template <typename T>
class STLVectorBackedBuffer : public winrt::implements<
STLVectorBackedBuffer<T>,
wss::IBuffer,
::Windows::Storage::Streams::IBufferByteAccess> {
private:
std::vector<T> data_;
size_t length_ = 0;
public:
STLVectorBackedBuffer(size_t num_elements) : data_(num_elements) {}
uint32_t Capacity() const try {
// Return the size of the backing vector in bytes
return static_cast<uint32_t>(data_.size() * sizeof(T));
}
WINML_CATCH_ALL
uint32_t Length() const try {
// Return the used buffer in bytes
return static_cast<uint32_t>(length_);
}
WINML_CATCH_ALL
void Length(uint32_t value) try {
// Set the use buffer length in bytes
WINML_THROW_HR_IF_TRUE_MSG(E_INVALIDARG, value > Capacity(), "Parameter 'value' cannot be greater than the buffer's capacity.");
length_ = value;
}
WINML_CATCH_ALL
STDMETHOD(Buffer)
(_Outptr_ BYTE** value) {
// Return the buffer
RETURN_HR_IF_NULL(E_POINTER, value);
*value = reinterpret_cast<BYTE*>(data_.data());
return S_OK;
}
};
LearningModel::LearningModel(
const hstring& path,
const winml::ILearningModelOperatorProvider op_provider) try : operator_provider_(op_provider) {
@ -90,12 +134,11 @@ static HRESULT CreateModelFromStream(
_winml::IModel** model) {
auto content = stream.OpenReadAsync().get();
wss::Buffer buffer(static_cast<uint32_t>(content.Size()));
auto buffer = winrt::make<STLVectorBackedBuffer<BYTE>>(static_cast<size_t>(content.Size()));
auto result = content.ReadAsync(
buffer,
buffer.Capacity(),
wss::InputStreamOptions::None)
.get();
wss::InputStreamOptions::None).get();
auto bytes = buffer.try_as<::Windows::Storage::Streams::IBufferByteAccess>();
WINML_THROW_HR_IF_NULL_MSG(E_UNEXPECTED, bytes, "Model stream is invalid.");

View file

@ -0,0 +1,363 @@
// Copyright 2019 Microsoft Corporation. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef RANDOM_ACCESS_STREAM_H
#define RANDOM_ACCESS_STREAM_H
#include <wrl.h>
#include <wrl/client.h>
#include <windows.storage.streams.h>
#include <robuffer.h>
#include <istream>
namespace WinMLTest {
struct BufferBackedRandomAccessStreamReadAsync
: public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
__FIAsyncOperationWithProgress_2_Windows__CStorage__CStreams__CIBuffer_UINT32,
ABI::Windows::Foundation::IAsyncInfo> {
InspectableClass(L"WinMLTest.BufferBackedRandomAccessStreamReadAsync", BaseTrust)
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer> buffer_;
Microsoft::WRL::ComPtr<ABI::Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>> completed_handler_;
Microsoft::WRL::ComPtr<ABI::Windows::Foundation::IAsyncOperationProgressHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>> progress_handler_;
AsyncStatus status_ = AsyncStatus::Started;
public:
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Id(
/* [retval][out] */ __RPC__out unsigned __int32* id) override {
*id = 0; // Do we need to implement this?
return S_OK;
}
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Status(
/* [retval][out] */ __RPC__out AsyncStatus* status) override {
*status = status_;
return S_OK;
}
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_ErrorCode(
/* [retval][out] */ __RPC__out HRESULT* /*errorCode*/) override {
return E_NOTIMPL;
}
virtual HRESULT STDMETHODCALLTYPE Cancel(void) override {
return E_NOTIMPL;
}
virtual HRESULT STDMETHODCALLTYPE Close(void) override {
return E_NOTIMPL;
}
HRESULT SetBuffer(ABI::Windows::Storage::Streams::IBuffer* buffer) {
buffer_ = buffer;
status_ = AsyncStatus::Completed;
if (buffer_ != nullptr) {
if (completed_handler_ != nullptr) {
completed_handler_->Invoke(this, ABI::Windows::Foundation::AsyncStatus::Completed);
}
}
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE put_Progress(
ABI::Windows::Foundation::IAsyncOperationProgressHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>* handler) override {
progress_handler_ = handler;
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE get_Progress(ABI::Windows::Foundation::IAsyncOperationProgressHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>** handler)override {
progress_handler_.CopyTo(handler);
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE put_Completed(ABI::Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>* handler) override {
completed_handler_ = handler;
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE get_Completed(ABI::Windows::Foundation::IAsyncOperationWithProgressCompletedHandler<ABI::Windows::Storage::Streams::IBuffer*, UINT32>** handler) override {
completed_handler_.CopyTo(handler);
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE GetResults(ABI::Windows::Storage::Streams::IBuffer** results) override {
if (buffer_ == nullptr) {
return E_FAIL;
}
buffer_.CopyTo(results);
return S_OK;
}
};
struct RandomAccessStream
: public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType,
ABI::Windows::Storage::Streams::IContentTypeProvider,
ABI::Windows::Storage::Streams::IRandomAccessStream,
ABI::Windows::Storage::Streams::IInputStream,
ABI::Windows::Storage::Streams::IOutputStream,
ABI::Windows::Foundation::IClosable> {
InspectableClass(L"WinMLTest.RandomAccessStream", BaseTrust)
private:
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer> buffer_ = nullptr;
UINT64 position_ = 0;
public:
HRESULT RuntimeClassInitialize(ABI::Windows::Storage::Streams::IBuffer* buffer) {
buffer_ = buffer;
position_ = 0;
return S_OK;
}
HRESULT RuntimeClassInitialize(ABI::Windows::Storage::Streams::IBuffer* buffer, UINT64 position) {
buffer_ = buffer;
position_ = position;
return S_OK;
}
// Content Provider
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_ContentType(
/* [retval, out] */__RPC__deref_out_opt HSTRING* value
) override {
return WindowsCreateString(nullptr, 0, value);
}
// IRandomAccessStream
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_Size(
/* [retval, out] */__RPC__out UINT64* value
) override {
*value = 0;
uint32_t length;
buffer_->get_Length(&length);
*value = static_cast<uint64_t>(length);
return S_OK;
}
/* [propput] */virtual HRESULT STDMETHODCALLTYPE put_Size(
/* [in] */UINT64 /*value*/
) override {
return E_NOTIMPL;
}
virtual HRESULT STDMETHODCALLTYPE GetInputStreamAt(
/* [in] */UINT64 position,
/* [retval, out] */__RPC__deref_out_opt ABI::Windows::Storage::Streams::IInputStream** stream
) override {
return Microsoft::WRL::MakeAndInitialize<RandomAccessStream>(stream, buffer_.Get(), position);
}
virtual HRESULT STDMETHODCALLTYPE GetOutputStreamAt(
/* [in] */UINT64 /*position*/,
/* [retval, out] */__RPC__deref_out_opt ABI::Windows::Storage::Streams::IOutputStream** /*stream*/
) override {
return E_NOTIMPL;
}
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_Position(
/* [retval, out] */__RPC__out UINT64* value
) override {
*value = position_;
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE Seek(
/* [in] */UINT64 position
) override {
position_ = position;
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE CloneStream(
/* [retval, out] */__RPC__deref_out_opt ABI::Windows::Storage::Streams::IRandomAccessStream** stream
) override {
return Microsoft::WRL::MakeAndInitialize<RandomAccessStream>(stream, buffer_.Get(), 0);
}
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_CanRead(
/* [retval, out] */__RPC__out::boolean* value
) override {
UINT32 length;
buffer_->get_Length(&length);
*value = buffer_ != nullptr && position_ < static_cast<UINT64>(length);
return S_OK;
}
/* [propget] */virtual HRESULT STDMETHODCALLTYPE get_CanWrite(
/* [retval, out] */__RPC__out::boolean* value
) override {
*value = false;
return S_OK;
}
// IInputStream
virtual HRESULT STDMETHODCALLTYPE ReadAsync(
/* [in] */__RPC__in_opt ABI::Windows::Storage::Streams::IBuffer* buffer,
/* [in] */UINT32 count,
/* [in] */ABI::Windows::Storage::Streams::InputStreamOptions /*options*/,
/* [retval, out] */__RPC__deref_out_opt __FIAsyncOperationWithProgress_2_Windows__CStorage__CStreams__CIBuffer_UINT32** operation
) override {
auto read_async = Microsoft::WRL::Make<BufferBackedRandomAccessStreamReadAsync>();
read_async.CopyTo(operation);
// perform the "async work" which is actually synchronous atm
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer> spBuffer = buffer;
Microsoft::WRL::ComPtr<Windows::Storage::Streams::IBufferByteAccess> out_buffer_byte_access;
spBuffer.As<Windows::Storage::Streams::IBufferByteAccess>(&out_buffer_byte_access);
byte* out_bytes = nullptr;
out_buffer_byte_access->Buffer(&out_bytes);
Microsoft::WRL::ComPtr<Windows::Storage::Streams::IBufferByteAccess> in_buffer_byte_access;
buffer_.As<Windows::Storage::Streams::IBufferByteAccess>(&in_buffer_byte_access);
byte* in_bytes = nullptr;
in_buffer_byte_access->Buffer(&in_bytes);
memcpy(out_bytes, in_bytes + static_cast<uint32_t>(position_), count);
read_async->SetBuffer(buffer);
return S_OK;
}
// IOutputStream
virtual HRESULT STDMETHODCALLTYPE WriteAsync(
/* [in] */__RPC__in_opt ABI::Windows::Storage::Streams::IBuffer* /*buffer*/,
/* [retval, out] */__RPC__deref_out_opt __FIAsyncOperationWithProgress_2_UINT32_UINT32** /*operation*/
) override {
return E_NOTIMPL;
}
virtual HRESULT STDMETHODCALLTYPE FlushAsync(
/* [retval, out] */__RPC__deref_out_opt __FIAsyncOperation_1_boolean** /*operation*/
) override {
return E_NOTIMPL;
}
// IClosable
virtual HRESULT STDMETHODCALLTYPE Close(void) override {
buffer_ = nullptr;
return S_OK;
}
};
struct BufferBackedRandomAccessStreamReferenceOpenReadAsync
: public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
__FIAsyncOperation_1_Windows__CStorage__CStreams__CIRandomAccessStreamWithContentType,
ABI::Windows::Foundation::IAsyncInfo> {
InspectableClass(L"WinMLTest.BufferBackedRandomAccessStreamReferenceOpenReadAsync", BaseTrust)
public:
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType> ras_;
Microsoft::WRL::ComPtr<ABI::Windows::Foundation::IAsyncOperationCompletedHandler<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType*>> completed_handler_;
AsyncStatus status_ = AsyncStatus::Started;
HRESULT SetRandomAccessStream(ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType* ras) {
ras_ = ras;
status_ = AsyncStatus::Completed;
if (ras_ != nullptr) {
if (completed_handler_ != nullptr) {
completed_handler_->Invoke(this, status_);
}
}
return S_OK;
}
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Id(
/* [retval][out] */ __RPC__out unsigned __int32* id) override {
*id = 0; // Do we need to implement this?
return S_OK;
}
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Status(
/* [retval][out] */ __RPC__out AsyncStatus* status) override {
*status = status_;
return S_OK;
}
virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_ErrorCode(
/* [retval][out] */ __RPC__out HRESULT* /*errorCode*/) override {
return E_NOTIMPL;
}
virtual HRESULT STDMETHODCALLTYPE Cancel(void) override {
return E_NOTIMPL;
}
virtual HRESULT STDMETHODCALLTYPE Close(void) override {
return E_NOTIMPL;
}
virtual HRESULT STDMETHODCALLTYPE put_Completed(
ABI::Windows::Foundation::IAsyncOperationCompletedHandler<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType*>* handler) override
{
completed_handler_ = handler;
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE get_Completed(
ABI::Windows::Foundation::IAsyncOperationCompletedHandler<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType*>** handler) override {
completed_handler_.CopyTo(handler);
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE GetResults(
ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType** results) override {
if (ras_ == nullptr) {
return E_FAIL;
}
ras_.CopyTo(results);
return S_OK;
}
};
struct BufferBackedRandomAccessStreamReference
: public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
ABI::Windows::Storage::Streams::IRandomAccessStreamReference> {
InspectableClass(L"WinMLTest.BufferBackedRandomAccessStreamReference", BaseTrust)
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer> buffer_ = nullptr;
public:
HRESULT RuntimeClassInitialize(ABI::Windows::Storage::Streams::IBuffer* buffer) {
buffer_ = buffer;
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE OpenReadAsync(
/* [retval, out] */__RPC__deref_out_opt __FIAsyncOperation_1_Windows__CStorage__CStreams__CIRandomAccessStreamWithContentType** operation
) override {
auto open_read_async = Microsoft::WRL::Make<BufferBackedRandomAccessStreamReferenceOpenReadAsync>();
open_read_async.CopyTo(operation);
Microsoft::WRL::ComPtr<RandomAccessStream> ras;
Microsoft::WRL::MakeAndInitialize<RandomAccessStream>(&ras, buffer_.Get());
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType> ras_interface = nullptr;
ras.As<ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType>(&ras_interface);
open_read_async.Get()->SetRandomAccessStream(ras_interface.Get());
return S_OK;
}
};
} // namespace WinMLTest
#endif // RANDOM_ACCESS_STREAM_H

View file

@ -10,14 +10,16 @@
#include <windows.storage.streams.h>
#include <robuffer.h>
namespace Microsoft { namespace AI { namespace MachineLearning { namespace Details {
namespace WinMLTest {
template <typename T>
struct weak_buffer
struct WeakBuffer
: public Microsoft::WRL::RuntimeClass<
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::WinRtClassicComMix | Microsoft::WRL::InhibitRoOriginateError>,
ABI::Windows::Storage::Streams::IBuffer,
Windows::Storage::Streams::IBufferByteAccess> {
InspectableClass(L"WinMLTest.WeakBuffer", BaseTrust)
private:
const T* m_p_begin;
const T* m_p_end;
@ -42,9 +44,13 @@ public:
}
virtual HRESULT STDMETHODCALLTYPE get_Length(
UINT32 * /*value*/)
UINT32 * value)
{
return E_NOTIMPL;
if (value == nullptr) {
return E_POINTER;
}
*value = static_cast<uint32_t>(m_p_end - m_p_begin) * sizeof(T);
return S_OK;
}
virtual HRESULT STDMETHODCALLTYPE put_Length(
@ -64,6 +70,6 @@ public:
}
};
}}}} // namespace Microsoft::AI::MachineLearning::Details
} // namespace WinMLTest
#endif // WEAK_BUFFER_H

View file

@ -4,6 +4,7 @@
#define WINML_H_
#include "weak_buffer.h"
#include "buffer_backed_random_access_stream_reference.h"
#include "weak_single_threaded_iterable.h"
#define RETURN_HR_IF_FAILED(expression) \
@ -213,7 +214,7 @@ public:
WinMLLearningModel(const char* bytes, size_t size)
{
ML_FAIL_FAST_IF(0 != Initialize(bytes, size));
ML_FAIL_FAST_IF(0 != Initialize(bytes, size, false /*dont copy*/));
}
private:
@ -264,7 +265,7 @@ private:
}
};
int32_t Initialize(const char* bytes, size_t size)
int32_t Initialize(const char* bytes, size_t size, bool with_copy = false)
{
auto hr = RoInitialize(RO_INIT_TYPE::RO_INIT_SINGLETHREADED);
// https://docs.microsoft.com/en-us/windows/win32/api/roapi/nf-roapi-roinitialize#return-value
@ -273,15 +274,17 @@ private:
return static_cast<int32_t>(hr);
}
// Create in memory stream
Microsoft::WRL::ComPtr<IInspectable> in_memory_random_access_stream_insp;
RETURN_HR_IF_FAILED(RoActivateInstance(
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(),
in_memory_random_access_stream_insp.GetAddressOf()));
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReference> random_access_stream_ref;
if (with_copy) {
// Create in memory stream
Microsoft::WRL::ComPtr<IInspectable> in_memory_random_access_stream_insp;
RETURN_HR_IF_FAILED(RoActivateInstance(
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(),
in_memory_random_access_stream_insp.GetAddressOf()));
// QI memory stream to output stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IOutputStream> output_stream;
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream));
// QI memory stream to output stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IOutputStream> output_stream;
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream));
// Create data writer factory
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriterFactory> activation_factory;
@ -297,7 +300,7 @@ private:
// Write the model to the data writer and thus to the stream
RETURN_HR_IF_FAILED(
data_writer->WriteBytes(static_cast<uint32_t>(size), reinterpret_cast<BYTE*>(const_cast<char *>(bytes))));
data_writer->WriteBytes(static_cast<uint32_t>(size), reinterpret_cast<BYTE*>(const_cast<char*>(bytes))));
// QI the in memory stream to a random access stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStream> random_access_stream;
@ -311,25 +314,35 @@ private:
// Create a random access stream reference from the random access stream view on top of
// the in memory stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReference> random_access_stream_ref;
RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream(
random_access_stream.Get(),
random_access_stream_ref.GetAddressOf()));
// Create a learning model factory
Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelStatics> learning_model;
RETURN_HR_IF_FAILED(
GetActivationFactory(
RuntimeClass_Microsoft_AI_MachineLearning_LearningModel,
ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics,
&learning_model));
Microsoft::WRL::ComPtr<ABI::Windows::Foundation::IAsyncOperation<uint32_t>> async_operation;
RETURN_HR_IF_FAILED(data_writer->StoreAsync(&async_operation));
auto store_completed_handler = Microsoft::WRL::Make<StoreCompleted>();
RETURN_HR_IF_FAILED(async_operation->put_Completed(store_completed_handler.Get()));
RETURN_HR_IF_FAILED(store_completed_handler->Wait());
} else {
Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<char>> buffer;
RETURN_HR_IF_FAILED(
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<char>>(
&buffer, bytes, bytes + size));
RETURN_HR_IF_FAILED(
Microsoft::WRL::MakeAndInitialize<WinMLTest::BufferBackedRandomAccessStreamReference>(
&random_access_stream_ref, buffer.Get()));
}
// Create a learning model factory
Microsoft::WRL::ComPtr<ABI::Microsoft::AI::MachineLearning::ILearningModelStatics> learning_model;
RETURN_HR_IF_FAILED(
GetActivationFactory(
RuntimeClass_Microsoft_AI_MachineLearning_LearningModel,
ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics,
&learning_model));
// Create a learning model from the factory with the random access stream reference that points
// to the random access stream view on top of the in memory stream copy of the model
RETURN_HR_IF_FAILED(
@ -459,9 +472,9 @@ public:
TensorFactory2IID<T>::IID,
&tensor_factory));
Microsoft::WRL::ComPtr<weak_buffer<T>> buffer;
Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<T>> buffer;
RETURN_HR_IF_FAILED(
Microsoft::WRL::MakeAndInitialize<weak_buffer<T>>(
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(
&buffer, p_data, p_data + data_size));
Microsoft::WRL::ComPtr<ITensor> tensor;
@ -493,7 +506,7 @@ public:
std::vector<Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer>> vec_buffers(num_buffers);
for (size_t i = 0; i < num_buffers; i++) {
RETURN_HR_IF_FAILED(
Microsoft::WRL::MakeAndInitialize<weak_buffer<T>>(
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(
&vec_buffers.at(i), p_data[i], p_data[i] + data_sizes[i]));
}

View file

@ -4,6 +4,7 @@
#define WINML_H_
#include "weak_buffer.h"
#include "buffer_backed_random_access_stream_reference.h"
#include "weak_single_threaded_iterable.h"
#define RETURN_HR_IF_FAILED(expression) \
@ -209,7 +210,7 @@ public:
WinMLLearningModel(const char* bytes, size_t size)
{
ML_FAIL_FAST_IF(0 != Initialize(bytes, size));
ML_FAIL_FAST_IF(0 != Initialize(bytes, size, false /*with_copy*/));
}
private:
@ -260,52 +261,63 @@ private:
}
};
int32_t Initialize(const char* bytes, size_t size)
int32_t Initialize(const char* bytes, size_t size, bool with_copy = false)
{
RoInitialize(RO_INIT_TYPE::RO_INIT_SINGLETHREADED);
// Create in memory stream
Microsoft::WRL::ComPtr<IInspectable> in_memory_random_access_stream_insp;
RETURN_HR_IF_FAILED(RoActivateInstance(
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(),
in_memory_random_access_stream_insp.GetAddressOf()));
// QI memory stream to output stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IOutputStream> output_stream;
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream));
// Create data writer factory
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriterFactory> activation_factory;
RETURN_HR_IF_FAILED(RoGetActivationFactory(
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_DataWriter).Get(),
IID_PPV_ARGS(activation_factory.GetAddressOf())));
// Create data writer object based on the in memory stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriter> data_writer;
RETURN_HR_IF_FAILED(activation_factory->CreateDataWriter(
output_stream.Get(),
data_writer.GetAddressOf()));
// Write the model to the data writer and thus to the stream
RETURN_HR_IF_FAILED(
data_writer->WriteBytes(static_cast<uint32_t>(size), reinterpret_cast<BYTE*>(const_cast<char *>(bytes))));
// QI the in memory stream to a random access stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStream> random_access_stream;
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&random_access_stream));
// Create a random access stream reference factory
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReferenceStatics> random_access_stream_ref_statics;
RETURN_HR_IF_FAILED(RoGetActivationFactory(
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_RandomAccessStreamReference).Get(),
IID_PPV_ARGS(random_access_stream_ref_statics.GetAddressOf())));
// Create a random access stream reference from the random access stream view on top of
// the in memory stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReference> random_access_stream_ref;
RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream(
random_access_stream.Get(),
random_access_stream_ref.GetAddressOf()));
if (with_copy) {
// Create in memory stream
Microsoft::WRL::ComPtr<IInspectable> in_memory_random_access_stream_insp;
RETURN_HR_IF_FAILED(RoActivateInstance(
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(),
in_memory_random_access_stream_insp.GetAddressOf()));
// QI memory stream to output stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IOutputStream> output_stream;
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream));
// Create data writer factory
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriterFactory> activation_factory;
RETURN_HR_IF_FAILED(RoGetActivationFactory(
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_DataWriter).Get(),
IID_PPV_ARGS(activation_factory.GetAddressOf())));
// Create data writer object based on the in memory stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IDataWriter> data_writer;
RETURN_HR_IF_FAILED(activation_factory->CreateDataWriter(
output_stream.Get(),
data_writer.GetAddressOf()));
// Write the model to the data writer and thus to the stream
RETURN_HR_IF_FAILED(
data_writer->WriteBytes(static_cast<uint32_t>(size), reinterpret_cast<BYTE*>(const_cast<char*>(bytes))));
// QI the in memory stream to a random access stream
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStream> random_access_stream;
RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&random_access_stream));
// Create a random access stream reference factory
Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IRandomAccessStreamReferenceStatics> random_access_stream_ref_statics;
RETURN_HR_IF_FAILED(RoGetActivationFactory(
Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_RandomAccessStreamReference).Get(),
IID_PPV_ARGS(random_access_stream_ref_statics.GetAddressOf())));
// Create a random access stream reference from the random access stream view on top of
// the in memory stream
RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream(
random_access_stream.Get(),
random_access_stream_ref.GetAddressOf()));
} else {
Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<BYTE>> buffer;
RETURN_HR_IF_FAILED(
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<BYTE>>(
&buffer, bytes, bytes + size));
RETURN_HR_IF_FAILED(
Microsoft::WRL::MakeAndInitialize<WinMLTest::BufferBackedRandomAccessStreamReference>(
&random_access_stream_ref, buffer.Get()));
}
// Create a learning model factory
Microsoft::WRL::ComPtr<ABI::Windows::AI::MachineLearning::ILearningModelStatics> learning_model;
@ -450,9 +462,9 @@ public:
TensorFactory2IID<T>::IID,
&tensor_factory));
Microsoft::WRL::ComPtr<weak_buffer<T>> buffer;
Microsoft::WRL::ComPtr<WinMLTest::WeakBuffer<T>> buffer;
RETURN_HR_IF_FAILED(
Microsoft::WRL::MakeAndInitialize<weak_buffer<T>>(
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(
&buffer, p_data, p_data + data_size));
Microsoft::WRL::ComPtr<ITensor> tensor;
@ -484,7 +496,7 @@ public:
std::vector<Microsoft::WRL::ComPtr<ABI::Windows::Storage::Streams::IBuffer>> vec_buffers(num_buffers);
for (size_t i = 0; i < num_buffers; i++) {
RETURN_HR_IF_FAILED(
Microsoft::WRL::MakeAndInitialize<weak_buffer<T>>(
Microsoft::WRL::MakeAndInitialize<WinMLTest::WeakBuffer<T>>(
&vec_buffers.at(i), p_data[i], p_data[i] + data_sizes[i]));
}