Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

std.net: use send/recv for streams, support vectorized network io on windows #19751

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lib/std/c.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3525,7 +3525,7 @@ pub const itimerspec = switch (native_os) {
};
pub const msghdr = switch (native_os) {
.linux => linux.msghdr,
.openbsd, .emscripten, .dragonfly, .freebsd, .netbsd, .haiku, .solaris, .illumos => extern struct {
.openbsd, .emscripten, .dragonfly, .freebsd, .netbsd, .haiku, .solaris, .illumos, .macos => extern struct {
/// optional address
name: ?*sockaddr,
/// size of address
Expand All @@ -3545,7 +3545,7 @@ pub const msghdr = switch (native_os) {
};
pub const msghdr_const = switch (native_os) {
.linux => linux.msghdr_const,
.openbsd, .emscripten, .dragonfly, .freebsd, .netbsd, .haiku, .solaris, .illumos => extern struct {
.openbsd, .emscripten, .dragonfly, .freebsd, .netbsd, .haiku, .solaris, .illumos, .macos => extern struct {
/// optional address
name: ?*const sockaddr,
/// size of address
Expand Down
91 changes: 35 additions & 56 deletions lib/std/crypto/tls/Client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub const StreamInterface = struct {
/// The `iovecs` parameter is mutable because so that function may to
/// mutate the fields in order to handle partial reads from the underlying
/// stream layer.
pub fn readv(this: @This(), iovecs: []std.posix.iovec) ReadError!usize {
pub fn readv(this: @This(), iovecs: []std.net.IoSlice) ReadError!usize {
_ = .{ this, iovecs };
@panic("unimplemented");
}
Expand All @@ -72,7 +72,7 @@ pub const StreamInterface = struct {

/// Returns the number of bytes read, which may be less than the buffer
/// space provided. A short read does not indicate end-of-stream.
pub fn writev(this: @This(), iovecs: []const std.posix.iovec_const) WriteError!usize {
pub fn writev(this: @This(), iovecs: []const std.net.IoSliceConst) WriteError!usize {
_ = .{ this, iovecs };
@panic("unimplemented");
}
Expand All @@ -81,7 +81,7 @@ pub const StreamInterface = struct {
/// space provided, indicating end-of-stream.
/// The `iovecs` parameter is mutable in case this function needs to mutate
/// the fields in order to handle partial writes from the underlying layer.
pub fn writevAll(this: @This(), iovecs: []std.posix.iovec_const) WriteError!usize {
pub fn writevAll(this: @This(), iovecs: []std.net.IoSliceConst) WriteError!usize {
// This can be implemented in terms of writev, or specialized if desired.
_ = .{ this, iovecs };
@panic("unimplemented");
Expand Down Expand Up @@ -215,16 +215,9 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
} ++ int2(@intCast(out_handshake.len + host_len)) ++ out_handshake;

{
var iovecs = [_]std.posix.iovec_const{
.{
.base = &plaintext_header,
.len = plaintext_header.len,
},
.{
.base = host.ptr,
.len = host.len,
},
};
var iovecs: [2]std.net.IoSliceConst = undefined;
iovecs[0].set(&plaintext_header);
iovecs[1].set(host);
try stream.writevAll(&iovecs);
}

Expand Down Expand Up @@ -677,10 +670,8 @@ pub fn init(stream: anytype, ca_bundle: Certificate.Bundle, host: []const u8) In
P.AEAD.encrypt(ciphertext, auth_tag, &out_cleartext, ad, nonce, p.client_handshake_key);

const both_msgs = client_change_cipher_spec_msg ++ finished_msg;
var both_msgs_vec = [_]std.posix.iovec_const{.{
.base = &both_msgs,
.len = both_msgs.len,
}};
var both_msgs_vec: [1]std.net.IoSliceConst = undefined;
both_msgs_vec[0].set(&both_msgs);
try stream.writevAll(&both_msgs_vec);

const client_secret = hkdfExpandLabel(P.Hkdf, p.master_secret, "c ap traffic", &handshake_hash, P.Hash.digest_length);
Expand Down Expand Up @@ -755,7 +746,7 @@ pub fn writeAllEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !v
/// TLS session, or a truncation attack.
pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usize {
var ciphertext_buf: [tls.max_ciphertext_record_len * 4]u8 = undefined;
var iovecs_buf: [6]std.posix.iovec_const = undefined;
var iovecs_buf: [6]std.net.IoSliceConst = undefined;
var prepared = prepareCiphertextRecord(c, &iovecs_buf, &ciphertext_buf, bytes, .application_data);
if (end) {
prepared.iovec_end += prepareCiphertextRecord(
Expand All @@ -776,8 +767,8 @@ pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usiz
var total_amt: usize = 0;
while (true) {
var amt = try stream.writev(iovecs_buf[i..iovec_end]);
while (amt >= iovecs_buf[i].len) {
const encrypted_amt = iovecs_buf[i].len;
while (amt >= iovecs_buf[i].len()) {
const encrypted_amt = iovecs_buf[i].len();
total_amt += encrypted_amt - overhead_len;
amt -= encrypted_amt;
i += 1;
Expand All @@ -789,14 +780,13 @@ pub fn writeEnd(c: *Client, stream: anytype, bytes: []const u8, end: bool) !usiz
// not sent; otherwise the caller would not know to retry the call.
if (amt == 0 and (!end or i < iovec_end - 1)) return total_amt;
}
iovecs_buf[i].base += amt;
iovecs_buf[i].len -= amt;
iovecs_buf[i].discard(amt);
}
}

fn prepareCiphertextRecord(
c: *Client,
iovecs: []std.posix.iovec_const,
iovecs: []std.net.IoSliceConst,
ciphertext_buf: []u8,
bytes: []const u8,
inner_content_type: tls.ContentType,
Expand Down Expand Up @@ -863,10 +853,7 @@ fn prepareCiphertextRecord(
P.AEAD.encrypt(ciphertext, auth_tag, cleartext, ad, nonce, p.client_key);

const record = ciphertext_buf[record_start..ciphertext_end];
iovecs[iovec_end] = .{
.base = record.ptr,
.len = record.len,
};
iovecs[iovec_end].set(record);
iovec_end += 1;
}
},
Expand Down Expand Up @@ -908,7 +895,7 @@ pub fn readAll(c: *Client, stream: anytype, buffer: []u8) !usize {
/// stream is not an error condition.
/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
/// order to handle partial reads from the underlying stream layer.
pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize {
pub fn readv(c: *Client, stream: anytype, iovecs: []std.net.IoSlice) !usize {
return readvAtLeast(c, stream, iovecs, 1);
}

Expand All @@ -919,7 +906,7 @@ pub fn readv(c: *Client, stream: anytype, iovecs: []std.posix.iovec) !usize {
/// Reaching the end of the stream is not an error condition.
/// The `iovecs` parameter is mutable because this function needs to mutate the fields in
/// order to handle partial reads from the underlying stream layer.
pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len: usize) !usize {
pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.net.IoSlice, len: usize) !usize {
if (c.eof()) return 0;

var off_i: usize = 0;
Expand All @@ -928,12 +915,11 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len:
var amt = try c.readvAdvanced(stream, iovecs[vec_i..]);
off_i += amt;
if (c.eof() or off_i >= len) return off_i;
while (amt >= iovecs[vec_i].len) {
amt -= iovecs[vec_i].len;
while (amt >= iovecs[vec_i].len()) {
amt -= iovecs[vec_i].len();
vec_i += 1;
}
iovecs[vec_i].base += amt;
iovecs[vec_i].len -= amt;
iovecs[vec_i].discard(amt);
}
}

Expand All @@ -945,7 +931,7 @@ pub fn readvAtLeast(c: *Client, stream: anytype, iovecs: []std.posix.iovec, len:
/// function asserts that `eof()` is `false`.
/// See `readv` for a higher level function that has the same, familiar API as
/// other read functions, such as `std.fs.File.read`.
pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iovec) !usize {
pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.net.IoSlice) !usize {
var vp: VecPut = .{ .iovecs = iovecs };

// Give away the buffered cleartext we have, if any.
Expand Down Expand Up @@ -998,16 +984,9 @@ pub fn readvAdvanced(c: *Client, stream: anytype, iovecs: []const std.posix.iove
c.partial_cleartext_idx = 0;
const first_iov = c.partially_read_buffer[c.partial_ciphertext_end..];

var ask_iovecs_buf: [2]std.posix.iovec = .{
.{
.base = first_iov.ptr,
.len = first_iov.len,
},
.{
.base = &in_stack_buffer,
.len = in_stack_buffer.len,
},
};
var ask_iovecs_buf: [2]std.net.IoSlice = undefined;
ask_iovecs_buf[0].set(first_iov);
ask_iovecs_buf[1].set(&in_stack_buffer);

// Cleartext capacity of output buffer, in records. Minimum one full record.
const buf_cap = @max(cleartext_buf_len / max_ciphertext_len, 1);
Expand Down Expand Up @@ -1352,7 +1331,7 @@ fn SchemeEddsa(comptime scheme: tls.SignatureScheme) type {

/// Abstraction for sending multiple byte buffers to a slice of iovecs.
const VecPut = struct {
iovecs: []const std.posix.iovec,
iovecs: []const std.net.IoSlice,
idx: usize = 0,
off: usize = 0,
total: usize = 0,
Expand All @@ -1364,12 +1343,12 @@ const VecPut = struct {
var bytes_i: usize = 0;
while (true) {
const v = vp.iovecs[vp.idx];
const dest = v.base[vp.off..v.len];
const dest = v.base()[vp.off..v.len()];
const src = bytes[bytes_i..][0..@min(dest.len, bytes.len - bytes_i)];
@memcpy(dest[0..src.len], src);
bytes_i += src.len;
vp.off += src.len;
if (vp.off >= v.len) {
if (vp.off >= v.len()) {
vp.off = 0;
vp.idx += 1;
if (vp.idx >= vp.iovecs.len) {
Expand All @@ -1388,15 +1367,15 @@ const VecPut = struct {
fn peek(vp: VecPut) []u8 {
if (vp.idx >= vp.iovecs.len) return &.{};
const v = vp.iovecs[vp.idx];
return v.base[vp.off..v.len];
return v.base()[vp.off..v.len()];
}

// After writing to the result of peek(), one can call next() to
// advance the cursor.
fn next(vp: *VecPut, len: usize) void {
vp.total += len;
vp.off += len;
if (vp.off >= vp.iovecs[vp.idx].len) {
if (vp.off >= vp.iovecs[vp.idx].len()) {
vp.off = 0;
vp.idx += 1;
}
Expand All @@ -1405,22 +1384,22 @@ const VecPut = struct {
fn freeSize(vp: VecPut) usize {
if (vp.idx >= vp.iovecs.len) return 0;
var total: usize = 0;
total += vp.iovecs[vp.idx].len - vp.off;
total += vp.iovecs[vp.idx].len() - vp.off;
if (vp.idx + 1 >= vp.iovecs.len) return total;
for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.len;
for (vp.iovecs[vp.idx + 1 ..]) |v| total += v.len();
return total;
}
};

/// Limit iovecs to a specific byte size.
fn limitVecs(iovecs: []std.posix.iovec, len: usize) []std.posix.iovec {
var bytes_left: usize = len;
fn limitVecs(iovecs: []std.net.IoSlice, len: usize) []std.net.IoSlice {
var bytes_left: u32 = @intCast(len);
for (iovecs, 0..) |*iovec, vec_i| {
if (bytes_left <= iovec.len) {
iovec.len = bytes_left;
if (bytes_left <= iovec.len()) {
iovec.setLen(bytes_left);
return iovecs[0 .. vec_i + 1];
}
bytes_left -= iovec.len;
bytes_left -= @intCast(iovec.len());
}
return iovecs;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/std/fs/File.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1482,7 +1482,7 @@ pub fn writeFileAll(self: File, in_file: File, args: WriteFileOptions) WriteFile
error.Unseekable,
error.FastOpenAlreadyInProgress,
error.MessageTooBig,
error.FileDescriptorNotASocket,
error.SocketNotBound,
error.NetworkUnreachable,
error.NetworkSubsystemFailed,
=> return self.writeFileAllUnseekable(in_file, args),
Expand Down
20 changes: 9 additions & 11 deletions lib/std/http/Client.zig
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,21 @@ pub const Connection = struct {

pub const Protocol = enum { plain, tls };

pub fn readvDirectTls(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
pub fn readvDirectTls(conn: *Connection, buffers: []std.net.IoSlice) ReadError!usize {
return conn.tls_client.readv(conn.stream, buffers) catch |err| {
// https://github.com/ziglang/zig/issues/2473
if (mem.startsWith(u8, @errorName(err), "TlsAlert")) return error.TlsAlert;

switch (err) {
error.TlsConnectionTruncated, error.TlsRecordOverflow, error.TlsDecodeError, error.TlsBadRecordMac, error.TlsBadLength, error.TlsIllegalParameter, error.TlsUnexpectedMessage => return error.TlsFailure,
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
}
};
}

pub fn readvDirect(conn: *Connection, buffers: []std.posix.iovec) ReadError!usize {
pub fn readvDirect(conn: *Connection, buffers: []std.net.IoSlice) ReadError!usize {
if (conn.protocol == .tls) {
if (disable_tls) unreachable;

Expand All @@ -243,7 +243,7 @@ pub const Connection = struct {

return conn.stream.readv(buffers) catch |err| switch (err) {
error.ConnectionTimedOut => return error.ConnectionTimedOut,
error.ConnectionResetByPeer, error.BrokenPipe => return error.ConnectionResetByPeer,
error.ConnectionResetByPeer => return error.ConnectionResetByPeer,
else => return error.UnexpectedReadFailure,
};
}
Expand All @@ -252,9 +252,8 @@ pub const Connection = struct {
pub fn fill(conn: *Connection) ReadError!void {
if (conn.read_end != conn.read_start) return;

var iovecs = [1]std.posix.iovec{
.{ .base = &conn.read_buf, .len = conn.read_buf.len },
};
var iovecs: [1]std.net.IoSlice = undefined;
iovecs[0].set(&conn.read_buf);
const nread = try conn.readvDirect(&iovecs);
if (nread == 0) return error.EndOfStream;
conn.read_start = 0;
Expand Down Expand Up @@ -288,10 +287,9 @@ pub const Connection = struct {
return available_read;
}

var iovecs = [2]std.posix.iovec{
.{ .base = buffer.ptr, .len = buffer.len },
.{ .base = &conn.read_buf, .len = conn.read_buf.len },
};
var iovecs: [2]std.net.IoSlice = undefined;
iovecs[0].set(buffer);
iovecs[1].set(&conn.read_buf);
const nread = try conn.readvDirect(&iovecs);

if (nread > buffer.len) {
Expand Down
Loading
Loading