mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-27 22:45:57 +00:00
Layer dev paulm (#2426)
* model moved over. everything builds clean. step ! * weak ref comment * added a wrapper for RoGetActivationFactory to hook back into winml for creating winml objects. fixes model load. * fixed some lifetime management. fixed the debug build. squeezenet passes using winmlrunner for CPU and GPU * PR feedback. * couple of fixes and coded getmutabledata() * fixed 2 more heap corruptions
This commit is contained in:
parent
8c28f0816f
commit
3f67aaaf81
3 changed files with 24 additions and 9 deletions
|
|
@ -777,7 +777,7 @@ class IOBinding : public Microsoft::WRL::RuntimeClass<
|
|||
return binding_->GetOutputNames();
|
||||
}
|
||||
std::vector<IOrtValue*>& STDMETHODCALLTYPE GetOutputs() override {
|
||||
auto output_inner = binding_->GetOutputs();
|
||||
auto& output_inner = binding_->GetOutputs();
|
||||
outputs_weak_.clear();
|
||||
outputs_.clear();
|
||||
for (unsigned i = 0; i < output_inner.size(); i++) {
|
||||
|
|
|
|||
|
|
@ -62,12 +62,13 @@ struct MapBase : winrt::implements<
|
|||
}
|
||||
|
||||
template <typename TRawType>
|
||||
static TRawType ConvertToABIType(typename ValidLotusType<TRawType>::Type lotusValue) {
|
||||
return lotusValue;
|
||||
static TRawType ConvertToABIType(const typename ValidLotusType<TRawType>::Type& lotusValue) {
|
||||
TRawType copy = lotusValue;
|
||||
return copy;
|
||||
}
|
||||
|
||||
template <>
|
||||
static typename winrt::hstring ConvertToABIType(typename ValidLotusType<winrt::hstring>::Type lotusValue) {
|
||||
static typename winrt::hstring ConvertToABIType(const typename ValidLotusType<winrt::hstring>::Type& lotusValue) {
|
||||
return WinML::Strings::HStringFromUTF8(lotusValue);
|
||||
}
|
||||
|
||||
|
|
@ -128,11 +129,19 @@ struct MapBase : winrt::implements<
|
|||
return adapter->GetMapType(TensorKindFrom<TLotusKey>::Type, TensorKindFrom<TLotusValue>::Type);
|
||||
}
|
||||
|
||||
STDMETHOD(GetOrtValue)(WinML::BindingContext& context, _winmla::IOrtValue** mlValue) {
|
||||
STDMETHOD(GetOrtValue)(WinML::BindingContext& context, _winmla::IOrtValue** ml_value) {
|
||||
// TODO: Tensorized data should be cached so multiple bindings work more efficiently
|
||||
|
||||
// TODO : we need to handle inputs. for now only handle outputs and don't pre allocate anything
|
||||
if (context.type == WinML::BindingType::kOutput) {
|
||||
*ml_value = nullptr;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
// Create a copy of the map
|
||||
auto map = context.type == WinML::BindingType::kInput ? std::make_unique<LotusMap>(ConvertToLotusMap(m_data)) : std::make_unique<LotusMap>();
|
||||
auto map = context.type == WinML::BindingType::kInput ?
|
||||
std::make_unique<LotusMap>(ConvertToLotusMap(m_data)) :
|
||||
std::make_unique<LotusMap>();
|
||||
|
||||
winrt::com_ptr<_winmla::IWinMLAdapter> adapter;
|
||||
RETURN_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));
|
||||
|
|
@ -142,7 +151,7 @@ struct MapBase : winrt::implements<
|
|||
winrt::com_ptr<_winmla::IOrtValue> ml_value_out;
|
||||
adapter->CreateOrtValue(map.release(), lotus_type, ml_value_out.put());
|
||||
|
||||
*mlValue = ml_value_out.detach();
|
||||
*ml_value = ml_value_out.detach();
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -177,6 +177,12 @@ struct SequenceBase : public winrt::implements<
|
|||
_winmla::IOrtValue** ml_value) {
|
||||
// TODO: Tensorized data should be cached so multiple bindings work more efficiently
|
||||
|
||||
// TODO : we need to handle inputs. for now only handle outputs and don't pre allocate anything
|
||||
if (context.type == WinML::BindingType::kOutput) {
|
||||
*ml_value = nullptr;
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
// Create a copy of the sequence
|
||||
auto sequence = context.type == WinML::BindingType::kInput
|
||||
? std::make_unique<LotusSequence>(ConvertToLotusSequence(data_))
|
||||
|
|
@ -256,12 +262,12 @@ struct SequenceBase : public winrt::implements<
|
|||
winrt::com_ptr<_winmla::IWinMLAdapter> adapter;
|
||||
RETURN_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));
|
||||
|
||||
const LotusSequence& sequence = *static_cast<LotusSequence*>(adapter->GetVectorData(
|
||||
const LotusSequence* sequence = static_cast<LotusSequence*>(adapter->GetVectorData(
|
||||
ml_value,
|
||||
TensorKindFrom<ValidLotusType<T>::TKey>::Type,
|
||||
TensorKindFrom<ValidLotusType<T>::TValue>::Type));
|
||||
|
||||
for (const auto& element : sequence) {
|
||||
for (const auto& element : *sequence) {
|
||||
writable_vector.Append(ConvertToABIType<T>(element));
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue