diff --git a/dune-project b/dune-project index 6af5f58..5b18faf 100644 --- a/dune-project +++ b/dune-project @@ -21,7 +21,7 @@ (name pgx) (synopsis "Pure-OCaml PostgreSQL client library") (description - "PGX is a pure-OCaml PostgreSQL client library, supporting Async, LWT, or synchronous operations.") + "PGX is a pure-OCaml PostgreSQL client library, supporting Async, LWT, or synchronous operations.") (depends (alcotest (and @@ -52,9 +52,9 @@ (package (name pgx_unix) (synopsis - "PGX using the standard library's Unix module for IO (synchronous)") + "PGX using the standard library's Unix module for IO (synchronous)") (description - "PGX using the standard library's Unix module for IO (synchronous)") + "PGX using the standard library's Unix module for IO (synchronous)") (depends (alcotest (and @@ -82,10 +82,12 @@ (>= "v0.13.0")) (async_unix (>= "v0.13.0")) + async_ssl (base64 (and :with-test (>= 3.0.0))) + conduit-async (ocaml (>= 4.08)) (pgx diff --git a/pgx/src/io_intf.ml b/pgx/src/io_intf.ml index 867ac38..00babdd 100644 --- a/pgx/src/io_intf.ml +++ b/pgx/src/io_intf.ml @@ -14,6 +14,18 @@ module type S = sig | Inet of string * int val open_connection : sockaddr -> (in_channel * out_channel) t + + type ssl_config + + val upgrade_ssl + : [ `Not_supported + | `Supported of + ?ssl_config:ssl_config + -> in_channel + -> out_channel + -> (in_channel * out_channel) t + ] + val output_char : out_channel -> char -> unit t val output_binary_int : out_channel -> int -> unit t val output_string : out_channel -> string -> unit t diff --git a/pgx/src/pgx.ml b/pgx/src/pgx.ml index 9c00400..4ae77e5 100644 --- a/pgx/src/pgx.ml +++ b/pgx/src/pgx.ml @@ -288,6 +288,7 @@ module Message_out = struct | Describe_portal of portal (* DP *) | Startup_message of startup | Simple_query of query + | SSLRequest [@@deriving sexp] let add_byte buf i = @@ -381,6 +382,10 @@ module Message_out = struct add_byte msg 0; None, Buffer.contents msg | Simple_query q -> Some 'Q', str q + | SSLRequest -> + let msg = Buffer.create 8 in + add_int32 msg 80877103l; + None, Buffer.contents msg ;; end @@ -526,7 +531,59 @@ module Make (Thread : Io) = struct (*----- Connection. -----*) + let attempt_tls_upgrade ?(ssl = `Auto) ({ ichan; chan; _ } as conn) = + (* To initiate an SSL-encrypted connection, the frontend initially sends an SSLRequest message rather than a + StartupMessage. The server then responds with a single byte containing S or N, indicating that it is willing + or unwilling to perform SSL, respectively. The frontend might close the connection at this point if it is + dissatisfied with the response. To continue after S, perform an SSL startup handshake (not described here, + part of the SSL specification) with the server. If this is successful, continue with sending the usual + StartupMessage. In this case the StartupMessage and all subsequent data will be SSL-encrypted. To continue + after N, send the usual StartupMessage and proceed without encryption. + See https://www.postgresql.org/docs/9.3/protocol-flow.html#AEN100021 *) + match ssl with + | `No -> return conn + | (`Auto | `Always _) as ssl -> + (match Io.upgrade_ssl with + | `Not_supported -> + (match ssl with + | `Always _ -> + failwith + "TLS support is not compiled into this Pgx library but ~ssl was set to \ + `Always" + | _ -> ()); + debug + "TLS-support is not compiled into this Pgx library, not attempting to upgrade" + >>| fun () -> conn + | `Supported upgrade_ssl -> + debug "Request SSL upgrade from server" + >>= fun () -> + let msg = Message_out.SSLRequest in + send_message conn msg + >>= fun () -> + flush chan + >>= fun () -> + input_char ichan + >>= (function + | 'S' -> + debug "Server supports TLS, attempting to upgrade" + >>= fun () -> + let ssl_config = + match ssl with + | `Auto -> None + | `Always ssl_config -> Some ssl_config + in + upgrade_ssl ?ssl_config ichan chan + >>= fun (ichan, chan) -> return { conn with ichan; chan } + | 'N' -> debug "Server does not support TLS, not upgrading" >>| fun () -> conn + | c -> + fail_msg + "Got unexpected response '%c' from server after SSLRequest message. Response \ + should always be 'S' or 'N'." + c)) + ;; + let connect + ?ssl ?host ?port ?user @@ -600,6 +657,8 @@ module Make (Thread : Io) = struct ; prepared_num = Int64.of_int 0 } in + attempt_tls_upgrade ?ssl conn + >>= fun conn -> (* Send the StartUpMessage. NB. At present we do not support SSL. *) let msg = Message_out.Startup_message { Message_out.user; database } in (* Loop around here until the database gives a ReadyForQuery message. *) @@ -665,6 +724,7 @@ module Make (Thread : Io) = struct ;; let with_conn + ?ssl ?host ?port ?user @@ -676,6 +736,7 @@ module Make (Thread : Io) = struct f = connect + ?ssl ?host ?port ?user diff --git a/pgx/src/pgx.mli b/pgx/src/pgx.mli index 526f400..cbf5ec6 100644 --- a/pgx/src/pgx.mli +++ b/pgx/src/pgx.mli @@ -49,4 +49,5 @@ module Value = Pgx_value module type S = Pgx_intf.S -module Make (Thread : Io) : S with type 'a Io.t = 'a Thread.t +module Make (Thread : Io) : + S with type 'a Io.t = 'a Thread.t and type Io.ssl_config = Thread.ssl_config diff --git a/pgx/src/pgx_intf.ml b/pgx/src/pgx_intf.ml index f6167ad..6bd16e4 100644 --- a/pgx/src/pgx_intf.ml +++ b/pgx/src/pgx_intf.ml @@ -5,6 +5,7 @@ module type S = sig module Io : sig type 'a t + type ssl_config val return : 'a -> 'a t val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t @@ -22,7 +23,8 @@ module type S = sig possible denial of service. You may want to set this to a smaller size to avoid this happening. *) val connect - : ?host:string + : ?ssl:[ `Auto | `No | `Always of Io.ssl_config ] + -> ?host:string -> ?port:int -> ?user:string -> ?password:string @@ -42,7 +44,8 @@ module type S = sig [close]. This is the preferred way to use this library since it cleans up after itself. *) val with_conn - : ?host:string + : ?ssl:[ `Auto | `No | `Always of Io.ssl_config ] + -> ?host:string -> ?port:int -> ?user:string -> ?password:string diff --git a/pgx_async.opam b/pgx_async.opam index e9dbcfe..b6a7eff 100644 --- a/pgx_async.opam +++ b/pgx_async.opam @@ -13,7 +13,9 @@ depends: [ "alcotest-async" {with-test & >= "1.0.0"} "async_kernel" {>= "v0.13.0"} "async_unix" {>= "v0.13.0"} + "async_ssl" "base64" {with-test & >= "3.0.0"} + "conduit-async" "ocaml" {>= "4.08"} "pgx" {= version} "pgx_value_core" {= version} diff --git a/pgx_async/src/dune b/pgx_async/src/dune index 5e63847..e361930 100644 --- a/pgx_async/src/dune +++ b/pgx_async/src/dune @@ -11,6 +11,6 @@ let () = Jbuild_plugin.V1.send @@ {| (library (public_name pgx_async) (wrapped false) - (libraries async_kernel async_unix pgx_value_core) + (libraries async_kernel async_unix conduit-async pgx_value_core) |} ^ preprocess ^ {|) |} diff --git a/pgx_async/src/pgx_async.ml b/pgx_async/src/pgx_async.ml index 7aec8ae..0e3ac26 100644 --- a/pgx_async/src/pgx_async.ml +++ b/pgx_async/src/pgx_async.ml @@ -73,19 +73,25 @@ module Thread = struct let close_in = Reader.close let open_connection sockaddr = - let get_reader_writer socket = - let fd = Socket.fd socket in - Reader.create fd, Writer.create fd - in match sockaddr with - | Unix path -> - let unix_sockaddr = Tcp.Where_to_connect.of_unix_address (`Unix path) in - Tcp.connect_sock unix_sockaddr >>| get_reader_writer + | Unix path -> Conduit_async.connect (`Unix_domain_socket path) | Inet (host, port) -> - let inet_sockaddr = - Tcp.Where_to_connect.of_host_and_port (Host_and_port.create ~host ~port) - in - Tcp.connect_sock inet_sockaddr >>| get_reader_writer + Uri.make ~host ~port () + |> Conduit_async.V3.resolve_uri + >>= Conduit_async.V3.connect + >>| fun (_socket, in_channel, out_channel) -> in_channel, out_channel + ;; + + type ssl_config = Conduit_async.Ssl.config + + let upgrade_ssl = + try + let default_config = Conduit_async.V1.Conduit_async_ssl.Ssl_config.configure () in + `Supported + (fun ?(ssl_config = default_config) in_channel out_channel -> + Conduit_async.V1.Conduit_async_ssl.ssl_connect ssl_config in_channel out_channel) + with + | _ -> `Not_supported ;; (* The unix getlogin syscall can fail *) @@ -130,6 +136,7 @@ let check_pgdatabase = ;; let connect + ?ssl ?host ?port ?user @@ -146,6 +153,7 @@ let connect | None -> Lazy_deferred.force_exn default_unix_domain_socket_dir) >>= fun unix_domain_socket_dir -> connect + ?ssl ?host ?port ?user @@ -158,6 +166,7 @@ let connect ;; let with_conn + ?ssl ?host ?port ?user @@ -169,6 +178,7 @@ let with_conn f = connect + ?ssl ?host ?port ?user diff --git a/pgx_async/src/pgx_async.mli b/pgx_async/src/pgx_async.mli index 8adc38f..f816033 100644 --- a/pgx_async/src/pgx_async.mli +++ b/pgx_async/src/pgx_async.mli @@ -1,23 +1,14 @@ (** Async based Postgres client based on Pgx. *) open Async_kernel -include Pgx.S with type 'a Io.t = 'a Deferred.t +include + Pgx.S + with type 'a Io.t = 'a Deferred.t + and type Io.ssl_config = Conduit_async.Ssl.config (* for testing purposes *) module Thread : Pgx.Io with type 'a t = 'a Deferred.t -val with_conn - : ?host:string - -> ?port:int - -> ?user:string - -> ?password:string - -> ?database:string - -> ?unix_domain_socket_dir:string - -> ?verbose:int - -> ?max_message_length:int - -> (t -> 'a Deferred.t) - -> 'a Deferred.t - (** Like [execute] but returns a pipe so you can operate on the results before they have all returned. Note that [execute_iter] and [execute_fold] can perform significantly better because they don't have as much overhead. *) diff --git a/pgx_lwt/src/pgx_lwt.ml b/pgx_lwt/src/pgx_lwt.ml index 70f827d..c70ac8a 100644 --- a/pgx_lwt/src/pgx_lwt.ml +++ b/pgx_lwt/src/pgx_lwt.ml @@ -47,6 +47,8 @@ module Thread = struct let close_in = Io.close_in let open_connection = Io.open_connection + type ssl_config + let upgrade_ssl = `Not_supported let getlogin = Io.getlogin let debug s = Logs_lwt.debug (fun m -> m "%s" s) let protect f ~finally = Lwt.finalize f finally diff --git a/pgx_lwt_unix.opam b/pgx_lwt_unix.opam index 9a8e7e0..3163469 100644 --- a/pgx_lwt_unix.opam +++ b/pgx_lwt_unix.opam @@ -12,7 +12,6 @@ depends: [ "dune" {>= "1.11"} "alcotest-lwt" {with-test & >= "1.0.0"} "base64" {with-test & >= "3.0.0"} - "lwt" "ocaml" {>= "4.08"} "pgx" {= version} "pgx_lwt" {= version} diff --git a/pgx_unix/src/pgx_unix.ml b/pgx_unix/src/pgx_unix.ml index d0b94b6..965f7f6 100644 --- a/pgx_unix/src/pgx_unix.ml +++ b/pgx_unix/src/pgx_unix.ml @@ -55,6 +55,9 @@ module Simple_thread = struct Unix.open_connection std_socket ;; + type ssl_config + + let upgrade_ssl = `Not_supported let output_char = output_char let output_binary_int = output_binary_int let output_string = output_string