diff --git a/credentials/local/local.go b/credentials/local/local.go new file mode 100644 index 000000000000..ad559293d7ea --- /dev/null +++ b/credentials/local/local.go @@ -0,0 +1,104 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package local includes implementation of local credentials. +// The local credential should be used in either a local TCP or UDS connection. +// This package is experimental. +package local + +import ( + "context" + "fmt" + "net" + "strings" + + "google.golang.org/grpc/credentials" +) + +// Info contains the auth information for a local connection. +// It implements the AuthInfo interface. +type Info struct { + credentials.CommonAuthInfo +} + +// AuthType returns the type of Info as a string. +func (t Info) AuthType() string { + return "local" +} + +// localTC is the credentials required to eatablish a local connection. +type localTC struct { + info *credentials.ProtocolInfo +} + +func (c *localTC) Info() credentials.ProtocolInfo { + return *c.info +} + +// getSecurityLevel returns the security level for a local connection. +// It returns an error if a connection is not local. +func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) { + switch { + // Local TCP connection + case strings.HasPrefix(addr, "127."), strings.HasPrefix(addr, "localhost:"), strings.HasPrefix(addr, "[::1]:"): + return credentials.NoSecurity, nil + // UDS connection + case network == "unix": + return credentials.PrivacyAndIntegrity, nil + // Not a local connection and should fail + default: + return credentials.Invalid, fmt.Errorf("credentials: local credentials should be used in a local connection") + } +} + +func (c *localTC) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) + if err != nil { + return nil, nil, err + } + return conn, Info{credentials.CommonAuthInfo{secLevel}}, nil +} + +func (c *localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) { + secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) + if err != nil { + return nil, nil, err + } + return conn, Info{credentials.CommonAuthInfo{secLevel}}, nil +} + +// NewCredentials returns a local credential implementing TransportCredentials interface. +func NewCredentials() credentials.TransportCredentials { + return &localTC{ + info: &credentials.ProtocolInfo{ + SecurityProtocol: "local", + }, + } +} + +// Clone makes a copy of Local credentials. +func (c *localTC) Clone() credentials.TransportCredentials { + return &localTC{info: &(*c.info)} +} + +// OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server. +// Since this feature is specific to TLS (SNI + hostname verification check), it does not take any effet for local credentials. +func (c *localTC) OverrideServerName(serverNameOverride string) error { + c.info.ServerName = serverNameOverride + return nil +} diff --git a/credentials/local/local_test.go b/credentials/local/local_test.go new file mode 100644 index 000000000000..f04fae89ea5a --- /dev/null +++ b/credentials/local/local_test.go @@ -0,0 +1,209 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package local + +import ( + "context" + "net" + "runtime" + "strings" + "testing" + "time" + + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +func (s) TestGetSecurityLevel(t *testing.T) { + testCases := []struct { + testNetwork string + testAddr string + want credentials.SecurityLevel + }{ + { + testNetwork: "tcp", + testAddr: "127.0.0.1:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "[::1]:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "localhost:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "unix", + testAddr: "/tmp/grpc_fullstack_test", + want: credentials.PrivacyAndIntegrity, + }, + { + testNetwork: "tcp", + testAddr: "192.168.0.1:10000", + want: credentials.Invalid, + }, + } + for _, tc := range testCases { + got, _ := getSecurityLevel(tc.testNetwork, tc.testAddr) + if got != tc.want { + t.Fatalf("GetSeurityLevel(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String()) + } + } +} + +type serverHandshake func(net.Conn) (credentials.AuthInfo, error) + +// Server local handshake implementation. +func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) { + cred := NewCredentials() + _, authInfo, err := cred.ServerHandshake(conn) + if err != nil { + return nil, err + } + return authInfo, nil +} + +// Client local handshake implementation. +func clientLocalHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) { + cred := NewCredentials() + _, authInfo, err := cred.ClientHandshake(context.Background(), lisAddr, conn) + if err != nil { + return nil, err + } + return authInfo, nil +} + +// Client connects to a server with local credentials. +func clientHandle(t *testing.T, hs func(net.Conn, string) (credentials.AuthInfo, error), network, lisAddr string) (credentials.AuthInfo, error) { + conn, _ := net.Dial(network, lisAddr) + defer conn.Close() + clientAuthInfo, err := hs(conn, lisAddr) + if err != nil { + t.Fatal("Error on client while handshake.") + return nil, err + } + return clientAuthInfo, nil +} + +type testServerHandleResult struct { + authInfo credentials.AuthInfo + err error +} + +// Server accepts a client's connection with local credentials. +func serverHandle(t *testing.T, hs serverHandshake, done chan testServerHandleResult, lis net.Listener) { + serverRawConn, err := lis.Accept() + serverAuthInfo, err := hs(serverRawConn) + if err != nil { + t.Errorf("Server failed while handshake. Error: %v", err) + serverRawConn.Close() + done <- testServerHandleResult{authInfo: nil, err: err} + } + done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil} +} + +func ServerAndClientHandshake(t *testing.T, network, listAddr string) credentials.SecurityLevel { + done := make(chan testServerHandleResult, 1) + timeout := time.After(100 * time.Second) + lis, err := net.Listen(network, listAddr) + if err != nil { + if strings.Contains(err.Error(), "bind: cannot assign requested address") || + strings.Contains(err.Error(), "socket: address family not supported by protocol") { + t.Skipf("no support for address %v", listAddr) + } + t.Fatalf("Failed to listen: %v", err) + return credentials.Invalid + } + go serverHandle(t, serverLocalHandshake, done, lis) + defer lis.Close() + clientAuthInfo, err := clientHandle(t, clientLocalHandshake, network, lis.Addr().String()) + if err != nil { + t.Fatalf("Error at client-side: Error: %v", err) + return credentials.Invalid + } + select { + case <-timeout: + t.Fatal("Test didn't finish in time") + return credentials.Invalid + case serverHandleResult := <-done: + if serverHandleResult.err != nil { + t.Fatalf("Error at server-side: Error: %v", serverHandleResult.err) + return credentials.Invalid + } + clientLocal, _ := clientAuthInfo.(Info) + serverLocal, _ := serverHandleResult.authInfo.(Info) + clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel + serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel + if clientSecLevel != serverSecLevel { + t.Fatalf("client's AuthInfo contains %s but server's AuthInfo contains %s.", clientSecLevel.String(), serverSecLevel.String()) + return credentials.Invalid + } + return clientSecLevel + } + return credentials.Invalid +} + +func (s) TestServerAndClientHandshake(t *testing.T) { + testCases := []struct { + testNetwork string + testAddr string + want credentials.SecurityLevel + }{ + { + testNetwork: "tcp", + testAddr: "127.0.0.1:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "[::1]:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "tcp", + testAddr: "localhost:10000", + want: credentials.NoSecurity, + }, + { + testNetwork: "unix", + testAddr: "/tmp/grpc_fullstack_test", + want: credentials.PrivacyAndIntegrity, + }, + } + for _, tc := range testCases { + if runtime.GOOS == "windows" && tc.testNetwork == "unix" { + continue + } + got := ServerAndClientHandshake(t, tc.testNetwork, tc.testAddr) + if got != tc.want { + t.Fatalf("ServerAndClientHandshake(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String()) + } + } +}