Skip to content

Commit

Permalink
feat: Better logging in prover task
Browse files Browse the repository at this point in the history
# Conflicts:
#	wasm/prover/src/lib.rs
  • Loading branch information
heeckhau committed Feb 22, 2024
1 parent 1415792 commit 9eeb191
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 34 deletions.
3 changes: 3 additions & 0 deletions wasm/prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ ws_stream_wasm = {version = "0.7.4", git = "https://github.com/tlsnotary/ws_stre
# code size when deploying.
console_error_panic_hook = {version = "0.1.7"}

strum = {version = "0.26.1"}
strum_macros = "0.26.1"

[dev-dependencies]
wasm-bindgen-test = "0.3.34"

Expand Down
105 changes: 71 additions & 34 deletions wasm/prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ use elliptic_curve::pkcs8::DecodePublicKey;
use std::time::Duration;
use tlsn_core::proof::{SessionProof, TlsProof};

use strum::EnumMessage;
use strum_macros;

// A macro to provide `println!(..)`-style syntax for `console.log` logging.
macro_rules! log {
( $( $t:tt )* ) => {
Expand Down Expand Up @@ -59,6 +62,53 @@ async fn fetch_as_json_string(url: &str, opts: &RequestInit) -> Result<String, J
.ok_or_else(|| JsValue::from_str("Could not stringify JSON"))
}

#[derive(strum_macros::EnumMessage, Debug, Clone, Copy)]
#[allow(dead_code)]
enum ProverPhases {
#[strum(message = "Connect application server with websocket proxy")]
ConnectWsProxy,
#[strum(message = "Build prover config")]
BuildProverConfig,
#[strum(message = "Set up prover")]
SetUpProver,
#[strum(message = "Bind the prover to the server connection")]
BindProverToConnection,
#[strum(message = "Spawn the prover thread")]
SpawnProverThread,
#[strum(message = "Attach the hyper HTTP client to the TLS connection")]
AttachHttpClient,
#[strum(message = "Spawn the HTTP task to be run concurrently")]
SpawnHttpTask,
#[strum(message = "Build request")]
BuildRequest,
#[strum(message = "Start MPC-TLS connection with the server")]
StartMpcConnection,
#[strum(message = "Received response from the server")]
ReceivedResponse,
#[strum(message = "Parsing response from the server")]
ParseResponse,
#[strum(message = "Close the connection to the server")]
CloseConnection,
#[strum(message = "Start notarization")]
StartNotarization,
#[strum(message = "Commit to data")]
Commit,
#[strum(message = "Finalize")]
Finalize,
#[strum(message = "Notarization complete")]
NotarizationComplete,
#[strum(message = "Create Proof")]
CreateProof,
}

fn log_phase(phase: ProverPhases) {
log!(
"!@# tlsn-js {}: {}",
phase as u8,
phase.get_message().unwrap()
);
}

#[wasm_bindgen]
pub async fn prover(
target_url_str: &str,
Expand All @@ -78,8 +128,8 @@ pub async fn prover(
);
let options: RequestOptions = serde_wasm_bindgen::from_value(val)
.map_err(|e| JsValue::from_str(&format!("Could not deserialize options: {:?}", e)))?;
log!("done!");
log!("options.notary_url: {}", options.notary_url.as_str());

// let fmt_layer = tracing_subscriber::fmt::layer()
// .with_ansi(false) // Only partially supported across browsers
// .with_timer(UtcTime::rfc_3339()) // std::time is not available in browsers
Expand Down Expand Up @@ -162,7 +212,7 @@ pub async fn prover(
.expect_throw("assume the notary ws connection succeeds");
let notary_ws_stream_into = notary_ws_stream.into_io();

log!("!@# 0");
log_phase(ProverPhases::BuildProverConfig);

let target_host = target_url
.host_str()
Expand All @@ -175,68 +225,59 @@ pub async fn prover(
.build()
.map_err(|e| JsValue::from_str(&format!("Could not build prover config: {:?}", e)))?;

log!("!@# 1");

// Create a Prover and set it up with the Notary
// This will set up the MPC backend prior to connecting to the server.
log_phase(ProverPhases::SetUpProver);
let prover = Prover::new(config)
.setup(notary_ws_stream_into)
.await
.map_err(|e| JsValue::from_str(&format!("Could not set up prover: {:?}", e)))?;

log!("!@# 2");
/*
Connect Application Server with websocket proxy
*/
log_phase(ProverPhases::ConnectWsProxy);

let (_, client_ws_stream) = WsMeta::connect(options.websocket_proxy_url, None)
.await
.expect_throw("assume the client ws connection succeeds");
let client_ws_stream_into = client_ws_stream.into_io();

log!("!@# 3");

// Bind the Prover to the server connection.
// The returned `mpc_tls_connection` is an MPC TLS connection to the Server: all data written
// to/read from it will be encrypted/decrypted using MPC with the Notary.
log_phase(ProverPhases::BindProverToConnection);
let (mpc_tls_connection, prover_fut) = prover
.connect(client_ws_stream_into)
.await
.map_err(|e| JsValue::from_str(&format!("Could not connect prover: {:?}", e)))?;

log!("!@# 4");

// let prover_task = tokio::spawn(prover_fut);
log_phase(ProverPhases::SpawnProverThread);
let (prover_sender, prover_receiver) = oneshot::channel();
let handled_prover_fut = async {
let result = prover_fut.await;
let _ = prover_sender.send(result);
};
spawn_local(handled_prover_fut);
log!("!@# 7");

// Attach the hyper HTTP client to the TLS connection
log_phase(ProverPhases::AttachHttpClient);
let (mut request_sender, connection) =
hyper::client::conn::handshake(mpc_tls_connection.compat())
.await
.map_err(|e| JsValue::from_str(&format!("Could not handshake: {:?}", e)))?;
log!("!@# 8");

// Spawn the HTTP task to be run concurrently
// let connection_task = tokio::spawn(connection.without_shutdown());
log_phase(ProverPhases::SpawnHttpTask);
let (connection_sender, connection_receiver) = oneshot::channel();
let connection_fut = connection.without_shutdown();
let handled_connection_fut = async {
let result = connection_fut.await;
let _ = connection_sender.send(result);
};
spawn_local(handled_connection_fut);
log!(
"!@# 9 - {} request to {}",
options.method.as_str(),
target_url_str
);

log_phase(ProverPhases::BuildRequest);
let mut req_with_header = Request::builder()
.uri(target_url_str)
.method(options.method.as_str());
Expand All @@ -259,40 +300,36 @@ pub async fn prover(
let unwrapped_request = req_with_body
.map_err(|e| JsValue::from_str(&format!("Could not build request: {:?}", e)))?;

log!("Starting an MPC TLS connection with the server");
log_phase(ProverPhases::StartMpcConnection);

// Send the request to the Server and get a response via the MPC TLS connection
let response = request_sender
.send_request(unwrapped_request)
.await
.map_err(|e| JsValue::from_str(&format!("Could not send request: {:?}", e)))?;

log!("Got a response from the server");

log_phase(ProverPhases::ReceivedResponse);
if response.status() != StatusCode::OK {
return Err(JsValue::from_str(&format!(
"Response status is not OK: {:?}",
response.status()
)));
}

log!("Request OK");

log_phase(ProverPhases::ParseResponse);
// Pretty printing :)
let payload = to_bytes(response.into_body())
.await
.map_err(|e| JsValue::from_str(&format!("Could not get response body: {:?}", e)))?
.to_vec();
let parsed = serde_json::from_str::<serde_json::Value>(&String::from_utf8_lossy(&payload))
.map_err(|e| JsValue::from_str(&format!("Could not parse response: {:?}", e)))?;
log!("!@# 10");
let response_pretty = serde_json::to_string_pretty(&parsed)
.map_err(|e| JsValue::from_str(&format!("Could not serialize response: {:?}", e)))?;
log!("{}", response_pretty);
log!("!@# 11");
log!("Response: {}", response_pretty);

// Close the connection to the server
// let mut client_socket = connection_task.await.unwrap().unwrap().io.into_inner();
log_phase(ProverPhases::CloseConnection);
let mut client_socket = connection_receiver
.await
.map_err(|e| {
Expand All @@ -304,23 +341,20 @@ pub async fn prover(
.map_err(|e| JsValue::from_str(&format!("Could not get TlsConnection: {:?}", e)))?
.io
.into_inner();
log!("!@# 12");
client_socket
.close()
.await
.map_err(|e| JsValue::from_str(&format!("Could not close socket: {:?}", e)))?;
log!("!@# 13");

// The Prover task should be done now, so we can grab it.
// let mut prover = prover_task.await.unwrap().unwrap();
log_phase(ProverPhases::StartNotarization);
let prover = prover_receiver
.await
.map_err(|e| {
JsValue::from_str(&format!("Could not receive from prover_receiver: {:?}", e))
})?
.map_err(|e| JsValue::from_str(&format!("Could not get Prover: {:?}", e)))?;
let mut prover = prover.start_notarize();
log!("!@# 14");

let secret_headers_vecs = string_list_to_bytes_vec(&secret_headers)?;
let secret_headers_slices: Vec<&[u8]> = secret_headers_vecs
Expand All @@ -343,7 +377,8 @@ pub async fn prover(
prover.recv_transcript().data(),
secret_body_slices.as_slice(),
);
log!("!@# 15");

log_phase(ProverPhases::Commit);

let _recv_len = prover.recv_transcript().data().len();

Expand Down Expand Up @@ -387,14 +422,16 @@ pub async fn prover(
})?;

// Finalize, returning the notarized session
log_phase(ProverPhases::Finalize);
let notarized_session = prover
.finalize()
.await
.map_err(|e| JsValue::from_str(&format!("Error finalizing prover: {:?}", e)))?;

log!("Notarization complete!");
log_phase(ProverPhases::NotarizationComplete);

// Create a proof for all committed data in this session
log_phase(ProverPhases::CreateProof);
let session_proof = notarized_session.session_proof();

let mut proof_builder = notarized_session.data().build_substrings_proof();
Expand Down Expand Up @@ -423,7 +460,7 @@ pub async fn prover(
.map_err(|e| JsValue::from_str(&format!("Could not serialize proof: {:?}", e)))?;

let duration = start_time.elapsed();
log!("!@# request takes: {} seconds", duration.as_secs());
log!("!@# request took {} seconds", duration.as_secs());

Ok(res)
}
Expand Down

0 comments on commit 9eeb191

Please sign in to comment.