onnxruntime/winml/adapter/abi_custom_registry_impl.h
Dwayne Robinson 1de79d3158 Merged PR 4781392: Add 64-bit tensor types to DML EP
- Add 64-bit for those which now support true 64-bit (Gather, Scatter, OneHot, Cast), and update others (ArgMin, ArgMax, ReverseSequence, TopK, MaxPool, MaxUnpool) which take 64-bit indices but use strided 32-bit fallback.
- Stop forcibly coercing all 64-bit tensors in TensorDesc. Instead, decide in the respective kernels how to behave.
- Update graph partitioning code with enough registration information to know whether (a) 64-bit tenors are not supported at all (b) they are support via strided 32-bit fallback (c) they are supported via fallback and directly (preferred when device capable). Unfortunately this introduces a lot of flag parameters :/.

Related work items: #22265955
2020-06-17 09:00:54 +00:00

45 lines
1.9 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#ifdef USE_DML
#include "core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h"
namespace Windows::AI::MachineLearning::Adapter {
// An implementation of AbiCustomRegistry that emits telemetry events when operator kernels or schemas are registered.
class AbiCustomRegistryImpl : public AbiCustomRegistry {
public:
HRESULT STDMETHODCALLTYPE RegisterOperatorSetSchema(
const MLOperatorSetId* op_set_id,
int baseline_version,
const MLOperatorSchemaDescription* const* schema,
uint32_t schema_count,
_In_opt_ IMLOperatorTypeInferrer* type_inferrer,
_In_opt_ IMLOperatorShapeInferrer* shape_inferrer) const noexcept override;
HRESULT STDMETHODCALLTYPE RegisterOperatorKernel(
const MLOperatorKernelDescription* operator_kernel,
IMLOperatorKernelFactory* operator_kernel_factory,
_In_opt_ IMLOperatorShapeInferrer* shape_inferrer,
_In_opt_ IMLOperatorSupportQueryPrivate* supportQuery,
bool is_internal_operator,
bool can_alias_first_input,
bool supports_graph,
const uint32_t* required_input_count_for_graph = nullptr,
bool requires_float_formats_for_graph = false,
bool supports_64bit_directly = false,
bool allows_64bit_via_strides = false,
bool allows_64bit_via_strides_from_any_ep = false,
_In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr,
uint32_t constant_cpu_input_count = 0) const noexcept override;
HRESULT STDMETHODCALLTYPE RegisterOperatorKernel(
const MLOperatorKernelDescription* op_kernel,
IMLOperatorKernelFactory* operator_kernel_factory,
_In_opt_ IMLOperatorShapeInferrer* shape_inferrer) const noexcept override;
};
} // namespace Windows::AI::MachineLearning::Adapter
#endif USE_DML