ocr detect

This commit is contained in:
Julian Freeman
2026-01-19 12:42:05 -04:00
parent 6439759b04
commit eb251b5eac
3 changed files with 207 additions and 40 deletions

View File

@@ -95,6 +95,7 @@ use ab_glyph::{FontRef, PxScale};
use tauri::{AppHandle, Manager}; use tauri::{AppHandle, Manager};
mod lama; mod lama;
mod ocr;
// Embed the font to ensure it's always available without path issues // Embed the font to ensure it's always available without path issues
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf"); const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");
@@ -453,52 +454,63 @@ struct Rect {
} }
#[tauri::command] #[tauri::command]
async fn detect_watermark(path: String) -> Result<DetectionResult, String> { async fn detect_watermark(app: AppHandle, path: String) -> Result<DetectionResult, String> {
let img = image::open(&path).map_err(|e| e.to_string())?; let img = image::open(&path).map_err(|e| e.to_string())?.to_rgba8();
let (width, height) = img.dimensions();
let gray = img.to_luma8();
// "Stroke Detection" Algorithm // 1. Try OCR Detection
// Distinguishes "Text" (Thin White Strokes) from "Solid White Areas" (Walls, Sky) let ocr_model_path = app.path().resource_dir()
// Logic: A white text pixel must be "sandwiched" by dark pixels within a short distance. .map_err(|e| e.to_string())?
.join("resources")
.join("en_PP-OCRv3_det_infer.onnx");
if ocr_model_path.exists() {
println!("Using OCR model for detection");
match ocr::run_ocr_detection(&ocr_model_path, &img) {
Ok(boxes) => {
let rects = boxes.into_iter().map(|b| Rect {
x: b.x,
y: b.y,
width: b.width,
height: b.height,
}).collect();
return Ok(DetectionResult { rects });
}
Err(e) => {
eprintln!("OCR Detection failed: {}", e);
// Fallthrough to legacy
}
}
} else {
eprintln!("OCR model not found at {:?}", ocr_model_path);
}
// 2. Legacy Detection (Fallback)
println!("Falling back to legacy detection");
let (width, height) = img.dimensions();
let gray = image::DynamicImage::ImageRgba8(img.clone()).to_luma8();
let cell_size = 10; let cell_size = 10;
let grid_w = (width + cell_size - 1) / cell_size; let grid_w = (width + cell_size - 1) / cell_size;
let grid_h = (height + cell_size - 1) / cell_size; let grid_h = (height + cell_size - 1) / cell_size;
let mut grid = vec![false; (grid_w * grid_h) as usize]; let mut grid = vec![false; (grid_w * grid_h) as usize];
// Focus Areas: Top 15%, Bottom 25%
let top_limit = (height as f64 * 0.15) as u32; let top_limit = (height as f64 * 0.15) as u32;
let bottom_start = (height as f64 * 0.75) as u32; let bottom_start = (height as f64 * 0.75) as u32;
let max_stroke_width = 15;
let max_stroke_width = 15; // Max pixels for a text stroke thickness let contrast_threshold = 40;
let contrast_threshold = 40; // How much darker the background must be let brightness_threshold = 200;
let brightness_threshold = 200; // Text must be at least this white
for y in 1..height-1 { for y in 1..height-1 {
if y > top_limit && y < bottom_start { continue; } if y > top_limit && y < bottom_start { continue; }
for x in 1..width-1 { for x in 1..width-1 {
let p = gray.get_pixel(x, y)[0]; let p = gray.get_pixel(x, y)[0];
// 1. Must be Bright
if p < brightness_threshold { continue; } if p < brightness_threshold { continue; }
// 2. Stroke Check
// We check for "Vertical Stroke" (Dark - Bright - Dark vertically)
// OR "Horizontal Stroke" (Dark - Bright - Dark horizontally)
let mut is_stroke = false; let mut is_stroke = false;
// Check Horizontal Stroke (Vertical boundaries? No, Vertical Stroke has Left/Right boundaries?
// Terminology: "Vertical Stroke" is like 'I'. It has Left/Right boundaries.
// "Horizontal Stroke" is like '-', It has Up/Down boundaries.
// Let's check Left/Right boundaries (Vertical Stroke)
let mut left_bound = false; let mut left_bound = false;
let mut right_bound = false; let mut right_bound = false;
// Search Left
for k in 1..=max_stroke_width { for k in 1..=max_stroke_width {
if x < k { break; } if x < k { break; }
let neighbor = gray.get_pixel(x - k, y)[0]; let neighbor = gray.get_pixel(x - k, y)[0];
@@ -507,7 +519,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
break; break;
} }
} }
// Search Right
if left_bound { if left_bound {
for k in 1..=max_stroke_width { for k in 1..=max_stroke_width {
if x + k >= width { break; } if x + k >= width { break; }
@@ -522,11 +533,8 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
if left_bound && right_bound { if left_bound && right_bound {
is_stroke = true; is_stroke = true;
} else { } else {
// Check Up/Down boundaries (Horizontal Stroke)
let mut up_bound = false; let mut up_bound = false;
let mut down_bound = false; let mut down_bound = false;
// Search Up
for k in 1..=max_stroke_width { for k in 1..=max_stroke_width {
if y < k { break; } if y < k { break; }
let neighbor = gray.get_pixel(x, y - k)[0]; let neighbor = gray.get_pixel(x, y - k)[0];
@@ -535,7 +543,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
break; break;
} }
} }
// Search Down
if up_bound { if up_bound {
for k in 1..=max_stroke_width { for k in 1..=max_stroke_width {
if y + k >= height { break; } if y + k >= height { break; }
@@ -546,10 +553,7 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
} }
} }
} }
if up_bound && down_bound { is_stroke = true; }
if up_bound && down_bound {
is_stroke = true;
}
} }
if is_stroke { if is_stroke {
@@ -560,7 +564,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
} }
} }
// Connected Components on Grid (Simple merging)
let mut rects = Vec::new(); let mut rects = Vec::new();
let mut visited = vec![false; grid.len()]; let mut visited = vec![false; grid.len()];
@@ -568,7 +571,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
for gx in 0..grid_w { for gx in 0..grid_w {
let idx = (gy * grid_w + gx) as usize; let idx = (gy * grid_w + gx) as usize;
if grid[idx] && !visited[idx] { if grid[idx] && !visited[idx] {
// Start a new component
let mut min_gx = gx; let mut min_gx = gx;
let mut max_gx = gx; let mut max_gx = gx;
let mut min_gy = gy; let mut min_gy = gy;
@@ -583,7 +585,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
if cy < min_gy { min_gy = cy; } if cy < min_gy { min_gy = cy; }
if cy > max_gy { max_gy = cy; } if cy > max_gy { max_gy = cy; }
// Neighbors
let neighbors = [ let neighbors = [
(cx.wrapping_sub(1), cy), (cx + 1, cy), (cx.wrapping_sub(1), cy), (cx + 1, cy),
(cx, cy.wrapping_sub(1)), (cx, cy + 1) (cx, cy.wrapping_sub(1)), (cx, cy + 1)
@@ -600,8 +601,6 @@ async fn detect_watermark(path: String) -> Result<DetectionResult, String> {
} }
} }
// Convert grid rect to normalized image rect
// Add padding (1 cell)
let px = (min_gx * cell_size) as f64; let px = (min_gx * cell_size) as f64;
let py = (min_gy * cell_size) as f64; let py = (min_gy * cell_size) as f64;
let pw = ((max_gx - min_gx + 1) * cell_size) as f64; let pw = ((max_gx - min_gx + 1) * cell_size) as f64;

168
src-tauri/src/ocr.rs Normal file
View File

@@ -0,0 +1,168 @@
use image::RgbaImage;
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
use std::path::Path;
#[derive(Clone, Debug)]
pub struct DetectedBox {
pub x: f64,
pub y: f64,
pub width: f64,
pub height: f64,
}
pub fn run_ocr_detection(
model_path: &Path,
input_image: &RgbaImage,
) -> Result<Vec<DetectedBox>, String> {
// 1. Load Model
let mut session = Session::builder()
.map_err(|e| format!("Failed to create session: {}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| format!("Failed to set opt: {}", e))?
.with_intra_threads(4)
.map_err(|e| format!("Failed to set threads: {}", e))?
.commit_from_file(model_path)
.map_err(|e| format!("Failed to load OCR model: {}", e))?;
// 2. Preprocess
// DBNet expects standard normalization: (img - mean) / std
// Mean: [0.485, 0.456, 0.406], Std: [0.229, 0.224, 0.225]
// And usually resized to multiple of 32. Limit max size for speed.
let max_side = 1600; // Increase resolution limit
let (orig_w, orig_h) = input_image.dimensions();
let mut resize_w = orig_w;
let mut resize_h = orig_h;
// Resize logic: Limit max side, preserve aspect, ensure divisible by 32
if resize_w > max_side || resize_h > max_side {
let ratio = max_side as f64 / (orig_w.max(orig_h) as f64);
resize_w = (orig_w as f64 * ratio) as u32;
resize_h = (orig_h as f64 * ratio) as u32;
}
// Align to 32
resize_w = (resize_w + 31) / 32 * 32;
resize_h = (resize_h + 31) / 32 * 32;
// Minimum size
resize_w = resize_w.max(32);
resize_h = resize_h.max(32);
let resized = image::imageops::resize(input_image, resize_w, resize_h, image::imageops::FilterType::Triangle);
let channel_stride = (resize_w * resize_h) as usize;
let mut input_data = Vec::with_capacity(1 * 3 * channel_stride);
let mut r_plane = Vec::with_capacity(channel_stride);
let mut g_plane = Vec::with_capacity(channel_stride);
let mut b_plane = Vec::with_capacity(channel_stride);
let mean = [0.485, 0.456, 0.406];
let std = [0.229, 0.224, 0.225];
for (_x, _y, pixel) in resized.enumerate_pixels() {
let r = pixel[0] as f32 / 255.0;
let g = pixel[1] as f32 / 255.0;
let b = pixel[2] as f32 / 255.0;
r_plane.push((r - mean[0]) / std[0]);
g_plane.push((g - mean[1]) / std[1]);
b_plane.push((b - mean[2]) / std[2]);
}
input_data.extend(r_plane);
input_data.extend(g_plane);
input_data.extend(b_plane);
// 3. Inference
let input_shape = vec![1, 3, resize_h as i64, resize_w as i64];
let input_value = Value::from_array((input_shape, input_data))
.map_err(|e| format!("Failed to create input tensor: {}", e))?;
let inputs = ort::inputs![input_value]; // For PP-OCR, usually just one input "x"
let outputs = session.run(inputs).map_err(|e| format!("Inference failed: {}", e))?;
let output_tensor = outputs.values().next().ok_or("No output")?;
let (shape, data) = output_tensor.try_extract_tensor::<f32>()
.map_err(|e| format!("Failed to extract output: {}", e))?;
// 4. Post-process
// Output shape is [1, 1, H, W] probability map
if shape.len() < 4 {
return Err("Unexpected output shape".to_string());
}
let map_w = shape[3] as u32;
let map_h = shape[2] as u32;
// Create binary map (threshold 0.3)
let threshold = 0.2; // Lower threshold to catch more text
let mut binary_map = vec![false; (map_w * map_h) as usize];
for i in 0..binary_map.len() {
if data[i] > threshold {
binary_map[i] = true;
}
}
// Find Connected Components (Simple Bounding Box finding)
let mut visited = vec![false; binary_map.len()];
let mut boxes = Vec::new();
for y in 0..map_h {
for x in 0..map_w {
let idx = (y * map_w + x) as usize;
if binary_map[idx] && !visited[idx] {
// Flood fill
let mut stack = vec![(x, y)];
visited[idx] = true;
let mut min_x = x;
let mut max_x = x;
let mut min_y = y;
let mut max_y = y;
let mut pixel_count = 0;
while let Some((cx, cy)) = stack.pop() {
pixel_count += 1;
if cx < min_x { min_x = cx; }
if cx > max_x { max_x = cx; }
if cy < min_y { min_y = cy; }
if cy > max_y { max_y = cy; }
let neighbors = [
(cx.wrapping_sub(1), cy), (cx + 1, cy),
(cx, cy.wrapping_sub(1)), (cx, cy + 1)
];
for (nx, ny) in neighbors {
if nx < map_w && ny < map_h {
let nidx = (ny * map_w + nx) as usize;
if binary_map[nidx] && !visited[nidx] {
visited[nidx] = true;
stack.push((nx, ny));
}
}
}
}
// Filter small noise
if pixel_count < 10 { continue; }
// Scale back to original
let scale_x = orig_w as f64 / resize_w as f64;
let scale_y = orig_h as f64 / resize_h as f64;
// Removed brightness check to allow detection of any text detected by DBNet
boxes.push(DetectedBox {
x: min_x as f64 * scale_x / orig_w as f64,
y: min_y as f64 * scale_y / orig_h as f64,
width: (max_x - min_x + 1) as f64 * scale_x / orig_w as f64,
height: (max_y - min_y + 1) as f64 * scale_y / orig_h as f64,
});
}
}
}
Ok(boxes)
}

View File

@@ -28,7 +28,7 @@
"bundle": { "bundle": {
"active": true, "active": true,
"targets": "all", "targets": "all",
"resources": ["resources/lama_fp32.onnx"], "resources": ["resources/lama_fp32.onnx", "resources/en_PP-OCRv3_det_infer.onnx"],
"icon": [ "icon": [
"icons/32x32.png", "icons/32x32.png",
"icons/128x128.png", "icons/128x128.png",