42 lines
1.7 KiB
Rust
42 lines
1.7 KiB
Rust
use ort::session::{Session, builder::GraphOptimizationLevel};
|
|
use ort::execution_providers::DirectMLExecutionProvider;
|
|
use std::path::Path;
|
|
|
|
/// Attempts to create an ORT session with GPU (DirectML) acceleration.
|
|
/// If GPU initialization fails, it falls back to a CPU-only session.
|
|
pub fn create_session(model_path: &Path) -> Result<Session, ort::Error> {
|
|
|
|
// Try to build with DirectML
|
|
let dm_provider = DirectMLExecutionProvider::default().build();
|
|
let session_builder = Session::builder()?
|
|
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
|
.with_intra_threads(4)?;
|
|
|
|
match session_builder.with_execution_providers([dm_provider]) {
|
|
Ok(builder_with_dm) => {
|
|
println!("Attempting to commit session with DirectML provider...");
|
|
match builder_with_dm.commit_from_file(model_path) {
|
|
Ok(session) => {
|
|
println!("Successfully created ORT session with DirectML GPU acceleration.");
|
|
return Ok(session);
|
|
},
|
|
Err(e) => {
|
|
println!("Failed to create session with DirectML: {:?}. Falling back to CPU.", e);
|
|
// Fall through to CPU execution
|
|
}
|
|
}
|
|
},
|
|
Err(e) => {
|
|
println!("Failed to build session with DirectML provider: {:?}. Falling back to CPU.", e);
|
|
// Fall through to CPU execution
|
|
}
|
|
};
|
|
|
|
// Fallback to CPU
|
|
println!("Creating ORT session with CPU provider.");
|
|
Session::builder()?
|
|
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
|
.with_intra_threads(4)?
|
|
.commit_from_file(model_path)
|
|
}
|