diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 58926a5..089ff45 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -26,6 +26,6 @@ rayon = "1.10" tauri-plugin-dialog = "2" imageproc = "0.25" ab_glyph = "0.2.23" -ort = { version = "=2.0.0-rc.11", features = ["load-dynamic"] } +ort = { version = "=2.0.0-rc.11", features = ["load-dynamic", "directml"] } ndarray = "0.16" diff --git a/src-tauri/src/lama.rs b/src-tauri/src/lama.rs index 716b91a..8017b42 100644 --- a/src-tauri/src/lama.rs +++ b/src-tauri/src/lama.rs @@ -1,5 +1,5 @@ +use crate::ort_ops; use image::{Rgba, RgbaImage}; -use ort::session::{Session, builder::GraphOptimizationLevel}; use ort::value::Value; pub fn run_lama_inpainting( @@ -8,14 +8,8 @@ pub fn run_lama_inpainting( mask_image: &image::GrayImage, ) -> Result { // 1. Initialize Session - let mut session = Session::builder() - .map_err(|e| format!("Failed to create session builder: {}", e))? - .with_optimization_level(GraphOptimizationLevel::Level3) - .map_err(|e| format!("Failed to set opt level: {}", e))? - .with_intra_threads(4) - .map_err(|e| format!("Failed to set threads: {}", e))? - .commit_from_file(model_path) - .map_err(|e| format!("Failed to load model from {:?}: {}", model_path, e))?; + let mut session = ort_ops::create_session(model_path) + .map_err(|e| format!("Failed to create ORT session for LAMA: {}", e))?; // 2. Preprocess let target_size = (512, 512); diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 00e4411..1ae3ba8 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -94,8 +94,9 @@ use imageproc::drawing::draw_text_mut; use ab_glyph::{FontRef, PxScale}; use tauri::{AppHandle, Manager}; -mod lama; -mod ocr; +pub mod lama; +pub mod ocr; +pub mod ort_ops; // Embed the font to ensure it's always available without path issues const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf"); diff --git a/src-tauri/src/ocr.rs b/src-tauri/src/ocr.rs index 9b7eda8..0671217 100644 --- a/src-tauri/src/ocr.rs +++ b/src-tauri/src/ocr.rs @@ -1,5 +1,5 @@ +use crate::ort_ops; use image::{GenericImageView, Rgba, RgbaImage}; -use ort::session::{Session, builder::GraphOptimizationLevel}; use ort::value::Value; use std::path::Path; @@ -15,15 +15,9 @@ pub fn run_ocr_detection( model_path: &Path, input_image: &RgbaImage, ) -> Result, String> { - // 1. Load Model - let mut session = Session::builder() - .map_err(|e| format!("Failed to create session: {}", e))? - .with_optimization_level(GraphOptimizationLevel::Level3) - .map_err(|e| format!("Failed to set opt: {}", e))? - .with_intra_threads(4) - .map_err(|e| format!("Failed to set threads: {}", e))? - .commit_from_file(model_path) - .map_err(|e| format!("Failed to load OCR model: {}", e))?; + // 1. Load Model using the shared function + let mut session = ort_ops::create_session(model_path) + .map_err(|e| format!("Failed to create ORT session: {}", e))?; // 2. Preprocess // DBNet expects standard normalization: (img - mean) / std @@ -133,7 +127,7 @@ pub fn run_ocr_detection( let idx = (y * map_w + x) as usize; if binary_map[idx] && !visited[idx] { // Flood fill - let mut stack = vec![(x, y)]; + let mut stack = vec![(x as u32, y as u32)]; visited[idx] = true; let mut min_x = x; diff --git a/src-tauri/src/ort_ops.rs b/src-tauri/src/ort_ops.rs new file mode 100644 index 0000000..1f416d2 --- /dev/null +++ b/src-tauri/src/ort_ops.rs @@ -0,0 +1,41 @@ +use ort::session::{Session, builder::GraphOptimizationLevel}; +use ort::execution_providers::DirectMLExecutionProvider; +use std::path::Path; + +/// Attempts to create an ORT session with GPU (DirectML) acceleration. +/// If GPU initialization fails, it falls back to a CPU-only session. +pub fn create_session(model_path: &Path) -> Result { + + // Try to build with DirectML + let dm_provider = DirectMLExecutionProvider::default().build(); + let session_builder = Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(4)?; + + match session_builder.with_execution_providers([dm_provider]) { + Ok(builder_with_dm) => { + println!("Attempting to commit session with DirectML provider..."); + match builder_with_dm.commit_from_file(model_path) { + Ok(session) => { + println!("Successfully created ORT session with DirectML GPU acceleration."); + return Ok(session); + }, + Err(e) => { + println!("Failed to create session with DirectML: {:?}. Falling back to CPU.", e); + // Fall through to CPU execution + } + } + }, + Err(e) => { + println!("Failed to build session with DirectML provider: {:?}. Falling back to CPU.", e); + // Fall through to CPU execution + } + }; + + // Fallback to CPU + println!("Creating ORT session with CPU provider."); + Session::builder()? + .with_optimization_level(GraphOptimizationLevel::Level3)? + .with_intra_threads(4)? + .commit_from_file(model_path) +} diff --git a/src/stores/gallery.ts b/src/stores/gallery.ts index f43a3c7..e372d24 100644 --- a/src/stores/gallery.ts +++ b/src/stores/gallery.ts @@ -236,10 +236,14 @@ export const useGalleryStore = defineStore("gallery", () => { isProcessing.value = true; progress.value = { current: 0, total: candidates.length }; try { - // Sequential processing to avoid freezing UI or overloading backend - for (const img of candidates) { - await runInpaintingForImage(img); - progress.value.current++; + // Parallel processing in batches to avoid overwhelming backend + const batchSize = 4; + for (let i = 0; i < candidates.length; i += batchSize) { + const batch = candidates.slice(i, i + batchSize).map(async (img) => { + await runInpaintingForImage(img); + progress.value.current++; + }); + await Promise.all(batch); } alert("批量处理完成!"); } catch (e) {