// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "lib/Api/pch/pch.h" #include "LearningModel.h" #include "TelemetryEvent.h" #include "MapFeatureDescriptor.h" #include "SequenceFeatureDescriptor.h" #include "TensorFeatureDescriptor.h" #include "OnnxruntimeProvider.h" #include namespace WINMLP { LearningModel::LearningModel( const hstring& path, const winml::ILearningModelOperatorProvider op_provider) try : operator_provider_(op_provider) { _winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad); WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put())); wil::unique_handle file_handle{ #if WINVER >= _WIN32_WINNT_WIN8 CreateFile2(path.c_str(), GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, NULL)}; #else CreateFileW(path.c_str(), GENERIC_READ, FILE_SHARE_READ, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_READONLY, NULL)}; #endif WINML_THROW_HR_IF_TRUE_MSG(__HRESULT_FROM_WIN32(GetLastError()), file_handle.get() == INVALID_HANDLE_VALUE, "Model load failed!"); auto file_mapping = wil::unique_handle(CreateFileMappingW(file_handle.get(), // current file handle NULL, // default security PAGE_READONLY, // read/write permission 0, // size of mapping object, high 0, // size of mapping object, low NULL)); // name of mapping object WINML_THROW_HR_IF_TRUE_MSG(__HRESULT_FROM_WIN32(GetLastError()), file_mapping == nullptr, "Model load failed!"); auto buffer = MapViewOfFile(file_mapping.get(), // handle to mapping object FILE_MAP_READ, // read/write 0, // high-order 32 bits of file offset 0, // low-order 32 bits of file offset 0); // number of bytes to map. 0 means read whole file. WINML_THROW_HR_IF_TRUE_MSG(__HRESULT_FROM_WIN32(GetLastError()), file_mapping == nullptr, "Model load failed!"); LARGE_INTEGER file_size; WINML_THROW_HR_IF_FALSE_MSG(__HRESULT_FROM_WIN32(GetLastError()), GetFileSizeEx(file_handle.get(), &file_size), "GetFileSizeEx"); WINML_THROW_IF_FAILED(engine_factory_->CreateModel(buffer, static_cast(file_size.QuadPart), model_.put())); WINML_THROW_HR_IF_TRUE_MSG(E_UNEXPECTED, UnmapViewOfFile(buffer) == 0, "Could not unmap model file."); WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); } WINML_CATCH_ALL LearningModel::LearningModel( _winml::IEngineFactory* engine_factory, _winml::IModel* model, const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { engine_factory_.copy_from(engine_factory); model_.copy_from(model); WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); } WINML_CATCH_ALL static HRESULT CreateModelFromStream( _winml::IEngineFactory* engine_factory, const wss::IRandomAccessStreamReference stream, _winml::IModel** model) { auto content = stream.OpenReadAsync().get(); wss::Buffer buffer(static_cast(content.Size())); auto result = content.ReadAsync( buffer, buffer.Capacity(), wss::InputStreamOptions::None) .get(); auto bytes = buffer.try_as<::Windows::Storage::Streams::IBufferByteAccess>(); WINML_THROW_HR_IF_NULL_MSG(E_UNEXPECTED, bytes, "Model stream is invalid."); void* data; WINML_THROW_IF_FAILED_MSG(bytes->Buffer(reinterpret_cast(&data)), "Failed to acquire buffer from model stream."); size_t len = static_cast(content.Size()); if (FAILED(engine_factory->CreateModel(data, len, model))) { WINML_THROW_HR(E_INVALIDARG); } return S_OK; } LearningModel::LearningModel( const wss::IRandomAccessStreamReference stream, const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { _winmlt::TelemetryEvent loadModel_event(_winmlt::EventCategory::kModelLoad); WINML_THROW_IF_FAILED(CreateOnnxruntimeEngineFactory(engine_factory_.put())); WINML_THROW_IF_FAILED(CreateModelFromStream(engine_factory_.get(), stream, model_.put())); WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); } WINML_CATCH_ALL hstring LearningModel::Author() try { const char* out; size_t len; WINML_THROW_IF_FAILED(model_info_->GetAuthor(&out, &len)); return _winml::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL hstring LearningModel::Name() try { const char* out; size_t len; WINML_THROW_IF_FAILED(model_info_->GetName(&out, &len)); return _winml::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL hstring LearningModel::Domain() try { const char* out; size_t len; WINML_THROW_IF_FAILED(model_info_->GetDomain(&out, &len)); return _winml::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL hstring LearningModel::Description() try { const char* out; size_t len; WINML_THROW_IF_FAILED(model_info_->GetDescription(&out, &len)); return _winml::Strings::HStringFromUTF8(out); } WINML_CATCH_ALL int64_t LearningModel::Version() try { int64_t version; WINML_THROW_IF_FAILED(model_info_->GetVersion(&version)); return version; } WINML_CATCH_ALL wfc::IMapView LearningModel::Metadata() try { ABI::Windows::Foundation::Collections::IMapView* metadata = nullptr; wfc::IMapView out; WINML_THROW_IF_FAILED(model_info_->GetModelMetadata(&metadata)); winrt::attach_abi(out, metadata); return out; } WINML_CATCH_ALL IMLOperatorRegistry* LearningModel::GetOperatorRegistry() { if (operator_provider_ == nullptr) { return nullptr; } // Get the native winrt provider interface out of winrt operator provider. auto operator_provider_native = operator_provider_.as(); IMLOperatorRegistry* registry = nullptr; // Retrieve the "operator abi" registry. THROW_IF_FAILED(operator_provider_native->GetRegistry(®istry)); return registry; } wfc::IVectorView LearningModel::InputFeatures() try { ABI::Windows::Foundation::Collections::IVectorView* features = nullptr; wfc::IVectorView out; WINML_THROW_IF_FAILED(model_info_->GetInputFeatures(&features)); winrt::attach_abi(out, features); return out; } WINML_CATCH_ALL wfc::IVectorView LearningModel::OutputFeatures() try { ABI::Windows::Foundation::Collections::IVectorView* features = nullptr; wfc::IVectorView out; WINML_THROW_IF_FAILED(model_info_->GetOutputFeatures(&features)); winrt::attach_abi(out, features); return out; } WINML_CATCH_ALL void LearningModel::Close() try { // close the model model_ = nullptr; } WINML_CATCH_ALL bool LearningModel::IsDisposed() { return model_ == nullptr; } wf::IAsyncOperation LearningModel::LoadFromStorageFileAsync( ws::IStorageFile const modelFile) { return LoadFromStorageFileAsync(modelFile, nullptr); } wf::IAsyncOperation LearningModel::LoadFromStorageFileAsync( ws::IStorageFile const modelFile, winml::ILearningModelOperatorProvider const provider) { co_await resume_background(); return make(modelFile, provider); } wf::IAsyncOperation LearningModel::LoadFromStreamAsync( wss::IRandomAccessStreamReference const model_stream) { return LoadFromStreamAsync(model_stream, nullptr); } wf::IAsyncOperation LearningModel::LoadFromStreamAsync( wss::IRandomAccessStreamReference const model_stream, winml::ILearningModelOperatorProvider const provider) { co_await resume_background(); return make(model_stream, provider); } winml::LearningModel LearningModel::LoadFromFilePath( hstring const& path) try { return LoadFromFilePath(path, nullptr); } WINML_CATCH_ALL winml::LearningModel LearningModel::LoadFromFilePath( hstring const& path, winml::ILearningModelOperatorProvider const provider) try { return make(path, provider); } WINML_CATCH_ALL winml::LearningModel LearningModel::LoadFromStream( wss::IRandomAccessStreamReference const model_stream) try { return LoadFromStream(model_stream, nullptr); } WINML_CATCH_ALL winml::LearningModel LearningModel::LoadFromStream( wss::IRandomAccessStreamReference const model_stream, winml::ILearningModelOperatorProvider const provider) try { return make(model_stream, provider); } WINML_CATCH_ALL _winml::IModel* LearningModel::DetachModel() { com_ptr<_winml::IModel> detached_model; if (model_ != nullptr) { detached_model.attach(model_.detach()); // Close the model since we now own the model proto Close(); } return detached_model.detach(); } _winml::IModel* LearningModel::CloneModel() { if (model_ == nullptr) { return nullptr; } com_ptr<_winml::IModel> model_copy; WINML_THROW_IF_FAILED(model_->CloneModel(model_copy.put())); return model_copy.detach(); } _winml::IEngineFactory* LearningModel::GetEngineFactory() { return engine_factory_.get(); } void LearningModel::SaveToFile(const hstring& file_name) { model_->SaveModel(file_name.c_str(), file_name.size()); } void LearningModel::JoinModel( winml::LearningModel other, const std::unordered_map& linkages, bool promote_unlinked_outputs, bool close_model_on_join, const winrt::hstring& join_node_prefix) { auto otherp = other.as(); winrt::com_ptr<_winml::IModel> other_model; if (close_model_on_join) { other_model.attach(otherp->DetachModel()); } else { other_model.attach(otherp->CloneModel()); } std::vector raw_outputs(linkages.size()); std::vector raw_inputs(linkages.size()); std::transform(std::begin(linkages), std::end(linkages), std::begin(raw_outputs), [](auto& pair) { return pair.first.c_str(); }); std::transform(std::begin(linkages), std::end(linkages), std::begin(raw_inputs), [](auto& pair) { return pair.second.c_str(); }); auto prefix = winrt::to_string(join_node_prefix); WINML_THROW_IF_FAILED(model_->JoinModel(other_model.get(), raw_outputs.data(), raw_inputs.data(), linkages.size(), promote_unlinked_outputs, prefix.c_str())); model_info_ = nullptr; WINML_THROW_IF_FAILED(model_->GetModelInfo(model_info_.put())); } } // namespace WINMLP namespace WINML::factory_implementation { // copied from cppwinrt magic to create abi wrappers. Need to do it this way // since peeps underneath (like the constructor) will throw HRESULT __stdcall LearningModel::Load( const wchar_t* p_model_path, uint32_t model_path_size, IUnknown** pp_model_unk) { try { WINML_THROW_HR_IF_NULL_MSG(E_INVALIDARG, p_model_path, "Failed to create LearningModel. Ivalid argument p_model_path."); WINML_THROW_HR_IF_FALSE_MSG(E_INVALIDARG, model_path_size > 0, "Failed to create LearningModel. Ivalid argument model_path_size."); WINML_THROW_HR_IF_NULL_MSG(E_INVALIDARG, pp_model_unk, "Failed to create LearningModel. Ivalid argument pp_model_unk."); winrt::hstring path(p_model_path, model_path_size); auto model = make(path, nullptr); *pp_model_unk = model.as().detach(); return S_OK; } WINML_CATCH_ALL_COM } } // namespace WINML::factory_implementation