From 86e30a2db6f02ab4d2b92800fbf398f7ea641978 Mon Sep 17 00:00:00 2001 From: Jesse Benson Date: Wed, 18 Nov 2020 11:27:50 -0800 Subject: [PATCH] Update CUDA IsAllFinite kernel --- .../core/graph/optimizer_graph_builder.cc | 2 +- .../core/graph/training_op_defs.cc | 6 ++++- .../training_ops/cuda/math/isfinite.cc | 25 ++++++------------- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/orttraining/orttraining/core/graph/optimizer_graph_builder.cc b/orttraining/orttraining/core/graph/optimizer_graph_builder.cc index 4925b009a9..465a6af8f0 100644 --- a/orttraining/orttraining/core/graph/optimizer_graph_builder.cc +++ b/orttraining/orttraining/core/graph/optimizer_graph_builder.cc @@ -348,7 +348,7 @@ Status OptimizerGraphBuilder::AddFiniteGradientCheck( ArgDef& grad_norm_finite_argdef, const std::string& node_name) { const TypeProto* const grad_norm_finite_type = - graph_defs.CreateTypeProto({1}, ONNX_NAMESPACE::TensorProto_DataType_BOOL); + graph_defs.CreateTypeProto({}, ONNX_NAMESPACE::TensorProto_DataType_BOOL); grad_norm_finite_argdef = ArgDef{nodearg_name_generator(node_name), grad_norm_finite_type}; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 6130afb64c..6585940feb 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -1843,7 +1843,11 @@ Example 4: "The output scalar. Its value is true if all input " "tensors are finite. Otherwise, the output value would " "be false.", - "T"); + "T") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + updateOutputShape(ctx, 0, {}); + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + }); static const char* All_doc = R"DOC( Return true if all elements are true and false otherwise. diff --git a/orttraining/orttraining/training_ops/cuda/math/isfinite.cc b/orttraining/orttraining/training_ops/cuda/math/isfinite.cc index 749a50da2b..e0ecf15dd2 100644 --- a/orttraining/orttraining/training_ops/cuda/math/isfinite.cc +++ b/orttraining/orttraining/training_ops/cuda/math/isfinite.cc @@ -44,7 +44,6 @@ REGISTER_ISFINITE_KERNEL_TYPED(double) T, \ kCudaExecutionProvider, \ KernelDefBuilder() \ - .OutputMemoryType(0) \ .TypeConstraint("V", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ IsAllFiniteOp); @@ -56,11 +55,11 @@ Status IsAllFiniteOp::ComputeInternal(OpKernelContext* context) const { // Get Input tensor count. const auto total_tensor_count = context->InputCount(); - // Allocate GPU memory to capture the result computed by GPU kernel. - // The GPU result will be copied later to the output which locates - // on CPU memory. - IAllocatorUniquePtr deviceOutput = GetScratchBuffer(1); - CUDA_RETURN_IF_ERROR(cudaMemsetAsync(deviceOutput.get(), int(true), sizeof(bool))); + // Initialize the output to true. GPU kernel will set it to false + // if any value in any tensor is non-finite. + Tensor& output = *context->Output(0, {}); + auto output_data = reinterpret_cast::MappedType*>(output.template MutableData()); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(output_data, int(true), sizeof(bool))); std::vector> grouped_tensor_pointers(total_tensor_count); std::vector tensor_sizes(total_tensor_count); @@ -74,20 +73,10 @@ Status IsAllFiniteOp::ComputeInternal(OpKernelContext* context) const { typedef IsAllFiniteFunctor TFunctor; TFunctor functor; - // Check if all values are finite and write true to deviceOutput. + // Check if all values are finite and write true to output. // Otherwise, false will be written. launch_multi_tensor_functor<1, TFunctor, bool*>( - 2048 * 32, tensor_sizes, grouped_tensor_pointers, functor, deviceOutput.get()); - - // Copy GPU result in deviceOutput to CPU memory. - // Per this operator's schema, it's output is in CPU memory. - Tensor& output = *context->Output(0, {}); - CUDA_RETURN_IF_ERROR( - cudaMemcpy( - output.MutableData(), - deviceOutput.get(), - sizeof(bool), - cudaMemcpyDeviceToHost)); + 2048 * 32, tensor_sizes, grouped_tensor_pointers, functor, output_data); return Status::OK(); }