Skip to content

Commit

Permalink
add more comments and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed May 7, 2023
1 parent 38cff0f commit 241fd6a
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 136 deletions.
142 changes: 84 additions & 58 deletions p2p/net/swarm/dial_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,20 @@ type pendRequest struct {
addrs map[ma.Multiaddr]struct{} // pending addr dials
}

// addrDial tracks dials to a particular multiaddress.
type addrDial struct {
addr ma.Multiaddr
ctx context.Context
conn *Conn
err error
// addr is the address dialed
addr ma.Multiaddr
// ctx is the context used for dialing the address
ctx context.Context
// conn is the established connection on success
conn *Conn
// err is the err on dialing the address
err error
// requests is the list of dialRequests interested in this dial
requests []int
dialed bool
// dialed indicates whether we have triggered the dial to the address
dialed bool
}

type dialWorker struct {
Expand Down Expand Up @@ -78,20 +85,22 @@ func (w *dialWorker) loop() {
defer w.wg.Done()
defer w.s.limiter.clearAllPeerDials(w.peer)

// used to signal readiness to dial and completion of the dial
ready := make(chan struct{})
close(ready)
// dq is used to pace dials to different addresses of the peer
dq := newDialQueue()
// currDials is the number of dials in flight
currDials := 0
st := w.cl.Now()
// timer is the timer used to trigger dials
timer := w.cl.InstantTimer(st.Add(math.MaxInt64))
timerRunning := true
scheduleNext := func() {
// scheduleNextDial updates timer for triggering the next dial
scheduleNextDial := func() {
if timerRunning && !timer.Stop() {
<-timer.Ch()
}
timerRunning = false
if dq.len() > 0 {
// if there are no dials in flight, trigger the next dials immediately
if currDials == 0 {
timer.Reset(st)
} else {
Expand Down Expand Up @@ -196,9 +205,11 @@ loop:
dq.add(network.AddrDelay{Addr: a, Delay: addrDelay[a]})
}
}
scheduleNext()
scheduleNextDial()

case <-timer.Ch():
// we dont check the delay here because an early trigger means all in flight
// dials have completed
for _, adelay := range dq.nextBatch() {
// spawn the dial
ad := w.pending[adelay.Addr]
Expand All @@ -211,7 +222,7 @@ loop:
}
}
timerRunning = false
scheduleNext()
scheduleNextDial()

case res := <-w.resch:
if res.Conn != nil {
Expand Down Expand Up @@ -255,7 +266,8 @@ loop:
w.s.backf.AddBackoff(w.peer, res.Addr)
}
w.dispatchError(ad, res.Err)
scheduleNext()
// only schedule next dial on error
scheduleNextDial()
}
}
}
Expand Down Expand Up @@ -294,56 +306,67 @@ func (w *dialWorker) dispatchError(ad *addrDial, err error) {
// this is necessary to support active listen scenarios, where a new dial comes in while
// another dial is in progress, and needs to do a direct connection without inhibitions from
// dial backoff.
// it is also necessary to preserve consisent behaviour with the old dialer -- TestDialBackoff
// regresses without this.
if err == ErrDialBackoff {
delete(w.pending, ad.addr)
}
}

// rankAddrs ranks addresses for dialing. if it's a simConnect request we
// dial all addresses immediately without any delay
func (w *dialWorker) rankAddrs(addrs []ma.Multiaddr, isSimConnect bool) []network.AddrDelay {
if isSimConnect {
return noDelayRanker(addrs)
}
return w.s.dialRanker(addrs)
}

// dialQueue is a priority queue used to schedule dials
type dialQueue struct {
q []network.AddrDelay
// q is the queue maintained as a heap
q []network.AddrDelay
// pos is the reverse map from address to its position in q
// the reverse map is required to provide efficient updates
pos map[ma.Multiaddr]int
}

func newDialQueue() *dialQueue {
return &dialQueue{pos: make(map[ma.Multiaddr]int)}
}

// add adds adelay to the queue. if another elements exists in the queue with
// the same address, it replaces that element.
func (dq *dialQueue) add(adelay network.AddrDelay) {
dq.remove(adelay.Addr)
dq.q = append(dq.q, adelay)
dq.pos[adelay.Addr] = len(dq.q) - 1
dq.heapify(len(dq.q) - 1)
}

// swap swaps elements at i and j maintaining the reverse map pos.
func (dq *dialQueue) swap(i, j int) {
dq.pos[dq.q[i].Addr] = j
dq.pos[dq.q[j].Addr] = i
dq.q[i], dq.q[j] = dq.q[j], dq.q[i]
}

// len is the length of the queue. Calling top on an empty queue panics.
func (dq *dialQueue) len() int {
return len(dq.q)
}

// top returns the top element of the queue
func (dq *dialQueue) top() network.AddrDelay {
return dq.q[0]
}

// pop removes the top element from the queue and returns it
func (dq *dialQueue) pop() network.AddrDelay {
v := dq.q[0]
dq.remove(v.Addr)
return v
}

// remove removes the element in the queue with address a
func (dq *dialQueue) remove(a ma.Multiaddr) {
pos, ok := dq.pos[a]
if !ok {
Expand All @@ -352,66 +375,69 @@ func (dq *dialQueue) remove(a ma.Multiaddr) {
dq.swap(pos, len(dq.q)-1)
dq.q = dq.q[:len(dq.q)-1]
delete(dq.pos, a)
dq.heapify(pos)
if pos < len(dq.q) {
dq.heapify(pos)
}
}

// heapify fixes the heap property for element at position i
func (dq *dialQueue) heapify(i int) {
if dq.len() == 0 {
return
}
dq.fixdown(i)
dq.fixup(i)
}

func (dq *dialQueue) fixup(i int) {
if dq.len() == 0 || i == 0 {
return
}
for i != 0 {
p := (i - 1) / 2
if dq.q[i].Delay < dq.q[p].Delay {
dq.swap(i, p)
i = p
continue
}
break
}
}

func (dq *dialQueue) fixdown(i int) {
if i >= dq.len() {
return
}
for {
v := dq.q[i].Delay
l, r := 2*i+1, 2*i+2
if l >= dq.len() && r >= dq.len() {
if i == 0 {
return
}
i = (i - 1) / 2
continue
if l >= dq.len() {
break
}
lv := dq.q[l].Delay
if v <= lv {
if r < dq.len() {
rv := dq.q[r].Delay
if v <= rv {
if i == 0 {
return
}
i = (i - 1) / 2
continue
} else {
dq.swap(i, r)
i = r
continue
}
} else {
if i == 0 {
return
}
i = (i - 1) / 2
continue
}
} else {
if r < dq.len() {
rv := dq.q[r].Delay
if lv <= rv {
dq.swap(i, l)
i = l
continue
} else {
dq.swap(i, r)
i = r
continue
}
} else {
if r >= dq.len() {
if dq.q[i].Delay > dq.q[l].Delay {
dq.swap(i, l)
i = l
continue
}
break
}
v, lv, rv := dq.q[i].Delay, dq.q[l].Delay, dq.q[r].Delay
if lv < v && lv <= rv {
dq.swap(i, l)
i = l
continue
}
if rv < v && rv <= lv {
dq.swap(i, r)
i = r
continue
}
break
}
}

// nextBatch returns all the elements in the queue with delay equal to the top element
// of the queue
func (dq *dialQueue) nextBatch() []network.AddrDelay {
if dq.len() == 0 {
return nil
Expand Down
Loading

0 comments on commit 241fd6a

Please sign in to comment.