From f15d4aef6a6412cc1f18e401d4443266d6acd01a Mon Sep 17 00:00:00 2001 From: yihuaz Date: Fri, 8 May 2020 13:09:33 -0700 Subject: [PATCH 1/9] Fix conflicts with upstream master --- credentials/credentials.go | 8 +- credentials/credentials_test.go | 4 +- credentials/local/local.go | 109 +++++++++++++++ credentials/local/local_test.go | 204 +++++++++++++++++++++++++++++ internal/transport/http2_client.go | 4 +- test/end2end_test.go | 39 +++++- 6 files changed, 360 insertions(+), 8 deletions(-) create mode 100644 credentials/local/local.go create mode 100644 credentials/local/local_test.go diff --git a/credentials/credentials.go b/credentials/credentials.go index 53addd8c71e9..02766443ae74 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -58,9 +58,11 @@ type PerRPCCredentials interface { type SecurityLevel int const ( - // NoSecurity indicates a connection is insecure. + // Invalid indicates an invalid security level. // The zero SecurityLevel value is invalid for backward compatibility. - NoSecurity SecurityLevel = iota + 1 + Invalid SecurityLevel = iota + // NoSecurity indicates a connection is insecure. + NoSecurity // IntegrityOnly indicates a connection only provides integrity protection. IntegrityOnly // PrivacyAndIntegrity indicates a connection provides both privacy and integrity protection. @@ -237,7 +239,7 @@ func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error { } if ci, ok := ri.AuthInfo.(internalInfo); ok { // CommonAuthInfo.SecurityLevel has an invalid value. - if ci.GetCommonAuthInfo().SecurityLevel == 0 { + if ci.GetCommonAuthInfo().SecurityLevel == Invalid { return nil } if ci.GetCommonAuthInfo().SecurityLevel < level { diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index c2a316281928..5ff4850453bd 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -86,12 +86,12 @@ func (s) TestCheckSecurityLevel(t *testing.T) { want: true, }, { - authLevel: 0, + authLevel: Invalid, testLevel: IntegrityOnly, want: true, }, { - authLevel: 0, + authLevel: Invalid, testLevel: PrivacyAndIntegrity, want: true, }, diff --git a/credentials/local/local.go b/credentials/local/local.go new file mode 100644 index 000000000000..23de34cf8a18 --- /dev/null +++ b/credentials/local/local.go @@ -0,0 +1,109 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package local implements local transport credentials. +// Local credentials reports the security level based on the type +// of connetion. If the connection is local TCP, NoSecurity will be +// reported, and if the connection is UDS, PrivacyAndIntegrity will be +// reported. If local credentials is not used in local connections +// (local TCP or UDS), it will fail. +// +// This package is EXPERIMENTAL. +package local + +import ( + "context" + "fmt" + "net" + "strings" + + "google.golang.org/grpc/credentials" +) + +// Info contains the auth information for a local connection. +// It implements the AuthInfo interface. +type Info struct { + credentials.CommonAuthInfo +} + +// AuthType returns the type of Info as a string. +func (Info) AuthType() string { + return "local" +} + +// localTC is the credentials required to establish a local connection. +type localTC struct { + info credentials.ProtocolInfo +} + +func (c *localTC) Info() credentials.ProtocolInfo { + return c.info +} + +// getSecurityLevel returns the security level for a local connection. +// It returns an error if a connection is not local. +func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) { + switch { + // Local TCP connection + case strings.HasPrefix(addr, "127."), strings.HasPrefix(addr, "[::1]:"): + return credentials.NoSecurity, nil + // UDS connection + case network == "unix": + return credentials.PrivacyAndIntegrity, nil + // Not a local connection and should fail + default: + return credentials.Invalid, fmt.Errorf("local credentials rejected connection to non-local address %q", addr) + } +} + +func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) + if err != nil { + return nil, nil, err + } + return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil +} + +func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) + if err != nil { + return nil, nil, err + } + return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil +} + +// NewCredentials returns a local credential implementing credentials.TransportCredentials. +func NewCredentials() credentials.TransportCredentials { + return &localTC{ + info: credentials.ProtocolInfo{ + SecurityProtocol: "local", + }, + } +} + +// Clone makes a copy of Local credentials. +func (c *localTC) Clone() credentials.TransportCredentials { + return &localTC{info: c.info} +} + +// OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server. +// Since this feature is specific to TLS (SNI + hostname verification check), it does not take any effet for local credentials. +func (c *localTC) OverrideServerName(serverNameOverride string) error { + c.info.ServerName = serverNameOverride + return nil +} diff --git a/credentials/local/local_test.go b/credentials/local/local_test.go new file mode 100644 index 000000000000..8317ec3e8ed6 --- /dev/null +++ b/credentials/local/local_test.go @@ -0,0 +1,204 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package local + +import ( + "context" + "fmt" + "net" + "runtime" + "strings" + "testing" + "time" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestGetSecurityLevel(t *testing.T) { + testCases := []struct { + testNetwork string + testAddr string + want credentials.SecurityLevel + }{ + { + testNetwork: "tcp", + testAddr: "127.0.0.1:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "[::1]:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "unix", + testAddr: "/tmp/grpc_fullstack_test", + want: credentials.PrivacyAndIntegrity, + }, + { + testNetwork: "tcp", + testAddr: "192.168.0.1:10000", + want: credentials.Invalid, + }, + } + for _, tc := range testCases { + got, _ := getSecurityLevel(tc.testNetwork, tc.testAddr) + if got != tc.want { + t.Fatalf("GetSeurityLevel(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String()) + } + } +} + +type serverHandshake func(net.Conn) (credentials.AuthInfo, error) + +// Server local handshake implementation. +func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) { + cred := NewCredentials() + _, authInfo, err := cred.ServerHandshake(conn) + if err != nil { + return nil, err + } + return authInfo, nil +} + +// Client local handshake implementation. +func clientLocalHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) { + cred := NewCredentials() + _, authInfo, err := cred.ClientHandshake(context.Background(), lisAddr, conn) + if err != nil { + return nil, err + } + return authInfo, nil +} + +// Client connects to a server with local credentials. +func clientHandle(hs func(net.Conn, string) (credentials.AuthInfo, error), network, lisAddr string) (credentials.AuthInfo, error) { + conn, _ := net.Dial(network, lisAddr) + defer conn.Close() + clientAuthInfo, err := hs(conn, lisAddr) + if err != nil { + return nil, fmt.Errorf("Error on client while handshake") + } + return clientAuthInfo, nil +} + +type testServerHandleResult struct { + authInfo credentials.AuthInfo + err error +} + +// Server accepts a client's connection with local credentials. +func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net.Listener) { + serverRawConn, err := lis.Accept() + if err != nil { + done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)} + } + serverAuthInfo, err := hs(serverRawConn) + if err != nil { + serverRawConn.Close() + done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)} + } + done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil} +} + +func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, error) { + done := make(chan testServerHandleResult, 1) + const timeout = 5 * time.Second + timer := time.NewTimer(timeout) + defer timer.Stop() + go serverHandle(serverLocalHandshake, done, lis) + defer lis.Close() + clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String()) + if err != nil { + return credentials.Invalid, fmt.Errorf("Error at client-side: %v", err) + } + select { + case <-timer.C: + return credentials.Invalid, fmt.Errorf("Test didn't finish in time") + case serverHandleResult := <-done: + if serverHandleResult.err != nil { + return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err) + } + clientLocal, _ := clientAuthInfo.(Info) + serverLocal, _ := serverHandleResult.authInfo.(Info) + clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel + serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel + if clientSecLevel != serverSecLevel { + return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String()) + } + return clientSecLevel, nil + } +} + +func (s) TestServerAndClientHandshake(t *testing.T) { + testCases := []struct { + testNetwork string + testAddr string + want credentials.SecurityLevel + }{ + { + testNetwork: "tcp", + testAddr: "127.0.0.1:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "[::1]:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "localhost:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "unix", + testAddr: fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()), + want: credentials.PrivacyAndIntegrity, + }, + } + for _, tc := range testCases { + if runtime.GOOS == "windows" && tc.testNetwork == "unix" { + t.Skip("skipping tests for unix connections on Windows") + } + t.Run("serverAndClientHandshakeResult", func(t *testing.T) { + lis, err := net.Listen(tc.testNetwork, tc.testAddr) + if err != nil { + if strings.Contains(err.Error(), "bind: cannot assign requested address") || + strings.Contains(err.Error(), "socket: address family not supported by protocol") { + t.Skipf("no support for address %v", tc.testAddr) + } + t.Fatalf("Failed to listen: %v", err) + } + got, err := serverAndClientHandshake(lis) + if got != tc.want { + t.Fatalf("ServerAndClientHandshake(%s, %s) returned %s but want %s. Error: %v", tc.testNetwork, tc.testAddr, got.String(), tc.want.String(), err) + } + }) + } +} diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index d1eb17e068fe..b4e55d0101ba 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -221,12 +221,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // address specific arbitrary data to reach the credential handshaker. contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) - scheme = "https" conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn) if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } isSecure = true + if transportCreds.Info().SecurityProtocol == "tls" { + scheme = "https" + } } dynamicWindow := true icwz := int32(initialWindowSize) diff --git a/test/end2end_test.go b/test/end2end_test.go index f3a60de5a96f..939839869ac9 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -50,6 +50,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/local" "google.golang.org/grpc/encoding" _ "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/health" @@ -214,18 +215,43 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* if s.security != "" { // Check Auth info var authType, serverName string + var secLevel credentials.SecurityLevel switch info := pr.AuthInfo.(type) { case credentials.TLSInfo: authType = info.AuthType() serverName = info.State.ServerName + secLevel = info.CommonAuthInfo.SecurityLevel + case local.Info: + authType = info.AuthType() + secLevel = info.CommonAuthInfo.SecurityLevel default: return nil, status.Error(codes.Unauthenticated, "Unknown AuthInfo type") } if authType != s.security { return nil, status.Errorf(codes.Unauthenticated, "Wrong auth type: got %q, want %q", authType, s.security) } - if serverName != "x.test.youtube.com" { - return nil, status.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName) + + // Check Auth info specific to credentials.TLSInfo + if s.security == "tls" { + if secLevel != credentials.PrivacyAndIntegrity { + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String()) + } + if serverName != "x.test.youtube.com" { + return nil, status.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName) + } + } + // Check Auth info specific to local.Info + if s.security == "local" { + switch pr.Addr.Network() { + case "tcp": + if secLevel != credentials.NoSecurity { + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.NoSecurity.String()) + } + case "unix": + if secLevel != credentials.PrivacyAndIntegrity { + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String()) + } + } } } // Simulate some service delay. @@ -394,6 +420,7 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ type env struct { name string network string // The type of network such as tcp, unix, etc. + listenerAddr string // The address of listener. security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS balancer string // One of "round_robin", "pick_first", or "". @@ -417,6 +444,7 @@ func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { var ( tcpClearEnv = env{name: "tcp-clear-v1-balancer", network: "tcp"} tcpTLSEnv = env{name: "tcp-tls-v1-balancer", network: "tcp", security: "tls"} + tcpLocalEnv = env{name: "tcp-local-v1-balancer", network: "tcp", listenerAddr: "[::1]:0", security: "local"} tcpClearRREnv = env{name: "tcp-clear", network: "tcp", balancer: "round_robin"} tcpTLSRREnv = env{name: "tcp-tls", network: "tcp", security: "tls", balancer: "round_robin"} handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true, balancer: "round_robin"} @@ -607,6 +635,9 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network, sopts = append(sopts, grpc.InitialConnWindowSize(te.serverInitialConnWindowSize)) } la := "localhost:0" + if te.e.listenerAddr != "" { + la = te.e.listenerAddr + } switch te.e.network { case "unix": la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now().UnixNano()) @@ -622,6 +653,8 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network, te.t.Fatalf("Failed to generate credentials %v", err) } sopts = append(sopts, grpc.Creds(creds)) + } else if te.e.security == "local" { + sopts = append(sopts, grpc.Creds(local.NewCredentials())) } sopts = append(sopts, te.customServerOptions...) s := grpc.NewServer(sopts...) @@ -798,6 +831,8 @@ func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) te.t.Fatalf("Failed to load credentials: %v", err) } opts = append(opts, grpc.WithTransportCredentials(creds)) + case "local": + opts = append(opts, grpc.WithTransportCredentials(local.NewCredentials())) case "empty": // Don't add any transport creds option. default: From 9a865e93465e4f707aed14bfc4f7c889deee7ed8 Mon Sep 17 00:00:00 2001 From: yihuaz Date: Sun, 10 May 2020 12:37:11 -0700 Subject: [PATCH 2/9] add end2end test --- test/local_creds_test.go | 193 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) create mode 100644 test/local_creds_test.go diff --git a/test/local_creds_test.go b/test/local_creds_test.go new file mode 100644 index 000000000000..269607b2dbb2 --- /dev/null +++ b/test/local_creds_test.go @@ -0,0 +1,193 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package test + +import ( + "context" + "fmt" + "net" + "strings" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/local" + testpb "google.golang.org/grpc/test/grpc_testing" +) + +func testE2ESucceed(network, address string) error { + ss := &stubServer{ + emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + sopts := []grpc.ServerOption{grpc.Creds(local.NewCredentials())} + s := grpc.NewServer(sopts...) + defer s.Stop() + + testpb.RegisterTestServiceServer(s, ss) + + lis, err := net.Listen(network, address) + if err != nil { + return fmt.Errorf("Failed to create listener: %v", err) + } + + go s.Serve(lis) + + var cc *grpc.ClientConn + if network == "unix" { + cc, err = grpc.Dial("passthrough:///"+address, grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithContextDialer( + func(ctx context.Context, addr string) (net.Conn, error) { + return net.Dial("unix", address) + })) + } else { + cc, err = grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(local.NewCredentials())) + } + + if err != nil { + return fmt.Errorf("Failed to dial server: %v, %v", err, lis.Addr().String()) + } + defer cc.Close() + + c := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + if _, err := c.EmptyCall(ctx, &testpb.Empty{}); err != nil { + return fmt.Errorf("EmptyCall(_, _) = _, %v; want _, ", err) + } + + return nil +} + +func (s) TestLocalhost(t *testing.T) { + err := testE2ESucceed("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed e2e test for localhost: %v", err) + } +} + +func (s) TestUDS(t *testing.T) { + addr := fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()) + err := testE2ESucceed("unix", addr) + if err != nil { + t.Fatalf("Failed e2e test for UDS: %v", err) + } +} + +type connWrapper struct { + net.Conn + remote net.Addr +} + +func (c connWrapper) RemoteAddr() net.Addr { + return c.remote +} + +type lisWrapper struct { + net.Listener +} + +func newLisWrapper(l net.Listener) net.Listener { + return &lisWrapper{l} +} + +var remoteAddrs = []net.Addr{ + &net.IPAddr{ + IP: net.ParseIP("10.8.9.10"), + Zone: "", + }, + &net.IPAddr{ + IP: net.ParseIP("10.8.9.11"), + Zone: "", + }, +} + +func (l *lisWrapper) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return connWrapper{c, remoteAddrs[0]}, nil +} + +func dialer(target string, t time.Duration) (net.Conn, error) { + c, err := net.DialTimeout("tcp", target, t) + if err != nil { + return nil, err + } + return connWrapper{c, remoteAddrs[1]}, nil +} + +func testE2EFail(useLocal bool) error { + ss := &stubServer{ + emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + return &testpb.Empty{}, nil + }, + } + + sopts := []grpc.ServerOption{grpc.Creds(local.NewCredentials())} + s := grpc.NewServer(sopts...) + defer s.Stop() + + testpb.RegisterTestServiceServer(s, ss) + + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + return fmt.Errorf("Failed to create listener: %v", err) + } + + go s.Serve(newLisWrapper(lis)) + + var cc *grpc.ClientConn + if useLocal { + cc, err = grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithDialer(dialer)) + } else { + cc, err = grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithDialer(dialer)) + } + + if err != nil { + return fmt.Errorf("Failed to dial server: %v, %v", err, lis.Addr().String()) + } + defer cc.Close() + + c := testpb.NewTestServiceClient(cc) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + _, err = c.EmptyCall(ctx, &testpb.Empty{}) + return err +} + +func (s) TestClientFail(t *testing.T) { + // Use local creds at client-side which should lead to client-side failure. + err := testE2EFail(true /*useLocal*/) + if err == nil || !strings.Contains(err.Error(), "local credentials rejected connection to non-local address") { + t.Fatalf("testE2EFail(%v) = _; want security handshake fails, %v", false, err) + } +} + +func (s) TestServerFail(t *testing.T) { + // Use insecure at client-side which should lead to server-side failure. + err := testE2EFail(false /*useLocal*/) + if err == nil || !strings.Contains(err.Error(), "connection closed") { + t.Fatalf("testE2EFail(%v) = _; want security handshake fails, %v", true, err) + } +} From 9482a1391a5d8f2ef9132f777e4b17940c400cc2 Mon Sep 17 00:00:00 2001 From: yihuaz Date: Sun, 10 May 2020 13:50:19 -0700 Subject: [PATCH 3/9] fix a test failure --- test/local_creds_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/local_creds_test.go b/test/local_creds_test.go index 269607b2dbb2..83a3e4b591fd 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -187,7 +187,7 @@ func (s) TestClientFail(t *testing.T) { func (s) TestServerFail(t *testing.T) { // Use insecure at client-side which should lead to server-side failure. err := testE2EFail(false /*useLocal*/) - if err == nil || !strings.Contains(err.Error(), "connection closed") { + if err == nil { t.Fatalf("testE2EFail(%v) = _; want security handshake fails, %v", true, err) } } From b81b36ce5e30a08f96a0e88e940f47ffb61abf79 Mon Sep 17 00:00:00 2001 From: yihuaz Date: Wed, 13 May 2020 15:08:05 -0700 Subject: [PATCH 4/9] address doug's 2nd round comments --- credentials/local/local_test.go | 2 +- test/end2end_test.go | 39 +---------- test/local_creds_test.go | 111 +++++++++++++++++++------------- 3 files changed, 71 insertions(+), 81 deletions(-) diff --git a/credentials/local/local_test.go b/credentials/local/local_test.go index 8317ec3e8ed6..a508d89bcd06 100644 --- a/credentials/local/local_test.go +++ b/credentials/local/local_test.go @@ -197,7 +197,7 @@ func (s) TestServerAndClientHandshake(t *testing.T) { } got, err := serverAndClientHandshake(lis) if got != tc.want { - t.Fatalf("ServerAndClientHandshake(%s, %s) returned %s but want %s. Error: %v", tc.testNetwork, tc.testAddr, got.String(), tc.want.String(), err) + t.Fatalf("serverAndClientHandshake(%s, %s) = %v, %v; want %v, nil", tc.testNetwork, tc.testAddr, got, err, tc.want) } }) } diff --git a/test/end2end_test.go b/test/end2end_test.go index 939839869ac9..f3a60de5a96f 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -50,7 +50,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/local" "google.golang.org/grpc/encoding" _ "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/health" @@ -215,43 +214,18 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* if s.security != "" { // Check Auth info var authType, serverName string - var secLevel credentials.SecurityLevel switch info := pr.AuthInfo.(type) { case credentials.TLSInfo: authType = info.AuthType() serverName = info.State.ServerName - secLevel = info.CommonAuthInfo.SecurityLevel - case local.Info: - authType = info.AuthType() - secLevel = info.CommonAuthInfo.SecurityLevel default: return nil, status.Error(codes.Unauthenticated, "Unknown AuthInfo type") } if authType != s.security { return nil, status.Errorf(codes.Unauthenticated, "Wrong auth type: got %q, want %q", authType, s.security) } - - // Check Auth info specific to credentials.TLSInfo - if s.security == "tls" { - if secLevel != credentials.PrivacyAndIntegrity { - return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String()) - } - if serverName != "x.test.youtube.com" { - return nil, status.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName) - } - } - // Check Auth info specific to local.Info - if s.security == "local" { - switch pr.Addr.Network() { - case "tcp": - if secLevel != credentials.NoSecurity { - return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.NoSecurity.String()) - } - case "unix": - if secLevel != credentials.PrivacyAndIntegrity { - return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String()) - } - } + if serverName != "x.test.youtube.com" { + return nil, status.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName) } } // Simulate some service delay. @@ -420,7 +394,6 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ type env struct { name string network string // The type of network such as tcp, unix, etc. - listenerAddr string // The address of listener. security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS balancer string // One of "round_robin", "pick_first", or "". @@ -444,7 +417,6 @@ func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { var ( tcpClearEnv = env{name: "tcp-clear-v1-balancer", network: "tcp"} tcpTLSEnv = env{name: "tcp-tls-v1-balancer", network: "tcp", security: "tls"} - tcpLocalEnv = env{name: "tcp-local-v1-balancer", network: "tcp", listenerAddr: "[::1]:0", security: "local"} tcpClearRREnv = env{name: "tcp-clear", network: "tcp", balancer: "round_robin"} tcpTLSRREnv = env{name: "tcp-tls", network: "tcp", security: "tls", balancer: "round_robin"} handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true, balancer: "round_robin"} @@ -635,9 +607,6 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network, sopts = append(sopts, grpc.InitialConnWindowSize(te.serverInitialConnWindowSize)) } la := "localhost:0" - if te.e.listenerAddr != "" { - la = te.e.listenerAddr - } switch te.e.network { case "unix": la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now().UnixNano()) @@ -653,8 +622,6 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network, te.t.Fatalf("Failed to generate credentials %v", err) } sopts = append(sopts, grpc.Creds(creds)) - } else if te.e.security == "local" { - sopts = append(sopts, grpc.Creds(local.NewCredentials())) } sopts = append(sopts, te.customServerOptions...) s := grpc.NewServer(sopts...) @@ -831,8 +798,6 @@ func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) te.t.Fatalf("Failed to load credentials: %v", err) } opts = append(opts, grpc.WithTransportCredentials(creds)) - case "local": - opts = append(opts, grpc.WithTransportCredentials(local.NewCredentials())) case "empty": // Don't add any transport creds option. default: diff --git a/test/local_creds_test.go b/test/local_creds_test.go index 83a3e4b591fd..6f2a83c548d0 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -27,13 +27,34 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/local" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" testpb "google.golang.org/grpc/test/grpc_testing" ) func testE2ESucceed(network, address string) error { ss := &stubServer{ - emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { + emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { + pr, ok := peer.FromContext(ctx) + if !ok { + return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx") + } + // Check security level + info := pr.AuthInfo.(local.Info) + secLevel := info.CommonAuthInfo.SecurityLevel + switch network { + case "unix": + if secLevel != credentials.PrivacyAndIntegrity { + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String()) + } + case "tcp": + if secLevel != credentials.NoSecurity { + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.NoSecurity.String()) + } + } return &testpb.Empty{}, nil }, } @@ -52,15 +73,17 @@ func testE2ESucceed(network, address string) error { go s.Serve(lis) var cc *grpc.ClientConn - if network == "unix" { - cc, err = grpc.Dial("passthrough:///"+address, grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithContextDialer( + switch network { + case "unix": + cc, err = grpc.Dial(address, grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithContextDialer( func(ctx context.Context, addr string) (net.Conn, error) { return net.Dial("unix", address) })) - } else { + case "tcp": cc, err = grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(local.NewCredentials())) + default: + return fmt.Errorf("unsupported network %q", network) } - if err != nil { return fmt.Errorf("Failed to dial server: %v, %v", err, lis.Addr().String()) } @@ -73,21 +96,18 @@ func testE2ESucceed(network, address string) error { if _, err := c.EmptyCall(ctx, &testpb.Empty{}); err != nil { return fmt.Errorf("EmptyCall(_, _) = _, %v; want _, ", err) } - return nil } func (s) TestLocalhost(t *testing.T) { - err := testE2ESucceed("tcp", "localhost:0") - if err != nil { + if err := testE2ESucceed("tcp", "localhost:0"); err != nil { t.Fatalf("Failed e2e test for localhost: %v", err) } } func (s) TestUDS(t *testing.T) { addr := fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()) - err := testE2ESucceed("unix", addr) - if err != nil { + if err := testE2ESucceed("unix", addr); err != nil { t.Fatalf("Failed e2e test for UDS: %v", err) } } @@ -103,21 +123,11 @@ func (c connWrapper) RemoteAddr() net.Addr { type lisWrapper struct { net.Listener + remote net.Addr } -func newLisWrapper(l net.Listener) net.Listener { - return &lisWrapper{l} -} - -var remoteAddrs = []net.Addr{ - &net.IPAddr{ - IP: net.ParseIP("10.8.9.10"), - Zone: "", - }, - &net.IPAddr{ - IP: net.ParseIP("10.8.9.11"), - Zone: "", - }, +func newLisWrapper(l net.Listener, remote net.Addr) net.Listener { + return &lisWrapper{l, remote} } func (l *lisWrapper) Accept() (net.Conn, error) { @@ -125,18 +135,20 @@ func (l *lisWrapper) Accept() (net.Conn, error) { if err != nil { return nil, err } - return connWrapper{c, remoteAddrs[0]}, nil + return connWrapper{c, l.remote}, nil } -func dialer(target string, t time.Duration) (net.Conn, error) { - c, err := net.DialTimeout("tcp", target, t) - if err != nil { - return nil, err +func spoofDialer(addr net.Addr) func(target string, t time.Duration) (net.Conn, error) { + return func(t string, d time.Duration) (net.Conn, error) { + c, err := net.DialTimeout("tcp", t, d) + if err != nil { + return nil, err + } + return connWrapper{c, addr}, nil } - return connWrapper{c, remoteAddrs[1]}, nil } -func testE2EFail(useLocal bool) error { +func testE2EFail(dopts []grpc.DialOption) error { ss := &stubServer{ emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil @@ -154,15 +166,19 @@ func testE2EFail(useLocal bool) error { return fmt.Errorf("Failed to create listener: %v", err) } - go s.Serve(newLisWrapper(lis)) - - var cc *grpc.ClientConn - if useLocal { - cc, err = grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithDialer(dialer)) - } else { - cc, err = grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithDialer(dialer)) + var fakeClientAddr, fakeServerAddr net.Addr + fakeClientAddr = &net.IPAddr{ + IP: net.ParseIP("10.8.9.10"), + Zone: "", + } + fakeServerAddr = &net.IPAddr{ + IP: net.ParseIP("10.8.9.11"), + Zone: "", } + go s.Serve(newLisWrapper(lis, fakeClientAddr)) + + cc, err := grpc.Dial(lis.Addr().String(), append(dopts, grpc.WithDialer(spoofDialer(fakeServerAddr)))...) if err != nil { return fmt.Errorf("Failed to dial server: %v, %v", err, lis.Addr().String()) } @@ -176,18 +192,27 @@ func testE2EFail(useLocal bool) error { return err } +func isExpected(got, want error) bool { + if status.Code(got) == status.Code(want) && strings.Contains(status.Convert(got).Message(), status.Convert(want).Message()) { + return true + } + return false +} + func (s) TestClientFail(t *testing.T) { // Use local creds at client-side which should lead to client-side failure. - err := testE2EFail(true /*useLocal*/) - if err == nil || !strings.Contains(err.Error(), "local credentials rejected connection to non-local address") { - t.Fatalf("testE2EFail(%v) = _; want security handshake fails, %v", false, err) + opts := []grpc.DialOption{grpc.WithTransportCredentials(local.NewCredentials())} + want := status.Error(codes.Unavailable, "transport: authentication handshake failed: local credentials rejected connection to non-local address") + if err := testE2EFail(opts); !isExpected(err, want) { + t.Fatalf("testE2EFail() = %v; want %v", err, want) } } func (s) TestServerFail(t *testing.T) { // Use insecure at client-side which should lead to server-side failure. - err := testE2EFail(false /*useLocal*/) - if err == nil { - t.Fatalf("testE2EFail(%v) = _; want security handshake fails, %v", true, err) + opts := []grpc.DialOption{grpc.WithInsecure()} + want := status.Error(codes.Unavailable, "connection closed") + if err := testE2EFail(opts); !isExpected(err, want) { + t.Fatalf("testE2EFail() = %v; want %v", err, want) } } From 876fdb9f664cf460503ea96fbb58c11919cec66e Mon Sep 17 00:00:00 2001 From: yihuaz Date: Thu, 14 May 2020 10:35:21 -0700 Subject: [PATCH 5/9] address doug's 3rd round of comments --- test/local_creds_test.go | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/test/local_creds_test.go b/test/local_creds_test.go index 6f2a83c548d0..5c7c1089615f 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -32,6 +32,7 @@ import ( "google.golang.org/grpc/credentials/local" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" + testpb "google.golang.org/grpc/test/grpc_testing" ) @@ -48,11 +49,11 @@ func testE2ESucceed(network, address string) error { switch network { case "unix": if secLevel != credentials.PrivacyAndIntegrity { - return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String()) + return nil, status.Errorf(codes.Unauthenticated, fmt.Sprintf("Wrong security level: got %q, want %q", secLevel, credentials.PrivacyAndIntegrity)) } case "tcp": if secLevel != credentials.NoSecurity { - return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.NoSecurity.String()) + return nil, status.Errorf(codes.Unauthenticated, fmt.Sprintf("Wrong security level: got %q, want %q", secLevel, credentials.NoSecurity)) } } return &testpb.Empty{}, nil @@ -73,19 +74,26 @@ func testE2ESucceed(network, address string) error { go s.Serve(lis) var cc *grpc.ClientConn + lisAddr := address + switch network { case "unix": - cc, err = grpc.Dial(address, grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithContextDialer( + cc, err = grpc.Dial(lisAddr, grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithContextDialer( func(ctx context.Context, addr string) (net.Conn, error) { - return net.Dial("unix", address) + return net.Dial("unix", addr) })) case "tcp": - cc, err = grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(local.NewCredentials())) + _, port, err := net.SplitHostPort(lis.Addr().String()) + if err != nil { + return fmt.Errorf("Failed to parse listener address: %v", err) + } + lisAddr = "localhost:" + port + cc, err = grpc.Dial(lisAddr, grpc.WithTransportCredentials(local.NewCredentials())) default: return fmt.Errorf("unsupported network %q", network) } if err != nil { - return fmt.Errorf("Failed to dial server: %v, %v", err, lis.Addr().String()) + return fmt.Errorf("Failed to dial server: %v, %v", err, lisAddr) } defer cc.Close() @@ -193,10 +201,7 @@ func testE2EFail(dopts []grpc.DialOption) error { } func isExpected(got, want error) bool { - if status.Code(got) == status.Code(want) && strings.Contains(status.Convert(got).Message(), status.Convert(want).Message()) { - return true - } - return false + return status.Code(got) == status.Code(want) && strings.Contains(status.Convert(got).Message(), status.Convert(want).Message()) } func (s) TestClientFail(t *testing.T) { @@ -211,7 +216,7 @@ func (s) TestClientFail(t *testing.T) { func (s) TestServerFail(t *testing.T) { // Use insecure at client-side which should lead to server-side failure. opts := []grpc.DialOption{grpc.WithInsecure()} - want := status.Error(codes.Unavailable, "connection closed") + want := status.Error(codes.Unavailable, "") if err := testE2EFail(opts); !isExpected(err, want) { t.Fatalf("testE2EFail() = %v; want %v", err, want) } From e417e60f4c847a181c063ff2521a3f86c00d75a0 Mon Sep 17 00:00:00 2001 From: yihuaz Date: Thu, 14 May 2020 12:31:04 -0700 Subject: [PATCH 6/9] fix travis error --- test/local_creds_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/local_creds_test.go b/test/local_creds_test.go index 5c7c1089615f..7a4dc0cb73b3 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -66,6 +66,7 @@ func testE2ESucceed(network, address string) error { testpb.RegisterTestServiceServer(s, ss) + var err error lis, err := net.Listen(network, address) if err != nil { return fmt.Errorf("Failed to create listener: %v", err) @@ -101,7 +102,7 @@ func testE2ESucceed(network, address string) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - if _, err := c.EmptyCall(ctx, &testpb.Empty{}); err != nil { + if _, err = c.EmptyCall(ctx, &testpb.Empty{}); err != nil { return fmt.Errorf("EmptyCall(_, _) = _, %v; want _, ", err) } return nil From 11d75673d11ba95cc788c158c3cb070b24566754 Mon Sep 17 00:00:00 2001 From: yihuaz Date: Thu, 14 May 2020 13:02:00 -0700 Subject: [PATCH 7/9] 2nd attempt to fix travis error --- test/local_creds_test.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/local_creds_test.go b/test/local_creds_test.go index 7a4dc0cb73b3..ced3e547bfe7 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -84,7 +84,8 @@ func testE2ESucceed(network, address string) error { return net.Dial("unix", addr) })) case "tcp": - _, port, err := net.SplitHostPort(lis.Addr().String()) + var port string + _, port, err = net.SplitHostPort(lis.Addr().String()) if err != nil { return fmt.Errorf("Failed to parse listener address: %v", err) } From 3533a21a51591cc128b7e429b034ab441f706d46 Mon Sep 17 00:00:00 2001 From: yihuaz Date: Thu, 14 May 2020 13:15:05 -0700 Subject: [PATCH 8/9] Address doug's 4th round of comments --- test/local_creds_test.go | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/test/local_creds_test.go b/test/local_creds_test.go index ced3e547bfe7..8effea99162a 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -49,11 +49,11 @@ func testE2ESucceed(network, address string) error { switch network { case "unix": if secLevel != credentials.PrivacyAndIntegrity { - return nil, status.Errorf(codes.Unauthenticated, fmt.Sprintf("Wrong security level: got %q, want %q", secLevel, credentials.PrivacyAndIntegrity)) + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel, credentials.PrivacyAndIntegrity) } case "tcp": if secLevel != credentials.NoSecurity { - return nil, status.Errorf(codes.Unauthenticated, fmt.Sprintf("Wrong security level: got %q, want %q", secLevel, credentials.NoSecurity)) + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel, credentials.NoSecurity) } } return &testpb.Empty{}, nil @@ -66,7 +66,6 @@ func testE2ESucceed(network, address string) error { testpb.RegisterTestServiceServer(s, ss) - var err error lis, err := net.Listen(network, address) if err != nil { return fmt.Errorf("Failed to create listener: %v", err) @@ -75,7 +74,7 @@ func testE2ESucceed(network, address string) error { go s.Serve(lis) var cc *grpc.ClientConn - lisAddr := address + lisAddr := lis.Addr().String() switch network { case "unix": @@ -84,12 +83,6 @@ func testE2ESucceed(network, address string) error { return net.Dial("unix", addr) })) case "tcp": - var port string - _, port, err = net.SplitHostPort(lis.Addr().String()) - if err != nil { - return fmt.Errorf("Failed to parse listener address: %v", err) - } - lisAddr = "localhost:" + port cc, err = grpc.Dial(lisAddr, grpc.WithTransportCredentials(local.NewCredentials())) default: return fmt.Errorf("unsupported network %q", network) @@ -218,8 +211,7 @@ func (s) TestClientFail(t *testing.T) { func (s) TestServerFail(t *testing.T) { // Use insecure at client-side which should lead to server-side failure. opts := []grpc.DialOption{grpc.WithInsecure()} - want := status.Error(codes.Unavailable, "") - if err := testE2EFail(opts); !isExpected(err, want) { - t.Fatalf("testE2EFail() = %v; want %v", err, want) + if err := testE2EFail(opts); status.Code(err) != codes.Unavailable { + t.Fatalf("testE2EFail() = %v; want %v", err, codes.Unavailable) } } From 2f261a37a5a4274ac63c5c4af776143b93078a1b Mon Sep 17 00:00:00 2001 From: yihuaz Date: Tue, 19 May 2020 12:56:12 -0700 Subject: [PATCH 9/9] give a more meaningfule name to functions --- test/local_creds_test.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/test/local_creds_test.go b/test/local_creds_test.go index 8effea99162a..b55b73bdcbce 100644 --- a/test/local_creds_test.go +++ b/test/local_creds_test.go @@ -36,7 +36,7 @@ import ( testpb "google.golang.org/grpc/test/grpc_testing" ) -func testE2ESucceed(network, address string) error { +func testLocalCredsE2ESucceed(network, address string) error { ss := &stubServer{ emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { pr, ok := peer.FromContext(ctx) @@ -102,15 +102,15 @@ func testE2ESucceed(network, address string) error { return nil } -func (s) TestLocalhost(t *testing.T) { - if err := testE2ESucceed("tcp", "localhost:0"); err != nil { +func (s) TestLocalCredsLocalhost(t *testing.T) { + if err := testLocalCredsE2ESucceed("tcp", "localhost:0"); err != nil { t.Fatalf("Failed e2e test for localhost: %v", err) } } -func (s) TestUDS(t *testing.T) { +func (s) TestLocalCredsUDS(t *testing.T) { addr := fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()) - if err := testE2ESucceed("unix", addr); err != nil { + if err := testLocalCredsE2ESucceed("unix", addr); err != nil { t.Fatalf("Failed e2e test for UDS: %v", err) } } @@ -129,7 +129,7 @@ type lisWrapper struct { remote net.Addr } -func newLisWrapper(l net.Listener, remote net.Addr) net.Listener { +func spoofListener(l net.Listener, remote net.Addr) net.Listener { return &lisWrapper{l, remote} } @@ -151,7 +151,7 @@ func spoofDialer(addr net.Addr) func(target string, t time.Duration) (net.Conn, } } -func testE2EFail(dopts []grpc.DialOption) error { +func testLocalCredsE2EFail(dopts []grpc.DialOption) error { ss := &stubServer{ emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) { return &testpb.Empty{}, nil @@ -179,7 +179,7 @@ func testE2EFail(dopts []grpc.DialOption) error { Zone: "", } - go s.Serve(newLisWrapper(lis, fakeClientAddr)) + go s.Serve(spoofListener(lis, fakeClientAddr)) cc, err := grpc.Dial(lis.Addr().String(), append(dopts, grpc.WithDialer(spoofDialer(fakeServerAddr)))...) if err != nil { @@ -199,19 +199,19 @@ func isExpected(got, want error) bool { return status.Code(got) == status.Code(want) && strings.Contains(status.Convert(got).Message(), status.Convert(want).Message()) } -func (s) TestClientFail(t *testing.T) { +func (s) TestLocalCredsClientFail(t *testing.T) { // Use local creds at client-side which should lead to client-side failure. opts := []grpc.DialOption{grpc.WithTransportCredentials(local.NewCredentials())} want := status.Error(codes.Unavailable, "transport: authentication handshake failed: local credentials rejected connection to non-local address") - if err := testE2EFail(opts); !isExpected(err, want) { - t.Fatalf("testE2EFail() = %v; want %v", err, want) + if err := testLocalCredsE2EFail(opts); !isExpected(err, want) { + t.Fatalf("testLocalCredsE2EFail() = %v; want %v", err, want) } } -func (s) TestServerFail(t *testing.T) { +func (s) TestLocalCredsServerFail(t *testing.T) { // Use insecure at client-side which should lead to server-side failure. opts := []grpc.DialOption{grpc.WithInsecure()} - if err := testE2EFail(opts); status.Code(err) != codes.Unavailable { - t.Fatalf("testE2EFail() = %v; want %v", err, codes.Unavailable) + if err := testLocalCredsE2EFail(opts); status.Code(err) != codes.Unavailable { + t.Fatalf("testLocalCredsE2EFail() = %v; want %v", err, codes.Unavailable) } }