From f96033e42103285982a8561e5f75dbe6f0a426e3 Mon Sep 17 00:00:00 2001 From: Julian Freeman Date: Mon, 19 Jan 2026 12:08:36 -0400 Subject: [PATCH] lama remove --- .gitignore | 3 + src-tauri/Cargo.lock | 368 +++++++++++++++++++++++++++++++++++++- src-tauri/Cargo.toml | 2 + src-tauri/src/lama.rs | 139 ++++++++++++++ src-tauri/src/lib.rs | 90 +++------- src-tauri/tauri.conf.json | 1 + 6 files changed, 529 insertions(+), 74 deletions(-) create mode 100644 src-tauri/src/lama.rs diff --git a/.gitignore b/.gitignore index a547bf3..d098a42 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ dist-ssr *.njsproj *.sln *.sw? + +*.onnx +*.dll \ No newline at end of file diff --git a/src-tauri/Cargo.lock b/src-tauri/Cargo.lock index 35c8410..cb20223 100644 --- a/src-tauri/Cargo.lock +++ b/src-tauri/Cargo.lock @@ -212,6 +212,12 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + [[package]] name = "bit_field" version = "0.10.3" @@ -476,6 +482,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation" version = "0.10.1" @@ -499,9 +515,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1" dependencies = [ "bitflags 2.10.0", - "core-foundation", + "core-foundation 0.10.1", "core-graphics-types", - "foreign-types", + "foreign-types 0.5.0", "libc", ] @@ -512,7 +528,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb" dependencies = [ "bitflags 2.10.0", - "core-foundation", + "core-foundation 0.10.1", "libc", ] @@ -665,6 +681,16 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.5.5" @@ -870,6 +896,16 @@ dependencies = [ "typeid", ] +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "exr" version = "1.74.0" @@ -885,6 +921,12 @@ dependencies = [ "zune-inflate", ] +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fax" version = "0.2.6" @@ -946,6 +988,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared 0.1.1", +] + [[package]] name = "foreign-types" version = "0.5.0" @@ -953,7 +1004,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ "foreign-types-macros", - "foreign-types-shared", + "foreign-types-shared 0.3.1", ] [[package]] @@ -967,6 +1018,12 @@ dependencies = [ "syn 2.0.114", ] +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "foreign-types-shared" version = "0.3.1" @@ -1417,6 +1474,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac-sha256" +version = "1.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad6880c8d4a9ebf39c6e8b77007ce223f646a4d21ce29d99f70cb16420545425" + [[package]] name = "html5ever" version = "0.29.1" @@ -1944,7 +2007,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e9ec52138abedcc58dc17a7c6c0c00a2bdb4f3427c7f63fa97fd0d859155caf" dependencies = [ "gtk-sys", - "libloading", + "libloading 0.7.4", "once_cell", ] @@ -1974,6 +2037,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "libloading" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60" +dependencies = [ + "cfg-if", + "windows-link 0.2.1", +] + [[package]] name = "libm" version = "0.2.15" @@ -1990,6 +2063,12 @@ dependencies = [ "libc", ] +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.1" @@ -2020,6 +2099,12 @@ dependencies = [ "imgref", ] +[[package]] +name = "lzma-rust2" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69" + [[package]] name = "mac" version = "0.1.1" @@ -2165,6 +2250,53 @@ dependencies = [ "typenum", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk" version = "0.9.0" @@ -2553,12 +2685,81 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "foreign-types 0.3.2", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "option-ext" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ort" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5df903c0d2c07b56950f1058104ab0c8557159f2741782223704de9be73c3c" +dependencies = [ + "libloading 0.9.0", + "ndarray 0.17.2", + "ort-sys", + "smallvec", + "tracing", + "ureq", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06503bb33f294c5f1ba484011e053bfa6ae227074bdb841e9863492dc5960d4b" +dependencies = [ + "hmac-sha256", + "lzma-rust2", + "ureq", +] + [[package]] name = "owned_ttf_parser" version = "0.25.1" @@ -2628,6 +2829,15 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -2825,6 +3035,21 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "portable-atomic" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950" + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.4" @@ -3335,6 +3560,28 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags 2.10.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "zeroize", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -3365,6 +3612,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + [[package]] name = "schemars" version = "0.8.22" @@ -3422,6 +3678,29 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.10.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "selectors" version = "0.24.0" @@ -3699,6 +3978,17 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "softbuffer" version = "0.4.8" @@ -3858,7 +4148,7 @@ checksum = "f3a753bdc39c07b192151523a3f77cd0394aa75413802c883a0f6f6a0e5ee2e7" dependencies = [ "bitflags 2.10.0", "block2", - "core-foundation", + "core-foundation 0.10.1", "core-graphics", "crossbeam-channel", "dispatch", @@ -4180,6 +4470,19 @@ dependencies = [ "toml 0.9.11+spec-1.1.0", ] +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "tendril" version = "0.4.3" @@ -4572,6 +4875,36 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +[[package]] +name = "ureq" +version = "3.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" +dependencies = [ + "base64 0.22.1", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.8" @@ -4632,6 +4965,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "version-compare" version = "0.2.1" @@ -4783,6 +5122,8 @@ dependencies = [ "ab_glyph", "image", "imageproc", + "ndarray 0.16.1", + "ort", "rayon", "serde", "serde_json", @@ -4845,6 +5186,15 @@ dependencies = [ "system-deps", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webview2-com" version = "0.38.2" @@ -5498,6 +5848,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zerotrie" version = "0.2.3" diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 96aa39c..58926a5 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -26,4 +26,6 @@ rayon = "1.10" tauri-plugin-dialog = "2" imageproc = "0.25" ab_glyph = "0.2.23" +ort = { version = "=2.0.0-rc.11", features = ["load-dynamic"] } +ndarray = "0.16" diff --git a/src-tauri/src/lama.rs b/src-tauri/src/lama.rs new file mode 100644 index 0000000..716b91a --- /dev/null +++ b/src-tauri/src/lama.rs @@ -0,0 +1,139 @@ +use image::{Rgba, RgbaImage}; +use ort::session::{Session, builder::GraphOptimizationLevel}; +use ort::value::Value; + +pub fn run_lama_inpainting( + model_path: &std::path::Path, + input_image: &RgbaImage, + mask_image: &image::GrayImage, +) -> Result { + // 1. Initialize Session + let mut session = Session::builder() + .map_err(|e| format!("Failed to create session builder: {}", e))? + .with_optimization_level(GraphOptimizationLevel::Level3) + .map_err(|e| format!("Failed to set opt level: {}", e))? + .with_intra_threads(4) + .map_err(|e| format!("Failed to set threads: {}", e))? + .commit_from_file(model_path) + .map_err(|e| format!("Failed to load model from {:?}: {}", model_path, e))?; + + // 2. Preprocess + let target_size = (512, 512); + + let resized_img = image::imageops::resize(input_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle); + let resized_mask = image::imageops::resize(mask_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle); + + // Flatten Image to Vec (NCHW: 1, 3, 512, 512) + let channel_stride = (target_size.0 * target_size.1) as usize; + let mut input_data: Vec = Vec::with_capacity(1 * 3 * channel_stride); + + // We need to fill R plane, then G plane, then B plane + let mut r_plane: Vec = Vec::with_capacity(channel_stride); + let mut g_plane: Vec = Vec::with_capacity(channel_stride); + let mut b_plane: Vec = Vec::with_capacity(channel_stride); + + for (_x, _y, pixel) in resized_img.enumerate_pixels() { + r_plane.push(pixel[0] as f32 / 255.0f32); + g_plane.push(pixel[1] as f32 / 255.0f32); + b_plane.push(pixel[2] as f32 / 255.0f32); + } + + input_data.extend(r_plane); + input_data.extend(g_plane); + input_data.extend(b_plane); + + // Flatten Mask to Vec (NCHW: 1, 1, 512, 512) + let mut mask_data: Vec = Vec::with_capacity(channel_stride); + for (_x, _y, pixel) in resized_mask.enumerate_pixels() { + let val = if pixel[0] > 127 { 1.0f32 } else { 0.0f32 }; + mask_data.push(val); + } + + // 3. Inference + // Use (Shape, Data) tuple which implements OwnedTensorArrayData + // Explicitly casting shape to i64 is correct for ORT + let input_shape = vec![1, 3, target_size.1 as i64, target_size.0 as i64]; + let input_value = Value::from_array((input_shape, input_data)) + .map_err(|e| format!("Failed to create input tensor: {}", e))?; + + let mask_shape = vec![1, 1, target_size.1 as i64, target_size.0 as i64]; + let mask_value = Value::from_array((mask_shape, mask_data)) + .map_err(|e| format!("Failed to create mask tensor: {}", e))?; + + let inputs = ort::inputs![ + "image" => input_value, + "mask" => mask_value + ]; + + let outputs = session.run(inputs).map_err(|e| format!("Inference failed: {}", e))?; + + // Get output tensor + // Just take the first output. + let output_tensor_ref = outputs.values().next() + .ok_or("No output tensor produced by model")?; + + let (shape, data) = output_tensor_ref.try_extract_tensor::() + .map_err(|e| format!("Failed to extract tensor: {}", e))?; + + // 4. Post-process + let mut output_img_512 = RgbaImage::new(target_size.0, target_size.1); + + if shape.len() < 4 { + return Err(format!("Unexpected output shape: {:?}", shape)); + } + + let h = 512; + let w = 512; + let channel_stride = (h * w) as usize; + + // Safety check on data length + if data.len() < (3 * h * w) as usize { + return Err(format!("Output data size mismatch. Expected {}, got {}", 3*h*w, data.len())); + } + + // Auto-detect output range + // If values are already in 0-255 range, multiplying by 255 results in all white image. + let mut max_val = 0.0f32; + // Check a subset of pixels to avoid iterating everything if speed is key, but full scan is safer and fast enough. + for v in data.iter().take(1000) { + if *v > max_val { max_val = *v; } + } + + // Heuristic: if max > 2.0, it's likely 0-255. If it's <= 1.0 (or slightly above due to overshoot), it's 0-1. + // LaMa usually outputs -1..1 or 0..1. But some exports differ. + // Let's assume if any value is > 5.0, it is definitely not 0-1 normalized. + let scale_factor = if max_val > 2.0 { 1.0 } else { 255.0 }; + + for y in 0..h { + for x in 0..w { + let offset = (y * w + x) as usize; + + let r_idx = offset; + let g_idx = offset + channel_stride; + let b_idx = offset + 2 * channel_stride; + + let r = (data[r_idx] * scale_factor).clamp(0.0, 255.0) as u8; + let g = (data[g_idx] * scale_factor).clamp(0.0, 255.0) as u8; + let b = (data[b_idx] * scale_factor).clamp(0.0, 255.0) as u8; + + output_img_512.put_pixel(x, y, Rgba([r, g, b, 255])); + } + } + + // Resize back to original + let (orig_w, orig_h) = input_image.dimensions(); + let final_inpainted = image::imageops::resize(&output_img_512, orig_w, orig_h, image::imageops::FilterType::Lanczos3); + + // 5. Blending + let mut result_image = input_image.clone(); + + for y in 0..orig_h { + for x in 0..orig_w { + if mask_image.get_pixel(x, y)[0] > 127 { + result_image.put_pixel(x, y, *final_inpainted.get_pixel(x, y)); + } + } + } + + Ok(result_image) +} diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 4e8ce44..1f8b7d1 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -92,6 +92,9 @@ use rayon::prelude::*; use std::path::Path; use imageproc::drawing::draw_text_mut; use ab_glyph::{FontRef, PxScale}; +use tauri::{AppHandle, Manager}; + +mod lama; // Embed the font to ensure it's always available without path issues const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf"); @@ -646,12 +649,12 @@ enum MaskStroke { } #[tauri::command] -async fn run_inpainting(path: String, strokes: Vec) -> Result { +async fn run_inpainting(app: AppHandle, path: String, strokes: Vec) -> Result { let img = image::open(&path).map_err(|e| e.to_string())?.to_rgba8(); let (width, height) = img.dimensions(); - // 1. Create Mask - let mut mask = vec![false; (width * height) as usize]; + // 1. Create Gray Mask (0 = keep, 255 = remove) + let mut mask = image::GrayImage::new(width, height); for stroke in strokes { match stroke { @@ -661,10 +664,11 @@ async fn run_inpainting(path: String, strokes: Vec) -> Result= 0 && x < width as i32 && y >= 0 && y < height as i32 { - mask[(y as u32 * width + x as u32) as usize] = true; + mask.put_pixel(x as u32, y as u32, image::Luma([255])); } } } @@ -699,7 +703,7 @@ async fn run_inpainting(path: String, strokes: Vec) -> Result= 0 && nx < width as i32 && ny >= 0 && ny < height as i32 { - mask[(ny as u32 * width + nx as u32) as usize] = true; + mask.put_pixel(nx as u32, ny as u32, image::Luma([255])); } } } @@ -710,76 +714,26 @@ async fn run_inpainting(path: String, strokes: Vec) -> Result 0 { - let avg = image::Rgba([ - (sum_r / count) as u8, - (sum_g / count) as u8, - (sum_b / count) as u8, - 255 - ]); - next_img.put_pixel(x, y, avg); - changed = true; - } - } - } - } - current_img = next_img.clone(); - if !changed { break; } + // 2. Resolve Model Path + let model_path = app.path().resource_dir() + .map_err(|e| e.to_string())? + .join("resources") + .join("lama_fp32.onnx"); + + if !model_path.exists() { + return Err("Model file 'lama_fp32.onnx' not found in resources.".to_string()); } + + // 3. Run Inference + // This is computationally heavy, maybe run in thread? Tauri async commands are already threaded. + let result_img = lama::run_lama_inpainting(&model_path, &img, &mask)?; // Save to temp let cache_dir = get_cache_dir(); let file_name = format!("inpainted_{}.png", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis()); let out_path = cache_dir.join(file_name); - current_img.save(&out_path).map_err(|e| e.to_string())?; + result_img.save(&out_path).map_err(|e| e.to_string())?; Ok(out_path.to_string_lossy().to_string()) } diff --git a/src-tauri/tauri.conf.json b/src-tauri/tauri.conf.json index ac7820b..3662d18 100644 --- a/src-tauri/tauri.conf.json +++ b/src-tauri/tauri.conf.json @@ -28,6 +28,7 @@ "bundle": { "active": true, "targets": "all", + "resources": ["resources/lama_fp32.onnx"], "icon": [ "icons/32x32.png", "icons/128x128.png",