merge layer_dev_paulm

This commit is contained in:
Paul McDaniel 2019-11-19 10:58:27 -08:00
commit 94fc7bccff
3 changed files with 13 additions and 10 deletions

View file

@ -705,15 +705,20 @@ HRESULT STDMETHODCALLTYPE CreateMLValue(
return S_OK;
}
static void Delete(void* p) {
// do nothing
}
HRESULT STDMETHODCALLTYPE CreateOrtValue(
void* data,
onnxruntime::MLDataType data_type,
IOrtValue** ort_value) override {
auto ort_value_out = wil::MakeOrThrow<AbiSafeOrtValue>();
// pass the data in as a weak ref, don't let it delete it
ort_value_out->get()->Init(
data,
data_type,
data_type->GetDeleteFunc());
&Delete);
*ort_value = ort_value_out.Detach();
return S_OK;

View file

@ -139,9 +139,7 @@ struct MapBase : winrt::implements<
}
// Create a copy of the map
auto map = context.type == WinML::BindingType::kInput ?
std::make_unique<LotusMap>(ConvertToLotusMap(m_data)) :
std::make_unique<LotusMap>();
lotus_data_ = std::make_unique<LotusMap>(ConvertToLotusMap(m_data));
winrt::com_ptr<_winmla::IWinMLAdapter> adapter;
RETURN_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));
@ -149,7 +147,7 @@ struct MapBase : winrt::implements<
auto lotus_type = GetLotusType<TKey, TValue>(adapter.get());
winrt::com_ptr<_winmla::IOrtValue> ml_value_out;
adapter->CreateOrtValue(map.release(), lotus_type, ml_value_out.put());
adapter->CreateOrtValue(lotus_data_.get(), lotus_type, ml_value_out.put());
*ml_value = ml_value_out.detach();
return S_OK;
@ -191,6 +189,7 @@ struct MapBase : winrt::implements<
private:
ABIMap m_data;
std::unique_ptr<LotusMap> lotus_data_;
};
} // namespace Windows::AI::MachineLearning

View file

@ -183,10 +183,8 @@ struct SequenceBase : public winrt::implements<
return S_OK;
}
// Create a copy of the sequence
auto sequence = context.type == WinML::BindingType::kInput
? std::make_unique<LotusSequence>(ConvertToLotusSequence(data_))
: std::make_unique<LotusSequence>();
// handle inputs, Create a copy of the sequence
lotus_data_ = std::make_unique<LotusSequence>(ConvertToLotusSequence(data_));
winrt::com_ptr<_winmla::IWinMLAdapter> adapter;
RETURN_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));
@ -195,7 +193,7 @@ struct SequenceBase : public winrt::implements<
TensorKindFrom<ValidLotusType<T>::TValue>::Type);
winrt::com_ptr<_winmla::IOrtValue> ml_value_out;
adapter->CreateOrtValue(sequence.release(), lotus_type, ml_value_out.put());
adapter->CreateOrtValue(lotus_data_.get(), lotus_type, ml_value_out.put());
*ml_value = ml_value_out.detach();
return S_OK;
@ -283,6 +281,7 @@ struct SequenceBase : public winrt::implements<
private:
ABISequence data_;
std::unique_ptr<LotusSequence> lotus_data_;
};
} // namespace Windows::AI::MachineLearning