mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Add comment (#6860)
Co-authored-by: Jingyan Wang <jingywa@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
This commit is contained in:
parent
6285ee2398
commit
f22f04a109
10 changed files with 19 additions and 38 deletions
|
|
@ -20,7 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask);
|
||||
|
|
@ -178,7 +178,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, WordConvEmbedding)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, GatherND)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MurmurHash3)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MaxpoolWithMask)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Pad)>,
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
// TransposedMatMul is kept for backward compatibility
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
TransposeMatMul,
|
||||
kMSDomain,
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
|
||||
|
|
@ -83,7 +83,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul); // backward compatibility
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization);
|
||||
#endif
|
||||
|
|
@ -105,9 +105,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul)>,
|
||||
|
|
@ -171,7 +171,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
|
||||
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>,
|
||||
// TransposedMatMul is still here for backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, TransposeMatMul)>, // backward compatibility
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float, LayerNormalization)>,
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ namespace cuda {
|
|||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
onnxruntime::cuda::MatMul<T>);
|
||||
|
||||
// TransposeMatMul is kept here for backward compatibility
|
||||
REGISTER_KERNEL_TYPED(TransposeMatMul, float)
|
||||
REGISTER_KERNEL_TYPED(TransposeMatMul, double)
|
||||
REGISTER_KERNEL_TYPED(TransposeMatMul, MLFloat16)
|
||||
|
|
|
|||
|
|
@ -17,9 +17,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul);
|
||||
|
|
@ -92,9 +89,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, BiasGelu)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, TransposeMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, TransposeMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, TransposeMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, FusedMatMul)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, FusedMatMul)>,
|
||||
|
|
|
|||
|
|
@ -187,8 +187,7 @@ Status ProcessNode(
|
|||
}
|
||||
|
||||
if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedMatMul", {1}, kMSDomain) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeMatMul", {1}, kMSDomain)) {
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedMatMul", {1}, kMSDomain)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
@ -206,7 +205,7 @@ Status ProcessNode(
|
|||
}
|
||||
|
||||
NodeAttributes fused_node_attrs =
|
||||
(node.OpType() == "TransposeMatMul") || (node.OpType() == "FusedMatMul") ? node.GetAttributes() : NodeAttributes{};
|
||||
node.OpType() == "FusedMatMul" ? node.GetAttributes() : NodeAttributes{};
|
||||
|
||||
{
|
||||
ONNX_NAMESPACE::AttributeProto& alpha_attr = fused_node_attrs["alpha"];
|
||||
|
|
|
|||
|
|
@ -98,8 +98,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
|
|||
ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger));
|
||||
|
||||
if ((!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {9, 13}) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedMatMul", {1}, kMSDomain) &&
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "TransposeMatMul", {1}, kMSDomain)) ||
|
||||
!graph_utils::IsSupportedOptypeVersionAndDomain(node, "FusedMatMul", {1}, kMSDomain)) ||
|
||||
!graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -139,7 +138,7 @@ Status MatmulTransposeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_
|
|||
bool transpose_left = (left != nullptr);
|
||||
bool transpose_right = (right != nullptr);
|
||||
float alpha = 1.0f;
|
||||
if ((node.OpType() == "TransposeMatMul") || (node.OpType() == "FusedMatMul")) {
|
||||
if (node.OpType() == "FusedMatMul") {
|
||||
transpose_left ^= static_cast<bool>(node.GetAttributes().at("transA").i());
|
||||
transpose_right ^= static_cast<bool>(node.GetAttributes().at("transB").i());
|
||||
alpha = node.GetAttributes().at("alpha").f();
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class MatMul<float> final : public OpKernel {
|
|||
TensorShape b_shape_;
|
||||
BufferUniquePtr packed_b_;
|
||||
|
||||
// For FusedMatMul and TransposeMatMul contrib ops
|
||||
// For FusedMatMul contrib ops
|
||||
float alpha_attr_;
|
||||
int64_t trans_a_attr_;
|
||||
int64_t trans_b_attr_;
|
||||
|
|
|
|||
|
|
@ -170,9 +170,6 @@ TEST(FusedMatMulOpTest, FloatTypeNoTranspose) {
|
|||
}
|
||||
|
||||
#if defined(USE_CUDA) || defined(USE_ROCM) // double support only implemented in CUDA/ROCM kernel
|
||||
TEST(TransposeMatMulOpTest, DoubleTypeNoTranspose) {
|
||||
RunFusedMatMulTest<double>("TransposeMatMul", 1);
|
||||
}
|
||||
|
||||
TEST(FusedMatMulOpTest, DoubleTypeNoTranspose) {
|
||||
RunFusedMatMulTest<double>("FusedMatMul", 1);
|
||||
|
|
@ -207,17 +204,6 @@ TEST(FusedMatMulOpTest, FloatTypeScale) {
|
|||
RunFusedMatMulTest<float>("FusedMatMul", 1, true, true, 4.0f, true);
|
||||
}
|
||||
|
||||
TEST(TransposeMatMulOpTest, FloatTypeScale) {
|
||||
RunFusedMatMulTest<float>("TransposeMatMul", 1, false, false, 0.5f);
|
||||
RunFusedMatMulTest<float>("TransposeMatMul", 1, true, false, 2.0f);
|
||||
RunFusedMatMulTest<float>("TransposeMatMul", 1, true, true, 4.0f);
|
||||
|
||||
// now run tests with b constant.
|
||||
RunFusedMatMulTest<float>("TransposeMatMul", 1, false, false, 0.5f, true);
|
||||
RunFusedMatMulTest<float>("TransposeMatMul", 1, true, false, 2.0f, true);
|
||||
RunFusedMatMulTest<float>("TransposeMatMul", 1, true, true, 4.0f, true);
|
||||
}
|
||||
|
||||
} // namespace transpose_matmul
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -263,7 +263,7 @@ def _create_operator_type_usage_processors():
|
|||
# ai.onnx: If, Loop, Reshape, Scan, Shape, Squeeze, Unsqueeze
|
||||
# com.microsoft: DynamicQuantizeMatMul, MatMulIntegerToFloat
|
||||
# - Only one type supported in the ORT implementation:
|
||||
# com.microsoft: FusedConv, FusedGemm, FusedMatMul, TransposeMatMul
|
||||
# com.microsoft: FusedConv, FusedGemm, FusedMatMul
|
||||
# - Implementation does not have any significant type specific code:
|
||||
# ai.onnx: Concat, Flatten, Not, QLinearConv, Reshape, Shape, Squeeze, Unsqueeze
|
||||
#
|
||||
|
|
|
|||
Loading…
Reference in a new issue