diff --git a/components/aead/Cargo.toml b/components/aead/Cargo.toml index 927208b5c..f2bbcafc6 100644 --- a/components/aead/Cargo.toml +++ b/components/aead/Cargo.toml @@ -19,13 +19,13 @@ mock = ["mpz-common/test-utils", "dep:mpz-ot"] tlsn-block-cipher = { path = "../cipher/block-cipher" } tlsn-stream-cipher = { path = "../cipher/stream-cipher" } tlsn-universal-hash = { path = "../universal-hash" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c", optional = true, features = [ +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e", optional = true, features = [ "ideal", ] } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "f8d4533" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +serio = "0.1" async-trait = "0.1" derive_builder = "0.12" diff --git a/components/cipher/Cargo.toml b/components/cipher/Cargo.toml index be7a0ae90..396822800 100644 --- a/components/cipher/Cargo.toml +++ b/components/cipher/Cargo.toml @@ -4,9 +4,9 @@ resolver = "2" [workspace.dependencies] # tlsn -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "f8d4533" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "bb9769d" } # crypto aes = "0.8" diff --git a/components/integration-tests/Cargo.toml b/components/integration-tests/Cargo.toml index ddf69b5d8..95efd9013 100644 --- a/components/integration-tests/Cargo.toml +++ b/components/integration-tests/Cargo.toml @@ -13,9 +13,9 @@ lto = true [dev-dependencies] -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } tlsn-block-cipher = { path = "../cipher/block-cipher" } tlsn-stream-cipher = { path = "../cipher/stream-cipher" } tlsn-universal-hash = { path = "../universal-hash" } @@ -23,7 +23,7 @@ tlsn-aead = { path = "../aead" } tlsn-key-exchange = { path = "../key-exchange" } tlsn-point-addition = { path = "../point-addition" } tlsn-hmac-sha256 = { path = "../prf/hmac-sha256" } -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "f8d4533" } +tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "bb9769d" } uid-mux = { path = "../uid-mux" } diff --git a/components/key-exchange/Cargo.toml b/components/key-exchange/Cargo.toml index e6484cd3c..75b3b49b4 100644 --- a/components/key-exchange/Cargo.toml +++ b/components/key-exchange/Cargo.toml @@ -16,27 +16,27 @@ default = ["mock"] mock = [] [dependencies] -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c", features = [ +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e", features = [ "ideal", ] } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } p256 = { version = "0.13", features = ["ecdh", "serde"] } async-trait = "0.1" thiserror = "1" serde = "1" futures = "0.3" -serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "f8d4533" } +serio = "0.1" derive_builder = "0.12" tracing = "0.1" rand = "0.8" [dev-dependencies] -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c", features = [ +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e", features = [ "ideal", ] } diff --git a/components/prf/Cargo.toml b/components/prf/Cargo.toml index 532b61de6..4e306c8fd 100644 --- a/components/prf/Cargo.toml +++ b/components/prf/Cargo.toml @@ -4,10 +4,10 @@ resolver = "2" [workspace.dependencies] # tlsn -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } # async async-trait = "0.1" diff --git a/components/tls/Cargo.toml b/components/tls/Cargo.toml index 6e326729b..5e6d1e461 100644 --- a/components/tls/Cargo.toml +++ b/components/tls/Cargo.toml @@ -10,9 +10,9 @@ members = [ resolver = "2" [patch."https://github.com/tlsnotary/tlsn-utils"] -tlsn-utils-aio = { git = "https://github.com/tlsnotary//tlsn-utils", rev = "f8d4533" } -uid-mux = { git = "https://github.com/tlsnotary//tlsn-utils", rev = "f8d4533" } -serio = { git = "https://github.com/tlsnotary//tlsn-utils", rev = "f8d4533" } +tlsn-utils-aio = { git = "https://github.com/tlsnotary//tlsn-utils", rev = "bb9769d" } +uid-mux = "0.1" +serio = "0.1" [workspace.dependencies] # rand diff --git a/components/tls/tls-mpc/Cargo.toml b/components/tls/tls-mpc/Cargo.toml index 636d24c52..3727c6fc3 100644 --- a/components/tls/tls-mpc/Cargo.toml +++ b/components/tls/tls-mpc/Cargo.toml @@ -18,13 +18,13 @@ default = [] tlsn-tls-core = { path = "../tls-core", features = ["serde"] } tlsn-tls-backend = { path = "../tls-backend" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } tlsn-block-cipher = { path = "../../cipher/block-cipher" } tlsn-stream-cipher = { path = "../../cipher/stream-cipher" } @@ -33,10 +33,8 @@ tlsn-aead = { path = "../../aead" } tlsn-key-exchange = { path = "../../key-exchange" } tlsn-hmac-sha256 = { path = "../../prf/hmac-sha256" } -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "f8d4533" } -uid-mux = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "f8d4533", features = [ - "serio", -] } +tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "bb9769d" } +uid-mux = { version = "0.1", features = ["serio"] } p256.workspace = true rand.workspace = true @@ -53,9 +51,7 @@ ludi = { git = "https://github.com/sinui0/ludi", rev = "b590de5" } tlsn-tls-client = { path = "../tls-client" } tlsn-tls-client-async = { path = "../tls-client-async" } tls-server-fixture = { path = "../tls-server-fixture" } -serio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "f8d4533", features = [ - "compat", -] } +serio = { version = "0.1", features = ["compat"] } tracing-subscriber.workspace = true tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } diff --git a/components/universal-hash/Cargo.toml b/components/universal-hash/Cargo.toml index dbbada158..ec3e47103 100644 --- a/components/universal-hash/Cargo.toml +++ b/components/universal-hash/Cargo.toml @@ -15,13 +15,13 @@ ideal = ["dep:ghash_rc"] [dependencies] # tlsn -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c", features = [ +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e", features = [ "ideal", ] } -mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c" } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-share-conversion-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } ghash_rc = { package = "ghash", version = "0.5", optional = true } @@ -39,10 +39,10 @@ tracing = "0.1" derive_builder = "0.12" [dev-dependencies] -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c", features = [ +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e", features = [ "test-utils", ] } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "544cf5c", features = [ +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e", features = [ "ideal", ] } diff --git a/notary-server/Cargo.toml b/notary-server/Cargo.toml index 865af7baa..3a0a4d539 100644 --- a/notary-server/Cargo.toml +++ b/notary-server/Cargo.toml @@ -18,8 +18,10 @@ futures-util = "0.3.28" http = "1.1" http-body-util = "0.1" hyper = { version = "1.1", features = ["client", "http1", "server"] } -hyper-util = {version = "0.1", features = ["full"]} -notify = { version = "6.1.1", default-features = false, features = ["macos_kqueue"] } +hyper-util = { version = "0.1", features = ["full"] } +notify = { version = "6.1.1", default-features = false, features = [ + "macos_kqueue", +] } opentelemetry = { version = "0.19" } p256 = "0.13" rstest = "0.18" @@ -31,7 +33,7 @@ serde_yaml = "0.9.21" sha1 = "0.10" structopt = "0.3.26" thiserror = "1" -tlsn-verifier = { path = "../tlsn/tlsn-verifier", features = ["tracing"] } +tlsn-verifier = { path = "../tlsn/tlsn-verifier" } tokio = { version = "1", features = ["full"] } tokio-rustls = { version = "0.24.1" } tokio-util = { version = "0.7", features = ["compat"] } diff --git a/tlsn/Cargo.toml b/tlsn/Cargo.toml index 8497bc264..5f8c2f114 100644 --- a/tlsn/Cargo.toml +++ b/tlsn/Cargo.toml @@ -25,18 +25,21 @@ tlsn-tls-mpc = { path = "../components/tls/tls-mpc" } tlsn-tls-client = { path = "../components/tls/tls-client" } tlsn-tls-client-async = { path = "../components/tls/tls-client-async" } tls-server-fixture = { path = "../components/tls/tls-server-fixture" } -uid-mux = { path = "../components/uid-mux" } -tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "51f313d" } -tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "51f313d" } -spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "51f313d" } +tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "bb9769d" } +tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "bb9769d" } +spansy = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "bb9769d" } +serio = "0.1" +uid-mux = { version = "0.1", features = ["serio"] } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "477448c " } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "477448c " } -mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "477448c " } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "477448c " } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "477448c " } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "477448c " } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "435526e" } futures = "0.3" tokio-util = "0.7" diff --git a/tlsn/examples/Cargo.toml b/tlsn/examples/Cargo.toml index ec14c396a..fd6a3874a 100644 --- a/tlsn/examples/Cargo.toml +++ b/tlsn/examples/Cargo.toml @@ -6,42 +6,42 @@ version = "0.0.0" [dependencies] mpz-core.workspace = true -notary-server = {path = "../../notary-server"} +notary-server = { path = "../../notary-server" } tlsn-core.workspace = true -tlsn-prover = {workspace = true, features = ["tracing"]} +tlsn-prover.workspace = true tlsn-tls-client.workspace = true tlsn-tls-core.workspace = true tlsn-utils.workspace = true tlsn-verifier.workspace = true -elliptic-curve = {version = "0.13.5", features = ["pkcs8"]} -p256 = {workspace = true, features = ["ecdsa"]} +elliptic-curve = { version = "0.13.5", features = ["pkcs8"] } +p256 = { workspace = true, features = ["ecdsa"] } webpki-roots.workspace = true -async-tls = {version = "0.12", default-features = false, features = [ +async-tls = { version = "0.12", default-features = false, features = [ "client", -]} +] } chrono = "0.4" futures.workspace = true http-body-util = "0.1" -hyper = {version = "1.1", features = ["client", "http1"]} -hyper-util = {version = "0.1", features = ["full"]} -rustls = {version = "0.21"} -rustls-pemfile = {version = "1.0.2"} -tokio = {workspace = true, features = [ +hyper = { version = "1.1", features = ["client", "http1"] } +hyper-util = { version = "0.1", features = ["full"] } +rustls = { version = "0.21" } +rustls-pemfile = { version = "1.0.2" } +tokio = { workspace = true, features = [ "rt", "rt-multi-thread", "macros", "net", "io-std", "fs", -]} -tokio-rustls = {version = "0.24.1"} +] } +tokio-rustls = { version = "0.24.1" } tokio-util.workspace = true dotenv = "0.15.0" eyre = "0.6.8" -serde = {version = "1.0.147", features = ["derive"]} +serde = { version = "1.0.147", features = ["derive"] } serde_json = "1.0" tracing-subscriber.workspace = true tracing.workspace = true diff --git a/tlsn/tests-integration/Cargo.toml b/tlsn/tests-integration/Cargo.toml index 79b0b8a2f..d8593d449 100644 --- a/tlsn/tests-integration/Cargo.toml +++ b/tlsn/tests-integration/Cargo.toml @@ -7,8 +7,8 @@ publish = false [dev-dependencies] tlsn-core.workspace = true tlsn-tls-core.workspace = true -tlsn-prover = { workspace = true, features = ["tracing"] } -tlsn-verifier = { workspace = true, features = ["tracing"] } +tlsn-prover.workspace = true +tlsn-verifier.workspace = true tlsn-server-fixture.workspace = true tlsn-utils.workspace = true diff --git a/tlsn/tlsn-common/Cargo.toml b/tlsn/tlsn-common/Cargo.toml index eeceaf4b0..b3a5afe24 100644 --- a/tlsn/tlsn-common/Cargo.toml +++ b/tlsn/tlsn-common/Cargo.toml @@ -5,11 +5,18 @@ version = "0.1.0-alpha.5" edition = "2021" [features] -default = ["tracing"] -tracing = ["uid-mux/tracing"] +default = [] [dependencies] -tlsn-utils-aio.workspace = true +mpz-share-conversion.workspace = true +mpz-garble.workspace = true +mpz-garble-core.workspace = true +mpz-ot.workspace = true +mpz-ole.workspace = true +mpz-core.workspace = true +mpz-common.workspace = true futures.workspace = true -uid-mux.workspace = true +serio = { workspace = true, features = ["codec", "bincode"] } +uid-mux = { workspace = true, features = ["serio"] } +tracing.workspace = true diff --git a/tlsn/tlsn-common/src/config.rs b/tlsn/tlsn-common/src/config.rs index a9bf5edfc..b7833c270 100644 --- a/tlsn/tlsn-common/src/config.rs +++ b/tlsn/tlsn-common/src/config.rs @@ -7,10 +7,6 @@ pub const DEFAULT_MAX_SENT_LIMIT: usize = 1 << 12; /// Default for the maximum number of bytes that can be received (16Kb). pub const DEFAULT_MAX_RECV_LIMIT: usize = 1 << 14; -// Determined experimentally, will be subject to change if underlying protocols are modified. -const KE_OTS: usize = 3360; -// Secret-sharing the GHASH blocks. -const GHASH_OTS: usize = 65664 * 2; // Extra cushion room, eg. for sharing J0 blocks. const EXTRA_OTS: usize = 16384; const OTS_PER_BYTE_SENT: usize = 8; @@ -20,12 +16,9 @@ const OTS_PER_BYTE_RECV: usize = 16; /// Returns an estimate of the number of OTs that will be sent. pub fn ot_send_estimate(role: Role, max_sent_data: usize, max_recv_data: usize) -> usize { match role { - Role::Prover => KE_OTS + GHASH_OTS + EXTRA_OTS, + Role::Prover => EXTRA_OTS, Role::Verifier => { - KE_OTS - + EXTRA_OTS - + (max_sent_data * OTS_PER_BYTE_SENT) - + (max_recv_data * OTS_PER_BYTE_RECV) + EXTRA_OTS + (max_sent_data * OTS_PER_BYTE_SENT) + (max_recv_data * OTS_PER_BYTE_RECV) } } } @@ -34,11 +27,8 @@ pub fn ot_send_estimate(role: Role, max_sent_data: usize, max_recv_data: usize) pub fn ot_recv_estimate(role: Role, max_sent_data: usize, max_recv_data: usize) -> usize { match role { Role::Prover => { - KE_OTS - + EXTRA_OTS - + (max_sent_data * OTS_PER_BYTE_SENT) - + (max_recv_data * OTS_PER_BYTE_RECV) + EXTRA_OTS + (max_sent_data * OTS_PER_BYTE_SENT) + (max_recv_data * OTS_PER_BYTE_RECV) } - Role::Verifier => KE_OTS + GHASH_OTS + EXTRA_OTS, + Role::Verifier => EXTRA_OTS, } } diff --git a/tlsn/tlsn-common/src/lib.rs b/tlsn/tlsn-common/src/lib.rs index e3e7357ee..c74ea7e16 100644 --- a/tlsn/tlsn-common/src/lib.rs +++ b/tlsn/tlsn-common/src/lib.rs @@ -7,6 +7,27 @@ pub mod config; pub mod mux; +use serio::codec::Codec; + +use crate::mux::MuxControl; + +/// IO type. +pub type Io = >::Framed; +/// Base OT sender. +pub type BaseOTSender = mpz_ot::chou_orlandi::Sender; +/// Base OT receiver. +pub type BaseOTReceiver = mpz_ot::chou_orlandi::Receiver; +/// OT sender. +pub type OTSender = mpz_ot::kos::SharedSender; +/// OT receiver. +pub type OTReceiver = mpz_ot::kos::SharedReceiver; +/// MPC executor. +pub type Executor = mpz_common::executor::MTExecutor; +/// MPC thread context. +pub type Context = mpz_common::executor::MTContext; +/// DEAP thread. +pub type DEAPThread = mpz_garble::protocol::deap::DEAPThread; + /// The party's role in the TLSN protocol. /// /// A Notary is classified as a Verifier. diff --git a/tlsn/tlsn-common/src/mux.rs b/tlsn/tlsn-common/src/mux.rs index 67a9be1a7..29719667f 100644 --- a/tlsn/tlsn-common/src/mux.rs +++ b/tlsn/tlsn-common/src/mux.rs @@ -1,19 +1,63 @@ //! Multiplexer used in the TLSNotary protocol. -use utils_aio::codec::BincodeMux; +use std::future::IntoFuture; -use futures::{AsyncRead, AsyncWrite}; -use uid_mux::{yamux, UidYamux, UidYamuxControl}; +use futures::{ + future::{FusedFuture, FutureExt}, + AsyncRead, AsyncWrite, Future, +}; +use serio::codec::Bincode; +use tracing::error; +use uid_mux::{yamux, FramedMux}; use crate::Role; /// Multiplexer supporting unique deterministic stream IDs. -pub type Mux = UidYamux; +pub type Mux = yamux::Yamux; /// Multiplexer controller providing streams with a codec attached. -pub type MuxControl = BincodeMux; +pub type MuxControl = FramedMux; -const KB: usize = 1024; -const MB: usize = 1024 * KB; +/// Multiplexer future which must be polled for the muxer to make progress. +pub struct MuxFuture( + Box> + Send + Unpin>, +); + +impl MuxFuture { + /// Returns true if the muxer is complete. + pub fn is_complete(&self) -> bool { + self.0.is_terminated() + } + + /// Awaits a future, polling the muxer future concurrently. + pub async fn poll_with(&mut self, fut: F) -> R + where + F: Future, + { + let mut fut = Box::pin(fut.fuse()); + // Poll the future concurrently with the muxer future. + // If the muxer returns an error, continue polling the future + // until it completes. + loop { + futures::select! { + res = fut => return res, + res = &mut self.0 => if let Err(e) = res { + error!("mux error: {:?}", e); + }, + } + } + } +} + +impl Future for MuxFuture { + type Output = Result<(), yamux::ConnectionError>; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + self.0.as_mut().poll_unpin(cx) + } +} /// Attaches a multiplexer to the provided socket. /// @@ -26,20 +70,21 @@ const MB: usize = 1024 * KB; pub fn attach_mux( socket: T, role: Role, -) -> (Mux, MuxControl) { +) -> (MuxFuture, MuxControl) { let mut mux_config = yamux::Config::default(); - // See PR #418 - mux_config.set_max_num_streams(40); - mux_config.set_max_buffer_size(16 * MB); - mux_config.set_receive_window(16 * MB as u32); + mux_config.set_max_num_streams(64); let mux_role = match role { Role::Prover => yamux::Mode::Client, Role::Verifier => yamux::Mode::Server, }; - let mux = UidYamux::new(mux_config, socket, mux_role); - let ctrl = BincodeMux::new(mux.control()); + let mux = Mux::new(socket, mux_config, mux_role); + let ctrl = FramedMux::new(mux.control(), Bincode); + + if let Role::Prover = role { + ctrl.mux().alloc(64); + } - (mux, ctrl) + (MuxFuture(Box::new(mux.into_future().fuse())), ctrl) } diff --git a/tlsn/tlsn-prover/Cargo.toml b/tlsn/tlsn-prover/Cargo.toml index 28b7760b6..490111894 100644 --- a/tlsn/tlsn-prover/Cargo.toml +++ b/tlsn/tlsn-prover/Cargo.toml @@ -9,14 +9,10 @@ version = "0.1.0-alpha.5" edition = "2021" [features] -default = ["formats"] +default = ["formats", "rayon"] formats = ["dep:tlsn-formats"] -tracing = [ - "dep:tracing", - "tlsn-tls-client-async/tracing", - "tlsn-tls-mpc/tracing", - "tlsn-common/tracing", -] +rayon = ["mpz-common/rayon"] +force-st = ["mpz-common/force-st"] [dependencies] tlsn-tls-core.workspace = true @@ -29,12 +25,16 @@ tlsn-tls-mpc.workspace = true tlsn-utils.workspace = true tlsn-utils-aio.workspace = true +serio = { workspace = true, features = ["compat"] } +uid-mux = { workspace = true, features = ["serio"] } mpz-share-conversion.workspace = true mpz-garble.workspace = true mpz-garble-core.workspace = true mpz-ot.workspace = true +mpz-ole.workspace = true mpz-core.workspace = true +mpz-common.workspace = true rand.workspace = true futures.workspace = true @@ -43,8 +43,7 @@ webpki-roots.workspace = true derive_builder.workspace = true opaque-debug.workspace = true bytes.workspace = true - -tracing = { workspace = true, optional = true } +tracing.workspace = true web-time.workspace = true diff --git a/tlsn/tlsn-prover/src/tls/config.rs b/tlsn/tlsn-prover/src/tls/config.rs index 02eca59d2..4d2145a6e 100644 --- a/tlsn/tlsn-prover/src/tls/config.rs +++ b/tlsn/tlsn-prover/src/tls/config.rs @@ -1,5 +1,4 @@ use mpz_ot::{chou_orlandi, kos}; -use mpz_share_conversion::{ReceiverConfig, SenderConfig}; use tls_client::RootCertStore; use tls_mpc::{MpcTlsCommonConfig, MpcTlsLeaderConfig, TranscriptConfig}; use tlsn_common::{ @@ -102,18 +101,6 @@ impl ProverConfig { pub(crate) fn ot_receiver_setup_count(&self) -> usize { ot_recv_estimate(Role::Prover, self.max_sent_data, self.max_recv_data) } - - pub(crate) fn build_p256_sender_config(&self) -> SenderConfig { - SenderConfig::builder().id("p256/0").build().unwrap() - } - - pub(crate) fn build_p256_receiver_config(&self) -> ReceiverConfig { - ReceiverConfig::builder().id("p256/1").build().unwrap() - } - - pub(crate) fn build_gf2_config(&self) -> SenderConfig { - SenderConfig::builder().id("gf2").record().build().unwrap() - } } /// Default root store using mozilla certs. diff --git a/tlsn/tlsn-prover/src/tls/error.rs b/tlsn/tlsn-prover/src/tls/error.rs index 2a4a129b5..932c1cfb3 100644 --- a/tlsn/tlsn-prover/src/tls/error.rs +++ b/tlsn/tlsn-prover/src/tls/error.rs @@ -12,8 +12,6 @@ pub enum ProverError { AsyncClientError(#[from] tls_client_async::ConnectionError), #[error(transparent)] IOError(#[from] std::io::Error), - #[error(transparent)] - MuxerError(#[from] utils_aio::mux::MuxerError), #[error("notarization error: {0}")] NotarizationError(String), #[error(transparent)] @@ -30,6 +28,21 @@ pub enum ProverError { InvalidRange, } +impl From for ProverError { + fn from(e: uid_mux::yamux::ConnectionError) -> Self { + Self::IOError(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + } +} + +impl From for ProverError { + fn from(e: mpz_common::ContextError) -> Self { + Self::MpcError(Box::new(e)) + } +} + impl From for ProverError { fn from(e: MpcTlsError) -> Self { Self::MpcError(Box::new(e)) @@ -42,49 +55,44 @@ impl From for ProverError { } } -impl From for ProverError { - fn from(e: mpz_garble::VmError) -> Self { +impl From for ProverError { + fn from(e: mpz_ot::kos::SenderError) -> Self { Self::MpcError(Box::new(e)) } } -impl From for ProverError { - fn from(e: mpz_garble::MemoryError) -> Self { +impl From for ProverError { + fn from(e: mpz_ole::OLEError) -> Self { Self::MpcError(Box::new(e)) } } -impl From for ProverError { - fn from(e: mpz_garble::ProveError) -> Self { +impl From for ProverError { + fn from(e: mpz_ot::kos::ReceiverError) -> Self { Self::MpcError(Box::new(e)) } } -impl From for ProverError { - fn from(value: mpz_ot::actor::kos::SenderActorError) -> Self { - Self::MpcError(Box::new(value)) +impl From for ProverError { + fn from(e: mpz_garble::VmError) -> Self { + Self::MpcError(Box::new(e)) } } -impl From for ProverError { - fn from(value: mpz_ot::actor::kos::ReceiverActorError) -> Self { - Self::MpcError(Box::new(value)) +impl From for ProverError { + fn from(e: mpz_garble::protocol::deap::DEAPError) -> Self { + Self::MpcError(Box::new(e)) } } -#[derive(Debug)] -pub struct OTShutdownError; - -impl std::fmt::Display for OTShutdownError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("ot shutdown prior to completion") +impl From for ProverError { + fn from(e: mpz_garble::MemoryError) -> Self { + Self::MpcError(Box::new(e)) } } -impl Error for OTShutdownError {} - -impl From for ProverError { - fn from(e: OTShutdownError) -> Self { +impl From for ProverError { + fn from(e: mpz_garble::ProveError) -> Self { Self::MpcError(Box::new(e)) } } diff --git a/tlsn/tlsn-prover/src/tls/future.rs b/tlsn/tlsn-prover/src/tls/future.rs index e4fad7976..9ea8bbfff 100644 --- a/tlsn/tlsn-prover/src/tls/future.rs +++ b/tlsn/tlsn-prover/src/tls/future.rs @@ -1,7 +1,7 @@ //! This module collects futures which are used by the [Prover]. use super::{state, Prover, ProverControl, ProverError}; -use futures::{future::FusedFuture, Future}; +use futures::Future; use std::pin::Pin; /// Prover future which must be polled for the TLS connection to make progress. @@ -29,47 +29,3 @@ impl Future for ProverFuture { self.fut.as_mut().poll(cx) } } - -/// A future which must be polled for the muxer to make progress. -pub(crate) struct MuxFuture { - pub(crate) fut: Pin> + Send + 'static>>, -} - -impl Future for MuxFuture { - type Output = Result<(), ProverError>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.fut.as_mut().poll(cx) - } -} - -impl FusedFuture for MuxFuture { - fn is_terminated(&self) -> bool { - self.fut.is_terminated() - } -} - -/// A future which must be polled for the Oblivious Transfer protocol to make progress. -pub(crate) struct OTFuture { - pub(crate) fut: Pin> + Send + 'static>>, -} - -impl Future for OTFuture { - type Output = Result<(), ProverError>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.fut.as_mut().poll(cx) - } -} - -impl FusedFuture for OTFuture { - fn is_terminated(&self) -> bool { - self.fut.is_terminated() - } -} diff --git a/tlsn/tlsn-prover/src/tls/mod.rs b/tlsn/tlsn-prover/src/tls/mod.rs index cb5a2c02e..86f5b29f8 100644 --- a/tlsn/tlsn-prover/src/tls/mod.rs +++ b/tlsn/tlsn-prover/src/tls/mod.rs @@ -16,33 +16,28 @@ pub mod state; pub use config::{ProverConfig, ProverConfigBuilder, ProverConfigBuilderError}; pub use error::ProverError; pub use future::ProverFuture; -use tlsn_common::{ - mux::{attach_mux, MuxControl}, - Role, -}; +use state::{Notarize, Prove}; -use error::OTShutdownError; -use future::{MuxFuture, OTFuture}; -use futures::{AsyncRead, AsyncWrite, FutureExt, StreamExt, TryFutureExt}; -use mpz_garble::{config::Role as DEAPRole, protocol::deap::DEAPVm}; -use mpz_ot::{ - actor::kos::{ReceiverActor, SenderActor, SharedReceiver, SharedSender}, - chou_orlandi, kos, -}; -use mpz_share_conversion as ff; +use futures::{AsyncRead, AsyncWrite, TryFutureExt}; +use mpz_common::Allocate; +use mpz_garble::config::Role as DEAPRole; +use mpz_ot::{chou_orlandi, kos}; use rand::Rng; -use state::{Notarize, Prove}; +use serio::StreamExt; use std::sync::Arc; use tls_client::{ClientConnection, ServerName as TlsServerName}; use tls_client_async::{bind_client, ClosedConnection, TlsConnection}; -use tls_mpc::{setup_components, LeaderCtrl, MpcTlsLeader, TlsRole}; +use tls_mpc::{build_components, LeaderCtrl, MpcTlsLeader, TlsRole}; +use tlsn_common::{ + mux::{attach_mux, MuxControl}, + DEAPThread, Executor, OTReceiver, OTSender, Role, +}; use tlsn_core::transcript::Transcript; -use utils_aio::mux::MuxChannel; +use uid_mux::FramedUidMux as _; #[cfg(feature = "formats")] use crate::http::{state as http_state, HttpProver, HttpProverError}; -#[cfg(feature = "tracing")] use tracing::{debug, debug_span, instrument, Instrument}; /// A prover instance. @@ -73,31 +68,41 @@ impl Prover { /// # Arguments /// /// * `socket` - The socket to the TLS verifier. + #[instrument(level = "debug", skip_all, err)] pub async fn setup( self, socket: S, ) -> Result, ProverError> { - let (mut mux, mux_ctrl) = attach_mux(socket, Role::Prover); + let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Prover); - let mut mux_fut = MuxFuture { - fut: Box::pin(async move { mux.run().await.map_err(ProverError::from) }.fuse()), - }; + let mut exec = Executor::new(mux_ctrl.clone(), 8); - let mpc_setup_fut = setup_mpc_backend(&self.config, mux_ctrl.clone()); - let (mpc_tls, vm, _, gf2, ot_fut) = futures::select! { - res = mpc_setup_fut.fuse() => res?, - _ = (&mut mux_fut).fuse() => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; + let (mpc_tls, vm, ot_recv) = mux_fut + .poll_with(setup_mpc_backend(&self.config, &mux_ctrl, &mut exec)) + .await?; + + let io = mux_fut + .poll_with( + mux_ctrl + .open_framed(b"tlsnotary") + .map_err(ProverError::from), + ) + .await?; + + let ctx = mux_fut + .poll_with(exec.new_thread().map_err(ProverError::from)) + .await?; Ok(Prover { config: self.config, state: state::Setup { + io, mux_ctrl, mux_fut, mpc_tls, vm, - ot_fut, - gf2, + ot_recv, + ctx, }, }) } @@ -112,21 +117,19 @@ impl Prover { /// # Arguments /// /// * `socket` - The socket to the server. - #[cfg_attr( - feature = "tracing", - instrument(level = "debug", skip(self, socket), err) - )] + #[instrument(level = "debug", skip_all, err)] pub async fn connect( self, socket: S, ) -> Result<(TlsConnection, ProverFuture), ProverError> { let state::Setup { + io, mux_ctrl, mut mux_fut, mpc_tls, vm, - mut ot_fut, - gf2, + ot_recv, + ctx, } = self.state; let (mpc_ctrl, mpc_fut) = mpc_tls.run(); @@ -145,31 +148,31 @@ impl Prover { let fut = Box::pin({ let mpc_ctrl = mpc_ctrl.clone(); - #[allow(clippy::let_and_return)] - let fut = async move { + async move { let conn_fut = async { - let ClosedConnection { sent, recv, .. } = futures::select! { - res = conn_fut.fuse() => res?, - _ = ot_fut => return Err(OTShutdownError)?, - _ = mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; + let ClosedConnection { sent, recv, .. } = mux_fut + .poll_with(conn_fut.map_err(ProverError::from)) + .await?; mpc_ctrl.close_connection().await?; Ok::<_, ProverError>((sent, recv)) }; - let ((sent, recv), mpc_tls_data) = - futures::try_join!(conn_fut, mpc_fut.map_err(ProverError::from))?; + let ((sent, recv), mpc_tls_data) = futures::try_join!( + conn_fut, + mpc_fut.in_current_span().map_err(ProverError::from) + )?; Ok(Prover { config: self.config, state: state::Closed { + io, mux_ctrl, mux_fut, vm, - ot_fut, - gf2, + ot_recv, + ctx, start_time, handshake_decommitment: mpc_tls_data .handshake_decommitment @@ -179,10 +182,8 @@ impl Prover { transcript_rx: Transcript::new(recv), }, }) - }; - #[cfg(feature = "tracing")] - let fut = fut.instrument(debug_span!("prover_tls_connection")); - fut + } + .instrument(debug_span!("prover")) }); Ok(( @@ -236,122 +237,116 @@ impl Prover { } /// Performs a setup of the various MPC subprotocols. -#[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] -#[allow(clippy::type_complexity)] +#[instrument(level = "debug", skip_all, err)] async fn setup_mpc_backend( config: &ProverConfig, - mut mux: MuxControl, -) -> Result< - ( - MpcTlsLeader, - DEAPVm, - SharedReceiver, - ff::ConverterSender, - OTFuture, - ), - ProverError, -> { - let (ot_send_sink, ot_send_stream) = mux.get_channel("ot/0").await?.split(); - let (ot_recv_sink, ot_recv_stream) = mux.get_channel("ot/1").await?.split(); - - let mut ot_sender_actor = SenderActor::new( - kos::Sender::new( - config.build_ot_sender_config(), - chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()), - ), - ot_send_sink, - ot_send_stream, + mux: &MuxControl, + exec: &mut Executor, +) -> Result<(MpcTlsLeader, DEAPThread, OTReceiver), ProverError> { + let mut ot_sender = kos::Sender::new( + config.build_ot_sender_config(), + chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()), ); + ot_sender.alloc(config.ot_sender_setup_count()); - let mut ot_receiver_actor = ReceiverActor::new( - kos::Receiver::new( - config.build_ot_receiver_config(), - chou_orlandi::Sender::new(config.build_base_ot_sender_config()), - ), - ot_recv_sink, - ot_recv_stream, + let mut ot_receiver = kos::Receiver::new( + config.build_ot_receiver_config(), + chou_orlandi::Sender::new(config.build_base_ot_sender_config()), ); - - let ot_send = ot_sender_actor.sender(); - let ot_recv = ot_receiver_actor.receiver(); - - #[cfg(feature = "tracing")] - debug!("Starting OT setup"); - - futures::try_join!( - ot_sender_actor - .setup(config.ot_sender_setup_count()) - .map_err(ProverError::from), - ot_receiver_actor - .setup(config.ot_receiver_setup_count()) - .map_err(ProverError::from) + ot_receiver.alloc(config.ot_receiver_setup_count()); + + let ot_sender = OTSender::new(ot_sender); + let ot_receiver = OTReceiver::new(ot_receiver); + + let ( + ctx_vm, + ctx_ke_0, + ctx_ke_1, + ctx_prf_0, + ctx_prf_1, + ctx_encrypter_block_cipher, + ctx_encrypter_stream_cipher, + ctx_encrypter_ghash, + ctx_encrypter, + ctx_decrypter_block_cipher, + ctx_decrypter_stream_cipher, + ctx_decrypter_ghash, + ctx_decrypter, + ) = futures::try_join!( + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), )?; - #[cfg(feature = "tracing")] - debug!("OT setup complete"); - - let ot_fut = OTFuture { - fut: Box::pin( - async move { - futures::try_join!( - ot_sender_actor.run().map_err(ProverError::from), - ot_receiver_actor.run().map_err(ProverError::from) - )?; - - Ok(()) - } - .fuse(), - ), - }; - - let mut vm = DEAPVm::new( - "vm", + let vm = DEAPThread::new( DEAPRole::Leader, rand::rngs::OsRng.gen(), - mux.get_channel("vm").await?, - Box::new(mux.clone()), - ot_send.clone(), - ot_recv.clone(), + ctx_vm, + ot_sender.clone(), + ot_receiver.clone(), ); - let p256_sender_config = config.build_p256_sender_config(); - let channel = mux.get_channel(p256_sender_config.id()).await?; - let p256_send = - ff::ConverterSender::::new(p256_sender_config, ot_send.clone(), channel); - - let p256_receiver_config = config.build_p256_receiver_config(); - let channel = mux.get_channel(p256_receiver_config.id()).await?; - let p256_recv = - ff::ConverterReceiver::::new(p256_receiver_config, ot_recv.clone(), channel); - - let gf2_config = config.build_gf2_config(); - let channel = mux.get_channel(gf2_config.id()).await?; - let gf2 = ff::ConverterSender::::new(gf2_config, ot_send.clone(), channel); - let mpc_tls_config = config.build_mpc_tls_config(); - - let (ke, prf, encrypter, decrypter) = setup_components( - mpc_tls_config.common(), + let (ke, prf, encrypter, decrypter) = build_components( TlsRole::Leader, - &mut mux, - &mut vm, - p256_send, - p256_recv, - gf2.handle() - .map_err(|e| ProverError::MpcError(Box::new(e)))?, - ) - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))?; - - let channel = mux.get_channel(mpc_tls_config.common().id()).await?; - let mut mpc_tls = MpcTlsLeader::new(mpc_tls_config, channel, ke, prf, encrypter, decrypter); + mpc_tls_config.common(), + ctx_ke_0, + ctx_encrypter, + ctx_decrypter, + ctx_encrypter_ghash, + ctx_decrypter_ghash, + vm.new_thread(ctx_ke_1, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread(ctx_prf_0, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread(ctx_prf_1, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread( + ctx_encrypter_block_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_decrypter_block_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_encrypter_stream_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_decrypter_stream_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + ot_sender.clone(), + ot_receiver.clone(), + ); + + let channel = mux.open_framed(b"mpc_tls").await?; + let mut mpc_tls = MpcTlsLeader::new( + mpc_tls_config, + Box::new(StreamExt::compat_stream(channel)), + ke, + prf, + encrypter, + decrypter, + ); mpc_tls.setup().await?; - #[cfg(feature = "tracing")] debug!("MPC backend setup complete"); - Ok((mpc_tls, vm, ot_recv, gf2, ot_fut)) + Ok((mpc_tls, vm, ot_receiver)) } /// A controller for the prover. diff --git a/tlsn/tlsn-prover/src/tls/notarize.rs b/tlsn/tlsn-prover/src/tls/notarize.rs index 49bddd405..9bcf59948 100644 --- a/tlsn/tlsn-prover/src/tls/notarize.rs +++ b/tlsn/tlsn-prover/src/tls/notarize.rs @@ -2,19 +2,14 @@ //! //! The prover deals with a TLS verifier that is only a notary. -use crate::tls::error::OTShutdownError; - -use super::{ff::ShareConversionReveal, state::Notarize, Prover, ProverError}; -use futures::{FutureExt, SinkExt, StreamExt}; +use super::{state::Notarize, Prover, ProverError}; +use mpz_ot::VerifiableOTReceiver; +use serio::{stream::IoStreamExt as _, SinkExt as _}; use tlsn_core::{ - commitment::TranscriptCommitmentBuilder, - msg::{SignedSessionHeader, TlsnMessage}, - transcript::Transcript, + commitment::TranscriptCommitmentBuilder, msg::SignedSessionHeader, transcript::Transcript, NotarizedSession, ServerName, SessionData, }; -#[cfg(feature = "tracing")] -use tracing::instrument; -use utils_aio::{expect_msg_or_err, mux::MuxChannel}; +use tracing::{debug, instrument}; impl Prover { /// Returns the transcript of the sent data. @@ -33,14 +28,15 @@ impl Prover { } /// Finalizes the notarization returning a [`NotarizedSession`]. - #[cfg_attr(feature = "tracing", instrument(level = "info", skip(self), err))] + #[instrument(level = "debug", skip_all, err)] pub async fn finalize(self) -> Result { let Notarize { - mut mux_ctrl, + mut io, + mux_ctrl, mut mux_fut, mut vm, - mut ot_fut, - mut gf2, + mut ot_recv, + mut ctx, start_time, handshake_decommitment, server_public_key, @@ -61,39 +57,32 @@ impl Prover { let merkle_root = session_data.commitments().merkle_root(); - let mut notarize_fut = Box::pin(async move { - let mut channel = mux_ctrl.get_channel("notarize").await?; - - channel - .send(TlsnMessage::TranscriptCommitmentRoot(merkle_root)) - .await?; - - let notary_encoder_seed = vm - .finalize() - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))? - .expect("encoder seed returned"); - - // This is a temporary approach until a maliciously secure share conversion protocol is implemented. - // The prover is essentially revealing the TLS MAC key. In some exotic scenarios this allows a malicious - // TLS verifier to modify the prover's sent data. - gf2.reveal() - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))?; - - let signed_header = expect_msg_or_err!(channel, TlsnMessage::SignedSessionHeader)?; - - Ok::<_, ProverError>((notary_encoder_seed, signed_header)) - }) - .fuse(); - - let (notary_encoder_seed, SignedSessionHeader { header, signature }) = futures::select_biased! { - res = notarize_fut => res?, - _ = ot_fut => return Err(OTShutdownError)?, - _ = &mut mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; - // Wait for the notary to correctly close the connection. - mux_fut.await?; + let (notary_encoder_seed, SignedSessionHeader { header, signature }) = mux_fut + .poll_with(async { + debug!("starting finalization"); + + io.send(merkle_root).await?; + + ot_recv.accept_reveal(&mut ctx).await?; + + debug!("received OT secret"); + + let notary_encoder_seed = vm + .finalize() + .await + .map_err(|e| ProverError::MpcError(Box::new(e)))? + .expect("encoder seed returned"); + + let signed_header: SignedSessionHeader = io.expect_next().await?; + + Ok::<_, ProverError>((notary_encoder_seed, signed_header)) + }) + .await?; + + if !mux_fut.is_complete() { + mux_ctrl.mux().close(); + mux_fut.await?; + } // Check the header is consistent with the Prover's view. header diff --git a/tlsn/tlsn-prover/src/tls/prove.rs b/tlsn/tlsn-prover/src/tls/prove.rs index b9932a34b..8b777f48e 100644 --- a/tlsn/tlsn-prover/src/tls/prove.rs +++ b/tlsn/tlsn-prover/src/tls/prove.rs @@ -4,19 +4,13 @@ //! the verifier directly verifies parts of the transcript. use super::{state::Prove as ProveState, Prover, ProverError}; -use crate::tls::error::OTShutdownError; -use futures::{FutureExt, SinkExt}; -use mpz_garble::{Memory, Prove, Vm}; -use mpz_share_conversion::ShareConversionReveal; -use tlsn_core::{ - msg::TlsnMessage, proof::SessionInfo, transcript::get_value_ids, Direction, ServerName, - Transcript, -}; +use mpz_garble::{Memory, Prove}; +use mpz_ot::VerifiableOTReceiver; +use serio::SinkExt as _; +use tlsn_core::{proof::SessionInfo, transcript::get_value_ids, Direction, ServerName, Transcript}; use utils::range::{RangeSet, RangeUnion}; -use utils_aio::mux::MuxChannel; -#[cfg(feature = "tracing")] -use tracing::info; +use tracing::{info, instrument}; impl Prover { /// Returns the transcript of the sent requests @@ -66,99 +60,82 @@ impl Prover { } /// Prove transcript values + #[instrument(level = "debug", skip_all, err)] pub async fn prove(&mut self) -> Result<(), ProverError> { let mut proving_info = std::mem::take(&mut self.state.proving_info); - let mut prove_fut = Box::pin(async { - // Create a new channel and vm thread if not already present - let channel = if let Some(ref mut channel) = self.state.channel { - channel - } else { - self.state.channel = Some(self.state.mux_ctrl.get_channel("prove-verify").await?); - self.state.channel.as_mut().unwrap() - }; - - let prove_thread = if let Some(ref mut prove_thread) = self.state.prove_thread { - prove_thread - } else { - self.state.prove_thread = Some(self.state.vm.new_thread("prove-verify").await?); - self.state.prove_thread.as_mut().unwrap() - }; - - // Now prove the transcript parts which have been marked for reveal - let sent_value_ids = proving_info - .sent_ids - .iter_ranges() - .map(|r| get_value_ids(&r.into(), Direction::Sent).collect::>()); - let recv_value_ids = proving_info - .recv_ids - .iter_ranges() - .map(|r| get_value_ids(&r.into(), Direction::Received).collect::>()); - - let value_refs = sent_value_ids - .chain(recv_value_ids) - .map(|ids| { - let inner_refs = ids - .iter() - .map(|id| { - prove_thread - .get_value(id.as_str()) - .expect("Byte should be in VM memory") - }) - .collect::>(); - - prove_thread - .array_from_values(inner_refs.as_slice()) - .expect("Byte should be in VM Memory") - }) - .collect::>(); - - // Extract cleartext we want to reveal from transcripts - let mut cleartext = - Vec::with_capacity(proving_info.sent_ids.len() + proving_info.recv_ids.len()); - proving_info - .sent_ids - .iter_ranges() - .for_each(|r| cleartext.extend_from_slice(&self.state.transcript_tx.data()[r])); - proving_info - .recv_ids - .iter_ranges() - .for_each(|r| cleartext.extend_from_slice(&self.state.transcript_rx.data()[r])); - proving_info.cleartext = cleartext; - - // Send the proving info to the verifier - channel.send(TlsnMessage::ProvingInfo(proving_info)).await?; - - #[cfg(feature = "tracing")] - info!("Sent proving info to verifier"); - - // Prove the revealed transcript parts - prove_thread.prove(value_refs.as_slice()).await?; - - #[cfg(feature = "tracing")] - info!("Successfully proved cleartext"); - - Ok::<_, ProverError>(()) - }) - .fuse(); - - futures::select_biased! { - res = prove_fut => res?, - _ = &mut self.state.ot_fut => return Err(OTShutdownError)?, - _ = &mut self.state.mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; + self.state + .mux_fut + .poll_with(async { + // Now prove the transcript parts which have been marked for reveal + let sent_value_ids = proving_info + .sent_ids + .iter_ranges() + .map(|r| get_value_ids(&r.into(), Direction::Sent).collect::>()); + let recv_value_ids = proving_info.recv_ids.iter_ranges().map(|r| { + get_value_ids(&r.into(), Direction::Received).collect::>() + }); + + let value_refs = sent_value_ids + .chain(recv_value_ids) + .map(|ids| { + let inner_refs = ids + .iter() + .map(|id| { + self.state + .vm + .get_value(id.as_str()) + .expect("Byte should be in VM memory") + }) + .collect::>(); + + self.state + .vm + .array_from_values(inner_refs.as_slice()) + .expect("Byte should be in VM Memory") + }) + .collect::>(); + + // Extract cleartext we want to reveal from transcripts + let mut cleartext = + Vec::with_capacity(proving_info.sent_ids.len() + proving_info.recv_ids.len()); + proving_info + .sent_ids + .iter_ranges() + .for_each(|r| cleartext.extend_from_slice(&self.state.transcript_tx.data()[r])); + proving_info + .recv_ids + .iter_ranges() + .for_each(|r| cleartext.extend_from_slice(&self.state.transcript_rx.data()[r])); + proving_info.cleartext = cleartext; + + // Send the proving info to the verifier + self.state.io.send(proving_info).await?; + + info!("Sent proving info to verifier"); + + // Prove the revealed transcript parts + self.state.vm.prove(value_refs.as_slice()).await?; + + info!("Successfully proved cleartext"); + + Ok::<_, ProverError>(()) + }) + .await?; Ok(()) } /// Finalize the proving + #[instrument(level = "debug", skip_all, err)] pub async fn finalize(self) -> Result<(), ProverError> { let ProveState { - mut mux_ctrl, + mut io, + mux_ctrl, mut mux_fut, mut vm, - mut ot_fut, - mut gf2, + mut ot_recv, + mut ctx, handshake_decommitment, .. } = self.state; @@ -169,38 +146,28 @@ impl Prover { handshake_decommitment, }; - let mut finalize_fut = Box::pin(async move { - let mut channel = mux_ctrl.get_channel("finalize").await?; - - _ = vm - .finalize() - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))? - .expect("encoder seed returned"); - - // This is a temporary approach until a maliciously secure share conversion protocol is implemented. - // The prover is essentially revealing the TLS MAC key. In some exotic scenarios this allows a malicious - // TLS verifier to modify the prover's request. - gf2.reveal() - .await - .map_err(|e| ProverError::MpcError(Box::new(e)))?; - - // Send session_info to the verifier - channel.send(TlsnMessage::SessionInfo(session_info)).await?; - - Ok::<_, ProverError>(()) - }) - .fuse(); - - futures::select_biased! { - res = finalize_fut => res?, - _ = ot_fut => return Err(OTShutdownError)?, - _ = &mut mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; + mux_fut + .poll_with(async move { + ot_recv.accept_reveal(&mut ctx).await?; + + _ = vm + .finalize() + .await + .map_err(|e| ProverError::MpcError(Box::new(e)))? + .expect("encoder seed returned"); + + // Send session_info to the verifier + io.send(session_info).await?; + + Ok::<_, ProverError>(()) + }) + .await?; + + if !mux_fut.is_complete() { + mux_ctrl.mux().close(); + mux_fut.await?; + } - // We need to wait for the verifier to correctly close the connection. Otherwise the prover - // would rush ahead and close the connection before the verifier has finished. - mux_fut.await?; Ok(()) } } diff --git a/tlsn/tlsn-prover/src/tls/state.rs b/tlsn/tlsn-prover/src/tls/state.rs index 64ae9c059..f166f7459 100644 --- a/tlsn/tlsn-prover/src/tls/state.rs +++ b/tlsn/tlsn-prover/src/tls/state.rs @@ -1,21 +1,16 @@ //! TLS prover states. -use crate::tls::{MuxFuture, OTFuture}; use mpz_core::commit::Decommitment; -use mpz_garble::protocol::deap::{DEAPThread, DEAPVm, PeerEncodings}; +use mpz_garble::protocol::deap::PeerEncodings; use mpz_garble_core::{encoding_state, EncodedValue}; -use mpz_ot::actor::kos::{SharedReceiver, SharedSender}; -use mpz_share_conversion::{ConverterSender, Gf2_128}; use std::collections::HashMap; use tls_core::{handshake::HandshakeData, key::PublicKey}; use tls_mpc::MpcTlsLeader; -use tlsn_common::mux::MuxControl; -use tlsn_core::{ - commitment::TranscriptCommitmentBuilder, - msg::{ProvingInfo, TlsnMessage}, - Transcript, +use tlsn_common::{ + mux::{MuxControl, MuxFuture}, + Context, DEAPThread, Io, OTReceiver, }; -use utils_aio::duplex::Duplex; +use tlsn_core::{commitment::TranscriptCommitmentBuilder, msg::ProvingInfo, Transcript}; /// Entry state pub struct Initialized; @@ -24,26 +19,27 @@ opaque_debug::implement!(Initialized); /// State after MPC setup has completed. pub struct Setup { - /// A muxer for communication with the TLS verifier. + pub(crate) io: Io, pub(crate) mux_ctrl: MuxControl, pub(crate) mux_fut: MuxFuture, pub(crate) mpc_tls: MpcTlsLeader, - pub(crate) vm: DEAPVm, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterSender, + pub(crate) vm: DEAPThread, + pub(crate) ot_recv: OTReceiver, + pub(crate) ctx: Context, } opaque_debug::implement!(Setup); /// State after the TLS connection has been closed. pub struct Closed { + pub(crate) io: Io, pub(crate) mux_ctrl: MuxControl, pub(crate) mux_fut: MuxFuture, - pub(crate) vm: DEAPVm, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterSender, + pub(crate) vm: DEAPThread, + pub(crate) ot_recv: OTReceiver, + pub(crate) ctx: Context, pub(crate) start_time: u64, pub(crate) handshake_decommitment: Decommitment, @@ -57,13 +53,13 @@ opaque_debug::implement!(Closed); /// Notarizing state. pub struct Notarize { - /// A muxer for communication with the Notary + pub(crate) io: Io, pub(crate) mux_ctrl: MuxControl, pub(crate) mux_fut: MuxFuture, - pub(crate) vm: DEAPVm, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterSender, + pub(crate) vm: DEAPThread, + pub(crate) ot_recv: OTReceiver, + pub(crate) ctx: Context, pub(crate) start_time: u64, pub(crate) handshake_decommitment: Decommitment, @@ -92,11 +88,12 @@ impl From for Notarize { ); Self { + io: state.io, mux_ctrl: state.mux_ctrl, mux_fut: state.mux_fut, vm: state.vm, - ot_fut: state.ot_fut, - gf2: state.gf2, + ot_recv: state.ot_recv, + ctx: state.ctx, start_time: state.start_time, handshake_decommitment: state.handshake_decommitment, server_public_key: state.server_public_key, @@ -109,12 +106,13 @@ impl From for Notarize { /// Proving state. pub struct Prove { + pub(crate) io: Io, pub(crate) mux_ctrl: MuxControl, pub(crate) mux_fut: MuxFuture, - pub(crate) vm: DEAPVm, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterSender, + pub(crate) vm: DEAPThread, + pub(crate) ot_recv: OTReceiver, + pub(crate) ctx: Context, pub(crate) handshake_decommitment: Decommitment, @@ -122,24 +120,21 @@ pub struct Prove { pub(crate) transcript_rx: Transcript, pub(crate) proving_info: ProvingInfo, - pub(crate) channel: Option>>, - pub(crate) prove_thread: Option>, } impl From for Prove { fn from(state: Closed) -> Self { Self { + io: state.io, mux_ctrl: state.mux_ctrl, mux_fut: state.mux_fut, vm: state.vm, - ot_fut: state.ot_fut, - gf2: state.gf2, + ot_recv: state.ot_recv, + ctx: state.ctx, handshake_decommitment: state.handshake_decommitment, transcript_tx: state.transcript_tx, transcript_rx: state.transcript_rx, proving_info: ProvingInfo::default(), - channel: None, - prove_thread: None, } } } @@ -163,7 +158,7 @@ mod sealed { } fn collect_encodings( - vm: &DEAPVm, + vm: &impl PeerEncodings, transcript_tx: &Transcript, transcript_rx: &Transcript, ) -> HashMap> { diff --git a/tlsn/tlsn-verifier/Cargo.toml b/tlsn/tlsn-verifier/Cargo.toml index addf4d4ec..7e6e5433f 100644 --- a/tlsn/tlsn-verifier/Cargo.toml +++ b/tlsn/tlsn-verifier/Cargo.toml @@ -9,22 +9,27 @@ version = "0.1.0-alpha.5" edition = "2021" [features] -tracing = ["dep:tracing", "tlsn-tls-mpc/tracing", "tlsn-common/tracing"] +default = ["rayon"] +rayon = ["mpz-common/rayon"] +force-st = ["mpz-common/force-st"] [dependencies] tlsn-core.workspace = true tlsn-common.workspace = true tlsn-tls-core.workspace = true tlsn-tls-mpc.workspace = true -uid-mux.workspace = true tlsn-utils-aio.workspace = true +serio = { workspace = true, features = ["compat"] } +uid-mux = { workspace = true, features = ["serio"] } mpz-core.workspace = true mpz-garble.workspace = true mpz-ot.workspace = true +mpz-ole.workspace = true mpz-share-conversion.workspace = true mpz-circuits.workspace = true +mpz-common.workspace = true futures.workspace = true thiserror.workspace = true @@ -32,5 +37,4 @@ derive_builder.workspace = true rand.workspace = true signature.workspace = true opaque-debug.workspace = true - -tracing = { workspace = true, optional = true } +tracing.workspace = true diff --git a/tlsn/tlsn-verifier/src/tls/config.rs b/tlsn/tlsn-verifier/src/tls/config.rs index 879a3d547..c98caccf1 100644 --- a/tlsn/tlsn-verifier/src/tls/config.rs +++ b/tlsn/tlsn-verifier/src/tls/config.rs @@ -1,5 +1,4 @@ use mpz_ot::{chou_orlandi, kos}; -use mpz_share_conversion::{ReceiverConfig, SenderConfig}; use std::fmt::{Debug, Formatter, Result}; use tls_core::verify::{ServerCertVerifier, WebPkiVerifier}; use tls_mpc::{MpcTlsCommonConfig, MpcTlsFollowerConfig, TranscriptConfig}; @@ -123,20 +122,4 @@ impl VerifierConfig { pub(crate) fn ot_receiver_setup_count(&self) -> usize { ot_recv_estimate(Role::Verifier, self.max_sent_data, self.max_recv_data) } - - pub(crate) fn build_p256_sender_config(&self) -> SenderConfig { - SenderConfig::builder().id("p256/1").build().unwrap() - } - - pub(crate) fn build_p256_receiver_config(&self) -> ReceiverConfig { - ReceiverConfig::builder().id("p256/0").build().unwrap() - } - - pub(crate) fn build_gf2_config(&self) -> ReceiverConfig { - ReceiverConfig::builder() - .id("gf2") - .record() - .build() - .unwrap() - } } diff --git a/tlsn/tlsn-verifier/src/tls/error.rs b/tlsn/tlsn-verifier/src/tls/error.rs index 39ac6650c..504aa3867 100644 --- a/tlsn/tlsn-verifier/src/tls/error.rs +++ b/tlsn/tlsn-verifier/src/tls/error.rs @@ -15,6 +15,21 @@ pub enum VerifierError { InvalidRange, } +impl From for VerifierError { + fn from(e: uid_mux::yamux::ConnectionError) -> Self { + Self::IOError(std::io::Error::new( + std::io::ErrorKind::ConnectionAborted, + e, + )) + } +} + +impl From for VerifierError { + fn from(e: mpz_common::ContextError) -> Self { + Self::MpcError(Box::new(e)) + } +} + impl From for VerifierError { fn from(e: MpcTlsError) -> Self { Self::MpcError(Box::new(e)) @@ -27,14 +42,20 @@ impl From for VerifierError { } } -impl From for VerifierError { - fn from(e: mpz_ot::actor::kos::SenderActorError) -> Self { +impl From for VerifierError { + fn from(e: mpz_ot::kos::SenderError) -> Self { + Self::MpcError(Box::new(e)) + } +} + +impl From for VerifierError { + fn from(e: mpz_ot::kos::ReceiverError) -> Self { Self::MpcError(Box::new(e)) } } -impl From for VerifierError { - fn from(e: mpz_ot::actor::kos::ReceiverActorError) -> Self { +impl From for VerifierError { + fn from(e: mpz_garble::protocol::deap::DEAPError) -> Self { Self::MpcError(Box::new(e)) } } diff --git a/tlsn/tlsn-verifier/src/tls/future.rs b/tlsn/tlsn-verifier/src/tls/future.rs deleted file mode 100644 index 886fb4128..000000000 --- a/tlsn/tlsn-verifier/src/tls/future.rs +++ /dev/null @@ -1,50 +0,0 @@ -//! This module collects futures which are used by the [Verifier](crate::tls::Verifier). - -use super::{OTSenderActor, VerifierError}; -use futures::{future::FusedFuture, Future}; -use std::pin::Pin; - -/// A future which must be polled for the muxer to make progress. -pub(crate) struct MuxFuture { - pub(crate) fut: Pin> + Send + 'static>>, -} - -impl Future for MuxFuture { - type Output = Result<(), VerifierError>; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.fut.as_mut().poll(cx) - } -} - -impl FusedFuture for MuxFuture { - fn is_terminated(&self) -> bool { - self.fut.is_terminated() - } -} - -/// A future which must be polled for the Oblivious Transfer protocol to make progress. -pub(crate) struct OTFuture { - pub(crate) fut: - Pin> + Send + 'static>>, -} - -impl Future for OTFuture { - type Output = Result; - - fn poll( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - self.fut.as_mut().poll(cx) - } -} - -impl FusedFuture for OTFuture { - fn is_terminated(&self) -> bool { - self.fut.is_terminated() - } -} diff --git a/tlsn/tlsn-verifier/src/tls/mod.rs b/tlsn/tlsn-verifier/src/tls/mod.rs index 32599e3f1..c222db23d 100644 --- a/tlsn/tlsn-verifier/src/tls/mod.rs +++ b/tlsn/tlsn-verifier/src/tls/mod.rs @@ -2,53 +2,33 @@ pub(crate) mod config; mod error; -mod future; mod notarize; pub mod state; mod verify; pub use config::{VerifierConfig, VerifierConfigBuilder, VerifierConfigBuilderError}; pub use error::VerifierError; +use mpz_common::Allocate; +use serio::StreamExt; +use uid_mux::FramedUidMux; use std::time::{SystemTime, UNIX_EPOCH}; -use crate::tls::future::OTFuture; -use future::MuxFuture; -use futures::{ - stream::{SplitSink, SplitStream}, - AsyncRead, AsyncWrite, FutureExt, StreamExt, TryFutureExt, -}; -use mpz_garble::{config::Role as GarbleRole, protocol::deap::DEAPVm}; -use mpz_ot::{ - actor::kos::{ - msgs::Message as ActorMessage, ReceiverActor, SenderActor, SharedReceiver, SharedSender, - }, - chou_orlandi, kos, -}; -use mpz_share_conversion as ff; +use futures::{AsyncRead, AsyncWrite, TryFutureExt}; +use mpz_garble::config::Role as DEAPRole; +use mpz_ot::{chou_orlandi, kos}; use rand::Rng; use signature::Signer; use state::{Notarize, Verify}; -use tls_mpc::{setup_components, MpcTlsFollower, MpcTlsFollowerData, TlsRole}; +use tls_mpc::{build_components, MpcTlsFollower, MpcTlsFollowerData, TlsRole}; use tlsn_common::{ mux::{attach_mux, MuxControl}, - Role, + DEAPThread, Executor, OTReceiver, OTSender, Role, }; use tlsn_core::{proof::SessionInfo, RedactedTranscript, SessionHeader, Signature}; -use utils_aio::{duplex::Duplex, mux::MuxChannel}; -#[cfg(feature = "tracing")] use tracing::{debug, info, instrument}; -type OTSenderActor = SenderActor< - chou_orlandi::Receiver, - SplitSink< - Box>>, - ActorMessage, - >, - SplitStream>>>, ->; - /// A Verifier instance. pub struct Verifier { config: VerifierConfig, @@ -75,30 +55,42 @@ impl Verifier { self, socket: S, ) -> Result, VerifierError> { - let (mut mux, mux_ctrl) = attach_mux(socket, Role::Verifier); + let (mut mux_fut, mux_ctrl) = attach_mux(socket, Role::Verifier); - let mut mux_fut = MuxFuture { - fut: Box::pin(async move { mux.run().await.map_err(VerifierError::from) }.fuse()), - }; + let mut exec = Executor::new(mux_ctrl.clone(), 8); let encoder_seed: [u8; 32] = rand::rngs::OsRng.gen(); - let mpc_setup_fut = setup_mpc_backend(&self.config, mux_ctrl.clone(), encoder_seed); - let (mpc_tls, vm, ot_send, ot_recv, gf2, ot_fut) = futures::select! { - res = mpc_setup_fut.fuse() => res?, - _ = &mut mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; + let (mpc_tls, vm, ot_send) = mux_fut + .poll_with(setup_mpc_backend( + &self.config, + &mux_ctrl, + &mut exec, + encoder_seed, + )) + .await?; + + let io = mux_fut + .poll_with( + mux_ctrl + .open_framed(b"tlsnotary") + .map_err(VerifierError::from), + ) + .await?; + + let ctx = mux_fut + .poll_with(exec.new_thread().map_err(VerifierError::from)) + .await?; Ok(Verifier { config: self.config, state: state::Setup { + io, mux_ctrl, mux_fut, mpc_tls, vm, ot_send, - ot_recv, - ot_fut, - gf2, + ctx, encoder_seed, }, }) @@ -152,14 +144,13 @@ impl Verifier { /// Runs the verifier until the TLS connection is closed. pub async fn run(self) -> Result, VerifierError> { let state::Setup { + io, mux_ctrl, mut mux_fut, mpc_tls, vm, ot_send, - ot_recv, - mut ot_fut, - gf2, + ctx, encoder_seed, } = self.state; @@ -168,20 +159,15 @@ impl Verifier { .unwrap() .as_secs(); - let (_, mpc_fut) = mpc_tls.run(); - let MpcTlsFollowerData { handshake_commitment, server_key: server_ephemeral_key, bytes_sent: sent_len, bytes_recv: recv_len, - } = futures::select! { - res = mpc_fut.fuse() => res?, - _ = &mut mux_fut => return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - res = ot_fut => return Err(res.map(|_| ()).expect_err("future will not return Ok here")) - }; + } = mux_fut + .poll_with(mpc_tls.run().1.map_err(VerifierError::from)) + .await?; - #[cfg(feature = "tracing")] info!("Finished TLS session"); // TODO: We should be able to skip this commitment and verify the handshake directly. @@ -190,13 +176,12 @@ impl Verifier { Ok(Verifier { config: self.config, state: state::Closed { + io, mux_ctrl, mux_fut, vm, ot_send, - ot_recv, - ot_fut, - gf2, + ctx, encoder_seed, start_time, server_ephemeral_key, @@ -233,122 +218,115 @@ impl Verifier { } /// Performs a setup of the various MPC subprotocols. -#[cfg_attr(feature = "tracing", instrument(level = "debug", skip_all, err))] -#[allow(clippy::type_complexity)] +#[instrument(level = "debug", skip_all, err)] async fn setup_mpc_backend( config: &VerifierConfig, - mut mux_ctrl: MuxControl, + mux: &MuxControl, + exec: &mut Executor, encoder_seed: [u8; 32], -) -> Result< - ( - MpcTlsFollower, - DEAPVm, - SharedSender, - SharedReceiver, - ff::ConverterReceiver, - OTFuture, - ), - VerifierError, -> { - let (ot_send_sink, ot_send_stream) = mux_ctrl.get_channel("ot/1").await?.split(); - let (ot_recv_sink, ot_recv_stream) = mux_ctrl.get_channel("ot/0").await?.split(); - - let mut ot_sender_actor = OTSenderActor::new( - kos::Sender::new( - config.build_ot_sender_config(), - chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()), - ), - ot_send_sink, - ot_send_stream, +) -> Result<(MpcTlsFollower, DEAPThread, OTSender), VerifierError> { + let mut ot_sender = kos::Sender::new( + config.build_ot_sender_config(), + chou_orlandi::Receiver::new(config.build_base_ot_receiver_config()), ); + ot_sender.alloc(config.ot_sender_setup_count()); - let mut ot_receiver_actor = ReceiverActor::new( - kos::Receiver::new( - config.build_ot_receiver_config(), - chou_orlandi::Sender::new(config.build_base_ot_sender_config()), - ), - ot_recv_sink, - ot_recv_stream, + let mut ot_receiver = kos::Receiver::new( + config.build_ot_receiver_config(), + chou_orlandi::Sender::new(config.build_base_ot_sender_config()), ); - - let ot_send = ot_sender_actor.sender(); - let ot_recv = ot_receiver_actor.receiver(); - - #[cfg(feature = "tracing")] - debug!("Starting OT setup"); - - futures::try_join!( - ot_sender_actor - .setup(config.ot_sender_setup_count()) - .map_err(VerifierError::from), - ot_receiver_actor - .setup(config.ot_receiver_setup_count()) - .map_err(VerifierError::from) + ot_receiver.alloc(config.ot_receiver_setup_count()); + + let ot_sender = OTSender::new(ot_sender); + let ot_receiver = OTReceiver::new(ot_receiver); + + let ( + ctx_vm, + ctx_ke_0, + ctx_ke_1, + ctx_prf_0, + ctx_prf_1, + ctx_encrypter_block_cipher, + ctx_encrypter_stream_cipher, + ctx_encrypter_ghash, + ctx_encrypter, + ctx_decrypter_block_cipher, + ctx_decrypter_stream_cipher, + ctx_decrypter_ghash, + ctx_decrypter, + ) = futures::try_join!( + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), + exec.new_thread(), )?; - #[cfg(feature = "tracing")] - debug!("OT setup complete"); - - let ot_fut = OTFuture { - fut: Box::pin( - async move { - futures::try_join!( - ot_sender_actor.run().map_err(VerifierError::from), - ot_receiver_actor.run().map_err(VerifierError::from) - )?; - - Ok(ot_sender_actor) - } - .fuse(), - ), - }; - - let mut vm = DEAPVm::new( - "vm", - GarbleRole::Follower, + let vm = DEAPThread::new( + DEAPRole::Follower, encoder_seed, - mux_ctrl.get_channel("vm").await?, - Box::new(mux_ctrl.clone()), - ot_send.clone(), - ot_recv.clone(), + ctx_vm, + ot_sender.clone(), + ot_receiver.clone(), ); - let p256_sender_config = config.build_p256_sender_config(); - let channel = mux_ctrl.get_channel(p256_sender_config.id()).await?; - let p256_send = - ff::ConverterSender::::new(p256_sender_config, ot_send.clone(), channel); - - let p256_receiver_config = config.build_p256_receiver_config(); - let channel = mux_ctrl.get_channel(p256_receiver_config.id()).await?; - let p256_recv = - ff::ConverterReceiver::::new(p256_receiver_config, ot_recv.clone(), channel); - - let gf2_config = config.build_gf2_config(); - let channel = mux_ctrl.get_channel(gf2_config.id()).await?; - let gf2 = ff::ConverterReceiver::::new(gf2_config, ot_recv.clone(), channel); - let mpc_tls_config = config.build_mpc_tls_config(); - - let (ke, prf, encrypter, decrypter) = setup_components( - mpc_tls_config.common(), + let (ke, prf, encrypter, decrypter) = build_components( TlsRole::Follower, - &mut mux_ctrl, - &mut vm, - p256_send, - p256_recv, - gf2.handle() - .map_err(|e| VerifierError::MpcError(Box::new(e)))?, - ) - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; - - let channel = mux_ctrl.get_channel(mpc_tls_config.common().id()).await?; - let mut mpc_tls = MpcTlsFollower::new(mpc_tls_config, channel, ke, prf, encrypter, decrypter); + mpc_tls_config.common(), + ctx_ke_0, + ctx_encrypter, + ctx_decrypter, + ctx_encrypter_ghash, + ctx_decrypter_ghash, + vm.new_thread(ctx_ke_1, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread(ctx_prf_0, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread(ctx_prf_1, ot_sender.clone(), ot_receiver.clone())?, + vm.new_thread( + ctx_encrypter_block_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_decrypter_block_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_encrypter_stream_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + vm.new_thread( + ctx_decrypter_stream_cipher, + ot_sender.clone(), + ot_receiver.clone(), + )?, + ot_sender.clone(), + ot_receiver.clone(), + ); + + let channel = mux.open_framed(b"mpc_tls").await?; + let mut mpc_tls = MpcTlsFollower::new( + mpc_tls_config, + Box::new(StreamExt::compat_stream(channel)), + ke, + prf, + encrypter, + decrypter, + ); mpc_tls.setup().await?; - #[cfg(feature = "tracing")] debug!("MPC backend setup complete"); - Ok((mpc_tls, vm, ot_send, ot_recv, gf2, ot_fut)) + Ok((mpc_tls, vm, ot_sender)) } diff --git a/tlsn/tlsn-verifier/src/tls/notarize.rs b/tlsn/tlsn-verifier/src/tls/notarize.rs index 69545a82b..17ab0ea16 100644 --- a/tlsn/tlsn-verifier/src/tls/notarize.rs +++ b/tlsn/tlsn-verifier/src/tls/notarize.rs @@ -3,18 +3,15 @@ //! The TLS verifier is only a notary. use super::{state::Notarize, Verifier, VerifierError}; -use futures::{FutureExt, SinkExt, StreamExt, TryFutureExt}; use mpz_core::serialize::CanonicalSerialize; -use mpz_share_conversion::ShareConversionVerify; +use mpz_ot::CommittedOTSender; +use serio::{stream::IoStreamExt, SinkExt as _}; use signature::Signer; use tlsn_core::{ - msg::{SignedSessionHeader, TlsnMessage}, - HandshakeSummary, SessionHeader, Signature, + merkle::MerkleRoot, msg::SignedSessionHeader, HandshakeSummary, SessionHeader, Signature, }; -use utils_aio::{expect_msg_or_err, mux::MuxChannel}; -#[cfg(feature = "tracing")] -use tracing::info; +use tracing::{debug, info, instrument}; impl Verifier { /// Notarizes the TLS session. @@ -22,18 +19,18 @@ impl Verifier { /// # Arguments /// /// * `signer` - The signer used to sign the notarization result. + #[instrument(level = "debug", skip_all, err)] pub async fn finalize(self, signer: &impl Signer) -> Result where T: Into, { let Notarize { - mut mux_ctrl, + mut io, + mux_ctrl, mut mux_fut, mut vm, - ot_send, - ot_recv, - ot_fut, - mut gf2, + mut ot_send, + mut ctx, encoder_seed, start_time, server_ephemeral_key, @@ -42,69 +39,52 @@ impl Verifier { recv_len, } = self.state; - let notarize_fut = async { - let mut notarize_channel = mux_ctrl.get_channel("notarize").await?; + let session_header = mux_fut + .poll_with(async { + let merkle_root: MerkleRoot = io.expect_next().await?; - let merkle_root = - expect_msg_or_err!(notarize_channel, TlsnMessage::TranscriptCommitmentRoot)?; + // Finalize all MPC before signing the session header. + ot_send.reveal(&mut ctx).await?; - // Finalize all MPC before signing the session header. - let (mut ot_sender_actor, _, _) = futures::try_join!( - ot_fut, - ot_send.shutdown().map_err(VerifierError::from), - ot_recv.shutdown().map_err(VerifierError::from) - )?; + debug!("revealed OT secret"); - ot_sender_actor.reveal().await?; + vm.finalize() + .await + .map_err(|e| VerifierError::MpcError(Box::new(e)))?; - vm.finalize() - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; + info!("Finalized all MPC"); - gf2.verify() - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; + let handshake_summary = + HandshakeSummary::new(start_time, server_ephemeral_key, handshake_commitment); - #[cfg(feature = "tracing")] - info!("Finalized all MPC"); + let session_header = SessionHeader::new( + encoder_seed, + merkle_root, + sent_len, + recv_len, + handshake_summary, + ); - let handshake_summary = - HandshakeSummary::new(start_time, server_ephemeral_key, handshake_commitment); + let signature = signer.sign(&session_header.to_bytes()); - let session_header = SessionHeader::new( - encoder_seed, - merkle_root, - sent_len, - recv_len, - handshake_summary, - ); + info!("Signed session header"); - let signature = signer.sign(&session_header.to_bytes()); - - #[cfg(feature = "tracing")] - info!("Signed session header"); - - notarize_channel - .send(TlsnMessage::SignedSessionHeader(SignedSessionHeader { + io.send(SignedSessionHeader { header: session_header.clone(), signature: signature.into(), - })) + }) .await?; - #[cfg(feature = "tracing")] - info!("Sent session header"); - - Ok::<_, VerifierError>(session_header) - }; - - let session_header = futures::select! { - res = notarize_fut.fuse() => res?, - _ = &mut mux_fut => Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; + info!("Sent session header"); - let mut mux_ctrl = mux_ctrl.into_inner(); + Ok::<_, VerifierError>(session_header) + }) + .await?; - futures::try_join!(mux_ctrl.close().map_err(VerifierError::from), mux_fut)?; + if !mux_fut.is_complete() { + mux_ctrl.mux().close(); + mux_fut.await?; + } Ok(session_header) } diff --git a/tlsn/tlsn-verifier/src/tls/state.rs b/tlsn/tlsn-verifier/src/tls/state.rs index 3100fa061..ecbbb39da 100644 --- a/tlsn/tlsn-verifier/src/tls/state.rs +++ b/tlsn/tlsn-verifier/src/tls/state.rs @@ -1,16 +1,12 @@ //! TLS Verifier state. use mpz_core::hash::Hash; -use mpz_garble::protocol::deap::{DEAPThread, DEAPVm}; -use mpz_ot::actor::kos::{SharedReceiver, SharedSender}; -use mpz_share_conversion::{ConverterReceiver, Gf2_128}; use tls_core::key::PublicKey; use tls_mpc::MpcTlsFollower; -use tlsn_common::mux::MuxControl; -use tlsn_core::msg::TlsnMessage; -use utils_aio::duplex::Duplex; - -use crate::tls::future::{MuxFuture, OTFuture}; +use tlsn_common::{ + mux::{MuxControl, MuxFuture}, + Context, DEAPThread, Io, OTSender, +}; /// TLS Verifier state. pub trait VerifierState: sealed::Sealed {} @@ -22,29 +18,27 @@ opaque_debug::implement!(Initialized); /// State after MPC setup has completed. pub struct Setup { + pub(crate) io: Io, pub(crate) mux_ctrl: MuxControl, pub(crate) mux_fut: MuxFuture, pub(crate) mpc_tls: MpcTlsFollower, - pub(crate) vm: DEAPVm, - pub(crate) ot_send: SharedSender, - pub(crate) ot_recv: SharedReceiver, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterReceiver, + pub(crate) vm: DEAPThread, + pub(crate) ot_send: OTSender, + pub(crate) ctx: Context, pub(crate) encoder_seed: [u8; 32], } /// State after the TLS connection has been closed. pub struct Closed { + pub(crate) io: Io, pub(crate) mux_ctrl: MuxControl, pub(crate) mux_fut: MuxFuture, - pub(crate) vm: DEAPVm, - pub(crate) ot_send: SharedSender, - pub(crate) ot_recv: SharedReceiver, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterReceiver, + pub(crate) vm: DEAPThread, + pub(crate) ot_send: OTSender, + pub(crate) ctx: Context, pub(crate) encoder_seed: [u8; 32], pub(crate) start_time: u64, @@ -58,14 +52,13 @@ opaque_debug::implement!(Closed); /// Notarizing state. pub struct Notarize { + pub(crate) io: Io, pub(crate) mux_ctrl: MuxControl, pub(crate) mux_fut: MuxFuture, - pub(crate) vm: DEAPVm, - pub(crate) ot_send: SharedSender, - pub(crate) ot_recv: SharedReceiver, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterReceiver, + pub(crate) vm: DEAPThread, + pub(crate) ot_send: OTSender, + pub(crate) ctx: Context, pub(crate) encoder_seed: [u8; 32], pub(crate) start_time: u64, @@ -80,13 +73,12 @@ opaque_debug::implement!(Notarize); impl From for Notarize { fn from(value: Closed) -> Self { Self { + io: value.io, mux_ctrl: value.mux_ctrl, mux_fut: value.mux_fut, vm: value.vm, ot_send: value.ot_send, - ot_recv: value.ot_recv, - ot_fut: value.ot_fut, - gf2: value.gf2, + ctx: value.ctx, encoder_seed: value.encoder_seed, start_time: value.start_time, server_ephemeral_key: value.server_ephemeral_key, @@ -99,23 +91,19 @@ impl From for Notarize { /// Verifying state. pub struct Verify { + pub(crate) io: Io, pub(crate) mux_ctrl: MuxControl, pub(crate) mux_fut: MuxFuture, - pub(crate) vm: DEAPVm, - pub(crate) ot_send: SharedSender, - pub(crate) ot_recv: SharedReceiver, - pub(crate) ot_fut: OTFuture, - pub(crate) gf2: ConverterReceiver, + pub(crate) vm: DEAPThread, + pub(crate) ot_send: OTSender, + pub(crate) ctx: Context, pub(crate) start_time: u64, pub(crate) server_ephemeral_key: PublicKey, pub(crate) handshake_commitment: Hash, pub(crate) sent_len: usize, pub(crate) recv_len: usize, - - pub(crate) channel: Option>>, - pub(crate) verify_thread: Option>, } opaque_debug::implement!(Verify); @@ -123,20 +111,17 @@ opaque_debug::implement!(Verify); impl From for Verify { fn from(value: Closed) -> Self { Self { + io: value.io, mux_ctrl: value.mux_ctrl, mux_fut: value.mux_fut, vm: value.vm, ot_send: value.ot_send, - ot_recv: value.ot_recv, - ot_fut: value.ot_fut, - gf2: value.gf2, + ctx: value.ctx, start_time: value.start_time, server_ephemeral_key: value.server_ephemeral_key, handshake_commitment: value.handshake_commitment, sent_len: value.sent_len, recv_len: value.recv_len, - channel: None, - verify_thread: None, } } } diff --git a/tlsn/tlsn-verifier/src/tls/verify.rs b/tlsn/tlsn-verifier/src/tls/verify.rs index 4617629d3..8683a0087 100644 --- a/tlsn/tlsn-verifier/src/tls/verify.rs +++ b/tlsn/tlsn-verifier/src/tls/verify.rs @@ -3,17 +3,15 @@ //! The TLS verifier is an application-specific verifier. use super::{state::Verify as VerifyState, Verifier, VerifierError}; -use futures::{FutureExt, StreamExt, TryFutureExt}; use mpz_circuits::types::Value; -use mpz_garble::{Memory, Verify, Vm}; -use mpz_share_conversion::ShareConversionVerify; +use mpz_garble::{Memory, Verify}; +use mpz_ot::CommittedOTSender; +use serio::stream::IoStreamExt; use tlsn_core::{ - msg::TlsnMessage, proof::SessionInfo, transcript::get_value_ids, Direction, HandshakeSummary, + msg::ProvingInfo, proof::SessionInfo, transcript::get_value_ids, Direction, HandshakeSummary, RedactedTranscript, TranscriptSlice, }; -use utils_aio::{expect_msg_or_err, mux::MuxChannel}; -#[cfg(feature = "tracing")] use tracing::info; impl Verifier { @@ -25,159 +23,123 @@ impl Verifier { pub async fn receive( &mut self, ) -> Result<(RedactedTranscript, RedactedTranscript), VerifierError> { - let verify_fut = async { - // Create a new channel and vm thread if not already present - let channel = if let Some(ref mut channel) = self.state.channel { - channel - } else { - self.state.channel = Some(self.state.mux_ctrl.get_channel("prove-verify").await?); - self.state.channel.as_mut().unwrap() - }; - - let verify_thread = if let Some(ref mut verify_thread) = self.state.verify_thread { - verify_thread - } else { - self.state.verify_thread = Some(self.state.vm.new_thread("prove-verify").await?); - self.state.verify_thread.as_mut().unwrap() - }; - - // Receive the proving info from the prover - let mut proving_info = expect_msg_or_err!(channel, TlsnMessage::ProvingInfo)?; - let mut cleartext = proving_info.cleartext.clone(); - - #[cfg(feature = "tracing")] - info!("Received proving info from prover"); - - // Check ranges - if proving_info.sent_ids.max().unwrap_or_default() > self.state.sent_len - || proving_info.recv_ids.max().unwrap_or_default() > self.state.recv_len - { - return Err(VerifierError::InvalidRange); - } - - // Now verify the transcript parts which the prover wants to reveal - let sent_value_ids = proving_info - .sent_ids - .iter_ranges() - .map(|r| get_value_ids(&r.into(), Direction::Sent).collect::>()); - let recv_value_ids = proving_info - .recv_ids - .iter_ranges() - .map(|r| get_value_ids(&r.into(), Direction::Received).collect::>()); - - let value_refs = sent_value_ids - .chain(recv_value_ids) - .map(|ids| { - let inner_refs = ids - .iter() - .map(|id| { - verify_thread - .get_value(id.as_str()) - .expect("Byte should be in VM memory") - }) - .collect::>(); - - verify_thread - .array_from_values(inner_refs.as_slice()) - .expect("Byte should be in VM Memory") - }) - .collect::>(); - - let values = proving_info - .sent_ids - .iter_ranges() - .chain(proving_info.recv_ids.iter_ranges()) - .map(|range| { - Value::Array(cleartext.drain(..range.len()).map(|b| (b).into()).collect()) - }) - .collect::>(); - - // Check that purported values are correct - verify_thread.verify(&value_refs, &values).await?; - - #[cfg(feature = "tracing")] - info!("Successfully verified purported cleartext"); - - // Create redacted transcripts - let mut transcripts = proving_info - .sent_ids - .iter_ranges() - .chain(proving_info.recv_ids.iter_ranges()) - .map(|range| { - TranscriptSlice::new( - range.clone(), - proving_info.cleartext.drain(..range.len()).collect(), - ) - }) - .collect::>(); - - let recv_transcripts = - transcripts.split_off(proving_info.sent_ids.iter_ranges().count()); - let (sent_redacted, recv_redacted) = ( - RedactedTranscript::new(self.state.sent_len, transcripts), - RedactedTranscript::new(self.state.recv_len, recv_transcripts), - ); - - #[cfg(feature = "tracing")] - info!("Successfully created redacted transcripts"); - - Ok::<_, VerifierError>((sent_redacted, recv_redacted)) - }; - - futures::select! { - res = verify_fut.fuse() => res, - _ = &mut self.state.mux_fut => Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - } + self.state + .mux_fut + .poll_with(async { + // Receive the proving info from the prover + let mut proving_info: ProvingInfo = self.state.io.expect_next().await?; + let mut cleartext = proving_info.cleartext.clone(); + + info!("Received proving info from prover"); + + // Check ranges + if proving_info.sent_ids.max().unwrap_or_default() > self.state.sent_len + || proving_info.recv_ids.max().unwrap_or_default() > self.state.recv_len + { + return Err(VerifierError::InvalidRange); + } + + // Now verify the transcript parts which the prover wants to reveal + let sent_value_ids = proving_info + .sent_ids + .iter_ranges() + .map(|r| get_value_ids(&r.into(), Direction::Sent).collect::>()); + let recv_value_ids = proving_info.recv_ids.iter_ranges().map(|r| { + get_value_ids(&r.into(), Direction::Received).collect::>() + }); + + let value_refs = sent_value_ids + .chain(recv_value_ids) + .map(|ids| { + let inner_refs = ids + .iter() + .map(|id| { + self.state + .vm + .get_value(id.as_str()) + .expect("Byte should be in VM memory") + }) + .collect::>(); + + self.state + .vm + .array_from_values(inner_refs.as_slice()) + .expect("Byte should be in VM Memory") + }) + .collect::>(); + + let values = proving_info + .sent_ids + .iter_ranges() + .chain(proving_info.recv_ids.iter_ranges()) + .map(|range| { + Value::Array(cleartext.drain(..range.len()).map(|b| (b).into()).collect()) + }) + .collect::>(); + + // Check that purported values are correct + self.state.vm.verify(&value_refs, &values).await?; + + info!("Successfully verified purported cleartext"); + + // Create redacted transcripts + let mut transcripts = proving_info + .sent_ids + .iter_ranges() + .chain(proving_info.recv_ids.iter_ranges()) + .map(|range| { + TranscriptSlice::new( + range.clone(), + proving_info.cleartext.drain(..range.len()).collect(), + ) + }) + .collect::>(); + + let recv_transcripts = + transcripts.split_off(proving_info.sent_ids.iter_ranges().count()); + let (sent_redacted, recv_redacted) = ( + RedactedTranscript::new(self.state.sent_len, transcripts), + RedactedTranscript::new(self.state.recv_len, recv_transcripts), + ); + + info!("Successfully created redacted transcripts"); + + Ok::<_, VerifierError>((sent_redacted, recv_redacted)) + }) + .await } /// Verifies the TLS session. pub async fn finalize(self) -> Result { let VerifyState { - mut mux_ctrl, + mut io, + mux_ctrl, mut mux_fut, mut vm, - ot_send, - ot_recv, - ot_fut, - mut gf2, + mut ot_send, + mut ctx, start_time, server_ephemeral_key, handshake_commitment, .. } = self.state; - let finalize_fut = async { - let mut channel = mux_ctrl.get_channel("finalize").await?; - - // Finalize all MPC - let (mut ot_sender_actor, _, _) = futures::try_join!( - ot_fut, - ot_send.shutdown().map_err(VerifierError::from), - ot_recv.shutdown().map_err(VerifierError::from) - )?; - - ot_sender_actor.reveal().await?; - - vm.finalize() - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; + let session_info = mux_fut + .poll_with(async { + // Finalize all MPC + ot_send.reveal(&mut ctx).await?; - gf2.verify() - .await - .map_err(|e| VerifierError::MpcError(Box::new(e)))?; + vm.finalize() + .await + .map_err(|e| VerifierError::MpcError(Box::new(e)))?; - let session_info = expect_msg_or_err!(channel, TlsnMessage::SessionInfo)?; + let session_info: SessionInfo = io.expect_next().await?; - #[cfg(feature = "tracing")] - info!("Finalized all MPC"); + info!("Finalized all MPC"); - Ok::<_, VerifierError>(session_info) - }; - - let session_info = futures::select! { - res = finalize_fut.fuse() => res?, - _ = &mut mux_fut => Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?, - }; + Ok::<_, VerifierError>(session_info) + }) + .await?; let handshake_summary = HandshakeSummary::new(start_time, server_ephemeral_key, handshake_commitment); @@ -185,12 +147,12 @@ impl Verifier { // Verify the TLS session session_info.verify(&handshake_summary, self.config.cert_verifier())?; - #[cfg(feature = "tracing")] info!("Successfully verified session"); - let mut mux_ctrl = mux_ctrl.into_inner(); - - futures::try_join!(mux_ctrl.close().map_err(VerifierError::from), mux_fut)?; + if !mux_fut.is_complete() { + mux_ctrl.mux().close(); + mux_fut.await?; + } Ok(session_info) }