ocr detect
This commit is contained in:
@@ -95,6 +95,7 @@ use ab_glyph::{FontRef, PxScale};
|
|||||||
use tauri::{AppHandle, Manager};
|
use tauri::{AppHandle, Manager};
|
||||||
|
|
||||||
mod lama;
|
mod lama;
|
||||||
|
mod ocr;
|
||||||
|
|
||||||
// Embed the font to ensure it's always available without path issues
|
// Embed the font to ensure it's always available without path issues
|
||||||
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");
|
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");
|
||||||
@@ -453,52 +454,63 @@ struct Rect {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[tauri::command]
|
#[tauri::command]
|
||||||
async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
async fn detect_watermark(app: AppHandle, path: String) -> Result<DetectionResult, String> {
|
||||||
let img = image::open(&path).map_err(|e| e.to_string())?;
|
let img = image::open(&path).map_err(|e| e.to_string())?.to_rgba8();
|
||||||
let (width, height) = img.dimensions();
|
|
||||||
let gray = img.to_luma8();
|
|
||||||
|
|
||||||
// "Stroke Detection" Algorithm
|
// 1. Try OCR Detection
|
||||||
// Distinguishes "Text" (Thin White Strokes) from "Solid White Areas" (Walls, Sky)
|
let ocr_model_path = app.path().resource_dir()
|
||||||
// Logic: A white text pixel must be "sandwiched" by dark pixels within a short distance.
|
.map_err(|e| e.to_string())?
|
||||||
|
.join("resources")
|
||||||
|
.join("en_PP-OCRv3_det_infer.onnx");
|
||||||
|
|
||||||
|
if ocr_model_path.exists() {
|
||||||
|
println!("Using OCR model for detection");
|
||||||
|
match ocr::run_ocr_detection(&ocr_model_path, &img) {
|
||||||
|
Ok(boxes) => {
|
||||||
|
let rects = boxes.into_iter().map(|b| Rect {
|
||||||
|
x: b.x,
|
||||||
|
y: b.y,
|
||||||
|
width: b.width,
|
||||||
|
height: b.height,
|
||||||
|
}).collect();
|
||||||
|
return Ok(DetectionResult { rects });
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("OCR Detection failed: {}", e);
|
||||||
|
// Fallthrough to legacy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eprintln!("OCR model not found at {:?}", ocr_model_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Legacy Detection (Fallback)
|
||||||
|
println!("Falling back to legacy detection");
|
||||||
|
let (width, height) = img.dimensions();
|
||||||
|
let gray = image::DynamicImage::ImageRgba8(img.clone()).to_luma8();
|
||||||
|
|
||||||
let cell_size = 10;
|
let cell_size = 10;
|
||||||
let grid_w = (width + cell_size - 1) / cell_size;
|
let grid_w = (width + cell_size - 1) / cell_size;
|
||||||
let grid_h = (height + cell_size - 1) / cell_size;
|
let grid_h = (height + cell_size - 1) / cell_size;
|
||||||
let mut grid = vec![false; (grid_w * grid_h) as usize];
|
let mut grid = vec![false; (grid_w * grid_h) as usize];
|
||||||
|
|
||||||
// Focus Areas: Top 15%, Bottom 25%
|
|
||||||
let top_limit = (height as f64 * 0.15) as u32;
|
let top_limit = (height as f64 * 0.15) as u32;
|
||||||
let bottom_start = (height as f64 * 0.75) as u32;
|
let bottom_start = (height as f64 * 0.75) as u32;
|
||||||
|
let max_stroke_width = 15;
|
||||||
let max_stroke_width = 15; // Max pixels for a text stroke thickness
|
let contrast_threshold = 40;
|
||||||
let contrast_threshold = 40; // How much darker the background must be
|
let brightness_threshold = 200;
|
||||||
let brightness_threshold = 200; // Text must be at least this white
|
|
||||||
|
|
||||||
for y in 1..height-1 {
|
for y in 1..height-1 {
|
||||||
if y > top_limit && y < bottom_start { continue; }
|
if y > top_limit && y < bottom_start { continue; }
|
||||||
|
|
||||||
for x in 1..width-1 {
|
for x in 1..width-1 {
|
||||||
let p = gray.get_pixel(x, y)[0];
|
let p = gray.get_pixel(x, y)[0];
|
||||||
|
|
||||||
// 1. Must be Bright
|
|
||||||
if p < brightness_threshold { continue; }
|
if p < brightness_threshold { continue; }
|
||||||
|
|
||||||
// 2. Stroke Check
|
|
||||||
// We check for "Vertical Stroke" (Dark - Bright - Dark vertically)
|
|
||||||
// OR "Horizontal Stroke" (Dark - Bright - Dark horizontally)
|
|
||||||
|
|
||||||
let mut is_stroke = false;
|
let mut is_stroke = false;
|
||||||
|
|
||||||
// Check Horizontal Stroke (Vertical boundaries? No, Vertical Stroke has Left/Right boundaries?
|
|
||||||
// Terminology: "Vertical Stroke" is like 'I'. It has Left/Right boundaries.
|
|
||||||
// "Horizontal Stroke" is like '-', It has Up/Down boundaries.
|
|
||||||
|
|
||||||
// Let's check Left/Right boundaries (Vertical Stroke)
|
|
||||||
let mut left_bound = false;
|
let mut left_bound = false;
|
||||||
let mut right_bound = false;
|
let mut right_bound = false;
|
||||||
|
|
||||||
// Search Left
|
|
||||||
for k in 1..=max_stroke_width {
|
for k in 1..=max_stroke_width {
|
||||||
if x < k { break; }
|
if x < k { break; }
|
||||||
let neighbor = gray.get_pixel(x - k, y)[0];
|
let neighbor = gray.get_pixel(x - k, y)[0];
|
||||||
@@ -507,7 +519,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Search Right
|
|
||||||
if left_bound {
|
if left_bound {
|
||||||
for k in 1..=max_stroke_width {
|
for k in 1..=max_stroke_width {
|
||||||
if x + k >= width { break; }
|
if x + k >= width { break; }
|
||||||
@@ -522,11 +533,8 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
|||||||
if left_bound && right_bound {
|
if left_bound && right_bound {
|
||||||
is_stroke = true;
|
is_stroke = true;
|
||||||
} else {
|
} else {
|
||||||
// Check Up/Down boundaries (Horizontal Stroke)
|
|
||||||
let mut up_bound = false;
|
let mut up_bound = false;
|
||||||
let mut down_bound = false;
|
let mut down_bound = false;
|
||||||
|
|
||||||
// Search Up
|
|
||||||
for k in 1..=max_stroke_width {
|
for k in 1..=max_stroke_width {
|
||||||
if y < k { break; }
|
if y < k { break; }
|
||||||
let neighbor = gray.get_pixel(x, y - k)[0];
|
let neighbor = gray.get_pixel(x, y - k)[0];
|
||||||
@@ -535,7 +543,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Search Down
|
|
||||||
if up_bound {
|
if up_bound {
|
||||||
for k in 1..=max_stroke_width {
|
for k in 1..=max_stroke_width {
|
||||||
if y + k >= height { break; }
|
if y + k >= height { break; }
|
||||||
@@ -546,10 +553,7 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if up_bound && down_bound { is_stroke = true; }
|
||||||
if up_bound && down_bound {
|
|
||||||
is_stroke = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_stroke {
|
if is_stroke {
|
||||||
@@ -560,7 +564,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connected Components on Grid (Simple merging)
|
|
||||||
let mut rects = Vec::new();
|
let mut rects = Vec::new();
|
||||||
let mut visited = vec![false; grid.len()];
|
let mut visited = vec![false; grid.len()];
|
||||||
|
|
||||||
@@ -568,7 +571,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
|||||||
for gx in 0..grid_w {
|
for gx in 0..grid_w {
|
||||||
let idx = (gy * grid_w + gx) as usize;
|
let idx = (gy * grid_w + gx) as usize;
|
||||||
if grid[idx] && !visited[idx] {
|
if grid[idx] && !visited[idx] {
|
||||||
// Start a new component
|
|
||||||
let mut min_gx = gx;
|
let mut min_gx = gx;
|
||||||
let mut max_gx = gx;
|
let mut max_gx = gx;
|
||||||
let mut min_gy = gy;
|
let mut min_gy = gy;
|
||||||
@@ -583,7 +585,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
|||||||
if cy < min_gy { min_gy = cy; }
|
if cy < min_gy { min_gy = cy; }
|
||||||
if cy > max_gy { max_gy = cy; }
|
if cy > max_gy { max_gy = cy; }
|
||||||
|
|
||||||
// Neighbors
|
|
||||||
let neighbors = [
|
let neighbors = [
|
||||||
(cx.wrapping_sub(1), cy), (cx + 1, cy),
|
(cx.wrapping_sub(1), cy), (cx + 1, cy),
|
||||||
(cx, cy.wrapping_sub(1)), (cx, cy + 1)
|
(cx, cy.wrapping_sub(1)), (cx, cy + 1)
|
||||||
@@ -600,8 +601,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert grid rect to normalized image rect
|
|
||||||
// Add padding (1 cell)
|
|
||||||
let px = (min_gx * cell_size) as f64;
|
let px = (min_gx * cell_size) as f64;
|
||||||
let py = (min_gy * cell_size) as f64;
|
let py = (min_gy * cell_size) as f64;
|
||||||
let pw = ((max_gx - min_gx + 1) * cell_size) as f64;
|
let pw = ((max_gx - min_gx + 1) * cell_size) as f64;
|
||||||
|
|||||||
168
src-tauri/src/ocr.rs
Normal file
168
src-tauri/src/ocr.rs
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
use image::RgbaImage;
|
||||||
|
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||||
|
use ort::value::Value;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct DetectedBox {
|
||||||
|
pub x: f64,
|
||||||
|
pub y: f64,
|
||||||
|
pub width: f64,
|
||||||
|
pub height: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn run_ocr_detection(
|
||||||
|
model_path: &Path,
|
||||||
|
input_image: &RgbaImage,
|
||||||
|
) -> Result<Vec<DetectedBox>, String> {
|
||||||
|
// 1. Load Model
|
||||||
|
let mut session = Session::builder()
|
||||||
|
.map_err(|e| format!("Failed to create session: {}", e))?
|
||||||
|
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||||
|
.map_err(|e| format!("Failed to set opt: {}", e))?
|
||||||
|
.with_intra_threads(4)
|
||||||
|
.map_err(|e| format!("Failed to set threads: {}", e))?
|
||||||
|
.commit_from_file(model_path)
|
||||||
|
.map_err(|e| format!("Failed to load OCR model: {}", e))?;
|
||||||
|
|
||||||
|
// 2. Preprocess
|
||||||
|
// DBNet expects standard normalization: (img - mean) / std
|
||||||
|
// Mean: [0.485, 0.456, 0.406], Std: [0.229, 0.224, 0.225]
|
||||||
|
// And usually resized to multiple of 32. Limit max size for speed.
|
||||||
|
let max_side = 1600; // Increase resolution limit
|
||||||
|
let (orig_w, orig_h) = input_image.dimensions();
|
||||||
|
|
||||||
|
let mut resize_w = orig_w;
|
||||||
|
let mut resize_h = orig_h;
|
||||||
|
|
||||||
|
// Resize logic: Limit max side, preserve aspect, ensure divisible by 32
|
||||||
|
if resize_w > max_side || resize_h > max_side {
|
||||||
|
let ratio = max_side as f64 / (orig_w.max(orig_h) as f64);
|
||||||
|
resize_w = (orig_w as f64 * ratio) as u32;
|
||||||
|
resize_h = (orig_h as f64 * ratio) as u32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Align to 32
|
||||||
|
resize_w = (resize_w + 31) / 32 * 32;
|
||||||
|
resize_h = (resize_h + 31) / 32 * 32;
|
||||||
|
|
||||||
|
// Minimum size
|
||||||
|
resize_w = resize_w.max(32);
|
||||||
|
resize_h = resize_h.max(32);
|
||||||
|
|
||||||
|
let resized = image::imageops::resize(input_image, resize_w, resize_h, image::imageops::FilterType::Triangle);
|
||||||
|
|
||||||
|
let channel_stride = (resize_w * resize_h) as usize;
|
||||||
|
let mut input_data = Vec::with_capacity(1 * 3 * channel_stride);
|
||||||
|
let mut r_plane = Vec::with_capacity(channel_stride);
|
||||||
|
let mut g_plane = Vec::with_capacity(channel_stride);
|
||||||
|
let mut b_plane = Vec::with_capacity(channel_stride);
|
||||||
|
|
||||||
|
let mean = [0.485, 0.456, 0.406];
|
||||||
|
let std = [0.229, 0.224, 0.225];
|
||||||
|
|
||||||
|
for (_x, _y, pixel) in resized.enumerate_pixels() {
|
||||||
|
let r = pixel[0] as f32 / 255.0;
|
||||||
|
let g = pixel[1] as f32 / 255.0;
|
||||||
|
let b = pixel[2] as f32 / 255.0;
|
||||||
|
|
||||||
|
r_plane.push((r - mean[0]) / std[0]);
|
||||||
|
g_plane.push((g - mean[1]) / std[1]);
|
||||||
|
b_plane.push((b - mean[2]) / std[2]);
|
||||||
|
}
|
||||||
|
input_data.extend(r_plane);
|
||||||
|
input_data.extend(g_plane);
|
||||||
|
input_data.extend(b_plane);
|
||||||
|
|
||||||
|
// 3. Inference
|
||||||
|
let input_shape = vec![1, 3, resize_h as i64, resize_w as i64];
|
||||||
|
let input_value = Value::from_array((input_shape, input_data))
|
||||||
|
.map_err(|e| format!("Failed to create input tensor: {}", e))?;
|
||||||
|
|
||||||
|
let inputs = ort::inputs![input_value]; // For PP-OCR, usually just one input "x"
|
||||||
|
let outputs = session.run(inputs).map_err(|e| format!("Inference failed: {}", e))?;
|
||||||
|
|
||||||
|
let output_tensor = outputs.values().next().ok_or("No output")?;
|
||||||
|
let (shape, data) = output_tensor.try_extract_tensor::<f32>()
|
||||||
|
.map_err(|e| format!("Failed to extract output: {}", e))?;
|
||||||
|
|
||||||
|
// 4. Post-process
|
||||||
|
// Output shape is [1, 1, H, W] probability map
|
||||||
|
if shape.len() < 4 {
|
||||||
|
return Err("Unexpected output shape".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let map_w = shape[3] as u32;
|
||||||
|
let map_h = shape[2] as u32;
|
||||||
|
|
||||||
|
// Create binary map (threshold 0.3)
|
||||||
|
let threshold = 0.2; // Lower threshold to catch more text
|
||||||
|
let mut binary_map = vec![false; (map_w * map_h) as usize];
|
||||||
|
|
||||||
|
for i in 0..binary_map.len() {
|
||||||
|
if data[i] > threshold {
|
||||||
|
binary_map[i] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find Connected Components (Simple Bounding Box finding)
|
||||||
|
let mut visited = vec![false; binary_map.len()];
|
||||||
|
let mut boxes = Vec::new();
|
||||||
|
|
||||||
|
for y in 0..map_h {
|
||||||
|
for x in 0..map_w {
|
||||||
|
let idx = (y * map_w + x) as usize;
|
||||||
|
if binary_map[idx] && !visited[idx] {
|
||||||
|
// Flood fill
|
||||||
|
let mut stack = vec![(x, y)];
|
||||||
|
visited[idx] = true;
|
||||||
|
|
||||||
|
let mut min_x = x;
|
||||||
|
let mut max_x = x;
|
||||||
|
let mut min_y = y;
|
||||||
|
let mut max_y = y;
|
||||||
|
let mut pixel_count = 0;
|
||||||
|
|
||||||
|
while let Some((cx, cy)) = stack.pop() {
|
||||||
|
pixel_count += 1;
|
||||||
|
if cx < min_x { min_x = cx; }
|
||||||
|
if cx > max_x { max_x = cx; }
|
||||||
|
if cy < min_y { min_y = cy; }
|
||||||
|
if cy > max_y { max_y = cy; }
|
||||||
|
|
||||||
|
let neighbors = [
|
||||||
|
(cx.wrapping_sub(1), cy), (cx + 1, cy),
|
||||||
|
(cx, cy.wrapping_sub(1)), (cx, cy + 1)
|
||||||
|
];
|
||||||
|
|
||||||
|
for (nx, ny) in neighbors {
|
||||||
|
if nx < map_w && ny < map_h {
|
||||||
|
let nidx = (ny * map_w + nx) as usize;
|
||||||
|
if binary_map[nidx] && !visited[nidx] {
|
||||||
|
visited[nidx] = true;
|
||||||
|
stack.push((nx, ny));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter small noise
|
||||||
|
if pixel_count < 10 { continue; }
|
||||||
|
|
||||||
|
// Scale back to original
|
||||||
|
let scale_x = orig_w as f64 / resize_w as f64;
|
||||||
|
let scale_y = orig_h as f64 / resize_h as f64;
|
||||||
|
|
||||||
|
// Removed brightness check to allow detection of any text detected by DBNet
|
||||||
|
boxes.push(DetectedBox {
|
||||||
|
x: min_x as f64 * scale_x / orig_w as f64,
|
||||||
|
y: min_y as f64 * scale_y / orig_h as f64,
|
||||||
|
width: (max_x - min_x + 1) as f64 * scale_x / orig_w as f64,
|
||||||
|
height: (max_y - min_y + 1) as f64 * scale_y / orig_h as f64,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(boxes)
|
||||||
|
}
|
||||||
@@ -28,7 +28,7 @@
|
|||||||
"bundle": {
|
"bundle": {
|
||||||
"active": true,
|
"active": true,
|
||||||
"targets": "all",
|
"targets": "all",
|
||||||
"resources": ["resources/lama_fp32.onnx"],
|
"resources": ["resources/lama_fp32.onnx", "resources/en_PP-OCRv3_det_infer.onnx"],
|
||||||
"icon": [
|
"icon": [
|
||||||
"icons/32x32.png",
|
"icons/32x32.png",
|
||||||
"icons/128x128.png",
|
"icons/128x128.png",
|
||||||
|
|||||||
Reference in New Issue
Block a user