mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-25 22:26:24 +00:00
Remove memcpy in in-place ATen ops (#9913)
* Make ops in-place * Add comment
This commit is contained in:
parent
a7c2d1cb09
commit
9e04b7e59b
2 changed files with 8 additions and 4 deletions
|
|
@ -348,6 +348,12 @@ class ORTGen:
|
|||
writer.write(f'std::vector<OrtValue> {onnx_op.outputs}')
|
||||
writer.writeline(f'({onnx_op.outputs.count});')
|
||||
|
||||
if in_place_param:
|
||||
assert(onnx_op.outputs.count == 1)
|
||||
# TODO: This assumes that the first output corresponds to the first input.
|
||||
# This may not work for more complicated ops.
|
||||
writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[0]};')
|
||||
|
||||
# Perform the invocation
|
||||
writer.writeline()
|
||||
if onnx_op_index == 0:
|
||||
|
|
@ -404,7 +410,6 @@ class ORTGen:
|
|||
raise Exception(f'"{cpp_func.torch_func.torch_schema}" ' +
|
||||
'has alias info on its return type but no associated parameter')
|
||||
|
||||
writer.writeline(f'copy(invoker, {return_outputs}[0], ort_input_{in_place_param.identifier.value});')
|
||||
writer.writeline(f'return {in_place_param.identifier.value};')
|
||||
|
||||
def _write_function_registrations(
|
||||
|
|
|
|||
|
|
@ -365,8 +365,8 @@ at::Tensor& zero_(at::Tensor& self){
|
|||
auto* ort_flag_tensor = flag_val.GetMutable<onnxruntime::Tensor>();
|
||||
CopyVectorToTensor<int64_t>(invoker, {1}, *ort_flag_tensor);
|
||||
|
||||
std::vector<OrtValue> ort_out(1);
|
||||
|
||||
std::vector<OrtValue> ort_out = {ort_in_self};
|
||||
|
||||
auto status = invoker.Invoke(
|
||||
"ZeroGradient", {
|
||||
std::move(ort_in_self),
|
||||
|
|
@ -377,7 +377,6 @@ at::Tensor& zero_(at::Tensor& self){
|
|||
throw std::runtime_error(
|
||||
"ORT return failure status:" + status.ErrorMessage());
|
||||
|
||||
copy(invoker, ort_out[0], ort_in_self);
|
||||
return self;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue