support gpu
This commit is contained in:
@@ -26,6 +26,6 @@ rayon = "1.10"
|
|||||||
tauri-plugin-dialog = "2"
|
tauri-plugin-dialog = "2"
|
||||||
imageproc = "0.25"
|
imageproc = "0.25"
|
||||||
ab_glyph = "0.2.23"
|
ab_glyph = "0.2.23"
|
||||||
ort = { version = "=2.0.0-rc.11", features = ["load-dynamic"] }
|
ort = { version = "=2.0.0-rc.11", features = ["load-dynamic", "directml"] }
|
||||||
ndarray = "0.16"
|
ndarray = "0.16"
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
use crate::ort_ops;
|
||||||
use image::{Rgba, RgbaImage};
|
use image::{Rgba, RgbaImage};
|
||||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
|
||||||
use ort::value::Value;
|
use ort::value::Value;
|
||||||
|
|
||||||
pub fn run_lama_inpainting(
|
pub fn run_lama_inpainting(
|
||||||
@@ -8,14 +8,8 @@ pub fn run_lama_inpainting(
|
|||||||
mask_image: &image::GrayImage,
|
mask_image: &image::GrayImage,
|
||||||
) -> Result<RgbaImage, String> {
|
) -> Result<RgbaImage, String> {
|
||||||
// 1. Initialize Session
|
// 1. Initialize Session
|
||||||
let mut session = Session::builder()
|
let mut session = ort_ops::create_session(model_path)
|
||||||
.map_err(|e| format!("Failed to create session builder: {}", e))?
|
.map_err(|e| format!("Failed to create ORT session for LAMA: {}", e))?;
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
|
||||||
.map_err(|e| format!("Failed to set opt level: {}", e))?
|
|
||||||
.with_intra_threads(4)
|
|
||||||
.map_err(|e| format!("Failed to set threads: {}", e))?
|
|
||||||
.commit_from_file(model_path)
|
|
||||||
.map_err(|e| format!("Failed to load model from {:?}: {}", model_path, e))?;
|
|
||||||
|
|
||||||
// 2. Preprocess
|
// 2. Preprocess
|
||||||
let target_size = (512, 512);
|
let target_size = (512, 512);
|
||||||
|
|||||||
@@ -94,8 +94,9 @@ use imageproc::drawing::draw_text_mut;
|
|||||||
use ab_glyph::{FontRef, PxScale};
|
use ab_glyph::{FontRef, PxScale};
|
||||||
use tauri::{AppHandle, Manager};
|
use tauri::{AppHandle, Manager};
|
||||||
|
|
||||||
mod lama;
|
pub mod lama;
|
||||||
mod ocr;
|
pub mod ocr;
|
||||||
|
pub mod ort_ops;
|
||||||
|
|
||||||
// Embed the font to ensure it's always available without path issues
|
// Embed the font to ensure it's always available without path issues
|
||||||
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");
|
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
use crate::ort_ops;
|
||||||
use image::{GenericImageView, Rgba, RgbaImage};
|
use image::{GenericImageView, Rgba, RgbaImage};
|
||||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
|
||||||
use ort::value::Value;
|
use ort::value::Value;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
@@ -15,15 +15,9 @@ pub fn run_ocr_detection(
|
|||||||
model_path: &Path,
|
model_path: &Path,
|
||||||
input_image: &RgbaImage,
|
input_image: &RgbaImage,
|
||||||
) -> Result<Vec<DetectedBox>, String> {
|
) -> Result<Vec<DetectedBox>, String> {
|
||||||
// 1. Load Model
|
// 1. Load Model using the shared function
|
||||||
let mut session = Session::builder()
|
let mut session = ort_ops::create_session(model_path)
|
||||||
.map_err(|e| format!("Failed to create session: {}", e))?
|
.map_err(|e| format!("Failed to create ORT session: {}", e))?;
|
||||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
|
||||||
.map_err(|e| format!("Failed to set opt: {}", e))?
|
|
||||||
.with_intra_threads(4)
|
|
||||||
.map_err(|e| format!("Failed to set threads: {}", e))?
|
|
||||||
.commit_from_file(model_path)
|
|
||||||
.map_err(|e| format!("Failed to load OCR model: {}", e))?;
|
|
||||||
|
|
||||||
// 2. Preprocess
|
// 2. Preprocess
|
||||||
// DBNet expects standard normalization: (img - mean) / std
|
// DBNet expects standard normalization: (img - mean) / std
|
||||||
@@ -133,7 +127,7 @@ pub fn run_ocr_detection(
|
|||||||
let idx = (y * map_w + x) as usize;
|
let idx = (y * map_w + x) as usize;
|
||||||
if binary_map[idx] && !visited[idx] {
|
if binary_map[idx] && !visited[idx] {
|
||||||
// Flood fill
|
// Flood fill
|
||||||
let mut stack = vec![(x, y)];
|
let mut stack = vec![(x as u32, y as u32)];
|
||||||
visited[idx] = true;
|
visited[idx] = true;
|
||||||
|
|
||||||
let mut min_x = x;
|
let mut min_x = x;
|
||||||
|
|||||||
41
src-tauri/src/ort_ops.rs
Normal file
41
src-tauri/src/ort_ops.rs
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||||
|
use ort::execution_providers::DirectMLExecutionProvider;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// Attempts to create an ORT session with GPU (DirectML) acceleration.
|
||||||
|
/// If GPU initialization fails, it falls back to a CPU-only session.
|
||||||
|
pub fn create_session(model_path: &Path) -> Result<Session, ort::Error> {
|
||||||
|
|
||||||
|
// Try to build with DirectML
|
||||||
|
let dm_provider = DirectMLExecutionProvider::default().build();
|
||||||
|
let session_builder = Session::builder()?
|
||||||
|
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||||
|
.with_intra_threads(4)?;
|
||||||
|
|
||||||
|
match session_builder.with_execution_providers([dm_provider]) {
|
||||||
|
Ok(builder_with_dm) => {
|
||||||
|
println!("Attempting to commit session with DirectML provider...");
|
||||||
|
match builder_with_dm.commit_from_file(model_path) {
|
||||||
|
Ok(session) => {
|
||||||
|
println!("Successfully created ORT session with DirectML GPU acceleration.");
|
||||||
|
return Ok(session);
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
println!("Failed to create session with DirectML: {:?}. Falling back to CPU.", e);
|
||||||
|
// Fall through to CPU execution
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
println!("Failed to build session with DirectML provider: {:?}. Falling back to CPU.", e);
|
||||||
|
// Fall through to CPU execution
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fallback to CPU
|
||||||
|
println!("Creating ORT session with CPU provider.");
|
||||||
|
Session::builder()?
|
||||||
|
.with_optimization_level(GraphOptimizationLevel::Level3)?
|
||||||
|
.with_intra_threads(4)?
|
||||||
|
.commit_from_file(model_path)
|
||||||
|
}
|
||||||
@@ -236,10 +236,14 @@ export const useGalleryStore = defineStore("gallery", () => {
|
|||||||
isProcessing.value = true;
|
isProcessing.value = true;
|
||||||
progress.value = { current: 0, total: candidates.length };
|
progress.value = { current: 0, total: candidates.length };
|
||||||
try {
|
try {
|
||||||
// Sequential processing to avoid freezing UI or overloading backend
|
// Parallel processing in batches to avoid overwhelming backend
|
||||||
for (const img of candidates) {
|
const batchSize = 4;
|
||||||
await runInpaintingForImage(img);
|
for (let i = 0; i < candidates.length; i += batchSize) {
|
||||||
progress.value.current++;
|
const batch = candidates.slice(i, i + batchSize).map(async (img) => {
|
||||||
|
await runInpaintingForImage(img);
|
||||||
|
progress.value.current++;
|
||||||
|
});
|
||||||
|
await Promise.all(batch);
|
||||||
}
|
}
|
||||||
alert("批量处理完成!");
|
alert("批量处理完成!");
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
|
|||||||
Reference in New Issue
Block a user