diff --git a/proxygen/lib/http/session/test/MockQuicSocketDriver.h b/proxygen/lib/http/session/test/MockQuicSocketDriver.h index 71542dc589..bf9752e8b0 100644 --- a/proxygen/lib/http/session/test/MockQuicSocketDriver.h +++ b/proxygen/lib/http/session/test/MockQuicSocketDriver.h @@ -75,7 +75,11 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { uint64_t readOffset{0}; bool readEOF{false}; bool writeEOF{false}; - QuicSocket::WriteCallback* pendingWriteCb{nullptr}; + union WriteCallback { + quic::ConnectionWriteCallback* conn{nullptr}; + quic::StreamWriteCallback* stream; + }; + WriteCallback pendingWriteCb; bool isPendingWriteCbStreamNotif{false}; // data written by application folly::IOBufQueue unsentBuf{folly::IOBufQueue::cacheChainLength()}; @@ -380,17 +384,22 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { EXPECT_CALL(*sock_, notifyPendingWriteOnStream(testing::_, testing::_)) .WillRepeatedly(testing::Invoke( - [this](StreamId id, QuicSocket::WriteCallback* wcb) + [this](StreamId id, quic::StreamWriteCallback* wcb) -> folly::Expected { checkNotReadOnlyStream(id); - return notifyPendingWriteImpl(id, wcb, /* isStreamNotif */ true); + return notifyPendingWriteImpl( + id, + StreamState::WriteCallback({.stream = wcb}), + /* isStreamNotif */ true); })); EXPECT_CALL(*sock_, notifyPendingWriteOnConnection(testing::_)) .WillRepeatedly(testing::Invoke( - [this](QuicSocket::WriteCallback* wcb) + [this](quic::ConnectionWriteCallback* wcb) -> folly::Expected { - return notifyPendingWriteImpl(quic::kConnectionStreamId, wcb); + return notifyPendingWriteImpl( + quic::kConnectionStreamId, + StreamState::WriteCallback({.conn = wcb})); })); EXPECT_CALL(*sock_, getDatagramSizeLimit()) @@ -684,7 +693,7 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { stream.writeState = ERROR; stream.unsentBuf.move(); stream.pendingWriteBuf.move(); - stream.pendingWriteCb = nullptr; + stream.pendingWriteCb.stream = nullptr; stream.pendingBufMetaLength = 0; stream.unsentBufMeta.length = 0; cancelDeliveryCallbacks(id, stream); @@ -1030,13 +1039,13 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { void deliverWriteError(quic::StreamId id, StreamState& stream, QuicErrorCode errorCode) { - if (stream.pendingWriteCb) { + if (stream.pendingWriteCb.stream || stream.pendingWriteCb.conn) { auto cb = stream.pendingWriteCb; - stream.pendingWriteCb = nullptr; + stream.pendingWriteCb.stream = nullptr; if (stream.isPendingWriteCbStreamNotif) { - cb->onStreamWriteError(id, QuicError(errorCode)); + cb.stream->onStreamWriteError(id, QuicError(errorCode)); } else { - cb->onConnectionWriteError(QuicError(errorCode)); + cb.conn->onConnectionWriteError(QuicError(errorCode)); } } stream.writeState = ERROR; @@ -1068,14 +1077,14 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { } folly::Expected notifyPendingWriteImpl( - StreamId id, QuicSocket::WriteCallback* wcb, bool streamNotif = false) { + StreamId id, StreamState::WriteCallback wcb, bool streamNotif = false) { auto& stream = streams_[id]; if (stream.writeState == PAUSED) { stream.pendingWriteCb = wcb; stream.isPendingWriteCbStreamNotif = streamNotif; return folly::unit; } else if (stream.writeState == OPEN) { - if (wcb == nullptr) { + if (wcb.stream == nullptr) { return folly::makeUnexpected(LocalErrorCode::INVALID_WRITE_CALLBACK); } @@ -1092,7 +1101,7 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { if (stream.writeState != OPEN) { return; } - ERROR_IF(!stream.pendingWriteCb, + ERROR_IF(!stream.pendingWriteCb.stream, fmt::format("write callback not set when calling " "onConnectionWriteReady for streamId={}", id), @@ -1104,11 +1113,11 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { } auto writeCb = stream.pendingWriteCb; - stream.pendingWriteCb = nullptr; + stream.pendingWriteCb.stream = nullptr; if (stream.isPendingWriteCbStreamNotif) { - writeCb->onStreamWriteReady(id, maxStreamToWrite); + writeCb.stream->onStreamWriteReady(id, maxStreamToWrite); } else { - writeCb->onConnectionWriteReady(maxConnToWrite); + writeCb.conn->onConnectionWriteReady(maxConnToWrite); } }, true); @@ -1368,7 +1377,7 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { auto& stream = it.second; stream.readCB = nullptr; stream.peekCB = nullptr; - stream.pendingWriteCb = nullptr; + stream.pendingWriteCb.stream = nullptr; } if (cb) { cb->onConnectionEnd(); @@ -1654,7 +1663,8 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { // now check onConnectionWriteReady/onStreamWriteReady call is warranted. uint64_t maxWritableOnStream = maxStreamWritable(streamId); uint64_t maxWritableOnConn = maxConnWritable(); - bool shouldResume = stream.writeState == OPEN && stream.pendingWriteCb && + bool shouldResume = stream.writeState == OPEN && + stream.pendingWriteCb.stream && (maxWritableOnConn > 0 || maxWritableOnStream > 0); if (shouldResume) { @@ -1669,9 +1679,9 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { streamNotif = stream.isPendingWriteCbStreamNotif] { if (!*deleted) { if (streamNotif) { - wcb->onStreamWriteReady(streamId, window); + wcb.stream->onStreamWriteReady(streamId, window); } else { - wcb->onConnectionWriteReady(window); + wcb.conn->onConnectionWriteReady(window); } } }, @@ -1680,7 +1690,7 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback { if (streamId != quic::kConnectionStreamId) { sock_->connCb_->onFlowControlUpdate(streamId); } - stream.pendingWriteCb = nullptr; + stream.pendingWriteCb.stream = nullptr; if (!stream.unsentBuf.empty()) { // re-invoke write chain with the pending data sock_->writeChain(