Remove memcpy in in-place ATen ops (#9913)

* Make ops in-place

* Add comment
This commit is contained in:
ashari4 2021-12-14 08:28:12 -08:00 committed by GitHub
parent a7c2d1cb09
commit 9e04b7e59b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 4 deletions

View file

@ -348,6 +348,12 @@ class ORTGen:
writer.write(f'std::vector<OrtValue> {onnx_op.outputs}')
writer.writeline(f'({onnx_op.outputs.count});')
if in_place_param:
assert(onnx_op.outputs.count == 1)
# TODO: This assumes that the first output corresponds to the first input.
# This may not work for more complicated ops.
writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[0]};')
# Perform the invocation
writer.writeline()
if onnx_op_index == 0:
@ -404,7 +410,6 @@ class ORTGen:
raise Exception(f'"{cpp_func.torch_func.torch_schema}" ' +
'has alias info on its return type but no associated parameter')
writer.writeline(f'copy(invoker, {return_outputs}[0], ort_input_{in_place_param.identifier.value});')
writer.writeline(f'return {in_place_param.identifier.value};')
def _write_function_registrations(

View file

@ -365,8 +365,8 @@ at::Tensor& zero_(at::Tensor& self){
auto* ort_flag_tensor = flag_val.GetMutable<onnxruntime::Tensor>();
CopyVectorToTensor<int64_t>(invoker, {1}, *ort_flag_tensor);
std::vector<OrtValue> ort_out(1);
std::vector<OrtValue> ort_out = {ort_in_self};
auto status = invoker.Invoke(
"ZeroGradient", {
std::move(ort_in_self),
@ -377,7 +377,6 @@ at::Tensor& zero_(at::Tensor& self){
throw std::runtime_error(
"ORT return failure status:" + status.ErrorMessage());
copy(invoker, ort_out[0], ort_in_self);
return self;
}