Skip to content

Commit

Permalink
Move QuicSocket callbacks outside of QuicSocket
Browse files Browse the repository at this point in the history
Summary:
I have need of these in the proxygen WebTransport implementation, and don't want the full dependency on QuicSocket.

Also refactored WriteCallback into StreamWriteCallback and ConnWriteCallback, leaving the original WriteCallback as both for now.

Reviewed By: hanidamlaj

Differential Revision: D63486821

fbshipit-source-id: 4b16ad871c4deac4e262c12835ad5c457e9240da
  • Loading branch information
afrind authored and facebook-github-bot committed Oct 11, 2024
1 parent 5720c23 commit 48b2c9a
Showing 1 changed file with 31 additions and 21 deletions.
52 changes: 31 additions & 21 deletions proxygen/lib/http/session/test/MockQuicSocketDriver.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
uint64_t readOffset{0};
bool readEOF{false};
bool writeEOF{false};
QuicSocket::WriteCallback* pendingWriteCb{nullptr};
union WriteCallback {
quic::ConnectionWriteCallback* conn{nullptr};
quic::StreamWriteCallback* stream;
};
WriteCallback pendingWriteCb;
bool isPendingWriteCbStreamNotif{false};
// data written by application
folly::IOBufQueue unsentBuf{folly::IOBufQueue::cacheChainLength()};
Expand Down Expand Up @@ -380,17 +384,22 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {

EXPECT_CALL(*sock_, notifyPendingWriteOnStream(testing::_, testing::_))
.WillRepeatedly(testing::Invoke(
[this](StreamId id, QuicSocket::WriteCallback* wcb)
[this](StreamId id, quic::StreamWriteCallback* wcb)
-> folly::Expected<folly::Unit, quic::LocalErrorCode> {
checkNotReadOnlyStream(id);
return notifyPendingWriteImpl(id, wcb, /* isStreamNotif */ true);
return notifyPendingWriteImpl(
id,
StreamState::WriteCallback({.stream = wcb}),
/* isStreamNotif */ true);
}));

EXPECT_CALL(*sock_, notifyPendingWriteOnConnection(testing::_))
.WillRepeatedly(testing::Invoke(
[this](QuicSocket::WriteCallback* wcb)
[this](quic::ConnectionWriteCallback* wcb)
-> folly::Expected<folly::Unit, quic::LocalErrorCode> {
return notifyPendingWriteImpl(quic::kConnectionStreamId, wcb);
return notifyPendingWriteImpl(
quic::kConnectionStreamId,
StreamState::WriteCallback({.conn = wcb}));
}));

EXPECT_CALL(*sock_, getDatagramSizeLimit())
Expand Down Expand Up @@ -684,7 +693,7 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
stream.writeState = ERROR;
stream.unsentBuf.move();
stream.pendingWriteBuf.move();
stream.pendingWriteCb = nullptr;
stream.pendingWriteCb.stream = nullptr;
stream.pendingBufMetaLength = 0;
stream.unsentBufMeta.length = 0;
cancelDeliveryCallbacks(id, stream);
Expand Down Expand Up @@ -1030,13 +1039,13 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
void deliverWriteError(quic::StreamId id,
StreamState& stream,
QuicErrorCode errorCode) {
if (stream.pendingWriteCb) {
if (stream.pendingWriteCb.stream || stream.pendingWriteCb.conn) {
auto cb = stream.pendingWriteCb;
stream.pendingWriteCb = nullptr;
stream.pendingWriteCb.stream = nullptr;
if (stream.isPendingWriteCbStreamNotif) {
cb->onStreamWriteError(id, QuicError(errorCode));
cb.stream->onStreamWriteError(id, QuicError(errorCode));
} else {
cb->onConnectionWriteError(QuicError(errorCode));
cb.conn->onConnectionWriteError(QuicError(errorCode));
}
}
stream.writeState = ERROR;
Expand Down Expand Up @@ -1068,14 +1077,14 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
}

folly::Expected<folly::Unit, quic::LocalErrorCode> notifyPendingWriteImpl(
StreamId id, QuicSocket::WriteCallback* wcb, bool streamNotif = false) {
StreamId id, StreamState::WriteCallback wcb, bool streamNotif = false) {
auto& stream = streams_[id];
if (stream.writeState == PAUSED) {
stream.pendingWriteCb = wcb;
stream.isPendingWriteCbStreamNotif = streamNotif;
return folly::unit;
} else if (stream.writeState == OPEN) {
if (wcb == nullptr) {
if (wcb.stream == nullptr) {
return folly::makeUnexpected(LocalErrorCode::INVALID_WRITE_CALLBACK);
}

Expand All @@ -1092,7 +1101,7 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
if (stream.writeState != OPEN) {
return;
}
ERROR_IF(!stream.pendingWriteCb,
ERROR_IF(!stream.pendingWriteCb.stream,
fmt::format("write callback not set when calling "
"onConnectionWriteReady for streamId={}",
id),
Expand All @@ -1104,11 +1113,11 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
}

auto writeCb = stream.pendingWriteCb;
stream.pendingWriteCb = nullptr;
stream.pendingWriteCb.stream = nullptr;
if (stream.isPendingWriteCbStreamNotif) {
writeCb->onStreamWriteReady(id, maxStreamToWrite);
writeCb.stream->onStreamWriteReady(id, maxStreamToWrite);
} else {
writeCb->onConnectionWriteReady(maxConnToWrite);
writeCb.conn->onConnectionWriteReady(maxConnToWrite);
}
},
true);
Expand Down Expand Up @@ -1368,7 +1377,7 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
auto& stream = it.second;
stream.readCB = nullptr;
stream.peekCB = nullptr;
stream.pendingWriteCb = nullptr;
stream.pendingWriteCb.stream = nullptr;
}
if (cb) {
cb->onConnectionEnd();
Expand Down Expand Up @@ -1654,7 +1663,8 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
// now check onConnectionWriteReady/onStreamWriteReady call is warranted.
uint64_t maxWritableOnStream = maxStreamWritable(streamId);
uint64_t maxWritableOnConn = maxConnWritable();
bool shouldResume = stream.writeState == OPEN && stream.pendingWriteCb &&
bool shouldResume = stream.writeState == OPEN &&
stream.pendingWriteCb.stream &&
(maxWritableOnConn > 0 || maxWritableOnStream > 0);

if (shouldResume) {
Expand All @@ -1669,9 +1679,9 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
streamNotif = stream.isPendingWriteCbStreamNotif] {
if (!*deleted) {
if (streamNotif) {
wcb->onStreamWriteReady(streamId, window);
wcb.stream->onStreamWriteReady(streamId, window);
} else {
wcb->onConnectionWriteReady(window);
wcb.conn->onConnectionWriteReady(window);
}
}
},
Expand All @@ -1680,7 +1690,7 @@ class MockQuicSocketDriver : public folly::EventBase::LoopCallback {
if (streamId != quic::kConnectionStreamId) {
sock_->connCb_->onFlowControlUpdate(streamId);
}
stream.pendingWriteCb = nullptr;
stream.pendingWriteCb.stream = nullptr;
if (!stream.unsentBuf.empty()) {
// re-invoke write chain with the pending data
sock_->writeChain(
Expand Down

0 comments on commit 48b2c9a

Please sign in to comment.