diff --git a/src/aioquic/quic/connection.py b/src/aioquic/quic/connection.py index b99c4ab85..4fa965388 100644 --- a/src/aioquic/quic/connection.py +++ b/src/aioquic/quic/connection.py @@ -5,6 +5,7 @@ from dataclasses import dataclass from enum import Enum from functools import partial +from itertools import count from typing import Any, Deque, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple from .. import tls @@ -416,6 +417,13 @@ def __init__( 0x31: (self._handle_datagram_frame, EPOCHS("01")), } + if self._is_client: + self._bidi_stream_id = count(0, 4) + self._uni_stream_id = count(2, 4) + else: + self._bidi_stream_id = count(1, 4) + self._uni_stream_id = count(3, 4) + @property def configuration(self) -> QuicConfiguration: return self._configuration @@ -623,10 +631,7 @@ def get_next_available_stream_id(self, is_unidirectional=False) -> int: """ Return the stream ID for the next stream created by this endpoint. """ - stream_id = (int(is_unidirectional) << 1) | int(not self._is_client) - while stream_id in self._streams or stream_id in self._streams_finished: - stream_id += 4 - return stream_id + return next(self._uni_stream_id if is_unidirectional else self._bidi_stream_id) def get_timer(self) -> Optional[float]: """