Skip to content

Commit

Permalink
credentials: local creds implementation (grpc#3517)
Browse files Browse the repository at this point in the history
Local credentials should be used in either a UDS and local TCP connection. The former will be associated with the security level PrigvacyAndIntegrity while the latter is associated with NoSecurity. Local credentials should be used instead of WithInsecure for localhost connections.
  • Loading branch information
yihuazhang authored May 20, 2020
1 parent 636b0d8 commit 9eb3e7d
Show file tree
Hide file tree
Showing 6 changed files with 540 additions and 6 deletions.
8 changes: 5 additions & 3 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ type PerRPCCredentials interface {
type SecurityLevel int

const (
// NoSecurity indicates a connection is insecure.
// Invalid indicates an invalid security level.
// The zero SecurityLevel value is invalid for backward compatibility.
NoSecurity SecurityLevel = iota + 1
Invalid SecurityLevel = iota
// NoSecurity indicates a connection is insecure.
NoSecurity
// IntegrityOnly indicates a connection only provides integrity protection.
IntegrityOnly
// PrivacyAndIntegrity indicates a connection provides both privacy and integrity protection.
Expand Down Expand Up @@ -237,7 +239,7 @@ func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
}
if ci, ok := ri.AuthInfo.(internalInfo); ok {
// CommonAuthInfo.SecurityLevel has an invalid value.
if ci.GetCommonAuthInfo().SecurityLevel == 0 {
if ci.GetCommonAuthInfo().SecurityLevel == Invalid {
return nil
}
if ci.GetCommonAuthInfo().SecurityLevel < level {
Expand Down
4 changes: 2 additions & 2 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ func (s) TestCheckSecurityLevel(t *testing.T) {
want: true,
},
{
authLevel: 0,
authLevel: Invalid,
testLevel: IntegrityOnly,
want: true,
},
{
authLevel: 0,
authLevel: Invalid,
testLevel: PrivacyAndIntegrity,
want: true,
},
Expand Down
109 changes: 109 additions & 0 deletions credentials/local/local.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

// Package local implements local transport credentials.
// Local credentials reports the security level based on the type
// of connetion. If the connection is local TCP, NoSecurity will be
// reported, and if the connection is UDS, PrivacyAndIntegrity will be
// reported. If local credentials is not used in local connections
// (local TCP or UDS), it will fail.
//
// This package is EXPERIMENTAL.
package local

import (
"context"
"fmt"
"net"
"strings"

"google.golang.org/grpc/credentials"
)

// Info contains the auth information for a local connection.
// It implements the AuthInfo interface.
type Info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
return "local"
}

// localTC is the credentials required to establish a local connection.
type localTC struct {
info credentials.ProtocolInfo
}

func (c *localTC) Info() credentials.ProtocolInfo {
return c.info
}

// getSecurityLevel returns the security level for a local connection.
// It returns an error if a connection is not local.
func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) {
switch {
// Local TCP connection
case strings.HasPrefix(addr, "127."), strings.HasPrefix(addr, "[::1]:"):
return credentials.NoSecurity, nil
// UDS connection
case network == "unix":
return credentials.PrivacyAndIntegrity, nil
// Not a local connection and should fail
default:
return credentials.Invalid, fmt.Errorf("local credentials rejected connection to non-local address %q", addr)
}
}

func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

// NewCredentials returns a local credential implementing credentials.TransportCredentials.
func NewCredentials() credentials.TransportCredentials {
return &localTC{
info: credentials.ProtocolInfo{
SecurityProtocol: "local",
},
}
}

// Clone makes a copy of Local credentials.
func (c *localTC) Clone() credentials.TransportCredentials {
return &localTC{info: c.info}
}

// OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server.
// Since this feature is specific to TLS (SNI + hostname verification check), it does not take any effet for local credentials.
func (c *localTC) OverrideServerName(serverNameOverride string) error {
c.info.ServerName = serverNameOverride
return nil
}
204 changes: 204 additions & 0 deletions credentials/local/local_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package local

import (
"context"
"fmt"
"net"
"runtime"
"strings"
"testing"
"time"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpctest"
)

type s struct {
grpctest.Tester
}

func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}

func (s) TestGetSecurityLevel(t *testing.T) {
testCases := []struct {
testNetwork string
testAddr string
want credentials.SecurityLevel
}{
{
testNetwork: "tcp",
testAddr: "127.0.0.1:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "tcp",
testAddr: "[::1]:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "unix",
testAddr: "/tmp/grpc_fullstack_test",
want: credentials.PrivacyAndIntegrity,
},
{
testNetwork: "tcp",
testAddr: "192.168.0.1:10000",
want: credentials.Invalid,
},
}
for _, tc := range testCases {
got, _ := getSecurityLevel(tc.testNetwork, tc.testAddr)
if got != tc.want {
t.Fatalf("GetSeurityLevel(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String())
}
}
}

type serverHandshake func(net.Conn) (credentials.AuthInfo, error)

// Server local handshake implementation.
func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) {
cred := NewCredentials()
_, authInfo, err := cred.ServerHandshake(conn)
if err != nil {
return nil, err
}
return authInfo, nil
}

// Client local handshake implementation.
func clientLocalHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) {
cred := NewCredentials()
_, authInfo, err := cred.ClientHandshake(context.Background(), lisAddr, conn)
if err != nil {
return nil, err
}
return authInfo, nil
}

// Client connects to a server with local credentials.
func clientHandle(hs func(net.Conn, string) (credentials.AuthInfo, error), network, lisAddr string) (credentials.AuthInfo, error) {
conn, _ := net.Dial(network, lisAddr)
defer conn.Close()
clientAuthInfo, err := hs(conn, lisAddr)
if err != nil {
return nil, fmt.Errorf("Error on client while handshake")
}
return clientAuthInfo, nil
}

type testServerHandleResult struct {
authInfo credentials.AuthInfo
err error
}

// Server accepts a client's connection with local credentials.
func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net.Listener) {
serverRawConn, err := lis.Accept()
if err != nil {
done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)}
}
serverAuthInfo, err := hs(serverRawConn)
if err != nil {
serverRawConn.Close()
done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)}
}
done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil}
}

func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, error) {
done := make(chan testServerHandleResult, 1)
const timeout = 5 * time.Second
timer := time.NewTimer(timeout)
defer timer.Stop()
go serverHandle(serverLocalHandshake, done, lis)
defer lis.Close()
clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String())
if err != nil {
return credentials.Invalid, fmt.Errorf("Error at client-side: %v", err)
}
select {
case <-timer.C:
return credentials.Invalid, fmt.Errorf("Test didn't finish in time")
case serverHandleResult := <-done:
if serverHandleResult.err != nil {
return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
}
clientLocal, _ := clientAuthInfo.(Info)
serverLocal, _ := serverHandleResult.authInfo.(Info)
clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel
serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel
if clientSecLevel != serverSecLevel {
return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
}
return clientSecLevel, nil
}
}

func (s) TestServerAndClientHandshake(t *testing.T) {
testCases := []struct {
testNetwork string
testAddr string
want credentials.SecurityLevel
}{
{
testNetwork: "tcp",
testAddr: "127.0.0.1:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "tcp",
testAddr: "[::1]:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "tcp",
testAddr: "localhost:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "unix",
testAddr: fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()),
want: credentials.PrivacyAndIntegrity,
},
}
for _, tc := range testCases {
if runtime.GOOS == "windows" && tc.testNetwork == "unix" {
t.Skip("skipping tests for unix connections on Windows")
}
t.Run("serverAndClientHandshakeResult", func(t *testing.T) {
lis, err := net.Listen(tc.testNetwork, tc.testAddr)
if err != nil {
if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
strings.Contains(err.Error(), "socket: address family not supported by protocol") {
t.Skipf("no support for address %v", tc.testAddr)
}
t.Fatalf("Failed to listen: %v", err)
}
got, err := serverAndClientHandshake(lis)
if got != tc.want {
t.Fatalf("serverAndClientHandshake(%s, %s) = %v, %v; want %v, nil", tc.testNetwork, tc.testAddr, got, err, tc.want)
}
})
}
}
4 changes: 3 additions & 1 deletion internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
// address specific arbitrary data to reach the credential handshaker.
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
scheme = "https"
conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn)
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
isSecure = true
if transportCreds.Info().SecurityProtocol == "tls" {
scheme = "https"
}
}
dynamicWindow := true
icwz := int32(initialWindowSize)
Expand Down
Loading

0 comments on commit 9eb3e7d

Please sign in to comment.