Optimize 4bit Qlora training (#18131)

### Optimize 4bit Qlora training

Extent existing `MatmulBnb4bit` to its usage in training scenarios. 

The PR includes following changes:
1. Add special `torch.autograd.Function` export logic for
`bitsandbytes.autograd._functions.MatMul4Bit` that is preferred before
common PythonOp exporter.
2. Add `training_mode` optional attribute for op `MatmulBnb4bit`, which
help skip some inference specific logic in implementation.
3. Add `transB` optional attribute, which is by default be 1; setting it
to be 0 is needed by backward usage.

Changing from `PythonOp` to this `MatmulBnb4bit` brings roughly ~2.9%
throughput gains. The reason is:
`bitsandbytes.autograd._functions.MatMul4Bit` has logic
`ctx.save_for_backward`, which would need an additional copy in
PythonOp, otherwise, the tensor might be released by ORT, while backward
op still references it.

Removing the clones also reduce the peak memory consumptions because
`bitsandbytes.autograd._functions.MatMul4Bit` saved tensors that are not
needed in backward compute.
This commit is contained in:
pengwa 2023-11-03 00:46:11 +08:00 committed by GitHub
parent e3b043ba17
commit c8e1038eab
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 297 additions and 67 deletions

View file

@ -2580,8 +2580,30 @@ This version of the operator has been available since version 1 of the 'com.micr
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's quantization constants or scales are specified by input 'absmax'.
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
1. (Default value) transB=True (Majorly used for forward pass)
Shape of A: [D0, D1, ..., Dn, K]
Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
The computation math:
dequant_B = dequant(B, absmax, quant_type, block_size)
transposed_dequant_B = dequant_B^T
output = A @ transposed_dequant_B
Shape of output: [D0, D1, ..., Dn, N]
2. transB=False (Majorly used for backward pass)
Shape of A: [D0, D1, ..., Dn, N]
Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
The computation math:
dequant_B = dequant(B, absmax, quant_type, block_size)
output = A @ dequant_B
Shape of output: [D0, D1, ..., Dn, K]
#### Version
@ -2599,6 +2621,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.</dd>
<dt><tt>quant_type</tt> : int (required)</dt>
<dd>quantization data type. 0 for FP4, 1 for NF4.</dd>
<dt><tt>training_mode</tt> : int</dt>
<dd>Indicate if the ops run in training_mode, by default, False.</dd>
<dt><tt>transB</tt> : int</dt>
<dd>Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.</dd>
</dl>
#### Inputs

View file

@ -21,6 +21,8 @@ class MatMulBnb4 final : public OpKernel {
ORT_ENFORCE(
quant_type_ == FP4 || quant_type_ == NF4,
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
is_training_mode_ = static_cast<bool>(info.GetAttrOrDefault("training_mode", static_cast<int64_t>(0)));
transB_ = static_cast<bool>(info.GetAttrOrDefault("transB", static_cast<int64_t>(1)));
}
Status Compute(OpKernelContext* context) const override;
@ -30,6 +32,8 @@ class MatMulBnb4 final : public OpKernel {
int64_t N_;
int64_t block_size_;
int64_t quant_type_;
bool is_training_mode_;
bool transB_;
};
Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
@ -58,7 +62,7 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
thread_pool);
constexpr bool transa = false;
constexpr bool transb = true;
const bool transb = transB_;
TensorShape b_shape({N_, K_});
MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb));

View file

@ -25,6 +25,9 @@ class MatMulBnb4 final : public CudaKernel {
ORT_ENFORCE(
quant_type_ == FP4 || quant_type_ == NF4,
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
is_training_mode_ = static_cast<bool>(info.GetAttrOrDefault("training_mode", static_cast<int64_t>(0)));
transB_ = static_cast<bool>(info.GetAttrOrDefault("transB", static_cast<int64_t>(1)));
}
Status ComputeInternal(OpKernelContext* context) const override;
@ -34,6 +37,8 @@ class MatMulBnb4 final : public CudaKernel {
int64_t N_;
int64_t block_size_;
int64_t quant_type_;
bool is_training_mode_;
bool transB_;
};
template <typename T>
@ -59,7 +64,7 @@ Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle())));
constexpr bool transa = false;
constexpr bool transb = true;
const bool transb = transB_;
MatMulComputeHelper helper;
TensorShape b_shape({N_, K_});
ORT_RETURN_IF_ERROR(
@ -69,17 +74,18 @@ Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();
bool is_4bit_done = TryMatMulBnb4(
reinterpret_cast<const CudaT*>(quant_map_buffer_data),
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
reinterpret_cast<const CudaT*>(a_data),
b_quant_data,
reinterpret_cast<const CudaT*>(absmax_data),
SafeInt<int>(helper.M()),
SafeInt<int>(helper.N()),
SafeInt<int>(helper.K()),
SafeInt<int>(block_size_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
bool is_4bit_done = !is_training_mode_ // skip inference specific handle if in training mode
&& TryMatMulBnb4(
reinterpret_cast<const CudaT*>(quant_map_buffer_data),
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
reinterpret_cast<const CudaT*>(a_data),
b_quant_data,
reinterpret_cast<const CudaT*>(absmax_data),
SafeInt<int>(helper.M()),
SafeInt<int>(helper.N()),
SafeInt<int>(helper.K()),
SafeInt<int>(block_size_),
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()));
if (!is_4bit_done) {
IAllocatorUniquePtr<T> b_dequant_ptr = GetScratchBuffer<T>(N_ * K_, ctx->GetComputeStream());
@ -98,16 +104,16 @@ Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
CUBLAS_RETURN_IF_ERROR(cublasGemmHelper(
GetCublasHandle(ctx),
CUBLAS_OP_T,
CUBLAS_OP_N,
transb ? CUBLAS_OP_T : CUBLAS_OP_N, // transB
CUBLAS_OP_N, // transA
SafeInt<int>(helper.N()),
SafeInt<int>(helper.M()),
SafeInt<int>(helper.K()),
&alpha,
reinterpret_cast<const CudaT*>(b_dequant_data),
SafeInt<int>(K_),
helper.Ldb(transb), // ldb
reinterpret_cast<const CudaT*>(a_data),
helper.Lda(transa),
helper.Lda(transa), // lda
&zero,
reinterpret_cast<CudaT*>(Y->MutableData<T>()),
helper.Ldc(),

View file

@ -2693,7 +2693,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1,
static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx,
int64_t K,
int64_t N) {
int64_t N,
bool transB) {
int input_a_idx = 0;
if (!hasInputShape(ctx, input_a_idx)) {
return;
@ -2707,15 +2708,15 @@ static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext
// TODO: check B shape
const auto& dim_last = a_shape.dim(a_shape.dim_size() - 1);
if (dim_last.has_dim_value() && dim_last.dim_value() != K) {
ONNX_NAMESPACE::TensorShapeProto resultShape;
if (dim_last.has_dim_value() && dim_last.dim_value() != (transB ? K : N)) {
fail_shape_inference("Incompatible dimensions for matrix multiplication");
}
ONNX_NAMESPACE::TensorShapeProto resultShape;
for (int i = 0; i < a_shape.dim_size() - 1; ++i) {
*resultShape.add_dim() = a_shape.dim(i);
}
resultShape.add_dim()->set_dim_value(N);
resultShape.add_dim()->set_dim_value(transB ? N : K);
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape() = resultShape;
}
@ -3354,7 +3355,7 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored
// Shape inference
int64_t in_features = getAttribute(ctx, "K", -1);
int64_t out_features = getAttribute(ctx, "N", -1);
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, true);
});
static const char* MatMulBnb4_ver1_doc = R"DOC(
@ -3364,8 +3365,30 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4
And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
3. Input B's quantization constants or scales are specified by input 'absmax'.
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
1. (Default value) transB=True (Majorly used for forward pass)
Shape of A: [D0, D1, ..., Dn, K]
Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
The computation math:
dequant_B = dequant(B, absmax, quant_type, block_size)
transposed_dequant_B = dequant_B^T
output = A @ transposed_dequant_B
Shape of output: [D0, D1, ..., Dn, N]
2. transB=False (Majorly used for backward pass)
Shape of A: [D0, D1, ..., Dn, N]
Shape of Dequanted B: [N, K], this is aligned with how PyTorch defined the linear weight, .e.g [out_features, in_features].
The computation math:
dequant_B = dequant(B, absmax, quant_type, block_size)
output = A @ dequant_B
Shape of output: [D0, D1, ..., Dn, K]
)DOC";
@ -3377,6 +3400,12 @@ Input absmax is stored in same type as original type of B(float32, float16) with
.Attr("N", "size of each output feature", AttributeProto::INT)
.Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT)
.Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT)
.Attr("training_mode",
"Indicate if the ops run in training_mode, by default, False.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Attr("transB", "Whether B should be transposed on the last two dimensions before doing multiplication. Default to be 1.",
AttributeProto::INT, static_cast<int64_t>(1))
.Input(0, "A", "The input tensor, not quantized", "T1")
.Input(1, "B", "1-dimensional quantized data for weight", "T2")
.Input(2, "absmax", "quantization constants", "T1")
@ -3389,7 +3418,8 @@ Input absmax is stored in same type as original type of B(float32, float16) with
// Shape inference
int64_t in_features = getAttribute(ctx, "K", -1);
int64_t out_features = getAttribute(ctx, "N", -1);
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features);
bool transB = getAttribute(ctx, "transB", 1) != 0;
MatmulWithQuantWeightShapeInference(ctx, in_features, out_features, transB);
});
#ifdef ENABLE_ATEN

View file

@ -70,6 +70,7 @@ static std::unordered_map<std::string, std::unordered_set<size_t>>
{"Split", {1}},
{"Clip", {1, 2}},
{"Pad", {1, 2}},
{"MatMulBnb4", {1, 2}}, // quantified weight (non float) and absmax constant don't need gradients.
{"Multinomial", {0}},
{"RandomNormalLike", {0}},
{"RandomUniformLike", {0}},

View file

@ -494,6 +494,42 @@ IMPLEMENT_GRADIENT_BUILDER(GetGemmGradient) {
return result;
}
IMPLEMENT_GRADIENT_BUILDER(GetMatmulBnb4Gradient) {
auto attributes = SrcNodeAttributes();
std::vector<AttributeProto> attrs;
bool find_transB = false;
for (auto& attr : attributes) {
if (attr.first == "transB") {
int64_t transB_value = attr.second.i();
transB_value = (transB_value + 1) % 2; // revert the transpose
attrs.push_back(MakeAttribute("transB", transB_value));
find_transB = true;
} else {
attrs.push_back(attr.second);
}
}
if (!find_transB) {
attrs.push_back(MakeAttribute("transB", int64_t(0))); // default is 1, so we need to set it to 0
}
std::vector<NodeDef> result;
// Y = A * B
// dA = dY * B', dB = A' * dY
if (IsGradientRequiredForSrcNodeInput(0)) {
// B is 1-D, so don't need transpose here.
result.push_back(NodeDef(OpDef{"MatMulBnb4", kMSDomain, 1},
{GO(0), I(1), I(2)},
{GI(0)},
attrs));
}
ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(1), "Gradient propagation to B is not supported yet.");
ORT_ENFORCE(!IsGradientRequiredForSrcNodeInput(2), "Gradient propagation to absmax is not supported yet.");
return result;
}
IMPLEMENT_GRADIENT_BUILDER(GetSplitGradient) {
std::vector<NodeDef> result = {};
std::vector<ArgDef> input_args;

View file

@ -54,6 +54,7 @@ DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossGradient)
DECLARE_GRADIENT_BUILDER(GetSoftmaxCrossEntropyLossInternalGradient)
DECLARE_GRADIENT_BUILDER(GetGlobalAveragePoolGradient)
DECLARE_GRADIENT_BUILDER(GetGemmGradient)
DECLARE_GRADIENT_BUILDER(GetMatmulBnb4Gradient)
DECLARE_GRADIENT_BUILDER(GetDropoutGradient)
DECLARE_GRADIENT_BUILDER(GetGatherNDGradient)
DECLARE_GRADIENT_BUILDER(GetGatherElementsGradient)

View file

@ -68,6 +68,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
REGISTER_GRADIENT_BUILDER("Reshape", GetReshapeGradient);
REGISTER_GRADIENT_BUILDER("Transpose", GetTransposeGradient);
REGISTER_GRADIENT_BUILDER("Gemm", GetGemmGradient);
REGISTER_GRADIENT_BUILDER("MatMulBnb4", GetMatmulBnb4Gradient);
REGISTER_GRADIENT_BUILDER("MaxPool", GetMaxPoolGradient);
REGISTER_GRADIENT_BUILDER("Gather", GetGatherGradient);
REGISTER_GRADIENT_BUILDER("PadAndUnflatten", GetPadAndUnflattenGradient);

View file

@ -3,7 +3,10 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations
import sys
from typing import ClassVar
import torch
import torch.utils.checkpoint
@ -18,13 +21,55 @@ from onnxruntime.capi._pybind_state import (
register_torch_autograd_function,
)
from onnxruntime.training import ortmodule
from onnxruntime.training.utils import pytorch_dtype_to_onnx
from onnxruntime.training.utils import pytorch_scalar_type_to_pytorch_dtype, pytorch_type_to_onnx_dtype
from ._custom_op_symbolic_registry import wrap_custom_export_function
from ._fallback import ORTModuleONNXModelException, wrap_exception
from ._utils import get_fully_qualified_class_name, get_runtime_pytorch_version
class _SpecialCustomFunctionHandler:
"""A class to handle high priority export of torch.autograd.Function.
`register_high_priority_handler` can be used as function decorator to register a handler for a torch.autograd.Function.
"""
_HIGH_PRIORITY_EXPORT_HANDLER_MAP: ClassVar[dict[str, callable]] = {}
@staticmethod
def add_handler(func_name: str, handler: callable) -> None:
"""Add a handler for a function name.
Args:
func_name (str): The function name.
handler (callable): The handler.
"""
_SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP[func_name] = handler
@staticmethod
def get_handler(func_name: str) -> callable | None:
"""Get the handler for a function name.
Args:
func_name (str): The function name.
Returns:
callable | None: The handler.
"""
return _SpecialCustomFunctionHandler._HIGH_PRIORITY_EXPORT_HANDLER_MAP.get(func_name, None)
def register_high_priority_handler(func_name):
"""Register a handler for a torch.autograd.Function using its full qualified class name."""
def symbolic_wrapper(fn):
_SpecialCustomFunctionHandler.add_handler(func_name, fn)
return fn
return symbolic_wrapper
def register_custom_function_schema_supplementary(kclass: torch.autograd.Function) -> None:
"""Register a shape inference function for a torch.autograd.Function if there is staticmethod
"infer_shape" defined.
@ -96,6 +141,30 @@ _UNSUPPORTED_CKPT_FUNC_NAMES = frozenset(
)
def _get_training_mode() -> bool:
# TODO move to public API once the exporter team exposes that
training_mode = None
if get_runtime_pytorch_version() >= version.parse("1.12"):
# FIXME: using private modules
from torch.onnx import _globals
# before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9,
# training_mode is still a bool type
if isinstance(_globals.GLOBALS.training_mode, bool):
training_mode = _globals.GLOBALS.training_mode
else:
if _globals.GLOBALS.training_mode not in [
torch.onnx.TrainingMode.EVAL,
torch.onnx.TrainingMode.TRAINING,
]:
raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}")
training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING
else:
training_mode = symbolic_helper._training_mode
return bool(training_mode)
def _export_pt_1_10(g, n, *args, **kwargs):
"""Export torch.autograd.Function in ORT PythonOp.
@ -114,6 +183,15 @@ def _export_pt_1_10(g, n, *args, **kwargs):
func_class = n.pyobj().__self__
func_full_qual_name = get_fully_qualified_class_name(func_class)
# Check if the function is handled by high priority exporter.
hi_pri_handler = _SpecialCustomFunctionHandler.get_handler(func_full_qual_name)
if hi_pri_handler:
try_export = hi_pri_handler(g, n, *args, **kwargs)
if try_export is not None:
return try_export
# Fall back to common exporter if not handled by high priority exporter.
# Check if the checkpointing activation is allowed.
is_ckpt_activation_allowed = ortmodule._defined_from_envvar("ORTMODULE_ALLOW_AUTOGRAD_CHECKPOINT", 0) == 1
if is_ckpt_activation_allowed is False and func_full_qual_name in _UNSUPPORTED_CKPT_FUNC_NAMES:
@ -123,26 +201,6 @@ def _export_pt_1_10(g, n, *args, **kwargs):
"wrap exportable sub-nn.Module's as ORTModule."
)
# TODO move to public API once the exporter team exposes that
training_mode = None
if get_runtime_pytorch_version() >= version.parse("1.12"):
# FIXME: using private modules
from torch.onnx import _globals
# before https://github.com/pytorch/pytorch/commit/c8b9b6266b505328e503b12f6a42fd88c56374f9,
# training_mode is still a bool type
if isinstance(_globals.GLOBALS.training_mode, bool):
training_mode = _globals.GLOBALS.training_mode
else:
if _globals.GLOBALS.training_mode not in [
torch.onnx.TrainingMode.EVAL,
torch.onnx.TrainingMode.TRAINING,
]:
raise Exception(f"Unexpected training mode {_globals.GLOBALS.training_mode}")
training_mode = _globals.GLOBALS.training_mode == torch.onnx.TrainingMode.TRAINING
else:
training_mode = symbolic_helper._training_mode
cconv = n.cconv()
input_tensor_types = []
@ -179,7 +237,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
if call_type == "d":
# Got a tensor variable.
tensor_args.append(arg)
scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType())
scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType())
input_tensor_types.append(scalar_type)
input_tensor_ranks.append(arg.type().dim())
continue
@ -258,7 +316,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
output_tensor_ranks = []
for arg in n.outputs():
# Type of tensor's elements.
scalar_type = pytorch_dtype_to_onnx(arg.type().scalarType())
scalar_type = pytorch_type_to_onnx_dtype(arg.type().scalarType())
output_tensor_types.append(scalar_type)
output_tensor_ranks.append(arg.type().dim())
@ -270,7 +328,7 @@ def _export_pt_1_10(g, n, *args, **kwargs):
"input_tensor_ranks_i": input_tensor_ranks,
"output_tensor_types_i": output_tensor_types,
"output_tensor_ranks_i": output_tensor_ranks,
"training_mode_i": 1 if training_mode else 0,
"training_mode_i": 1 if _get_training_mode() else 0,
"comment_s": debug_comment,
}
@ -336,3 +394,55 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model
index += 1
return exported_model
@register_high_priority_handler("bitsandbytes.autograd._functions.MatMul4Bit")
def _matmul4bit_export(g, n, *args, **kwargs):
cconv = n.cconv()
can_converted = cconv[0] == "d" and cconv[1] == "d" and cconv[2] == "c" and cconv[3] == "c" and cconv[4] == "c"
can_converted = can_converted and (args[2] is None and args[3] is None and args[4] is not None)
if not can_converted:
return None
quant_state = args[4]
absmax, shape, dtype, blocksize, compressed_stats, quant_type, data_type = quant_state
# MatMulBnb4's blocksize needs to be a power of 2 and not smaller than 16
if blocksize < 16 or blocksize & (blocksize - 1) != 0:
return None
# MatMulBnb4 does not support double de-quantization (e.g. absmax is int, needs to be dequantized too)
if compressed_stats is not None:
return None
# The PyTorch linear weight shape is [out_feature, in_feature]
in_feature = shape[1]
out_feature = shape[0]
if quant_type == "fp4":
quant_type = 0
elif quant_type == "nf4":
quant_type = 1
else:
return None
attrs = {
"K_i": in_feature,
"N_i": out_feature,
"block_size_i": blocksize,
"quant_type_i": quant_type,
"training_mode_i": 1 if _get_training_mode() else 0,
}
# Make sure the quant weight can be flatten to 1D tensor safely, which com.microsoft::MatMulBnb4 requires.
found_dim1 = any(v == 1 for v in args[1].type().sizes())
if not found_dim1:
return None
absmax = g.op(
"Constant",
value_t=torch.tensor(absmax, dtype=pytorch_scalar_type_to_pytorch_dtype(args[0].type().scalarType())),
)
quant_weight = g.op(
"Reshape", args[1], g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))
) # flatten to 1D
tensor_args = [args[0], quant_weight, absmax]
return g.op("com.microsoft::MatMulBnb4", *tensor_args, **attrs)

View file

@ -12,7 +12,7 @@ from packaging.version import Version
from torch.onnx import register_custom_op_symbolic
from torch.onnx.symbolic_helper import _get_tensor_dim_size, _get_tensor_sizes, parse_args
from onnxruntime.training.utils import pytorch_dtype_to_onnx
from onnxruntime.training.utils import pytorch_type_to_onnx_dtype
from ._utils import get_runtime_pytorch_version
@ -145,7 +145,7 @@ def cross_entropy_loss(g, node, logits, target, weight, reduction, ignore_index,
weight_casted,
ignore_index,
reduction_s=reduction,
output_type_i=pytorch_dtype_to_onnx(output_type.scalarType()),
output_type_i=pytorch_type_to_onnx_dtype(output_type.scalarType()),
outputs=2,
)
output.setType(output_type)

View file

@ -19,7 +19,7 @@ from torch.utils.cpp_extension import ROCM_HOME
import onnxruntime
from onnxruntime.capi import _pybind_state as C
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch
from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch_dtype
from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3
from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils
@ -345,7 +345,7 @@ class GraphExecutionManager(GraphExecutionInterface):
cache_dir, f"{hash_fn(str(self._flattened_module).encode()).hexdigest()}_{get_rank()}.onnx"
)
if os.path.exists(cache_dir) and os.path.isfile(filename):
self._logger.info(
self._logger.warning(
f"Cached model detected! Cached model will be used to save export and initialization time."
f"If you want the model to be re-exported then DELETE {filename}."
)
@ -627,7 +627,7 @@ class GraphExecutionManager(GraphExecutionInterface):
kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros(
STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE,
dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
dtype=onnx_dtype_to_pytorch_dtype(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE),
device=device,
).requires_grad_()

View file

@ -14,7 +14,7 @@ from onnxruntime.capi._pybind_state import (
register_shape_inference_function,
register_torch_autograd_function,
)
from onnxruntime.training.utils import pytorch_dtype_to_onnx
from onnxruntime.training.utils import pytorch_type_to_onnx_dtype
from ._custom_autograd_function_exporter import register_custom_function_schema_supplementary
from ._utils import get_fully_qualified_class_name
@ -149,7 +149,7 @@ def post_processing_enable_zero_stage3_compat(
c,
graph_input.name,
len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank
pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type
pytorch_type_to_onnx_dtype(zero_stage3_named_params[graph_input.name].dtype), # new data type
)
# Delete exported_model.graph.input

View file

@ -366,7 +366,7 @@ class _RuntimeOptions:
# Cache exported model
if "ORTMODULE_CACHE_DIR" in os.environ:
self._logger.info("ORTModule cache optimization is ON.")
self._logger.warning("ORTModule optimization for caching exported model is ON.")
self.ortmodule_cache_dir = os.getenv("ORTMODULE_CACHE_DIR")
# Experimental features.

View file

@ -9,7 +9,11 @@ from onnxruntime.training.utils.torch_io_helper import (
extract_data_and_schema,
unflatten_data_using_schema,
)
from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx
from onnxruntime.training.utils.torch_type_map import (
onnx_dtype_to_pytorch_dtype,
pytorch_scalar_type_to_pytorch_dtype,
pytorch_type_to_onnx_dtype,
)
__all__ = [
"PrimitiveType",
@ -17,6 +21,7 @@ __all__ = [
"ORTModelInputOutputSchemaType",
"extract_data_and_schema",
"unflatten_data_using_schema",
"pytorch_dtype_to_onnx",
"onnx_dtype_to_pytorch",
"pytorch_type_to_onnx_dtype",
"onnx_dtype_to_pytorch_dtype",
"pytorch_scalar_type_to_pytorch_dtype",
]

View file

@ -17,7 +17,7 @@ import torch
from onnxruntime.training.utils import (
ORTModelInputOutputType,
extract_data_and_schema,
pytorch_dtype_to_onnx,
pytorch_type_to_onnx_dtype,
unflatten_data_using_schema,
)
@ -324,7 +324,7 @@ class ORTZeROOffloadPreForwardFunction(torch.autograd.Function):
start_offset = len(tensor_input_shapes) - len(partitioned_params)
for index, param in enumerate(partitioned_params):
tensor_output_shapes[start_offset + index] = list(param.ds_shape)
tensor_output_dtypes[start_offset + index] = int(pytorch_dtype_to_onnx(param.dtype))
tensor_output_dtypes[start_offset + index] = int(pytorch_type_to_onnx_dtype(param.dtype))
assert len(tensor_output_shapes) == len(tensor_input_shapes)
assert len(tensor_output_dtypes) == len(tensor_input_dtypes)

View file

@ -36,8 +36,10 @@ _DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _C
_ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()}
def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType:
"""Converts a pytorch dtype or scalar type string to an onnx dtype."""
def pytorch_type_to_onnx_dtype(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType:
"""Converts a pytorch dtype or scalar type string to an onnx dtype.
PyTorch type can be either a dtype or a scalar type string.
"""
dtype = dtype_or_scalar_type
if isinstance(dtype, str):
if dtype not in _CAST_PYTORCH_TO_ONNX:
@ -49,7 +51,15 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc
return _DTYPE_TO_ONNX[dtype]
def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype:
def pytorch_scalar_type_to_pytorch_dtype(dtype: str) -> torch.dtype:
"""Converts a pytorch scalar type string to a pytorch dtype."""
assert isinstance(dtype, str)
if dtype not in _CAST_PYTORCH_TO_ONNX:
raise RuntimeError(f"Unsupported dtype {dtype}")
return _CAST_PYTORCH_TO_ONNX[dtype][1]
def onnx_dtype_to_pytorch_dtype(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype:
"""Converts an onnx dtype to a pytorch dtype."""
if dtype not in _ONNX_TO_DTYPE:
raise RuntimeError(f"Unsupported dtype {dtype}")