onnxruntime/onnxruntime/core/codegen/mti/math/gemm.cc
KeDengMS c9240f4e93
Implementation of Nuphar execution provider (#881)
* Implement Nuphar execution provider

Nuphar execution provider is a TVM-based compilation provider. It has shown great speedups for RNN models using Scan.
This PR is mainly for a preview of the shared codegen library for other TVM-based providers.

* Fix submodules

* Fix TVM submodule

* Update Nuphar to latest and resolve confliction

* Remove stale files caused by merge -X theirs

* Revert heap buffer change to not introduce onnxruntime_framework into onnxruntime_perf_test

* Fix bad merge

* Merge from Nuphar

* Fix warning treated as error, revert some unnecessary changes

* Revert some more test changes

* Some more test revert or comments to make review easier
New tests could be added later

* One more revert of unnecessary changes

* More change revert. Test could be added back later.
2019-09-01 23:01:47 -07:00

30 lines
966 B
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/codegen/mti/math/gemm.h"
#include "core/codegen/mti/math/matmul_ops.h"
#include "core/codegen/mti/mti_tvm_utils.h"
#include <topi/broadcast.h>
// Using namespace topi for override operator +-*/
using namespace topi;
namespace onnxruntime {
namespace tvm_codegen {
tvm::Tensor Gemm(const tvm::Tensor& A, const tvm::Tensor& B, const tvm::Tensor& C,
bool trans_A, bool trans_B, float alpha, float beta,
const std::string& name) {
auto A_dot_B = MatMul2D(A, B, trans_A, trans_B, name + "_matmul2d");
tvm::Expr alphaExpr = tvm::make_const(A->dtype, alpha);
if (beta != 0) {
tvm::Expr betaExpr = tvm::make_const(A->dtype, beta);
return Rename(alphaExpr * A_dot_B + (betaExpr * C), name);
} else {
return Rename(alphaExpr * A_dot_B, name);
}
}
} // namespace tvm_codegen
} // namespace onnxruntime