lama remove

This commit is contained in:
Julian Freeman
2026-01-19 12:08:36 -04:00
parent 64d6e770ad
commit f96033e421
6 changed files with 529 additions and 74 deletions

139
src-tauri/src/lama.rs Normal file
View File

@@ -0,0 +1,139 @@
use image::{Rgba, RgbaImage};
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
pub fn run_lama_inpainting(
model_path: &std::path::Path,
input_image: &RgbaImage,
mask_image: &image::GrayImage,
) -> Result<RgbaImage, String> {
// 1. Initialize Session
let mut session = Session::builder()
.map_err(|e| format!("Failed to create session builder: {}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| format!("Failed to set opt level: {}", e))?
.with_intra_threads(4)
.map_err(|e| format!("Failed to set threads: {}", e))?
.commit_from_file(model_path)
.map_err(|e| format!("Failed to load model from {:?}: {}", model_path, e))?;
// 2. Preprocess
let target_size = (512, 512);
let resized_img = image::imageops::resize(input_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle);
let resized_mask = image::imageops::resize(mask_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle);
// Flatten Image to Vec<f32> (NCHW: 1, 3, 512, 512)
let channel_stride = (target_size.0 * target_size.1) as usize;
let mut input_data: Vec<f32> = Vec::with_capacity(1 * 3 * channel_stride);
// We need to fill R plane, then G plane, then B plane
let mut r_plane: Vec<f32> = Vec::with_capacity(channel_stride);
let mut g_plane: Vec<f32> = Vec::with_capacity(channel_stride);
let mut b_plane: Vec<f32> = Vec::with_capacity(channel_stride);
for (_x, _y, pixel) in resized_img.enumerate_pixels() {
r_plane.push(pixel[0] as f32 / 255.0f32);
g_plane.push(pixel[1] as f32 / 255.0f32);
b_plane.push(pixel[2] as f32 / 255.0f32);
}
input_data.extend(r_plane);
input_data.extend(g_plane);
input_data.extend(b_plane);
// Flatten Mask to Vec<f32> (NCHW: 1, 1, 512, 512)
let mut mask_data: Vec<f32> = Vec::with_capacity(channel_stride);
for (_x, _y, pixel) in resized_mask.enumerate_pixels() {
let val = if pixel[0] > 127 { 1.0f32 } else { 0.0f32 };
mask_data.push(val);
}
// 3. Inference
// Use (Shape, Data) tuple which implements OwnedTensorArrayData
// Explicitly casting shape to i64 is correct for ORT
let input_shape = vec![1, 3, target_size.1 as i64, target_size.0 as i64];
let input_value = Value::from_array((input_shape, input_data))
.map_err(|e| format!("Failed to create input tensor: {}", e))?;
let mask_shape = vec![1, 1, target_size.1 as i64, target_size.0 as i64];
let mask_value = Value::from_array((mask_shape, mask_data))
.map_err(|e| format!("Failed to create mask tensor: {}", e))?;
let inputs = ort::inputs![
"image" => input_value,
"mask" => mask_value
];
let outputs = session.run(inputs).map_err(|e| format!("Inference failed: {}", e))?;
// Get output tensor
// Just take the first output.
let output_tensor_ref = outputs.values().next()
.ok_or("No output tensor produced by model")?;
let (shape, data) = output_tensor_ref.try_extract_tensor::<f32>()
.map_err(|e| format!("Failed to extract tensor: {}", e))?;
// 4. Post-process
let mut output_img_512 = RgbaImage::new(target_size.0, target_size.1);
if shape.len() < 4 {
return Err(format!("Unexpected output shape: {:?}", shape));
}
let h = 512;
let w = 512;
let channel_stride = (h * w) as usize;
// Safety check on data length
if data.len() < (3 * h * w) as usize {
return Err(format!("Output data size mismatch. Expected {}, got {}", 3*h*w, data.len()));
}
// Auto-detect output range
// If values are already in 0-255 range, multiplying by 255 results in all white image.
let mut max_val = 0.0f32;
// Check a subset of pixels to avoid iterating everything if speed is key, but full scan is safer and fast enough.
for v in data.iter().take(1000) {
if *v > max_val { max_val = *v; }
}
// Heuristic: if max > 2.0, it's likely 0-255. If it's <= 1.0 (or slightly above due to overshoot), it's 0-1.
// LaMa usually outputs -1..1 or 0..1. But some exports differ.
// Let's assume if any value is > 5.0, it is definitely not 0-1 normalized.
let scale_factor = if max_val > 2.0 { 1.0 } else { 255.0 };
for y in 0..h {
for x in 0..w {
let offset = (y * w + x) as usize;
let r_idx = offset;
let g_idx = offset + channel_stride;
let b_idx = offset + 2 * channel_stride;
let r = (data[r_idx] * scale_factor).clamp(0.0, 255.0) as u8;
let g = (data[g_idx] * scale_factor).clamp(0.0, 255.0) as u8;
let b = (data[b_idx] * scale_factor).clamp(0.0, 255.0) as u8;
output_img_512.put_pixel(x, y, Rgba([r, g, b, 255]));
}
}
// Resize back to original
let (orig_w, orig_h) = input_image.dimensions();
let final_inpainted = image::imageops::resize(&output_img_512, orig_w, orig_h, image::imageops::FilterType::Lanczos3);
// 5. Blending
let mut result_image = input_image.clone();
for y in 0..orig_h {
for x in 0..orig_w {
if mask_image.get_pixel(x, y)[0] > 127 {
result_image.put_pixel(x, y, *final_inpainted.get_pixel(x, y));
}
}
}
Ok(result_image)
}

View File

@@ -92,6 +92,9 @@ use rayon::prelude::*;
use std::path::Path;
use imageproc::drawing::draw_text_mut;
use ab_glyph::{FontRef, PxScale};
use tauri::{AppHandle, Manager};
mod lama;
// Embed the font to ensure it's always available without path issues
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");
@@ -646,12 +649,12 @@ enum MaskStroke {
}
#[tauri::command]
async fn run_inpainting(path: String, strokes: Vec<MaskStroke>) -> Result<String, String> {
async fn run_inpainting(app: AppHandle, path: String, strokes: Vec<MaskStroke>) -> Result<String, String> {
let img = image::open(&path).map_err(|e| e.to_string())?.to_rgba8();
let (width, height) = img.dimensions();
// 1. Create Mask
let mut mask = vec![false; (width * height) as usize];
// 1. Create Gray Mask (0 = keep, 255 = remove)
let mut mask = image::GrayImage::new(width, height);
for stroke in strokes {
match stroke {
@@ -661,10 +664,11 @@ async fn run_inpainting(path: String, strokes: Vec<MaskStroke>) -> Result<String
let w = (rect.w * width as f64) as i32;
let h = (rect.h * height as f64) as i32;
// Draw 255 on mask
for y in y1..(y1 + h) {
for x in x1..(x1 + w) {
if x >= 0 && x < width as i32 && y >= 0 && y < height as i32 {
mask[(y as u32 * width + x as u32) as usize] = true;
mask.put_pixel(x as u32, y as u32, image::Luma([255]));
}
}
}
@@ -699,7 +703,7 @@ async fn run_inpainting(path: String, strokes: Vec<MaskStroke>) -> Result<String
let nx = cx + dx;
let ny = cy + dy;
if nx >= 0 && nx < width as i32 && ny >= 0 && ny < height as i32 {
mask[(ny as u32 * width + nx as u32) as usize] = true;
mask.put_pixel(nx as u32, ny as u32, image::Luma([255]));
}
}
}
@@ -710,76 +714,26 @@ async fn run_inpainting(path: String, strokes: Vec<MaskStroke>) -> Result<String
}
}
// 2. Diffusion Inpainting (Simple)
// Iteratively replace masked pixels with average of non-masked neighbors
// To make it converge, we update 'mask' as we go (treating filled pixels as valid source)
// But standard diffusion uses double buffering.
// For "Removing Text", simple inward filling works well.
let iterations = 30;
let mut current_img = img.clone();
let mut next_img = img.clone();
// Convert mask to a distance map-like state?
// Or just simple neighbor average.
for _ in 0..iterations {
let mut changed = false;
for y in 0..height {
for x in 0..width {
let idx = (y * width + x) as usize;
if mask[idx] {
// It's a hole. Find valid neighbors.
// Valid = Not in ORIGINAL mask (so we pull from original image)
// OR processed in previous iteration?
// Simple logic: Pull from 'current_img'.
let mut sum_r = 0u32;
let mut sum_g = 0u32;
let mut sum_b = 0u32;
let mut count = 0;
// Check 4 neighbors
let neighbors = [
(x.wrapping_sub(1), y), (x + 1, y),
(x, y.wrapping_sub(1)), (x, y + 1)
];
for (nx, ny) in neighbors {
if nx < width && ny < height {
// Weighted check: If neighbor is ALSO masked, it contributes less?
// Or just take everything.
let pixel = current_img.get_pixel(nx, ny);
sum_r += pixel[0] as u32;
sum_g += pixel[1] as u32;
sum_b += pixel[2] as u32;
count += 1;
}
}
if count > 0 {
let avg = image::Rgba([
(sum_r / count) as u8,
(sum_g / count) as u8,
(sum_b / count) as u8,
255
]);
next_img.put_pixel(x, y, avg);
changed = true;
}
}
}
}
current_img = next_img.clone();
if !changed { break; }
// 2. Resolve Model Path
let model_path = app.path().resource_dir()
.map_err(|e| e.to_string())?
.join("resources")
.join("lama_fp32.onnx");
if !model_path.exists() {
return Err("Model file 'lama_fp32.onnx' not found in resources.".to_string());
}
// 3. Run Inference
// This is computationally heavy, maybe run in thread? Tauri async commands are already threaded.
let result_img = lama::run_lama_inpainting(&model_path, &img, &mask)?;
// Save to temp
let cache_dir = get_cache_dir();
let file_name = format!("inpainted_{}.png", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis());
let out_path = cache_dir.join(file_name);
current_img.save(&out_path).map_err(|e| e.to_string())?;
result_img.save(&out_path).map_err(|e| e.to_string())?;
Ok(out_path.to_string_lossy().to_string())
}