diff --git a/Cargo.lock b/Cargo.lock index e73bfbc..fe8b365 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -107,9 +107,9 @@ checksum = "e2d098ff73c1ca148721f37baad5ea6a465a13f9573aba8641fbbbae8164a54e" [[package]] name = "async-trait" -version = "0.1.69" +version = "0.1.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b2d0f03b3640e3a630367e40c468cb7f309529c708ed1d88597047b0e7c6ef7" +checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", @@ -325,9 +325,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" [[package]] name = "block-buffer" @@ -346,18 +346,18 @@ checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" [[package]] name = "burn" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "406c6dd1ae8e5a12e982b62db8cca7c152abaddcfc6bd0ecabd685a786599fae" +checksum = "e06bb3dfa90408228c879224e26a8bbf072aa2a68194c9b512f715624525c7cc" dependencies = [ "burn-core", ] [[package]] name = "burn-autodiff" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8284166dac96910f610a70b86240bd3b8162d560dbe2fe8d4a8b36af5759324" +checksum = "b20c3ba4141da32bbcc48a4ce33a0fbf09742dfb6f17d6e781f27e076bf06d82" dependencies = [ "burn-common", "burn-tensor", @@ -368,11 +368,13 @@ dependencies = [ [[package]] name = "burn-common" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ebcde23b1c85f6f2b9363c20cdc91a86a274c43dc5a045915aa93090647bc68" +checksum = "2ce227728da80c4c7f932e66900d285ff6e04b815d77d769fd2ade463acb0d52" dependencies = [ - "const-random", + "async-trait", + "derive-new", + "getrandom", "rand", "spin", "uuid", @@ -380,9 +382,9 @@ dependencies = [ [[package]] name = "burn-core" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b248672dcd5a554b26795fa8e25a3d3629c3240bfa613b03f35258ee466d9d9" +checksum = "fbbeb7c07436e89d3b7e2445198f5d4142255bf74564eb4dc6cb9898f00d2da5" dependencies = [ "bincode 2.0.0-rc.3", "burn-common", @@ -392,7 +394,7 @@ dependencies = [ "derive-new", "flate2", "half", - "hashbrown 0.14.0", + "hashbrown 0.14.2", "libm", "log", "rand", @@ -404,14 +406,13 @@ dependencies = [ [[package]] name = "burn-dataset" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c14640e8f06cbfeb67f8d389abbc8fa4fbdd53aeb9a2591308bbb7de83008793" +checksum = "1c7f12f9a55e82d327384e20c394c79e5a414b245fe59c37346e7a98234d1707" dependencies = [ "csv", "derive-new", "dirs", - "gix-tempfile", "rand", "rmp-serde", "sanitize-filename", @@ -425,9 +426,9 @@ dependencies = [ [[package]] name = "burn-derive" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "609e6c99fe67c8da80390c6432c07b13a9b0064548e915d4d769ca0ef7071101" +checksum = "0726d6006ab4f1c65b37b079a8663ad168110976fb57234764495ed8c49a94b6" dependencies = [ "derive-new", "proc-macro2", @@ -437,9 +438,9 @@ dependencies = [ [[package]] name = "burn-ndarray" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d76677931c670b5bb74bfd1ba698b6caf399b57393485f33f1bef224a3317dcb" +checksum = "0cbba6e5180a91a48e2e5da7e74107d82dae9cc94a772a73783eb1381dd71fa2" dependencies = [ "burn-autodiff", "burn-common", @@ -456,9 +457,9 @@ dependencies = [ [[package]] name = "burn-tch" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5235b745300abda1f92040972e170d16804ea2237ba5d7bed5ec78ac0300fb92" +checksum = "97d6a7c694493aeef181f495cdcc0e2e694804d1e5cce474d5689fdb91dae9b2" dependencies = [ "burn-tensor", "half", @@ -469,14 +470,15 @@ dependencies = [ [[package]] name = "burn-tensor" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "726b7665aeeb669364b47b9b741952dc3215d8d17611d4274381b7b5b8943f0b" +checksum = "c4ddf687c9e2ddf235bb1528530b5e6c04601d4240aca78de6484fd01cc81e9f" dependencies = [ + "burn-common", "burn-tensor-testgen", "derive-new", "half", - "hashbrown 0.14.0", + "hashbrown 0.14.2", "libm", "num-traits", "rand", @@ -486,9 +488,9 @@ dependencies = [ [[package]] name = "burn-tensor-testgen" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc2efb5c94891470910f0ae5bd824f9b265dd94289b6927d27c9617523a14f21" +checksum = "f92da000e738bcf20ec873d63c3b46ed4e0afa10fb14ea3133874886c0fb591f" dependencies = [ "proc-macro2", "quote", @@ -660,26 +662,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b" -[[package]] -name = "const-random" -version = "0.1.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaf16c9c2c612020bcfd042e170f6e32de9b9d75adb5277cdbbd2e2c8c8299a" -dependencies = [ - "const-random-macro", -] - -[[package]] -name = "const-random-macro" -version = "0.1.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" -dependencies = [ - "getrandom", - "once_cell", - "tiny-keccak", -] - [[package]] name = "constant_time_eq" version = "0.1.5" @@ -1050,19 +1032,6 @@ dependencies = [ "syn 1.0.107", ] -[[package]] -name = "dashmap" -version = "5.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" -dependencies = [ - "cfg-if 1.0.0", - "hashbrown 0.12.3", - "lock_api 0.4.11", - "once_cell", - "parking_lot_core 0.9.9", -] - [[package]] name = "debugid" version = "0.8.0" @@ -1238,15 +1207,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" -[[package]] -name = "faster-hex" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "239f7bfb930f820ab16a9cd95afc26f88264cf6905c960b340a615384aa3338a" -dependencies = [ - "serde", -] - [[package]] name = "fastrand" version = "2.0.1" @@ -1458,7 +1418,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27d12c0aed7f1e24276a241aadc4cb8ea9f83000f34bc062b7cc2d51e3b0fabd" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "debugid", "fxhash", "serde", @@ -1493,9 +1453,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if 1.0.0", "js-sys", @@ -1525,58 +1485,6 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "gix-features" -version = "0.33.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f77decb545f63a52852578ef5f66ecd71017ffc1983d551d5fa2328d6d9817f" -dependencies = [ - "gix-hash", - "gix-trace", - "libc", -] - -[[package]] -name = "gix-fs" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53d5089f3338647776733a75a800a664ab046f56f21c515fa4722e395f877ef8" -dependencies = [ - "gix-features", -] - -[[package]] -name = "gix-hash" -version = "0.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d4796bac3aaf0c2f8bea152ca924ae3bdc5f135caefe6431116bcd67e98eab9" -dependencies = [ - "faster-hex", - "thiserror", -] - -[[package]] -name = "gix-tempfile" -version = "8.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea558d3daf3b1d0001052b12218c66c8f84788852791333b633d7eeb6999db1" -dependencies = [ - "dashmap", - "gix-fs", - "libc", - "once_cell", - "parking_lot 0.12.1", - "signal-hook", - "signal-hook-registry", - "tempfile", -] - -[[package]] -name = "gix-trace" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b6d623a1152c3facb79067d6e2ecdae48130030cf27d6eb21109f13bd7b836" - [[package]] name = "glob" version = "0.3.1" @@ -1615,9 +1523,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "f93e7192158dbcda357bdec5fb5788eebf8bbac027f3f33e719d29135ae84156" dependencies = [ "ahash 0.8.3", "allocator-api2", @@ -1759,7 +1667,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown 0.14.2", ] [[package]] @@ -1931,9 +1839,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libloading" @@ -2314,9 +2222,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.15" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", "libm", @@ -2379,20 +2287,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f842b1982eb6c2fe34036a4fbfb06dd185a3f5c8edfaacdf7d1ea10b07de6252" dependencies = [ "lock_api 0.3.4", - "parking_lot_core 0.6.3", + "parking_lot_core", "rustc_version 0.2.3", ] -[[package]] -name = "parking_lot" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" -dependencies = [ - "lock_api 0.4.11", - "parking_lot_core 0.9.9", -] - [[package]] name = "parking_lot_core" version = "0.6.3" @@ -2408,19 +2306,6 @@ dependencies = [ "winapi 0.3.9", ] -[[package]] -name = "parking_lot_core" -version = "0.9.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" -dependencies = [ - "cfg-if 1.0.0", - "libc", - "redox_syscall 0.4.1", - "smallvec 1.11.0", - "windows-targets", -] - [[package]] name = "password-hash" version = "0.4.2" @@ -2662,9 +2547,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.29" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "573015e8ab27661678357f27dc26460738fd2b6c86e46f386fde94cb5d913105" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] @@ -2717,9 +2602,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", @@ -2727,14 +2612,12 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ - "crossbeam-channel", "crossbeam-deque 0.8.2", "crossbeam-utils 0.8.14", - "num_cpus", ] [[package]] @@ -2761,15 +2644,6 @@ dependencies = [ "bitflags 1.3.2", ] -[[package]] -name = "redox_syscall" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" -dependencies = [ - "bitflags 1.3.2", -] - [[package]] name = "redox_users" version = "0.4.3" @@ -2904,7 +2778,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19599f60a688b5160247ee9c37a6af8b0c742ee8b160c5b44acc0f0eb265a59f" dependencies = [ "csv", - "hashbrown 0.14.0", + "hashbrown 0.14.2", "itertools 0.11.0", "lazy_static", "protobuf", @@ -2976,7 +2850,7 @@ version = "0.38.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19ed4fa021d81c8392ce04db050a3da9a60299050b7ae1cf482d862b54a7218f" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "errno", "libc", "linux-raw-sys 0.4.11", @@ -3175,25 +3049,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" -[[package]] -name = "signal-hook" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" -dependencies = [ - "libc", - "signal-hook-registry", -] - -[[package]] -name = "signal-hook-registry" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" -dependencies = [ - "libc", -] - [[package]] name = "simd-adler32" version = "0.3.7" @@ -3284,7 +3139,8 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "stablediffusion" -version = "0.1.0" +version = "0.2.0" +source = "git+https://github.com/tychedelia/stable-diffusion-burn.git#8d9dc1015ad0467a751d0405f8cb3910060a527b" dependencies = [ "burn", "burn-autodiff", @@ -3302,21 +3158,21 @@ dependencies = [ [[package]] name = "strum" -version = "0.24.1" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "063e6045c0e62079840579a7e47a355ae92f60eb74daaf156fb1e84ba164e63f" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" [[package]] name = "strum_macros" -version = "0.24.3" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" dependencies = [ "heck", "proc-macro2", "quote", "rustversion", - "syn 1.0.107", + "syn 2.0.38", ] [[package]] @@ -3383,9 +3239,9 @@ checksum = "df8e77cb757a61f51b947ec4a7e3646efd825b73561db1c232a8ccb639e611a0" [[package]] name = "tch" -version = "0.13.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cbd9ce6fb581a1b918db880b649d1364b50f7f6717eda8497bcdc929cddd4b9" +checksum = "0ed5dddab3812892bf5fb567136e372ea49f31672931e21cec967ca68aec03da" dependencies = [ "half", "lazy_static", @@ -3567,18 +3423,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", @@ -3632,15 +3488,6 @@ dependencies = [ "time-core", ] -[[package]] -name = "tiny-keccak" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" -dependencies = [ - "crunchy", -] - [[package]] name = "tinyvec" version = "1.6.0" @@ -3745,7 +3592,7 @@ dependencies = [ "log", "mio", "num_cpus", - "parking_lot 0.9.0", + "parking_lot", "slab", "tokio-executor", "tokio-io", @@ -3883,9 +3730,9 @@ dependencies = [ [[package]] name = "torch-sys" -version = "0.13.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42b2b81a479510717464df1d07c02cb4aebb26539a39b5db6637dda114a476cb" +checksum = "803446f89fb877a117503dbfb8375b6a29fa8b0e0f44810fac3863c798ecef22" dependencies = [ "anyhow", "cc", @@ -4021,9 +3868,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d023da39d1fde5a8a3fe1f3e01ca9632ada0a63e9797de55a879d6e2236277be" +checksum = "88ad59a7560b41a70d191093a945f0b87bc1deeda46fb237479708a1d6b6cdfc" [[package]] name = "version_check" diff --git a/plugins/top/stable-diffusion/Cargo.toml b/plugins/top/stable-diffusion/Cargo.toml index ae513cf..e9243a2 100644 --- a/plugins/top/stable-diffusion/Cargo.toml +++ b/plugins/top/stable-diffusion/Cargo.toml @@ -14,9 +14,9 @@ crate-type = ["staticlib"] [dependencies] td-rs-top = { path = "../../../td-rs-top" } td-rs-derive = { path = "../../../td-rs-derive" } -stablediffusion = { git = "https://github.com/tychedelia/stable-diffusion-burn.git"} -burn = "0.9" -burn-ndarray = "0.9" -burn-tch = "0.9" -burn-autodiff = "0.9" -tch = "0.13.0" \ No newline at end of file +stablediffusion = { git = "https://github.com/tychedelia/stable-diffusion-burn.git", version = "0.2.0" } +burn = "0.10" +burn-ndarray = "0.10" +burn-tch = "0.10" +burn-autodiff = "0.10" +tch = "0.14.0" \ No newline at end of file diff --git a/plugins/top/stable-diffusion/src/lib.rs b/plugins/top/stable-diffusion/src/lib.rs index b3c3139..449f6ba 100644 --- a/plugins/top/stable-diffusion/src/lib.rs +++ b/plugins/top/stable-diffusion/src/lib.rs @@ -7,10 +7,16 @@ use stablediffusion::model::stablediffusion::{StableDiffusion, StableDiffusionCo use stablediffusion::tokenizer::SimpleTokenizer; use std::fmt::format; use std::path::PathBuf; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, MutexGuard, RwLock}; +use std::sync::atomic::AtomicBool; +use std::sync::mpsc::{Receiver, SyncSender, TryRecvError, TrySendError}; +use std::thread::JoinHandle; use td_rs_derive::Params; use td_rs_top::*; +const WIDTH: usize = 512; +const HEIGHT: usize = 512; + #[derive(Params, Default, Clone, Debug)] struct StableDiffusionTopParams { #[param(label = "Reset")] @@ -26,38 +32,28 @@ pub struct StableDiffusionTop { params: StableDiffusionTopParams, execute_count: u32, context: TopContext, - sd: Option>>, - pub tokenizer: SimpleTokenizer, -} - -impl StableDiffusionTop { - fn load_stable_diffusion_model_file( - self: &Self, - filename: &PathBuf, - ) -> Result, record::RecorderError> { - BinFileRecorder::::new() - .load(filename.into()) - .map(|record| StableDiffusionConfig::new().init().load_record(record)) - } + sd_producer: StableDiffusionProducer, + init: bool, + prompt: String, } impl TopNew for StableDiffusionTop { fn new(_info: NodeInfo, context: TopContext) -> Self { - let tokenizer = SimpleTokenizer::new().unwrap(); Self { params: Default::default(), execute_count: 0, context, - sd: None, - tokenizer, + sd_producer: StableDiffusionProducer::new(), + init: false, + prompt: "".to_string(), } } } impl OpInfo for StableDiffusionTop { const OPERATOR_LABEL: &'static str = "Stable Diffusion"; - const OPERATOR_TYPE: &'static str = "Stable Diffusion"; - const MAX_INPUTS: usize = 1; + const OPERATOR_TYPE: &'static str = "Stablediffusion"; + const MAX_INPUTS: usize = 0; const MIN_INPUTS: usize = 0; } @@ -76,60 +72,48 @@ impl Op for StableDiffusionTop { } impl Top for StableDiffusionTop { + fn general_info(&self, _input: &OperatorInputs) -> TopGeneralInfo { + TopGeneralInfo { + cook_every_frame: false, + cook_every_frame_if_asked: true, + input_size_index: 0, + } + } + fn execute(&mut self, mut output: TopOutput, input: &OperatorInputs) { if !self.params.model.exists() && !self.params.model.is_file() { + self.set_warning("A model must be loaded!"); return; } + self.set_warning(""); - type Backend = TchBackend; - let device = TchDevice::Cuda(0); - - let unconditional_guidance_scale: f64 = 7.5; - let n_steps: usize = 20; - let prompt = self.params.prompt.as_str(); - if self.sd.is_none() { - let mut sd = self.load_stable_diffusion_model_file(&self.params.model); - match sd { - Err(err) => self.set_error(format!("Error loading model: {}", err).as_str()), - Ok(sd) => { - self.sd = Some(sd.clone()); - } - } - return; + if self.prompt != self.params.prompt { + self.sd_producer.set_prompt(&self.params.prompt); + self.prompt = self.params.prompt.clone(); + } + if !self.init { + self.sd_producer.init_model(&self.params.model); + self.init = true; } - if let Some(sd) = &self.sd { - let sd = sd.clone().to_device(&device); - let unconditional_context = sd.unconditional_context(&self.tokenizer); - let context = sd.context(&self.tokenizer, prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples - - let images = sd.sample_image( - context, - unconditional_context, - unconditional_guidance_scale, - n_steps, - ); - - let image = &images[0]; + if let Some(image) = self.sd_producer.get_image() { let mut buf = self .context .create_output_buffer(image.len(), TopBufferFlags::None); buf.data_mut().copy_from_slice(image.as_slice()); - let height = 512; - let width = 512; let info = UploadInfo { buffer_offset: 0, texture_desc: TextureDesc { tex_dim: TexDim::E2D, - width, - height, + width: WIDTH, + height: HEIGHT, pixel_format: PixelFormat::BGRA8Fixed, aspect_x: 0.0, depth: 1, aspect_y: 0.0, }, - first_pixel: FirstPixel::BottomLeft, + first_pixel: FirstPixel::TopLeft, color_buffer_index: 0, }; output.upload_buffer(&mut buf, &info); @@ -137,4 +121,136 @@ impl Top for StableDiffusionTop { } } +struct StableDiffusionProducer { + sd: Arc>>>>, + prompt: Arc>, + produce_loop: JoinHandle<()>, + rx: Receiver>, + trigger_tx: SyncSender<()>, +} + +impl StableDiffusionProducer { + fn new() -> Self { + let (tx, rx) = std::sync::mpsc::sync_channel(3); + let (trigger_tx, trigger_rx) = std::sync::mpsc::sync_channel(1); + let tokenizer = SimpleTokenizer::new().unwrap(); + let sd = Arc::new(RwLock::new(None::>>)); + let prompt = Arc::new(RwLock::new(String::new())); + let produce_loop_sd = sd.clone(); + let produce_loop_prompt = prompt.clone(); + let produce_loop = Self::produce_loop(tx, trigger_rx, tokenizer, produce_loop_sd, produce_loop_prompt); + + StableDiffusionProducer { + sd, + rx, + trigger_tx, + prompt, + produce_loop, + } + } + + fn produce_loop(tx: SyncSender>, trigger_rx: Receiver<()>, tokenizer: SimpleTokenizer, produce_loop_sd: Arc>>>>, produce_loop_prompt: Arc>) -> JoinHandle<()> { + let produce_loop = std::thread::spawn(move || { + loop { + // Wait for a frame to be requested + let _ = trigger_rx.recv().unwrap(); + + let sd = produce_loop_sd.read().unwrap(); + match sd.as_ref() { + None => {} + Some(sd) => { + let device = TchDevice::Cuda(0); + let sd = sd.clone(); + let sd = sd.to_device(&device); + let unconditional_context = sd.unconditional_context(&tokenizer); + let unconditional_guidance_scale: f64 = 7.5; + let n_steps: usize = 20; + + let prompt = produce_loop_prompt.read().unwrap(); + let context = sd.context(&tokenizer, &prompt).unsqueeze::<3>(); //.repeat(0, 2); // generate 2 samples + let images = sd.sample_image( + context, + unconditional_context, + unconditional_guidance_scale, + n_steps, + ); + + let image = &images[0]; + let layer_bytes = + (WIDTH * HEIGHT * 4 * std::mem::size_of::()) + as u64; + let mut pixels = Vec::with_capacity(layer_bytes as usize); + + for chunk in image.chunks(3) { + pixels.push(chunk[2]); // Blue + pixels.push(chunk[1]); // Green + pixels.push(chunk[0]); // Red + pixels.push(255); // Alpha (full opacity) + } + + tx.send(pixels).unwrap(); + } + } + } + }); + produce_loop + } + + fn init_model(&mut self, model_file: &PathBuf) { + let self_sd = self.sd.clone(); + let model_file = model_file.clone(); + // Load the model in a separate thread + std::thread::spawn(move || { + let sd = Self::load_stable_diffusion_model_file(&model_file) + .unwrap(); + *self_sd.write().unwrap() = Some(sd); + }); + return; + } + + fn set_prompt(&mut self, p: &str) { + let mut prompt = self.prompt.write().unwrap(); + *prompt = p.to_string(); + } + + fn get_image(&self) -> Option> { + match self.trigger_tx.try_send(()) { + Ok(_) => {} + Err(err) => { + match err { + TrySendError::Full(_) => { + // would block, so just return + } + TrySendError::Disconnected(_) => { + panic!("Stable Diffusion Producer thread disconnected!") + } + } + } + }; + + match self.rx.try_recv() { + Ok(img) => { + return Some(img); + } + Err(err) => { + match err { + TryRecvError::Empty => {} + TryRecvError::Disconnected => { + panic!("Stable Diffusion Producer thread disconnected!") + } + } + } + }; + + None + } + fn load_stable_diffusion_model_file( + filename: &PathBuf, + ) -> Result, record::RecorderError> { + BinFileRecorder::::new() + .load(filename.into()) + .map(|record| StableDiffusionConfig::new().init().load_record(record)) + } +} + top_plugin!(StableDiffusionTop); diff --git a/td-rs-base/src/lib.rs b/td-rs-base/src/lib.rs index 0fcdf34..d6d6ac0 100644 --- a/td-rs-base/src/lib.rs +++ b/td-rs-base/src/lib.rs @@ -87,7 +87,10 @@ pub trait Op { // are only ever called from the body of the plugin // and not exposed to C++. unsafe { - INFO_STR.get_mut().unwrap().replace_range(.., info); + let i = INFO_STR.get_mut(); + if let Some(i) = i { + i.replace_range(.., info); + } } } @@ -101,7 +104,10 @@ pub trait Op { // are only ever called from the body of the plugin // and not exposed to C++. unsafe { - ERROR_STR.get_mut().unwrap().replace_range(.., error); + let e = ERROR_STR.get_mut(); + if let Some(e) = e { + e.replace_range(.., error); + } } } @@ -115,7 +121,10 @@ pub trait Op { // are only ever called from the body of the plugin // and not exposed to C++. unsafe { - WARNING_STR.get_mut().unwrap().replace_range(.., warning); + let w = WARNING_STR.get_mut(); + if let Some(w) = w { + w.replace_range(.., warning); + } } } diff --git a/td-rs-xtask/msvc/top/RustTOP.vcxproj b/td-rs-xtask/msvc/top/RustTOP.vcxproj index 732f217..fcd87af 100644 --- a/td-rs-xtask/msvc/top/RustTOP.vcxproj +++ b/td-rs-xtask/msvc/top/RustTOP.vcxproj @@ -94,8 +94,8 @@ Windows true true - $(AdditionalLibraryDirectories) - python311.lib;bcrypt.lib;UserEnv.Lib;Ws2_32.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;msvcrt.lib;ntdll.lib;.\target\x86_64-pc-windows-msvc\release\td_rs.lib;.\target\x86_64-pc-windows-msvc\release\td_rs_top.lib;$(PLUGIN);%(AdditionalDependencies) + $(AdditionalLibraryDirectories);C:\Users\Charlotte\Downloads\libtorch-2.1.0\libtorch\lib + torch.lib;torch_cuda.lib;torch_cpu.lib;c10.lib;c10_cuda.lib;python311.lib;bcrypt.lib;UserEnv.Lib;Ws2_32.lib;kernel32.lib;user32.lib;gdi32.lib;winspool.lib;comdlg32.lib;advapi32.lib;shell32.lib;ole32.lib;oleaut32.lib;uuid.lib;odbc32.lib;odbccp32.lib;msvcrt.lib;ntdll.lib;.\target\x86_64-pc-windows-msvc\release\td_rs.lib;.\target\x86_64-pc-windows-msvc\release\td_rs_top.lib;$(PLUGIN);%(AdditionalDependencies)