mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-02 23:39:58 +00:00
Remove internal enforce for IO binding inputs (#18266)
### Description This PR removes an internal `ORT_ENFORCE` when binding `torch.tensor` inputs using IO binding for end-to-end scripts. ### Motivation and Context In merged exports of PyTorch models to ONNX, each past key and past value in the past KV cache has an input shape of `(batch_size, num_heads, past_sequence_length, head_size)`. In the first pass through the model to process the prompt, `past_sequence_length = 0`. Therefore, each of these inputs is of shape `(batch_size, num_heads, 0, head_size)`. In subsequent passes, `past_sequence_length > 0`. When binding a `torch.tensor` of shape `(batch_size, num_heads, 0, head_size)` with `io_binding.bind_input`, the tensor's `data_ptr()` must be passed. For a `torch.tensor` of this shape, its `data_ptr()` returns 0. Because it returns 0, the existing `ORT_ENFORCE` is therefore false and an error is raised. By removing the internal `ORT_ENFORCE`, no error is raised and the model runs successfully. LLaMA-2 Example: Input Name | Input Size | Device | Device ID | Torch Dtype | data_ptr() ------------- | ----------- | ------- | ----------- | ------------- | ----------- input_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 | 140639561842688 attention_mask | torch.Size([1, 11]) | cuda | 7 | torch.int64 | 140639561843200 position_ids | torch.Size([1, 11]) | cuda | 7 | torch.int64 | 140639561844224 past_key_values.0.key | torch.Size([1, 32, 0, 128]) | cuda | 7 | torch.float32 | 0 past_key_values.0.value | torch.Size([1, 32, 0, 128]) | cuda | 7 | torch.float32 | 0 ... | ... | ... | ... | ... | ...
This commit is contained in:
parent
84bdf04b25
commit
08eaa1c55d
1 changed files with 0 additions and 2 deletions
|
|
@ -60,8 +60,6 @@ void addIoBindingMethods(pybind11::module& m) {
|
|||
})
|
||||
// This binds input as a Tensor that wraps memory pointer along with the OrtMemoryInfo
|
||||
.def("bind_input", [](SessionIOBinding* io_binding, const std::string& name, const OrtDevice& device, py::object& element_type, const std::vector<int64_t>& shape, int64_t data_ptr) -> void {
|
||||
ORT_ENFORCE(data_ptr != 0, "Pointer to data memory is not valid");
|
||||
|
||||
PyArray_Descr* dtype;
|
||||
if (!PyArray_DescrConverter(element_type.ptr(), &dtype)) {
|
||||
throw std::runtime_error("Not a valid numpy type");
|
||||
|
|
|
|||
Loading…
Reference in a new issue