diff --git a/internal/quic/conn_send.go b/internal/quic/conn_send.go index a8d930898..c2d8d146b 100644 --- a/internal/quic/conn_send.go +++ b/internal/quic/conn_send.go @@ -60,7 +60,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { pad := false var sentInitial *sentPacket if c.keysInitial.canWrite() { - pnumMaxAcked := c.acks[initialSpace].largestSeen() + pnumMaxAcked := c.loss.spaces[initialSpace].maxAcked pnum := c.loss.nextNumber(initialSpace) p := longPacket{ ptype: packetTypeInitial, @@ -93,7 +93,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // Handshake packet. if c.keysHandshake.canWrite() { - pnumMaxAcked := c.acks[handshakeSpace].largestSeen() + pnumMaxAcked := c.loss.spaces[handshakeSpace].maxAcked pnum := c.loss.nextNumber(handshakeSpace) p := longPacket{ ptype: packetTypeHandshake, @@ -124,7 +124,7 @@ func (c *Conn) maybeSend(now time.Time) (next time.Time) { // 1-RTT packet. if c.keysAppData.canWrite() { - pnumMaxAcked := c.acks[appDataSpace].largestSeen() + pnumMaxAcked := c.loss.spaces[appDataSpace].maxAcked pnum := c.loss.nextNumber(appDataSpace) c.w.start1RTTPacket(pnum, pnumMaxAcked, dstConnID) c.appendFrames(now, appDataSpace, pnum, limit) diff --git a/internal/quic/conn_send_test.go b/internal/quic/conn_send_test.go index 822783c41..2205ff2f7 100644 --- a/internal/quic/conn_send_test.go +++ b/internal/quic/conn_send_test.go @@ -38,3 +38,46 @@ func TestAckElicitingAck(t *testing.T) { } t.Errorf("after sending %v PINGs, got no ack-eliciting response", count) } + +func TestSendPacketNumberSize(t *testing.T) { + tc := newTestConn(t, clientSide, permissiveTransportParameters) + tc.handshake() + + recvPing := func() *testPacket { + t.Helper() + tc.conn.ping(appDataSpace) + p := tc.readPacket() + if p == nil { + t.Fatalf("want packet containing PING, got none") + } + return p + } + + // Desynchronize the packet numbers the conn is sending and the ones it is receiving, + // by having the conn send a number of unacked packets. + for i := 0; i < 16; i++ { + recvPing() + } + + // Establish the maximum packet number the conn has received an ACK for. + maxAcked := recvPing().num + tc.writeAckForAll() + + // Make the conn send a sequence of packets. + // Check that the packet number is encoded with two bytes once the difference between the + // current packet and the max acked one is sufficiently large. + for want := maxAcked + 1; want < maxAcked+0x100; want++ { + p := recvPing() + if p.num != want { + t.Fatalf("received packet number %v, want %v", p.num, want) + } + gotPnumLen := int(p.header&0x03) + 1 + wantPnumLen := 1 + if p.num-maxAcked >= 0x80 { + wantPnumLen = 2 + } + if gotPnumLen != wantPnumLen { + t.Fatalf("packet number 0x%x encoded with %v bytes, want %v (max acked = %v)", p.num, gotPnumLen, wantPnumLen, maxAcked) + } + } +} diff --git a/internal/quic/conn_test.go b/internal/quic/conn_test.go index 058aa7edc..abf7eede7 100644 --- a/internal/quic/conn_test.go +++ b/internal/quic/conn_test.go @@ -82,6 +82,7 @@ func (d testDatagram) String() string { type testPacket struct { ptype packetType + header byte version uint32 num packetNumber keyPhaseBit bool @@ -599,12 +600,18 @@ func (tc *testConn) readFrame() (debugFrame, packetType) { func (tc *testConn) wantDatagram(expectation string, want *testDatagram) { tc.t.Helper() got := tc.readDatagram() - if !reflect.DeepEqual(got, want) { + if !datagramEqual(got, want) { tc.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) } } func datagramEqual(a, b *testDatagram) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } if a.paddedSize != b.paddedSize || a.addr != b.addr || len(a.packets) != len(b.packets) { @@ -622,7 +629,7 @@ func datagramEqual(a, b *testDatagram) bool { func (tc *testConn) wantPacket(expectation string, want *testPacket) { tc.t.Helper() got := tc.readPacket() - if !reflect.DeepEqual(got, want) { + if !packetEqual(got, want) { tc.t.Fatalf("%v:\ngot packet: %v\nwant packet: %v", expectation, got, want) } } @@ -630,8 +637,10 @@ func (tc *testConn) wantPacket(expectation string, want *testPacket) { func packetEqual(a, b *testPacket) bool { ac := *a ac.frames = nil + ac.header = 0 bc := *b bc.frames = nil + bc.header = 0 if !reflect.DeepEqual(ac, bc) { return false } @@ -839,6 +848,7 @@ func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) } d.packets = append(d.packets, &testPacket{ ptype: p.ptype, + header: buf[0], version: p.version, num: p.num, dstConnID: p.dstConnID, @@ -880,6 +890,7 @@ func parseTestDatagram(t *testing.T, te *testEndpoint, tc *testConn, buf []byte) } d.packets = append(d.packets, &testPacket{ ptype: packetType1RTT, + header: hdr[0], num: pnum, dstConnID: hdr[1:][:len(tc.peerConnID)], keyPhaseBit: hdr[0]&keyPhaseBit != 0, diff --git a/internal/quic/endpoint_test.go b/internal/quic/endpoint_test.go index 2a6daa076..452d26052 100644 --- a/internal/quic/endpoint_test.go +++ b/internal/quic/endpoint_test.go @@ -13,7 +13,6 @@ import ( "io" "net" "net/netip" - "reflect" "testing" "time" ) @@ -242,7 +241,7 @@ func (te *testEndpoint) readDatagram() *testDatagram { func (te *testEndpoint) wantDatagram(expectation string, want *testDatagram) { te.t.Helper() got := te.readDatagram() - if !reflect.DeepEqual(got, want) { + if !datagramEqual(got, want) { te.t.Fatalf("%v:\ngot datagram: %v\nwant datagram: %v", expectation, got, want) } } diff --git a/internal/quic/tls_test.go b/internal/quic/tls_test.go index 14f74a00a..9c1dd364e 100644 --- a/internal/quic/tls_test.go +++ b/internal/quic/tls_test.go @@ -10,7 +10,6 @@ import ( "crypto/tls" "crypto/x509" "errors" - "reflect" "testing" "time" ) @@ -56,7 +55,7 @@ func (tc *testConn) handshake() { fillCryptoFrames(want, tc.cryptoDataOut) i++ } - if !reflect.DeepEqual(got, want) { + if !datagramEqual(got, want) { t.Fatalf("dgram %v:\ngot %v\n\nwant %v", i, got, want) } if i >= len(dgrams) {