mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
### Description 1. Renames all references of on device training to training apis. This is to keep the naming general. Nothing really prevents us from using the same apis on servers\non-edge devices. 2. Update ENABLE_TRAINING option: With this PR when this option is enabled, training apis and torch interop is also enabled. 3. Refactoring for onnxruntime_ENABLE_TRAINING_TORCH_INTEROP option: - Removed user facing option - Setting onnxruntime_ENABLE_TRAINING_TORCH_INTEROP to ON when onnxruntime_ENABLE_TRAINING is ON as we always build with torch interop. Once this PR is merged when --enable_training is selected we will do a "FULL Build" for training (with all the training entry points and features). Training entry points include: 1. ORTModule 2. Training APIs Features include: 1. ATen Fallback 2. All Training OPs includes communication and collectives 3. Strided Tensor Support 4. Python Op (torch interop) 5. ONNXBlock (Front end tools for training artifacts prep when using trianing apis) ### Motivation and Context Intention is to simply the options for building training enabled builds. This is part of the larger work item to create dedicated build for learning on the edge scenarios with just training apis enabled.
71 lines
2.3 KiB
C#
71 lines
2.3 KiB
C#
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
using System;
|
|
using System.Runtime.InteropServices;
|
|
|
|
namespace Microsoft.ML.OnnxRuntime
|
|
{
|
|
#if __ENABLE_TRAINING_APIS__
|
|
/// <summary>
|
|
/// Holds the Checkpoint State as generated/consumed by on-device training APIs
|
|
/// </summary>
|
|
public class CheckpointState : SafeHandle
|
|
{
|
|
internal IntPtr Handle
|
|
{
|
|
get
|
|
{
|
|
return handle;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Creates CheckpointState by loading state from path.
|
|
/// <param name="checkpointPath"> absolute path to checkpoint file.</param>
|
|
/// </summary>
|
|
public CheckpointState(string checkpointPath)
|
|
: base(IntPtr.Zero, true)
|
|
{
|
|
if (NativeTrainingMethods.TrainingEnabled())
|
|
{
|
|
var envHandle = OrtEnv.Handle; // just so it is initialized
|
|
LoadCheckpoint(checkpointPath);
|
|
}
|
|
else
|
|
{
|
|
throw new InvalidOperationException("Training is disabled in the current build. Please build ONNXRuntime from source with the build flags enable_training. \n");
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Overrides SafeHandle.IsInvalid
|
|
/// </summary>
|
|
/// <value>returns true if handle is equal to Zero</value>
|
|
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
|
|
|
|
/// <summary>
|
|
/// Loads Checkpoint state from path
|
|
/// </summary>
|
|
/// <param name="checkpointPath"> absolute path to checkpoint</param>
|
|
private void LoadCheckpoint(string checkpointPath)
|
|
{
|
|
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out handle));
|
|
}
|
|
|
|
#region SafeHandle
|
|
/// <summary>
|
|
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
|
|
/// the native instance of CheckpointState
|
|
/// </summary>
|
|
/// <returns>always returns true</returns>
|
|
protected override bool ReleaseHandle()
|
|
{
|
|
NativeTrainingMethods.OrtReleaseCheckpointState(handle);
|
|
handle = IntPtr.Zero;
|
|
return true;
|
|
}
|
|
|
|
#endregion
|
|
}
|
|
#endif
|
|
}
|