Fix softmax cpu code for double type (#3065)

This commit is contained in:
Changming Sun 2020-02-21 12:06:13 -08:00 committed by GitHub
parent 179603775f
commit 7ffb36be44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -61,10 +61,10 @@ class Softmax final : public OpKernel {
int N = static_cast<int>(input_shape.SizeToDimension(axis));
int D = static_cast<int>(input_shape.SizeFromDimension(axis));
Eigen::TensorMap<Eigen::Tensor<const float, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> X_tensor(
X.Data<float>(), N, D);
Eigen::TensorMap<Eigen::Tensor<float, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> Y_tensor(
Y->MutableData<float>(), N, D);
Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> X_tensor(
X.Data<T>(), N, D);
Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, Eigen::DenseIndex>, Eigen::Aligned> Y_tensor(
Y->MutableData<T>(), N, D);
#ifndef USE_OPENMP
if (tp == nullptr)
#endif