diff --git a/credentials/credentials.go b/credentials/credentials.go index 53addd8c71e9..02766443ae74 100644 --- a/credentials/credentials.go +++ b/credentials/credentials.go @@ -58,9 +58,11 @@ type PerRPCCredentials interface { type SecurityLevel int const ( - // NoSecurity indicates a connection is insecure. + // Invalid indicates an invalid security level. // The zero SecurityLevel value is invalid for backward compatibility. - NoSecurity SecurityLevel = iota + 1 + Invalid SecurityLevel = iota + // NoSecurity indicates a connection is insecure. + NoSecurity // IntegrityOnly indicates a connection only provides integrity protection. IntegrityOnly // PrivacyAndIntegrity indicates a connection provides both privacy and integrity protection. @@ -237,7 +239,7 @@ func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error { } if ci, ok := ri.AuthInfo.(internalInfo); ok { // CommonAuthInfo.SecurityLevel has an invalid value. - if ci.GetCommonAuthInfo().SecurityLevel == 0 { + if ci.GetCommonAuthInfo().SecurityLevel == Invalid { return nil } if ci.GetCommonAuthInfo().SecurityLevel < level { diff --git a/credentials/credentials_test.go b/credentials/credentials_test.go index c2a316281928..5ff4850453bd 100644 --- a/credentials/credentials_test.go +++ b/credentials/credentials_test.go @@ -86,12 +86,12 @@ func (s) TestCheckSecurityLevel(t *testing.T) { want: true, }, { - authLevel: 0, + authLevel: Invalid, testLevel: IntegrityOnly, want: true, }, { - authLevel: 0, + authLevel: Invalid, testLevel: PrivacyAndIntegrity, want: true, }, diff --git a/credentials/local/local.go b/credentials/local/local.go new file mode 100644 index 000000000000..23de34cf8a18 --- /dev/null +++ b/credentials/local/local.go @@ -0,0 +1,109 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package local implements local transport credentials. +// Local credentials reports the security level based on the type +// of connetion. If the connection is local TCP, NoSecurity will be +// reported, and if the connection is UDS, PrivacyAndIntegrity will be +// reported. If local credentials is not used in local connections +// (local TCP or UDS), it will fail. +// +// This package is EXPERIMENTAL. +package local + +import ( + "context" + "fmt" + "net" + "strings" + + "google.golang.org/grpc/credentials" +) + +// Info contains the auth information for a local connection. +// It implements the AuthInfo interface. +type Info struct { + credentials.CommonAuthInfo +} + +// AuthType returns the type of Info as a string. +func (Info) AuthType() string { + return "local" +} + +// localTC is the credentials required to establish a local connection. +type localTC struct { + info credentials.ProtocolInfo +} + +func (c *localTC) Info() credentials.ProtocolInfo { + return c.info +} + +// getSecurityLevel returns the security level for a local connection. +// It returns an error if a connection is not local. +func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) { + switch { + // Local TCP connection + case strings.HasPrefix(addr, "127."), strings.HasPrefix(addr, "[::1]:"): + return credentials.NoSecurity, nil + // UDS connection + case network == "unix": + return credentials.PrivacyAndIntegrity, nil + // Not a local connection and should fail + default: + return credentials.Invalid, fmt.Errorf("local credentials rejected connection to non-local address %q", addr) + } +} + +func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) + if err != nil { + return nil, nil, err + } + return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil +} + +func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) + if err != nil { + return nil, nil, err + } + return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil +} + +// NewCredentials returns a local credential implementing credentials.TransportCredentials. +func NewCredentials() credentials.TransportCredentials { + return &localTC{ + info: credentials.ProtocolInfo{ + SecurityProtocol: "local", + }, + } +} + +// Clone makes a copy of Local credentials. +func (c *localTC) Clone() credentials.TransportCredentials { + return &localTC{info: c.info} +} + +// OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server. +// Since this feature is specific to TLS (SNI + hostname verification check), it does not take any effet for local credentials. +func (c *localTC) OverrideServerName(serverNameOverride string) error { + c.info.ServerName = serverNameOverride + return nil +} diff --git a/credentials/local/local_test.go b/credentials/local/local_test.go new file mode 100644 index 000000000000..8317ec3e8ed6 --- /dev/null +++ b/credentials/local/local_test.go @@ -0,0 +1,204 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package local + +import ( + "context" + "fmt" + "net" + "runtime" + "strings" + "testing" + "time" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestGetSecurityLevel(t *testing.T) { + testCases := []struct { + testNetwork string + testAddr string + want credentials.SecurityLevel + }{ + { + testNetwork: "tcp", + testAddr: "127.0.0.1:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "[::1]:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "unix", + testAddr: "/tmp/grpc_fullstack_test", + want: credentials.PrivacyAndIntegrity, + }, + { + testNetwork: "tcp", + testAddr: "192.168.0.1:10000", + want: credentials.Invalid, + }, + } + for _, tc := range testCases { + got, _ := getSecurityLevel(tc.testNetwork, tc.testAddr) + if got != tc.want { + t.Fatalf("GetSeurityLevel(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String()) + } + } +} + +type serverHandshake func(net.Conn) (credentials.AuthInfo, error) + +// Server local handshake implementation. +func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) { + cred := NewCredentials() + _, authInfo, err := cred.ServerHandshake(conn) + if err != nil { + return nil, err + } + return authInfo, nil +} + +// Client local handshake implementation. +func clientLocalHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) { + cred := NewCredentials() + _, authInfo, err := cred.ClientHandshake(context.Background(), lisAddr, conn) + if err != nil { + return nil, err + } + return authInfo, nil +} + +// Client connects to a server with local credentials. +func clientHandle(hs func(net.Conn, string) (credentials.AuthInfo, error), network, lisAddr string) (credentials.AuthInfo, error) { + conn, _ := net.Dial(network, lisAddr) + defer conn.Close() + clientAuthInfo, err := hs(conn, lisAddr) + if err != nil { + return nil, fmt.Errorf("Error on client while handshake") + } + return clientAuthInfo, nil +} + +type testServerHandleResult struct { + authInfo credentials.AuthInfo + err error +} + +// Server accepts a client's connection with local credentials. +func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net.Listener) { + serverRawConn, err := lis.Accept() + if err != nil { + done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)} + } + serverAuthInfo, err := hs(serverRawConn) + if err != nil { + serverRawConn.Close() + done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)} + } + done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil} +} + +func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, error) { + done := make(chan testServerHandleResult, 1) + const timeout = 5 * time.Second + timer := time.NewTimer(timeout) + defer timer.Stop() + go serverHandle(serverLocalHandshake, done, lis) + defer lis.Close() + clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String()) + if err != nil { + return credentials.Invalid, fmt.Errorf("Error at client-side: %v", err) + } + select { + case <-timer.C: + return credentials.Invalid, fmt.Errorf("Test didn't finish in time") + case serverHandleResult := <-done: + if serverHandleResult.err != nil { + return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err) + } + clientLocal, _ := clientAuthInfo.(Info) + serverLocal, _ := serverHandleResult.authInfo.(Info) + clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel + serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel + if clientSecLevel != serverSecLevel { + return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String()) + } + return clientSecLevel, nil + } +} + +func (s) TestServerAndClientHandshake(t *testing.T) { + testCases := []struct { + testNetwork string + testAddr string + want credentials.SecurityLevel + }{ + { + testNetwork: "tcp", + testAddr: "127.0.0.1:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "[::1]:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "localhost:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "unix", + testAddr: fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()), + want: credentials.PrivacyAndIntegrity, + }, + } + for _, tc := range testCases { + if runtime.GOOS == "windows" && tc.testNetwork == "unix" { + t.Skip("skipping tests for unix connections on Windows") + } + t.Run("serverAndClientHandshakeResult", func(t *testing.T) { + lis, err := net.Listen(tc.testNetwork, tc.testAddr) + if err != nil { + if strings.Contains(err.Error(), "bind: cannot assign requested address") || + strings.Contains(err.Error(), "socket: address family not supported by protocol") { + t.Skipf("no support for address %v", tc.testAddr) + } + t.Fatalf("Failed to listen: %v", err) + } + got, err := serverAndClientHandshake(lis) + if got != tc.want { + t.Fatalf("ServerAndClientHandshake(%s, %s) returned %s but want %s. Error: %v", tc.testNetwork, tc.testAddr, got.String(), tc.want.String(), err) + } + }) + } +} diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index d1eb17e068fe..b4e55d0101ba 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -221,12 +221,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // address specific arbitrary data to reach the credential handshaker. contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context) connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) - scheme = "https" conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn) if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } isSecure = true + if transportCreds.Info().SecurityProtocol == "tls" { + scheme = "https" + } } dynamicWindow := true icwz := int32(initialWindowSize) diff --git a/test/end2end_test.go b/test/end2end_test.go index f3a60de5a96f..939839869ac9 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -50,6 +50,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/local" "google.golang.org/grpc/encoding" _ "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/health" @@ -214,18 +215,43 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (* if s.security != "" { // Check Auth info var authType, serverName string + var secLevel credentials.SecurityLevel switch info := pr.AuthInfo.(type) { case credentials.TLSInfo: authType = info.AuthType() serverName = info.State.ServerName + secLevel = info.CommonAuthInfo.SecurityLevel + case local.Info: + authType = info.AuthType() + secLevel = info.CommonAuthInfo.SecurityLevel default: return nil, status.Error(codes.Unauthenticated, "Unknown AuthInfo type") } if authType != s.security { return nil, status.Errorf(codes.Unauthenticated, "Wrong auth type: got %q, want %q", authType, s.security) } - if serverName != "x.test.youtube.com" { - return nil, status.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName) + + // Check Auth info specific to credentials.TLSInfo + if s.security == "tls" { + if secLevel != credentials.PrivacyAndIntegrity { + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String()) + } + if serverName != "x.test.youtube.com" { + return nil, status.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName) + } + } + // Check Auth info specific to local.Info + if s.security == "local" { + switch pr.Addr.Network() { + case "tcp": + if secLevel != credentials.NoSecurity { + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.NoSecurity.String()) + } + case "unix": + if secLevel != credentials.PrivacyAndIntegrity { + return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String()) + } + } } } // Simulate some service delay. @@ -394,6 +420,7 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ type env struct { name string network string // The type of network such as tcp, unix, etc. + listenerAddr string // The address of listener. security string // The security protocol such as TLS, SSH, etc. httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS balancer string // One of "round_robin", "pick_first", or "". @@ -417,6 +444,7 @@ func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) { var ( tcpClearEnv = env{name: "tcp-clear-v1-balancer", network: "tcp"} tcpTLSEnv = env{name: "tcp-tls-v1-balancer", network: "tcp", security: "tls"} + tcpLocalEnv = env{name: "tcp-local-v1-balancer", network: "tcp", listenerAddr: "[::1]:0", security: "local"} tcpClearRREnv = env{name: "tcp-clear", network: "tcp", balancer: "round_robin"} tcpTLSRREnv = env{name: "tcp-tls", network: "tcp", security: "tls", balancer: "round_robin"} handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true, balancer: "round_robin"} @@ -607,6 +635,9 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network, sopts = append(sopts, grpc.InitialConnWindowSize(te.serverInitialConnWindowSize)) } la := "localhost:0" + if te.e.listenerAddr != "" { + la = te.e.listenerAddr + } switch te.e.network { case "unix": la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now().UnixNano()) @@ -622,6 +653,8 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network, te.t.Fatalf("Failed to generate credentials %v", err) } sopts = append(sopts, grpc.Creds(creds)) + } else if te.e.security == "local" { + sopts = append(sopts, grpc.Creds(local.NewCredentials())) } sopts = append(sopts, te.customServerOptions...) s := grpc.NewServer(sopts...) @@ -798,6 +831,8 @@ func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string) te.t.Fatalf("Failed to load credentials: %v", err) } opts = append(opts, grpc.WithTransportCredentials(creds)) + case "local": + opts = append(opts, grpc.WithTransportCredentials(local.NewCredentials())) case "empty": // Don't add any transport creds option. default: