onnxruntime/rust/onnxruntime/examples/sample.rs
Boyd Johnson 96b95a24ee
Add rust bindings (#12606)
This adds updated Rust bindings that have been located at
[nbigaouette/onnxruntime-rs](https://github.com/nbigaouette/onnxruntime-rs).

check out the build instructions included in this PR at /rust/BUILD.md.

Changes to the bindings included in this PR:
- The bindings are generated with the build script on each build
- The onnxruntime shared library is built with ORT_RUST_STRATEGY=compile
which is now the default.
- A memory leak was fixed where a call to free wasn't called
- Several small memory errors were fixed
- Session is Send but not Sync, Environment is Send + Sync
- Inputs and Outputs can be ndarray::Arrays of many different types.

Some commits can be squashed, if wanted, but were left unsquashed to
show differences between old bindings and new bindings.

This PR does not cover packaging nor does it include the Rust bindings
withing the build system.

For those of you who have previous Rust code based on the bindings,
these new bindings
can be used as a `path` dependency or a `git` dependency (though I have
not tested this out).

The work addressed in this PR was discussed in #11992
2023-02-08 14:57:15 -08:00

83 lines
2.6 KiB
Rust

#![forbid(unsafe_code)]
use onnxruntime::{environment::Environment, ndarray::Array, GraphOptimizationLevel, LoggingLevel};
use std::env::var;
use tracing::Level;
use tracing_subscriber::FmtSubscriber;
type Error = Box<dyn std::error::Error>;
fn main() {
if let Err(e) = run() {
eprintln!("Error: {}", e);
std::process::exit(1);
}
}
fn run() -> Result<(), Error> {
// Setup the example's log level.
// NOTE: ONNX Runtime's log level is controlled separately when building the environment.
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::TRACE)
.finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
let path = var("RUST_ONNXRUNTIME_LIBRARY_PATH").ok();
let builder = Environment::builder()
.with_name("test")
.with_log_level(LoggingLevel::Warning);
let builder = if let Some(path) = path.clone() {
builder.with_library_path(path)
} else {
builder
};
let environment = builder.build().unwrap();
let session = environment
.new_session_builder()?
.with_graph_optimization_level(GraphOptimizationLevel::Basic)?
.with_intra_op_num_threads(1)?
// NOTE: The example uses SqueezeNet 1.0 (ONNX version: 1.3, Opset version: 8),
// _not_ SqueezeNet 1.1 as downloaded by '.with_model_downloaded(ImageClassification::SqueezeNet)'
// Obtain it with:
// curl -LO "https://github.com/onnx/models/raw/main/vision/classification/squeezenet/model/squeezenet1.0-8.onnx"
.with_model_from_file("squeezenet1.0-8.onnx")?;
let input0_shape: Vec<usize> = session.inputs[0]
.dimensions()
.map(std::option::Option::unwrap)
.collect();
let output0_shape: Vec<usize> = session.outputs[0]
.dimensions()
.map(std::option::Option::unwrap)
.collect();
assert_eq!(input0_shape, [1, 3, 224, 224]);
assert_eq!(output0_shape, [1, 1000, 1, 1]);
// initialize input data with values in [0.0, 1.0]
let n: u32 = session.inputs[0]
.dimensions
.iter()
.map(|d| d.unwrap())
.product();
let array = Array::linspace(0.0_f32, 1.0, n as usize)
.into_shape(input0_shape)
.unwrap();
let input_tensor_values = vec![array.into()];
let outputs = session.run(input_tensor_values)?;
let output = outputs[0].float_array().unwrap();
assert_eq!(output.shape(), output0_shape.as_slice());
for i in 0..5 {
println!("Score for class [{}] = {}", i, output[[0, i, 0, 0]]);
}
Ok(())
}