mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
couple of fixes and coded getmutabledata()
This commit is contained in:
parent
a3542e1128
commit
acc6ea525b
3 changed files with 720 additions and 698 deletions
File diff suppressed because it is too large
Load diff
|
|
@ -32,7 +32,7 @@ MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") ITensor : IUnknown{
|
|||
|
||||
MIDL_INTERFACE("72aa5eee-100c-4146-9008-4643d3b8af23") IOrtValue : IUnknown{
|
||||
// these all return weak pointers
|
||||
virtual OrtValue& STDMETHODCALLTYPE get() = 0;
|
||||
virtual OrtValue* STDMETHODCALLTYPE get() = 0;
|
||||
virtual onnxruntime::MLDataType STDMETHODCALLTYPE Type() = 0;
|
||||
virtual bool STDMETHODCALLTYPE IsTensor() = 0;
|
||||
// end
|
||||
|
|
@ -126,9 +126,9 @@ MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown
|
|||
virtual onnxruntime::MLDataType STDMETHODCALLTYPE GetVectorMapType(winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
|
||||
|
||||
// Data getter
|
||||
virtual void * STDMETHODCALLTYPE GetTensorData(IOrtValue * ort_Value) = 0;
|
||||
virtual void * STDMETHODCALLTYPE GetMapData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
|
||||
virtual void * STDMETHODCALLTYPE GetVectorData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
|
||||
virtual void * STDMETHODCALLTYPE GetTensorData(IOrtValue * ort_value) = 0;
|
||||
virtual void * STDMETHODCALLTYPE GetMapData(IOrtValue * ort_value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
|
||||
virtual void * STDMETHODCALLTYPE GetVectorData(IOrtValue * ort_value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
|
||||
|
||||
// custom ops
|
||||
virtual HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) = 0;
|
||||
|
|
|
|||
|
|
@ -206,21 +206,24 @@ struct SequenceBase : public winrt::implements<
|
|||
template <typename TRawType>
|
||||
static TRawType
|
||||
ConvertToABIType(
|
||||
typename ValidLotusType<TRawType>::Type lotus_value) {
|
||||
return lotus_value;
|
||||
const typename ValidLotusType<TRawType>::Type& lotus_value) {
|
||||
// make a copy
|
||||
TRawType copy = lotus_value;
|
||||
return copy;
|
||||
}
|
||||
|
||||
template <>
|
||||
static winrt::hstring
|
||||
ConvertToABIType(
|
||||
typename ValidLotusType<winrt::hstring>::Type lotus_value) {
|
||||
const typename ValidLotusType<winrt::hstring>::Type& lotus_value) {
|
||||
return WinML::Strings::HStringFromUTF8(lotus_value);
|
||||
}
|
||||
|
||||
template <>
|
||||
static AbiMapStringToFloat
|
||||
ConvertToABIType(
|
||||
typename ValidLotusType<AbiMapStringToFloat>::Type lotus_value) {
|
||||
const typename ValidLotusType<AbiMapStringToFloat>::Type& lotus_value) {
|
||||
// need to make a copy to convert std::string to hstring
|
||||
std::map<winrt::hstring, float> copy;
|
||||
for (const auto& pair : lotus_value) {
|
||||
auto key = WinML::Strings::HStringFromUTF8(pair.first);
|
||||
|
|
@ -233,9 +236,14 @@ struct SequenceBase : public winrt::implements<
|
|||
template <>
|
||||
static AbiMapInt64BitToFloat
|
||||
ConvertToABIType(
|
||||
typename ValidLotusType<AbiMapInt64BitToFloat>::Type lotus_value) {
|
||||
const typename ValidLotusType<AbiMapInt64BitToFloat>::Type& lotus_value) {
|
||||
// need to make a copy since stl objects are not ABI safe.
|
||||
std::map<int64_t, float> copy;
|
||||
for (const auto& pair : lotus_value) {
|
||||
copy[pair.first] = pair.second;
|
||||
}
|
||||
return winrt::single_threaded_map<int64_t, float>(
|
||||
std::move(lotus_value));
|
||||
std::move(copy));
|
||||
}
|
||||
|
||||
STDMETHOD(UpdateSourceResourceData)(
|
||||
|
|
|
|||
Loading…
Reference in a new issue