support gpu
This commit is contained in:
@@ -26,6 +26,6 @@ rayon = "1.10"
|
||||
tauri-plugin-dialog = "2"
|
||||
imageproc = "0.25"
|
||||
ab_glyph = "0.2.23"
|
||||
ort = { version = "=2.0.0-rc.11", features = ["load-dynamic"] }
|
||||
ort = { version = "=2.0.0-rc.11", features = ["load-dynamic", "directml"] }
|
||||
ndarray = "0.16"
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::ort_ops;
|
||||
use image::{Rgba, RgbaImage};
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use ort::value::Value;
|
||||
|
||||
pub fn run_lama_inpainting(
|
||||
@@ -8,14 +8,8 @@ pub fn run_lama_inpainting(
|
||||
mask_image: &image::GrayImage,
|
||||
) -> Result<RgbaImage, String> {
|
||||
// 1. Initialize Session
|
||||
let mut session = Session::builder()
|
||||
.map_err(|e| format!("Failed to create session builder: {}", e))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.map_err(|e| format!("Failed to set opt level: {}", e))?
|
||||
.with_intra_threads(4)
|
||||
.map_err(|e| format!("Failed to set threads: {}", e))?
|
||||
.commit_from_file(model_path)
|
||||
.map_err(|e| format!("Failed to load model from {:?}: {}", model_path, e))?;
|
||||
let mut session = ort_ops::create_session(model_path)
|
||||
.map_err(|e| format!("Failed to create ORT session for LAMA: {}", e))?;
|
||||
|
||||
// 2. Preprocess
|
||||
let target_size = (512, 512);
|
||||
|
||||
@@ -94,8 +94,9 @@ use imageproc::drawing::draw_text_mut;
|
||||
use ab_glyph::{FontRef, PxScale};
|
||||
use tauri::{AppHandle, Manager};
|
||||
|
||||
mod lama;
|
||||
mod ocr;
|
||||
pub mod lama;
|
||||
pub mod ocr;
|
||||
pub mod ort_ops;
|
||||
|
||||
// Embed the font to ensure it's always available without path issues
|
||||
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::ort_ops;
|
||||
use image::{GenericImageView, Rgba, RgbaImage};
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use ort::value::Value;
|
||||
use std::path::Path;
|
||||
|
||||
@@ -15,15 +15,9 @@ pub fn run_ocr_detection(
|
||||
model_path: &Path,
|
||||
input_image: &RgbaImage,
|
||||
) -> Result<Vec<DetectedBox>, String> {
|
||||
// 1. Load Model
|
||||
let mut session = Session::builder()
|
||||
.map_err(|e| format!("Failed to create session: {}", e))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.map_err(|e| format!("Failed to set opt: {}", e))?
|
||||
.with_intra_threads(4)
|
||||
.map_err(|e| format!("Failed to set threads: {}", e))?
|
||||
.commit_from_file(model_path)
|
||||
.map_err(|e| format!("Failed to load OCR model: {}", e))?;
|
||||
// 1. Load Model using the shared function
|
||||
let mut session = ort_ops::create_session(model_path)
|
||||
.map_err(|e| format!("Failed to create ORT session: {}", e))?;
|
||||
|
||||
// 2. Preprocess
|
||||
// DBNet expects standard normalization: (img - mean) / std
|
||||
@@ -133,7 +127,7 @@ pub fn run_ocr_detection(
|
||||
let idx = (y * map_w + x) as usize;
|
||||
if binary_map[idx] && !visited[idx] {
|
||||
// Flood fill
|
||||
let mut stack = vec![(x, y)];
|
||||
let mut stack = vec![(x as u32, y as u32)];
|
||||
visited[idx] = true;
|
||||
|
||||
let mut min_x = x;
|
||||
|
||||
41
src-tauri/src/ort_ops.rs
Normal file
41
src-tauri/src/ort_ops.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use ort::execution_providers::DirectMLExecutionProvider;
|
||||
use std::path::Path;
|
||||
|
||||
/// Attempts to create an ORT session with GPU (DirectML) acceleration.
|
||||
/// If GPU initialization fails, it falls back to a CPU-only session.
|
||||
pub fn create_session(model_path: &Path) -> Result<Session, ort::Error> {
|
||||
|
||||
// Try to build with DirectML
|
||||
let dm_provider = DirectMLExecutionProvider::default().build();
|
||||
let session_builder = Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(4)?;
|
||||
|
||||
match session_builder.with_execution_providers([dm_provider]) {
|
||||
Ok(builder_with_dm) => {
|
||||
println!("Attempting to commit session with DirectML provider...");
|
||||
match builder_with_dm.commit_from_file(model_path) {
|
||||
Ok(session) => {
|
||||
println!("Successfully created ORT session with DirectML GPU acceleration.");
|
||||
return Ok(session);
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Failed to create session with DirectML: {:?}. Falling back to CPU.", e);
|
||||
// Fall through to CPU execution
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Failed to build session with DirectML provider: {:?}. Falling back to CPU.", e);
|
||||
// Fall through to CPU execution
|
||||
}
|
||||
};
|
||||
|
||||
// Fallback to CPU
|
||||
println!("Creating ORT session with CPU provider.");
|
||||
Session::builder()?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||
.with_intra_threads(4)?
|
||||
.commit_from_file(model_path)
|
||||
}
|
||||
@@ -236,10 +236,14 @@ export const useGalleryStore = defineStore("gallery", () => {
|
||||
isProcessing.value = true;
|
||||
progress.value = { current: 0, total: candidates.length };
|
||||
try {
|
||||
// Sequential processing to avoid freezing UI or overloading backend
|
||||
for (const img of candidates) {
|
||||
await runInpaintingForImage(img);
|
||||
progress.value.current++;
|
||||
// Parallel processing in batches to avoid overwhelming backend
|
||||
const batchSize = 4;
|
||||
for (let i = 0; i < candidates.length; i += batchSize) {
|
||||
const batch = candidates.slice(i, i + batchSize).map(async (img) => {
|
||||
await runInpaintingForImage(img);
|
||||
progress.value.current++;
|
||||
});
|
||||
await Promise.all(batch);
|
||||
}
|
||||
alert("批量处理完成!");
|
||||
} catch (e) {
|
||||
|
||||
Reference in New Issue
Block a user