diff --git a/orttraining/orttraining/eager/ort_aten.cpp b/orttraining/orttraining/eager/ort_aten.cpp index 8fbdc51086..79f33c41a2 100644 --- a/orttraining/orttraining/eager/ort_aten.cpp +++ b/orttraining/orttraining/eager/ort_aten.cpp @@ -181,7 +181,7 @@ at::Tensor empty__memory_format( auto& invoker = GetORTInvoker(*device_opt); CreateMLValue( invoker.GetCurrentExecutionProvider().GetAllocator(0, OrtMemTypeDefault), - ort_scalar_type_from_aten(at::kFloat), + ort_scalar_type_from_aten(*dtype_opt), size.vec(), &ot);