diff --git a/winml/lib/Api/LearningModel.cpp b/winml/lib/Api/LearningModel.cpp index d638fb1702..b0b988adf8 100644 --- a/winml/lib/Api/LearningModel.cpp +++ b/winml/lib/Api/LearningModel.cpp @@ -15,6 +15,50 @@ #include namespace WINMLP { + +// IBuffer implementation to avoid calling into WinTypes.dll to create wss::Buffer. +// This will enable model creation on VTL1 without pulling in additional binaries on load. +template +class STLVectorBackedBuffer : public winrt::implements< + STLVectorBackedBuffer, + wss::IBuffer, + ::Windows::Storage::Streams::IBufferByteAccess> { + private: + std::vector data_; + size_t length_ = 0; + + public: + STLVectorBackedBuffer(size_t num_elements) : data_(num_elements) {} + + uint32_t Capacity() const try { + // Return the size of the backing vector in bytes + return static_cast(data_.size() * sizeof(T)); + } + WINML_CATCH_ALL + + uint32_t Length() const try { + // Return the used buffer in bytes + return static_cast(length_); + } + WINML_CATCH_ALL + + void Length(uint32_t value) try { + // Set the use buffer length in bytes + WINML_THROW_HR_IF_TRUE_MSG(E_INVALIDARG, value > Capacity(), "Parameter 'value' cannot be greater than the buffer's capacity."); + length_ = value; + } + WINML_CATCH_ALL + + STDMETHOD(Buffer) + (_Outptr_ BYTE** value) { + // Return the buffer + RETURN_HR_IF_NULL(E_POINTER, value); + *value = reinterpret_cast(data_.data()); + return S_OK; + } +}; + + LearningModel::LearningModel( const hstring& path, const winml::ILearningModelOperatorProvider op_provider) try : operator_provider_(op_provider) { @@ -90,12 +134,11 @@ static HRESULT CreateModelFromStream( _winml::IModel** model) { auto content = stream.OpenReadAsync().get(); - wss::Buffer buffer(static_cast(content.Size())); + auto buffer = winrt::make>(static_cast(content.Size())); auto result = content.ReadAsync( buffer, buffer.Capacity(), - wss::InputStreamOptions::None) - .get(); + wss::InputStreamOptions::None).get(); auto bytes = buffer.try_as<::Windows::Storage::Streams::IBufferByteAccess>(); WINML_THROW_HR_IF_NULL_MSG(E_UNEXPECTED, bytes, "Model stream is invalid."); diff --git a/winml/test/api/raw/buffer_backed_random_access_stream_reference.h b/winml/test/api/raw/buffer_backed_random_access_stream_reference.h new file mode 100644 index 0000000000..f1eb29a678 --- /dev/null +++ b/winml/test/api/raw/buffer_backed_random_access_stream_reference.h @@ -0,0 +1,363 @@ +// Copyright 2019 Microsoft Corporation. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. +#ifndef RANDOM_ACCESS_STREAM_H +#define RANDOM_ACCESS_STREAM_H + +#include +#include + +#include +#include + +#include + +namespace WinMLTest { + +struct BufferBackedRandomAccessStreamReadAsync + : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + __FIAsyncOperationWithProgress_2_Windows__CStorage__CStreams__CIBuffer_UINT32, + ABI::Windows::Foundation::IAsyncInfo> { + + InspectableClass(L"WinMLTest.BufferBackedRandomAccessStreamReadAsync", BaseTrust) + + Microsoft::WRL::ComPtr buffer_; + + Microsoft::WRL::ComPtr> completed_handler_; + Microsoft::WRL::ComPtr> progress_handler_; + + AsyncStatus status_ = AsyncStatus::Started; + +public: + + virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Id( + /* [retval][out] */ __RPC__out unsigned __int32* id) override { + *id = 0; // Do we need to implement this? + return S_OK; + } + + virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Status( + /* [retval][out] */ __RPC__out AsyncStatus* status) override { + *status = status_; + return S_OK; + } + + virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_ErrorCode( + /* [retval][out] */ __RPC__out HRESULT* /*errorCode*/) override { + return E_NOTIMPL; + } + + virtual HRESULT STDMETHODCALLTYPE Cancel(void) override { + return E_NOTIMPL; + } + + virtual HRESULT STDMETHODCALLTYPE Close(void) override { + return E_NOTIMPL; + } + + + + HRESULT SetBuffer(ABI::Windows::Storage::Streams::IBuffer* buffer) { + buffer_ = buffer; + status_ = AsyncStatus::Completed; + if (buffer_ != nullptr) { + if (completed_handler_ != nullptr) { + completed_handler_->Invoke(this, ABI::Windows::Foundation::AsyncStatus::Completed); + } + } + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE put_Progress( + ABI::Windows::Foundation::IAsyncOperationProgressHandler* handler) override { + progress_handler_ = handler; + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE get_Progress(ABI::Windows::Foundation::IAsyncOperationProgressHandler** handler)override { + progress_handler_.CopyTo(handler); + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE put_Completed(ABI::Windows::Foundation::IAsyncOperationWithProgressCompletedHandler* handler) override { + completed_handler_ = handler; + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE get_Completed(ABI::Windows::Foundation::IAsyncOperationWithProgressCompletedHandler** handler) override { + completed_handler_.CopyTo(handler); + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE GetResults(ABI::Windows::Storage::Streams::IBuffer** results) override { + if (buffer_ == nullptr) { + return E_FAIL; + } + + buffer_.CopyTo(results); + return S_OK; + } +}; + +struct RandomAccessStream + : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType, + ABI::Windows::Storage::Streams::IContentTypeProvider, + ABI::Windows::Storage::Streams::IRandomAccessStream, + ABI::Windows::Storage::Streams::IInputStream, + ABI::Windows::Storage::Streams::IOutputStream, + ABI::Windows::Foundation::IClosable> { + InspectableClass(L"WinMLTest.RandomAccessStream", BaseTrust) + +private: + Microsoft::WRL::ComPtr buffer_ = nullptr; + UINT64 position_ = 0; + +public: + HRESULT RuntimeClassInitialize(ABI::Windows::Storage::Streams::IBuffer* buffer) { + buffer_ = buffer; + position_ = 0; + return S_OK; + } + + HRESULT RuntimeClassInitialize(ABI::Windows::Storage::Streams::IBuffer* buffer, UINT64 position) { + buffer_ = buffer; + position_ = position; + return S_OK; + } + + // Content Provider + + /* [propget] */virtual HRESULT STDMETHODCALLTYPE get_ContentType( + /* [retval, out] */__RPC__deref_out_opt HSTRING* value + ) override { + return WindowsCreateString(nullptr, 0, value); + } + + // IRandomAccessStream + + /* [propget] */virtual HRESULT STDMETHODCALLTYPE get_Size( + /* [retval, out] */__RPC__out UINT64* value + ) override { + *value = 0; + uint32_t length; + buffer_->get_Length(&length); + *value = static_cast(length); + return S_OK; + } + + /* [propput] */virtual HRESULT STDMETHODCALLTYPE put_Size( + /* [in] */UINT64 /*value*/ + ) override { + return E_NOTIMPL; + } + + virtual HRESULT STDMETHODCALLTYPE GetInputStreamAt( + /* [in] */UINT64 position, + /* [retval, out] */__RPC__deref_out_opt ABI::Windows::Storage::Streams::IInputStream** stream + ) override { + return Microsoft::WRL::MakeAndInitialize(stream, buffer_.Get(), position); + } + + virtual HRESULT STDMETHODCALLTYPE GetOutputStreamAt( + /* [in] */UINT64 /*position*/, + /* [retval, out] */__RPC__deref_out_opt ABI::Windows::Storage::Streams::IOutputStream** /*stream*/ + ) override { + return E_NOTIMPL; + } + + /* [propget] */virtual HRESULT STDMETHODCALLTYPE get_Position( + /* [retval, out] */__RPC__out UINT64* value + ) override { + *value = position_; + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE Seek( + /* [in] */UINT64 position + ) override { + position_ = position; + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE CloneStream( + /* [retval, out] */__RPC__deref_out_opt ABI::Windows::Storage::Streams::IRandomAccessStream** stream + ) override { + return Microsoft::WRL::MakeAndInitialize(stream, buffer_.Get(), 0); + } + + /* [propget] */virtual HRESULT STDMETHODCALLTYPE get_CanRead( + /* [retval, out] */__RPC__out::boolean* value + ) override { + UINT32 length; + buffer_->get_Length(&length); + *value = buffer_ != nullptr && position_ < static_cast(length); + return S_OK; + } + + /* [propget] */virtual HRESULT STDMETHODCALLTYPE get_CanWrite( + /* [retval, out] */__RPC__out::boolean* value + ) override { + *value = false; + return S_OK; + } + + // IInputStream + virtual HRESULT STDMETHODCALLTYPE ReadAsync( + /* [in] */__RPC__in_opt ABI::Windows::Storage::Streams::IBuffer* buffer, + /* [in] */UINT32 count, + /* [in] */ABI::Windows::Storage::Streams::InputStreamOptions /*options*/, + /* [retval, out] */__RPC__deref_out_opt __FIAsyncOperationWithProgress_2_Windows__CStorage__CStreams__CIBuffer_UINT32** operation + ) override { + auto read_async = Microsoft::WRL::Make(); + read_async.CopyTo(operation); + + // perform the "async work" which is actually synchronous atm + Microsoft::WRL::ComPtr spBuffer = buffer; + Microsoft::WRL::ComPtr out_buffer_byte_access; + spBuffer.As(&out_buffer_byte_access); + byte* out_bytes = nullptr; + out_buffer_byte_access->Buffer(&out_bytes); + + Microsoft::WRL::ComPtr in_buffer_byte_access; + buffer_.As(&in_buffer_byte_access); + byte* in_bytes = nullptr; + in_buffer_byte_access->Buffer(&in_bytes); + + memcpy(out_bytes, in_bytes + static_cast(position_), count); + + read_async->SetBuffer(buffer); + + return S_OK; + } + + // IOutputStream + virtual HRESULT STDMETHODCALLTYPE WriteAsync( + /* [in] */__RPC__in_opt ABI::Windows::Storage::Streams::IBuffer* /*buffer*/, + /* [retval, out] */__RPC__deref_out_opt __FIAsyncOperationWithProgress_2_UINT32_UINT32** /*operation*/ + ) override { + return E_NOTIMPL; + } + + virtual HRESULT STDMETHODCALLTYPE FlushAsync( + /* [retval, out] */__RPC__deref_out_opt __FIAsyncOperation_1_boolean** /*operation*/ + ) override { + return E_NOTIMPL; + } + + // IClosable + virtual HRESULT STDMETHODCALLTYPE Close(void) override { + buffer_ = nullptr; + return S_OK; + } + +}; + +struct BufferBackedRandomAccessStreamReferenceOpenReadAsync + : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + __FIAsyncOperation_1_Windows__CStorage__CStreams__CIRandomAccessStreamWithContentType, + ABI::Windows::Foundation::IAsyncInfo> { + + InspectableClass(L"WinMLTest.BufferBackedRandomAccessStreamReferenceOpenReadAsync", BaseTrust) +public: + Microsoft::WRL::ComPtr ras_; + Microsoft::WRL::ComPtr> completed_handler_; + AsyncStatus status_ = AsyncStatus::Started; + + HRESULT SetRandomAccessStream(ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType* ras) { + ras_ = ras; + status_ = AsyncStatus::Completed; + if (ras_ != nullptr) { + if (completed_handler_ != nullptr) { + completed_handler_->Invoke(this, status_); + } + } + return S_OK; + } + + virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Id( + /* [retval][out] */ __RPC__out unsigned __int32* id) override { + *id = 0; // Do we need to implement this? + return S_OK; + } + + virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_Status( + /* [retval][out] */ __RPC__out AsyncStatus* status) override { + *status = status_; + return S_OK; + } + + virtual /* [propget] */ HRESULT STDMETHODCALLTYPE get_ErrorCode( + /* [retval][out] */ __RPC__out HRESULT* /*errorCode*/) override { + return E_NOTIMPL; + } + + virtual HRESULT STDMETHODCALLTYPE Cancel(void) override { + return E_NOTIMPL; + } + + virtual HRESULT STDMETHODCALLTYPE Close(void) override { + return E_NOTIMPL; + } + + virtual HRESULT STDMETHODCALLTYPE put_Completed( + ABI::Windows::Foundation::IAsyncOperationCompletedHandler* handler) override + { + completed_handler_ = handler; + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE get_Completed( + ABI::Windows::Foundation::IAsyncOperationCompletedHandler** handler) override { + completed_handler_.CopyTo(handler); + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE GetResults( + ABI::Windows::Storage::Streams::IRandomAccessStreamWithContentType** results) override { + if (ras_ == nullptr) { + return E_FAIL; + } + ras_.CopyTo(results); + return S_OK; + } +}; + +struct BufferBackedRandomAccessStreamReference + : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + ABI::Windows::Storage::Streams::IRandomAccessStreamReference> { + InspectableClass(L"WinMLTest.BufferBackedRandomAccessStreamReference", BaseTrust) + + Microsoft::WRL::ComPtr buffer_ = nullptr; + +public: + HRESULT RuntimeClassInitialize(ABI::Windows::Storage::Streams::IBuffer* buffer) { + buffer_ = buffer; + return S_OK; + } + + virtual HRESULT STDMETHODCALLTYPE OpenReadAsync( + /* [retval, out] */__RPC__deref_out_opt __FIAsyncOperation_1_Windows__CStorage__CStreams__CIRandomAccessStreamWithContentType** operation + ) override { + auto open_read_async = Microsoft::WRL::Make(); + open_read_async.CopyTo(operation); + + Microsoft::WRL::ComPtr ras; + Microsoft::WRL::MakeAndInitialize(&ras, buffer_.Get()); + + Microsoft::WRL::ComPtr ras_interface = nullptr; + ras.As(&ras_interface); + + open_read_async.Get()->SetRandomAccessStream(ras_interface.Get()); + return S_OK; + } + +}; + +} // namespace WinMLTest + +#endif // RANDOM_ACCESS_STREAM_H \ No newline at end of file diff --git a/winml/test/api/raw/weak_buffer.h b/winml/test/api/raw/weak_buffer.h index a2c3fdf1a5..a175273c4e 100644 --- a/winml/test/api/raw/weak_buffer.h +++ b/winml/test/api/raw/weak_buffer.h @@ -10,14 +10,16 @@ #include #include -namespace Microsoft { namespace AI { namespace MachineLearning { namespace Details { +namespace WinMLTest { template -struct weak_buffer +struct WeakBuffer : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, ABI::Windows::Storage::Streams::IBuffer, Windows::Storage::Streams::IBufferByteAccess> { + InspectableClass(L"WinMLTest.WeakBuffer", BaseTrust) + private: const T* m_p_begin; const T* m_p_end; @@ -42,9 +44,13 @@ public: } virtual HRESULT STDMETHODCALLTYPE get_Length( - UINT32 * /*value*/) + UINT32 * value) { - return E_NOTIMPL; + if (value == nullptr) { + return E_POINTER; + } + *value = static_cast(m_p_end - m_p_begin) * sizeof(T); + return S_OK; } virtual HRESULT STDMETHODCALLTYPE put_Length( @@ -64,6 +70,6 @@ public: } }; -}}}} // namespace Microsoft::AI::MachineLearning::Details +} // namespace WinMLTest #endif // WEAK_BUFFER_H diff --git a/winml/test/api/raw/winml_microsoft.h b/winml/test/api/raw/winml_microsoft.h index 497facf499..a247303ca7 100644 --- a/winml/test/api/raw/winml_microsoft.h +++ b/winml/test/api/raw/winml_microsoft.h @@ -4,6 +4,7 @@ #define WINML_H_ #include "weak_buffer.h" +#include "buffer_backed_random_access_stream_reference.h" #include "weak_single_threaded_iterable.h" #define RETURN_HR_IF_FAILED(expression) \ @@ -213,7 +214,7 @@ public: WinMLLearningModel(const char* bytes, size_t size) { - ML_FAIL_FAST_IF(0 != Initialize(bytes, size)); + ML_FAIL_FAST_IF(0 != Initialize(bytes, size, false /*dont copy*/)); } private: @@ -264,7 +265,7 @@ private: } }; - int32_t Initialize(const char* bytes, size_t size) + int32_t Initialize(const char* bytes, size_t size, bool with_copy = false) { auto hr = RoInitialize(RO_INIT_TYPE::RO_INIT_SINGLETHREADED); // https://docs.microsoft.com/en-us/windows/win32/api/roapi/nf-roapi-roinitialize#return-value @@ -273,15 +274,17 @@ private: return static_cast(hr); } - // Create in memory stream - Microsoft::WRL::ComPtr in_memory_random_access_stream_insp; - RETURN_HR_IF_FAILED(RoActivateInstance( - Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(), - in_memory_random_access_stream_insp.GetAddressOf())); + Microsoft::WRL::ComPtr random_access_stream_ref; + if (with_copy) { + // Create in memory stream + Microsoft::WRL::ComPtr in_memory_random_access_stream_insp; + RETURN_HR_IF_FAILED(RoActivateInstance( + Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(), + in_memory_random_access_stream_insp.GetAddressOf())); - // QI memory stream to output stream - Microsoft::WRL::ComPtr output_stream; - RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream)); + // QI memory stream to output stream + Microsoft::WRL::ComPtr output_stream; + RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream)); // Create data writer factory Microsoft::WRL::ComPtr activation_factory; @@ -297,7 +300,7 @@ private: // Write the model to the data writer and thus to the stream RETURN_HR_IF_FAILED( - data_writer->WriteBytes(static_cast(size), reinterpret_cast(const_cast(bytes)))); + data_writer->WriteBytes(static_cast(size), reinterpret_cast(const_cast(bytes)))); // QI the in memory stream to a random access stream Microsoft::WRL::ComPtr random_access_stream; @@ -311,25 +314,35 @@ private: // Create a random access stream reference from the random access stream view on top of // the in memory stream - Microsoft::WRL::ComPtr random_access_stream_ref; RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream( random_access_stream.Get(), random_access_stream_ref.GetAddressOf())); - // Create a learning model factory - Microsoft::WRL::ComPtr learning_model; - RETURN_HR_IF_FAILED( - GetActivationFactory( - RuntimeClass_Microsoft_AI_MachineLearning_LearningModel, - ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics, - &learning_model)); - Microsoft::WRL::ComPtr> async_operation; RETURN_HR_IF_FAILED(data_writer->StoreAsync(&async_operation)); auto store_completed_handler = Microsoft::WRL::Make(); RETURN_HR_IF_FAILED(async_operation->put_Completed(store_completed_handler.Get())); RETURN_HR_IF_FAILED(store_completed_handler->Wait()); + } else { + Microsoft::WRL::ComPtr> buffer; + RETURN_HR_IF_FAILED( + Microsoft::WRL::MakeAndInitialize>( + &buffer, bytes, bytes + size)); + + RETURN_HR_IF_FAILED( + Microsoft::WRL::MakeAndInitialize( + &random_access_stream_ref, buffer.Get())); + } + + // Create a learning model factory + Microsoft::WRL::ComPtr learning_model; + RETURN_HR_IF_FAILED( + GetActivationFactory( + RuntimeClass_Microsoft_AI_MachineLearning_LearningModel, + ABI::Microsoft::AI::MachineLearning::IID_ILearningModelStatics, + &learning_model)); + // Create a learning model from the factory with the random access stream reference that points // to the random access stream view on top of the in memory stream copy of the model RETURN_HR_IF_FAILED( @@ -459,9 +472,9 @@ public: TensorFactory2IID::IID, &tensor_factory)); - Microsoft::WRL::ComPtr> buffer; + Microsoft::WRL::ComPtr> buffer; RETURN_HR_IF_FAILED( - Microsoft::WRL::MakeAndInitialize>( + Microsoft::WRL::MakeAndInitialize>( &buffer, p_data, p_data + data_size)); Microsoft::WRL::ComPtr tensor; @@ -493,7 +506,7 @@ public: std::vector> vec_buffers(num_buffers); for (size_t i = 0; i < num_buffers; i++) { RETURN_HR_IF_FAILED( - Microsoft::WRL::MakeAndInitialize>( + Microsoft::WRL::MakeAndInitialize>( &vec_buffers.at(i), p_data[i], p_data[i] + data_sizes[i])); } diff --git a/winml/test/api/raw/winml_windows.h b/winml/test/api/raw/winml_windows.h index 7a493618d9..172d95850f 100644 --- a/winml/test/api/raw/winml_windows.h +++ b/winml/test/api/raw/winml_windows.h @@ -4,6 +4,7 @@ #define WINML_H_ #include "weak_buffer.h" +#include "buffer_backed_random_access_stream_reference.h" #include "weak_single_threaded_iterable.h" #define RETURN_HR_IF_FAILED(expression) \ @@ -209,7 +210,7 @@ public: WinMLLearningModel(const char* bytes, size_t size) { - ML_FAIL_FAST_IF(0 != Initialize(bytes, size)); + ML_FAIL_FAST_IF(0 != Initialize(bytes, size, false /*with_copy*/)); } private: @@ -260,52 +261,63 @@ private: } }; - int32_t Initialize(const char* bytes, size_t size) + int32_t Initialize(const char* bytes, size_t size, bool with_copy = false) { RoInitialize(RO_INIT_TYPE::RO_INIT_SINGLETHREADED); - // Create in memory stream - Microsoft::WRL::ComPtr in_memory_random_access_stream_insp; - RETURN_HR_IF_FAILED(RoActivateInstance( - Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(), - in_memory_random_access_stream_insp.GetAddressOf())); - - // QI memory stream to output stream - Microsoft::WRL::ComPtr output_stream; - RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream)); - - // Create data writer factory - Microsoft::WRL::ComPtr activation_factory; - RETURN_HR_IF_FAILED(RoGetActivationFactory( - Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_DataWriter).Get(), - IID_PPV_ARGS(activation_factory.GetAddressOf()))); - - // Create data writer object based on the in memory stream - Microsoft::WRL::ComPtr data_writer; - RETURN_HR_IF_FAILED(activation_factory->CreateDataWriter( - output_stream.Get(), - data_writer.GetAddressOf())); - - // Write the model to the data writer and thus to the stream - RETURN_HR_IF_FAILED( - data_writer->WriteBytes(static_cast(size), reinterpret_cast(const_cast(bytes)))); - - // QI the in memory stream to a random access stream - Microsoft::WRL::ComPtr random_access_stream; - RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&random_access_stream)); - - // Create a random access stream reference factory - Microsoft::WRL::ComPtr random_access_stream_ref_statics; - RETURN_HR_IF_FAILED(RoGetActivationFactory( - Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_RandomAccessStreamReference).Get(), - IID_PPV_ARGS(random_access_stream_ref_statics.GetAddressOf()))); - - // Create a random access stream reference from the random access stream view on top of - // the in memory stream Microsoft::WRL::ComPtr random_access_stream_ref; - RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream( - random_access_stream.Get(), - random_access_stream_ref.GetAddressOf())); + if (with_copy) { + // Create in memory stream + Microsoft::WRL::ComPtr in_memory_random_access_stream_insp; + RETURN_HR_IF_FAILED(RoActivateInstance( + Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_InMemoryRandomAccessStream).Get(), + in_memory_random_access_stream_insp.GetAddressOf())); + + // QI memory stream to output stream + Microsoft::WRL::ComPtr output_stream; + RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&output_stream)); + + // Create data writer factory + Microsoft::WRL::ComPtr activation_factory; + RETURN_HR_IF_FAILED(RoGetActivationFactory( + Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_DataWriter).Get(), + IID_PPV_ARGS(activation_factory.GetAddressOf()))); + + // Create data writer object based on the in memory stream + Microsoft::WRL::ComPtr data_writer; + RETURN_HR_IF_FAILED(activation_factory->CreateDataWriter( + output_stream.Get(), + data_writer.GetAddressOf())); + + // Write the model to the data writer and thus to the stream + RETURN_HR_IF_FAILED( + data_writer->WriteBytes(static_cast(size), reinterpret_cast(const_cast(bytes)))); + + // QI the in memory stream to a random access stream + Microsoft::WRL::ComPtr random_access_stream; + RETURN_HR_IF_FAILED(in_memory_random_access_stream_insp.As(&random_access_stream)); + + // Create a random access stream reference factory + Microsoft::WRL::ComPtr random_access_stream_ref_statics; + RETURN_HR_IF_FAILED(RoGetActivationFactory( + Microsoft::WRL::Wrappers::HStringReference(RuntimeClass_Windows_Storage_Streams_RandomAccessStreamReference).Get(), + IID_PPV_ARGS(random_access_stream_ref_statics.GetAddressOf()))); + + // Create a random access stream reference from the random access stream view on top of + // the in memory stream + RETURN_HR_IF_FAILED(random_access_stream_ref_statics->CreateFromStream( + random_access_stream.Get(), + random_access_stream_ref.GetAddressOf())); + } else { + Microsoft::WRL::ComPtr> buffer; + RETURN_HR_IF_FAILED( + Microsoft::WRL::MakeAndInitialize>( + &buffer, bytes, bytes + size)); + + RETURN_HR_IF_FAILED( + Microsoft::WRL::MakeAndInitialize( + &random_access_stream_ref, buffer.Get())); + } // Create a learning model factory Microsoft::WRL::ComPtr learning_model; @@ -450,9 +462,9 @@ public: TensorFactory2IID::IID, &tensor_factory)); - Microsoft::WRL::ComPtr> buffer; + Microsoft::WRL::ComPtr> buffer; RETURN_HR_IF_FAILED( - Microsoft::WRL::MakeAndInitialize>( + Microsoft::WRL::MakeAndInitialize>( &buffer, p_data, p_data + data_size)); Microsoft::WRL::ComPtr tensor; @@ -484,7 +496,7 @@ public: std::vector> vec_buffers(num_buffers); for (size_t i = 0; i < num_buffers; i++) { RETURN_HR_IF_FAILED( - Microsoft::WRL::MakeAndInitialize>( + Microsoft::WRL::MakeAndInitialize>( &vec_buffers.at(i), p_data[i], p_data[i] + data_sizes[i])); }