use ort::session::{Session, builder::GraphOptimizationLevel}; use ort::execution_providers::DirectMLExecutionProvider; use std::path::Path; /// Attempts to create an ORT session with GPU (DirectML) acceleration. /// If GPU initialization fails, it falls back to a CPU-only session. pub fn create_session(model_path: &Path) -> Result { // Try to build with DirectML let dm_provider = DirectMLExecutionProvider::default().build(); let session_builder = Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(4)?; match session_builder.with_execution_providers([dm_provider]) { Ok(builder_with_dm) => { println!("Attempting to commit session with DirectML provider..."); match builder_with_dm.commit_from_file(model_path) { Ok(session) => { println!("Successfully created ORT session with DirectML GPU acceleration."); return Ok(session); }, Err(e) => { println!("Failed to create session with DirectML: {:?}. Falling back to CPU.", e); // Fall through to CPU execution } } }, Err(e) => { println!("Failed to build session with DirectML provider: {:?}. Falling back to CPU.", e); // Fall through to CPU execution } }; // Fallback to CPU println!("Creating ORT session with CPU provider."); Session::builder()? .with_optimization_level(GraphOptimizationLevel::Level3)? .with_intra_threads(4)? .commit_from_file(model_path) }