From ed1d2483ecda6ce7ce97e435c928cd92b8d6319b Mon Sep 17 00:00:00 2001 From: usamoi Date: Mon, 25 Mar 2024 16:12:13 +0800 Subject: [PATCH 01/16] chore: add another implementation of multiversion Signed-off-by: usamoi --- .cargo/config.toml | 3 + Cargo.lock | 238 +++++++-------- Cargo.toml | 1 - crates/base/Cargo.toml | 1 - crates/base/src/lib.rs | 1 + crates/base/src/vector/bvecf32.rs | 78 ++--- crates/base/src/vector/svecf32.rs | 88 ++---- crates/base/src/vector/vecf16.rs | 90 ++---- crates/base/src/vector/vecf32.rs | 72 +---- crates/base/src/vector/veci8.rs | 51 +--- crates/c/tests/f16.rs | 24 +- crates/detect/Cargo.toml | 2 + crates/detect/src/lib.rs | 22 +- crates/detect/src/linux.rs | 21 -- crates/detect/src/x86_64.rs | 120 -------- crates/detect/tests/linux.rs | 7 - crates/detect/tests/x86_64.rs | 21 -- crates/detect_macros/Cargo.toml | 21 ++ crates/detect_macros/src/lib.rs | 318 ++++++++++++++++++++ crates/memfd/src/lib.rs | 55 ++-- crates/quantization/Cargo.toml | 2 +- crates/quantization/src/lib.rs | 1 + crates/quantization/src/product/operator.rs | 126 ++------ crates/quantization/src/scalar/operator.rs | 84 +----- src/datatype/memory_veci8.rs | 2 +- src/lib.rs | 2 +- 26 files changed, 617 insertions(+), 834 deletions(-) delete mode 100644 crates/detect/src/linux.rs delete mode 100644 crates/detect/src/x86_64.rs delete mode 100644 crates/detect/tests/linux.rs delete mode 100644 crates/detect/tests/x86_64.rs create mode 100644 crates/detect_macros/Cargo.toml create mode 100644 crates/detect_macros/src/lib.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index 13c456b5d..c4db64902 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,3 +1,6 @@ +[build] +rustdocflags = ["--document-private-items"] + [target.'cfg(target_os="macos")'] # Postgres symbols won't be available until runtime rustflags = ["-Clink-arg=-Wl,-undefined,dynamic_lookup"] diff --git a/Cargo.lock b/Cargo.lock index 62dc6ecc0..514953e87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -19,9 +19,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "aho-corasick" -version = "1.1.2" +version = "1.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" dependencies = [ "memchr", ] @@ -76,15 +76,15 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.80" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" +checksum = "0952808a6c2afd1aa8947271f3a60f1a6763c7b912d210184c5149b5cf147247" [[package]] name = "arc-swap" -version = "1.7.0" +version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b3d0060af21e8d11a926981cc00c6c1541aa91dd64b9f881985c3da1094425f" +checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" [[package]] name = "arrayvec" @@ -154,8 +154,8 @@ dependencies = [ "async-lock 3.3.0", "async-task", "concurrent-queue", - "fastrand 2.0.1", - "futures-lite 2.2.0", + "fastrand 2.0.2", + "futures-lite 2.3.0", "slab", ] @@ -170,7 +170,7 @@ dependencies = [ "async-io 2.3.2", "async-lock 3.3.0", "blocking", - "futures-lite 2.2.0", + "futures-lite 2.3.0", "once_cell", ] @@ -204,10 +204,10 @@ dependencies = [ "cfg-if", "concurrent-queue", "futures-io", - "futures-lite 2.2.0", + "futures-lite 2.3.0", "parking", - "polling 3.5.0", - "rustix 0.38.31", + "polling 3.6.0", + "rustix 0.38.32", "slab", "tracing", "windows-sys 0.52.0", @@ -255,7 +255,7 @@ dependencies = [ "cfg-if", "event-listener 3.1.0", "futures-lite 1.13.0", - "rustix 0.38.31", + "rustix 0.38.32", "windows-sys 0.48.0", ] @@ -271,7 +271,7 @@ dependencies = [ "cfg-if", "futures-core", "futures-io", - "rustix 0.38.31", + "rustix 0.38.32", "signal-hook-registry", "slab", "windows-sys 0.48.0", @@ -313,13 +313,13 @@ checksum = "fbb36e985947064623dbd357f727af08ffd077f93d696782f3c56365fa2e2799" [[package]] name = "async-trait" -version = "0.1.77" +version = "0.1.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" +checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -346,9 +346,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.69" +version = "0.3.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" dependencies = [ "addr2line", "cc", @@ -368,7 +368,6 @@ dependencies = [ "detect", "half 2.4.0", "libc", - "multiversion", "num-traits", "rand", "serde", @@ -409,7 +408,7 @@ version = "0.69.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "cexpr", "clang-sys", "itertools 0.12.1", @@ -420,7 +419,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -446,9 +445,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.2" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed570934406eb16438a4e976b1b4500774099c13b8cb96eec99f620f05090ddf" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "bitvec" @@ -471,9 +470,9 @@ dependencies = [ "async-channel 2.2.0", "async-lock 3.3.0", "async-task", - "fastrand 2.0.1", + "fastrand 2.0.2", "futures-io", - "futures-lite 2.2.0", + "futures-lite 2.3.0", "piper", "tracing", ] @@ -486,9 +485,9 @@ checksum = "7ff69b9dd49fd426c69a0db9fc04dd934cdb6645ff000864d98f7e2af8830eaa" [[package]] name = "bytemuck" -version = "1.14.3" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" dependencies = [ "bytemuck_derive", ] @@ -501,7 +500,7 @@ checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -512,9 +511,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "c" @@ -581,7 +580,7 @@ dependencies = [ "bytemuck", "log", "memmap2", - "rustix 0.38.31", + "rustix 0.38.32", "serde", "serde_json", ] @@ -718,7 +717,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -729,7 +728,7 @@ checksum = "a668eda54683121533a393014d8692171709ff57a7d61f187b6e782719f8933f" dependencies = [ "darling_core", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -749,10 +748,20 @@ dependencies = [ name = "detect" version = "0.0.0" dependencies = [ - "rustix 0.38.31", + "detect_macros", + "rustix 0.38.32", "std_detect", ] +[[package]] +name = "detect_macros" +version = "0.0.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.55", +] + [[package]] name = "dirs-next" version = "2.0.0" @@ -838,7 +847,7 @@ checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -960,9 +969,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.0.1" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" [[package]] name = "fixedbitset" @@ -1040,11 +1049,11 @@ dependencies = [ [[package]] name = "futures-lite" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445ba825b27408685aaecefd65178908c36c6e96aaf6d8599419d46e624192ba" +checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5" dependencies = [ - "fastrand 2.0.1", + "fastrand 2.0.2", "futures-core", "futures-io", "parking", @@ -1059,7 +1068,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -1127,9 +1136,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.3.24" +version = "0.3.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bb2c4422095b67ee78da96fbb51a4cc413b3b25883c7717ff7ca1ab31022c9c9" +checksum = "4fbd2820c5e49886948654ab546d0688ff24530286bdcf8fca3cefb16d4618eb" dependencies = [ "bytes", "fnv", @@ -1377,9 +1386,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.2.5" +version = "2.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" dependencies = [ "equivalent", "hashbrown", @@ -1553,7 +1562,7 @@ version = "0.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85c833ca1e66078851dba29046874e38f08b2c883700aa29a03ddd3b23814ee8" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "libc", "redox_syscall", ] @@ -1601,7 +1610,7 @@ version = "0.0.0" dependencies = [ "detect", "rand", - "rustix 0.38.31", + "rustix 0.38.32", ] [[package]] @@ -1654,33 +1663,11 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "multiversion" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2c7b9d7fe61760ce5ea19532ead98541f6b4c495d87247aff9826445cf6872a" -dependencies = [ - "multiversion-macros", - "target-features", -] - -[[package]] -name = "multiversion-macros" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26a83d8500ed06d68877e9de1dde76c1dbb83885dcdbda4ef44ccbc3fbda2ac8" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", - "target-features", -] - [[package]] name = "new_debug_unreachable" -version = "1.0.4" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54" +checksum = "650eef8c711430f1a879fdd01d4745a7deea475becfb90269c06775983bbf086" [[package]] name = "nom" @@ -1811,7 +1798,7 @@ version = "0.12.0-alpha.1" source = "git+https://github.com/tensorchord/pgrx.git?branch=v0.12.0-alpha.1-patch#1a3459f597396a8d3dad0947a1d646f4cbe8e1ae" dependencies = [ "atomic-traits", - "bitflags 2.4.2", + "bitflags 2.5.0", "bitvec", "enum-map", "heapless", @@ -1838,7 +1825,7 @@ dependencies = [ "pgrx-sql-entity-graph", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -1877,7 +1864,7 @@ dependencies = [ "serde", "shlex", "sptr", - "syn 2.0.52", + "syn 2.0.55", "walkdir", ] @@ -1891,7 +1878,7 @@ dependencies = [ "petgraph", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", "thiserror", "unescape", ] @@ -1930,7 +1917,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "668d31b1c4eba19242f2088b2bf3316b82ca31082a8335764db4e083db7485d4" dependencies = [ "atomic-waker", - "fastrand 2.0.1", + "fastrand 2.0.2", "futures-io", ] @@ -1952,14 +1939,15 @@ dependencies = [ [[package]] name = "polling" -version = "3.5.0" +version = "3.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24f040dee2588b4963afb4e420540439d126f73fdacf4a9c486a96d840bac3c9" +checksum = "e0c976a60b2d7e99d6f229e414670a9b85d13ac305cc6d1e9c134de58c5aaaf6" dependencies = [ "cfg-if", "concurrent-queue", + "hermit-abi", "pin-project-lite", - "rustix 0.38.31", + "rustix 0.38.32", "tracing", "windows-sys 0.52.0", ] @@ -2002,9 +1990,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] @@ -2015,8 +2003,8 @@ version = "0.0.0" dependencies = [ "base", "common", + "detect", "elkan_k_means", - "multiversion", "num-traits", "rand", "serde_json", @@ -2082,14 +2070,14 @@ name = "rayon" version = "0.0.0" dependencies = [ "log", - "rayon 1.9.0", + "rayon 1.10.0", ] [[package]] name = "rayon" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" dependencies = [ "either", "rayon-core", @@ -2127,9 +2115,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.3" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", @@ -2156,9 +2144,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.25" +version = "0.11.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946" +checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" dependencies = [ "base64", "bytes", @@ -2247,11 +2235,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.31" +version = "0.38.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" +checksum = "65e04861e65f21776e67888bfbea442b3642beaa0138fdb1dd7a84a52dffdb89" dependencies = [ - "bitflags 2.4.2", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys 0.4.13", @@ -2356,7 +2344,7 @@ version = "0.0.0" dependencies = [ "libc", "log", - "rustix 0.38.31", + "rustix 0.38.32", ] [[package]] @@ -2392,7 +2380,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -2489,9 +2477,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.1" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6ecd384b10a64542d77071bd64bd7b231f4ed5940fba55e98c3de13824cf3d7" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" @@ -2580,9 +2568,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.52" +version = "2.0.55" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b699d15b36d1f02c3e7c69f8ffef53de37aefae075d8488d4ba1a7788d574a07" +checksum = "002a1b3dbf967edfafc32655d0f377ab0bb7b994aa1d32c8cc7e9b8bf3ebb8f0" dependencies = [ "proc-macro2", "quote", @@ -2597,20 +2585,20 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "system-configuration" -version = "0.6.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ - "bitflags 2.4.2", + "bitflags 1.3.2", "core-foundation", "system-configuration-sys", ] [[package]] name = "system-configuration-sys" -version = "0.6.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" dependencies = [ "core-foundation-sys", "libc", @@ -2622,12 +2610,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" -[[package]] -name = "target-features" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" - [[package]] name = "term" version = "0.7.0" @@ -2641,22 +2623,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -2729,7 +2711,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] @@ -2758,9 +2740,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.10" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a9aad4a3066010876e8dcf5a8a06e70a558751117a145c6ce2b82c2e2054290" +checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3" dependencies = [ "serde", "serde_spanned", @@ -2779,9 +2761,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.6" +version = "0.22.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c1b5fd4128cc8d3e0cb74d4ed9a9cc7c7284becd4df68f5f940e1ad123606f6" +checksum = "8e40bb779c5187258fd7aad0eb68cb8706a0a81fa712fbea808ab43c4b8374c4" dependencies = [ "indexmap", "serde", @@ -2900,9 +2882,9 @@ checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "uuid" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00cc9702ca12d3c81455259621e676d0f7251cec66a21e98fe2e9a37db93b2a" +checksum = "a183cf7feeba97b4dd1c0d46788634f6221d87fa961b305bed08c851829efcc0" dependencies = [ "getrandom", "serde", @@ -2935,14 +2917,14 @@ dependencies = [ "proc-macro2", "quote", "regex", - "syn 2.0.52", + "syn 2.0.55", ] [[package]] name = "value-bag" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "126e423afe2dd9ac52142e7e9d5ce4135d7e13776c529d27fd6bc49f19e3280b" +checksum = "74797339c3b98616c009c7c3eb53a0ce41e85c8ec66bd3db96ed132d20cfdee8" [[package]] name = "vectors" @@ -2965,7 +2947,7 @@ dependencies = [ "paste", "pgrx", "rand", - "rustix 0.38.31", + "rustix 0.38.32", "scopeguard", "send_fd", "serde", @@ -3035,7 +3017,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", "wasm-bindgen-shared", ] @@ -3069,7 +3051,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.52", + "syn 2.0.55", "wasm-bindgen-backend", "wasm-bindgen-shared", ] diff --git a/Cargo.toml b/Cargo.toml index 68c2912c8..ab50e25f3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,7 +80,6 @@ half = { version = "2.4.0", features = [ libc = "0.2.153" log = "0.4.21" memmap2 = "0.9.4" -multiversion = "0.7.3" num-traits = "0.2.18" parking_lot = "0.12.1" paste = "1.0.14" diff --git a/crates/base/Cargo.toml b/crates/base/Cargo.toml index 29ec893f7..a0eb840d0 100644 --- a/crates/base/Cargo.toml +++ b/crates/base/Cargo.toml @@ -7,7 +7,6 @@ edition.workspace = true bytemuck.workspace = true half.workspace = true libc.workspace = true -multiversion.workspace = true num-traits.workspace = true rand.workspace = true serde.workspace = true diff --git a/crates/base/src/lib.rs b/crates/base/src/lib.rs index 3d93c8f21..b36f744f7 100644 --- a/crates/base/src/lib.rs +++ b/crates/base/src/lib.rs @@ -1,4 +1,5 @@ #![feature(core_intrinsics)] +#![feature(doc_cfg)] #![feature(avx512_target_feature)] #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] #![allow(internal_features)] diff --git a/crates/base/src/vector/bvecf32.rs b/crates/base/src/vector/bvecf32.rs index 12b9371bf..c892e34d9 100644 --- a/crates/base/src/vector/bvecf32.rs +++ b/crates/base/src/vector/bvecf32.rs @@ -170,13 +170,7 @@ pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { let rhs = rhs.data(); assert!(lhs.len() == rhs.len()); - #[inline(always)] - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn cosine(lhs: &[usize], rhs: &[usize]) -> F32 { let mut xy = 0; let mut xx = 0; @@ -194,11 +188,11 @@ pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[inline] #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn cosine_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { use std::arch::x86_64::*; #[inline] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { let mut dst: __m512i; unsafe { @@ -247,7 +241,7 @@ pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_avx512vpopcntdq() { + if detect::v4_avx512vpopcntdq::detect() { unsafe { return cosine_avx512vpopcntdq(lhs, rhs); } @@ -261,13 +255,7 @@ pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { let rhs = rhs.data(); assert!(lhs.len() == rhs.len()); - #[inline(always)] - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn dot(lhs: &[usize], rhs: &[usize]) -> F32 { let mut xy = 0; for i in 0..lhs.len() { @@ -278,11 +266,11 @@ pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[inline] #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn dot_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { use std::arch::x86_64::*; #[inline] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { let mut dst: __m512i; unsafe { @@ -323,7 +311,7 @@ pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_avx512vpopcntdq() { + if detect::v4_avx512vpopcntdq::detect() { unsafe { return dot_avx512vpopcntdq(lhs, rhs); } @@ -337,13 +325,7 @@ pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { let rhs = rhs.data(); assert!(lhs.len() == rhs.len()); - #[inline(always)] - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn sl2(lhs: &[usize], rhs: &[usize]) -> F32 { let mut dd = 0; for i in 0..lhs.len() { @@ -354,11 +336,11 @@ pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[inline] #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn sl2_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { use std::arch::x86_64::*; #[inline] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { let mut dst: __m512i; unsafe { @@ -399,7 +381,7 @@ pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_avx512vpopcntdq() { + if detect::v4_avx512vpopcntdq::detect() { unsafe { return sl2_avx512vpopcntdq(lhs, rhs); } @@ -413,13 +395,7 @@ pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { let rhs = rhs.data(); assert!(lhs.len() == rhs.len()); - #[inline(always)] - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn jaccard(lhs: &[usize], rhs: &[usize]) -> F32 { let mut inter = 0; let mut union = 0; @@ -432,11 +408,11 @@ pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[inline] #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn jaccard_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { use std::arch::x86_64::*; #[inline] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { let mut dst: __m512i; unsafe { @@ -481,7 +457,7 @@ pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_avx512vpopcntdq() { + if detect::v4_avx512vpopcntdq::detect() { unsafe { return jaccard_avx512vpopcntdq(lhs, rhs); } @@ -493,13 +469,7 @@ pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { pub fn length(vector: BVecf32Borrowed<'_>) -> F32 { let vector = vector.data(); - #[inline(always)] - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn length(vector: &[usize]) -> F32 { let mut l = 0; for i in 0..vector.len() { @@ -510,11 +480,11 @@ pub fn length(vector: BVecf32Borrowed<'_>) -> F32 { #[inline] #[cfg(target_arch = "x86_64")] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn length_avx512vpopcntdq(lhs: &[usize]) -> F32 { use std::arch::x86_64::*; #[inline] - #[target_feature(enable = "avx512vpopcntdq,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { let mut dst: __m512i; unsafe { @@ -550,7 +520,7 @@ pub fn length(vector: BVecf32Borrowed<'_>) -> F32 { } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_avx512vpopcntdq() { + if detect::v4_avx512vpopcntdq::detect() { unsafe { return length_avx512vpopcntdq(vector); } @@ -558,13 +528,7 @@ pub fn length(vector: BVecf32Borrowed<'_>) -> F32 { length(vector) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn l2_normalize<'a>(vector: BVecf32Borrowed<'a>) -> Vecf32Owned { let l = length(vector); Vecf32Owned::new(vector.iter().map(|i| F32(i as u32 as f32) / l).collect()) diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index 1df950ed6..49bb7cadb 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -184,13 +184,7 @@ impl<'a> SVecf32Borrowed<'a> { } } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] fn cosine_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { let mut lhs_pos = 0; let mut rhs_pos = 0; @@ -231,12 +225,12 @@ fn cosine_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F3 #[inline] #[cfg(target_arch = "x86_64")] -#[target_feature(enable = "avx512bw,avx512f,bmi2")] +#[detect::target_cpu(enable = "v4")] unsafe fn cosine_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { use std::arch::x86_64::*; use std::cmp::min; #[inline] - #[target_feature(enable = "avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4")] pub unsafe fn _mm512_maskz_loadu_epi32(k: __mmask16, mem_addr: *const i32) -> __m512i { let mut dst: __m512i; unsafe { @@ -251,7 +245,7 @@ unsafe fn cosine_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F dst } #[inline] - #[target_feature(enable = "avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4")] pub unsafe fn _mm512_maskz_loadu_ps(k: __mmask16, mem_addr: *const f32) -> __m512 { let mut dst: __m512; unsafe { @@ -368,19 +362,13 @@ unsafe fn cosine_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F pub fn cosine<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { assert_eq!(lhs.dims(), rhs.dims()); #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v4() { + if detect::v4::detect() { return unsafe { cosine_v4(lhs, rhs) }; } cosine_fallback(lhs, rhs) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] fn dot_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { let mut lhs_pos = 0; let mut rhs_pos = 0; @@ -409,12 +397,12 @@ fn dot_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { #[inline] #[cfg(target_arch = "x86_64")] -#[target_feature(enable = "avx512bw,avx512f,bmi2")] +#[detect::target_cpu(enable = "v4")] unsafe fn dot_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { use std::arch::x86_64::*; use std::cmp::min; #[inline] - #[target_feature(enable = "avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4")] pub unsafe fn _mm512_maskz_loadu_epi32(k: __mmask16, mem_addr: *const i32) -> __m512i { let mut dst: __m512i; unsafe { @@ -429,7 +417,7 @@ unsafe fn dot_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 dst } #[inline] - #[target_feature(enable = "avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4")] pub unsafe fn _mm512_maskz_loadu_ps(k: __mmask16, mem_addr: *const f32) -> __m512 { let mut dst: __m512; unsafe { @@ -516,19 +504,13 @@ unsafe fn dot_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 pub fn dot<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { assert_eq!(lhs.dims(), rhs.dims()); #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v4() { + if detect::v4::detect() { return unsafe { dot_v4(lhs, rhs) }; } dot_fallback(lhs, rhs) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn dot_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 { let mut xy = F32::zero(); for i in 0..lhs.len() as usize { @@ -537,13 +519,7 @@ pub fn dot_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 { xy } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] fn sl2_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { let mut lhs_pos = 0; let mut rhs_pos = 0; @@ -581,12 +557,12 @@ fn sl2_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { #[inline] #[cfg(target_arch = "x86_64")] -#[target_feature(enable = "avx512bw,avx512f,bmi2")] +#[detect::target_cpu(enable = "v4")] unsafe fn sl2_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { use std::arch::x86_64::*; use std::cmp::min; #[inline] - #[target_feature(enable = "avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4")] pub unsafe fn _mm512_maskz_loadu_epi32(k: __mmask16, mem_addr: *const i32) -> __m512i { let mut dst: __m512i; unsafe { @@ -601,7 +577,7 @@ unsafe fn sl2_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 dst } #[inline] - #[target_feature(enable = "avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4")] pub unsafe fn _mm512_maskz_loadu_ps(k: __mmask16, mem_addr: *const f32) -> __m512 { let mut dst: __m512; unsafe { @@ -718,19 +694,13 @@ unsafe fn sl2_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 pub fn sl2<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { assert_eq!(lhs.dims(), rhs.dims()); #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v4() { + if detect::v4::detect() { return unsafe { sl2_v4(lhs, rhs) }; } sl2_fallback(lhs, rhs) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn sl2_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 { let mut d2 = F32::zero(); let mut lhs_pos = 0; @@ -749,13 +719,7 @@ pub fn sl2_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 { d2 } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn length<'a>(vector: SVecf32Borrowed<'a>) -> F32 { let mut dot = F32::zero(); for &i in vector.values() { @@ -764,13 +728,7 @@ pub fn length<'a>(vector: SVecf32Borrowed<'a>) -> F32 { dot.sqrt() } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn l2_normalize(vector: &mut SVecf32Owned) { let l = length(vector.for_borrow()); let dims = vector.dims(); @@ -787,7 +745,7 @@ pub fn l2_normalize(vector: &mut SVecf32Owned) { // Instructions. arXiv preprint arXiv:2112.06342. #[inline] #[cfg(target_arch = "x86_64")] -#[target_feature(enable = "avx512bw,avx512f")] +#[detect::target_cpu(enable = "v4")] unsafe fn emulate_mm512_2intersect_epi32( a: std::arch::x86_64::__m512i, b: std::arch::x86_64::__m512i, @@ -868,7 +826,7 @@ mod tests { let y = random_svector(RHS_SIZE); let cosine_fallback = cosine_fallback(x.for_borrow(), y.for_borrow()); #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v4() { + if detect::v4::detect() { let cosine_v4 = unsafe { cosine_v4(x.for_borrow(), y.for_borrow()) }; assert!( cosine_fallback - cosine_v4 < EPS, @@ -885,7 +843,7 @@ mod tests { let y = random_svector(RHS_SIZE); let dot_fallback = dot_fallback(x.for_borrow(), y.for_borrow()); #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v4() { + if detect::v4::detect() { let dot_v4 = unsafe { dot_v4(x.for_borrow(), y.for_borrow()) }; assert!( dot_fallback - dot_v4 < EPS, @@ -902,7 +860,7 @@ mod tests { let y = random_svector(RHS_SIZE); let sl2_fallback = sl2_fallback(x.for_borrow(), y.for_borrow()); #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v4() { + if detect::v4::detect() { let sl2_v4 = unsafe { sl2_v4(x.for_borrow(), y.for_borrow()) }; assert!( sl2_fallback - sl2_v4 < EPS, diff --git a/crates/base/src/vector/vecf16.rs b/crates/base/src/vector/vecf16.rs index f26d6c27d..6d2389371 100644 --- a/crates/base/src/vector/vecf16.rs +++ b/crates/base/src/vector/vecf16.rs @@ -102,13 +102,7 @@ impl<'a> VectorBorrowed for Vecf16Borrowed<'a> { } pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { - #[inline(always)] - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -123,7 +117,7 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { xy / (x2 * y2).sqrt() } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_avx512fp16() { + if detect::v4_avx512fp16::detect() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -131,7 +125,7 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v4() { + if detect::v4::detect() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -139,7 +133,7 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v3() { + if detect::v3::detect() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -150,13 +144,7 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { } pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { - #[inline(always)] - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -167,7 +155,7 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { xy } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_avx512fp16() { + if detect::v4_avx512fp16::detect() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -175,7 +163,7 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v4() { + if detect::v4::detect() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -183,7 +171,7 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v3() { + if detect::v3::detect() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -194,13 +182,7 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { } pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { - #[inline(always)] - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -212,7 +194,7 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { d2 } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_avx512fp16() { + if detect::v4_avx512fp16::detect() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -220,7 +202,7 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v4() { + if detect::v4::detect() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -228,7 +210,7 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { } } #[cfg(target_arch = "x86_64")] - if detect::x86_64::detect_v3() { + if detect::v3::detect() { assert!(lhs.len() == rhs.len()); let n = lhs.len(); unsafe { @@ -238,13 +220,7 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { sl2(lhs, rhs) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] fn length(vector: &[F16]) -> F16 { let n = vector.len(); let mut dot = F16::zero(); @@ -254,13 +230,7 @@ fn length(vector: &[F16]) -> F16 { dot.sqrt() } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn l2_normalize(vector: &mut [F16]) { let n = vector.len(); let l = length(vector); @@ -269,13 +239,7 @@ pub fn l2_normalize(vector: &mut [F16]) { } } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -290,13 +254,7 @@ pub fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { (xy, x2, y2) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -311,13 +269,7 @@ pub fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) (xy, x2, y2) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n: usize = lhs.len(); @@ -328,13 +280,7 @@ pub fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { xy } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); diff --git a/crates/base/src/vector/vecf32.rs b/crates/base/src/vector/vecf32.rs index 27aff1cd9..1167daf78 100644 --- a/crates/base/src/vector/vecf32.rs +++ b/crates/base/src/vector/vecf32.rs @@ -101,13 +101,7 @@ impl<'a> VectorBorrowed for Vecf32Borrowed<'a> { } } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -122,13 +116,7 @@ pub fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { xy / (x2 * y2).sqrt() } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -139,13 +127,7 @@ pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { xy } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn sl2(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -157,13 +139,7 @@ pub fn sl2(lhs: &[F32], rhs: &[F32]) -> F32 { d2 } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn length(vector: &[F32]) -> F32 { let n = vector.len(); let mut dot = F32::zero(); @@ -173,13 +149,7 @@ pub fn length(vector: &[F32]) -> F32 { dot.sqrt() } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn l2_normalize(vector: &mut [F32]) { let n = vector.len(); let l = length(vector); @@ -188,13 +158,7 @@ pub fn l2_normalize(vector: &mut [F32]) { } } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -209,13 +173,7 @@ pub fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { (xy, x2, y2) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -230,13 +188,7 @@ pub fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) (xy, x2, y2) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n: usize = lhs.len(); @@ -247,13 +199,7 @@ pub fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { xy } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index 73810f05d..7577fe649 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -289,12 +289,7 @@ impl<'a> From<&'a Veci8Owned> for Veci8Borrowed<'a> { } } -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn i8_quantization(vector: &[F32]) -> (Vec, F32, F32) { let min = vector.iter().copied().fold(F32::infinity(), Float::min); let max = vector.iter().copied().fold(F32::neg_infinity(), Float::max); @@ -307,12 +302,7 @@ pub fn i8_quantization(vector: &[F32]) -> (Vec, F32, F32) { (result, alpha, offset) } -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn i8_dequantization(vector: &[I8], alpha: F32, offset: F32) -> Vec { vector .iter() @@ -320,13 +310,7 @@ pub fn i8_dequantization(vector: &[I8], alpha: F32, offset: F32) -> Vec { .collect() } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn i8_precompute(data: &[I8], alpha: F32, offset: F32) -> (F32, F32) { let sum = data.iter().map(|&x| x.to_f32() * alpha).sum(); let l2_norm = data @@ -358,19 +342,14 @@ mod tests_0 { pub fn dot(x: &[I8], y: &[I8]) -> F32 { #[cfg(target_arch = "x86_64")] { - if detect::x86_64::test_avx512vnni() { + if detect::v4_avx512vnni::detect() { return unsafe { dot_i8_avx512vnni(x, y) }; } } dot_i8_fallback(x, y) } -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] fn dot_i8_fallback(x: &[I8], y: &[I8]) -> F32 { // i8 * i8 fall in range of i16. Since our length is less than (2^16 - 1), the result won't overflow. let mut sum = 0; @@ -384,11 +363,11 @@ fn dot_i8_fallback(x: &[I8], y: &[I8]) -> F32 { } #[cfg(target_arch = "x86_64")] -#[target_feature(enable = "avx512vnni,avx512bw,avx512f,bmi2")] +#[detect::target_cpu(enable = "v4_avx512vnni")] unsafe fn dot_i8_avx512vnni(x: &[I8], y: &[I8]) -> F32 { use std::arch::x86_64::*; #[inline] - #[target_feature(enable = "avx512vnni,avx512bw,avx512f,bmi2")] + #[detect::target_cpu(enable = "v4_avx512vnni")] pub unsafe fn _mm512_maskz_loadu_epi8(k: __mmask64, mem_addr: *const i8) -> __m512i { let mut dst: __m512i; unsafe { @@ -463,13 +442,7 @@ pub fn cosine_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { dot_xy / (l2_x * l2_y) } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn l2_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 { let data = lhs.data(); assert_eq!(data.len(), rhs.len()); @@ -482,13 +455,7 @@ pub fn l2_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 { .sum::() } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] +#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] pub fn dot_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 { let data = lhs.data(); assert_eq!(data.len(), rhs.len()); diff --git a/crates/c/tests/f16.rs b/crates/c/tests/f16.rs index 3363dec58..837d96068 100644 --- a/crates/c/tests/f16.rs +++ b/crates/c/tests/f16.rs @@ -2,7 +2,7 @@ #[test] fn test_v_f16_cosine() { - detect::initialize(); + detect::init(); const EPSILON: f32 = f16::EPSILON.to_f32_const(); use half::f16; unsafe fn v_f16_cosine(a: *const u16, b: *const u16, n: usize) -> f32 { @@ -22,21 +22,21 @@ fn test_v_f16_cosine() { let a = (0..n).map(|_| rand::random::()).collect::>(); let b = (0..n).map(|_| rand::random::()).collect::>(); let r = unsafe { v_f16_cosine(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - if detect::x86_64::detect_avx512fp16() { + if detect::v4_avx512fp16::detect() { println!("detected avx512fp16"); let c = unsafe { c::v_f16_cosine_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) }; assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); } else { println!("detected no avx512fp16, skipped"); } - if detect::x86_64::detect_v4() { + if detect::v4::detect() { println!("detected v4"); let c = unsafe { c::v_f16_cosine_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) }; assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); } else { println!("detected no v4, skipped"); } - if detect::x86_64::detect_v3() { + if detect::v3::detect() { println!("detected v3"); let c = unsafe { c::v_f16_cosine_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) }; assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); @@ -47,7 +47,7 @@ fn test_v_f16_cosine() { #[test] fn test_v_f16_dot() { - detect::initialize(); + detect::init(); const EPSILON: f32 = 1.0f32; use half::f16; unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 { @@ -63,21 +63,21 @@ fn test_v_f16_dot() { let a = (0..n).map(|_| rand::random::()).collect::>(); let b = (0..n).map(|_| rand::random::()).collect::>(); let r = unsafe { v_f16_dot(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - if detect::x86_64::detect_avx512fp16() { + if detect::v4_avx512fp16::detect() { println!("detected avx512fp16"); let c = unsafe { c::v_f16_dot_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) }; assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); } else { println!("detected no avx512fp16, skipped"); } - if detect::x86_64::detect_v4() { + if detect::v4::detect() { println!("detected v4"); let c = unsafe { c::v_f16_dot_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) }; assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); } else { println!("detected no v4, skipped"); } - if detect::x86_64::detect_v3() { + if detect::v3::detect() { println!("detected v3"); let c = unsafe { c::v_f16_dot_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) }; assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); @@ -88,7 +88,7 @@ fn test_v_f16_dot() { #[test] fn test_v_f16_sl2() { - detect::initialize(); + detect::init(); const EPSILON: f32 = 1.0f32; use half::f16; unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 { @@ -105,21 +105,21 @@ fn test_v_f16_sl2() { let a = (0..n).map(|_| rand::random::()).collect::>(); let b = (0..n).map(|_| rand::random::()).collect::>(); let r = unsafe { v_f16_sl2(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - if detect::x86_64::detect_avx512fp16() { + if detect::v4_avx512fp16::detect() { println!("detected avx512fp16"); let c = unsafe { c::v_f16_sl2_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) }; assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); } else { println!("detected no avx512fp16, skipped"); } - if detect::x86_64::detect_v4() { + if detect::v4::detect() { println!("detected v4"); let c = unsafe { c::v_f16_sl2_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) }; assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); } else { println!("detected no v4, skipped"); } - if detect::x86_64::detect_v3() { + if detect::v3::detect() { println!("detected v3"); let c = unsafe { c::v_f16_sl2_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) }; assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); diff --git a/crates/detect/Cargo.toml b/crates/detect/Cargo.toml index 4cd453d57..d7cc84b0f 100644 --- a/crates/detect/Cargo.toml +++ b/crates/detect/Cargo.toml @@ -7,5 +7,7 @@ edition.workspace = true rustix.workspace = true std_detect = { git = "https://github.com/tensorchord/stdarch", rev = "e50b2f6fa7f8a9a0081c88b1793d8560462d5848" } +detect_macros = { path = "../detect_macros" } + [lints] workspace = true diff --git a/crates/detect/src/lib.rs b/crates/detect/src/lib.rs index 419a3b666..6716aee27 100644 --- a/crates/detect/src/lib.rs +++ b/crates/detect/src/lib.rs @@ -1,21 +1 @@ -#[cfg(target_os = "linux")] -pub mod linux; - -#[cfg(target_arch = "x86_64")] -pub mod x86_64; - -pub fn initialize() { - #[cfg(target_os = "linux")] - { - self::linux::ctor_memfd(); - } - #[cfg(target_arch = "x86_64")] - { - self::x86_64::ctor_avx512fp16(); - self::x86_64::ctor_avx512vpopcntdq(); - self::x86_64::ctor_avx512vp2intersect(); - self::x86_64::ctor_v2(); - self::x86_64::ctor_v3(); - self::x86_64::ctor_v4(); - } -} +detect_macros::main!(); diff --git a/crates/detect/src/linux.rs b/crates/detect/src/linux.rs deleted file mode 100644 index e5280bb18..000000000 --- a/crates/detect/src/linux.rs +++ /dev/null @@ -1,21 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; - -static ATOMIC_MEMFD: AtomicBool = AtomicBool::new(false); - -pub fn test_memfd() -> bool { - use rustix::fs::MemfdFlags; - use std::io::ErrorKind; - match rustix::fs::memfd_create(".memfd.VECTORS.SUPPORT", MemfdFlags::empty()) { - Ok(_) => true, - Err(e) if e.kind() == ErrorKind::Unsupported => false, - Err(_) => false, - } -} - -pub fn ctor_memfd() { - ATOMIC_MEMFD.store(test_memfd(), Ordering::Relaxed); -} - -pub fn detect_memfd() -> bool { - ATOMIC_MEMFD.load(Ordering::Relaxed) -} diff --git a/crates/detect/src/x86_64.rs b/crates/detect/src/x86_64.rs deleted file mode 100644 index f8f02faaf..000000000 --- a/crates/detect/src/x86_64.rs +++ /dev/null @@ -1,120 +0,0 @@ -use std::sync::atomic::{AtomicBool, Ordering}; - -static ATOMIC_AVX512FP16: AtomicBool = AtomicBool::new(false); -static ATOMIC_AVX512VPOPCNTDQ: AtomicBool = AtomicBool::new(false); -static ATOMIC_AVX512VP2INTERSECT: AtomicBool = AtomicBool::new(false); - -pub fn test_avx512fp16() -> bool { - std_detect::is_x86_feature_detected!("avx512fp16") && test_v4() -} - -pub fn test_avx512vpopcntdq() -> bool { - std::is_x86_feature_detected!("avx512vpopcntdq") && test_v4() -} - -pub fn test_avx512vp2intersect() -> bool { - std_detect::is_x86_feature_detected!("avx512vp2intersect") && test_v4() -} - -pub fn ctor_avx512fp16() { - ATOMIC_AVX512FP16.store(test_avx512fp16(), Ordering::Relaxed); -} - -pub fn ctor_avx512vpopcntdq() { - ATOMIC_AVX512VPOPCNTDQ.store(test_avx512vpopcntdq(), Ordering::Relaxed); -} - -pub fn ctor_avx512vp2intersect() { - ATOMIC_AVX512VP2INTERSECT.store(test_avx512vp2intersect(), Ordering::Relaxed); -} - -pub fn detect_avx512fp16() -> bool { - ATOMIC_AVX512FP16.load(Ordering::Relaxed) -} - -pub fn detect_avx512vpopcntdq() -> bool { - ATOMIC_AVX512VPOPCNTDQ.load(Ordering::Relaxed) -} - -pub fn detect_avx512vp2intersect() -> bool { - ATOMIC_AVX512VP2INTERSECT.load(Ordering::Relaxed) -} - -static ATOMIC_V4: AtomicBool = AtomicBool::new(false); - -pub fn test_v4() -> bool { - std::is_x86_feature_detected!("avx512bw") - && std::is_x86_feature_detected!("avx512cd") - && std::is_x86_feature_detected!("avx512dq") - && std::is_x86_feature_detected!("avx512f") - && std::is_x86_feature_detected!("avx512vl") - && test_v3() -} - -pub fn ctor_v4() { - ATOMIC_V4.store(test_v4(), Ordering::Relaxed); -} - -pub fn detect_v4() -> bool { - ATOMIC_V4.load(Ordering::Relaxed) -} - -static ATOMIC_V3: AtomicBool = AtomicBool::new(false); - -pub fn test_v3() -> bool { - std::is_x86_feature_detected!("avx") - && std::is_x86_feature_detected!("avx2") - && std::is_x86_feature_detected!("bmi1") - && std::is_x86_feature_detected!("bmi2") - && std::is_x86_feature_detected!("f16c") - && std::is_x86_feature_detected!("fma") - && std::is_x86_feature_detected!("lzcnt") - && std::is_x86_feature_detected!("movbe") - && std::is_x86_feature_detected!("xsave") - && test_v2() -} - -pub fn ctor_v3() { - ATOMIC_V3.store(test_v3(), Ordering::Relaxed); -} - -pub fn detect_v3() -> bool { - ATOMIC_V3.load(Ordering::Relaxed) -} - -static ATOMIC_V2: AtomicBool = AtomicBool::new(false); - -pub fn test_v2() -> bool { - std::is_x86_feature_detected!("cmpxchg16b") - && std::is_x86_feature_detected!("fxsr") - && std::is_x86_feature_detected!("popcnt") - && std::is_x86_feature_detected!("sse") - && std::is_x86_feature_detected!("sse2") - && std::is_x86_feature_detected!("sse3") - && std::is_x86_feature_detected!("sse4.1") - && std::is_x86_feature_detected!("sse4.2") - && std::is_x86_feature_detected!("ssse3") -} - -pub fn ctor_v2() { - ATOMIC_V2.store(test_v2(), Ordering::Relaxed); -} - -pub fn detect_v2() -> bool { - ATOMIC_V2.load(Ordering::Relaxed) -} - -static ATOMIC_AVX512VNNI: AtomicBool = AtomicBool::new(false); - -/// check if the CPU supports avx512vnni -pub fn test_avx512vnni() -> bool { - std::is_x86_feature_detected!("avx512vnni") && test_v4() -} - -pub fn ctor_vnni() { - ATOMIC_AVX512VNNI.store(test_avx512vnni(), Ordering::Relaxed); -} - -pub fn detect_vnni() -> bool { - ATOMIC_AVX512VNNI.load(Ordering::Relaxed) -} diff --git a/crates/detect/tests/linux.rs b/crates/detect/tests/linux.rs deleted file mode 100644 index 0026ee4fa..000000000 --- a/crates/detect/tests/linux.rs +++ /dev/null @@ -1,7 +0,0 @@ -#![cfg(target_os = "linux")] - -#[test] -fn print() { - detect::initialize(); - assert_eq!(detect::linux::test_memfd(), detect::linux::detect_memfd()); -} diff --git a/crates/detect/tests/x86_64.rs b/crates/detect/tests/x86_64.rs deleted file mode 100644 index becbfd8f2..000000000 --- a/crates/detect/tests/x86_64.rs +++ /dev/null @@ -1,21 +0,0 @@ -#![cfg(target_arch = "x86_64")] - -#[test] -fn print() { - detect::initialize(); - assert_eq!( - detect::x86_64::test_avx512fp16(), - detect::x86_64::detect_avx512fp16() - ); - assert_eq!( - detect::x86_64::test_avx512vpopcntdq(), - detect::x86_64::detect_avx512vpopcntdq() - ); - assert_eq!( - detect::x86_64::test_avx512vp2intersect(), - detect::x86_64::detect_avx512vp2intersect() - ); - assert_eq!(detect::x86_64::test_v4(), detect::x86_64::detect_v4()); - assert_eq!(detect::x86_64::test_v3(), detect::x86_64::detect_v3()); - assert_eq!(detect::x86_64::test_v2(), detect::x86_64::detect_v2()); -} diff --git a/crates/detect_macros/Cargo.toml b/crates/detect_macros/Cargo.toml new file mode 100644 index 000000000..7826a3edf --- /dev/null +++ b/crates/detect_macros/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "detect_macros" +version.workspace = true +edition.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = { version = "1.0.79", features = ["proc-macro"] } +quote = "1.0.35" +syn = { version = "2.0.53", default-features = false, features = [ + "clone-impls", + "full", + "parsing", + "printing", + "proc-macro", +] } + +[lints] +workspace = true diff --git a/crates/detect_macros/src/lib.rs b/crates/detect_macros/src/lib.rs new file mode 100644 index 000000000..085e1b169 --- /dev/null +++ b/crates/detect_macros/src/lib.rs @@ -0,0 +1,318 @@ +struct List { + target_cpu: &'static str, + target_arch: &'static str, + target_features: &'static str, +} + +const LIST: &[List] = &[ + List { + target_cpu: "v4", + target_arch: "x86_64", + target_features: + "avx,avx2,avx512bw,avx512cd,avx512dq,avx512f,avx512vl,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave" + }, + List { + target_cpu: "v3", + target_arch: "x86_64", + target_features: + "avx,avx2,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave" + }, + List { + target_cpu: "v2", + target_arch: "x86_64", + target_features: "cmpxchg16b,fxsr,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3", + }, + List { + target_cpu: "neon", + target_arch: "aarch64", + target_features: "neon", + }, + List { + target_cpu: "v4_avx512vpopcntdq", + target_arch: "x86_64", + target_features: + "avx512vpopcntdq,avx,avx2,avx512bw,avx512cd,avx512dq,avx512f,avx512vl,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave", + }, + List { + target_cpu: "v4_avx512fp16", + target_arch: "x86_64", + target_features: + "avx512fp16,avx,avx2,avx512bw,avx512cd,avx512dq,avx512f,avx512vl,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave", + }, + List { + target_cpu: "v4_avx512vnni", + target_arch: "x86_64", + target_features: + "avx512vnni,avx,avx2,avx512bw,avx512cd,avx512dq,avx512f,avx512vl,bmi1,bmi2,cmpxchg16b,f16c,fma,fxsr,lzcnt,movbe,popcnt,sse,sse2,sse3,sse4.1,sse4.2,ssse3,xsave", + }, +]; + +enum MultiversionPort { + Import, + Export, + Hidden, +} + +struct MultiversionVersion { + ident: String, + // Some(false) => import (specialization) + // Some(true) => export + // None => hidden + port: MultiversionPort, +} + +impl syn::parse::Parse for MultiversionVersion { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let ident: syn::Ident = input.parse()?; + let lookahead1 = input.lookahead1(); + if lookahead1.peek(syn::Token![=]) { + let _: syn::Token![=] = input.parse()?; + let p: syn::Ident = input.parse()?; + if p == "import" { + Ok(Self { + ident: ident.to_string(), + port: MultiversionPort::Import, + }) + } else if p == "export" { + Ok(Self { + ident: ident.to_string(), + port: MultiversionPort::Export, + }) + } else if p == "hidden" { + Ok(Self { + ident: ident.to_string(), + port: MultiversionPort::Hidden, + }) + } else { + panic!("unknown port type") + } + } else { + Ok(Self { + ident: ident.to_string(), + port: MultiversionPort::Hidden, + }) + } + } +} + +struct Multiversion { + versions: syn::punctuated::Punctuated, +} + +impl syn::parse::Parse for Multiversion { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Multiversion { + versions: syn::punctuated::Punctuated::parse_terminated(input)?, + }) + } +} + +#[proc_macro_attribute] +pub fn multiversion( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let attr = syn::parse_macro_input!(attr as Multiversion); + let item_fn = syn::parse::(item).expect("not a function item"); + let syn::ItemFn { + attrs, + vis, + sig, + block, + } = item_fn; + let name = sig.ident.to_string(); + if sig.constness.is_some() { + panic!("const functions are not supported"); + } + if sig.asyncness.is_some() { + panic!("async functions are not supported"); + } + let generics_params = sig.generics.params.clone(); + for generic_param in generics_params.iter() { + if !matches!(generic_param, syn::GenericParam::Lifetime(_)) { + panic!("generic parameters are not supported"); + } + } + let generics_where = sig.generics.where_clause.clone(); + let inputs = sig.inputs.clone(); + let arguments = { + let mut list = vec![]; + for x in sig.inputs.iter() { + if let syn::FnArg::Typed(y) = x { + if let syn::Pat::Ident(ident) = *y.pat.clone() { + list.push(ident); + } else { + panic!("patterns on parameters are not supported") + } + } else { + panic!("receiver parameters are not supported") + } + } + list + }; + if sig.variadic.is_some() { + panic!("variadic parameters are not supported"); + } + let output = sig.output.clone(); + let mut versions_export = quote::quote! {}; + let mut versions_hidden = quote::quote! {}; + let mut branches = quote::quote! {}; + let mut fallback = false; + for version in attr.versions { + let ident = version.ident.clone(); + let name = syn::Ident::new(&format!("{name}_{ident}"), proc_macro2::Span::mixed_site()); + let port; + let branch; + if fallback { + panic!("fallback version is set"); + } else if ident == "fallback" { + fallback = true; + port = quote::quote! { + unsafe fn #name < #generics_params > (#inputs) #output #generics_where { #block } + }; + branch = quote::quote! { + { + let _multiversion_internal: unsafe fn(#inputs) #output = #name; + CACHE.store(_multiversion_internal as *mut (), core::sync::atomic::Ordering::Relaxed); + unsafe { _multiversion_internal(#(#arguments,)*) } + } + }; + } else { + let target_cpu = ident.clone(); + let t = syn::Ident::new(&target_cpu, proc_macro2::Span::mixed_site()); + let list = LIST + .iter() + .find(|list| list.target_cpu == target_cpu) + .expect("unknown target_cpu"); + let target_arch = list.target_arch; + let target_features = list.target_features; + port = quote::quote! { + #[cfg(any(target_arch = #target_arch, doc))] + #[doc(cfg(target_arch = #target_arch))] + #[target_feature(enable = #target_features)] + unsafe fn #name < #generics_params > (#inputs) #output #generics_where { #block } + }; + branch = quote::quote! { + #[cfg(target_arch = #target_arch)] + if detect::#t::detect() { + let _multiversion_internal: unsafe fn(#inputs) #output = #name; + CACHE.store(_multiversion_internal as *mut (), core::sync::atomic::Ordering::Relaxed); + return unsafe { _multiversion_internal(#(#arguments,)*) }; + } + }; + } + match version.port { + MultiversionPort::Import => (), + MultiversionPort::Export => versions_export.extend(port), + MultiversionPort::Hidden => versions_hidden.extend(port), + } + branches.extend(branch); + } + if !fallback { + panic!("fallback version is not set"); + } + quote::quote! { + #versions_export + #[inline(always)] + #(#attrs)* #vis #sig { + #versions_hidden + static CACHE: core::sync::atomic::AtomicPtr<()> = core::sync::atomic::AtomicPtr::new(core::ptr::null_mut()); + let cache = CACHE.load(core::sync::atomic::Ordering::Relaxed); + if !cache.is_null() { + let f = unsafe { core::mem::transmute::<*mut (), unsafe fn(#inputs) #output>(cache as _) }; + return unsafe { f(#(#arguments,)*) }; + } + #branches + } + } + .into() +} + +struct TargetCpu { + enable: String, +} + +impl syn::parse::Parse for TargetCpu { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let _: syn::Ident = input.parse()?; + let _: syn::Token![=] = input.parse()?; + let enable: syn::LitStr = input.parse()?; + Ok(Self { + enable: enable.value(), + }) + } +} + +#[proc_macro_attribute] +pub fn target_cpu( + attr: proc_macro::TokenStream, + item: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let attr = syn::parse_macro_input!(attr as TargetCpu); + let mut result = quote::quote! {}; + for cpu in attr.enable.split(',') { + let list = LIST + .iter() + .find(|list| list.target_cpu == cpu) + .expect("unknown target_cpu"); + let target_features = list.target_features; + result.extend(quote::quote!(#[target_feature(enable = #target_features)])); + } + result.extend(proc_macro2::TokenStream::from(item)); + result.into() +} + +#[proc_macro] +pub fn main(_: proc_macro::TokenStream) -> proc_macro::TokenStream { + let mut modules = quote::quote! {}; + let mut init = quote::quote! {}; + for x in LIST { + let ident = syn::Ident::new(x.target_cpu, proc_macro2::Span::mixed_site()); + let target_cpu = x.target_cpu; + let list = LIST + .iter() + .find(|list| list.target_cpu == target_cpu) + .expect("unknown target_cpu"); + let target_arch = list.target_arch; + let target_features = list.target_features.split(',').collect::>(); + modules.extend(quote::quote! { + #[cfg(target_arch = #target_arch)] + pub mod #ident { + use std::sync::atomic::{AtomicBool, Ordering}; + + static ATOMIC: AtomicBool = AtomicBool::new(false); + + #[cfg(target_arch = "x86_64")] + pub fn test() -> bool { + true #(&& std_detect::is_x86_feature_detected!(#target_features))* + } + + #[cfg(target_arch = "aarch64")] + pub fn test() -> bool { + true #(&& std_detect::is_aarch64_feature_detected!(#target_features))* + } + + pub(crate) fn init() { + ATOMIC.store(test(), Ordering::Relaxed); + } + + pub fn detect() -> bool { + ATOMIC.load(Ordering::Relaxed) + } + } + }); + init.extend(quote::quote! { + #[cfg(target_arch = #target_arch)] + self::#ident::init(); + }); + } + quote::quote! { + pub use detect_macros::multiversion; + pub use detect_macros::target_cpu; + #modules + pub fn init() { + #init + } + } + .into() +} diff --git a/crates/memfd/src/lib.rs b/crates/memfd/src/lib.rs index 54cdd42e4..73111063f 100644 --- a/crates/memfd/src/lib.rs +++ b/crates/memfd/src/lib.rs @@ -1,32 +1,47 @@ +#![feature(thread_local)] + use std::os::fd::OwnedFd; #[cfg(target_os = "linux")] pub fn memfd_create() -> std::io::Result { - if detect::linux::detect_memfd() { + use std::cell::Cell; + #[thread_local] + static SUPPORT_MEMFD: Cell = Cell::new(true); + if SUPPORT_MEMFD.get() { use rustix::fs::MemfdFlags; - Ok(rustix::fs::memfd_create( + let r = rustix::fs::memfd_create( format!(".memfd.MEMFD.{:x}", std::process::id()), MemfdFlags::empty(), - )?) - } else { - use rustix::fs::Mode; - use rustix::fs::OFlags; - // POSIX fcntl locking do not support shmem, so we use a regular file here. - // reference: https://man7.org/linux/man-pages/man3/fcntl.3p.html - // However, Linux shmem supports fcntl locking. - let name = format!( - ".shm.MEMFD.{:x}.{:x}", - std::process::id(), - rand::random::() ); - let fd = rustix::fs::open( - &name, - OFlags::RDWR | OFlags::CREATE | OFlags::EXCL, - Mode::RUSR | Mode::WUSR, - )?; - rustix::fs::unlink(&name)?; - Ok(fd) + match r { + Ok(fd) => { + return Ok(fd); + } + Err(e) if e.kind() == std::io::ErrorKind::Unsupported => { + SUPPORT_MEMFD.set(false); + } + Err(e) => { + return Err(e.into()); + } + } } + use rustix::fs::Mode; + use rustix::fs::OFlags; + // POSIX fcntl locking do not support shmem, so we use a regular file here. + // reference: https://man7.org/linux/man-pages/man3/fcntl.3p.html + // However, Linux shmem supports fcntl locking. + let name = format!( + ".shm.MEMFD.{:x}.{:x}", + std::process::id(), + rand::random::() + ); + let fd = rustix::fs::open( + &name, + OFlags::RDWR | OFlags::CREATE | OFlags::EXCL, + Mode::RUSR | Mode::WUSR, + )?; + rustix::fs::unlink(&name)?; + Ok(fd) } #[cfg(target_os = "macos")] diff --git a/crates/quantization/Cargo.toml b/crates/quantization/Cargo.toml index e2e45b428..9e55343ed 100644 --- a/crates/quantization/Cargo.toml +++ b/crates/quantization/Cargo.toml @@ -4,13 +4,13 @@ version.workspace = true edition.workspace = true [dependencies] -multiversion.workspace = true num-traits.workspace = true rand.workspace = true serde_json.workspace = true base = { path = "../base" } common = { path = "../common" } +detect = { path = "../detect" } elkan_k_means = { path = "../elkan_k_means" } [lints] diff --git a/crates/quantization/src/lib.rs b/crates/quantization/src/lib.rs index 3c05f03ff..56e0b23eb 100644 --- a/crates/quantization/src/lib.rs +++ b/crates/quantization/src/lib.rs @@ -1,3 +1,4 @@ +#![feature(doc_cfg)] #![feature(avx512_target_feature)] pub mod operator; diff --git a/crates/quantization/src/product/operator.rs b/crates/quantization/src/product/operator.rs index 8c01fdbea..7fa162287 100644 --- a/crates/quantization/src/product/operator.rs +++ b/crates/quantization/src/product/operator.rs @@ -338,12 +338,7 @@ impl OperatorProductQuantization for SVecf32L2 { impl OperatorProductQuantization for Vecf16Cos { type ProductQuantizationL2 = Vecf16L2; - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance<'a>( dims: u32, ratio: u32, @@ -369,12 +364,7 @@ impl OperatorProductQuantization for Vecf16Cos { F32(1.0) - xy / (x2 * y2).sqrt() } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance2( dims: u32, ratio: u32, @@ -400,12 +390,7 @@ impl OperatorProductQuantization for Vecf16Cos { F32(1.0) - xy / (x2 * y2).sqrt() } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance_with_delta<'a>( dims: u32, ratio: u32, @@ -445,12 +430,7 @@ impl OperatorProductQuantization for Vecf16Cos { impl OperatorProductQuantization for Vecf16Dot { type ProductQuantizationL2 = Vecf16L2; - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance<'a>( dims: u32, ratio: u32, @@ -472,12 +452,7 @@ impl OperatorProductQuantization for Vecf16Dot { xy * (-1.0) } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance2( dims: u32, ratio: u32, @@ -499,12 +474,7 @@ impl OperatorProductQuantization for Vecf16Dot { xy * (-1.0) } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance_with_delta<'a>( dims: u32, ratio: u32, @@ -540,12 +510,7 @@ impl OperatorProductQuantization for Vecf16Dot { impl OperatorProductQuantization for Vecf16L2 { type ProductQuantizationL2 = Vecf16L2; - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance<'a>( dims: u32, ratio: u32, @@ -566,12 +531,7 @@ impl OperatorProductQuantization for Vecf16L2 { result } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance2( dims: u32, ratio: u32, @@ -592,12 +552,7 @@ impl OperatorProductQuantization for Vecf16L2 { result } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance_with_delta<'a>( dims: u32, ratio: u32, @@ -632,12 +587,7 @@ impl OperatorProductQuantization for Vecf16L2 { impl OperatorProductQuantization for Vecf32Cos { type ProductQuantizationL2 = Vecf32L2; - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance<'a>( dims: u32, ratio: u32, @@ -663,12 +613,7 @@ impl OperatorProductQuantization for Vecf32Cos { F32(1.0) - xy / (x2 * y2).sqrt() } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance2( dims: u32, ratio: u32, @@ -694,12 +639,7 @@ impl OperatorProductQuantization for Vecf32Cos { F32(1.0) - xy / (x2 * y2).sqrt() } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance_with_delta<'a>( dims: u32, ratio: u32, @@ -739,12 +679,7 @@ impl OperatorProductQuantization for Vecf32Cos { impl OperatorProductQuantization for Vecf32Dot { type ProductQuantizationL2 = Vecf32L2; - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance<'a>( dims: u32, ratio: u32, @@ -766,12 +701,7 @@ impl OperatorProductQuantization for Vecf32Dot { xy * (-1.0) } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance2( dims: u32, ratio: u32, @@ -793,12 +723,7 @@ impl OperatorProductQuantization for Vecf32Dot { xy * (-1.0) } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance_with_delta<'a>( dims: u32, ratio: u32, @@ -834,12 +759,7 @@ impl OperatorProductQuantization for Vecf32Dot { impl OperatorProductQuantization for Vecf32L2 { type ProductQuantizationL2 = Vecf32L2; - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance<'a>( dims: u32, ratio: u32, @@ -860,12 +780,7 @@ impl OperatorProductQuantization for Vecf32L2 { result } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance2( dims: u32, ratio: u32, @@ -886,12 +801,7 @@ impl OperatorProductQuantization for Vecf32L2 { result } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn product_quantization_distance_with_delta<'a>( dims: u32, ratio: u32, diff --git a/crates/quantization/src/scalar/operator.rs b/crates/quantization/src/scalar/operator.rs index 9d60415cc..dbdb4bb6c 100644 --- a/crates/quantization/src/scalar/operator.rs +++ b/crates/quantization/src/scalar/operator.rs @@ -175,12 +175,7 @@ impl OperatorScalarQuantization for SVecf32L2 { } impl OperatorScalarQuantization for Vecf16Cos { - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance<'a>( dims: u16, max: &[F16], @@ -202,12 +197,7 @@ impl OperatorScalarQuantization for Vecf16Cos { F32(1.0) - xy / (x2 * y2).sqrt() } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance2( dims: u16, max: &[F16], @@ -230,12 +220,7 @@ impl OperatorScalarQuantization for Vecf16Cos { } impl OperatorScalarQuantization for Vecf16Dot { - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance<'a>( dims: u16, max: &[F16], @@ -253,12 +238,7 @@ impl OperatorScalarQuantization for Vecf16Dot { xy * (-1.0) } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance2( dims: u16, max: &[F16], @@ -277,12 +257,7 @@ impl OperatorScalarQuantization for Vecf16Dot { } impl OperatorScalarQuantization for Vecf16L2 { - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance<'a>( dims: u16, max: &[F16], @@ -300,12 +275,7 @@ impl OperatorScalarQuantization for Vecf16L2 { result } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance2( dims: u16, max: &[F16], @@ -324,12 +294,7 @@ impl OperatorScalarQuantization for Vecf16L2 { } impl OperatorScalarQuantization for Vecf32Cos { - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance<'a>( dims: u16, max: &[F32], @@ -351,12 +316,7 @@ impl OperatorScalarQuantization for Vecf32Cos { F32(1.0) - xy / (x2 * y2).sqrt() } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance2( dims: u16, max: &[F32], @@ -379,12 +339,7 @@ impl OperatorScalarQuantization for Vecf32Cos { } impl OperatorScalarQuantization for Vecf32Dot { - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance<'a>( dims: u16, max: &[F32], @@ -402,12 +357,7 @@ impl OperatorScalarQuantization for Vecf32Dot { xy * (-1.0) } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance2( dims: u16, max: &[F32], @@ -426,12 +376,7 @@ impl OperatorScalarQuantization for Vecf32Dot { } impl OperatorScalarQuantization for Vecf32L2 { - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance<'a>( dims: u16, max: &[F32], @@ -449,12 +394,7 @@ impl OperatorScalarQuantization for Vecf32L2 { result } - #[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" - ))] + #[detect::multiversion(v4, v3, v2, neon, fallback)] fn scalar_quantization_distance2( dims: u16, max: &[F32], diff --git a/src/datatype/memory_veci8.rs b/src/datatype/memory_veci8.rs index 2afc4bf6e..420641fdb 100644 --- a/src/datatype/memory_veci8.rs +++ b/src/datatype/memory_veci8.rs @@ -80,7 +80,7 @@ impl Veci8Header { } /// return value after dequantization by index - /// since index return &Output, we can't create a new Output and return it as a reference, so we need to use this function to return a new Output directly + /// since `index` return &Output, we can't create a new Output and return it as a reference, so we need to use this function to return a new Output directly #[inline(always)] pub fn index(&self, index: usize) -> F32 { self.data()[index].to_f32() * self.alpha() + self.offset() diff --git a/src/lib.rs b/src/lib.rs index 5d4112703..d85e2b744 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ unsafe extern "C" fn _PG_init() { bad_init(); } unsafe { - detect::initialize(); + detect::init(); self::gucs::init(); self::index::init(); self::ipc::init(); From 19009d08946bedd5cfbff025fb4dd20b0df7c01c Mon Sep 17 00:00:00 2001 From: usamoi Date: Mon, 25 Mar 2024 16:45:26 +0800 Subject: [PATCH 02/16] chore: update rust-toolchain Signed-off-by: usamoi --- .../base/src/scalar/{f16.rs => half_f16.rs} | 40 ++++----- crates/base/src/scalar/mod.rs | 4 +- crates/base/src/vector/bvecf32.rs | 75 ---------------- crates/base/src/vector/svecf32.rs | 90 ------------------- crates/base/src/vector/vecf16.rs | 73 ++++++++------- crates/base/src/vector/veci8.rs | 15 ---- crates/quantization/src/product/operator.rs | 72 +++++++-------- rust-toolchain.toml | 2 +- src/bgworker/mod.rs | 4 +- src/datatype/binary.rs | 2 +- src/gucs/mod.rs | 8 +- src/index/am.rs | 4 - src/index/mod.rs | 4 +- src/ipc/mod.rs | 16 ++-- src/lib.rs | 8 +- 15 files changed, 115 insertions(+), 302 deletions(-) rename crates/base/src/scalar/{f16.rs => half_f16.rs} (89%) diff --git a/crates/base/src/scalar/f16.rs b/crates/base/src/scalar/half_f16.rs similarity index 89% rename from crates/base/src/scalar/f16.rs rename to crates/base/src/scalar/half_f16.rs index 005211c10..ca3e1e29e 100644 --- a/crates/base/src/scalar/f16.rs +++ b/crates/base/src/scalar/half_f16.rs @@ -441,14 +441,14 @@ impl Add for F16 { #[inline(always)] fn add(self, rhs: F16) -> F16 { - unsafe { self::intrinsics::fadd_fast(self.0, rhs.0).into() } + unsafe { intrinsics::fadd_fast(self.0, rhs.0).into() } } } impl AddAssign for F16 { #[inline(always)] fn add_assign(&mut self, rhs: F16) { - unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs.0) } + unsafe { self.0 = intrinsics::fadd_fast(self.0, rhs.0) } } } @@ -457,14 +457,14 @@ impl Sub for F16 { #[inline(always)] fn sub(self, rhs: F16) -> F16 { - unsafe { self::intrinsics::fsub_fast(self.0, rhs.0).into() } + unsafe { intrinsics::fsub_fast(self.0, rhs.0).into() } } } impl SubAssign for F16 { #[inline(always)] fn sub_assign(&mut self, rhs: F16) { - unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs.0) } + unsafe { self.0 = intrinsics::fsub_fast(self.0, rhs.0) } } } @@ -473,14 +473,14 @@ impl Mul for F16 { #[inline(always)] fn mul(self, rhs: F16) -> F16 { - unsafe { self::intrinsics::fmul_fast(self.0, rhs.0).into() } + unsafe { intrinsics::fmul_fast(self.0, rhs.0).into() } } } impl MulAssign for F16 { #[inline(always)] fn mul_assign(&mut self, rhs: F16) { - unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs.0) } + unsafe { self.0 = intrinsics::fmul_fast(self.0, rhs.0) } } } @@ -489,14 +489,14 @@ impl Div for F16 { #[inline(always)] fn div(self, rhs: F16) -> F16 { - unsafe { self::intrinsics::fdiv_fast(self.0, rhs.0).into() } + unsafe { intrinsics::fdiv_fast(self.0, rhs.0).into() } } } impl DivAssign for F16 { #[inline(always)] fn div_assign(&mut self, rhs: F16) { - unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs.0) } + unsafe { self.0 = intrinsics::fdiv_fast(self.0, rhs.0) } } } @@ -505,14 +505,14 @@ impl Rem for F16 { #[inline(always)] fn rem(self, rhs: F16) -> F16 { - unsafe { self::intrinsics::frem_fast(self.0, rhs.0).into() } + unsafe { intrinsics::frem_fast(self.0, rhs.0).into() } } } impl RemAssign for F16 { #[inline(always)] fn rem_assign(&mut self, rhs: F16) { - unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs.0) } + unsafe { self.0 = intrinsics::frem_fast(self.0, rhs.0) } } } @@ -549,13 +549,13 @@ impl Add for F16 { #[inline(always)] fn add(self, rhs: f16) -> F16 { - unsafe { self::intrinsics::fadd_fast(self.0, rhs).into() } + unsafe { intrinsics::fadd_fast(self.0, rhs).into() } } } impl AddAssign for F16 { fn add_assign(&mut self, rhs: f16) { - unsafe { self.0 = self::intrinsics::fadd_fast(self.0, rhs) } + unsafe { self.0 = intrinsics::fadd_fast(self.0, rhs) } } } @@ -564,14 +564,14 @@ impl Sub for F16 { #[inline(always)] fn sub(self, rhs: f16) -> F16 { - unsafe { self::intrinsics::fsub_fast(self.0, rhs).into() } + unsafe { intrinsics::fsub_fast(self.0, rhs).into() } } } impl SubAssign for F16 { #[inline(always)] fn sub_assign(&mut self, rhs: f16) { - unsafe { self.0 = self::intrinsics::fsub_fast(self.0, rhs) } + unsafe { self.0 = intrinsics::fsub_fast(self.0, rhs) } } } @@ -580,14 +580,14 @@ impl Mul for F16 { #[inline(always)] fn mul(self, rhs: f16) -> F16 { - unsafe { self::intrinsics::fmul_fast(self.0, rhs).into() } + unsafe { intrinsics::fmul_fast(self.0, rhs).into() } } } impl MulAssign for F16 { #[inline(always)] fn mul_assign(&mut self, rhs: f16) { - unsafe { self.0 = self::intrinsics::fmul_fast(self.0, rhs) } + unsafe { self.0 = intrinsics::fmul_fast(self.0, rhs) } } } @@ -596,14 +596,14 @@ impl Div for F16 { #[inline(always)] fn div(self, rhs: f16) -> F16 { - unsafe { self::intrinsics::fdiv_fast(self.0, rhs).into() } + unsafe { intrinsics::fdiv_fast(self.0, rhs).into() } } } impl DivAssign for F16 { #[inline(always)] fn div_assign(&mut self, rhs: f16) { - unsafe { self.0 = self::intrinsics::fdiv_fast(self.0, rhs) } + unsafe { self.0 = intrinsics::fdiv_fast(self.0, rhs) } } } @@ -612,14 +612,14 @@ impl Rem for F16 { #[inline(always)] fn rem(self, rhs: f16) -> F16 { - unsafe { self::intrinsics::frem_fast(self.0, rhs).into() } + unsafe { intrinsics::frem_fast(self.0, rhs).into() } } } impl RemAssign for F16 { #[inline(always)] fn rem_assign(&mut self, rhs: f16) { - unsafe { self.0 = self::intrinsics::frem_fast(self.0, rhs) } + unsafe { self.0 = intrinsics::frem_fast(self.0, rhs) } } } diff --git a/crates/base/src/scalar/mod.rs b/crates/base/src/scalar/mod.rs index 3a3eebaef..f098a2d09 100644 --- a/crates/base/src/scalar/mod.rs +++ b/crates/base/src/scalar/mod.rs @@ -1,9 +1,9 @@ -mod f16; mod f32; +mod half_f16; mod i8; -pub use f16::F16; pub use f32::F32; +pub use half_f16::F16; pub use i8::I8; pub trait ScalarLike: diff --git a/crates/base/src/vector/bvecf32.rs b/crates/base/src/vector/bvecf32.rs index c892e34d9..1bc22b330 100644 --- a/crates/base/src/vector/bvecf32.rs +++ b/crates/base/src/vector/bvecf32.rs @@ -191,21 +191,6 @@ pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn cosine_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { use std::arch::x86_64::*; - #[inline] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { - let mut dst: __m512i; - unsafe { - std::arch::asm!( - "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } assert_eq!(lhs.len(), rhs.len()); unsafe { const WIDTH: usize = 512 / 8 / std::mem::size_of::(); @@ -269,21 +254,6 @@ pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn dot_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { use std::arch::x86_64::*; - #[inline] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { - let mut dst: __m512i; - unsafe { - std::arch::asm!( - "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } assert_eq!(lhs.len(), rhs.len()); unsafe { const WIDTH: usize = 512 / 8 / std::mem::size_of::(); @@ -339,21 +309,6 @@ pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn sl2_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { use std::arch::x86_64::*; - #[inline] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { - let mut dst: __m512i; - unsafe { - std::arch::asm!( - "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } assert_eq!(lhs.len(), rhs.len()); unsafe { const WIDTH: usize = 512 / 8 / std::mem::size_of::(); @@ -411,21 +366,6 @@ pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn jaccard_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { use std::arch::x86_64::*; - #[inline] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { - let mut dst: __m512i; - unsafe { - std::arch::asm!( - "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } assert_eq!(lhs.len(), rhs.len()); unsafe { const WIDTH: usize = 512 / 8 / std::mem::size_of::(); @@ -483,21 +423,6 @@ pub fn length(vector: BVecf32Borrowed<'_>) -> F32 { #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] unsafe fn length_avx512vpopcntdq(lhs: &[usize]) -> F32 { use std::arch::x86_64::*; - #[inline] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - pub unsafe fn _mm512_maskz_loadu_epi64(k: __mmask8, mem_addr: *const i8) -> __m512i { - let mut dst: __m512i; - unsafe { - std::arch::asm!( - "vmovdqu64 {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } unsafe { const WIDTH: usize = 512 / 8 / std::mem::size_of::(); let mut cnt = _mm512_setzero_si512(); diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index 49bb7cadb..7fabc284d 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -229,36 +229,6 @@ fn cosine_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F3 unsafe fn cosine_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { use std::arch::x86_64::*; use std::cmp::min; - #[inline] - #[detect::target_cpu(enable = "v4")] - pub unsafe fn _mm512_maskz_loadu_epi32(k: __mmask16, mem_addr: *const i32) -> __m512i { - let mut dst: __m512i; - unsafe { - std::arch::asm!( - "vmovdqu32 {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } - #[inline] - #[detect::target_cpu(enable = "v4")] - pub unsafe fn _mm512_maskz_loadu_ps(k: __mmask16, mem_addr: *const f32) -> __m512 { - let mut dst: __m512; - unsafe { - std::arch::asm!( - "vmovups {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } unsafe { const W: usize = 16; let mut lhs_pos = 0; @@ -401,36 +371,6 @@ fn dot_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { unsafe fn dot_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { use std::arch::x86_64::*; use std::cmp::min; - #[inline] - #[detect::target_cpu(enable = "v4")] - pub unsafe fn _mm512_maskz_loadu_epi32(k: __mmask16, mem_addr: *const i32) -> __m512i { - let mut dst: __m512i; - unsafe { - std::arch::asm!( - "vmovdqu32 {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } - #[inline] - #[detect::target_cpu(enable = "v4")] - pub unsafe fn _mm512_maskz_loadu_ps(k: __mmask16, mem_addr: *const f32) -> __m512 { - let mut dst: __m512; - unsafe { - std::arch::asm!( - "vmovups {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } unsafe { const W: usize = 16; let mut lhs_pos = 0; @@ -561,36 +501,6 @@ fn sl2_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { unsafe fn sl2_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { use std::arch::x86_64::*; use std::cmp::min; - #[inline] - #[detect::target_cpu(enable = "v4")] - pub unsafe fn _mm512_maskz_loadu_epi32(k: __mmask16, mem_addr: *const i32) -> __m512i { - let mut dst: __m512i; - unsafe { - std::arch::asm!( - "vmovdqu32 {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } - #[inline] - #[detect::target_cpu(enable = "v4")] - pub unsafe fn _mm512_maskz_loadu_ps(k: __mmask16, mem_addr: *const f32) -> __m512 { - let mut dst: __m512; - unsafe { - std::arch::asm!( - "vmovups {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } unsafe { const W: usize = 16; let mut lhs_pos = 0; diff --git a/crates/base/src/vector/vecf16.rs b/crates/base/src/vector/vecf16.rs index 6d2389371..6ceaa3b5e 100644 --- a/crates/base/src/vector/vecf16.rs +++ b/crates/base/src/vector/vecf16.rs @@ -101,46 +101,43 @@ impl<'a> VectorBorrowed for Vecf16Borrowed<'a> { } } +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] +unsafe fn cosine_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } +} + +#[cfg(target_arch = "x86_64")] +#[doc(cfg(target_arch = "x86_64"))] +unsafe fn cosine_v4(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { c::v_f16_cosine_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } +} + +#[cfg(target_arch = "x86_64")] +#[doc(cfg(target_arch = "x86_64"))] +unsafe fn cosine_v3(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { c::v_f16_cosine_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } +} + +#[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut xy = F32::zero(); - let mut x2 = F32::zero(); - let mut y2 = F32::zero(); - for i in 0..n { - xy += lhs[i].to_f() * rhs[i].to_f(); - x2 += lhs[i].to_f() * lhs[i].to_f(); - y2 += rhs[i].to_f() * rhs[i].to_f(); - } - xy / (x2 * y2).sqrt() - } - #[cfg(target_arch = "x86_64")] - if detect::v4_avx512fp16::detect() { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - unsafe { - return c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); - } - } - #[cfg(target_arch = "x86_64")] - if detect::v4::detect() { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - unsafe { - return c::v_f16_cosine_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); - } - } - #[cfg(target_arch = "x86_64")] - if detect::v3::detect() { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - unsafe { - return c::v_f16_cosine_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); - } + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); + x2 += lhs[i].to_f() * lhs[i].to_f(); + y2 += rhs[i].to_f() * rhs[i].to_f(); } - cosine(lhs, rhs) + xy / (x2 * y2).sqrt() } pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index 7577fe649..c8470dca6 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -366,21 +366,6 @@ fn dot_i8_fallback(x: &[I8], y: &[I8]) -> F32 { #[detect::target_cpu(enable = "v4_avx512vnni")] unsafe fn dot_i8_avx512vnni(x: &[I8], y: &[I8]) -> F32 { use std::arch::x86_64::*; - #[inline] - #[detect::target_cpu(enable = "v4_avx512vnni")] - pub unsafe fn _mm512_maskz_loadu_epi8(k: __mmask64, mem_addr: *const i8) -> __m512i { - let mut dst: __m512i; - unsafe { - std::arch::asm!( - "vmovdqu8 {dst}{{{k}}} {{z}}, [{p}]", - p = in(reg) mem_addr, - k = in(kreg) k, - dst = out(zmm_reg) dst, - options(pure, readonly, nostack) - ); - } - dst - } assert_eq!(x.len(), y.len()); let mut sum = 0; let mut i = x.len(); diff --git a/crates/quantization/src/product/operator.rs b/crates/quantization/src/product/operator.rs index 7fa162287..de6687373 100644 --- a/crates/quantization/src/product/operator.rs +++ b/crates/quantization/src/product/operator.rs @@ -112,7 +112,7 @@ impl OperatorProductQuantization for BVecf32Dot { } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } fn product_quantization_dense_distance(_: &[Scalar], _: &[Scalar]) -> F32 { @@ -155,7 +155,7 @@ impl OperatorProductQuantization for BVecf32Jaccard { } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } fn product_quantization_dense_distance(_: &[Scalar], _: &[Scalar]) -> F32 { @@ -198,7 +198,7 @@ impl OperatorProductQuantization for BVecf32L2 { } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } fn product_quantization_dense_distance(_: &[Scalar], _: &[Scalar]) -> F32 { @@ -241,7 +241,7 @@ impl OperatorProductQuantization for SVecf32Cos { } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } fn product_quantization_dense_distance(_: &[Scalar], _: &[Scalar]) -> F32 { @@ -284,7 +284,7 @@ impl OperatorProductQuantization for SVecf32Dot { } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } fn product_quantization_dense_distance(_: &[Scalar], _: &[Scalar]) -> F32 { @@ -327,7 +327,7 @@ impl OperatorProductQuantization for SVecf32L2 { } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } fn product_quantization_dense_distance(_: &[Scalar], _: &[Scalar]) -> F32 { @@ -356,7 +356,7 @@ impl OperatorProductQuantization for Vecf16Cos { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = super::vecf16::xy_x2_y2(lhs, rhs); + let (_xy, _x2, _y2) = vecf16::xy_x2_y2(lhs, rhs); xy += _xy; x2 += _x2; y2 += _y2; @@ -382,7 +382,7 @@ impl OperatorProductQuantization for Vecf16Cos { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = super::vecf16::xy_x2_y2(lhs, rhs); + let (_xy, _x2, _y2) = vecf16::xy_x2_y2(lhs, rhs); xy += _xy; x2 += _x2; y2 += _y2; @@ -410,7 +410,7 @@ impl OperatorProductQuantization for Vecf16Cos { let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; let del = &delta[(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = super::vecf16::xy_x2_y2_delta(lhs, rhs, del); + let (_xy, _x2, _y2) = vecf16::xy_x2_y2_delta(lhs, rhs, del); xy += _xy; x2 += _x2; y2 += _y2; @@ -419,11 +419,11 @@ impl OperatorProductQuantization for Vecf16Cos { } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf16::sl2(lhs, rhs) + vecf16::sl2(lhs, rhs) } fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - F32(1.0) - super::vecf16::cosine(lhs, rhs) + F32(1.0) - vecf16::cosine(lhs, rhs) } } @@ -446,7 +446,7 @@ impl OperatorProductQuantization for Vecf16Dot { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = super::vecf16::dot(lhs, rhs); + let _xy = vecf16::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -468,7 +468,7 @@ impl OperatorProductQuantization for Vecf16Dot { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = super::vecf16::dot(lhs, rhs); + let _xy = vecf16::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -492,18 +492,18 @@ impl OperatorProductQuantization for Vecf16Dot { let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; let del = &delta[(i * ratio) as usize..][..k as usize]; - let _xy = super::vecf16::dot_delta(lhs, rhs, del); + let _xy = vecf16::dot_delta(lhs, rhs, del); xy += _xy; } xy * (-1.0) } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf16::sl2(lhs, rhs) + vecf16::sl2(lhs, rhs) } fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf16::dot(lhs, rhs) * (-1.0) + vecf16::dot(lhs, rhs) * (-1.0) } } @@ -526,7 +526,7 @@ impl OperatorProductQuantization for Vecf16L2 { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::vecf16::sl2(lhs, rhs); + result += vecf16::sl2(lhs, rhs); } result } @@ -547,7 +547,7 @@ impl OperatorProductQuantization for Vecf16L2 { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::vecf16::sl2(lhs, rhs); + result += vecf16::sl2(lhs, rhs); } result } @@ -570,17 +570,17 @@ impl OperatorProductQuantization for Vecf16L2 { let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; let del = &delta[(i * ratio) as usize..][..k as usize]; - result += super::vecf16::distance_squared_l2_delta(lhs, rhs, del); + result += vecf16::distance_squared_l2_delta(lhs, rhs, del); } result } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf16::sl2(lhs, rhs) + vecf16::sl2(lhs, rhs) } fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf16::sl2(lhs, rhs) + vecf16::sl2(lhs, rhs) } } @@ -605,7 +605,7 @@ impl OperatorProductQuantization for Vecf32Cos { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = super::vecf32::xy_x2_y2(lhs, rhs); + let (_xy, _x2, _y2) = vecf32::xy_x2_y2(lhs, rhs); xy += _xy; x2 += _x2; y2 += _y2; @@ -631,7 +631,7 @@ impl OperatorProductQuantization for Vecf32Cos { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = super::vecf32::xy_x2_y2(lhs, rhs); + let (_xy, _x2, _y2) = vecf32::xy_x2_y2(lhs, rhs); xy += _xy; x2 += _x2; y2 += _y2; @@ -659,7 +659,7 @@ impl OperatorProductQuantization for Vecf32Cos { let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; let del = &delta[(i * ratio) as usize..][..k as usize]; - let (_xy, _x2, _y2) = super::vecf32::xy_x2_y2_delta(lhs, rhs, del); + let (_xy, _x2, _y2) = vecf32::xy_x2_y2_delta(lhs, rhs, del); xy += _xy; x2 += _x2; y2 += _y2; @@ -668,11 +668,11 @@ impl OperatorProductQuantization for Vecf32Cos { } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - F32(1.0) - super::vecf32::cosine(lhs, rhs) + F32(1.0) - vecf32::cosine(lhs, rhs) } } @@ -695,7 +695,7 @@ impl OperatorProductQuantization for Vecf32Dot { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = super::vecf32::dot(lhs, rhs); + let _xy = vecf32::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -717,7 +717,7 @@ impl OperatorProductQuantization for Vecf32Dot { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = super::vecf32::dot(lhs, rhs); + let _xy = vecf32::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -741,18 +741,18 @@ impl OperatorProductQuantization for Vecf32Dot { let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; let del = &delta[(i * ratio) as usize..][..k as usize]; - let _xy = super::vecf32::dot_delta(lhs, rhs, del); + let _xy = vecf32::dot_delta(lhs, rhs, del); xy += _xy; } xy * (-1.0) } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::dot(lhs, rhs) * (-1.0) + vecf32::dot(lhs, rhs) * (-1.0) } } @@ -775,7 +775,7 @@ impl OperatorProductQuantization for Vecf32L2 { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::vecf32::sl2(lhs, rhs); + result += vecf32::sl2(lhs, rhs); } result } @@ -796,7 +796,7 @@ impl OperatorProductQuantization for Vecf32L2 { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::vecf32::sl2(lhs, rhs); + result += vecf32::sl2(lhs, rhs); } result } @@ -819,17 +819,17 @@ impl OperatorProductQuantization for Vecf32L2 { let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; let del = &delta[(i * ratio) as usize..][..k as usize]; - result += super::vecf32::distance_squared_l2_delta(lhs, rhs, del); + result += vecf32::distance_squared_l2_delta(lhs, rhs, del); } result } fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { - super::vecf32::sl2(lhs, rhs) + vecf32::sl2(lhs, rhs) } } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 89c656b5f..6be0e4fa5 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2024-03-04" +channel = "nightly-2024-03-24" profile = "default" targets = [ "aarch64-apple-darwin", diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index 6c4a257b6..b36a5c499 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -71,10 +71,10 @@ extern "C" fn _vectors_main(_arg: pgrx::pg_sys::Datum) { let path = Path::new("pg_vectors"); if path.try_exists().unwrap() { let worker = Worker::open(path.to_owned()); - self::normal::normal(worker); + normal::normal(worker); } else { let worker = Worker::create(path.to_owned()); Version::write(path.join("VERSION")); - self::normal::normal(worker); + normal::normal(worker); } } diff --git a/src/datatype/binary.rs b/src/datatype/binary.rs index 66337a897..f9cd5d233 100644 --- a/src/datatype/binary.rs +++ b/src/datatype/binary.rs @@ -14,7 +14,7 @@ impl Bytea { impl IntoDatum for Bytea { fn into_datum(self) -> Option { if !self.0.is_null() { - Some(pgrx::pg_sys::Datum::from(self.0)) + Some(Datum::from(self.0)) } else { None } diff --git a/src/gucs/mod.rs b/src/gucs/mod.rs index ba6f5013f..e86d09429 100644 --- a/src/gucs/mod.rs +++ b/src/gucs/mod.rs @@ -9,10 +9,10 @@ pub mod planning; pub unsafe fn init() { unsafe { - self::planning::init(); - self::internal::init(); - self::executing::init(); - self::embedding::init(); + planning::init(); + internal::init(); + executing::init(); + embedding::init(); } } diff --git a/src/index/am.rs b/src/index/am.rs index e008efd23..223c6cea0 100644 --- a/src/index/am.rs +++ b/src/index/am.rs @@ -137,8 +137,6 @@ pub unsafe extern "C" fn ambuild( ) -> *mut pgrx::pg_sys::IndexBuildResult { pub struct Builder { pub rpc: ClientRpc, - pub heap: *mut pgrx::pg_sys::RelationData, - pub index_info: *mut pgrx::pg_sys::IndexInfo, pub result: *mut pgrx::pg_sys::IndexBuildResult, } let oid = unsafe { (*index).rd_id }; @@ -155,8 +153,6 @@ pub unsafe extern "C" fn ambuild( let result = unsafe { pgrx::PgBox::::alloc0() }; let mut builder = Builder { rpc, - heap, - index_info, result: result.as_ptr(), }; let table_am = unsafe { &*(*heap).rd_tableam }; diff --git a/src/index/mod.rs b/src/index/mod.rs index 3f196e33c..81d3c7759 100644 --- a/src/index/mod.rs +++ b/src/index/mod.rs @@ -10,7 +10,7 @@ mod views; pub unsafe fn init() { unsafe { - self::hooks::init(); - self::am::init(); + hooks::init(); + am::init(); } } diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index a746fb574..ab3244044 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -21,29 +21,29 @@ pub enum ConnectionError { pub fn listen_unix() -> impl Iterator { std::iter::from_fn(move || { - let socket = self::transport::ServerSocket::Unix(self::transport::unix::accept()); - Some(self::ServerRpcHandler::new(socket)) + let socket = ServerSocket::Unix(transport::unix::accept()); + Some(ServerRpcHandler::new(socket)) }) } pub fn listen_mmap() -> impl Iterator { std::iter::from_fn(move || { - let socket = self::transport::ServerSocket::Mmap(self::transport::mmap::accept()); - Some(self::ServerRpcHandler::new(socket)) + let socket = ServerSocket::Mmap(transport::mmap::accept()); + Some(ServerRpcHandler::new(socket)) }) } pub fn connect_unix() -> ClientSocket { - self::transport::ClientSocket::Unix(self::transport::unix::connect()) + ClientSocket::Unix(transport::unix::connect()) } pub fn connect_mmap() -> ClientSocket { - self::transport::ClientSocket::Mmap(self::transport::mmap::connect()) + ClientSocket::Mmap(transport::mmap::connect()) } pub fn init() { - self::transport::mmap::init(); - self::transport::unix::init(); + transport::mmap::init(); + transport::unix::init(); } impl Drop for ClientRpc { diff --git a/src/lib.rs b/src/lib.rs index d85e2b744..a38ab08da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,10 +28,10 @@ unsafe extern "C" fn _PG_init() { } unsafe { detect::init(); - self::gucs::init(); - self::index::init(); - self::ipc::init(); - self::bgworker::init(); + gucs::init(); + index::init(); + ipc::init(); + bgworker::init(); } } From c930af81a80f693c96b26a4aa8ca22e02d3fe7fb Mon Sep 17 00:00:00 2001 From: usamoi Date: Mon, 25 Mar 2024 17:10:06 +0800 Subject: [PATCH 03/16] chore: use detect::multiversion Signed-off-by: usamoi --- crates/base/src/vector/bvecf32.rs | 440 +++++++++++++----------------- crates/base/src/vector/svecf32.rs | 190 ++++++------- crates/base/src/vector/vecf16.rs | 153 ++++++----- crates/base/src/vector/vecf32.rs | 18 +- crates/base/src/vector/veci8.rs | 52 ++-- 5 files changed, 390 insertions(+), 463 deletions(-) diff --git a/crates/base/src/vector/bvecf32.rs b/crates/base/src/vector/bvecf32.rs index 1bc22b330..467aae270 100644 --- a/crates/base/src/vector/bvecf32.rs +++ b/crates/base/src/vector/bvecf32.rs @@ -164,296 +164,250 @@ impl<'a> PartialOrd for BVecf32Borrowed<'a> { } } -#[inline(always)] -pub fn cosine<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { +#[inline] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] +#[detect::target_cpu(enable = "v4_avx512vpopcntdq")] +unsafe fn cosine_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { + use std::arch::x86_64::*; let lhs = lhs.data(); let rhs = rhs.data(); assert!(lhs.len() == rhs.len()); - - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn cosine(lhs: &[usize], rhs: &[usize]) -> F32 { - let mut xy = 0; - let mut xx = 0; - let mut yy = 0; - for i in 0..lhs.len() { - xy += (lhs[i] & rhs[i]).count_ones(); - xx += lhs[i].count_ones(); - yy += rhs[i].count_ones(); + unsafe { + const WIDTH: usize = 512 / 8 / std::mem::size_of::(); + let mut xy = _mm512_setzero_si512(); + let mut xx = _mm512_setzero_si512(); + let mut yy = _mm512_setzero_si512(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut n = lhs.len(); + while n >= WIDTH { + let x = _mm512_loadu_si512(a.cast()); + let y = _mm512_loadu_si512(b.cast()); + a = a.add(WIDTH); + b = b.add(WIDTH); + n -= WIDTH; + xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); + xx = _mm512_add_epi64(xx, _mm512_popcnt_epi64(x)); + yy = _mm512_add_epi64(yy, _mm512_popcnt_epi64(y)); + } + if n > 0 { + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + let y = _mm512_maskz_loadu_epi64(mask, b.cast()); + xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); + xx = _mm512_add_epi64(xx, _mm512_popcnt_epi64(x)); + yy = _mm512_add_epi64(yy, _mm512_popcnt_epi64(y)); } - let rxy = xy as f32; - let rxx = xx as f32; - let ryy = yy as f32; + let rxy = _mm512_reduce_add_epi64(xy) as f32; + let rxx = _mm512_reduce_add_epi64(xx) as f32; + let ryy = _mm512_reduce_add_epi64(yy) as f32; F32(rxy / (rxx * ryy).sqrt()) } +} - #[inline] - #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn cosine_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { - use std::arch::x86_64::*; - assert_eq!(lhs.len(), rhs.len()); - unsafe { - const WIDTH: usize = 512 / 8 / std::mem::size_of::(); - let mut xy = _mm512_setzero_si512(); - let mut xx = _mm512_setzero_si512(); - let mut yy = _mm512_setzero_si512(); - let mut a = lhs.as_ptr(); - let mut b = rhs.as_ptr(); - let mut n = lhs.len(); - while n >= WIDTH { - let x = _mm512_loadu_si512(a.cast()); - let y = _mm512_loadu_si512(b.cast()); - a = a.add(WIDTH); - b = b.add(WIDTH); - n -= WIDTH; - xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); - xx = _mm512_add_epi64(xx, _mm512_popcnt_epi64(x)); - yy = _mm512_add_epi64(yy, _mm512_popcnt_epi64(y)); - } - if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32) as u8; - let x = _mm512_maskz_loadu_epi64(mask, a.cast()); - let y = _mm512_maskz_loadu_epi64(mask, b.cast()); - xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); - xx = _mm512_add_epi64(xx, _mm512_popcnt_epi64(x)); - yy = _mm512_add_epi64(yy, _mm512_popcnt_epi64(y)); - } - let rxy = _mm512_reduce_add_epi64(xy) as f32; - let rxx = _mm512_reduce_add_epi64(xx) as f32; - let ryy = _mm512_reduce_add_epi64(yy) as f32; - F32(rxy / (rxx * ryy).sqrt()) - } - } +#[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] +pub fn cosine(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { + let lhs = lhs.data(); + let rhs = rhs.data(); + assert!(lhs.len() == rhs.len()); + let mut xy = 0; + let mut xx = 0; + let mut yy = 0; + for i in 0..lhs.len() { + xy += (lhs[i] & rhs[i]).count_ones(); + xx += lhs[i].count_ones(); + yy += rhs[i].count_ones(); + } + let rxy = xy as f32; + let rxx = xx as f32; + let ryy = yy as f32; + F32(rxy / (rxx * ryy).sqrt()) +} - #[cfg(target_arch = "x86_64")] - if detect::v4_avx512vpopcntdq::detect() { - unsafe { - return cosine_avx512vpopcntdq(lhs, rhs); +#[inline] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] +#[detect::target_cpu(enable = "v4_avx512vpopcntdq")] +unsafe fn dot_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { + use std::arch::x86_64::*; + let lhs = lhs.data(); + let rhs = rhs.data(); + assert!(lhs.len() == rhs.len()); + unsafe { + const WIDTH: usize = 512 / 8 / std::mem::size_of::(); + let mut xy = _mm512_setzero_si512(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut n = lhs.len(); + while n >= WIDTH { + let x = _mm512_loadu_si512(a.cast()); + let y = _mm512_loadu_si512(b.cast()); + a = a.add(WIDTH); + b = b.add(WIDTH); + n -= WIDTH; + xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); + } + if n > 0 { + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + let y = _mm512_maskz_loadu_epi64(mask, b.cast()); + xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); } + let rxy = _mm512_reduce_add_epi64(xy) as f32; + F32(rxy) } - cosine(lhs, rhs) } -#[inline(always)] -pub fn dot<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { +#[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] +pub fn dot(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); let rhs = rhs.data(); assert!(lhs.len() == rhs.len()); - - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn dot(lhs: &[usize], rhs: &[usize]) -> F32 { - let mut xy = 0; - for i in 0..lhs.len() { - xy += (lhs[i] & rhs[i]).count_ones(); - } - F32(xy as f32) + let mut xy = 0; + for i in 0..lhs.len() { + xy += (lhs[i] & rhs[i]).count_ones(); } + F32(xy as f32) +} - #[inline] - #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn dot_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { - use std::arch::x86_64::*; - assert_eq!(lhs.len(), rhs.len()); - unsafe { - const WIDTH: usize = 512 / 8 / std::mem::size_of::(); - let mut xy = _mm512_setzero_si512(); - let mut a = lhs.as_ptr(); - let mut b = rhs.as_ptr(); - let mut n = lhs.len(); - while n >= WIDTH { - let x = _mm512_loadu_si512(a.cast()); - let y = _mm512_loadu_si512(b.cast()); - a = a.add(WIDTH); - b = b.add(WIDTH); - n -= WIDTH; - xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); - } - if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32) as u8; - let x = _mm512_maskz_loadu_epi64(mask, a.cast()); - let y = _mm512_maskz_loadu_epi64(mask, b.cast()); - xy = _mm512_add_epi64(xy, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); - } - let rxy = _mm512_reduce_add_epi64(xy) as f32; - F32(rxy) +#[inline] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] +#[detect::target_cpu(enable = "v4_avx512vpopcntdq")] +unsafe fn sl2_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { + use std::arch::x86_64::*; + let lhs = lhs.data(); + let rhs = rhs.data(); + assert!(lhs.len() == rhs.len()); + unsafe { + const WIDTH: usize = 512 / 8 / std::mem::size_of::(); + let mut dd = _mm512_setzero_si512(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut n = lhs.len(); + while n >= WIDTH { + let x = _mm512_loadu_si512(a.cast()); + let y = _mm512_loadu_si512(b.cast()); + a = a.add(WIDTH); + b = b.add(WIDTH); + n -= WIDTH; + dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y))); } - } - - #[cfg(target_arch = "x86_64")] - if detect::v4_avx512vpopcntdq::detect() { - unsafe { - return dot_avx512vpopcntdq(lhs, rhs); + if n > 0 { + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + let y = _mm512_maskz_loadu_epi64(mask, b.cast()); + dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y))); } + let rdd = _mm512_reduce_add_epi64(dd) as f32; + F32(rdd) } - dot(lhs, rhs) } -#[inline(always)] -pub fn sl2<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { +#[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] +pub fn sl2(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); let rhs = rhs.data(); assert!(lhs.len() == rhs.len()); - - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn sl2(lhs: &[usize], rhs: &[usize]) -> F32 { - let mut dd = 0; - for i in 0..lhs.len() { - dd += (lhs[i] ^ rhs[i]).count_ones(); - } - F32(dd as f32) + let mut dd = 0; + for i in 0..lhs.len() { + dd += (lhs[i] ^ rhs[i]).count_ones(); } + F32(dd as f32) +} - #[inline] - #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn sl2_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { - use std::arch::x86_64::*; - assert_eq!(lhs.len(), rhs.len()); - unsafe { - const WIDTH: usize = 512 / 8 / std::mem::size_of::(); - let mut dd = _mm512_setzero_si512(); - let mut a = lhs.as_ptr(); - let mut b = rhs.as_ptr(); - let mut n = lhs.len(); - while n >= WIDTH { - let x = _mm512_loadu_si512(a.cast()); - let y = _mm512_loadu_si512(b.cast()); - a = a.add(WIDTH); - b = b.add(WIDTH); - n -= WIDTH; - dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y))); - } - if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32) as u8; - let x = _mm512_maskz_loadu_epi64(mask, a.cast()); - let y = _mm512_maskz_loadu_epi64(mask, b.cast()); - dd = _mm512_add_epi64(dd, _mm512_popcnt_epi64(_mm512_xor_si512(x, y))); - } - let rdd = _mm512_reduce_add_epi64(dd) as f32; - F32(rdd) +#[inline] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] +#[detect::target_cpu(enable = "v4_avx512vpopcntdq")] +unsafe fn jaccard_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { + use std::arch::x86_64::*; + let lhs = lhs.data(); + let rhs = rhs.data(); + assert!(lhs.len() == rhs.len()); + unsafe { + const WIDTH: usize = 512 / 8 / std::mem::size_of::(); + let mut inter = _mm512_setzero_si512(); + let mut union = _mm512_setzero_si512(); + let mut a = lhs.as_ptr(); + let mut b = rhs.as_ptr(); + let mut n = lhs.len(); + while n >= WIDTH { + let x = _mm512_loadu_si512(a.cast()); + let y = _mm512_loadu_si512(b.cast()); + a = a.add(WIDTH); + b = b.add(WIDTH); + n -= WIDTH; + inter = _mm512_add_epi64(inter, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); + union = _mm512_add_epi64(union, _mm512_popcnt_epi64(_mm512_or_si512(x, y))); } - } - - #[cfg(target_arch = "x86_64")] - if detect::v4_avx512vpopcntdq::detect() { - unsafe { - return sl2_avx512vpopcntdq(lhs, rhs); + if n > 0 { + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + let y = _mm512_maskz_loadu_epi64(mask, b.cast()); + inter = _mm512_add_epi64(inter, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); + union = _mm512_add_epi64(union, _mm512_popcnt_epi64(_mm512_or_si512(x, y))); } + let rinter = _mm512_reduce_add_epi64(inter) as f32; + let runion = _mm512_reduce_add_epi64(union) as f32; + F32(rinter / runion) } - sl2(lhs, rhs) } -#[inline(always)] -pub fn jaccard<'a>(lhs: BVecf32Borrowed<'a>, rhs: BVecf32Borrowed<'a>) -> F32 { +#[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] +pub fn jaccard(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); let rhs = rhs.data(); assert!(lhs.len() == rhs.len()); - - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn jaccard(lhs: &[usize], rhs: &[usize]) -> F32 { - let mut inter = 0; - let mut union = 0; - for i in 0..lhs.len() { - inter += (lhs[i] & rhs[i]).count_ones(); - union += (lhs[i] | rhs[i]).count_ones(); - } - F32(inter as f32 / union as f32) + let mut inter = 0; + let mut union = 0; + for i in 0..lhs.len() { + inter += (lhs[i] & rhs[i]).count_ones(); + union += (lhs[i] | rhs[i]).count_ones(); } + F32(inter as f32 / union as f32) +} - #[inline] - #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn jaccard_avx512vpopcntdq(lhs: &[usize], rhs: &[usize]) -> F32 { - use std::arch::x86_64::*; - assert_eq!(lhs.len(), rhs.len()); - unsafe { - const WIDTH: usize = 512 / 8 / std::mem::size_of::(); - let mut inter = _mm512_setzero_si512(); - let mut union = _mm512_setzero_si512(); - let mut a = lhs.as_ptr(); - let mut b = rhs.as_ptr(); - let mut n = lhs.len(); - while n >= WIDTH { - let x = _mm512_loadu_si512(a.cast()); - let y = _mm512_loadu_si512(b.cast()); - a = a.add(WIDTH); - b = b.add(WIDTH); - n -= WIDTH; - inter = _mm512_add_epi64(inter, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); - union = _mm512_add_epi64(union, _mm512_popcnt_epi64(_mm512_or_si512(x, y))); - } - if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32) as u8; - let x = _mm512_maskz_loadu_epi64(mask, a.cast()); - let y = _mm512_maskz_loadu_epi64(mask, b.cast()); - inter = _mm512_add_epi64(inter, _mm512_popcnt_epi64(_mm512_and_si512(x, y))); - union = _mm512_add_epi64(union, _mm512_popcnt_epi64(_mm512_or_si512(x, y))); - } - let rinter = _mm512_reduce_add_epi64(inter) as f32; - let runion = _mm512_reduce_add_epi64(union) as f32; - F32(rinter / runion) +#[inline] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] +#[detect::target_cpu(enable = "v4_avx512vpopcntdq")] +unsafe fn length_v4_avx512vpopcntdq(vector: BVecf32Borrowed<'_>) -> F32 { + use std::arch::x86_64::*; + let lhs = vector.data(); + unsafe { + const WIDTH: usize = 512 / 8 / std::mem::size_of::(); + let mut cnt = _mm512_setzero_si512(); + let mut a = lhs.as_ptr(); + let mut n = lhs.len(); + while n >= WIDTH { + let x = _mm512_loadu_si512(a.cast()); + a = a.add(WIDTH); + n -= WIDTH; + cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x)); } - } - - #[cfg(target_arch = "x86_64")] - if detect::v4_avx512vpopcntdq::detect() { - unsafe { - return jaccard_avx512vpopcntdq(lhs, rhs); + if n > 0 { + let mask = _bzhi_u32(0xFFFF, n as u32) as u8; + let x = _mm512_maskz_loadu_epi64(mask, a.cast()); + cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x)); } + let rcnt = _mm512_reduce_add_epi64(cnt) as f32; + F32(rcnt.sqrt()) } - jaccard(lhs, rhs) } -#[inline(always)] +#[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] pub fn length(vector: BVecf32Borrowed<'_>) -> F32 { let vector = vector.data(); - - #[detect::multiversion(v4, v3, v2, neon, fallback)] - pub fn length(vector: &[usize]) -> F32 { - let mut l = 0; - for i in 0..vector.len() { - l += vector[i].count_ones(); - } - F32(l as f32).sqrt() - } - - #[inline] - #[cfg(target_arch = "x86_64")] - #[detect::target_cpu(enable = "v4_avx512vpopcntdq")] - unsafe fn length_avx512vpopcntdq(lhs: &[usize]) -> F32 { - use std::arch::x86_64::*; - unsafe { - const WIDTH: usize = 512 / 8 / std::mem::size_of::(); - let mut cnt = _mm512_setzero_si512(); - let mut a = lhs.as_ptr(); - let mut n = lhs.len(); - while n >= WIDTH { - let x = _mm512_loadu_si512(a.cast()); - a = a.add(WIDTH); - n -= WIDTH; - cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x)); - } - if n > 0 { - let mask = _bzhi_u32(0xFFFF, n as u32) as u8; - let x = _mm512_maskz_loadu_epi64(mask, a.cast()); - cnt = _mm512_add_epi64(cnt, _mm512_popcnt_epi64(x)); - } - let rcnt = _mm512_reduce_add_epi64(cnt) as f32; - F32(rcnt.sqrt()) - } - } - - #[cfg(target_arch = "x86_64")] - if detect::v4_avx512vpopcntdq::detect() { - unsafe { - return length_avx512vpopcntdq(vector); - } + let mut l = 0; + for i in 0..vector.len() { + l += vector[i].count_ones(); } - length(vector) + F32(l as f32).sqrt() } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn l2_normalize<'a>(vector: BVecf32Borrowed<'a>) -> Vecf32Owned { let l = length(vector); Vecf32Owned::new(vector.iter().map(|i| F32(i as u32 as f32) / l).collect()) diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index 7fabc284d..36b27e90a 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -184,51 +184,14 @@ impl<'a> SVecf32Borrowed<'a> { } } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] -fn cosine_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { - let mut lhs_pos = 0; - let mut rhs_pos = 0; - let size1 = lhs.len() as usize; - let size2 = rhs.len() as usize; - let mut xy = F32::zero(); - let mut x2 = F32::zero(); - let mut y2 = F32::zero(); - while lhs_pos < size1 && rhs_pos < size2 { - let lhs_index = lhs.indexes()[lhs_pos]; - let rhs_index = rhs.indexes()[rhs_pos]; - match lhs_index.cmp(&rhs_index) { - std::cmp::Ordering::Less => { - x2 += lhs.values()[lhs_pos] * lhs.values()[lhs_pos]; - lhs_pos += 1; - } - std::cmp::Ordering::Greater => { - y2 += rhs.values()[rhs_pos] * rhs.values()[rhs_pos]; - rhs_pos += 1; - } - std::cmp::Ordering::Equal => { - xy += lhs.values()[lhs_pos] * rhs.values()[rhs_pos]; - x2 += lhs.values()[lhs_pos] * lhs.values()[lhs_pos]; - y2 += rhs.values()[rhs_pos] * rhs.values()[rhs_pos]; - lhs_pos += 1; - rhs_pos += 1; - } - } - } - for i in lhs_pos..size1 { - x2 += lhs.values()[i] * lhs.values()[i]; - } - for i in rhs_pos..size2 { - y2 += rhs.values()[i] * rhs.values()[i]; - } - xy / (x2 * y2).sqrt() -} - #[inline] -#[cfg(target_arch = "x86_64")] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] #[detect::target_cpu(enable = "v4")] -unsafe fn cosine_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { +unsafe fn cosine_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { use std::arch::x86_64::*; use std::cmp::min; + assert_eq!(lhs.dims(), rhs.dims()); unsafe { const W: usize = 16; let mut lhs_pos = 0; @@ -328,47 +291,51 @@ unsafe fn cosine_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F } } -#[inline(always)] -pub fn cosine<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { +#[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] +pub fn cosine(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { assert_eq!(lhs.dims(), rhs.dims()); - #[cfg(target_arch = "x86_64")] - if detect::v4::detect() { - return unsafe { cosine_v4(lhs, rhs) }; - } - cosine_fallback(lhs, rhs) -} - -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] -fn dot_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { let mut lhs_pos = 0; let mut rhs_pos = 0; let size1 = lhs.len() as usize; let size2 = rhs.len() as usize; let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); while lhs_pos < size1 && rhs_pos < size2 { let lhs_index = lhs.indexes()[lhs_pos]; let rhs_index = rhs.indexes()[rhs_pos]; match lhs_index.cmp(&rhs_index) { std::cmp::Ordering::Less => { + x2 += lhs.values()[lhs_pos] * lhs.values()[lhs_pos]; lhs_pos += 1; } std::cmp::Ordering::Greater => { + y2 += rhs.values()[rhs_pos] * rhs.values()[rhs_pos]; rhs_pos += 1; } std::cmp::Ordering::Equal => { xy += lhs.values()[lhs_pos] * rhs.values()[rhs_pos]; + x2 += lhs.values()[lhs_pos] * lhs.values()[lhs_pos]; + y2 += rhs.values()[rhs_pos] * rhs.values()[rhs_pos]; lhs_pos += 1; rhs_pos += 1; } } } - xy + for i in lhs_pos..size1 { + x2 += lhs.values()[i] * lhs.values()[i]; + } + for i in rhs_pos..size2 { + y2 += rhs.values()[i] * rhs.values()[i]; + } + xy / (x2 * y2).sqrt() } #[inline] -#[cfg(target_arch = "x86_64")] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] #[detect::target_cpu(enable = "v4")] -unsafe fn dot_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { +unsafe fn dot_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { use std::arch::x86_64::*; use std::cmp::min; unsafe { @@ -440,67 +407,51 @@ unsafe fn dot_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 } } -#[inline(always)] -pub fn dot<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { +#[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] +pub fn dot(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { assert_eq!(lhs.dims(), rhs.dims()); - #[cfg(target_arch = "x86_64")] - if detect::v4::detect() { - return unsafe { dot_v4(lhs, rhs) }; - } - dot_fallback(lhs, rhs) -} - -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] -pub fn dot_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 { - let mut xy = F32::zero(); - for i in 0..lhs.len() as usize { - xy += lhs.values()[i] * rhs[lhs.indexes()[i] as usize]; - } - xy -} - -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] -fn sl2_fallback<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { let mut lhs_pos = 0; let mut rhs_pos = 0; let size1 = lhs.len() as usize; let size2 = rhs.len() as usize; - let mut d2 = F32::zero(); + let mut xy = F32::zero(); while lhs_pos < size1 && rhs_pos < size2 { let lhs_index = lhs.indexes()[lhs_pos]; let rhs_index = rhs.indexes()[rhs_pos]; match lhs_index.cmp(&rhs_index) { - std::cmp::Ordering::Equal => { - let d = lhs.values()[lhs_pos] - rhs.values()[rhs_pos]; - d2 += d * d; - lhs_pos += 1; - rhs_pos += 1; - } std::cmp::Ordering::Less => { - d2 += lhs.values()[lhs_pos] * lhs.values()[lhs_pos]; lhs_pos += 1; } std::cmp::Ordering::Greater => { - d2 += rhs.values()[rhs_pos] * rhs.values()[rhs_pos]; + rhs_pos += 1; + } + std::cmp::Ordering::Equal => { + xy += lhs.values()[lhs_pos] * rhs.values()[rhs_pos]; + lhs_pos += 1; rhs_pos += 1; } } } - for i in lhs_pos..size1 { - d2 += lhs.values()[i] * lhs.values()[i]; - } - for i in rhs_pos..size2 { - d2 += rhs.values()[i] * rhs.values()[i]; + xy +} + +#[detect::multiversion(v4, v3, v2, neon, fallback)] +pub fn dot_2(lhs: SVecf32Borrowed<'_>, rhs: &[F32]) -> F32 { + let mut xy = F32::zero(); + for i in 0..lhs.len() as usize { + xy += lhs.values()[i] * rhs[lhs.indexes()[i] as usize]; } - d2 + xy } #[inline] -#[cfg(target_arch = "x86_64")] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] #[detect::target_cpu(enable = "v4")] -unsafe fn sl2_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { +unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { use std::arch::x86_64::*; use std::cmp::min; + assert_eq!(lhs.dims(), rhs.dims()); unsafe { const W: usize = 16; let mut lhs_pos = 0; @@ -600,18 +551,45 @@ unsafe fn sl2_v4<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 } } -#[inline(always)] -pub fn sl2<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { +#[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] +pub fn sl2(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { assert_eq!(lhs.dims(), rhs.dims()); - #[cfg(target_arch = "x86_64")] - if detect::v4::detect() { - return unsafe { sl2_v4(lhs, rhs) }; + let mut lhs_pos = 0; + let mut rhs_pos = 0; + let size1 = lhs.len() as usize; + let size2 = rhs.len() as usize; + let mut d2 = F32::zero(); + while lhs_pos < size1 && rhs_pos < size2 { + let lhs_index = lhs.indexes()[lhs_pos]; + let rhs_index = rhs.indexes()[rhs_pos]; + match lhs_index.cmp(&rhs_index) { + std::cmp::Ordering::Equal => { + let d = lhs.values()[lhs_pos] - rhs.values()[rhs_pos]; + d2 += d * d; + lhs_pos += 1; + rhs_pos += 1; + } + std::cmp::Ordering::Less => { + d2 += lhs.values()[lhs_pos] * lhs.values()[lhs_pos]; + lhs_pos += 1; + } + std::cmp::Ordering::Greater => { + d2 += rhs.values()[rhs_pos] * rhs.values()[rhs_pos]; + rhs_pos += 1; + } + } + } + for i in lhs_pos..size1 { + d2 += lhs.values()[i] * lhs.values()[i]; } - sl2_fallback(lhs, rhs) + for i in rhs_pos..size2 { + d2 += rhs.values()[i] * rhs.values()[i]; + } + d2 } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] -pub fn sl2_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 { +#[detect::multiversion(v4, v3, v2, neon, fallback)] +pub fn sl2_2(lhs: SVecf32Borrowed<'_>, rhs: &[F32]) -> F32 { let mut d2 = F32::zero(); let mut lhs_pos = 0; let mut rhs_pos = 0; @@ -629,8 +607,8 @@ pub fn sl2_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 { d2 } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] -pub fn length<'a>(vector: SVecf32Borrowed<'a>) -> F32 { +#[detect::multiversion(v4, v3, v2, neon, fallback)] +pub fn length(vector: SVecf32Borrowed<'_>) -> F32 { let mut dot = F32::zero(); for &i in vector.values() { dot += i * i; @@ -638,7 +616,7 @@ pub fn length<'a>(vector: SVecf32Borrowed<'a>) -> F32 { dot.sqrt() } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn l2_normalize(vector: &mut SVecf32Owned) { let l = length(vector.for_borrow()); let dims = vector.dims(); @@ -734,7 +712,7 @@ mod tests { fn test_cosine_svector() { let x = random_svector(LHS_SIZE); let y = random_svector(RHS_SIZE); - let cosine_fallback = cosine_fallback(x.for_borrow(), y.for_borrow()); + let cosine_fallback = unsafe { cosine_fallback(x.for_borrow(), y.for_borrow()) }; #[cfg(target_arch = "x86_64")] if detect::v4::detect() { let cosine_v4 = unsafe { cosine_v4(x.for_borrow(), y.for_borrow()) }; @@ -751,7 +729,7 @@ mod tests { fn test_dot_svector() { let x = random_svector(LHS_SIZE); let y = random_svector(RHS_SIZE); - let dot_fallback = dot_fallback(x.for_borrow(), y.for_borrow()); + let dot_fallback = unsafe { dot_fallback(x.for_borrow(), y.for_borrow()) }; #[cfg(target_arch = "x86_64")] if detect::v4::detect() { let dot_v4 = unsafe { dot_v4(x.for_borrow(), y.for_borrow()) }; @@ -768,7 +746,7 @@ mod tests { fn test_sl2_svector() { let x = random_svector(LHS_SIZE); let y = random_svector(RHS_SIZE); - let sl2_fallback = sl2_fallback(x.for_borrow(), y.for_borrow()); + let sl2_fallback = unsafe { sl2_fallback(x.for_borrow(), y.for_borrow()) }; #[cfg(target_arch = "x86_64")] if detect::v4::detect() { let sl2_v4 = unsafe { sl2_v4(x.for_borrow(), y.for_borrow()) }; diff --git a/crates/base/src/vector/vecf16.rs b/crates/base/src/vector/vecf16.rs index 6ceaa3b5e..ca2efbe57 100644 --- a/crates/base/src/vector/vecf16.rs +++ b/crates/base/src/vector/vecf16.rs @@ -101,6 +101,7 @@ impl<'a> VectorBorrowed for Vecf16Borrowed<'a> { } } +#[inline] #[cfg(any(target_arch = "x86_64", doc))] #[doc(cfg(target_arch = "x86_64"))] unsafe fn cosine_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { @@ -109,6 +110,7 @@ unsafe fn cosine_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[inline] #[cfg(target_arch = "x86_64")] #[doc(cfg(target_arch = "x86_64"))] unsafe fn cosine_v4(lhs: &[F16], rhs: &[F16]) -> F32 { @@ -117,6 +119,7 @@ unsafe fn cosine_v4(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_cosine_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[inline] #[cfg(target_arch = "x86_64")] #[doc(cfg(target_arch = "x86_64"))] unsafe fn cosine_v3(lhs: &[F16], rhs: &[F16]) -> F32 { @@ -140,84 +143,84 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { xy / (x2 * y2).sqrt() } +#[inline] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] +unsafe fn dot_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { c::v_f16_dot_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[doc(cfg(target_arch = "x86_64"))] +unsafe fn dot_v4(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { c::v_f16_dot_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[doc(cfg(target_arch = "x86_64"))] +unsafe fn dot_v3(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { c::v_f16_dot_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } +} + +#[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut xy = F32::zero(); - for i in 0..n { - xy += lhs[i].to_f() * rhs[i].to_f(); - } - xy - } - #[cfg(target_arch = "x86_64")] - if detect::v4_avx512fp16::detect() { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - unsafe { - return c::v_f16_dot_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); - } - } - #[cfg(target_arch = "x86_64")] - if detect::v4::detect() { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - unsafe { - return c::v_f16_dot_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); - } - } - #[cfg(target_arch = "x86_64")] - if detect::v3::detect() { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - unsafe { - return c::v_f16_dot_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); - } + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i].to_f() * rhs[i].to_f(); } - dot(lhs, rhs) + xy } +#[inline] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] +unsafe fn sl2_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { c::v_f16_sl2_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[doc(cfg(target_arch = "x86_64"))] +unsafe fn sl2_v4(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { c::v_f16_sl2_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } +} + +#[inline] +#[cfg(target_arch = "x86_64")] +#[doc(cfg(target_arch = "x86_64"))] +unsafe fn sl2_v3(lhs: &[F16], rhs: &[F16]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + unsafe { c::v_f16_sl2_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } +} + +#[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { - #[detect::multiversion(v4, v3, v2, neon, fallback)] - fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut d2 = F32::zero(); - for i in 0..n { - let d = lhs[i].to_f() - rhs[i].to_f(); - d2 += d * d; - } - d2 - } - #[cfg(target_arch = "x86_64")] - if detect::v4_avx512fp16::detect() { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - unsafe { - return c::v_f16_sl2_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); - } - } - #[cfg(target_arch = "x86_64")] - if detect::v4::detect() { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - unsafe { - return c::v_f16_sl2_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); - } - } - #[cfg(target_arch = "x86_64")] - if detect::v3::detect() { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - unsafe { - return c::v_f16_sl2_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into(); - } + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i].to_f() - rhs[i].to_f(); + d2 += d * d; } - sl2(lhs, rhs) + d2 } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] fn length(vector: &[F16]) -> F16 { let n = vector.len(); let mut dot = F16::zero(); @@ -227,7 +230,7 @@ fn length(vector: &[F16]) -> F16 { dot.sqrt() } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn l2_normalize(vector: &mut [F16]) { let n = vector.len(); let l = length(vector); @@ -236,7 +239,7 @@ pub fn l2_normalize(vector: &mut [F16]) { } } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -251,7 +254,7 @@ pub fn xy_x2_y2(lhs: &[F16], rhs: &[F16]) -> (F32, F32, F32) { (xy, x2, y2) } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -266,7 +269,7 @@ pub fn xy_x2_y2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> (F32, F32, F32) (xy, x2, y2) } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n: usize = lhs.len(); @@ -277,7 +280,7 @@ pub fn dot_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { xy } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn distance_squared_l2_delta(lhs: &[F16], rhs: &[F16], del: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); diff --git a/crates/base/src/vector/vecf32.rs b/crates/base/src/vector/vecf32.rs index 1167daf78..68f77dff2 100644 --- a/crates/base/src/vector/vecf32.rs +++ b/crates/base/src/vector/vecf32.rs @@ -101,7 +101,7 @@ impl<'a> VectorBorrowed for Vecf32Borrowed<'a> { } } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -116,7 +116,7 @@ pub fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { xy / (x2 * y2).sqrt() } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -127,7 +127,7 @@ pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { xy } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn sl2(lhs: &[F32], rhs: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -139,7 +139,7 @@ pub fn sl2(lhs: &[F32], rhs: &[F32]) -> F32 { d2 } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn length(vector: &[F32]) -> F32 { let n = vector.len(); let mut dot = F32::zero(); @@ -149,7 +149,7 @@ pub fn length(vector: &[F32]) -> F32 { dot.sqrt() } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn l2_normalize(vector: &mut [F32]) { let n = vector.len(); let l = length(vector); @@ -158,7 +158,7 @@ pub fn l2_normalize(vector: &mut [F32]) { } } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -173,7 +173,7 @@ pub fn xy_x2_y2(lhs: &[F32], rhs: &[F32]) -> (F32, F32, F32) { (xy, x2, y2) } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) { assert!(lhs.len() == rhs.len()); let n = lhs.len(); @@ -188,7 +188,7 @@ pub fn xy_x2_y2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> (F32, F32, F32) (xy, x2, y2) } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n: usize = lhs.len(); @@ -199,7 +199,7 @@ pub fn dot_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { xy } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn distance_squared_l2_delta(lhs: &[F32], rhs: &[F32], del: &[F32]) -> F32 { assert!(lhs.len() == rhs.len()); let n = lhs.len(); diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index c8470dca6..9d7184160 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -289,7 +289,7 @@ impl<'a> From<&'a Veci8Owned> for Veci8Borrowed<'a> { } } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn i8_quantization(vector: &[F32]) -> (Vec, F32, F32) { let min = vector.iter().copied().fold(F32::infinity(), Float::min); let max = vector.iter().copied().fold(F32::neg_infinity(), Float::max); @@ -302,7 +302,7 @@ pub fn i8_quantization(vector: &[F32]) -> (Vec, F32, F32) { (result, alpha, offset) } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn i8_dequantization(vector: &[I8], alpha: F32, offset: F32) -> Vec { vector .iter() @@ -310,7 +310,7 @@ pub fn i8_dequantization(vector: &[I8], alpha: F32, offset: F32) -> Vec { .collect() } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn i8_precompute(data: &[I8], alpha: F32, offset: F32) -> (F32, F32) { let sum = data.iter().map(|&x| x.to_f32() * alpha).sum(); let l2_norm = data @@ -339,32 +339,11 @@ mod tests_0 { } } -pub fn dot(x: &[I8], y: &[I8]) -> F32 { - #[cfg(target_arch = "x86_64")] - { - if detect::v4_avx512vnni::detect() { - return unsafe { dot_i8_avx512vnni(x, y) }; - } - } - dot_i8_fallback(x, y) -} - -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] -fn dot_i8_fallback(x: &[I8], y: &[I8]) -> F32 { - // i8 * i8 fall in range of i16. Since our length is less than (2^16 - 1), the result won't overflow. - let mut sum = 0; - assert_eq!(x.len(), y.len()); - let length = x.len(); - // according to https://godbolt.org/z/ff48vW4es, this loop will be autovectorized - for i in 0..length { - sum += (x[i].0 as i16 * y[i].0 as i16) as i32; - } - F32(sum as f32) -} - -#[cfg(target_arch = "x86_64")] +#[inline] +#[cfg(any(target_arch = "x86_64", doc))] +#[doc(cfg(target_arch = "x86_64"))] #[detect::target_cpu(enable = "v4_avx512vnni")] -unsafe fn dot_i8_avx512vnni(x: &[I8], y: &[I8]) -> F32 { +unsafe fn dot_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 { use std::arch::x86_64::*; assert_eq!(x.len(), y.len()); let mut sum = 0; @@ -402,6 +381,19 @@ unsafe fn dot_i8_avx512vnni(x: &[I8], y: &[I8]) -> F32 { F32(sum as f32) } +#[detect::multiversion(v4_avx512vnni = import, v4, v3, v2, neon, fallback = export)] +pub fn dot(x: &[I8], y: &[I8]) -> F32 { + // i8 * i8 fall in range of i16. Since our length is less than (2^16 - 1), the result won't overflow. + let mut sum = 0; + assert_eq!(x.len(), y.len()); + let length = x.len(); + // according to https://godbolt.org/z/ff48vW4es, this loop will be autovectorized + for i in 0..length { + sum += (x[i].0 as i16 * y[i].0 as i16) as i32; + } + F32(sum as f32) +} + pub fn dot_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { // (alpha_x * x[i] + offset_x) * (alpha_y * y[i] + offset_y) // = alpha_x * alpha_y * x[i] * y[i] + alpha_x * offset_y * x[i] + alpha_y * offset_x * y[i] + offset_x * offset_y @@ -427,7 +419,7 @@ pub fn cosine_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { dot_xy / (l2_x * l2_y) } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn l2_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 { let data = lhs.data(); assert_eq!(data.len(), rhs.len()); @@ -440,7 +432,7 @@ pub fn l2_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 { .sum::() } -#[detect::multiversion(v4 = export, v3 = export, v2 = export, neon = export, fallback = export)] +#[detect::multiversion(v4, v3, v2, neon, fallback)] pub fn dot_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 { let data = lhs.data(); assert_eq!(data.len(), rhs.len()); From 539ab2b0139d59ca6dc1f0f0a32702e7a9284b88 Mon Sep 17 00:00:00 2001 From: usamoi Date: Mon, 25 Mar 2024 17:49:48 +0800 Subject: [PATCH 04/16] test: use detect::multiversion Signed-off-by: usamoi --- crates/base/src/vector/svecf32.rs | 139 +++++++++++------------ crates/base/src/vector/vecf16.rs | 180 ++++++++++++++++++++++++++++++ crates/base/src/vector/veci8.rs | 33 +++--- crates/c/tests/f16.rs | 129 --------------------- 4 files changed, 262 insertions(+), 219 deletions(-) delete mode 100644 crates/c/tests/f16.rs diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index 36b27e90a..f8b04881d 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -291,6 +291,25 @@ unsafe fn cosine_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn cosine_v4_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4::detect() { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let lhs = random_svector(300); + let rhs = random_svector(350); + let specialized = unsafe { cosine_v4(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] pub fn cosine(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { assert_eq!(lhs.dims(), rhs.dims()); @@ -407,6 +426,25 @@ unsafe fn dot_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn dot_v4_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4::detect() { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let lhs = random_svector(300); + let rhs = random_svector(350); + let specialized = unsafe { dot_v4(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] pub fn dot(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { assert_eq!(lhs.dims(), rhs.dims()); @@ -551,6 +589,25 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn sl2_v4_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4::detect() { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let lhs = random_svector(300); + let rhs = random_svector(350); + let specialized = unsafe { sl2_v4(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] pub fn sl2(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { assert_eq!(lhs.dims(), rhs.dims()); @@ -687,75 +744,15 @@ unsafe fn emulate_mm512_2intersect_epi32( } } -#[cfg(target_arch = "x86_64")] -#[cfg(test)] -mod tests { - use super::*; - - const LHS_SIZE: usize = 300; - const RHS_SIZE: usize = 350; - const EPS: F32 = F32(1e-5); - - pub fn random_svector(len: usize) -> SVecf32Owned { - use rand::Rng; - let mut rng = rand::thread_rng(); - let mut indexes: Vec = (0..len).map(|_| rng.gen_range(0..30000)).collect(); - indexes.sort_unstable(); - indexes.dedup(); - let values: Vec = (0..indexes.len()) - .map(|_| F32(rng.gen_range(-1.0..1.0))) - .collect(); - SVecf32Owned::new(30000, indexes, values) - } - - #[test] - fn test_cosine_svector() { - let x = random_svector(LHS_SIZE); - let y = random_svector(RHS_SIZE); - let cosine_fallback = unsafe { cosine_fallback(x.for_borrow(), y.for_borrow()) }; - #[cfg(target_arch = "x86_64")] - if detect::v4::detect() { - let cosine_v4 = unsafe { cosine_v4(x.for_borrow(), y.for_borrow()) }; - assert!( - cosine_fallback - cosine_v4 < EPS, - "cosine_fallback: {}, cosine_v4: {}", - cosine_fallback, - cosine_v4 - ); - } - } - - #[test] - fn test_dot_svector() { - let x = random_svector(LHS_SIZE); - let y = random_svector(RHS_SIZE); - let dot_fallback = unsafe { dot_fallback(x.for_borrow(), y.for_borrow()) }; - #[cfg(target_arch = "x86_64")] - if detect::v4::detect() { - let dot_v4 = unsafe { dot_v4(x.for_borrow(), y.for_borrow()) }; - assert!( - dot_fallback - dot_v4 < EPS, - "dot_fallback: {}, dot_v4: {}", - dot_fallback, - dot_v4 - ); - } - } - - #[test] - fn test_sl2_svector() { - let x = random_svector(LHS_SIZE); - let y = random_svector(RHS_SIZE); - let sl2_fallback = unsafe { sl2_fallback(x.for_borrow(), y.for_borrow()) }; - #[cfg(target_arch = "x86_64")] - if detect::v4::detect() { - let sl2_v4 = unsafe { sl2_v4(x.for_borrow(), y.for_borrow()) }; - assert!( - sl2_fallback - sl2_v4 < EPS, - "sl2_fallback: {}, sl2_v4: {}", - sl2_fallback, - sl2_v4 - ); - } - } +#[cfg(all(target_arch = "x86_64", test))] +fn random_svector(len: usize) -> SVecf32Owned { + use rand::Rng; + let mut rng = rand::thread_rng(); + let mut indexes: Vec = (0..len).map(|_| rng.gen_range(0..30000)).collect(); + indexes.sort_unstable(); + indexes.dedup(); + let values: Vec = (0..indexes.len()) + .map(|_| F32(rng.gen_range(-1.0..1.0))) + .collect(); + SVecf32Owned::new(30000, indexes, values) } diff --git a/crates/base/src/vector/vecf16.rs b/crates/base/src/vector/vecf16.rs index ca2efbe57..80b79d6fb 100644 --- a/crates/base/src/vector/vecf16.rs +++ b/crates/base/src/vector/vecf16.rs @@ -110,6 +110,26 @@ unsafe fn cosine_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_cosine_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn cosine_v4_avx512fp16_test() { + detect::init(); + if !detect::v4_avx512fp16::detect() { + println!("test {} ... skipped (v4_avx512fp16)", module_path!()); + return; + } + const EPSILON: F32 = F32(half::f16::EPSILON.to_f32_const()); + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { cosine_v4_avx512fp16(&lhs, &rhs) }; + let fallback = unsafe { cosine_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[inline] #[cfg(target_arch = "x86_64")] #[doc(cfg(target_arch = "x86_64"))] @@ -119,6 +139,26 @@ unsafe fn cosine_v4(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_cosine_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn cosine_v4_test() { + detect::init(); + if !detect::v4::detect() { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + const EPSILON: F32 = F32(half::f16::EPSILON.to_f32_const()); + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { cosine_v4(&lhs, &rhs) }; + let fallback = unsafe { cosine_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[inline] #[cfg(target_arch = "x86_64")] #[doc(cfg(target_arch = "x86_64"))] @@ -128,6 +168,26 @@ unsafe fn cosine_v3(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_cosine_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn cosine_v3_test() { + detect::init(); + if !detect::v3::detect() { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + const EPSILON: F32 = F32(half::f16::EPSILON.to_f32_const()); + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { cosine_v3(&lhs, &rhs) }; + let fallback = unsafe { cosine_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); @@ -152,6 +212,26 @@ unsafe fn dot_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_dot_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn dot_v4_avx512fp16_test() { + detect::init(); + if !detect::v4_avx512fp16::detect() { + println!("test {} ... skipped (v4_avx512fp16)", module_path!()); + return; + } + const EPSILON: F32 = F32(1.0); + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { dot_v4_avx512fp16(&lhs, &rhs) }; + let fallback = unsafe { dot_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[inline] #[cfg(target_arch = "x86_64")] #[doc(cfg(target_arch = "x86_64"))] @@ -161,6 +241,26 @@ unsafe fn dot_v4(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_dot_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn dot_v4_test() { + detect::init(); + if !detect::v4::detect() { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + const EPSILON: F32 = F32(1.0); + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { dot_v4(&lhs, &rhs) }; + let fallback = unsafe { dot_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[inline] #[cfg(target_arch = "x86_64")] #[doc(cfg(target_arch = "x86_64"))] @@ -170,6 +270,26 @@ unsafe fn dot_v3(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_dot_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn dot_v3_test() { + detect::init(); + if !detect::v3::detect() { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + const EPSILON: F32 = F32(1.0); + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { dot_v3(&lhs, &rhs) }; + let fallback = unsafe { dot_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); @@ -190,6 +310,26 @@ unsafe fn sl2_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_sl2_avx512fp16(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn sl2_v4_avx512fp16_test() { + detect::init(); + if !detect::v4_avx512fp16::detect() { + println!("test {} ... skipped (v4_avx512fp16)", module_path!()); + return; + } + const EPSILON: F32 = F32(1.0); + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { sl2_v4_avx512fp16(&lhs, &rhs) }; + let fallback = unsafe { sl2_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[inline] #[cfg(target_arch = "x86_64")] #[doc(cfg(target_arch = "x86_64"))] @@ -199,6 +339,26 @@ unsafe fn sl2_v4(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_sl2_v4(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn sl2_v4_test() { + detect::init(); + if !detect::v4::detect() { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + const EPSILON: F32 = F32(1.0); + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { sl2_v4(&lhs, &rhs) }; + let fallback = unsafe { sl2_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[inline] #[cfg(target_arch = "x86_64")] #[doc(cfg(target_arch = "x86_64"))] @@ -208,6 +368,26 @@ unsafe fn sl2_v3(lhs: &[F16], rhs: &[F16]) -> F32 { unsafe { c::v_f16_sl2_v3(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into() } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn sl2_v3_test() { + detect::init(); + if !detect::v3::detect() { + println!("test {} ... skipped (v3)", module_path!()); + return; + } + const EPSILON: F32 = F32(1.0); + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { sl2_v3(&lhs, &rhs) }; + let fallback = unsafe { sl2_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { assert!(lhs.len() == rhs.len()); diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index 9d7184160..871c76af4 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -321,24 +321,6 @@ pub fn i8_precompute(data: &[I8], alpha: F32, offset: F32) -> (F32, F32) { (sum, l2_norm) } -#[cfg(test)] -mod tests_0 { - use super::*; - - #[test] - fn test_quantization_roundtrip() { - let vector = vec![F32(0.0), F32(1.0), F32(2.0), F32(3.0), F32(4.0)]; - let (result, alpha, offset) = i8_quantization(&vector); - assert_eq!(result, vec![I8(-127), I8(-63), I8(0), I8(63), I8(127)]); - assert_eq!(alpha, F32(4.0 / 254.0)); - assert_eq!(offset, F32(2.0)); - let vector = i8_dequantization(result.as_slice(), alpha, offset); - for (i, x) in vector.iter().enumerate() { - assert!((x.0 - (i as f32)).abs() < 0.05); - } - } -} - #[inline] #[cfg(any(target_arch = "x86_64", doc))] #[doc(cfg(target_arch = "x86_64"))] @@ -443,9 +425,22 @@ pub fn dot_2<'a>(lhs: Veci8Borrowed<'a>, rhs: &[F32]) -> F32 { } #[cfg(test)] -mod tests_1 { +mod tests { use super::*; + #[test] + fn test_quantization_roundtrip() { + let vector = vec![F32(0.0), F32(1.0), F32(2.0), F32(3.0), F32(4.0)]; + let (result, alpha, offset) = i8_quantization(&vector); + assert_eq!(result, vec![I8(-127), I8(-63), I8(0), I8(63), I8(127)]); + assert_eq!(alpha, F32(4.0 / 254.0)); + assert_eq!(offset, F32(2.0)); + let vector = i8_dequantization(result.as_slice(), alpha, offset); + for (i, x) in vector.iter().enumerate() { + assert!((x.0 - (i as f32)).abs() < 0.05); + } + } + fn new_random_vec_f32(size: usize) -> Vec { use rand::Rng; let mut rng = rand::thread_rng(); diff --git a/crates/c/tests/f16.rs b/crates/c/tests/f16.rs deleted file mode 100644 index 837d96068..000000000 --- a/crates/c/tests/f16.rs +++ /dev/null @@ -1,129 +0,0 @@ -#![cfg(target_arch = "x86_64")] - -#[test] -fn test_v_f16_cosine() { - detect::init(); - const EPSILON: f32 = f16::EPSILON.to_f32_const(); - use half::f16; - unsafe fn v_f16_cosine(a: *const u16, b: *const u16, n: usize) -> f32 { - let mut xy = 0.0f32; - let mut xx = 0.0f32; - let mut yy = 0.0f32; - for i in 0..n { - let x = unsafe { a.add(i).cast::().read() }.to_f32(); - let y = unsafe { b.add(i).cast::().read() }.to_f32(); - xy += x * y; - xx += x * x; - yy += y * y; - } - xy / (xx * yy).sqrt() - } - let n = 4000; - let a = (0..n).map(|_| rand::random::()).collect::>(); - let b = (0..n).map(|_| rand::random::()).collect::>(); - let r = unsafe { v_f16_cosine(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - if detect::v4_avx512fp16::detect() { - println!("detected avx512fp16"); - let c = unsafe { c::v_f16_cosine_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); - } else { - println!("detected no avx512fp16, skipped"); - } - if detect::v4::detect() { - println!("detected v4"); - let c = unsafe { c::v_f16_cosine_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); - } else { - println!("detected no v4, skipped"); - } - if detect::v3::detect() { - println!("detected v3"); - let c = unsafe { c::v_f16_cosine_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); - } else { - println!("detected no v3, skipped"); - } -} - -#[test] -fn test_v_f16_dot() { - detect::init(); - const EPSILON: f32 = 1.0f32; - use half::f16; - unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 { - let mut xy = 0.0f32; - for i in 0..n { - let x = unsafe { a.add(i).cast::().read() }.to_f32(); - let y = unsafe { b.add(i).cast::().read() }.to_f32(); - xy += x * y; - } - xy - } - let n = 4000; - let a = (0..n).map(|_| rand::random::()).collect::>(); - let b = (0..n).map(|_| rand::random::()).collect::>(); - let r = unsafe { v_f16_dot(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - if detect::v4_avx512fp16::detect() { - println!("detected avx512fp16"); - let c = unsafe { c::v_f16_dot_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); - } else { - println!("detected no avx512fp16, skipped"); - } - if detect::v4::detect() { - println!("detected v4"); - let c = unsafe { c::v_f16_dot_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); - } else { - println!("detected no v4, skipped"); - } - if detect::v3::detect() { - println!("detected v3"); - let c = unsafe { c::v_f16_dot_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); - } else { - println!("detected no v3, skipped"); - } -} - -#[test] -fn test_v_f16_sl2() { - detect::init(); - const EPSILON: f32 = 1.0f32; - use half::f16; - unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 { - let mut dd = 0.0f32; - for i in 0..n { - let x = unsafe { a.add(i).cast::().read() }.to_f32(); - let y = unsafe { b.add(i).cast::().read() }.to_f32(); - let d = x - y; - dd += d * d; - } - dd - } - let n = 4000; - let a = (0..n).map(|_| rand::random::()).collect::>(); - let b = (0..n).map(|_| rand::random::()).collect::>(); - let r = unsafe { v_f16_sl2(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - if detect::v4_avx512fp16::detect() { - println!("detected avx512fp16"); - let c = unsafe { c::v_f16_sl2_avx512fp16(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); - } else { - println!("detected no avx512fp16, skipped"); - } - if detect::v4::detect() { - println!("detected v4"); - let c = unsafe { c::v_f16_sl2_v4(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); - } else { - println!("detected no v4, skipped"); - } - if detect::v3::detect() { - println!("detected v3"); - let c = unsafe { c::v_f16_sl2_v3(a.as_ptr().cast(), b.as_ptr().cast(), n) }; - assert!((c - r).abs() < EPSILON, "c = {c}, r = {r}."); - } else { - println!("detected no v3, skipped"); - } -} From f096c968b3b1e0722bd250fe2984f39800b9af7c Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 16:24:21 +0800 Subject: [PATCH 05/16] ci: add sde test Signed-off-by: usamoi --- .github/workflows/rust_check.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/rust_check.yml b/.github/workflows/rust_check.yml index fb7ffabef..04650ea4a 100644 --- a/.github/workflows/rust_check.yml +++ b/.github/workflows/rust_check.yml @@ -103,6 +103,13 @@ jobs: run: cargo build --no-default-features --features "pg$VERSION" --target $ARCH-unknown-linux-gnu - name: Test run: cargo test --all --no-fail-fast --no-default-features --features "pg$VERSION" --target $ARCH-unknown-linux-gnu -- --nocapture + - name: Test (x86_64) + if: matrix.arch == 'x86_64' + run: | + ASSETS=$(mktemp -d) + wget https://downloadmirror.intel.com/813591/sde-external-9.33.0-2024-01-07-lin.tar.xz -O $ASSETS/sde-external.tar.xz + tar -xf $ASSETS/sde-external.tar.xz -C $ASSETS + cargo --config "target.x86_64-unknown-linux-gnu.runner = [\"$ASSETS/sde-external-9.33.0-2024-01-07-lin/sde64\", \"-spr\", \"--\"]" test "_v4" --all --no-fail-fast --no-default-features --features "pg$VERSION" --target $ARCH-unknown-linux-gnu -- --nocapture - name: Post Set up Cache uses: actions/cache/save@v4 if: ${{ !steps.cache.outputs.cache-hit }} From c6001e705aad9b146e19104cfefe9307c380b83c Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 17:03:50 +0800 Subject: [PATCH 06/16] test: add dot_internal_v4_avx512vnni_test Signed-off-by: usamoi --- crates/base/src/operator/veci8_cos.rs | 2 +- crates/base/src/operator/veci8_dot.rs | 2 +- crates/base/src/operator/veci8_l2.rs | 2 +- crates/base/src/vector/veci8.rs | 45 +++++++++++++++++++-------- 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/crates/base/src/operator/veci8_cos.rs b/crates/base/src/operator/veci8_cos.rs index a882749a3..2cccf7107 100644 --- a/crates/base/src/operator/veci8_cos.rs +++ b/crates/base/src/operator/veci8_cos.rs @@ -12,6 +12,6 @@ impl Operator for Veci8Cos { const DISTANCE_KIND: DistanceKind = DistanceKind::Cos; fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { - F32(1.0) - veci8::cosine_distance(&lhs, &rhs) + F32(1.0) - veci8::cosine(&lhs, &rhs) } } diff --git a/crates/base/src/operator/veci8_dot.rs b/crates/base/src/operator/veci8_dot.rs index b066d7749..f6c5dd1df 100644 --- a/crates/base/src/operator/veci8_dot.rs +++ b/crates/base/src/operator/veci8_dot.rs @@ -12,6 +12,6 @@ impl Operator for Veci8Dot { const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { - veci8::dot_distance(&lhs, &rhs) * (-1.0) + veci8::dot(&lhs, &rhs) * (-1.0) } } diff --git a/crates/base/src/operator/veci8_l2.rs b/crates/base/src/operator/veci8_l2.rs index bde92d8ee..f856b0be2 100644 --- a/crates/base/src/operator/veci8_l2.rs +++ b/crates/base/src/operator/veci8_l2.rs @@ -12,6 +12,6 @@ impl Operator for Veci8L2 { const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { - veci8::l2_distance(&lhs, &rhs) + veci8::sl2(&lhs, &rhs) } } diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index 871c76af4..84e580c9f 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -325,7 +325,7 @@ pub fn i8_precompute(data: &[I8], alpha: F32, offset: F32) -> (F32, F32) { #[cfg(any(target_arch = "x86_64", doc))] #[doc(cfg(target_arch = "x86_64"))] #[detect::target_cpu(enable = "v4_avx512vnni")] -unsafe fn dot_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 { +unsafe fn dot_internal_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 { use std::arch::x86_64::*; assert_eq!(x.len(), y.len()); let mut sum = 0; @@ -363,8 +363,27 @@ unsafe fn dot_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 { F32(sum as f32) } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn dot_internal_v4_avx512vnni_test() { + const EPSILON: F32 = F32(4.0); + detect::init(); + if !detect::v4_avx512vnni::detect() { + println!("test {} ... skipped (v4_avx512vnni)", module_path!()); + return; + } + let lhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); + let rhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); + let specialized = unsafe { dot_internal_v4_avx512vnni(&lhs, &rhs) }; + let fallback = unsafe { dot_internal_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vnni = import, v4, v3, v2, neon, fallback = export)] -pub fn dot(x: &[I8], y: &[I8]) -> F32 { +fn dot_internal(x: &[I8], y: &[I8]) -> F32 { // i8 * i8 fall in range of i16. Since our length is less than (2^16 - 1), the result won't overflow. let mut sum = 0; assert_eq!(x.len(), y.len()); @@ -376,26 +395,26 @@ pub fn dot(x: &[I8], y: &[I8]) -> F32 { F32(sum as f32) } -pub fn dot_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { +pub fn dot(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { // (alpha_x * x[i] + offset_x) * (alpha_y * y[i] + offset_y) // = alpha_x * alpha_y * x[i] * y[i] + alpha_x * offset_y * x[i] + alpha_y * offset_x * y[i] + offset_x * offset_y // Sum(dot(origin_x[i] , origin_y[i])) = alpha_x * alpha_y * Sum(dot(x[i], y[i])) + offset_y * Sum(alpha_x * x[i]) + offset_x * Sum(alpha_y * y[i]) + offset_x * offset_y * dims - let dot_xy = dot(x.data(), y.data()); + let dot_xy = dot_internal(x.data(), y.data()); x.alpha() * y.alpha() * dot_xy + x.offset() * y.sum() + y.offset() * x.sum() + x.offset() * y.offset() * F32(x.dims() as f32) } -pub fn l2_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { +pub fn sl2(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { // Sum(l2(origin_x[i] - origin_y[i])) = sum(x[i] ^ 2 - 2 * x[i] * y[i] + y[i] ^ 2) // = dot(x, x) - 2 * dot(x, y) + dot(y, y) - x.l2_norm() * x.l2_norm() - F32(2.0) * dot_distance(x, y) + y.l2_norm() * y.l2_norm() + x.l2_norm() * x.l2_norm() - F32(2.0) * dot(x, y) + y.l2_norm() * y.l2_norm() } -pub fn cosine_distance(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { +pub fn cosine(x: &Veci8Borrowed<'_>, y: &Veci8Borrowed<'_>) -> F32 { // dot(x, y) / (l2(x) * l2(y)) - let dot_xy = dot_distance(x, y); + let dot_xy = dot(x, y); let l2_x = x.l2_norm(); let l2_y = y.l2_norm(); dot_xy / (l2_x * l2_y) @@ -462,7 +481,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = dot_distance(&ref_x, &ref_y); + let result = dot(&ref_x, &ref_y); assert!((result.0 - 10.0).abs() < 0.1); } @@ -474,7 +493,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = cosine_distance(&ref_x, &ref_y); + let result = cosine(&ref_x, &ref_y); assert!((result.0 - (10.0 / 14.0)).abs() < 0.1); // test cos_i8 using random generated data, check the precision let x = new_random_vec_f32(1000); @@ -487,7 +506,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = cosine_distance(&ref_x, &ref_y); + let result = cosine(&ref_x, &ref_y); assert!( result_expected < 0.01 || (result.0 - result_expected).abs() < 0.01 @@ -503,7 +522,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = l2_distance(&ref_x, &ref_y); + let result = sl2(&ref_x, &ref_y); assert!((result.0 - 8.0).abs() < 0.1); // test l2_i8 using random generated data, check the precision let x = new_random_vec_f32(1000); @@ -518,7 +537,7 @@ mod tests { let ref_x = x_owned.for_borrow(); let y_owned = vec_to_owned(y); let ref_y = y_owned.for_borrow(); - let result = l2_distance(&ref_x, &ref_y); + let result = sl2(&ref_x, &ref_y); assert!( result_expected < 1.0 || (result.0 - result_expected).abs() / result_expected < 0.05 ); From 470915bdd708a57bbbbdc5cf06b1f29506d83ba2 Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 17:31:43 +0800 Subject: [PATCH 07/16] test: bvector tests Signed-off-by: usamoi --- crates/base/src/vector/bvecf32.rs | 146 ++++++++++++++++++++++++++++++ crates/base/src/vector/veci8.rs | 3 +- 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/crates/base/src/vector/bvecf32.rs b/crates/base/src/vector/bvecf32.rs index 467aae270..ec11bf9b5 100644 --- a/crates/base/src/vector/bvecf32.rs +++ b/crates/base/src/vector/bvecf32.rs @@ -13,6 +13,35 @@ pub struct BVecf32Owned { } impl BVecf32Owned { + #[inline(always)] + pub fn new(dims: u16, data: Vec) -> Self { + Self::new_checked(dims, data).unwrap() + } + #[inline(always)] + pub fn new_checked(dims: u16, data: Vec) -> Option { + if dims == 0 { + return None; + } + if data.len() != (dims as usize).div_ceil(BVEC_WIDTH) { + return None; + } + if dims % BVEC_WIDTH as u16 != 0 && data[data.len() - 1] >> (dims % BVEC_WIDTH as u16) != 0 + { + return None; + } + unsafe { Some(Self::new_unchecked(dims, data)) } + } + /// # Safety + /// + /// * `dims` must be in `1..=65535`. + /// * `data` must be of the correct length. + /// * The padding bits must be zero. + #[inline(always)] + pub unsafe fn new_unchecked(dims: u16, data: Vec) -> Self { + Self { dims, data } + } + + #[inline(always)] pub fn new_zeroed(dims: u16) -> Self { assert!((1..=65535).contains(&dims)); let size = (dims as usize).div_ceil(BVEC_WIDTH); @@ -22,6 +51,7 @@ impl BVecf32Owned { } } + #[inline(always)] pub fn set(&mut self, index: usize, value: bool) { assert!(index < self.dims as usize); if value { @@ -206,6 +236,35 @@ unsafe fn cosine_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrow } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn cosine_v4_avx512vpopcntdq_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4_avx512vpopcntdq::detect() { + println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + return; + } + let lhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let rhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let specialized = unsafe { cosine_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] pub fn cosine(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); @@ -259,6 +318,35 @@ unsafe fn dot_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed< } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn dot_v4_avx512vpopcntdq_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4_avx512vpopcntdq::detect() { + println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + return; + } + let lhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let rhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let specialized = unsafe { dot_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] pub fn dot(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); @@ -305,6 +393,35 @@ unsafe fn sl2_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed< } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn sl2_v4_avx512vpopcntdq_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4_avx512vpopcntdq::detect() { + println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + return; + } + let lhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let rhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let specialized = unsafe { sl2_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] pub fn sl2(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); @@ -355,6 +472,35 @@ unsafe fn jaccard_v4_avx512vpopcntdq(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borro } } +#[cfg(all(target_arch = "x86_64", test))] +#[test] +fn jaccard_v4_avx512vpopcntdq_test() { + const EPSILON: F32 = F32(1e-5); + detect::init(); + if !detect::v4_avx512vpopcntdq::detect() { + println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); + return; + } + let lhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let rhs = { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) + }; + let specialized = unsafe { jaccard_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { jaccard_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); +} + #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] pub fn jaccard(lhs: BVecf32Borrowed<'_>, rhs: BVecf32Borrowed<'_>) -> F32 { let lhs = lhs.data(); diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index 84e580c9f..94b065167 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -366,7 +366,8 @@ unsafe fn dot_internal_v4_avx512vnni(x: &[I8], y: &[I8]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn dot_internal_v4_avx512vnni_test() { - const EPSILON: F32 = F32(4.0); + // A large epsilon is set for loss of precision caused by saturation arithmetic + const EPSILON: F32 = F32(512.0); detect::init(); if !detect::v4_avx512vnni::detect() { println!("test {} ... skipped (v4_avx512vnni)", module_path!()); From f4998e4396b86cee430c3c061c35811f4a8dab59 Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 17:58:30 +0800 Subject: [PATCH 08/16] chore: add comments for detect Signed-off-by: usamoi --- crates/detect/src/lib.rs | 48 +++++++++++++++++++++++++++++++++ crates/detect_macros/src/lib.rs | 2 -- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/crates/detect/src/lib.rs b/crates/detect/src/lib.rs index 6716aee27..c930d1a50 100644 --- a/crates/detect/src/lib.rs +++ b/crates/detect/src/lib.rs @@ -1 +1,49 @@ +/// Function multiversioning attribute macros for `pgvecto.rs`. +/// +/// ```no_run +/// #![feature(doc_cfg)] +/// +/// #[cfg(any(target_arch = "x86_64", doc))] +/// #[doc(cfg(target_arch = "x86_64"))] +/// #[detect::target_cpu(enable = "v3")] +/// unsafe fn g_v3(x: &[u32]) -> u32 { +/// todo!() +/// } +/// +/// #[cfg(all(target_arch = "x86_64", test))] +/// #[test] +/// fn g_v3_test() { +/// const EPSILON: F32 = F32(1e-5); +/// detect::init(); +/// if !detect::v3::detect() { +/// println!("test {} ... skipped (v3)", module_path!()); +/// return; +/// } +/// let x = vec![0u32; 400]; +/// x.fill_with(|| rand::random()); +/// let specialized = unsafe { g_v3(&x) }; +/// let fallback = unsafe { g_fallback(&x) }; +/// assert!( +/// (specialized - fallback).abs() < EPSILON, +/// "specialized = {specialized}, fallback = {fallback}." +/// ); +/// } +/// +/// // It generates x86_64/v3, x86_64/v2, aarch64/neon and fallback versions of this function. +/// // It takes advantage of `g_v4` as x86_64/v4 version of this function. +/// // It exposes the fallback version with the name "g_fallback". +/// #[detect::multiversion(v3 = import, v2, neon, fallback = export)] +/// fn g(x: &[u32]) -> u32 { +/// let mut result = 0_u32; +/// for v in x { +/// result = result.wrapping_add(*v); +/// } +/// result +/// } +/// ``` +pub use detect_macros::multiversion; + +/// This macros allows you to enable a set of features by target cpu names. +pub use detect_macros::target_cpu; + detect_macros::main!(); diff --git a/crates/detect_macros/src/lib.rs b/crates/detect_macros/src/lib.rs index 085e1b169..32c8c2667 100644 --- a/crates/detect_macros/src/lib.rs +++ b/crates/detect_macros/src/lib.rs @@ -307,8 +307,6 @@ pub fn main(_: proc_macro::TokenStream) -> proc_macro::TokenStream { }); } quote::quote! { - pub use detect_macros::multiversion; - pub use detect_macros::target_cpu; #modules pub fn init() { #init From 67572c75fa2670e545d0a33f48e9503625a0b151 Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 18:22:46 +0800 Subject: [PATCH 09/16] ci: do not run rust test 3 times Signed-off-by: usamoi --- .../workflows/{psql_check.yml => psql.yml} | 4 +- .../workflows/{rust_check.yml => rust.yml} | 65 ++++++++++++++++++- 2 files changed, 65 insertions(+), 4 deletions(-) rename .github/workflows/{psql_check.yml => psql.yml} (99%) rename .github/workflows/{rust_check.yml => rust.yml} (62%) diff --git a/.github/workflows/psql_check.yml b/.github/workflows/psql.yml similarity index 99% rename from .github/workflows/psql_check.yml rename to .github/workflows/psql.yml index 4451c64e0..5258b54fc 100644 --- a/.github/workflows/psql_check.yml +++ b/.github/workflows/psql.yml @@ -1,4 +1,4 @@ -name: PostgreSQL check +name: PostgreSQL on: push: @@ -44,7 +44,7 @@ env: RUSTFLAGS: "-Dwarnings" jobs: - check: + test: strategy: matrix: version: [14, 15, 16] diff --git a/.github/workflows/rust_check.yml b/.github/workflows/rust.yml similarity index 62% rename from .github/workflows/rust_check.yml rename to .github/workflows/rust.yml index 04650ea4a..4fcb0fb86 100644 --- a/.github/workflows/rust_check.yml +++ b/.github/workflows/rust.yml @@ -1,4 +1,4 @@ -name: Rust check +name: Rust on: push: @@ -101,6 +101,67 @@ jobs: run: cargo clippy --no-default-features --features "pg$VERSION" --target $ARCH-unknown-linux-gnu - name: Build run: cargo build --no-default-features --features "pg$VERSION" --target $ARCH-unknown-linux-gnu + - name: Post Set up Cache + uses: actions/cache/save@v4 + if: ${{ !steps.cache.outputs.cache-hit }} + with: + path: | + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ github.job }}-${{ hashFiles('./Cargo.lock') }}-${{ matrix.version }}-${{ matrix.arch }} + test: + strategy: + matrix: + arch: ["x86_64", "aarch64"] + runs-on: ubuntu-latest + env: + SEMVER: "0.0.0" + VERSION: "16" + ARCH: ${{ matrix.arch }} + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up Environment + run: | + sudo apt-get remove -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' + sudo apt-get purge -y '^postgres.*' '^libpq.*' '^clang.*' '^llvm.*' '^libclang.*' '^libllvm.*' '^mono-llvm.*' + sudo apt-get update + sudo apt-get install -y build-essential crossbuild-essential-arm64 + sudo apt-get install -y qemu-user-static + touch ~/.cargo/config.toml + echo 'target.aarch64-unknown-linux-gnu.linker = "aarch64-linux-gnu-gcc"' >> ~/.cargo/config.toml + echo 'target.aarch64-unknown-linux-gnu.runner = ["qemu-aarch64-static", "-L", "/usr/aarch64-linux-gnu"]' >> ~/.cargo/config.toml + - name: Set up Sccache + uses: mozilla-actions/sccache-action@v0.0.4 + - name: Set up Cache + uses: actions/cache/restore@v4 + id: cache + with: + path: | + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ github.job }}-${{ hashFiles('./Cargo.lock') }}-${{ matrix.arch }} + - name: Set up Clang-16 + run: | + sudo sh -c 'echo "deb http://apt.llvm.org/$(lsb_release -cs)/ llvm-toolchain-$(lsb_release -cs)-16 main" >> /etc/apt/sources.list' + wget --quiet -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - + sudo apt-get update + sudo apt-get install -y clang-16 + - name: Set up Pgrx + run: | + # pg_config + mkdir -p ~/.pg_config + touch ~/.pg_config/pg_config + chmod 777 ~/.pg_config/pg_config + echo "#!/usr/bin/env bash" >> ~/.pg_config/pg_config + echo "$(pwd)/tools/pg_config.sh \"\$@\" < $(pwd)/vendor/pg_config/pg${VERSION}_${ARCH}-unknown-linux-gnu.txt" >> ~/.pg_config/pg_config + mkdir -p ~/.pgrx && echo "configs.pg$VERSION=\"$HOME/.pg_config/pg_config\"" > ~/.pgrx/config.toml + # pgrx_binding + mkdir -p ~/.pgrx_binding + cp ./vendor/pgrx_binding/pg${VERSION}_$(uname --machine)-unknown-linux-gnu.rs ~/.pgrx_binding/pg${VERSION}_raw_bindings.rs + echo PGRX_TARGET_INFO_PATH_PG$VERSION=$HOME/.pgrx_binding >> "$GITHUB_ENV" - name: Test run: cargo test --all --no-fail-fast --no-default-features --features "pg$VERSION" --target $ARCH-unknown-linux-gnu -- --nocapture - name: Test (x86_64) @@ -118,4 +179,4 @@ jobs: ~/.cargo/registry/index/ ~/.cargo/registry/cache/ ~/.cargo/git/db/ - key: ${{ github.job }}-${{ hashFiles('./Cargo.lock') }}-${{ matrix.version }}-${{ matrix.arch }} + key: ${{ github.job }}-${{ hashFiles('./Cargo.lock') }}-${{ matrix.arch }} From f66651c7ada036afdd1245bba43a9e3cc9337ebf Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 21:43:25 +0800 Subject: [PATCH 10/16] fix: svector sl2_v4 Signed-off-by: usamoi --- crates/base/src/vector/svecf32.rs | 62 ++++++++++++++++--------------- 1 file changed, 33 insertions(+), 29 deletions(-) diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index f8b04881d..0ea81b6d2 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -511,10 +511,8 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { let v_r = _mm512_loadu_ps(rhs_val.add(rhs_pos)); let v_l = _mm512_maskz_compress_ps(m_l, v_l); let v_r = _mm512_maskz_compress_ps(m_r, v_r); - let d = _mm512_sub_ps(v_l, v_r); - dd = _mm512_fmadd_ps(d, d, dd); - dd = _mm512_fmsub_ps(v_l, v_l, dd); - dd = _mm512_fmsub_ps(v_r, v_r, dd); + dd = _mm512_fmadd_ps(v_l, v_r, dd); + dd = _mm512_fmadd_ps(v_l, v_r, dd); let l_max = lhs.indexes().get_unchecked(lhs_pos + W - 1); let r_max = rhs.indexes().get_unchecked(rhs_pos + W - 1); match l_max.cmp(r_max) { @@ -542,10 +540,8 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { let v_r = _mm512_maskz_loadu_ps(mask_r, rhs_val.add(rhs_pos)); let v_l = _mm512_maskz_compress_ps(m_l, v_l); let v_r = _mm512_maskz_compress_ps(m_r, v_r); - let d = _mm512_sub_ps(v_l, v_r); - dd = _mm512_fmadd_ps(d, d, dd); - dd = _mm512_fmsub_ps(v_l, v_l, dd); - dd = _mm512_fmsub_ps(v_r, v_r, dd); + dd = _mm512_fmadd_ps(v_l, v_r, dd); + dd = _mm512_fmadd_ps(v_l, v_r, dd); let l_max = lhs.indexes().get_unchecked(lhs_pos + len_l - 1); let r_max = rhs.indexes().get_unchecked(rhs_pos + len_r - 1); match l_max.cmp(r_max) { @@ -561,7 +557,7 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { } } } - + dd = _mm512_sub_ps(_mm512_setzero_ps(), dd); let mut lhs_pos = 0; while lhs_pos < lhs_size { let v = _mm512_loadu_ps(lhs_val.add(lhs_pos)); @@ -592,20 +588,25 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sl2_v4_test() { - const EPSILON: F32 = F32(1e-5); - detect::init(); - if !detect::v4::detect() { - println!("test {} ... skipped (v4)", module_path!()); - return; + let mut m = F32(0.0); + for _ in 0..10000 { + const EPSILON: F32 = F32(1e-3); + detect::init(); + if !detect::v4::detect() { + println!("test {} ... skipped (v4)", module_path!()); + return; + } + let lhs = random_svector(300); + let rhs = random_svector(350); + let specialized = unsafe { sl2_v4(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + m = std::cmp::max(m, (specialized - fallback).abs()); } - let lhs = random_svector(300); - let rhs = random_svector(350); - let specialized = unsafe { sl2_v4(lhs.for_borrow(), rhs.for_borrow()) }; - let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + dbg!(m); } #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] @@ -694,7 +695,7 @@ pub fn l2_normalize(vector: &mut SVecf32Owned) { unsafe fn emulate_mm512_2intersect_epi32( a: std::arch::x86_64::__m512i, b: std::arch::x86_64::__m512i, -) -> (u16, u16) { +) -> (std::arch::x86_64::__mmask16, std::arch::x86_64::__mmask16) { use std::arch::x86_64::*; unsafe { let a1 = _mm512_alignr_epi32(a, a, 4); @@ -748,11 +749,14 @@ unsafe fn emulate_mm512_2intersect_epi32( fn random_svector(len: usize) -> SVecf32Owned { use rand::Rng; let mut rng = rand::thread_rng(); - let mut indexes: Vec = (0..len).map(|_| rng.gen_range(0..30000)).collect(); - indexes.sort_unstable(); - indexes.dedup(); - let values: Vec = (0..indexes.len()) - .map(|_| F32(rng.gen_range(-1.0..1.0))) - .collect(); + let mut indexes = rand::seq::index::sample(&mut rand::thread_rng(), 30000, len) + .into_iter() + .map(|x| x as _) + .collect::>(); + indexes.sort(); + let values: Vec = std::iter::from_fn(|| Some(F32(rng.gen_range(-1.0..1.0)))) + .filter(|x| !x.is_zero()) + .take(indexes.len()) + .collect::>(); SVecf32Owned::new(30000, indexes, values) } From 5a2bdf504f6fa6a8452ba746811170cfd63ac52d Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 21:52:25 +0800 Subject: [PATCH 11/16] test: run svector tests 10000 times Signed-off-by: usamoi --- crates/base/src/vector/svecf32.rs | 58 +++++++++++++++++-------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index 0ea81b6d2..3b7b644b7 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -294,20 +294,25 @@ unsafe fn cosine_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn cosine_v4_test() { - const EPSILON: F32 = F32(1e-5); + const EPSILON: F32 = F32(5e-7); detect::init(); if !detect::v4::detect() { println!("test {} ... skipped (v4)", module_path!()); return; } - let lhs = random_svector(300); - let rhs = random_svector(350); - let specialized = unsafe { cosine_v4(lhs.for_borrow(), rhs.for_borrow()) }; - let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + let mut eps = F32(0.0); + for _ in 0..10000 { + let lhs = random_svector(300); + let rhs = random_svector(350); + let specialized = unsafe { cosine_v4(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + eps = std::cmp::max(eps, (specialized - fallback).abs()); + } + dbg!(eps); } #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] @@ -429,20 +434,22 @@ unsafe fn dot_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn dot_v4_test() { - const EPSILON: F32 = F32(1e-5); + const EPSILON: F32 = F32(1e-6); detect::init(); if !detect::v4::detect() { println!("test {} ... skipped (v4)", module_path!()); return; } - let lhs = random_svector(300); - let rhs = random_svector(350); - let specialized = unsafe { dot_v4(lhs.for_borrow(), rhs.for_borrow()) }; - let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let lhs = random_svector(300); + let rhs = random_svector(350); + let specialized = unsafe { dot_v4(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] @@ -588,14 +595,13 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sl2_v4_test() { - let mut m = F32(0.0); + const EPSILON: F32 = F32(5e-4); + detect::init(); + if !detect::v4::detect() { + println!("test {} ... skipped (v4)", module_path!()); + return; + } for _ in 0..10000 { - const EPSILON: F32 = F32(1e-3); - detect::init(); - if !detect::v4::detect() { - println!("test {} ... skipped (v4)", module_path!()); - return; - } let lhs = random_svector(300); let rhs = random_svector(350); let specialized = unsafe { sl2_v4(lhs.for_borrow(), rhs.for_borrow()) }; @@ -604,9 +610,7 @@ fn sl2_v4_test() { (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." ); - m = std::cmp::max(m, (specialized - fallback).abs()); } - dbg!(m); } #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] From 0e6552ff3df52e3c0cab07c9d06754596e79b896 Mon Sep 17 00:00:00 2001 From: Mingzhuo Yin Date: Tue, 26 Mar 2024 22:03:33 +0800 Subject: [PATCH 12/16] fix: svecf32_sl2_v4 Signed-off-by: Mingzhuo Yin --- crates/base/src/vector/svecf32.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index 3b7b644b7..63274c04c 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -518,8 +518,10 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { let v_r = _mm512_loadu_ps(rhs_val.add(rhs_pos)); let v_l = _mm512_maskz_compress_ps(m_l, v_l); let v_r = _mm512_maskz_compress_ps(m_r, v_r); - dd = _mm512_fmadd_ps(v_l, v_r, dd); - dd = _mm512_fmadd_ps(v_l, v_r, dd); + let d = _mm512_sub_ps(v_l, v_r); + dd = _mm512_fmadd_ps(d, d, dd); + dd = _mm512_sub_ps(dd, _mm512_mul_ps(v_l, v_l)); + dd = _mm512_sub_ps(dd, _mm512_mul_ps(v_r, v_r)); let l_max = lhs.indexes().get_unchecked(lhs_pos + W - 1); let r_max = rhs.indexes().get_unchecked(rhs_pos + W - 1); match l_max.cmp(r_max) { @@ -547,8 +549,10 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { let v_r = _mm512_maskz_loadu_ps(mask_r, rhs_val.add(rhs_pos)); let v_l = _mm512_maskz_compress_ps(m_l, v_l); let v_r = _mm512_maskz_compress_ps(m_r, v_r); - dd = _mm512_fmadd_ps(v_l, v_r, dd); - dd = _mm512_fmadd_ps(v_l, v_r, dd); + let d = _mm512_sub_ps(v_l, v_r); + dd = _mm512_fmadd_ps(d, d, dd); + dd = _mm512_sub_ps(dd, _mm512_mul_ps(v_l, v_l)); + dd = _mm512_sub_ps(dd, _mm512_mul_ps(v_r, v_r)); let l_max = lhs.indexes().get_unchecked(lhs_pos + len_l - 1); let r_max = rhs.indexes().get_unchecked(rhs_pos + len_r - 1); match l_max.cmp(r_max) { @@ -564,7 +568,6 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { } } } - dd = _mm512_sub_ps(_mm512_setzero_ps(), dd); let mut lhs_pos = 0; while lhs_pos < lhs_size { let v = _mm512_loadu_ps(lhs_val.add(lhs_pos)); @@ -587,7 +590,6 @@ unsafe fn sl2_v4(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { rhs_val.add(rhs_pos), ); dd = _mm512_fmadd_ps(v, v, dd); - F32(_mm512_reduce_add_ps(dd)) } } From bc6642364409f763de75b61c96e650e8eac7ec25 Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 22:28:28 +0800 Subject: [PATCH 13/16] test: run vecf16 and veci8 test for 10000 times Signed-off-by: usamoi --- crates/base/src/vector/svecf32.rs | 3 - crates/base/src/vector/vecf16.rs | 204 +++++++++++++++++------------- crates/base/src/vector/veci8.rs | 18 +-- 3 files changed, 124 insertions(+), 101 deletions(-) diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index 63274c04c..09837356e 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -300,7 +300,6 @@ fn cosine_v4_test() { println!("test {} ... skipped (v4)", module_path!()); return; } - let mut eps = F32(0.0); for _ in 0..10000 { let lhs = random_svector(300); let rhs = random_svector(350); @@ -310,9 +309,7 @@ fn cosine_v4_test() { (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." ); - eps = std::cmp::max(eps, (specialized - fallback).abs()); } - dbg!(eps); } #[detect::multiversion(v4 = import, v3, v2, neon, fallback = export)] diff --git a/crates/base/src/vector/vecf16.rs b/crates/base/src/vector/vecf16.rs index 80b79d6fb..20f9da07f 100644 --- a/crates/base/src/vector/vecf16.rs +++ b/crates/base/src/vector/vecf16.rs @@ -113,21 +113,23 @@ unsafe fn cosine_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn cosine_v4_avx512fp16_test() { + const EPSILON: F32 = F32(0.002); detect::init(); if !detect::v4_avx512fp16::detect() { println!("test {} ... skipped (v4_avx512fp16)", module_path!()); return; } - const EPSILON: F32 = F32(half::f16::EPSILON.to_f32_const()); - let n = 4000; - let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let specialized = unsafe { cosine_v4_avx512fp16(&lhs, &rhs) }; - let fallback = unsafe { cosine_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { cosine_v4_avx512fp16(&lhs, &rhs) }; + let fallback = unsafe { cosine_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[inline] @@ -142,21 +144,23 @@ unsafe fn cosine_v4(lhs: &[F16], rhs: &[F16]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn cosine_v4_test() { + const EPSILON: F32 = F32(0.002); detect::init(); if !detect::v4::detect() { println!("test {} ... skipped (v4)", module_path!()); return; } - const EPSILON: F32 = F32(half::f16::EPSILON.to_f32_const()); - let n = 4000; - let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let specialized = unsafe { cosine_v4(&lhs, &rhs) }; - let fallback = unsafe { cosine_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { cosine_v4(&lhs, &rhs) }; + let fallback = unsafe { cosine_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[inline] @@ -171,21 +175,23 @@ unsafe fn cosine_v3(lhs: &[F16], rhs: &[F16]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn cosine_v3_test() { + const EPSILON: F32 = F32(0.002); detect::init(); if !detect::v3::detect() { println!("test {} ... skipped (v3)", module_path!()); return; } - const EPSILON: F32 = F32(half::f16::EPSILON.to_f32_const()); - let n = 4000; - let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let specialized = unsafe { cosine_v3(&lhs, &rhs) }; - let fallback = unsafe { cosine_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { cosine_v3(&lhs, &rhs) }; + let fallback = unsafe { cosine_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] @@ -215,21 +221,26 @@ unsafe fn dot_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn dot_v4_avx512fp16_test() { + const EPSILON: F32 = F32(2.0); detect::init(); if !detect::v4_avx512fp16::detect() { println!("test {} ... skipped (v4_avx512fp16)", module_path!()); return; } - const EPSILON: F32 = F32(1.0); - let n = 4000; - let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let specialized = unsafe { dot_v4_avx512fp16(&lhs, &rhs) }; - let fallback = unsafe { dot_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + let mut m = F32(0.0); + for _ in 0..10000 { + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { dot_v4_avx512fp16(&lhs, &rhs) }; + let fallback = unsafe { dot_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + m = std::cmp::max(m, (specialized - fallback).abs()); + } + dbg!(m); } #[inline] @@ -244,21 +255,23 @@ unsafe fn dot_v4(lhs: &[F16], rhs: &[F16]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn dot_v4_test() { + const EPSILON: F32 = F32(2.0); detect::init(); if !detect::v4::detect() { println!("test {} ... skipped (v4)", module_path!()); return; } - const EPSILON: F32 = F32(1.0); - let n = 4000; - let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let specialized = unsafe { dot_v4(&lhs, &rhs) }; - let fallback = unsafe { dot_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { dot_v4(&lhs, &rhs) }; + let fallback = unsafe { dot_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[inline] @@ -273,21 +286,23 @@ unsafe fn dot_v3(lhs: &[F16], rhs: &[F16]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn dot_v3_test() { + const EPSILON: F32 = F32(2.0); detect::init(); if !detect::v3::detect() { println!("test {} ... skipped (v3)", module_path!()); return; } - const EPSILON: F32 = F32(1.0); - let n = 4000; - let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let specialized = unsafe { dot_v3(&lhs, &rhs) }; - let fallback = unsafe { dot_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { dot_v3(&lhs, &rhs) }; + let fallback = unsafe { dot_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] @@ -313,21 +328,26 @@ unsafe fn sl2_v4_avx512fp16(lhs: &[F16], rhs: &[F16]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sl2_v4_avx512fp16_test() { + const EPSILON: F32 = F32(2.0); detect::init(); if !detect::v4_avx512fp16::detect() { println!("test {} ... skipped (v4_avx512fp16)", module_path!()); return; } - const EPSILON: F32 = F32(1.0); - let n = 4000; - let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let specialized = unsafe { sl2_v4_avx512fp16(&lhs, &rhs) }; - let fallback = unsafe { sl2_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + let mut m = F32(0.0); + for _ in 0..10000 { + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { sl2_v4_avx512fp16(&lhs, &rhs) }; + let fallback = unsafe { sl2_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + m = std::cmp::max(m, (specialized - fallback).abs()); + } + dbg!(m); } #[inline] @@ -342,21 +362,23 @@ unsafe fn sl2_v4(lhs: &[F16], rhs: &[F16]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sl2_v4_test() { + const EPSILON: F32 = F32(2.0); detect::init(); if !detect::v4::detect() { println!("test {} ... skipped (v4)", module_path!()); return; } - const EPSILON: F32 = F32(1.0); - let n = 4000; - let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let specialized = unsafe { sl2_v4(&lhs, &rhs) }; - let fallback = unsafe { sl2_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { sl2_v4(&lhs, &rhs) }; + let fallback = unsafe { sl2_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[inline] @@ -371,21 +393,23 @@ unsafe fn sl2_v3(lhs: &[F16], rhs: &[F16]) -> F32 { #[cfg(all(target_arch = "x86_64", test))] #[test] fn sl2_v3_test() { + const EPSILON: F32 = F32(2.0); detect::init(); if !detect::v3::detect() { println!("test {} ... skipped (v3)", module_path!()); return; } - const EPSILON: F32 = F32(1.0); - let n = 4000; - let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); - let specialized = unsafe { sl2_v3(&lhs, &rhs) }; - let fallback = unsafe { sl2_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let n = 4000; + let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); + let specialized = unsafe { sl2_v3(&lhs, &rhs) }; + let fallback = unsafe { sl2_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[detect::multiversion(v4_avx512fp16 = import, v4 = import, v3 = import, v2, neon, fallback = export)] diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index 94b065167..bbc896572 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -373,14 +373,16 @@ fn dot_internal_v4_avx512vnni_test() { println!("test {} ... skipped (v4_avx512vnni)", module_path!()); return; } - let lhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); - let rhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); - let specialized = unsafe { dot_internal_v4_avx512vnni(&lhs, &rhs) }; - let fallback = unsafe { dot_internal_fallback(&lhs, &rhs) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let lhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); + let rhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); + let specialized = unsafe { dot_internal_v4_avx512vnni(&lhs, &rhs) }; + let fallback = unsafe { dot_internal_fallback(&lhs, &rhs) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[detect::multiversion(v4_avx512vnni = import, v4, v3, v2, neon, fallback = export)] From 4a360a36ac590ca8f6d15a4d8d38e6a16c18e5b7 Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 26 Mar 2024 22:31:42 +0800 Subject: [PATCH 14/16] test: run bvecf32 test for 10000 times Signed-off-by: usamoi --- crates/base/src/vector/bvecf32.rs | 120 ++++++++++++------------------ 1 file changed, 48 insertions(+), 72 deletions(-) diff --git a/crates/base/src/vector/bvecf32.rs b/crates/base/src/vector/bvecf32.rs index ec11bf9b5..7fe945495 100644 --- a/crates/base/src/vector/bvecf32.rs +++ b/crates/base/src/vector/bvecf32.rs @@ -245,24 +245,16 @@ fn cosine_v4_avx512vpopcntdq_test() { println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); return; } - let lhs = { - let mut x = vec![0; 126]; - x.fill_with(|| rand::random()); - x[125] &= 1; - BVecf32Owned::new(8001, x) - }; - let rhs = { - let mut x = vec![0; 126]; - x.fill_with(|| rand::random()); - x[125] &= 1; - BVecf32Owned::new(8001, x) - }; - let specialized = unsafe { cosine_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; - let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let lhs = random_bvector(); + let rhs = random_bvector(); + let specialized = unsafe { cosine_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { cosine_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] @@ -327,24 +319,16 @@ fn dot_v4_avx512vpopcntdq_test() { println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); return; } - let lhs = { - let mut x = vec![0; 126]; - x.fill_with(|| rand::random()); - x[125] &= 1; - BVecf32Owned::new(8001, x) - }; - let rhs = { - let mut x = vec![0; 126]; - x.fill_with(|| rand::random()); - x[125] &= 1; - BVecf32Owned::new(8001, x) - }; - let specialized = unsafe { dot_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; - let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let lhs = random_bvector(); + let rhs = random_bvector(); + let specialized = unsafe { dot_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { dot_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] @@ -402,24 +386,16 @@ fn sl2_v4_avx512vpopcntdq_test() { println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); return; } - let lhs = { - let mut x = vec![0; 126]; - x.fill_with(|| rand::random()); - x[125] &= 1; - BVecf32Owned::new(8001, x) - }; - let rhs = { - let mut x = vec![0; 126]; - x.fill_with(|| rand::random()); - x[125] &= 1; - BVecf32Owned::new(8001, x) - }; - let specialized = unsafe { sl2_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; - let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let lhs = random_bvector(); + let rhs = random_bvector(); + let specialized = unsafe { sl2_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { sl2_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] @@ -481,24 +457,16 @@ fn jaccard_v4_avx512vpopcntdq_test() { println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); return; } - let lhs = { - let mut x = vec![0; 126]; - x.fill_with(|| rand::random()); - x[125] &= 1; - BVecf32Owned::new(8001, x) - }; - let rhs = { - let mut x = vec![0; 126]; - x.fill_with(|| rand::random()); - x[125] &= 1; - BVecf32Owned::new(8001, x) - }; - let specialized = unsafe { jaccard_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; - let fallback = unsafe { jaccard_fallback(lhs.for_borrow(), rhs.for_borrow()) }; - assert!( - (specialized - fallback).abs() < EPSILON, - "specialized = {specialized}, fallback = {fallback}." - ); + for _ in 0..10000 { + let lhs = random_bvector(); + let rhs = random_bvector(); + let specialized = unsafe { jaccard_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; + let fallback = unsafe { jaccard_fallback(lhs.for_borrow(), rhs.for_borrow()) }; + assert!( + (specialized - fallback).abs() < EPSILON, + "specialized = {specialized}, fallback = {fallback}." + ); + } } #[detect::multiversion(v4_avx512vpopcntdq = import, v4, v3, v2, neon, fallback = export)] @@ -558,3 +526,11 @@ pub fn l2_normalize<'a>(vector: BVecf32Borrowed<'a>) -> Vecf32Owned { let l = length(vector); Vecf32Owned::new(vector.iter().map(|i| F32(i as u32 as f32) / l).collect()) } + +#[cfg(all(target_arch = "x86_64", test))] +fn random_bvector() -> BVecf32Owned { + let mut x = vec![0; 126]; + x.fill_with(|| rand::random()); + x[125] &= 1; + BVecf32Owned::new(8001, x) +} From 1fc097e8ed52d740911f254504bd41d3ff0e43df Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 27 Mar 2024 10:31:15 +0800 Subject: [PATCH 15/16] test: run tests for 300 times to reduce ci time Signed-off-by: usamoi --- crates/base/src/vector/bvecf32.rs | 8 ++++---- crates/base/src/vector/svecf32.rs | 6 +++--- crates/base/src/vector/vecf16.rs | 18 +++++++++--------- crates/base/src/vector/veci8.rs | 2 +- 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/crates/base/src/vector/bvecf32.rs b/crates/base/src/vector/bvecf32.rs index 7fe945495..e6b0700bf 100644 --- a/crates/base/src/vector/bvecf32.rs +++ b/crates/base/src/vector/bvecf32.rs @@ -245,7 +245,7 @@ fn cosine_v4_avx512vpopcntdq_test() { println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let lhs = random_bvector(); let rhs = random_bvector(); let specialized = unsafe { cosine_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; @@ -319,7 +319,7 @@ fn dot_v4_avx512vpopcntdq_test() { println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let lhs = random_bvector(); let rhs = random_bvector(); let specialized = unsafe { dot_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; @@ -386,7 +386,7 @@ fn sl2_v4_avx512vpopcntdq_test() { println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let lhs = random_bvector(); let rhs = random_bvector(); let specialized = unsafe { sl2_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; @@ -457,7 +457,7 @@ fn jaccard_v4_avx512vpopcntdq_test() { println!("test {} ... skipped (v4_avx512vpopcntdq)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let lhs = random_bvector(); let rhs = random_bvector(); let specialized = unsafe { jaccard_v4_avx512vpopcntdq(lhs.for_borrow(), rhs.for_borrow()) }; diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs index 09837356e..53e8f7fa5 100644 --- a/crates/base/src/vector/svecf32.rs +++ b/crates/base/src/vector/svecf32.rs @@ -300,7 +300,7 @@ fn cosine_v4_test() { println!("test {} ... skipped (v4)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let lhs = random_svector(300); let rhs = random_svector(350); let specialized = unsafe { cosine_v4(lhs.for_borrow(), rhs.for_borrow()) }; @@ -437,7 +437,7 @@ fn dot_v4_test() { println!("test {} ... skipped (v4)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let lhs = random_svector(300); let rhs = random_svector(350); let specialized = unsafe { dot_v4(lhs.for_borrow(), rhs.for_borrow()) }; @@ -600,7 +600,7 @@ fn sl2_v4_test() { println!("test {} ... skipped (v4)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let lhs = random_svector(300); let rhs = random_svector(350); let specialized = unsafe { sl2_v4(lhs.for_borrow(), rhs.for_borrow()) }; diff --git a/crates/base/src/vector/vecf16.rs b/crates/base/src/vector/vecf16.rs index 20f9da07f..2662cbb19 100644 --- a/crates/base/src/vector/vecf16.rs +++ b/crates/base/src/vector/vecf16.rs @@ -119,7 +119,7 @@ fn cosine_v4_avx512fp16_test() { println!("test {} ... skipped (v4_avx512fp16)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -150,7 +150,7 @@ fn cosine_v4_test() { println!("test {} ... skipped (v4)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -181,7 +181,7 @@ fn cosine_v3_test() { println!("test {} ... skipped (v3)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -228,7 +228,7 @@ fn dot_v4_avx512fp16_test() { return; } let mut m = F32(0.0); - for _ in 0..10000 { + for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -261,7 +261,7 @@ fn dot_v4_test() { println!("test {} ... skipped (v4)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -292,7 +292,7 @@ fn dot_v3_test() { println!("test {} ... skipped (v3)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -335,7 +335,7 @@ fn sl2_v4_avx512fp16_test() { return; } let mut m = F32(0.0); - for _ in 0..10000 { + for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -368,7 +368,7 @@ fn sl2_v4_test() { println!("test {} ... skipped (v4)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -399,7 +399,7 @@ fn sl2_v3_test() { println!("test {} ... skipped (v3)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); let rhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); diff --git a/crates/base/src/vector/veci8.rs b/crates/base/src/vector/veci8.rs index bbc896572..499aa2330 100644 --- a/crates/base/src/vector/veci8.rs +++ b/crates/base/src/vector/veci8.rs @@ -373,7 +373,7 @@ fn dot_internal_v4_avx512vnni_test() { println!("test {} ... skipped (v4_avx512vnni)", module_path!()); return; } - for _ in 0..10000 { + for _ in 0..300 { let lhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); let rhs = std::array::from_fn::<_, 400, _>(|_| I8(rand::random())); let specialized = unsafe { dot_internal_v4_avx512vnni(&lhs, &rhs) }; From 2224d959ae2b77cf11a5cb8ffb013f30d1d302a4 Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 27 Mar 2024 10:34:01 +0800 Subject: [PATCH 16/16] chore: update rust toolchain Signed-off-by: usamoi --- .github/workflows/style.yml | 5 ++--- Cargo.lock | 10 ---------- crates/base/src/vector/vecf16.rs | 6 ------ crates/detect/Cargo.toml | 1 - crates/detect_macros/src/lib.rs | 4 ++-- rust-toolchain.toml | 2 +- 6 files changed, 5 insertions(+), 23 deletions(-) diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index c5962aa2d..3ff8765bc 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -1,4 +1,4 @@ -name: Style check +name: Style on: push: @@ -9,8 +9,7 @@ on: workflow_dispatch: jobs: - run: - name: check + check: runs-on: ubuntu-latest steps: - name: Checkout Actions Repository diff --git a/Cargo.lock b/Cargo.lock index 514953e87..83656b00d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -750,7 +750,6 @@ version = "0.0.0" dependencies = [ "detect_macros", "rustix 0.38.32", - "std_detect", ] [[package]] @@ -2519,15 +2518,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" -[[package]] -name = "std_detect" -version = "0.1.5" -source = "git+https://github.com/tensorchord/stdarch?rev=e50b2f6fa7f8a9a0081c88b1793d8560462d5848#e50b2f6fa7f8a9a0081c88b1793d8560462d5848" -dependencies = [ - "cfg-if", - "libc", -] - [[package]] name = "storage" version = "0.0.0" diff --git a/crates/base/src/vector/vecf16.rs b/crates/base/src/vector/vecf16.rs index 2662cbb19..e3a18a9e9 100644 --- a/crates/base/src/vector/vecf16.rs +++ b/crates/base/src/vector/vecf16.rs @@ -227,7 +227,6 @@ fn dot_v4_avx512fp16_test() { println!("test {} ... skipped (v4_avx512fp16)", module_path!()); return; } - let mut m = F32(0.0); for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -238,9 +237,7 @@ fn dot_v4_avx512fp16_test() { (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." ); - m = std::cmp::max(m, (specialized - fallback).abs()); } - dbg!(m); } #[inline] @@ -334,7 +331,6 @@ fn sl2_v4_avx512fp16_test() { println!("test {} ... skipped (v4_avx512fp16)", module_path!()); return; } - let mut m = F32(0.0); for _ in 0..300 { let n = 4000; let lhs = (0..n).map(|_| F16(rand::random::<_>())).collect::>(); @@ -345,9 +341,7 @@ fn sl2_v4_avx512fp16_test() { (specialized - fallback).abs() < EPSILON, "specialized = {specialized}, fallback = {fallback}." ); - m = std::cmp::max(m, (specialized - fallback).abs()); } - dbg!(m); } #[inline] diff --git a/crates/detect/Cargo.toml b/crates/detect/Cargo.toml index d7cc84b0f..3c1ae6840 100644 --- a/crates/detect/Cargo.toml +++ b/crates/detect/Cargo.toml @@ -5,7 +5,6 @@ edition.workspace = true [dependencies] rustix.workspace = true -std_detect = { git = "https://github.com/tensorchord/stdarch", rev = "e50b2f6fa7f8a9a0081c88b1793d8560462d5848" } detect_macros = { path = "../detect_macros" } diff --git a/crates/detect_macros/src/lib.rs b/crates/detect_macros/src/lib.rs index 32c8c2667..1f60b80be 100644 --- a/crates/detect_macros/src/lib.rs +++ b/crates/detect_macros/src/lib.rs @@ -284,12 +284,12 @@ pub fn main(_: proc_macro::TokenStream) -> proc_macro::TokenStream { #[cfg(target_arch = "x86_64")] pub fn test() -> bool { - true #(&& std_detect::is_x86_feature_detected!(#target_features))* + true #(&& std::arch::is_x86_feature_detected!(#target_features))* } #[cfg(target_arch = "aarch64")] pub fn test() -> bool { - true #(&& std_detect::is_aarch64_feature_detected!(#target_features))* + true #(&& std::arch::is_aarch64_feature_detected!(#target_features))* } pub(crate) fn init() { diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 6be0e4fa5..a9ee38b8d 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2024-03-24" +channel = "nightly-2024-03-27" profile = "default" targets = [ "aarch64-apple-darwin",