From 9e04b7e59b3ee1d0eff70010077f659a4e32a86d Mon Sep 17 00:00:00 2001 From: ashari4 <70242157+ashari4@users.noreply.github.com> Date: Tue, 14 Dec 2021 08:28:12 -0800 Subject: [PATCH] Remove memcpy in in-place ATen ops (#9913) * Make ops in-place * Add comment --- orttraining/orttraining/eager/opgen/opgen/generator.py | 7 ++++++- orttraining/orttraining/eager/ort_aten.cpp | 5 ++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index eaa3a8deef..3a8268eb12 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -348,6 +348,12 @@ class ORTGen: writer.write(f'std::vector {onnx_op.outputs}') writer.writeline(f'({onnx_op.outputs.count});') + if in_place_param: + assert(onnx_op.outputs.count == 1) + # TODO: This assumes that the first output corresponds to the first input. + # This may not work for more complicated ops. + writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[0]};') + # Perform the invocation writer.writeline() if onnx_op_index == 0: @@ -404,7 +410,6 @@ class ORTGen: raise Exception(f'"{cpp_func.torch_func.torch_schema}" ' + 'has alias info on its return type but no associated parameter') - writer.writeline(f'copy(invoker, {return_outputs}[0], ort_input_{in_place_param.identifier.value});') writer.writeline(f'return {in_place_param.identifier.value};') def _write_function_registrations( diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index a698fccb69..9493c18f12 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -365,8 +365,8 @@ at::Tensor& zero_(at::Tensor& self){ auto* ort_flag_tensor = flag_val.GetMutable(); CopyVectorToTensor(invoker, {1}, *ort_flag_tensor); - std::vector ort_out(1); - + std::vector ort_out = {ort_in_self}; + auto status = invoker.Invoke( "ZeroGradient", { std::move(ort_in_self), @@ -377,7 +377,6 @@ at::Tensor& zero_(at::Tensor& self){ throw std::runtime_error( "ORT return failure status:" + status.ErrorMessage()); - copy(invoker, ort_out[0], ort_in_self); return self; }