From b300e750f4217968ec66055d2d7e20e491efe5a1 Mon Sep 17 00:00:00 2001 From: lhy1024 Date: Sun, 7 May 2023 00:44:56 +0800 Subject: [PATCH] add ctx Signed-off-by: lhy1024 --- pkg/utils/etcdutil/etcdutil_test.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/pkg/utils/etcdutil/etcdutil_test.go b/pkg/utils/etcdutil/etcdutil_test.go index 4258dc60a2ca..3ba1def329d2 100644 --- a/pkg/utils/etcdutil/etcdutil_test.go +++ b/pkg/utils/etcdutil/etcdutil_test.go @@ -344,7 +344,7 @@ func checkEtcdWithHangLeader(t *testing.T) error { // Create a proxy to etcd1. proxyAddr := tempurl.Alloc() var enableDiscard atomic.Bool - go proxyWithDiscard(re, ep1, proxyAddr, &enableDiscard) + go proxyWithDiscard(context.Background(), re, ep1, proxyAddr, &enableDiscard) // Create a etcd client with etcd1 as endpoint. urls, err := types.NewURLs([]string{proxyAddr}) @@ -402,7 +402,7 @@ func checkMembers(re *require.Assertions, client *clientv3.Client, etcds []*embe } } -func proxyWithDiscard(re *require.Assertions, server, proxy string, enableDiscard *atomic.Bool) { +func proxyWithDiscard(ctx context.Context, re *require.Assertions, server, proxy string, enableDiscard *atomic.Bool) { server = strings.TrimPrefix(server, "http://") proxy = strings.TrimPrefix(proxy, "http://") l, err := net.Listen("tcp", proxy) @@ -413,19 +413,19 @@ func proxyWithDiscard(re *require.Assertions, server, proxy string, enableDiscar go func(connect net.Conn) { serverConnect, err := net.Dial("tcp", server) re.NoError(err) - pipe(connect, serverConnect, enableDiscard) + pipe(ctx, connect, serverConnect, enableDiscard) }(connect) } } -func pipe(src net.Conn, dst net.Conn, enableDiscard *atomic.Bool) { +func pipe(ctx context.Context, src net.Conn, dst net.Conn, enableDiscard *atomic.Bool) { errChan := make(chan error, 1) go func() { - err := ioCopy(src, dst, enableDiscard) + err := ioCopy(ctx, src, dst, enableDiscard) errChan <- err }() go func() { - err := ioCopy(dst, src, enableDiscard) + err := ioCopy(ctx, dst, src, enableDiscard) errChan <- err }() <-errChan @@ -433,9 +433,14 @@ func pipe(src net.Conn, dst net.Conn, enableDiscard *atomic.Bool) { src.Close() } -func ioCopy(dst io.Writer, src io.Reader, enableDiscard *atomic.Bool) (err error) { +func ioCopy(ctx context.Context, dst io.Writer, src io.Reader, enableDiscard *atomic.Bool) (err error) { buffer := make([]byte, 32*1024) for { + select { + case <-ctx.Done(): + return nil + default: + } if enableDiscard.Load() { io.Copy(io.Discard, src) continue