// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once #include "LearningModel.g.h" namespace _winml { struct IEngineFactory; struct IModel; struct IModelInfo; } // namespace _winml namespace WINMLP { struct LearningModel : LearningModelT { /* LearningModel constructors (MachineLearningContract 1). */ LearningModel() = default; LearningModel(const hstring& path, const winml::ILearningModelOperatorProvider operator_provider); LearningModel( const wss::IRandomAccessStreamReference stream, const winml::ILearningModelOperatorProvider operator_provider ); LearningModel( _winml::IEngineFactory* engine_factory, _winml::IModel* model, const winml::ILearningModelOperatorProvider operator_provider ); /* LearningModel properties (MachineLearningContract 1). */ hstring Author(); hstring Name(); hstring Domain(); hstring Description(); int64_t Version(); wfc::IMapView Metadata(); wfc::IVectorView InputFeatures(); wfc::IVectorView OutputFeatures(); void SetName(const hstring& name); /* IClosable methods. */ void Close(); /* LearningModel static methods (MachineLearningContract 1). */ static wf::IAsyncOperation LoadFromStorageFileAsync(ws::IStorageFile const model_file); static wf::IAsyncOperation LoadFromStorageFileAsync( ws::IStorageFile const model_file, winml::ILearningModelOperatorProvider const operator_provider ); static wf::IAsyncOperation LoadFromStreamAsync(wss::IRandomAccessStreamReference const stream); static wf::IAsyncOperation LoadFromStreamAsync( wss::IRandomAccessStreamReference const stream, winml::ILearningModelOperatorProvider const operator_provider ); static winml::LearningModel LoadFromFilePath(hstring const& path); static winml::LearningModel LoadFromFilePath( hstring const& path, winml::ILearningModelOperatorProvider const operator_provider ); static winml::LearningModel LoadFromStream(wss::IRandomAccessStreamReference const stream); static winml::LearningModel LoadFromStream( wss::IRandomAccessStreamReference const stream, winml::ILearningModelOperatorProvider const operator_provider ); public: /* Non-ABI methods */ bool IsDisposed(); IMLOperatorRegistry* GetOperatorRegistry(); _winml::IModel* DetachModel(); _winml::IModel* CloneModel(); _winml::IEngineFactory* GetEngineFactory(); void SaveToFile(const hstring& file_name); void JoinModel( winml::LearningModel other, const std::unordered_map& linkages, bool promote_unlinked_outputs, bool close_model_on_join, const winrt::hstring& join_node_prefix ); private: com_ptr<_winml::IEngineFactory> engine_factory_; com_ptr<_winml::IModel> model_; com_ptr<_winml::IModelInfo> model_info_; ILearningModelOperatorProvider operator_provider_; }; } // namespace WINMLP namespace WINML::factory_implementation { struct LearningModel : LearningModelT { STDMETHOD(Load) (const wchar_t* p_model_path, UINT32 model_path_size, IUnknown** pp_model_unk); }; } // namespace WINML::factory_implementation