support gpu
This commit is contained in:
41
src-tauri/src/ort_ops.rs
Normal file
41
src-tauri/src/ort_ops.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use ort::execution_providers::DirectMLExecutionProvider;
|
||||
use std::path::Path;
|
||||
|
||||
/// Attempts to create an ORT session with GPU (DirectML) acceleration.
|
||||
/// If GPU initialization fails, it falls back to a CPU-only session.
|
||||
pub fn create_session(model_path: &Path) -> Result<Session, ort::Error> {
|
||||
|
||||
// Try to build with DirectML
|
||||
let dm_provider = DirectMLExecutionProvider::default().build();
|
||||
let session_builder = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(4)?;
|
||||
|
||||
match session_builder.with_execution_providers([dm_provider]) {
|
||||
Ok(builder_with_dm) => {
|
||||
println!("Attempting to commit session with DirectML provider...");
|
||||
match builder_with_dm.commit_from_file(model_path) {
|
||||
Ok(session) => {
|
||||
println!("Successfully created ORT session with DirectML GPU acceleration.");
|
||||
return Ok(session);
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Failed to create session with DirectML: {:?}. Falling back to CPU.", e);
|
||||
// Fall through to CPU execution
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Failed to build session with DirectML provider: {:?}. Falling back to CPU.", e);
|
||||
// Fall through to CPU execution
|
||||
}
|
||||
};
|
||||
|
||||
// Fallback to CPU
|
||||
println!("Creating ORT session with CPU provider.");
|
||||
Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(4)?
|
||||
.commit_from_file(model_path)
|
||||
}
|
||||
Reference in New Issue
Block a user