From fb4f7dbbb7ca1793c170253758653d94b035205c Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Tue, 26 Oct 2021 09:48:57 +0800 Subject: [PATCH] Call ATenOp for ReduceSum on ORTModule (#9471) * call ATenOp for ReduceSum * Enable ReduceSum ATenOp for training only * always load extension --- .../core/providers/cpu/cpu_provider_shared.cc | 4 ++ .../core/providers/cpu/cpu_provider_shared.h | 8 ++++ .../providers/cuda/reduction/reduction_ops.cc | 28 ++++++++++---- .../python/orttraining_pybind_state.cc | 2 +- .../ortmodule/_graph_execution_manager.py | 6 ++- .../cpu/aten_op_executor/__init__.py | 9 +---- .../python/orttraining_test_ortmodule_api.py | 37 +++++++++++++++++++ .../training_ops/cpu/aten_ops/aten_op.cc | 30 +++++++++++++++ .../training_ops/cpu/aten_ops/aten_op.h | 3 ++ .../cpu/aten_ops/aten_op_executor.h | 32 +++++++--------- 10 files changed, 123 insertions(+), 36 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index 6ca4777342..2ca01eaf3f 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -176,6 +176,10 @@ struct ProviderHostCPUImpl : ProviderHostCPU { void contrib__GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, std::vector& new_shape, std::vector& permutations) override { contrib::GetPermutationAndShape(ncd_to_ndc, tensor_shape, new_shape, permutations); } Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) override { return contrib::PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); } Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); } + + // From aten_op.h (direct) + bool contrib__IsATenOperatorExecutorInitialized() override { return contrib::IsATenOperatorExecutorInitialized(); } + Status contrib__ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector& axes, bool keepdims) override { return contrib::ExecuteReduceSumATenOp(p_ctx, axes, keepdims); } #endif #endif }; diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index ad9bb8a5ad..59b124b63c 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -138,6 +138,10 @@ struct ProviderHostCPU { virtual void contrib__GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, std::vector& new_shape, std::vector& permutations) = 0; virtual Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) = 0; virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0; + + // From aten_op.h + virtual bool contrib__IsATenOperatorExecutorInitialized() = 0; + virtual Status contrib__ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector& axes, bool keepdims) = 0; #endif #endif }; @@ -204,6 +208,10 @@ inline void VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, const inline void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) { g_host_cpu.contrib__GetNDCFromLogitAndLabelShape(logit_shape, label_shape, N_D, C); } inline void GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, std::vector& new_shape, std::vector& permutations) { g_host_cpu.contrib__GetPermutationAndShape(ncd_to_ndc, tensor_shape, new_shape, permutations); } inline Status PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector& split_sizes) { return g_host_cpu.contrib__PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); } + +// From aten_op.h +inline bool IsATenOperatorExecutorInitialized() { return g_host_cpu.contrib__IsATenOperatorExecutorInitialized(); } +inline Status ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector& axes, bool keepdims) { return g_host_cpu.contrib__ExecuteReduceSumATenOp(p_ctx, axes, keepdims); } } // namespace contrib #endif // ENABLE_TRAINING #endif // USE_CUDA || USE_ROCM diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc index af6a079abb..ba1ca2b039 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc +++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc @@ -8,6 +8,9 @@ #include "core/providers/cuda/math/binary_elementwise_ops_impl.h" #include "core/providers/cuda/math/binary_elementwise_ops.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" +#ifdef ENABLE_TRAINING +#include "orttraining/training_ops/cpu/aten_ops/aten_op.h" +#endif using namespace onnxruntime::common; namespace onnxruntime { @@ -703,7 +706,7 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe size_t num_inputs = ctx->InputCount(); if (num_inputs == 2) { - //override the attribute value with the input value for reduction_axes + // override the attribute value with the input value for reduction_axes const Tensor* axes_tensor = ctx->Input(1); ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null"); ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor."); @@ -717,18 +720,29 @@ Status ReduceKernel::ComputeImpl(OpKernelContext* ctx, cudnnRe // empty axes and no-op if (axes.empty() && noop_with_empty_axes_) { auto* Y = ctx->Output(0, X->Shape()); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->template MutableData(), X->template Data(), X->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream())); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->template MutableData(), X->template Data(), X->SizeInBytes(), + cudaMemcpyDeviceToDevice, Stream())); return Status::OK(); } +#ifdef ENABLE_TRAINING + // Use ATenOp for ReduceSum if possible. + const TensorShape& input_shape = X->Shape(); + if (contrib::IsATenOperatorExecutorInitialized() && cudnn_reduce_op == CUDNN_REDUCE_TENSOR_ADD && !calculate_log_ && + !calculate_sqt_ && !log_sum_exp_ && input_shape.Size() > 0) { + if (axes.empty()) { + axes.resize(input_shape.NumDimensions()); + std::iota(axes.begin(), axes.end(), 0); + } + ORT_RETURN_IF_ERROR(contrib::ExecuteReduceSumATenOp(ctx, axes, keepdims_)); + return Status::OK(); + } +#endif + PrepareReduceMetadata prepare_reduce_metadata; - ORT_RETURN_IF_ERROR(PrepareForReduce(X, - keepdims_, - axes, - prepare_reduce_metadata)); + ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata)); Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims); const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute(); - return ReduceComputeCore(*cuda_ep_, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes, calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction); } diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 682c2eed20..de7817f41a 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -469,7 +469,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int)); void* p_is_tensor_argument = reinterpret_cast(is_tensor_argument_address_int); void* p_aten_op_executor = reinterpret_cast(aten_op_executor_address_int); - contrib::aten_ops::ATenOperatorExecutor::Initialize(p_is_tensor_argument, p_aten_op_executor); + contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); }); m.def("register_forward_runner", [](py::object obj) -> void { #ifdef ENABLE_TRAINING_TORCH_INTEROP diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 39836ca602..c4b71f19f9 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -9,7 +9,7 @@ from . import (_utils, _logger, _onnx_models, _are_deterministic_algorithms_enabled) -from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension_if_needed +from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension from ._custom_autograd_function import custom_autograd_function_enabler from ._custom_autograd_function_exporter import _post_process_after_export from ._graph_execution_interface import GraphExecutionInterface @@ -175,6 +175,9 @@ class GraphExecutionManager(GraphExecutionInterface): # Re-export will be avoided if _skip_check is enabled. self._original_model_has_changed = False + # Load ATenOp executor extension. + load_aten_op_executor_cpp_extension() + def _get_torch_gpu_allocator_function_addresses(self): if self._use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses @@ -308,7 +311,6 @@ class GraphExecutionManager(GraphExecutionInterface): self._set_device_from_module(inputs, kwargs) self._onnx_models.exported_model = self._get_exported_model( schema, *inputs, **kwargs) - load_aten_op_executor_cpp_extension_if_needed(self._onnx_models.exported_model) if self._debug_options.save_onnx_models.save: self._onnx_models.save_exported_model(self._debug_options.save_onnx_models.path, self._debug_options.save_onnx_models.name_prefix, diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py index 3b58c55051..a17f4369da 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py @@ -24,14 +24,7 @@ def run_once_aten_op_executor(f): @run_once_aten_op_executor -def _load_aten_op_executor_cpp_extension(): +def load_aten_op_executor_cpp_extension(): from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor C.register_aten_op_executor(str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address())) - - -def load_aten_op_executor_cpp_extension_if_needed(onnx_model): - for node in onnx_model.graph.node: - if node.op_type == 'ATenOp' and node.domain == 'com.microsoft': - _load_aten_op_executor_cpp_extension() - break diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 8d644fa8bb..f069b8c01c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1015,6 +1015,43 @@ def test_gradient_correctness_argmax_diagonal(offset, dim1, dim2): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) +@pytest.mark.parametrize("dim", [None, 0, 1, (0, 1), (-1, 0), (0, 1, 2)]) +@pytest.mark.parametrize("keepdim", [True, False]) +def test_gradient_correctness_reducesum(dim, keepdim): + class NeuralNetReduceSum(torch.nn.Module): + def __init__(self, input_size, hidden_size, dim, keepdim): + super(NeuralNetReduceSum, self).__init__() + self.linear = torch.nn.Linear(input_size, hidden_size) + self.dim = dim + self.keepdim = keepdim + + def forward(self, input): + t = self.linear(input) + if self.dim is None: + return t.sum() + else: + return torch.sum(t, self.dim, keepdim=self.keepdim) + + N, D, H, W = 16, 256, 128, 64 + device = 'cuda' + pt_model = NeuralNetReduceSum(H, W, dim, keepdim).to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input): + prediction = model(input) + loss = prediction.sum() + loss.backward() + return prediction + + for _ in range(10): + pt_input = torch.rand((N, D, H), device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + # Since multinomial is a generator function, we do not have to test for gradient # Two consecutive calls on the torch.multinomail on a probability distribution with more # than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in diff --git a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.cc b/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.cc index 9581f3da17..8be5b15317 100644 --- a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.cc +++ b/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.cc @@ -35,5 +35,35 @@ Status ATenOp::Compute(OpKernelContext* p_ctx) const { return Status::OK(); } +bool IsATenOperatorExecutorInitialized() { + return aten_ops::ATenOperatorExecutor::Instance().IsInitialized(); +} + +Status ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector& axes, bool keepdims) { + ORT_ENFORCE(aten_ops::ATenOperatorExecutor::Instance().IsInitialized() && !axes.empty()); + std::vector dlpacks; + auto* p_ctx_internal = static_cast(p_ctx); + OrtValue ort_value = *p_ctx_internal->GetInputMLValue(0); + dlpacks.emplace_back(dlpack::OrtValueToDlpack(ort_value)); + OrtValue axes_tensor; + OrtValue keepdims_tensor; + std::vector axes_tensor_shape(1, static_cast(axes.size())); + std::vector keepdims_tensor_shape(1, 1); + auto ml_tensor = DataTypeImpl::GetType(); + OrtMemoryInfo info("Cpu", OrtDeviceAllocator); + axes_tensor.Init(new Tensor(DataTypeImpl::GetType(), axes_tensor_shape, + const_cast(reinterpret_cast(&axes[0])), info), + ml_tensor, ml_tensor->GetDeleteFunc()); + keepdims_tensor.Init( + new Tensor(DataTypeImpl::GetType(), keepdims_tensor_shape, reinterpret_cast(&keepdims), info), + ml_tensor, ml_tensor->GetDeleteFunc()); + dlpacks.emplace_back(dlpack::OrtValueToDlpack(axes_tensor)); + dlpacks.emplace_back(dlpack::OrtValueToDlpack(keepdims_tensor)); + dlpacks.emplace_back(nullptr); + auto result = aten_ops::ATenOperatorExecutor::Instance()("aten::sum", "dim_IntList", dlpacks); + ORT_RETURN_IF_ERROR(p_ctx_internal->SetOutputMLValue(0, dlpack::DlpackToOrtValue(result[0]))); + return Status::OK(); +} + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.h b/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.h index 61ea5647d9..d7b802f77e 100644 --- a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.h +++ b/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op.h @@ -22,5 +22,8 @@ class ATenOp : public OpKernel { std::string overload_name_; }; +bool IsATenOperatorExecutorInitialized(); +Status ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector& axes, bool keepdims); + } // namespace contrib } // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op_executor.h b/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op_executor.h index 10f348fb10..ce518bfea6 100644 --- a/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op_executor.h +++ b/orttraining/orttraining/training_ops/cpu/aten_ops/aten_op_executor.h @@ -16,16 +16,24 @@ typedef std::vector (*ExecuteATenOperatorFunc)(const char* op_ class ATenOperatorExecutor { public: - static ATenOperatorExecutor& Instance() { return InstanceImpl(); } - - static void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { - InstanceImpl(p_is_tensor_argument_func_raw, p_execute_aten_op_func_raw); + static ATenOperatorExecutor& Instance() { + static ATenOperatorExecutor instance; + return instance; } + void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); + p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); + p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); + } + + bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } + bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) { ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index); } + std::vector operator()(const std::string& op_name, const std::string& overload_name, const std::vector& dlpacks) { ORT_ENFORCE(p_execute_aten_op_func_, "ATenOperatorExecutor is not initialized."); @@ -33,20 +41,8 @@ class ATenOperatorExecutor { } private: - static ATenOperatorExecutor& InstanceImpl(void* p_is_tensor_argument_func_raw = nullptr, - void* p_execute_aten_op_func_raw = nullptr) { - static ATenOperatorExecutor instance(p_is_tensor_argument_func_raw, p_execute_aten_op_func_raw); - return instance; - } - - ATenOperatorExecutor(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); - p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); - p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); - } - - IsTensorArgumentFunc p_is_tensor_argument_func_; - ExecuteATenOperatorFunc p_execute_aten_op_func_; + IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; + ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; } // namespace aten_ops