Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply extra review feedback from #1340 #1400

Merged
merged 9 commits into from
Jul 23, 2021
3 changes: 3 additions & 0 deletions internal/db/read_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ type Writer interface {
msgs []*oplog.Message,
opt ...Option,
) error

// ScanRows will scan sql rows into the interface provided
ScanRows(rows *sql.Rows, result interface{}) error
}

const (
Expand Down
37 changes: 19 additions & 18 deletions internal/servers/controller/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -644,31 +644,26 @@ func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestCon
// period, this function returns an error.
func (tc *TestController) WaitForNextWorkerStatusUpdate(workerId string) error {
tc.Logger().Debug("waiting for next status report from worker", "worker", workerId)

if err := tc.waitForNextWorkerStatusUpdate(workerId); err != nil {
tc.Logger().Error("error waiting for next status report from worker", "worker", workerId, "err", err)
return err
}

tc.Logger().Debug("waiting for next status report from worker received successfully", "worker", workerId)
return nil
}

func (tc *TestController) waitForNextWorkerStatusUpdate(workerId string) error {
waitStatusStart := time.Now()
ctx, cancel := context.WithTimeout(tc.ctx, tc.b.StatusGracePeriodDuration)
defer cancel()
var err error
for {
select {
case <-ctx.Done():
return ctx.Err()
if err = func() error {
select {
case <-ctx.Done():
return ctx.Err()

case <-time.After(time.Second):
// pass
case <-time.After(time.Second):
// pass
}

return nil
}(); err != nil {
break
}

var waitStatusCurrent time.Time
var err error
tc.Controller().WorkerStatusUpdateTimes().Range(func(k, v interface{}) bool {
if k == nil || v == nil {
err = fmt.Errorf("nil key or value on entry: key=%#v value=%#v", k, v)
Expand Down Expand Up @@ -696,13 +691,19 @@ func (tc *TestController) waitForNextWorkerStatusUpdate(workerId string) error {
})

if err != nil {
return err
break
}

if waitStatusCurrent.Sub(waitStatusStart) > 0 {
break
}
}

if err != nil {
tc.Logger().Error("error waiting for next status report from worker", "worker", workerId, "err", err)
return err
}

tc.Logger().Debug("waiting for next status report from worker received successfully", "worker", workerId)
return nil
}
10 changes: 3 additions & 7 deletions internal/servers/worker/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context) {
// the request.
for _, conn := range sessInfo.GetConnections() {
connId := conn.GetConnectionId()
connInfo, ok := si.connInfoMap[conn.GetConnectionId()]
connInfo, ok := si.connInfoMap[connId]
if !ok {
w.logger.Warn("connection change requested but could not find local information for it", "connection_id", connId)
continue
Expand Down Expand Up @@ -250,7 +250,7 @@ func (w *Worker) cleanupConnections(cancelCtx context.Context, ignoreSessionStat
closedIds := w.cancelConnections(si.connInfoMap, true)
for _, connId := range closedIds {
closeInfo[connId] = si.id
w.logClose(si.id, connId)
w.logger.Info("terminated connection due to cancellation or expiration", "session_id", si.id, "connection_id", connId)
}

// closeTime is marked by closeConnections iff the
Expand All @@ -268,7 +268,7 @@ func (w *Worker) cleanupConnections(cancelCtx context.Context, ignoreSessionStat
closedIds := w.cancelConnections(si.connInfoMap, false)
for _, connId := range closedIds {
closeInfo[connId] = si.id
w.logClose(si.id, connId)
w.logger.Info("terminated connection due to cancellation or expiration", "session_id", si.id, "connection_id", connId)
}
}

Expand Down Expand Up @@ -314,10 +314,6 @@ func (w *Worker) cancelConnections(connInfoMap map[string]*connInfo, ignoreConne
return closedIds
}

func (w *Worker) logClose(sessionId, connId string) {
w.logger.Info("terminated connection due to cancellation or expiration", "session_id", sessionId, "connection_id", connId)
}

func (w *Worker) lastSuccessfulStatusTime() time.Time {
lastStatus := w.LastStatusSuccess()
if lastStatus == nil {
Expand Down
73 changes: 22 additions & 51 deletions internal/session/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,59 +291,30 @@ with
// The query returns the set of servers that have had connections closed
// along with their last update time and the number of connections closed on
// each.

closeConnectionsForDeadServersCte = `
with
-- Get dead servers, parameterized off of grace period in seconds
dead_servers as (
select private_id, update_time
from server
where update_time < wt_sub_seconds_from_now($1)
),
-- Find connections that are not closed so we can reference those IDs
unclosed_connections as (
select connection_id
from session_connection_state
where
-- It's the current state
end_time is null
and
-- Current state isn't closed state
state in ('authorized', 'connected')
and
-- It's not in limbo between when it moved into this state and when
-- it started being reported by the worker, which is roughly every
-- 2-3 seconds
start_time < wt_sub_seconds_from_now(10)
),
connections_to_close as (
select public_id
from session_connection
where
-- Related to the worker that just reported to us
server_id in (select private_id from dead_servers)
and
-- Only unclosed ones
public_id in (select connection_id from unclosed_connections)
),
closed_connections as (
update session_connection
set
closed_reason = 'system error'
where
public_id in (select public_id from connections_to_close)
with
dead_servers (server_id, last_update_time) as (
select private_id, update_time
from server
where update_time < wt_sub_seconds_from_now($1)
),
closed_connections (connection_id, server_id) as (
update session_connection
set closed_reason = 'system error'
where server_id in (select server_id from dead_servers)
and closed_reason is null
returning public_id, server_id
)
select
dead_servers.private_id,
dead_servers.update_time,
count(closed_connections.public_id)
from dead_servers
left join closed_connections
on dead_servers.private_id = closed_connections.server_id
group by dead_servers.private_id, dead_servers.update_time
having count(closed_connections.public_id) > 0
order by dead_servers.private_id
`
)
select closed_connections.server_id,
dead_servers.last_update_time,
count(closed_connections.connection_id) as number_connections_closed
from closed_connections
join dead_servers
on closed_connections.server_id = dead_servers.server_id
group by closed_connections.server_id, dead_servers.last_update_time
order by closed_connections.server_id;
`

// shouldCloseConnectionsCte finds connections that are marked as closed in
// the database given a set of connection IDs. They are returned along with
Expand Down
15 changes: 3 additions & 12 deletions internal/session/repository_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,21 +185,12 @@ func (r *Repository) CloseConnectionsForDeadWorkers(ctx context.Context, gracePe
defer rows.Close()

for rows.Next() {
var (
serverId string
lastUpdateTime time.Time
numberConnectionsClosed int
)

if err := rows.Scan(&serverId, &lastUpdateTime, &numberConnectionsClosed); err != nil {
var result CloseConnectionsForDeadWorkersResult
if err := w.ScanRows(rows, &result); err != nil {
return errors.Wrap(err, op)
}

results = append(results, CloseConnectionsForDeadWorkersResult{
ServerId: serverId,
LastUpdateTime: lastUpdateTime,
NumberConnectionsClosed: numberConnectionsClosed,
})
results = append(results, result)
}

return nil
Expand Down
4 changes: 2 additions & 2 deletions internal/tests/cluster/session_cleanup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ func workerGracePeriod(ty timeoutBurdenType) time.Duration {
return defaultGracePeriod
}

// TestWorkerSessionCleanup is the main test for session cleanup, and
// TestSessionCleanup is the main test for session cleanup, and
// dispatches to the individual subtests.
func TestWorkerSessionCleanup(t *testing.T) {
func TestSessionCleanup(t *testing.T) {
t.Parallel()
for _, burdenCase := range timeoutBurdenCases {
burdenCase := burdenCase
Expand Down