mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-02 03:55:34 +00:00
Optimize 4bit Qlora training (#18131)
### Optimize 4bit Qlora training Extent existing `MatmulBnb4bit` to its usage in training scenarios. The PR includes following changes: 1. Add special `torch.autograd.Function` export logic for `bitsandbytes.autograd._functions.MatMul4Bit` that is preferred before common PythonOp exporter. 2. Add `training_mode` optional attribute for op `MatmulBnb4bit`, which help skip some inference specific logic in implementation. 3. Add `transB` optional attribute, which is by default be 1; setting it to be 0 is needed by backward usage. Changing from `PythonOp` to this `MatmulBnb4bit` brings roughly ~2.9% throughput gains. The reason is: `bitsandbytes.autograd._functions.MatMul4Bit` has logic `ctx.save_for_backward`, which would need an additional copy in PythonOp, otherwise, the tensor might be released by ORT, while backward op still references it. Removing the clones also reduce the peak memory consumptions because `bitsandbytes.autograd._functions.MatMul4Bit` saved tensors that are not needed in backward compute.
This commit is contained in:
parent
e3b043ba17
commit
c8e1038eab
16 changed files with 297 additions and 67 deletions
|
|
@ -2580,8 +2580,30 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
|
||||
3. Input B's quantization constants or scales are specified by input 'absmax'.
|
||||
|
||||
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
|
||||
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
|
||||
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
|
||||
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
|
||||
|
||||
|
||||
1. (Default value) transB=True (Majorly used for forward pass)
|
||||
Shape of A: [D0, D1, ..., Dn, K]
|
||||
Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
|
||||
|
||||
The computation math:
|
||||
dequant_B = dequant(B, absmax, quant_type, block_size)
|
||||
transposed_dequant_B = dequant_B^T
|
||||
output = A @ transposed_dequant_B
|
||||
|
||||
Shape of output: [D0, D1, ..., Dn, N]
|
||||
|
||||
2. transB=False (Majorly used for backward pass)
|
||||
Shape of A: [D0, D1, ..., Dn, N]
|
||||
Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
|
||||
|
||||
The computation math:
|
||||
dequant_B = dequant(B, absmax, quant_type, block_size)
|
||||
output = A @ dequant_B
|
||||
|
||||
Shape of output: [D0, D1, ..., Dn, K]
|
||||
|
||||
|
||||
#### Version
|
||||
|
|
@ -2599,6 +2621,10 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.</dd>
|
||||
<dt><tt>quant_type</tt> : int (required)</dt>
|
||||
<dd>quantization data type. 0 for FP4, 1 for NF4.</dd>
|
||||
<dt><tt>training_mode</tt> : int</dt>
|
||||
<dd>Indicate if the ops run in training_mode, by default, False.</dd>
|
||||
<dt><tt>transB</tt> : int</dt>
|
||||
<dd>Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ class MatMulBnb4 final : public OpKernel {
|
|||
ORT_ENFORCE(
|
||||
quant_type_ == FP4 || quant_type_ == NF4,
|
||||
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
|
||||
is_training_mode_ = static_cast<bool>(info.GetAttrOrDefault("training_mode", static_cast<int64_t>(0)));
|
||||
transB_ = static_cast<bool>(info.GetAttrOrDefault("transB", static_cast<int64_t>(1)));
|
||||
}
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
|
@ -30,6 +32,8 @@ class MatMulBnb4 final : public OpKernel {
|
|||
int64_t N_;
|
||||
int64_t block_size_;
|
||||
int64_t quant_type_;
|
||||
bool is_training_mode_;
|
||||
bool transB_;
|
||||
};
|
||||
|
||||
Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
|
||||
|
|
@ -58,7 +62,7 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
|
|||
thread_pool);
|
||||
|
||||
constexpr bool transa = false;
|
||||
constexpr bool transb = true;
|
||||
const bool transb = transB_;
|
||||
TensorShape b_shape({N_, K_});
|
||||
MatMulComputeHelper helper;
|
||||
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb));
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ class MatMulBnb4 final : public CudaKernel {
|
|||
ORT_ENFORCE(
|
||||
quant_type_ == FP4 || quant_type_ == NF4,
|
||||
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
|
||||
|
||||
is_training_mode_ = static_cast<bool>(info.GetAttrOrDefault("training_mode", static_cast<int64_t>(0)));
|
||||
transB_ = static_cast<bool>(info.GetAttrOrDefault("transB", static_cast<int64_t>(1)));
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
|
@ -34,6 +37,8 @@ class MatMulBnb4 final : public CudaKernel {
|
|||
int64_t N_;
|
||||
int64_t block_size_;
|
||||
int64_t quant_type_;
|
||||
bool is_training_mode_;
|
||||
bool transB_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -59,7 +64,7 @@ Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
|
||||
|
||||
constexpr bool transa = false;
|
||||
constexpr bool transb = true;
|
||||
const bool transb = transB_;
|
||||
MatMulComputeHelper helper;
|
||||
TensorShape b_shape({N_, K_});
|
||||
ORT_RETURN_IF_ERROR(
|
||||
|
|
@ -69,17 +74,18 @@ Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
// Bail out early if the output is going to be empty
|
||||
if (Y->Shape().Size() == 0) return Status::OK();
|
||||
|
||||
bool is_4bit_done = TryMatMulBnb4(
|
||||
reinterpret_cast<const CudaT*>(quant_map_buffer_data),
|
||||
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
|
||||
reinterpret_cast<const CudaT*>(a_data),
|
||||
b_quant_data,
|
||||
reinterpret_cast<const CudaT*>(absmax_data),
|
||||
SafeInt<int>(helper.M()),
|
||||
SafeInt<int>(helper.N()),
|
||||
SafeInt<int>(helper.K()),
|
||||
SafeInt<int>(block_size_),
|
||||
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
|
||||
bool is_4bit_done = !is_training_mode_ // skip inference specific handle if in training mode
|
||||
&& TryMatMulBnb4(
|
||||
reinterpret_cast<const CudaT*>(quant_map_buffer_data),
|
||||
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
|
||||
reinterpret_cast<const CudaT*>(a_data),
|
||||
b_quant_data,
|
||||
reinterpret_cast<const CudaT*>(absmax_data),
|
||||
SafeInt<int>(helper.M()),
|
||||
SafeInt<int>(helper.N()),
|
||||
SafeInt<int>(helper.K()),
|
||||
SafeInt<int>(block_size_),
|
||||
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
|
||||
|
||||
if (!is_4bit_done) {
|
||||
IAllocatorUniquePtr<T> b_dequant_ptr = GetScratchBuffer<T>(N_ * K_, ctx->GetComputeStream());
|
||||
|
|
@ -98,16 +104,16 @@ Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
|
|||
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
|
||||
GetCublasHandle(ctx),
|
||||
CUBLAS_OP_T,
|
||||
CUBLAS_OP_N,
|
||||
transb ? CUBLAS_OP_T : CUBLAS_OP_N, // transB
|
||||
CUBLAS_OP_N, // transA
|
||||
SafeInt<int>(helper.N()),
|
||||
SafeInt<int>(helper.M()),
|
||||
SafeInt<int>(helper.K()),
|
||||
&alpha,
|
||||
reinterpret_cast<const CudaT*>(b_dequant_data),
|
||||
SafeInt<int>(K_),
|
||||
helper.Ldb(transb), // ldb
|
||||
reinterpret_cast<const CudaT*>(a_data),
|
||||
helper.Lda(transa),
|
||||
helper.Lda(transa), // lda
|
||||
&zero,
|
||||
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
|
||||
helper.Ldc(),
|
||||
|
|
|
|||
|
|
@ -2693,7 +2693,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1,
|
|||
|
||||
static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx,
|
||||
int64_t K,
|
||||
int64_t N) {
|
||||
int64_t N,
|
||||
bool transB) {
|
||||
int input_a_idx = 0;
|
||||
if (!hasInputShape(ctx, input_a_idx)) {
|
||||
return;
|
||||
|
|
@ -2707,15 +2708,15 @@ static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext
|
|||
// TODO: check B shape
|
||||
|
||||
const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1);
|
||||
if (dim_last.has_dim_value() && dim_last.dim_value() != K) {
|
||||
ONNX_NAMESPACE::TensorShapeProto resultShape;
|
||||
if (dim_last.has_dim_value() && dim_last.dim_value() != (transB ? K : N)) {
|
||||
fail_shape_inference("Incompatible dimensions for matrix multiplication");
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto resultShape;
|
||||
for (int i = 0; i < a_shape.dim_size() - 1; ++i) {
|
||||
*resultShape.add_dim() = a_shape.dim(i);
|
||||
}
|
||||
resultShape.add_dim()->set_dim_value(N);
|
||||
resultShape.add_dim()->set_dim_value(transB ? N : K);
|
||||
|
||||
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape;
|
||||
}
|
||||
|
|
@ -3354,7 +3355,7 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
|
|||
// Shape inference
|
||||
int64_t in_features = getAttribute(ctx, "K", -1);
|
||||
int64_t out_features = getAttribute(ctx, "N", -1);
|
||||
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
|
||||
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true);
|
||||
});
|
||||
|
||||
static const char* MatMulBnb4_ver1_doc = R"DOC(
|
||||
|
|
@ -3364,8 +3365,30 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4
|
|||
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
|
||||
3. Input B's quantization constants or scales are specified by input 'absmax'.
|
||||
|
||||
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
|
||||
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
|
||||
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
|
||||
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
|
||||
|
||||
|
||||
1. (Default value) transB=True (Majorly used for forward pass)
|
||||
Shape of A: [D0, D1, ..., Dn, K]
|
||||
Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
|
||||
|
||||
The computation math:
|
||||
dequant_B = dequant(B, absmax, quant_type, block_size)
|
||||
transposed_dequant_B = dequant_B^T
|
||||
output = A @ transposed_dequant_B
|
||||
|
||||
Shape of output: [D0, D1, ..., Dn, N]
|
||||
|
||||
2. transB=False (Majorly used for backward pass)
|
||||
Shape of A: [D0, D1, ..., Dn, N]
|
||||
Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
|
||||
|
||||
The computation math:
|
||||
dequant_B = dequant(B, absmax, quant_type, block_size)
|
||||
output = A @ dequant_B
|
||||
|
||||
Shape of output: [D0, D1, ..., Dn, K]
|
||||
|
||||
)DOC";
|
||||
|
||||
|
|
@ -3377,6 +3400,12 @@ Input absmax is stored in same type as original type of B(float32, float16) with
|
|||
.Attr("N", "size of each output feature", AttributeProto::INT)
|
||||
.Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
|
||||
.Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT)
|
||||
.Attr("training_mode",
|
||||
"Indicate if the ops run in training_mode, by default, False.",
|
||||
AttributeProto::INT,
|
||||
static_cast<int64_t>(0))
|
||||
.Attr("transB", "Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.",
|
||||
AttributeProto::INT, static_cast<int64_t>(1))
|
||||
.Input(0, "A", "The input tensor, not quantized", "T1")
|
||||
.Input(1, "B", "1-dimensional quantized data for weight", "T2")
|
||||
.Input(2, "absmax", "quantization constants", "T1")
|
||||
|
|
@ -3389,7 +3418,8 @@ Input absmax is stored in same type as original type of B(float32, float16) with
|
|||
// Shape inference
|
||||
int64_t in_features = getAttribute(ctx, "K", -1);
|
||||
int64_t out_features = getAttribute(ctx, "N", -1);
|
||||
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
|
||||
bool transB = getAttribute(ctx, "transB", 1) != 0;
|
||||
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, transB);
|
||||
});
|
||||
|
||||
#ifdef ENABLE_ATEN
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
|
|||
{"Split", {1}},
|
||||
{"Clip", {1, 2}},
|
||||
{"Pad", {1, 2}},
|
||||
{"MatMulBnb4", {1, 2}}, // quantified weight (non float) and absmax constant don't need gradients.
|
||||
{"Multinomial", {0}},
|
||||
{"RandomNormalLike", {0}},
|
||||
{"RandomUniformLike", {0}},
|
||||
|
|
|
|||
|
|
@ -494,6 +494,42 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
|
|||
return result;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetMatmulBnb4Gradient) {
|
||||
auto attributes = SrcNodeAttributes();
|
||||
std::vector<AttributeProto> attrs;
|
||||
bool find_transB = false;
|
||||
for (auto& attr : attributes) {
|
||||
if (attr.first == "transB") {
|
||||
int64_t transB_value = attr.second.i();
|
||||
transB_value = (transB_value + 1) % 2; // revert the transpose
|
||||
attrs.push_back(MakeAttribute("transB", transB_value));
|
||||
find_transB = true;
|
||||
} else {
|
||||
attrs.push_back(attr.second);
|
||||
}
|
||||
}
|
||||
|
||||
if (!find_transB) {
|
||||
attrs.push_back(MakeAttribute("transB", int64_t(0))); // default is 1, so we need to set it to 0
|
||||
}
|
||||
|
||||
std::vector<NodeDef> result;
|
||||
// Y = A * B
|
||||
// dA = dY * B', dB = A' * dY
|
||||
if (IsGradientRequiredForSrcNodeInput(0)) {
|
||||
// B is 1-D, so don't need transpose here.
|
||||
result.push_back(NodeDef(OpDef{"MatMulBnb4", kMSDomain, 1},
|
||||
{GO(0), I(1), I(2)},
|
||||
{GI(0)},
|
||||
attrs));
|
||||
}
|
||||
|
||||
ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(1), "Gradient propagation to B is not supported yet.");
|
||||
ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(2), "Gradient propagation to absmax is not supported yet.");
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
IMPLEMENT_GRADIENT_BUILDER(GetSplitGradient) {
|
||||
std::vector<NodeDef> result = {};
|
||||
std::vector<ArgDef> input_args;
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossGradient)
|
|||
DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossInternalGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGlobalAveragePoolGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGemmGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetMatmulBnb4Gradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetDropoutGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGatherNDGradient)
|
||||
DECLARE_GRADIENT_BUILDER(GetGatherElementsGradient)
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
|
|||
REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);
|
||||
REGISTER_GRADIENT_BUILDER("MatMulBnb4", GetMatmulBnb4Gradient);
|
||||
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
|
||||
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
|
||||
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@
|
|||
# Licensed under the MIT License.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
|
|
@ -18,13 +21,55 @@ from onnxruntime.capi._pybind_state import (
|
|||
register_torch_autograd_function,
|
||||
)
|
||||
from onnxruntime.training import ortmodule
|
||||
from onnxruntime.training.utils import pytorch_dtype_to_onnx
|
||||
from onnxruntime.training.utils import pytorch_scalar_type_to_pytorch_dtype, pytorch_type_to_onnx_dtype
|
||||
|
||||
from ._custom_op_symbolic_registry import wrap_custom_export_function
|
||||
from ._fallback import ORTModuleONNXModelException, wrap_exception
|
||||
from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version
|
||||
|
||||
|
||||
class _SpecialCustomFunctionHandler:
|
||||
"""A class to handle high priority export of torch.autograd.Function.
|
||||
`register_high_priority_handler` can be used as function decorator to register a handler for a torch.autograd.Function.
|
||||
"""
|
||||
|
||||
_HIGH_PRIORITY_EXPORT_HANDLER_MAP: ClassVar[dict[str, callable]] = {}
|
||||
|
||||
@staticmethod
|
||||
def add_handler(func_name: str, handler: callable) -> None:
|
||||
"""Add a handler for a function name.
|
||||
|
||||
Args:
|
||||
func_name (str): The function name.
|
||||
handler (callable): The handler.
|
||||
|
||||
"""
|
||||
_SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP[func_name] = handler
|
||||
|
||||
@staticmethod
|
||||
def get_handler(func_name: str) -> callable | None:
|
||||
"""Get the handler for a function name.
|
||||
|
||||
Args:
|
||||
func_name (str): The function name.
|
||||
|
||||
Returns:
|
||||
callable | None: The handler.
|
||||
|
||||
"""
|
||||
return _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP.get(func_name, None)
|
||||
|
||||
|
||||
def register_high_priority_handler(func_name):
|
||||
"""Register a handler for a torch.autograd.Function using its full qualified class name."""
|
||||
|
||||
def symbolic_wrapper(fn):
|
||||
_SpecialCustomFunctionHandler.add_handler(func_name, fn)
|
||||
return fn
|
||||
|
||||
return symbolic_wrapper
|
||||
|
||||
|
||||
def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None:
|
||||
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod
|
||||
"infer_shape" defined.
|
||||
|
|
@ -96,6 +141,30 @@ _UNSUPPORTED_CKPT_FUNC_NAMES = frozenset(
|
|||
)
|
||||
|
||||
|
||||
def _get_training_mode() -> bool:
|
||||
# TODO move to public API once the exporter team exposes that
|
||||
training_mode = None
|
||||
if get_runtime_pytorch_version() >= version.parse("1.12"):
|
||||
# FIXME: using private modules
|
||||
from torch.onnx import _globals
|
||||
|
||||
# before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9,
|
||||
# training_mode is still a bool type
|
||||
if isinstance(_globals.GLOBALS.training_mode, bool):
|
||||
training_mode = _globals.GLOBALS.training_mode
|
||||
else:
|
||||
if _globals.GLOBALS.training_mode not in [
|
||||
torch.onnx.TrainingMode.EVAL,
|
||||
torch.onnx.TrainingMode.TRAINING,
|
||||
]:
|
||||
raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}")
|
||||
training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING
|
||||
else:
|
||||
training_mode = symbolic_helper._training_mode
|
||||
|
||||
return bool(training_mode)
|
||||
|
||||
|
||||
def _export_pt_1_10(g, n, *args, **kwargs):
|
||||
"""Export torch.autograd.Function in ORT PythonOp.
|
||||
|
||||
|
|
@ -114,6 +183,15 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
func_class = n.pyobj().__self__
|
||||
func_full_qual_name = get_fully_qualified_class_name(func_class)
|
||||
|
||||
# Check if the function is handled by high priority exporter.
|
||||
hi_pri_handler = _SpecialCustomFunctionHandler.get_handler(func_full_qual_name)
|
||||
if hi_pri_handler:
|
||||
try_export = hi_pri_handler(g, n, *args, **kwargs)
|
||||
if try_export is not None:
|
||||
return try_export
|
||||
|
||||
# Fall back to common exporter if not handled by high priority exporter.
|
||||
|
||||
# Check if the checkpointing activation is allowed.
|
||||
is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1
|
||||
if is_ckpt_activation_allowed is False and func_full_qual_name in _UNSUPPORTED_CKPT_FUNC_NAMES:
|
||||
|
|
@ -123,26 +201,6 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
"wrap exportable sub-nn.Module's as ORTModule."
|
||||
)
|
||||
|
||||
# TODO move to public API once the exporter team exposes that
|
||||
training_mode = None
|
||||
if get_runtime_pytorch_version() >= version.parse("1.12"):
|
||||
# FIXME: using private modules
|
||||
from torch.onnx import _globals
|
||||
|
||||
# before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9,
|
||||
# training_mode is still a bool type
|
||||
if isinstance(_globals.GLOBALS.training_mode, bool):
|
||||
training_mode = _globals.GLOBALS.training_mode
|
||||
else:
|
||||
if _globals.GLOBALS.training_mode not in [
|
||||
torch.onnx.TrainingMode.EVAL,
|
||||
torch.onnx.TrainingMode.TRAINING,
|
||||
]:
|
||||
raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}")
|
||||
training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING
|
||||
else:
|
||||
training_mode = symbolic_helper._training_mode
|
||||
|
||||
cconv = n.cconv()
|
||||
|
||||
input_tensor_types = []
|
||||
|
|
@ -179,7 +237,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
if call_type == "d":
|
||||
# Got a tensor variable.
|
||||
tensor_args.append(arg)
|
||||
scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType())
|
||||
scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType())
|
||||
input_tensor_types.append(scalar_type)
|
||||
input_tensor_ranks.append(arg.type().dim())
|
||||
continue
|
||||
|
|
@ -258,7 +316,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
output_tensor_ranks = []
|
||||
for arg in n.outputs():
|
||||
# Type of tensor's elements.
|
||||
scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType())
|
||||
scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType())
|
||||
output_tensor_types.append(scalar_type)
|
||||
output_tensor_ranks.append(arg.type().dim())
|
||||
|
||||
|
|
@ -270,7 +328,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
|
|||
"input_tensor_ranks_i": input_tensor_ranks,
|
||||
"output_tensor_types_i": output_tensor_types,
|
||||
"output_tensor_ranks_i": output_tensor_ranks,
|
||||
"training_mode_i": 1 if training_mode else 0,
|
||||
"training_mode_i": 1 if _get_training_mode() else 0,
|
||||
"comment_s": debug_comment,
|
||||
}
|
||||
|
||||
|
|
@ -336,3 +394,55 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model
|
|||
index += 1
|
||||
|
||||
return exported_model
|
||||
|
||||
|
||||
@register_high_priority_handler("bitsandbytes.autograd._functions.MatMul4Bit")
|
||||
def _matmul4bit_export(g, n, *args, **kwargs):
|
||||
cconv = n.cconv()
|
||||
can_converted = cconv[0] == "d" and cconv[1] == "d" and cconv[2] == "c" and cconv[3] == "c" and cconv[4] == "c"
|
||||
can_converted = can_converted and (args[2] is None and args[3] is None and args[4] is not None)
|
||||
if not can_converted:
|
||||
return None
|
||||
|
||||
quant_state = args[4]
|
||||
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
|
||||
|
||||
# MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16
|
||||
if blocksize < 16 or blocksize & (blocksize - 1) != 0:
|
||||
return None
|
||||
|
||||
# MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too)
|
||||
if compressed_stats is not None:
|
||||
return None
|
||||
|
||||
# The PyTorch linear weight shape is [out_feature, in_feature]
|
||||
in_feature = shape[1]
|
||||
out_feature = shape[0]
|
||||
if quant_type == "fp4":
|
||||
quant_type = 0
|
||||
elif quant_type == "nf4":
|
||||
quant_type = 1
|
||||
else:
|
||||
return None
|
||||
attrs = {
|
||||
"K_i": in_feature,
|
||||
"N_i": out_feature,
|
||||
"block_size_i": blocksize,
|
||||
"quant_type_i": quant_type,
|
||||
"training_mode_i": 1 if _get_training_mode() else 0,
|
||||
}
|
||||
|
||||
# Make sure the quant weight can be flatten to 1D tensor safely, which com.microsoft::MatMulBnb4 requires.
|
||||
found_dim1 = any(v == 1 for v in args[1].type().sizes())
|
||||
if not found_dim1:
|
||||
return None
|
||||
|
||||
absmax = g.op(
|
||||
"Constant",
|
||||
value_t=torch.tensor(absmax, dtype=pytorch_scalar_type_to_pytorch_dtype(args[0].type().scalarType())),
|
||||
)
|
||||
quant_weight = g.op(
|
||||
"Reshape", args[1], g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
|
||||
) # flatten to 1D
|
||||
tensor_args = [args[0], quant_weight, absmax]
|
||||
return g.op("com.microsoft::MatMulBnb4", *tensor_args, **attrs)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from packaging.version import Version
|
|||
from torch.onnx import register_custom_op_symbolic
|
||||
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args
|
||||
|
||||
from onnxruntime.training.utils import pytorch_dtype_to_onnx
|
||||
from onnxruntime.training.utils import pytorch_type_to_onnx_dtype
|
||||
|
||||
from ._utils import get_runtime_pytorch_version
|
||||
|
||||
|
|
@ -145,7 +145,7 @@ def cross_entropy_loss(g, node, logits, target, weight, reduction, ignore_index,
|
|||
weight_casted,
|
||||
ignore_index,
|
||||
reduction_s=reduction,
|
||||
output_type_i=pytorch_dtype_to_onnx(output_type.scalarType()),
|
||||
output_type_i=pytorch_type_to_onnx_dtype(output_type.scalarType()),
|
||||
outputs=2,
|
||||
)
|
||||
output.setType(output_type)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from torch.utils.cpp_extension import ROCM_HOME
|
|||
import onnxruntime
|
||||
from onnxruntime.capi import _pybind_state as C
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch
|
||||
from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch_dtype
|
||||
from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
|
||||
|
||||
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
|
||||
|
|
@ -345,7 +345,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx"
|
||||
)
|
||||
if os.path.exists(cache_dir) and os.path.isfile(filename):
|
||||
self._logger.info(
|
||||
self._logger.warning(
|
||||
f"Cached model detected! Cached model will be used to save export and initialization time."
|
||||
f"If you want the model to be re-exported then DELETE {filename}."
|
||||
)
|
||||
|
|
@ -627,7 +627,7 @@ class GraphExecutionManager(GraphExecutionInterface):
|
|||
|
||||
kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
|
||||
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
|
||||
dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
|
||||
dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
|
||||
device=device,
|
||||
).requires_grad_()
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from onnxruntime.capi._pybind_state import (
|
|||
register_shape_inference_function,
|
||||
register_torch_autograd_function,
|
||||
)
|
||||
from onnxruntime.training.utils import pytorch_dtype_to_onnx
|
||||
from onnxruntime.training.utils import pytorch_type_to_onnx_dtype
|
||||
|
||||
from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary
|
||||
from ._utils import get_fully_qualified_class_name
|
||||
|
|
@ -149,7 +149,7 @@ def post_processing_enable_zero_stage3_compat(
|
|||
c,
|
||||
graph_input.name,
|
||||
len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank
|
||||
pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type
|
||||
pytorch_type_to_onnx_dtype(zero_stage3_named_params[graph_input.name].dtype), # new data type
|
||||
)
|
||||
|
||||
# Delete exported_model.graph.input
|
||||
|
|
|
|||
|
|
@ -366,7 +366,7 @@ class _RuntimeOptions:
|
|||
|
||||
# Cache exported model
|
||||
if "ORTMODULE_CACHE_DIR" in os.environ:
|
||||
self._logger.info("ORTModule cache optimization is ON.")
|
||||
self._logger.warning("ORTModule optimization for caching exported model is ON.")
|
||||
self.ortmodule_cache_dir = os.getenv("ORTMODULE_CACHE_DIR")
|
||||
|
||||
# Experimental features.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,11 @@ from onnxruntime.training.utils.torch_io_helper import (
|
|||
extract_data_and_schema,
|
||||
unflatten_data_using_schema,
|
||||
)
|
||||
from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx
|
||||
from onnxruntime.training.utils.torch_type_map import (
|
||||
onnx_dtype_to_pytorch_dtype,
|
||||
pytorch_scalar_type_to_pytorch_dtype,
|
||||
pytorch_type_to_onnx_dtype,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PrimitiveType",
|
||||
|
|
@ -17,6 +21,7 @@ __all__ = [
|
|||
"ORTModelInputOutputSchemaType",
|
||||
"extract_data_and_schema",
|
||||
"unflatten_data_using_schema",
|
||||
"pytorch_dtype_to_onnx",
|
||||
"onnx_dtype_to_pytorch",
|
||||
"pytorch_type_to_onnx_dtype",
|
||||
"onnx_dtype_to_pytorch_dtype",
|
||||
"pytorch_scalar_type_to_pytorch_dtype",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import torch
|
|||
from onnxruntime.training.utils import (
|
||||
ORTModelInputOutputType,
|
||||
extract_data_and_schema,
|
||||
pytorch_dtype_to_onnx,
|
||||
pytorch_type_to_onnx_dtype,
|
||||
unflatten_data_using_schema,
|
||||
)
|
||||
|
||||
|
|
@ -324,7 +324,7 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
|
|||
start_offset = len(tensor_input_shapes) - len(partitioned_params)
|
||||
for index, param in enumerate(partitioned_params):
|
||||
tensor_output_shapes[start_offset + index] = list(param.ds_shape)
|
||||
tensor_output_dtypes[start_offset + index] = int(pytorch_dtype_to_onnx(param.dtype))
|
||||
tensor_output_dtypes[start_offset + index] = int(pytorch_type_to_onnx_dtype(param.dtype))
|
||||
assert len(tensor_output_shapes) == len(tensor_input_shapes)
|
||||
assert len(tensor_output_dtypes) == len(tensor_input_dtypes)
|
||||
|
||||
|
|
|
|||
|
|
@ -36,8 +36,10 @@ _DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _C
|
|||
_ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()}
|
||||
|
||||
|
||||
def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType:
|
||||
"""Converts a pytorch dtype or scalar type string to an onnx dtype."""
|
||||
def pytorch_type_to_onnx_dtype(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType:
|
||||
"""Converts a pytorch dtype or scalar type string to an onnx dtype.
|
||||
PyTorch type can be either a dtype or a scalar type string.
|
||||
"""
|
||||
dtype = dtype_or_scalar_type
|
||||
if isinstance(dtype, str):
|
||||
if dtype not in _CAST_PYTORCH_TO_ONNX:
|
||||
|
|
@ -49,7 +51,15 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc
|
|||
return _DTYPE_TO_ONNX[dtype]
|
||||
|
||||
|
||||
def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype:
|
||||
def pytorch_scalar_type_to_pytorch_dtype(dtype: str) -> torch.dtype:
|
||||
"""Converts a pytorch scalar type string to a pytorch dtype."""
|
||||
assert isinstance(dtype, str)
|
||||
if dtype not in _CAST_PYTORCH_TO_ONNX:
|
||||
raise RuntimeError(f"Unsupported dtype {dtype}")
|
||||
return _CAST_PYTORCH_TO_ONNX[dtype][1]
|
||||
|
||||
|
||||
def onnx_dtype_to_pytorch_dtype(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype:
|
||||
"""Converts an onnx dtype to a pytorch dtype."""
|
||||
if dtype not in _ONNX_TO_DTYPE:
|
||||
raise RuntimeError(f"Unsupported dtype {dtype}")
|
||||
|
|
|
|||
Loading…
Reference in a new issue