onnxruntime/winml/lib/Api/LearningModelSession.h
Paul McDaniel 5350abe19d
LearningModelSession is cleaned up to use the adapter, and parts of b… (#2382)
this is a big PR.    we are going to move it up to layer_dev , which is still a L3 so we are still safe to do work there agile.

we are going to move this into the L3 so that ryan can start doing intergration testing.   

we will pause for a full code review and integration test result prior to going into the L2.

>>>> raw comments from previous commits >>> 

* LearningModelSession is cleaned up to use the adapter, and parts of binding are.
* moved everything in the winmladapter
made it all nano-com using, WRL to construct objects in the ORT side.
base interfaces for everythign for winml to call
cleaned up a bunch of winml to use the base interfaces.
* more pieces
* GetData across the abi.
* renamed some namepsace
cleaned up OrtValue
cleaned up Tensor
cleaned up custom ops.
everything *but* learnignmodel should be clean
* make sure it's building.   winml.dll is still a monolith.
2019-11-14 17:44:07 -08:00

130 lines
3.4 KiB
C++

// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include "LearningModelSession.g.h"
#include "LearningModelBinding.h"
#include "WinML_Lock.h"
#include "WinMLAdapter.h"
namespace winrt::Windows::AI::MachineLearning::implementation {
struct LearningModelSession : LearningModelSessionT<LearningModelSession> {
/* LearningModelSession constructors (MachineLearningContract 1). */
LearningModelSession() = delete;
LearningModelSession(
winml::LearningModel const& model);
LearningModelSession(
winml::LearningModel const& model,
winml::LearningModelDevice const& deviceToRunOn);
/* LearningModelSession constructors (MachineLearningContract 2). */
LearningModelSession(
winml::LearningModel const& model,
winml::LearningModelDevice const& deviceToRunOn,
winml::LearningModelSessionOptions const& sessionOptions);
/* IClosable methods. */
void
Close();
/* LearningModelSession properties (MachineLearningContract 1). */
wfc::IPropertySet
EvaluationProperties();
winml::LearningModel
Model();
winml::LearningModelDevice
Device();
/* LearningModelSession methods (MachineLearningContract 1). */
winml::LearningModelEvaluationResult
Evaluate(
winml::LearningModelBinding binding,
hstring const& correlationId);
wf::IAsyncOperation<winml::LearningModelEvaluationResult>
EvaluateAsync(
winml::LearningModelBinding binding,
hstring const correlationId);
winml::LearningModelEvaluationResult
EvaluateFeatures(
wfc::IMap<hstring, wf::IInspectable> const features,
hstring const correlationId);
wf::IAsyncOperation<winml::LearningModelEvaluationResult>
EvaluateFeaturesAsync(
wfc::IMap<hstring, wf::IInspectable> const features,
hstring const correlationId);
public:
/* Non-ABI methods */
onnxruntime::IExecutionProvider*
GetExecutionProvider();
_winmla::IIOBinding*
CreateSessionBinding();
private:
void
Initialize();
_winmla::IModelProto*
GetOptimizedModel();
_winmla::IModelProto*
GetOptimizedModel(bool should_close_model);
uint64_t
Run(
winrt::com_ptr<winmlp::LearningModelBinding> bindingImpl);
winml::LearningModelEvaluationResult
GetResults(
winrt::com_ptr<winmlp::LearningModelBinding> bindingImpl,
hstring const& correlationId,
uint64_t fenceValueForDML);
void
ApplyEvaluationProperties();
void
ToggleProfiler();
void
CheckClosed();
private:
com_ptr<_winmla::IInferenceSession> inference_session_;
// reference to the active execution provider. weak
onnxruntime::IExecutionProvider* cached_execution_provider_ = nullptr;
winml::LearningModel model_;
winml::LearningModelDevice device_;
winml::LearningModelSessionOptions session_options_;
wfc::IPropertySet evaluation_properties_;
// Synchronization
CWinMLLock session_creation_lock_;
CWinMLLock evaluate_lock_;
// is_first_evaluate_ is used as a heuristic to determine
// when the dml upload heap can be trimmed.
bool is_first_evaluate_ = true;
};
} // namespace winrt::Windows::AI::MachineLearning::implementation
namespace winrt::Windows::AI::MachineLearning::factory_implementation {
struct LearningModelSession : LearningModelSessionT<LearningModelSession, implementation::LearningModelSession> {
};
} // namespace winrt::Windows::AI::MachineLearning::factory_implementation