couple of fixes and coded getmutabledata()

This commit is contained in:
Paul McDaniel 2019-11-18 18:31:04 -08:00
parent a3542e1128
commit acc6ea525b
3 changed files with 720 additions and 698 deletions

File diff suppressed because it is too large Load diff

View file

@ -32,7 +32,7 @@ MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") ITensor : IUnknown{
MIDL_INTERFACE("72aa5eee-100c-4146-9008-4643d3b8af23") IOrtValue : IUnknown{
// these all return weak pointers
virtual OrtValue& STDMETHODCALLTYPE get() = 0;
virtual OrtValue* STDMETHODCALLTYPE get() = 0;
virtual onnxruntime::MLDataType STDMETHODCALLTYPE Type() = 0;
virtual bool STDMETHODCALLTYPE IsTensor() = 0;
// end
@ -126,9 +126,9 @@ MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown
virtual onnxruntime::MLDataType STDMETHODCALLTYPE GetVectorMapType(winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
// Data getter
virtual void * STDMETHODCALLTYPE GetTensorData(IOrtValue * ort_Value) = 0;
virtual void * STDMETHODCALLTYPE GetMapData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
virtual void * STDMETHODCALLTYPE GetVectorData(IOrtValue * ort_Value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
virtual void * STDMETHODCALLTYPE GetTensorData(IOrtValue * ort_value) = 0;
virtual void * STDMETHODCALLTYPE GetMapData(IOrtValue * ort_value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
virtual void * STDMETHODCALLTYPE GetVectorData(IOrtValue * ort_value, winml::TensorKind key_kind, winml::TensorKind value_kind) = 0;
// custom ops
virtual HRESULT STDMETHODCALLTYPE GetCustomRegistry(IMLOperatorRegistry** registry) = 0;

View file

@ -206,21 +206,24 @@ struct SequenceBase : public winrt::implements<
template <typename TRawType>
static TRawType
ConvertToABIType(
typename ValidLotusType<TRawType>::Type lotus_value) {
return lotus_value;
const typename ValidLotusType<TRawType>::Type& lotus_value) {
// make a copy
TRawType copy = lotus_value;
return copy;
}
template <>
static winrt::hstring
ConvertToABIType(
typename ValidLotusType<winrt::hstring>::Type lotus_value) {
const typename ValidLotusType<winrt::hstring>::Type& lotus_value) {
return WinML::Strings::HStringFromUTF8(lotus_value);
}
template <>
static AbiMapStringToFloat
ConvertToABIType(
typename ValidLotusType<AbiMapStringToFloat>::Type lotus_value) {
const typename ValidLotusType<AbiMapStringToFloat>::Type& lotus_value) {
// need to make a copy to convert std::string to hstring
std::map<winrt::hstring, float> copy;
for (const auto& pair : lotus_value) {
auto key = WinML::Strings::HStringFromUTF8(pair.first);
@ -233,9 +236,14 @@ struct SequenceBase : public winrt::implements<
template <>
static AbiMapInt64BitToFloat
ConvertToABIType(
typename ValidLotusType<AbiMapInt64BitToFloat>::Type lotus_value) {
const typename ValidLotusType<AbiMapInt64BitToFloat>::Type& lotus_value) {
// need to make a copy since stl objects are not ABI safe.
std::map<int64_t, float> copy;
for (const auto& pair : lotus_value) {
copy[pair.first] = pair.second;
}
return winrt::single_threaded_map<int64_t, float>(
std::move(lotus_value));
std::move(copy));
}
STDMETHOD(UpdateSourceResourceData)(