Skip to content

Commit

Permalink
Eio.Net.connect: optionally bind before connect
Browse files Browse the repository at this point in the history
  • Loading branch information
art-w committed Mar 20, 2024
1 parent 14ae3cf commit 594bc12
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 21 deletions.
2 changes: 1 addition & 1 deletion lib_eio/mock/net.ml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ module Impl = struct
Switch.on_release sw (fun () -> Eio.Resource.close socket);
socket

let connect t ~sw addr =
let connect t ~sw ?bind:_ addr =
traceln "%s: connect to %a" t.label Eio.Net.Sockaddr.pp addr;
let socket = Handler.run t.on_connect in
Switch.on_release sw (fun () -> Eio.Flow.close socket);
Expand Down
6 changes: 3 additions & 3 deletions lib_eio/net.ml
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ module Pi = struct
type tag

val listen : t -> reuse_addr:bool -> reuse_port:bool -> backlog:int -> sw:Switch.t -> Sockaddr.stream -> tag listening_socket_ty r
val connect : t -> sw:Switch.t -> Sockaddr.stream -> tag stream_socket_ty r
val connect : t -> sw:Switch.t -> ?bind:Sockaddr.stream -> Sockaddr.stream -> tag stream_socket_ty r
val datagram_socket :
t
-> reuse_addr:bool
Expand Down Expand Up @@ -295,10 +295,10 @@ let listen (type tag) ?(reuse_addr=false) ?(reuse_port=false) ~backlog ~sw (t:[>
let module X = (val (Resource.get ops Pi.Network)) in
X.listen t ~reuse_addr ~reuse_port ~backlog ~sw

let connect (type tag) ~sw (t:[> tag ty] r) addr =
let connect (type tag) ~sw ?bind (t:[> tag ty] r) addr =
let (Resource.T (t, ops)) = t in
let module X = (val (Resource.get ops Pi.Network)) in
try X.connect t ~sw addr
try X.connect t ~sw ?bind addr
with Exn.Io _ as ex ->
let bt = Printexc.get_raw_backtrace () in
Exn.reraise_with_context ex bt "connecting to %a" Sockaddr.pp addr
Expand Down
8 changes: 5 additions & 3 deletions lib_eio/net.mli
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ type 'a t = 'a r

(** {2 Out-bound Connections} *)

val connect : sw:Switch.t -> [> 'tag ty] t -> Sockaddr.stream -> 'tag stream_socket_ty r
val connect : sw:Switch.t -> ?bind:Sockaddr.stream -> [> 'tag ty] t -> Sockaddr.stream -> 'tag stream_socket_ty r
(** [connect ~sw t addr] is a new socket connected to remote address [addr].
The new socket will be closed when [sw] finishes, unless closed manually first. *)
The new socket will be closed when [sw] finishes, unless closed manually first.
@param bind Set the outbound client address. *)

val with_tcp_connect :
?timeout:Time.Timeout.t ->
Expand Down Expand Up @@ -346,7 +348,7 @@ module Pi : sig
t -> reuse_addr:bool -> reuse_port:bool -> backlog:int -> sw:Switch.t ->
Sockaddr.stream -> tag listening_socket_ty r

val connect : t -> sw:Switch.t -> Sockaddr.stream -> tag stream_socket_ty r
val connect : t -> sw:Switch.t -> ?bind:Sockaddr.stream -> Sockaddr.stream -> tag stream_socket_ty r

val datagram_socket :
t
Expand Down
8 changes: 5 additions & 3 deletions lib_eio_linux/eio_linux.ml
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,12 @@ let socket_domain_of = function
~v4:(fun _ -> Unix.PF_INET)
~v6:(fun _ -> Unix.PF_INET6)

let connect ~sw connect_addr =
let connect ~sw ?bind connect_addr =
let addr = Eio_unix.Net.sockaddr_to_unix connect_addr in
let sock_unix = Unix.socket ~cloexec:true (socket_domain_of connect_addr) Unix.SOCK_STREAM 0 in
let sock = Fd.of_unix ~sw ~seekable:false ~close_unix:true sock_unix in
Low_level.connect sock addr;
let bind = Option.map Eio_unix.Net.sockaddr_to_unix bind in
Low_level.connect sock ?bind addr;
(flow sock :> _ Eio_unix.Net.stream_socket)

module Impl = struct
Expand Down Expand Up @@ -296,7 +297,8 @@ module Impl = struct
Unix.listen sock_unix backlog;
(listening_socket sock :> _ Eio.Net.listening_socket_ty r)

let connect () ~sw addr = (connect ~sw addr :> [`Generic | `Unix] Eio.Net.stream_socket_ty r)
let connect () ~sw ?bind addr =
(connect ~sw ?bind addr :> [`Generic | `Unix] Eio.Net.stream_socket_ty r)

let datagram_socket () ~reuse_addr ~reuse_port ~sw saddr =
if reuse_addr then (
Expand Down
9 changes: 8 additions & 1 deletion lib_eio_linux/low_level.ml
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,15 @@ let splice src ~dst ~len =
else if res = 0 then raise End_of_file
else raise @@ Err.wrap (Uring.error_of_errno res) "splice" ""

let connect fd addr =
let try_bind fd = function
| None -> ()
| Some bind_addr ->
try Unix.bind fd bind_addr
with Unix.Unix_error (code, name, arg) -> raise @@ Err.wrap_fs code name arg

let connect fd ?bind addr =
Fd.use_exn "connect" fd @@ fun fd ->
try_bind fd bind;
let res = Sched.enter "connect" (enqueue_connect fd addr) in
if res < 0 then (
let ex =
Expand Down
2 changes: 1 addition & 1 deletion lib_eio_linux/low_level.mli
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ val splice : fd -> dst:fd -> len:int -> int
@raise End_of_file [src] is at the end of the file.
@raise Unix.Unix_error(EINVAL, "splice", _) if splice is not supported for these FDs. *)

val connect : fd -> Unix.sockaddr -> unit
val connect : fd -> ?bind:Unix.sockaddr -> Unix.sockaddr -> unit
(** [connect fd addr] attempts to connect socket [fd] to [addr]. *)

val await_readable : fd -> unit
Expand Down
11 changes: 7 additions & 4 deletions lib_eio_posix/low_level.ml
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,18 @@ let socket ~sw socket_domain socket_type protocol =
Unix.set_nonblock sock_unix;
Fd.of_unix ~sw ~blocking:false ~close_unix:true sock_unix

let connect fd addr =
let connect fd ?bind addr =
try
Fd.use_exn "connect" fd (fun fd -> Unix.connect fd addr)
Fd.use_exn "connect" fd @@ fun fd ->
Option.iter (Unix.bind fd) bind;
Unix.connect fd addr
with
| Unix.Unix_error ((EINTR | EAGAIN | EWOULDBLOCK | EINPROGRESS), _, _) ->
await_writable "connect" fd;
match Fd.use_exn "connect" fd Unix.getsockopt_error with
(match Fd.use_exn "connect" fd Unix.getsockopt_error with
| None -> ()
| Some code -> raise (Err.wrap code "connect-in-progress" "")
| Some code -> raise (Err.wrap code "connect-in-progress" ""))
| Unix.Unix_error (code, name, arg) -> raise (Err.wrap code name arg)

let accept ~sw sock =
Fd.use_exn "accept" sock @@ fun sock ->
Expand Down
2 changes: 1 addition & 1 deletion lib_eio_posix/low_level.mli
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ val read : fd -> bytes -> int -> int -> int
val write : fd -> bytes -> int -> int -> int

val socket : sw:Switch.t -> Unix.socket_domain -> Unix.socket_type -> int -> fd
val connect : fd -> Unix.sockaddr -> unit
val connect : fd -> ?bind:Unix.sockaddr -> Unix.sockaddr -> unit
val accept : sw:Switch.t -> fd -> fd * Unix.sockaddr

val shutdown : fd -> Unix.shutdown_command -> unit
Expand Down
9 changes: 5 additions & 4 deletions lib_eio_posix/net.ml
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ let listen ~reuse_addr ~reuse_port ~backlog ~sw (listen_addr : Eio.Net.Sockaddr.
);
(listening_socket ~hook sock :> _ Eio.Net.listening_socket_ty r)

let connect ~sw connect_addr =
let connect ~sw ?bind connect_addr =
let socket_type, addr =
match connect_addr with
| `Unix path -> Unix.SOCK_STREAM, Unix.ADDR_UNIX path
Expand All @@ -147,8 +147,9 @@ let connect ~sw connect_addr =
Unix.SOCK_STREAM, Unix.ADDR_INET (host, port)
in
let sock = Low_level.socket ~sw (socket_domain_of connect_addr) socket_type 0 in
let bind = Option.map Eio_unix.Net.sockaddr_to_unix bind in
try
Low_level.connect sock addr;
Low_level.connect sock ?bind addr;
(Flow.of_fd sock :> _ Eio_unix.Net.stream_socket)
with Unix.Unix_error (code, name, arg) -> raise (Err.wrap code name arg)

Expand All @@ -174,8 +175,8 @@ module Impl = struct

let listen () = listen

let connect () ~sw addr =
let socket = connect ~sw addr in
let connect () ~sw ?bind addr =
let socket = connect ~sw ?bind addr in
(socket :> [`Generic | `Unix] Eio.Net.stream_socket_ty r)

let datagram_socket () ~reuse_addr ~reuse_port ~sw saddr =
Expand Down

0 comments on commit 594bc12

Please sign in to comment.