diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index b6e6a9c051..8343cd1b2a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -86,9 +86,19 @@ namespace Dml } else { + auto operatorDescCopy = operatorDesc; + + // TODO: Change as new header is ingested + if (operatorDescCopy.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) + operatorDescCopy.Type = (DML_OPERATOR_TYPE) 169; + + // TODO: Change as new header is ingested + if (operatorDescCopy.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) + operatorDescCopy.Type = (DML_OPERATOR_TYPE) 170; + // Create and compile the operator. ComPtr dmlOperator; - ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDesc, IID_PPV_ARGS(&dmlOperator))); + ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDescCopy, IID_PPV_ARGS(&dmlOperator))); DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags(); ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), executionFlags, IID_PPV_ARGS(&m_compiledOperator)));