onnxruntime/rust/onnxruntime/examples/issue22.rs
Boyd Johnson 96b95a24ee
Add rust bindings (#12606)
This adds updated Rust bindings that have been located at
[nbigaouette/onnxruntime-rs](https://github.com/nbigaouette/onnxruntime-rs).

check out the build instructions included in this PR at /rust/BUILD.md.

Changes to the bindings included in this PR:
- The bindings are generated with the build script on each build
- The onnxruntime shared library is built with ORT_RUST_STRATEGY=compile
which is now the default.
- A memory leak was fixed where a call to free wasn't called
- Several small memory errors were fixed
- Session is Send but not Sync, Environment is Send + Sync
- Inputs and Outputs can be ndarray::Arrays of many different types.

Some commits can be squashed, if wanted, but were left unsquashed to
show differences between old bindings and new bindings.

This PR does not cover packaging nor does it include the Rust bindings
withing the build system.

For those of you who have previous Rust code based on the bindings,
these new bindings
can be used as a `path` dependency or a `git` dependency (though I have
not tested this out).

The work addressed in this PR was discussed in #11992
2023-02-08 14:57:15 -08:00

55 lines
1.7 KiB
Rust

//! Example reproducing issue #22.
//!
//! `model.onnx` available to download here:
//! https://drive.google.com/file/d/1FmL-Wpm06V-8wgRqvV3Skey_X98Ue4D_/view?usp=sharing
use ndarray::Array2;
use onnxruntime::{environment::Environment, GraphOptimizationLevel, LoggingLevel};
use std::env::var;
use tracing::Level;
use tracing_subscriber::FmtSubscriber;
fn main() {
// a builder for `FmtSubscriber`.
let subscriber = FmtSubscriber::builder()
// all spans/events with a level higher than TRACE (e.g, debug, info, warn, etc.)
// will be written to stdout.
.with_max_level(Level::TRACE)
// completes the builder.
.finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok();
let builder = Environment::builder()
.with_name("env")
.with_log_level(LoggingLevel::Warning);
let builder = if let Some(path) = path.clone() {
builder.with_library_path(path)
} else {
builder
};
let env = builder.build().unwrap();
let session = env
.new_session_builder()
.unwrap()
.with_graph_optimization_level(GraphOptimizationLevel::Basic)
.unwrap()
.with_model_from_file("model.onnx")
.unwrap();
println!("{:#?}", session.inputs);
println!("{:#?}", session.outputs);
let input_ids = Array2::<i64>::from_shape_vec((1, 3), vec![1, 2, 3]).unwrap();
let attention_mask = Array2::<i64>::from_shape_vec((1, 3), vec![1, 1, 1]).unwrap();
let inputs = vec![input_ids.into(), attention_mask.into()];
let outputs = session.run(inputs).unwrap();
print!("outputs: {:#?}", outputs[0].float_array().unwrap());
}