use crate::ort_ops; use image::{Rgba, RgbaImage}; use ort::value::Value; pub fn run_lama_inpainting( model_path: &std::path::Path, input_image: &RgbaImage, mask_image: &image::GrayImage, ) -> Result { // 1. Initialize Session let mut session = ort_ops::create_session(model_path) .map_err(|e| format!("Failed to create ORT session for LAMA: {}", e))?; // 2. Preprocess let target_size = (512, 512); let resized_img = image::imageops::resize(input_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle); let resized_mask = image::imageops::resize(mask_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle); // Flatten Image to Vec (NCHW: 1, 3, 512, 512) let channel_stride = (target_size.0 * target_size.1) as usize; let mut input_data: Vec = Vec::with_capacity(1 * 3 * channel_stride); // We need to fill R plane, then G plane, then B plane let mut r_plane: Vec = Vec::with_capacity(channel_stride); let mut g_plane: Vec = Vec::with_capacity(channel_stride); let mut b_plane: Vec = Vec::with_capacity(channel_stride); for (_x, _y, pixel) in resized_img.enumerate_pixels() { r_plane.push(pixel[0] as f32 / 255.0f32); g_plane.push(pixel[1] as f32 / 255.0f32); b_plane.push(pixel[2] as f32 / 255.0f32); } input_data.extend(r_plane); input_data.extend(g_plane); input_data.extend(b_plane); // Flatten Mask to Vec (NCHW: 1, 1, 512, 512) let mut mask_data: Vec = Vec::with_capacity(channel_stride); for (_x, _y, pixel) in resized_mask.enumerate_pixels() { let val = if pixel[0] > 127 { 1.0f32 } else { 0.0f32 }; mask_data.push(val); } // 3. Inference // Use (Shape, Data) tuple which implements OwnedTensorArrayData // Explicitly casting shape to i64 is correct for ORT let input_shape = vec![1, 3, target_size.1 as i64, target_size.0 as i64]; let input_value = Value::from_array((input_shape, input_data)) .map_err(|e| format!("Failed to create input tensor: {}", e))?; let mask_shape = vec![1, 1, target_size.1 as i64, target_size.0 as i64]; let mask_value = Value::from_array((mask_shape, mask_data)) .map_err(|e| format!("Failed to create mask tensor: {}", e))?; let inputs = ort::inputs![ "image" => input_value, "mask" => mask_value ]; let outputs = session.run(inputs).map_err(|e| format!("Inference failed: {}", e))?; // Get output tensor // Just take the first output. let output_tensor_ref = outputs.values().next() .ok_or("No output tensor produced by model")?; let (shape, data) = output_tensor_ref.try_extract_tensor::() .map_err(|e| format!("Failed to extract tensor: {}", e))?; // 4. Post-process let mut output_img_512 = RgbaImage::new(target_size.0, target_size.1); if shape.len() < 4 { return Err(format!("Unexpected output shape: {:?}", shape)); } let h = 512; let w = 512; let channel_stride = (h * w) as usize; // Safety check on data length if data.len() < (3 * h * w) as usize { return Err(format!("Output data size mismatch. Expected {}, got {}", 3*h*w, data.len())); } // Auto-detect output range // If values are already in 0-255 range, multiplying by 255 results in all white image. let mut max_val = 0.0f32; // Check a subset of pixels to avoid iterating everything if speed is key, but full scan is safer and fast enough. for v in data.iter().take(1000) { if *v > max_val { max_val = *v; } } // Heuristic: if max > 2.0, it's likely 0-255. If it's <= 1.0 (or slightly above due to overshoot), it's 0-1. // LaMa usually outputs -1..1 or 0..1. But some exports differ. // Let's assume if any value is > 5.0, it is definitely not 0-1 normalized. let scale_factor = if max_val > 2.0 { 1.0 } else { 255.0 }; for y in 0..h { for x in 0..w { let offset = (y * w + x) as usize; let r_idx = offset; let g_idx = offset + channel_stride; let b_idx = offset + 2 * channel_stride; let r = (data[r_idx] * scale_factor).clamp(0.0, 255.0) as u8; let g = (data[g_idx] * scale_factor).clamp(0.0, 255.0) as u8; let b = (data[b_idx] * scale_factor).clamp(0.0, 255.0) as u8; output_img_512.put_pixel(x, y, Rgba([r, g, b, 255])); } } // Resize back to original let (orig_w, orig_h) = input_image.dimensions(); let final_inpainted = image::imageops::resize(&output_img_512, orig_w, orig_h, image::imageops::FilterType::Lanczos3); // 5. Blending let mut result_image = input_image.clone(); for y in 0..orig_h { for x in 0..orig_w { if mask_image.get_pixel(x, y)[0] > 127 { result_image.put_pixel(x, y, *final_inpainted.get_pixel(x, y)); } } } Ok(result_image) }