lama remove

This commit is contained in:
Julian Freeman
2026-01-19 12:08:36 -04:00
parent 64d6e770ad
commit f96033e421
6 changed files with 529 additions and 74 deletions

3
.gitignore vendored
View File

@@ -22,3 +22,6 @@ dist-ssr
*.njsproj *.njsproj
*.sln *.sln
*.sw? *.sw?
*.onnx
*.dll

368
src-tauri/Cargo.lock generated
View File

@@ -212,6 +212,12 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64ct"
version = "1.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06"
[[package]] [[package]]
name = "bit_field" name = "bit_field"
version = "0.10.3" version = "0.10.3"
@@ -476,6 +482,16 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "core-foundation" name = "core-foundation"
version = "0.10.1" version = "0.10.1"
@@ -499,9 +515,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1" checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1"
dependencies = [ dependencies = [
"bitflags 2.10.0", "bitflags 2.10.0",
"core-foundation", "core-foundation 0.10.1",
"core-graphics-types", "core-graphics-types",
"foreign-types", "foreign-types 0.5.0",
"libc", "libc",
] ]
@@ -512,7 +528,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb" checksum = "3d44a101f213f6c4cdc1853d4b78aef6db6bdfa3468798cc1d9912f4735013eb"
dependencies = [ dependencies = [
"bitflags 2.10.0", "bitflags 2.10.0",
"core-foundation", "core-foundation 0.10.1",
"libc", "libc",
] ]
@@ -665,6 +681,16 @@ dependencies = [
"syn 2.0.114", "syn 2.0.114",
] ]
[[package]]
name = "der"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb"
dependencies = [
"pem-rfc7468",
"zeroize",
]
[[package]] [[package]]
name = "deranged" name = "deranged"
version = "0.5.5" version = "0.5.5"
@@ -870,6 +896,16 @@ dependencies = [
"typeid", "typeid",
] ]
[[package]]
name = "errno"
version = "0.3.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [
"libc",
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "exr" name = "exr"
version = "1.74.0" version = "1.74.0"
@@ -885,6 +921,12 @@ dependencies = [
"zune-inflate", "zune-inflate",
] ]
[[package]]
name = "fastrand"
version = "2.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
[[package]] [[package]]
name = "fax" name = "fax"
version = "0.2.6" version = "0.2.6"
@@ -946,6 +988,15 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foreign-types"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
dependencies = [
"foreign-types-shared 0.1.1",
]
[[package]] [[package]]
name = "foreign-types" name = "foreign-types"
version = "0.5.0" version = "0.5.0"
@@ -953,7 +1004,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965"
dependencies = [ dependencies = [
"foreign-types-macros", "foreign-types-macros",
"foreign-types-shared", "foreign-types-shared 0.3.1",
] ]
[[package]] [[package]]
@@ -967,6 +1018,12 @@ dependencies = [
"syn 2.0.114", "syn 2.0.114",
] ]
[[package]]
name = "foreign-types-shared"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
[[package]] [[package]]
name = "foreign-types-shared" name = "foreign-types-shared"
version = "0.3.1" version = "0.3.1"
@@ -1417,6 +1474,12 @@ version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "hmac-sha256"
version = "1.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad6880c8d4a9ebf39c6e8b77007ce223f646a4d21ce29d99f70cb16420545425"
[[package]] [[package]]
name = "html5ever" name = "html5ever"
version = "0.29.1" version = "0.29.1"
@@ -1944,7 +2007,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e9ec52138abedcc58dc17a7c6c0c00a2bdb4f3427c7f63fa97fd0d859155caf" checksum = "6e9ec52138abedcc58dc17a7c6c0c00a2bdb4f3427c7f63fa97fd0d859155caf"
dependencies = [ dependencies = [
"gtk-sys", "gtk-sys",
"libloading", "libloading 0.7.4",
"once_cell", "once_cell",
] ]
@@ -1974,6 +2037,16 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "libloading"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "754ca22de805bb5744484a5b151a9e1a8e837d5dc232c2d7d8c2e3492edc8b60"
dependencies = [
"cfg-if",
"windows-link 0.2.1",
]
[[package]] [[package]]
name = "libm" name = "libm"
version = "0.2.15" version = "0.2.15"
@@ -1990,6 +2063,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "linux-raw-sys"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039"
[[package]] [[package]]
name = "litemap" name = "litemap"
version = "0.8.1" version = "0.8.1"
@@ -2020,6 +2099,12 @@ dependencies = [
"imgref", "imgref",
] ]
[[package]]
name = "lzma-rust2"
version = "0.15.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69"
[[package]] [[package]]
name = "mac" name = "mac"
version = "0.1.1" version = "0.1.1"
@@ -2165,6 +2250,53 @@ dependencies = [
"typenum", "typenum",
] ]
[[package]]
name = "native-tls"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e"
dependencies = [
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
name = "ndarray"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]]
name = "ndarray"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"portable-atomic",
"portable-atomic-util",
"rawpointer",
]
[[package]] [[package]]
name = "ndk" name = "ndk"
version = "0.9.0" version = "0.9.0"
@@ -2553,12 +2685,81 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "openssl"
version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags 2.10.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.114",
]
[[package]]
name = "openssl-probe"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "openssl-sys"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
"pkg-config",
"vcpkg",
]
[[package]] [[package]]
name = "option-ext" name = "option-ext"
version = "0.2.0" version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "ort"
version = "2.0.0-rc.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5df903c0d2c07b56950f1058104ab0c8557159f2741782223704de9be73c3c"
dependencies = [
"libloading 0.9.0",
"ndarray 0.17.2",
"ort-sys",
"smallvec",
"tracing",
"ureq",
]
[[package]]
name = "ort-sys"
version = "2.0.0-rc.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06503bb33f294c5f1ba484011e053bfa6ae227074bdb841e9863492dc5960d4b"
dependencies = [
"hmac-sha256",
"lzma-rust2",
"ureq",
]
[[package]] [[package]]
name = "owned_ttf_parser" name = "owned_ttf_parser"
version = "0.25.1" version = "0.25.1"
@@ -2628,6 +2829,15 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec"
[[package]]
name = "pem-rfc7468"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412"
dependencies = [
"base64ct",
]
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.2" version = "2.3.2"
@@ -2825,6 +3035,21 @@ dependencies = [
"miniz_oxide", "miniz_oxide",
] ]
[[package]]
name = "portable-atomic"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f89776e4d69bb58bc6993e99ffa1d11f228b839984854c7daeb5d37f87cbe950"
[[package]]
name = "portable-atomic-util"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507"
dependencies = [
"portable-atomic",
]
[[package]] [[package]]
name = "potential_utf" name = "potential_utf"
version = "0.1.4" version = "0.1.4"
@@ -3335,6 +3560,28 @@ dependencies = [
"semver", "semver",
] ]
[[package]]
name = "rustix"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34"
dependencies = [
"bitflags 2.10.0",
"errno",
"libc",
"linux-raw-sys",
"windows-sys 0.61.2",
]
[[package]]
name = "rustls-pki-types"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd"
dependencies = [
"zeroize",
]
[[package]] [[package]]
name = "rustversion" name = "rustversion"
version = "1.0.22" version = "1.0.22"
@@ -3365,6 +3612,15 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "schannel"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1"
dependencies = [
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "schemars" name = "schemars"
version = "0.8.22" version = "0.8.22"
@@ -3422,6 +3678,29 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "security-framework"
version = "2.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02"
dependencies = [
"bitflags 2.10.0",
"core-foundation 0.9.4",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "selectors" name = "selectors"
version = "0.24.0" version = "0.24.0"
@@ -3699,6 +3978,17 @@ dependencies = [
"windows-sys 0.60.2", "windows-sys 0.60.2",
] ]
[[package]]
name = "socks"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b"
dependencies = [
"byteorder",
"libc",
"winapi",
]
[[package]] [[package]]
name = "softbuffer" name = "softbuffer"
version = "0.4.8" version = "0.4.8"
@@ -3858,7 +4148,7 @@ checksum = "f3a753bdc39c07b192151523a3f77cd0394aa75413802c883a0f6f6a0e5ee2e7"
dependencies = [ dependencies = [
"bitflags 2.10.0", "bitflags 2.10.0",
"block2", "block2",
"core-foundation", "core-foundation 0.10.1",
"core-graphics", "core-graphics",
"crossbeam-channel", "crossbeam-channel",
"dispatch", "dispatch",
@@ -4180,6 +4470,19 @@ dependencies = [
"toml 0.9.11+spec-1.1.0", "toml 0.9.11+spec-1.1.0",
] ]
[[package]]
name = "tempfile"
version = "3.24.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c"
dependencies = [
"fastrand",
"getrandom 0.3.4",
"once_cell",
"rustix",
"windows-sys 0.61.2",
]
[[package]] [[package]]
name = "tendril" name = "tendril"
version = "0.4.3" version = "0.4.3"
@@ -4572,6 +4875,36 @@ version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493"
[[package]]
name = "ureq"
version = "3.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a"
dependencies = [
"base64 0.22.1",
"der",
"log",
"native-tls",
"percent-encoding",
"rustls-pki-types",
"socks",
"ureq-proto",
"utf-8",
"webpki-root-certs",
]
[[package]]
name = "ureq-proto"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f"
dependencies = [
"base64 0.22.1",
"http",
"httparse",
"log",
]
[[package]] [[package]]
name = "url" name = "url"
version = "2.5.8" version = "2.5.8"
@@ -4632,6 +4965,12 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "vcpkg"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
[[package]] [[package]]
name = "version-compare" name = "version-compare"
version = "0.2.1" version = "0.2.1"
@@ -4783,6 +5122,8 @@ dependencies = [
"ab_glyph", "ab_glyph",
"image", "image",
"imageproc", "imageproc",
"ndarray 0.16.1",
"ort",
"rayon", "rayon",
"serde", "serde",
"serde_json", "serde_json",
@@ -4845,6 +5186,15 @@ dependencies = [
"system-deps", "system-deps",
] ]
[[package]]
name = "webpki-root-certs"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36a29fc0408b113f68cf32637857ab740edfafdf460c326cd2afaa2d84cc05dc"
dependencies = [
"rustls-pki-types",
]
[[package]] [[package]]
name = "webview2-com" name = "webview2-com"
version = "0.38.2" version = "0.38.2"
@@ -5498,6 +5848,12 @@ dependencies = [
"synstructure", "synstructure",
] ]
[[package]]
name = "zeroize"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0"
[[package]] [[package]]
name = "zerotrie" name = "zerotrie"
version = "0.2.3" version = "0.2.3"

View File

@@ -26,4 +26,6 @@ rayon = "1.10"
tauri-plugin-dialog = "2" tauri-plugin-dialog = "2"
imageproc = "0.25" imageproc = "0.25"
ab_glyph = "0.2.23" ab_glyph = "0.2.23"
ort = { version = "=2.0.0-rc.11", features = ["load-dynamic"] }
ndarray = "0.16"

139
src-tauri/src/lama.rs Normal file
View File

@@ -0,0 +1,139 @@
use image::{Rgba, RgbaImage};
use ort::session::{Session, builder::GraphOptimizationLevel};
use ort::value::Value;
pub fn run_lama_inpainting(
model_path: &std::path::Path,
input_image: &RgbaImage,
mask_image: &image::GrayImage,
) -> Result<RgbaImage, String> {
// 1. Initialize Session
let mut session = Session::builder()
.map_err(|e| format!("Failed to create session builder: {}", e))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.map_err(|e| format!("Failed to set opt level: {}", e))?
.with_intra_threads(4)
.map_err(|e| format!("Failed to set threads: {}", e))?
.commit_from_file(model_path)
.map_err(|e| format!("Failed to load model from {:?}: {}", model_path, e))?;
// 2. Preprocess
let target_size = (512, 512);
let resized_img = image::imageops::resize(input_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle);
let resized_mask = image::imageops::resize(mask_image, target_size.0, target_size.1, image::imageops::FilterType::Triangle);
// Flatten Image to Vec<f32> (NCHW: 1, 3, 512, 512)
let channel_stride = (target_size.0 * target_size.1) as usize;
let mut input_data: Vec<f32> = Vec::with_capacity(1 * 3 * channel_stride);
// We need to fill R plane, then G plane, then B plane
let mut r_plane: Vec<f32> = Vec::with_capacity(channel_stride);
let mut g_plane: Vec<f32> = Vec::with_capacity(channel_stride);
let mut b_plane: Vec<f32> = Vec::with_capacity(channel_stride);
for (_x, _y, pixel) in resized_img.enumerate_pixels() {
r_plane.push(pixel[0] as f32 / 255.0f32);
g_plane.push(pixel[1] as f32 / 255.0f32);
b_plane.push(pixel[2] as f32 / 255.0f32);
}
input_data.extend(r_plane);
input_data.extend(g_plane);
input_data.extend(b_plane);
// Flatten Mask to Vec<f32> (NCHW: 1, 1, 512, 512)
let mut mask_data: Vec<f32> = Vec::with_capacity(channel_stride);
for (_x, _y, pixel) in resized_mask.enumerate_pixels() {
let val = if pixel[0] > 127 { 1.0f32 } else { 0.0f32 };
mask_data.push(val);
}
// 3. Inference
// Use (Shape, Data) tuple which implements OwnedTensorArrayData
// Explicitly casting shape to i64 is correct for ORT
let input_shape = vec![1, 3, target_size.1 as i64, target_size.0 as i64];
let input_value = Value::from_array((input_shape, input_data))
.map_err(|e| format!("Failed to create input tensor: {}", e))?;
let mask_shape = vec![1, 1, target_size.1 as i64, target_size.0 as i64];
let mask_value = Value::from_array((mask_shape, mask_data))
.map_err(|e| format!("Failed to create mask tensor: {}", e))?;
let inputs = ort::inputs![
"image" => input_value,
"mask" => mask_value
];
let outputs = session.run(inputs).map_err(|e| format!("Inference failed: {}", e))?;
// Get output tensor
// Just take the first output.
let output_tensor_ref = outputs.values().next()
.ok_or("No output tensor produced by model")?;
let (shape, data) = output_tensor_ref.try_extract_tensor::<f32>()
.map_err(|e| format!("Failed to extract tensor: {}", e))?;
// 4. Post-process
let mut output_img_512 = RgbaImage::new(target_size.0, target_size.1);
if shape.len() < 4 {
return Err(format!("Unexpected output shape: {:?}", shape));
}
let h = 512;
let w = 512;
let channel_stride = (h * w) as usize;
// Safety check on data length
if data.len() < (3 * h * w) as usize {
return Err(format!("Output data size mismatch. Expected {}, got {}", 3*h*w, data.len()));
}
// Auto-detect output range
// If values are already in 0-255 range, multiplying by 255 results in all white image.
let mut max_val = 0.0f32;
// Check a subset of pixels to avoid iterating everything if speed is key, but full scan is safer and fast enough.
for v in data.iter().take(1000) {
if *v > max_val { max_val = *v; }
}
// Heuristic: if max > 2.0, it's likely 0-255. If it's <= 1.0 (or slightly above due to overshoot), it's 0-1.
// LaMa usually outputs -1..1 or 0..1. But some exports differ.
// Let's assume if any value is > 5.0, it is definitely not 0-1 normalized.
let scale_factor = if max_val > 2.0 { 1.0 } else { 255.0 };
for y in 0..h {
for x in 0..w {
let offset = (y * w + x) as usize;
let r_idx = offset;
let g_idx = offset + channel_stride;
let b_idx = offset + 2 * channel_stride;
let r = (data[r_idx] * scale_factor).clamp(0.0, 255.0) as u8;
let g = (data[g_idx] * scale_factor).clamp(0.0, 255.0) as u8;
let b = (data[b_idx] * scale_factor).clamp(0.0, 255.0) as u8;
output_img_512.put_pixel(x, y, Rgba([r, g, b, 255]));
}
}
// Resize back to original
let (orig_w, orig_h) = input_image.dimensions();
let final_inpainted = image::imageops::resize(&output_img_512, orig_w, orig_h, image::imageops::FilterType::Lanczos3);
// 5. Blending
let mut result_image = input_image.clone();
for y in 0..orig_h {
for x in 0..orig_w {
if mask_image.get_pixel(x, y)[0] > 127 {
result_image.put_pixel(x, y, *final_inpainted.get_pixel(x, y));
}
}
}
Ok(result_image)
}

View File

@@ -92,6 +92,9 @@ use rayon::prelude::*;
use std::path::Path; use std::path::Path;
use imageproc::drawing::draw_text_mut; use imageproc::drawing::draw_text_mut;
use ab_glyph::{FontRef, PxScale}; use ab_glyph::{FontRef, PxScale};
use tauri::{AppHandle, Manager};
mod lama;
// Embed the font to ensure it's always available without path issues // Embed the font to ensure it's always available without path issues
const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf"); const FONT_DATA: &[u8] = include_bytes!("../assets/fonts/Roboto-Regular.ttf");
@@ -646,12 +649,12 @@ enum MaskStroke {
} }
#[tauri::command] #[tauri::command]
async fn run_inpainting(path: String, strokes: Vec<MaskStroke>) -> Result<String, String> { async fn run_inpainting(app: AppHandle, path: String, strokes: Vec<MaskStroke>) -> Result<String, String> {
let img = image::open(&path).map_err(|e| e.to_string())?.to_rgba8(); let img = image::open(&path).map_err(|e| e.to_string())?.to_rgba8();
let (width, height) = img.dimensions(); let (width, height) = img.dimensions();
// 1. Create Mask // 1. Create Gray Mask (0 = keep, 255 = remove)
let mut mask = vec![false; (width * height) as usize]; let mut mask = image::GrayImage::new(width, height);
for stroke in strokes { for stroke in strokes {
match stroke { match stroke {
@@ -661,10 +664,11 @@ async fn run_inpainting(path: String, strokes: Vec<MaskStroke>) -> Result<String
let w = (rect.w * width as f64) as i32; let w = (rect.w * width as f64) as i32;
let h = (rect.h * height as f64) as i32; let h = (rect.h * height as f64) as i32;
// Draw 255 on mask
for y in y1..(y1 + h) { for y in y1..(y1 + h) {
for x in x1..(x1 + w) { for x in x1..(x1 + w) {
if x >= 0 && x < width as i32 && y >= 0 && y < height as i32 { if x >= 0 && x < width as i32 && y >= 0 && y < height as i32 {
mask[(y as u32 * width + x as u32) as usize] = true; mask.put_pixel(x as u32, y as u32, image::Luma([255]));
} }
} }
} }
@@ -699,7 +703,7 @@ async fn run_inpainting(path: String, strokes: Vec<MaskStroke>) -> Result<String
let nx = cx + dx; let nx = cx + dx;
let ny = cy + dy; let ny = cy + dy;
if nx >= 0 && nx < width as i32 && ny >= 0 && ny < height as i32 { if nx >= 0 && nx < width as i32 && ny >= 0 && ny < height as i32 {
mask[(ny as u32 * width + nx as u32) as usize] = true; mask.put_pixel(nx as u32, ny as u32, image::Luma([255]));
} }
} }
} }
@@ -710,76 +714,26 @@ async fn run_inpainting(path: String, strokes: Vec<MaskStroke>) -> Result<String
} }
} }
// 2. Diffusion Inpainting (Simple) // 2. Resolve Model Path
// Iteratively replace masked pixels with average of non-masked neighbors let model_path = app.path().resource_dir()
// To make it converge, we update 'mask' as we go (treating filled pixels as valid source) .map_err(|e| e.to_string())?
// But standard diffusion uses double buffering. .join("resources")
// For "Removing Text", simple inward filling works well. .join("lama_fp32.onnx");
let iterations = 30; if !model_path.exists() {
let mut current_img = img.clone(); return Err("Model file 'lama_fp32.onnx' not found in resources.".to_string());
let mut next_img = img.clone();
// Convert mask to a distance map-like state?
// Or just simple neighbor average.
for _ in 0..iterations {
let mut changed = false;
for y in 0..height {
for x in 0..width {
let idx = (y * width + x) as usize;
if mask[idx] {
// It's a hole. Find valid neighbors.
// Valid = Not in ORIGINAL mask (so we pull from original image)
// OR processed in previous iteration?
// Simple logic: Pull from 'current_img'.
let mut sum_r = 0u32;
let mut sum_g = 0u32;
let mut sum_b = 0u32;
let mut count = 0;
// Check 4 neighbors
let neighbors = [
(x.wrapping_sub(1), y), (x + 1, y),
(x, y.wrapping_sub(1)), (x, y + 1)
];
for (nx, ny) in neighbors {
if nx < width && ny < height {
// Weighted check: If neighbor is ALSO masked, it contributes less?
// Or just take everything.
let pixel = current_img.get_pixel(nx, ny);
sum_r += pixel[0] as u32;
sum_g += pixel[1] as u32;
sum_b += pixel[2] as u32;
count += 1;
}
}
if count > 0 {
let avg = image::Rgba([
(sum_r / count) as u8,
(sum_g / count) as u8,
(sum_b / count) as u8,
255
]);
next_img.put_pixel(x, y, avg);
changed = true;
}
}
}
}
current_img = next_img.clone();
if !changed { break; }
} }
// 3. Run Inference
// This is computationally heavy, maybe run in thread? Tauri async commands are already threaded.
let result_img = lama::run_lama_inpainting(&model_path, &img, &mask)?;
// Save to temp // Save to temp
let cache_dir = get_cache_dir(); let cache_dir = get_cache_dir();
let file_name = format!("inpainted_{}.png", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis()); let file_name = format!("inpainted_{}.png", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis());
let out_path = cache_dir.join(file_name); let out_path = cache_dir.join(file_name);
current_img.save(&out_path).map_err(|e| e.to_string())?; result_img.save(&out_path).map_err(|e| e.to_string())?;
Ok(out_path.to_string_lossy().to_string()) Ok(out_path.to_string_lossy().to_string())
} }

View File

@@ -28,6 +28,7 @@
"bundle": { "bundle": {
"active": true, "active": true,
"targets": "all", "targets": "all",
"resources": ["resources/lama_fp32.onnx"],
"icon": [ "icon": [
"icons/32x32.png", "icons/32x32.png",
"icons/128x128.png", "icons/128x128.png",