mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-24 22:17:32 +00:00
DML EP Register ScatterND-16 (#14240)
This PR registers ScatterND-16 to the DML EP - CPU fallback is added if the reduction attribute is in use, as this is not yet supported by DML. Co-authored-by: Numfor Mbiziwo-Tiapo <numform@microsoft.com>
This commit is contained in:
parent
8f7eb75c3e
commit
dee36f8ade
5 changed files with 22 additions and 1 deletions
|
|
@ -1080,7 +1080,8 @@ Do not modify directly.*
|
|||
|||9+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|ScatterElements|*in* data:**T**<br> *in* indices:**Tind**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
|
||||
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|16+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|Selu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
|
||||
|Shape|*in* data:**T**<br> *out* shape:**T1**|15+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **T1** = tensor(int64)|
|
||||
|
|
|
|||
|
|
@ -108,6 +108,22 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
void CALLBACK QueryScatter(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
|
||||
{
|
||||
*isSupported = false;
|
||||
|
||||
MLOperatorAttributes attributes(context);
|
||||
|
||||
// DML does not support reduction.
|
||||
std::string reduction = attributes.GetOptionalAttribute<std::string>(AttrName::Reduction, "none");
|
||||
if (reduction != "none")
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
*isSupported = true;
|
||||
}
|
||||
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Scatter9, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Scatter11, DmlOperatorScatter);
|
||||
DML_OP_DEFINE_CREATION_FUNCTION(Scatter13, DmlOperatorScatter);
|
||||
|
|
|
|||
|
|
@ -274,6 +274,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
|
|||
DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(Resize);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(Scatter);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(EinSum);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(RecurrentNeuralNetwork);
|
||||
DML_OP_EXTERN_QUERY_FUNCTION(BatchNormalization);
|
||||
|
|
@ -488,6 +489,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
|
|||
{REG_INFO( 13, ScatterElements, typeNameListScatterGather, supportedTypeListScatterGather, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 11, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 13, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 16, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryScatter)},
|
||||
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListEyeLike, DmlGraphSupport::Supported)},
|
||||
{REG_INFO( 14, Trilu, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
|
||||
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ namespace AttrName
|
|||
static constexpr const char* OutputWidth = "output_width";
|
||||
static constexpr const char* Pads = "pads";
|
||||
static constexpr const char* PooledShape = "pooled_shape";
|
||||
static constexpr const char* Reduction = "reduction";
|
||||
static constexpr const char* Reverse = "reverse";
|
||||
static constexpr const char* SampleSize = "sample_size";
|
||||
static constexpr const char* SamplingRatio = "sampling_ratio";
|
||||
|
|
|
|||
|
|
@ -378,6 +378,7 @@ namespace OperatorHelper
|
|||
static const int sc_sinceVer_Where = 16;
|
||||
static const int sc_sinceVer_GreaterOrEqual = 16;
|
||||
static const int sc_sinceVer_LessOrEqual = 16;
|
||||
static const int sc_sinceVer_ScatterND = 16;
|
||||
} // namespace OnnxOperatorSet16
|
||||
|
||||
namespace OnnxOperatorSet17
|
||||
|
|
|
|||
Loading…
Reference in a new issue