diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 1f8b7d1..ae11618 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -95,6 +95,7 @@ use ab_glyph::{FontRef, PxScale}; use tauri::{AppHandle, Manager}; mod lama; +mod ocr; // Embed the font to ensure it's always available without path issues const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf"); @@ -453,52 +454,63 @@ struct Rect { } #[tauri::command] -async fn detect_watermark(path: String) -> Result { - let img = image::open(&path).map_err(|e| e.to_string())?; - let (width, height) = img.dimensions(); - let gray = img.to_luma8(); +async fn detect_watermark(app: AppHandle, path: String) -> Result { + let img = image::open(&path).map_err(|e| e.to_string())?.to_rgba8(); - // "Stroke Detection" Algorithm - // Distinguishes "Text" (Thin White Strokes) from "Solid White Areas" (Walls, Sky) - // Logic: A white text pixel must be "sandwiched" by dark pixels within a short distance. + // 1. Try OCR Detection + let ocr_model_path = app.path().resource_dir() + .map_err(|e| e.to_string())? + .join("resources") + .join("en_PP-OCRv3_det_infer.onnx"); + + if ocr_model_path.exists() { + println!("Using OCR model for detection"); + match ocr::run_ocr_detection(&ocr_model_path, &img) { + Ok(boxes) => { + let rects = boxes.into_iter().map(|b| Rect { + x: b.x, + y: b.y, + width: b.width, + height: b.height, + }).collect(); + return Ok(DetectionResult { rects }); + } + Err(e) => { + eprintln!("OCR Detection failed: {}", e); + // Fallthrough to legacy + } + } + } else { + eprintln!("OCR model not found at {:?}", ocr_model_path); + } + + // 2. Legacy Detection (Fallback) + println!("Falling back to legacy detection"); + let (width, height) = img.dimensions(); + let gray = image::DynamicImage::ImageRgba8(img.clone()).to_luma8(); let cell_size = 10; let grid_w = (width + cell_size - 1) / cell_size; let grid_h = (height + cell_size - 1) / cell_size; let mut grid = vec![false; (grid_w * grid_h) as usize]; - // Focus Areas: Top 15%, Bottom 25% let top_limit = (height as f64 * 0.15) as u32; let bottom_start = (height as f64 * 0.75) as u32; - - let max_stroke_width = 15; // Max pixels for a text stroke thickness - let contrast_threshold = 40; // How much darker the background must be - let brightness_threshold = 200; // Text must be at least this white + let max_stroke_width = 15; + let contrast_threshold = 40; + let brightness_threshold = 200; for y in 1..height-1 { if y > top_limit && y < bottom_start { continue; } for x in 1..width-1 { let p = gray.get_pixel(x, y)[0]; - - // 1. Must be Bright if p < brightness_threshold { continue; } - // 2. Stroke Check - // We check for "Vertical Stroke" (Dark - Bright - Dark vertically) - // OR "Horizontal Stroke" (Dark - Bright - Dark horizontally) - let mut is_stroke = false; - - // Check Horizontal Stroke (Vertical boundaries? No, Vertical Stroke has Left/Right boundaries? - // Terminology: "Vertical Stroke" is like 'I'. It has Left/Right boundaries. - // "Horizontal Stroke" is like '-', It has Up/Down boundaries. - - // Let's check Left/Right boundaries (Vertical Stroke) let mut left_bound = false; let mut right_bound = false; - // Search Left for k in 1..=max_stroke_width { if x < k { break; } let neighbor = gray.get_pixel(x - k, y)[0]; @@ -507,7 +519,6 @@ async fn detect_watermark(path: String) -> Result { break; } } - // Search Right if left_bound { for k in 1..=max_stroke_width { if x + k >= width { break; } @@ -522,11 +533,8 @@ async fn detect_watermark(path: String) -> Result { if left_bound && right_bound { is_stroke = true; } else { - // Check Up/Down boundaries (Horizontal Stroke) let mut up_bound = false; let mut down_bound = false; - - // Search Up for k in 1..=max_stroke_width { if y < k { break; } let neighbor = gray.get_pixel(x, y - k)[0]; @@ -535,7 +543,6 @@ async fn detect_watermark(path: String) -> Result { break; } } - // Search Down if up_bound { for k in 1..=max_stroke_width { if y + k >= height { break; } @@ -546,10 +553,7 @@ async fn detect_watermark(path: String) -> Result { } } } - - if up_bound && down_bound { - is_stroke = true; - } + if up_bound && down_bound { is_stroke = true; } } if is_stroke { @@ -560,7 +564,6 @@ async fn detect_watermark(path: String) -> Result { } } - // Connected Components on Grid (Simple merging) let mut rects = Vec::new(); let mut visited = vec![false; grid.len()]; @@ -568,7 +571,6 @@ async fn detect_watermark(path: String) -> Result { for gx in 0..grid_w { let idx = (gy * grid_w + gx) as usize; if grid[idx] && !visited[idx] { - // Start a new component let mut min_gx = gx; let mut max_gx = gx; let mut min_gy = gy; @@ -583,7 +585,6 @@ async fn detect_watermark(path: String) -> Result { if cy < min_gy { min_gy = cy; } if cy > max_gy { max_gy = cy; } - // Neighbors let neighbors = [ (cx.wrapping_sub(1), cy), (cx + 1, cy), (cx, cy.wrapping_sub(1)), (cx, cy + 1) @@ -600,8 +601,6 @@ async fn detect_watermark(path: String) -> Result { } } - // Convert grid rect to normalized image rect - // Add padding (1 cell) let px = (min_gx * cell_size) as f64; let py = (min_gy * cell_size) as f64; let pw = ((max_gx - min_gx + 1) * cell_size) as f64; diff --git a/src-tauri/src/ocr.rs b/src-tauri/src/ocr.rs new file mode 100644 index 0000000..ed703ef --- /dev/null +++ b/src-tauri/src/ocr.rs @@ -0,0 +1,168 @@ +use image::RgbaImage; +use ort::session::{Session, builder::GraphOptimizationLevel}; +use ort::value::Value; +use std::path::Path; + +#[derive(Clone, Debug)] +pub struct DetectedBox { + pub x: f64, + pub y: f64, + pub width: f64, + pub height: f64, +} + +pub fn run_ocr_detection( + model_path: &Path, + input_image: &RgbaImage, +) -> Result, String> { + // 1. Load Model + let mut session = Session::builder() + .map_err(|e| format!("Failed to create session: {}", e))? + .with_optimization_level(GraphOptimizationLevel::Level3) + .map_err(|e| format!("Failed to set opt: {}", e))? + .with_intra_threads(4) + .map_err(|e| format!("Failed to set threads: {}", e))? + .commit_from_file(model_path) + .map_err(|e| format!("Failed to load OCR model: {}", e))?; + + // 2. Preprocess + // DBNet expects standard normalization: (img - mean) / std + // Mean: [0.485, 0.456, 0.406], Std: [0.229, 0.224, 0.225] + // And usually resized to multiple of 32. Limit max size for speed. + let max_side = 1600; // Increase resolution limit + let (orig_w, orig_h) = input_image.dimensions(); + + let mut resize_w = orig_w; + let mut resize_h = orig_h; + + // Resize logic: Limit max side, preserve aspect, ensure divisible by 32 + if resize_w > max_side || resize_h > max_side { + let ratio = max_side as f64 / (orig_w.max(orig_h) as f64); + resize_w = (orig_w as f64 * ratio) as u32; + resize_h = (orig_h as f64 * ratio) as u32; + } + + // Align to 32 + resize_w = (resize_w + 31) / 32 * 32; + resize_h = (resize_h + 31) / 32 * 32; + + // Minimum size + resize_w = resize_w.max(32); + resize_h = resize_h.max(32); + + let resized = image::imageops::resize(input_image, resize_w, resize_h, image::imageops::FilterType::Triangle); + + let channel_stride = (resize_w * resize_h) as usize; + let mut input_data = Vec::with_capacity(1 * 3 * channel_stride); + let mut r_plane = Vec::with_capacity(channel_stride); + let mut g_plane = Vec::with_capacity(channel_stride); + let mut b_plane = Vec::with_capacity(channel_stride); + + let mean = [0.485, 0.456, 0.406]; + let std = [0.229, 0.224, 0.225]; + + for (_x, _y, pixel) in resized.enumerate_pixels() { + let r = pixel[0] as f32 / 255.0; + let g = pixel[1] as f32 / 255.0; + let b = pixel[2] as f32 / 255.0; + + r_plane.push((r - mean[0]) / std[0]); + g_plane.push((g - mean[1]) / std[1]); + b_plane.push((b - mean[2]) / std[2]); + } + input_data.extend(r_plane); + input_data.extend(g_plane); + input_data.extend(b_plane); + + // 3. Inference + let input_shape = vec![1, 3, resize_h as i64, resize_w as i64]; + let input_value = Value::from_array((input_shape, input_data)) + .map_err(|e| format!("Failed to create input tensor: {}", e))?; + + let inputs = ort::inputs![input_value]; // For PP-OCR, usually just one input "x" + let outputs = session.run(inputs).map_err(|e| format!("Inference failed: {}", e))?; + + let output_tensor = outputs.values().next().ok_or("No output")?; + let (shape, data) = output_tensor.try_extract_tensor::() + .map_err(|e| format!("Failed to extract output: {}", e))?; + + // 4. Post-process + // Output shape is [1, 1, H, W] probability map + if shape.len() < 4 { + return Err("Unexpected output shape".to_string()); + } + + let map_w = shape[3] as u32; + let map_h = shape[2] as u32; + + // Create binary map (threshold 0.3) + let threshold = 0.2; // Lower threshold to catch more text + let mut binary_map = vec![false; (map_w * map_h) as usize]; + + for i in 0..binary_map.len() { + if data[i] > threshold { + binary_map[i] = true; + } + } + + // Find Connected Components (Simple Bounding Box finding) + let mut visited = vec![false; binary_map.len()]; + let mut boxes = Vec::new(); + + for y in 0..map_h { + for x in 0..map_w { + let idx = (y * map_w + x) as usize; + if binary_map[idx] && !visited[idx] { + // Flood fill + let mut stack = vec![(x, y)]; + visited[idx] = true; + + let mut min_x = x; + let mut max_x = x; + let mut min_y = y; + let mut max_y = y; + let mut pixel_count = 0; + + while let Some((cx, cy)) = stack.pop() { + pixel_count += 1; + if cx < min_x { min_x = cx; } + if cx > max_x { max_x = cx; } + if cy < min_y { min_y = cy; } + if cy > max_y { max_y = cy; } + + let neighbors = [ + (cx.wrapping_sub(1), cy), (cx + 1, cy), + (cx, cy.wrapping_sub(1)), (cx, cy + 1) + ]; + + for (nx, ny) in neighbors { + if nx < map_w && ny < map_h { + let nidx = (ny * map_w + nx) as usize; + if binary_map[nidx] && !visited[nidx] { + visited[nidx] = true; + stack.push((nx, ny)); + } + } + } + } + + // Filter small noise + if pixel_count < 10 { continue; } + + // Scale back to original + let scale_x = orig_w as f64 / resize_w as f64; + let scale_y = orig_h as f64 / resize_h as f64; + + // Removed brightness check to allow detection of any text detected by DBNet + boxes.push(DetectedBox { + x: min_x as f64 * scale_x / orig_w as f64, + y: min_y as f64 * scale_y / orig_h as f64, + width: (max_x - min_x + 1) as f64 * scale_x / orig_w as f64, + height: (max_y - min_y + 1) as f64 * scale_y / orig_h as f64, + }); + } + } + } + + Ok(boxes) +} diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index 3662d18..de9a70a 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -28,7 +28,7 @@ "bundle": { "active": true, "targets": "all", - "resources": ["resources/lama_fp32.onnx"], + "resources": ["resources/lama_fp32.onnx", "resources/en_PP-OCRv3_det_infer.onnx"], "icon": [ "icons/32x32.png", "icons/128x128.png",