ocr detect
This commit is contained in:
@@ -95,6 +95,7 @@ use ab_glyph::{FontRef, PxScale};
|
||||
use tauri::{AppHandle, Manager};
|
||||
|
||||
mod lama;
|
||||
mod ocr;
|
||||
|
||||
// Embed the font to ensure it's always available without path issues
|
||||
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");
|
||||
@@ -453,52 +454,63 @@ struct Rect {
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
||||
let img = image::open(&path).map_err(|e| e.to_string())?;
|
||||
let (width, height) = img.dimensions();
|
||||
let gray = img.to_luma8();
|
||||
async fn detect_watermark(app: AppHandle, path: String) -> Result<DetectionResult, String> {
|
||||
let img = image::open(&path).map_err(|e| e.to_string())?.to_rgba8();
|
||||
|
||||
// "Stroke Detection" Algorithm
|
||||
// Distinguishes "Text" (Thin White Strokes) from "Solid White Areas" (Walls, Sky)
|
||||
// Logic: A white text pixel must be "sandwiched" by dark pixels within a short distance.
|
||||
// 1. Try OCR Detection
|
||||
let ocr_model_path = app.path().resource_dir()
|
||||
.map_err(|e| e.to_string())?
|
||||
.join("resources")
|
||||
.join("en_PP-OCRv3_det_infer.onnx");
|
||||
|
||||
if ocr_model_path.exists() {
|
||||
println!("Using OCR model for detection");
|
||||
match ocr::run_ocr_detection(&ocr_model_path, &img) {
|
||||
Ok(boxes) => {
|
||||
let rects = boxes.into_iter().map(|b| Rect {
|
||||
x: b.x,
|
||||
y: b.y,
|
||||
width: b.width,
|
||||
height: b.height,
|
||||
}).collect();
|
||||
return Ok(DetectionResult { rects });
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("OCR Detection failed: {}", e);
|
||||
// Fallthrough to legacy
|
||||
}
|
||||
}
|
||||
} else {
|
||||
eprintln!("OCR model not found at {:?}", ocr_model_path);
|
||||
}
|
||||
|
||||
// 2. Legacy Detection (Fallback)
|
||||
println!("Falling back to legacy detection");
|
||||
let (width, height) = img.dimensions();
|
||||
let gray = image::DynamicImage::ImageRgba8(img.clone()).to_luma8();
|
||||
|
||||
let cell_size = 10;
|
||||
let grid_w = (width + cell_size - 1) / cell_size;
|
||||
let grid_h = (height + cell_size - 1) / cell_size;
|
||||
let mut grid = vec![false; (grid_w * grid_h) as usize];
|
||||
|
||||
// Focus Areas: Top 15%, Bottom 25%
|
||||
let top_limit = (height as f64 * 0.15) as u32;
|
||||
let bottom_start = (height as f64 * 0.75) as u32;
|
||||
|
||||
let max_stroke_width = 15; // Max pixels for a text stroke thickness
|
||||
let contrast_threshold = 40; // How much darker the background must be
|
||||
let brightness_threshold = 200; // Text must be at least this white
|
||||
let max_stroke_width = 15;
|
||||
let contrast_threshold = 40;
|
||||
let brightness_threshold = 200;
|
||||
|
||||
for y in 1..height-1 {
|
||||
if y > top_limit && y < bottom_start { continue; }
|
||||
|
||||
for x in 1..width-1 {
|
||||
let p = gray.get_pixel(x, y)[0];
|
||||
|
||||
// 1. Must be Bright
|
||||
if p < brightness_threshold { continue; }
|
||||
|
||||
// 2. Stroke Check
|
||||
// We check for "Vertical Stroke" (Dark - Bright - Dark vertically)
|
||||
// OR "Horizontal Stroke" (Dark - Bright - Dark horizontally)
|
||||
|
||||
let mut is_stroke = false;
|
||||
|
||||
// Check Horizontal Stroke (Vertical boundaries? No, Vertical Stroke has Left/Right boundaries?
|
||||
// Terminology: "Vertical Stroke" is like 'I'. It has Left/Right boundaries.
|
||||
// "Horizontal Stroke" is like '-', It has Up/Down boundaries.
|
||||
|
||||
// Let's check Left/Right boundaries (Vertical Stroke)
|
||||
let mut left_bound = false;
|
||||
let mut right_bound = false;
|
||||
|
||||
// Search Left
|
||||
for k in 1..=max_stroke_width {
|
||||
if x < k { break; }
|
||||
let neighbor = gray.get_pixel(x - k, y)[0];
|
||||
@@ -507,7 +519,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Search Right
|
||||
if left_bound {
|
||||
for k in 1..=max_stroke_width {
|
||||
if x + k >= width { break; }
|
||||
@@ -522,11 +533,8 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
||||
if left_bound && right_bound {
|
||||
is_stroke = true;
|
||||
} else {
|
||||
// Check Up/Down boundaries (Horizontal Stroke)
|
||||
let mut up_bound = false;
|
||||
let mut down_bound = false;
|
||||
|
||||
// Search Up
|
||||
for k in 1..=max_stroke_width {
|
||||
if y < k { break; }
|
||||
let neighbor = gray.get_pixel(x, y - k)[0];
|
||||
@@ -535,7 +543,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Search Down
|
||||
if up_bound {
|
||||
for k in 1..=max_stroke_width {
|
||||
if y + k >= height { break; }
|
||||
@@ -546,10 +553,7 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if up_bound && down_bound {
|
||||
is_stroke = true;
|
||||
}
|
||||
if up_bound && down_bound { is_stroke = true; }
|
||||
}
|
||||
|
||||
if is_stroke {
|
||||
@@ -560,7 +564,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
||||
}
|
||||
}
|
||||
|
||||
// Connected Components on Grid (Simple merging)
|
||||
let mut rects = Vec::new();
|
||||
let mut visited = vec![false; grid.len()];
|
||||
|
||||
@@ -568,7 +571,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
||||
for gx in 0..grid_w {
|
||||
let idx = (gy * grid_w + gx) as usize;
|
||||
if grid[idx] && !visited[idx] {
|
||||
// Start a new component
|
||||
let mut min_gx = gx;
|
||||
let mut max_gx = gx;
|
||||
let mut min_gy = gy;
|
||||
@@ -583,7 +585,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
||||
if cy < min_gy { min_gy = cy; }
|
||||
if cy > max_gy { max_gy = cy; }
|
||||
|
||||
// Neighbors
|
||||
let neighbors = [
|
||||
(cx.wrapping_sub(1), cy), (cx + 1, cy),
|
||||
(cx, cy.wrapping_sub(1)), (cx, cy + 1)
|
||||
@@ -600,8 +601,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
||||
}
|
||||
}
|
||||
|
||||
// Convert grid rect to normalized image rect
|
||||
// Add padding (1 cell)
|
||||
let px = (min_gx * cell_size) as f64;
|
||||
let py = (min_gy * cell_size) as f64;
|
||||
let pw = ((max_gx - min_gx + 1) * cell_size) as f64;
|
||||
|
||||
168
src-tauri/src/ocr.rs
Normal file
168
src-tauri/src/ocr.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use image::RgbaImage;
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use ort::value::Value;
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DetectedBox {
|
||||
pub x: f64,
|
||||
pub y: f64,
|
||||
pub width: f64,
|
||||
pub height: f64,
|
||||
}
|
||||
|
||||
pub fn run_ocr_detection(
|
||||
model_path: &Path,
|
||||
input_image: &RgbaImage,
|
||||
) -> Result<Vec<DetectedBox>, String> {
|
||||
// 1. Load Model
|
||||
let mut session = Session::builder()
|
||||
.map_err(|e| format!("Failed to create session: {}", e))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.map_err(|e| format!("Failed to set opt: {}", e))?
|
||||
.with_intra_threads(4)
|
||||
.map_err(|e| format!("Failed to set threads: {}", e))?
|
||||
.commit_from_file(model_path)
|
||||
.map_err(|e| format!("Failed to load OCR model: {}", e))?;
|
||||
|
||||
// 2. Preprocess
|
||||
// DBNet expects standard normalization: (img - mean) / std
|
||||
// Mean: [0.485, 0.456, 0.406], Std: [0.229, 0.224, 0.225]
|
||||
// And usually resized to multiple of 32. Limit max size for speed.
|
||||
let max_side = 1600; // Increase resolution limit
|
||||
let (orig_w, orig_h) = input_image.dimensions();
|
||||
|
||||
let mut resize_w = orig_w;
|
||||
let mut resize_h = orig_h;
|
||||
|
||||
// Resize logic: Limit max side, preserve aspect, ensure divisible by 32
|
||||
if resize_w > max_side || resize_h > max_side {
|
||||
let ratio = max_side as f64 / (orig_w.max(orig_h) as f64);
|
||||
resize_w = (orig_w as f64 * ratio) as u32;
|
||||
resize_h = (orig_h as f64 * ratio) as u32;
|
||||
}
|
||||
|
||||
// Align to 32
|
||||
resize_w = (resize_w + 31) / 32 * 32;
|
||||
resize_h = (resize_h + 31) / 32 * 32;
|
||||
|
||||
// Minimum size
|
||||
resize_w = resize_w.max(32);
|
||||
resize_h = resize_h.max(32);
|
||||
|
||||
let resized = image::imageops::resize(input_image, resize_w, resize_h, image::imageops::FilterType::Triangle);
|
||||
|
||||
let channel_stride = (resize_w * resize_h) as usize;
|
||||
let mut input_data = Vec::with_capacity(1 * 3 * channel_stride);
|
||||
let mut r_plane = Vec::with_capacity(channel_stride);
|
||||
let mut g_plane = Vec::with_capacity(channel_stride);
|
||||
let mut b_plane = Vec::with_capacity(channel_stride);
|
||||
|
||||
let mean = [0.485, 0.456, 0.406];
|
||||
let std = [0.229, 0.224, 0.225];
|
||||
|
||||
for (_x, _y, pixel) in resized.enumerate_pixels() {
|
||||
let r = pixel[0] as f32 / 255.0;
|
||||
let g = pixel[1] as f32 / 255.0;
|
||||
let b = pixel[2] as f32 / 255.0;
|
||||
|
||||
r_plane.push((r - mean[0]) / std[0]);
|
||||
g_plane.push((g - mean[1]) / std[1]);
|
||||
b_plane.push((b - mean[2]) / std[2]);
|
||||
}
|
||||
input_data.extend(r_plane);
|
||||
input_data.extend(g_plane);
|
||||
input_data.extend(b_plane);
|
||||
|
||||
// 3. Inference
|
||||
let input_shape = vec![1, 3, resize_h as i64, resize_w as i64];
|
||||
let input_value = Value::from_array((input_shape, input_data))
|
||||
.map_err(|e| format!("Failed to create input tensor: {}", e))?;
|
||||
|
||||
let inputs = ort::inputs![input_value]; // For PP-OCR, usually just one input "x"
|
||||
let outputs = session.run(inputs).map_err(|e| format!("Inference failed: {}", e))?;
|
||||
|
||||
let output_tensor = outputs.values().next().ok_or("No output")?;
|
||||
let (shape, data) = output_tensor.try_extract_tensor::<f32>()
|
||||
.map_err(|e| format!("Failed to extract output: {}", e))?;
|
||||
|
||||
// 4. Post-process
|
||||
// Output shape is [1, 1, H, W] probability map
|
||||
if shape.len() < 4 {
|
||||
return Err("Unexpected output shape".to_string());
|
||||
}
|
||||
|
||||
let map_w = shape[3] as u32;
|
||||
let map_h = shape[2] as u32;
|
||||
|
||||
// Create binary map (threshold 0.3)
|
||||
let threshold = 0.2; // Lower threshold to catch more text
|
||||
let mut binary_map = vec![false; (map_w * map_h) as usize];
|
||||
|
||||
for i in 0..binary_map.len() {
|
||||
if data[i] > threshold {
|
||||
binary_map[i] = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Find Connected Components (Simple Bounding Box finding)
|
||||
let mut visited = vec![false; binary_map.len()];
|
||||
let mut boxes = Vec::new();
|
||||
|
||||
for y in 0..map_h {
|
||||
for x in 0..map_w {
|
||||
let idx = (y * map_w + x) as usize;
|
||||
if binary_map[idx] && !visited[idx] {
|
||||
// Flood fill
|
||||
let mut stack = vec![(x, y)];
|
||||
visited[idx] = true;
|
||||
|
||||
let mut min_x = x;
|
||||
let mut max_x = x;
|
||||
let mut min_y = y;
|
||||
let mut max_y = y;
|
||||
let mut pixel_count = 0;
|
||||
|
||||
while let Some((cx, cy)) = stack.pop() {
|
||||
pixel_count += 1;
|
||||
if cx < min_x { min_x = cx; }
|
||||
if cx > max_x { max_x = cx; }
|
||||
if cy < min_y { min_y = cy; }
|
||||
if cy > max_y { max_y = cy; }
|
||||
|
||||
let neighbors = [
|
||||
(cx.wrapping_sub(1), cy), (cx + 1, cy),
|
||||
(cx, cy.wrapping_sub(1)), (cx, cy + 1)
|
||||
];
|
||||
|
||||
for (nx, ny) in neighbors {
|
||||
if nx < map_w && ny < map_h {
|
||||
let nidx = (ny * map_w + nx) as usize;
|
||||
if binary_map[nidx] && !visited[nidx] {
|
||||
visited[nidx] = true;
|
||||
stack.push((nx, ny));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter small noise
|
||||
if pixel_count < 10 { continue; }
|
||||
|
||||
// Scale back to original
|
||||
let scale_x = orig_w as f64 / resize_w as f64;
|
||||
let scale_y = orig_h as f64 / resize_h as f64;
|
||||
|
||||
// Removed brightness check to allow detection of any text detected by DBNet
|
||||
boxes.push(DetectedBox {
|
||||
x: min_x as f64 * scale_x / orig_w as f64,
|
||||
y: min_y as f64 * scale_y / orig_h as f64,
|
||||
width: (max_x - min_x + 1) as f64 * scale_x / orig_w as f64,
|
||||
height: (max_y - min_y + 1) as f64 * scale_y / orig_h as f64,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(boxes)
|
||||
}
|
||||
Reference in New Issue
Block a user