support gpu

This commit is contained in:
Julian Freeman
2026-01-19 15:35:53 -04:00
parent 330a62027c
commit 2c89ef42b3
6 changed files with 61 additions and 27 deletions

View File

@@ -26,6 +26,6 @@ rayon = "1.10"
tauri-plugin-dialog = "2"
imageproc = "0.25"
ab_glyph = "0.2.23"
ort = { version = "=2.0.0-rc.11", features = ["load-dynamic"] }
ort = { version = "=2.0.0-rc.11", features = ["load-dynamic", "directml"] }
ndarray = "0.16"

View File

@@ -1,5 +1,5 @@
use crate::ort_ops;
use image::{Rgba, RgbaImage};
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
pub fn run_lama_inpainting(
@@ -8,14 +8,8 @@ pub fn run_lama_inpainting(
mask_image: &image::GrayImage,
) -> Result<RgbaImage, String> {
// 1. Initialize Session
let mut session = Session::builder()
.map_err(|e| format!("Failed to create session builder: {}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| format!("Failed to set opt level: {}", e))?
.with_intra_threads(4)
.map_err(|e| format!("Failed to set threads: {}", e))?
.commit_from_file(model_path)
.map_err(|e| format!("Failed to load model from {:?}: {}", model_path, e))?;
let mut session = ort_ops::create_session(model_path)
.map_err(|e| format!("Failed to create ORT session for LAMA: {}", e))?;
// 2. Preprocess
let target_size = (512, 512);

View File

@@ -94,8 +94,9 @@ use imageproc::drawing::draw_text_mut;
use ab_glyph::{FontRef, PxScale};
use tauri::{AppHandle, Manager};
mod lama;
mod ocr;
pub mod lama;
pub mod ocr;
pub mod ort_ops;
// Embed the font to ensure it's always available without path issues
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");

View File

@@ -1,5 +1,5 @@
use crate::ort_ops;
use image::{GenericImageView, Rgba, RgbaImage};
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
use std::path::Path;
@@ -15,15 +15,9 @@ pub fn run_ocr_detection(
model_path: &Path,
input_image: &RgbaImage,
) -> Result<Vec<DetectedBox>, String> {
// 1. Load Model
let mut session = Session::builder()
.map_err(|e| format!("Failed to create session: {}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| format!("Failed to set opt: {}", e))?
.with_intra_threads(4)
.map_err(|e| format!("Failed to set threads: {}", e))?
.commit_from_file(model_path)
.map_err(|e| format!("Failed to load OCR model: {}", e))?;
// 1. Load Model using the shared function
let mut session = ort_ops::create_session(model_path)
.map_err(|e| format!("Failed to create ORT session: {}", e))?;
// 2. Preprocess
// DBNet expects standard normalization: (img - mean) / std
@@ -133,7 +127,7 @@ pub fn run_ocr_detection(
let idx = (y * map_w + x) as usize;
if binary_map[idx] && !visited[idx] {
// Flood fill
let mut stack = vec![(x, y)];
let mut stack = vec![(x as u32, y as u32)];
visited[idx] = true;
let mut min_x = x;

41
src-tauri/src/ort_ops.rs Normal file
View File

@@ -0,0 +1,41 @@
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::execution_providers::DirectMLExecutionProvider;
use std::path::Path;
/// Attempts to create an ORT session with GPU (DirectML) acceleration.
/// If GPU initialization fails, it falls back to a CPU-only session.
pub fn create_session(model_path: &Path) -> Result<Session, ort::Error> {
// Try to build with DirectML
let dm_provider = DirectMLExecutionProvider::default().build();
let session_builder = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?;
match session_builder.with_execution_providers([dm_provider]) {
Ok(builder_with_dm) => {
println!("Attempting to commit session with DirectML provider...");
match builder_with_dm.commit_from_file(model_path) {
Ok(session) => {
println!("Successfully created ORT session with DirectML GPU acceleration.");
return Ok(session);
},
Err(e) => {
println!("Failed to create session with DirectML: {:?}. Falling back to CPU.", e);
// Fall through to CPU execution
}
}
},
Err(e) => {
println!("Failed to build session with DirectML provider: {:?}. Falling back to CPU.", e);
// Fall through to CPU execution
}
};
// Fallback to CPU
println!("Creating ORT session with CPU provider.");
Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.commit_from_file(model_path)
}

View File

@@ -236,10 +236,14 @@ export const useGalleryStore = defineStore("gallery", () => {
isProcessing.value = true;
progress.value = { current: 0, total: candidates.length };
try {
// Sequential processing to avoid freezing UI or overloading backend
for (const img of candidates) {
// Parallel processing in batches to avoid overwhelming backend
const batchSize = 4;
for (let i = 0; i < candidates.length; i += batchSize) {
const batch = candidates.slice(i, i + batchSize).map(async (img) => {
await runInpaintingForImage(img);
progress.value.current++;
});
await Promise.all(batch);
}
alert("批量处理完成!");
} catch (e) {