lama remove
This commit is contained in:
139
src-tauri/src/lama.rs
Normal file
139
src-tauri/src/lama.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use image::{Rgba, RgbaImage};
|
||||
use ort::session::{Session, builder::GraphOptimizationLevel};
|
||||
use ort::value::Value;
|
||||
|
||||
pub fn run_lama_inpainting(
|
||||
model_path: &std::path::Path,
|
||||
input_image: &RgbaImage,
|
||||
mask_image: &image::GrayImage,
|
||||
) -> Result<RgbaImage, String> {
|
||||
// 1. Initialize Session
|
||||
let mut session = Session::builder()
|
||||
.map_err(|e| format!("Failed to create session builder: {}", e))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.map_err(|e| format!("Failed to set opt level: {}", e))?
|
||||
.with_intra_threads(4)
|
||||
.map_err(|e| format!("Failed to set threads: {}", e))?
|
||||
.commit_from_file(model_path)
|
||||
.map_err(|e| format!("Failed to load model from {:?}: {}", model_path, e))?;
|
||||
|
||||
// 2. Preprocess
|
||||
let target_size = (512, 512);
|
||||
|
||||
let resized_img = image::imageops::resize(input_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle);
|
||||
let resized_mask = image::imageops::resize(mask_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle);
|
||||
|
||||
// Flatten Image to Vec<f32> (NCHW: 1, 3, 512, 512)
|
||||
let channel_stride = (target_size.0 * target_size.1) as usize;
|
||||
let mut input_data: Vec<f32> = Vec::with_capacity(1 * 3 * channel_stride);
|
||||
|
||||
// We need to fill R plane, then G plane, then B plane
|
||||
let mut r_plane: Vec<f32> = Vec::with_capacity(channel_stride);
|
||||
let mut g_plane: Vec<f32> = Vec::with_capacity(channel_stride);
|
||||
let mut b_plane: Vec<f32> = Vec::with_capacity(channel_stride);
|
||||
|
||||
for (_x, _y, pixel) in resized_img.enumerate_pixels() {
|
||||
r_plane.push(pixel[0] as f32 / 255.0f32);
|
||||
g_plane.push(pixel[1] as f32 / 255.0f32);
|
||||
b_plane.push(pixel[2] as f32 / 255.0f32);
|
||||
}
|
||||
|
||||
input_data.extend(r_plane);
|
||||
input_data.extend(g_plane);
|
||||
input_data.extend(b_plane);
|
||||
|
||||
// Flatten Mask to Vec<f32> (NCHW: 1, 1, 512, 512)
|
||||
let mut mask_data: Vec<f32> = Vec::with_capacity(channel_stride);
|
||||
for (_x, _y, pixel) in resized_mask.enumerate_pixels() {
|
||||
let val = if pixel[0] > 127 { 1.0f32 } else { 0.0f32 };
|
||||
mask_data.push(val);
|
||||
}
|
||||
|
||||
// 3. Inference
|
||||
// Use (Shape, Data) tuple which implements OwnedTensorArrayData
|
||||
// Explicitly casting shape to i64 is correct for ORT
|
||||
let input_shape = vec![1, 3, target_size.1 as i64, target_size.0 as i64];
|
||||
let input_value = Value::from_array((input_shape, input_data))
|
||||
.map_err(|e| format!("Failed to create input tensor: {}", e))?;
|
||||
|
||||
let mask_shape = vec![1, 1, target_size.1 as i64, target_size.0 as i64];
|
||||
let mask_value = Value::from_array((mask_shape, mask_data))
|
||||
.map_err(|e| format!("Failed to create mask tensor: {}", e))?;
|
||||
|
||||
let inputs = ort::inputs![
|
||||
"image" => input_value,
|
||||
"mask" => mask_value
|
||||
];
|
||||
|
||||
let outputs = session.run(inputs).map_err(|e| format!("Inference failed: {}", e))?;
|
||||
|
||||
// Get output tensor
|
||||
// Just take the first output.
|
||||
let output_tensor_ref = outputs.values().next()
|
||||
.ok_or("No output tensor produced by model")?;
|
||||
|
||||
let (shape, data) = output_tensor_ref.try_extract_tensor::<f32>()
|
||||
.map_err(|e| format!("Failed to extract tensor: {}", e))?;
|
||||
|
||||
// 4. Post-process
|
||||
let mut output_img_512 = RgbaImage::new(target_size.0, target_size.1);
|
||||
|
||||
if shape.len() < 4 {
|
||||
return Err(format!("Unexpected output shape: {:?}", shape));
|
||||
}
|
||||
|
||||
let h = 512;
|
||||
let w = 512;
|
||||
let channel_stride = (h * w) as usize;
|
||||
|
||||
// Safety check on data length
|
||||
if data.len() < (3 * h * w) as usize {
|
||||
return Err(format!("Output data size mismatch. Expected {}, got {}", 3*h*w, data.len()));
|
||||
}
|
||||
|
||||
// Auto-detect output range
|
||||
// If values are already in 0-255 range, multiplying by 255 results in all white image.
|
||||
let mut max_val = 0.0f32;
|
||||
// Check a subset of pixels to avoid iterating everything if speed is key, but full scan is safer and fast enough.
|
||||
for v in data.iter().take(1000) {
|
||||
if *v > max_val { max_val = *v; }
|
||||
}
|
||||
|
||||
// Heuristic: if max > 2.0, it's likely 0-255. If it's <= 1.0 (or slightly above due to overshoot), it's 0-1.
|
||||
// LaMa usually outputs -1..1 or 0..1. But some exports differ.
|
||||
// Let's assume if any value is > 5.0, it is definitely not 0-1 normalized.
|
||||
let scale_factor = if max_val > 2.0 { 1.0 } else { 255.0 };
|
||||
|
||||
for y in 0..h {
|
||||
for x in 0..w {
|
||||
let offset = (y * w + x) as usize;
|
||||
|
||||
let r_idx = offset;
|
||||
let g_idx = offset + channel_stride;
|
||||
let b_idx = offset + 2 * channel_stride;
|
||||
|
||||
let r = (data[r_idx] * scale_factor).clamp(0.0, 255.0) as u8;
|
||||
let g = (data[g_idx] * scale_factor).clamp(0.0, 255.0) as u8;
|
||||
let b = (data[b_idx] * scale_factor).clamp(0.0, 255.0) as u8;
|
||||
|
||||
output_img_512.put_pixel(x, y, Rgba([r, g, b, 255]));
|
||||
}
|
||||
}
|
||||
|
||||
// Resize back to original
|
||||
let (orig_w, orig_h) = input_image.dimensions();
|
||||
let final_inpainted = image::imageops::resize(&output_img_512, orig_w, orig_h, image::imageops::FilterType::Lanczos3);
|
||||
|
||||
// 5. Blending
|
||||
let mut result_image = input_image.clone();
|
||||
|
||||
for y in 0..orig_h {
|
||||
for x in 0..orig_w {
|
||||
if mask_image.get_pixel(x, y)[0] > 127 {
|
||||
result_image.put_pixel(x, y, *final_inpainted.get_pixel(x, y));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result_image)
|
||||
}
|
||||
Reference in New Issue
Block a user