Call ATenOp for ReduceSum on ORTModule (#9471)

* call ATenOp for ReduceSum

* Enable ReduceSum ATenOp for training only

* always load extension
This commit is contained in:
Vincent Wang 2021-10-26 09:48:57 +08:00 committed by GitHub
parent 651955d3c9
commit fb4f7dbbb7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 123 additions and 36 deletions

View file

@ -176,6 +176,10 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
void contrib__GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, std::vector<int64_t>& new_shape, std::vector<size_t>& permutations) override { contrib::GetPermutationAndShape(ncd_to_ndc, tensor_shape, new_shape, permutations); }
Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) override { return contrib::PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); }
Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) override { return p->YieldOp::Compute(context); }
// From aten_op.h (direct)
bool contrib__IsATenOperatorExecutorInitialized() override { return contrib::IsATenOperatorExecutorInitialized(); }
Status contrib__ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector<int64_t>& axes, bool keepdims) override { return contrib::ExecuteReduceSumATenOp(p_ctx, axes, keepdims); }
#endif
#endif
};

View file

@ -138,6 +138,10 @@ struct ProviderHostCPU {
virtual void contrib__GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, std::vector<int64_t>& new_shape, std::vector<size_t>& permutations) = 0;
virtual Status contrib__PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) = 0;
virtual Status contrib__YieldOp__Compute(const contrib::YieldOp* p, OpKernelContext* context) = 0;
// From aten_op.h
virtual bool contrib__IsATenOperatorExecutorInitialized() = 0;
virtual Status contrib__ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector<int64_t>& axes, bool keepdims) = 0;
#endif
#endif
};
@ -204,6 +208,10 @@ inline void VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, const
inline void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, int64_t& N_D, int64_t& C) { g_host_cpu.contrib__GetNDCFromLogitAndLabelShape(logit_shape, label_shape, N_D, C); }
inline void GetPermutationAndShape(bool ncd_to_ndc, const TensorShape& tensor_shape, std::vector<int64_t>& new_shape, std::vector<size_t>& permutations) { g_host_cpu.contrib__GetPermutationAndShape(ncd_to_ndc, tensor_shape, new_shape, permutations); }
inline Status PrepareForTrainingCompute(const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, int& after_dims_including_split_axis, int& after_dims_excluding_split, std::vector<int64_t>& split_sizes) { return g_host_cpu.contrib__PrepareForTrainingCompute(input_shape, num_outputs, axis, before_dims, after_dims_including_split_axis, after_dims_excluding_split, split_sizes); }
// From aten_op.h
inline bool IsATenOperatorExecutorInitialized() { return g_host_cpu.contrib__IsATenOperatorExecutorInitialized(); }
inline Status ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector<int64_t>& axes, bool keepdims) { return g_host_cpu.contrib__ExecuteReduceSumATenOp(p_ctx, axes, keepdims); }
} // namespace contrib
#endif // ENABLE_TRAINING
#endif // USE_CUDA || USE_ROCM

View file

@ -8,6 +8,9 @@
#include "core/providers/cuda/math/binary_elementwise_ops_impl.h"
#include "core/providers/cuda/math/binary_elementwise_ops.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
#ifdef ENABLE_TRAINING
#include "orttraining/training_ops/cpu/aten_ops/aten_op.h"
#endif
using namespace onnxruntime::common;
namespace onnxruntime {
@ -703,7 +706,7 @@ Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, cudnnRe
size_t num_inputs = ctx->InputCount();
if (num_inputs == 2) {
//override the attribute value with the input value for reduction_axes
// override the attribute value with the input value for reduction_axes
const Tensor* axes_tensor = ctx->Input<Tensor>(1);
ORT_ENFORCE(axes_tensor != nullptr, "Axes input is null");
ORT_ENFORCE(axes_tensor->Shape().NumDimensions() == 1, "An axes tensor must be a vector tensor.");
@ -717,18 +720,29 @@ Status ReduceKernel<allow_multi_axes>::ComputeImpl(OpKernelContext* ctx, cudnnRe
// empty axes and no-op
if (axes.empty() && noop_with_empty_axes_) {
auto* Y = ctx->Output(0, X->Shape());
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->template MutableData<T>(), X->template Data<T>(), X->SizeInBytes(), cudaMemcpyDeviceToDevice, Stream()));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(Y->template MutableData<T>(), X->template Data<T>(), X->SizeInBytes(),
cudaMemcpyDeviceToDevice, Stream()));
return Status::OK();
}
#ifdef ENABLE_TRAINING
// Use ATenOp for ReduceSum if possible.
const TensorShape& input_shape = X->Shape();
if (contrib::IsATenOperatorExecutorInitialized() && cudnn_reduce_op == CUDNN_REDUCE_TENSOR_ADD && !calculate_log_ &&
!calculate_sqt_ && !log_sum_exp_ && input_shape.Size() > 0) {
if (axes.empty()) {
axes.resize(input_shape.NumDimensions());
std::iota(axes.begin(), axes.end(), 0);
}
ORT_RETURN_IF_ERROR(contrib::ExecuteReduceSumATenOp(ctx, axes, keepdims_));
return Status::OK();
}
#endif
PrepareReduceMetadata prepare_reduce_metadata;
ORT_RETURN_IF_ERROR(PrepareForReduce(X,
keepdims_,
axes,
prepare_reduce_metadata));
ORT_RETURN_IF_ERROR(PrepareForReduce(X, keepdims_, axes, prepare_reduce_metadata));
Tensor* Y = ctx->Output(0, prepare_reduce_metadata.squeezed_output_dims);
const bool fast_reduction = fast_reduction_ && !ctx->GetUseDeterministicCompute();
return ReduceComputeCore<T, ReduceTensorIndices>(*cuda_ep_, *X, prepare_reduce_metadata, *Y, cudnn_reduce_op, axes,
calculate_log_, calculate_sqt_, log_sum_exp_, fast_reduction);
}

View file

@ -469,7 +469,7 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int));
void* p_is_tensor_argument = reinterpret_cast<void*>(is_tensor_argument_address_int);
void* p_aten_op_executor = reinterpret_cast<void*>(aten_op_executor_address_int);
contrib::aten_ops::ATenOperatorExecutor::Initialize(p_is_tensor_argument, p_aten_op_executor);
contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor);
});
m.def("register_forward_runner", [](py::object obj) -> void {
#ifdef ENABLE_TRAINING_TORCH_INTEROP

View file

@ -9,7 +9,7 @@ from . import (_utils,
_logger,
_onnx_models,
_are_deterministic_algorithms_enabled)
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension_if_needed
from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension
from ._custom_autograd_function import custom_autograd_function_enabler
from ._custom_autograd_function_exporter import _post_process_after_export
from ._graph_execution_interface import GraphExecutionInterface
@ -175,6 +175,9 @@ class GraphExecutionManager(GraphExecutionInterface):
# Re-export will be avoided if _skip_check is enabled.
self._original_model_has_changed = False
# Load ATenOp executor extension.
load_aten_op_executor_cpp_extension()
def _get_torch_gpu_allocator_function_addresses(self):
if self._use_external_gpu_allocator and torch.cuda.is_available():
# CPP extension to get torch GPU allocator's alloc and free function addresses
@ -308,7 +311,6 @@ class GraphExecutionManager(GraphExecutionInterface):
self._set_device_from_module(inputs, kwargs)
self._onnx_models.exported_model = self._get_exported_model(
schema, *inputs, **kwargs)
load_aten_op_executor_cpp_extension_if_needed(self._onnx_models.exported_model)
if self._debug_options.save_onnx_models.save:
self._onnx_models.save_exported_model(self._debug_options.save_onnx_models.path,
self._debug_options.save_onnx_models.name_prefix,

View file

@ -24,14 +24,7 @@ def run_once_aten_op_executor(f):
@run_once_aten_op_executor
def _load_aten_op_executor_cpp_extension():
def load_aten_op_executor_cpp_extension():
from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor
C.register_aten_op_executor(str(aten_op_executor.is_tensor_argument_address()),
str(aten_op_executor.execute_aten_operator_address()))
def load_aten_op_executor_cpp_extension_if_needed(onnx_model):
for node in onnx_model.graph.node:
if node.op_type == 'ATenOp' and node.domain == 'com.microsoft':
_load_aten_op_executor_cpp_extension()
break

View file

@ -1015,6 +1015,43 @@ def test_gradient_correctness_argmax_diagonal(offset, dim1, dim2):
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
@pytest.mark.parametrize("dim", [None, 0, 1, (0, 1), (-1, 0), (0, 1, 2)])
@pytest.mark.parametrize("keepdim", [True, False])
def test_gradient_correctness_reducesum(dim, keepdim):
class NeuralNetReduceSum(torch.nn.Module):
def __init__(self, input_size, hidden_size, dim, keepdim):
super(NeuralNetReduceSum, self).__init__()
self.linear = torch.nn.Linear(input_size, hidden_size)
self.dim = dim
self.keepdim = keepdim
def forward(self, input):
t = self.linear(input)
if self.dim is None:
return t.sum()
else:
return torch.sum(t, self.dim, keepdim=self.keepdim)
N, D, H, W = 16, 256, 128, 64
device = 'cuda'
pt_model = NeuralNetReduceSum(H, W, dim, keepdim).to(device)
ort_model = ORTModule(copy.deepcopy(pt_model))
def run_step(model, input):
prediction = model(input)
loss = prediction.sum()
loss.backward()
return prediction
for _ in range(10):
pt_input = torch.rand((N, D, H), device=device, requires_grad=True)
ort_input = copy.deepcopy(pt_input)
pt_prediction = run_step(pt_model, pt_input)
ort_prediction = run_step(ort_model, ort_input)
_test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
_test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
# Since multinomial is a generator function, we do not have to test for gradient
# Two consecutive calls on the torch.multinomail on a probability distribution with more
# than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in

View file

@ -35,5 +35,35 @@ Status ATenOp::Compute(OpKernelContext* p_ctx) const {
return Status::OK();
}
bool IsATenOperatorExecutorInitialized() {
return aten_ops::ATenOperatorExecutor::Instance().IsInitialized();
}
Status ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector<int64_t>& axes, bool keepdims) {
ORT_ENFORCE(aten_ops::ATenOperatorExecutor::Instance().IsInitialized() && !axes.empty());
std::vector<DLManagedTensor*> dlpacks;
auto* p_ctx_internal = static_cast<OpKernelContextInternal*>(p_ctx);
OrtValue ort_value = *p_ctx_internal->GetInputMLValue(0);
dlpacks.emplace_back(dlpack::OrtValueToDlpack(ort_value));
OrtValue axes_tensor;
OrtValue keepdims_tensor;
std::vector<int64_t> axes_tensor_shape(1, static_cast<int64_t>(axes.size()));
std::vector<int64_t> keepdims_tensor_shape(1, 1);
auto ml_tensor = DataTypeImpl::GetType<Tensor>();
OrtMemoryInfo info("Cpu", OrtDeviceAllocator);
axes_tensor.Init(new Tensor(DataTypeImpl::GetType<int64_t>(), axes_tensor_shape,
const_cast<void*>(reinterpret_cast<const void*>(&axes[0])), info),
ml_tensor, ml_tensor->GetDeleteFunc());
keepdims_tensor.Init(
new Tensor(DataTypeImpl::GetType<bool>(), keepdims_tensor_shape, reinterpret_cast<void*>(&keepdims), info),
ml_tensor, ml_tensor->GetDeleteFunc());
dlpacks.emplace_back(dlpack::OrtValueToDlpack(axes_tensor));
dlpacks.emplace_back(dlpack::OrtValueToDlpack(keepdims_tensor));
dlpacks.emplace_back(nullptr);
auto result = aten_ops::ATenOperatorExecutor::Instance()("aten::sum", "dim_IntList", dlpacks);
ORT_RETURN_IF_ERROR(p_ctx_internal->SetOutputMLValue(0, dlpack::DlpackToOrtValue(result[0])));
return Status::OK();
}
} // namespace contrib
} // namespace onnxruntime

View file

@ -22,5 +22,8 @@ class ATenOp : public OpKernel {
std::string overload_name_;
};
bool IsATenOperatorExecutorInitialized();
Status ExecuteReduceSumATenOp(OpKernelContext* p_ctx, const std::vector<int64_t>& axes, bool keepdims);
} // namespace contrib
} // namespace onnxruntime

View file

@ -16,16 +16,24 @@ typedef std::vector<DLManagedTensor*> (*ExecuteATenOperatorFunc)(const char* op_
class ATenOperatorExecutor {
public:
static ATenOperatorExecutor& Instance() { return InstanceImpl(); }
static void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) {
InstanceImpl(p_is_tensor_argument_func_raw, p_execute_aten_op_func_raw);
static ATenOperatorExecutor& Instance() {
static ATenOperatorExecutor instance;
return instance;
}
void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) {
ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw);
p_is_tensor_argument_func_ = reinterpret_cast<IsTensorArgumentFunc>(p_is_tensor_argument_func_raw);
p_execute_aten_op_func_ = reinterpret_cast<ExecuteATenOperatorFunc>(p_execute_aten_op_func_raw);
}
bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; }
bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index) {
ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized.");
return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index);
}
std::vector<DLManagedTensor*> operator()(const std::string& op_name, const std::string& overload_name,
const std::vector<DLManagedTensor*>& dlpacks) {
ORT_ENFORCE(p_execute_aten_op_func_, "ATenOperatorExecutor is not initialized.");
@ -33,20 +41,8 @@ class ATenOperatorExecutor {
}
private:
static ATenOperatorExecutor& InstanceImpl(void* p_is_tensor_argument_func_raw = nullptr,
void* p_execute_aten_op_func_raw = nullptr) {
static ATenOperatorExecutor instance(p_is_tensor_argument_func_raw, p_execute_aten_op_func_raw);
return instance;
}
ATenOperatorExecutor(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) {
ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw);
p_is_tensor_argument_func_ = reinterpret_cast<IsTensorArgumentFunc>(p_is_tensor_argument_func_raw);
p_execute_aten_op_func_ = reinterpret_cast<ExecuteATenOperatorFunc>(p_execute_aten_op_func_raw);
}
IsTensorArgumentFunc p_is_tensor_argument_func_;
ExecuteATenOperatorFunc p_execute_aten_op_func_;
IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr;
ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr;
};
} // namespace aten_ops