Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

credentials: local creds implementation #3517

Merged
merged 9 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions credentials/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ type PerRPCCredentials interface {
type SecurityLevel int

const (
// NoSecurity indicates a connection is insecure.
// Invalid indicates an invalid security level.
// The zero SecurityLevel value is invalid for backward compatibility.
NoSecurity SecurityLevel = iota + 1
Invalid SecurityLevel = iota
// NoSecurity indicates a connection is insecure.
NoSecurity
// IntegrityOnly indicates a connection only provides integrity protection.
IntegrityOnly
// PrivacyAndIntegrity indicates a connection provides both privacy and integrity protection.
Expand Down Expand Up @@ -237,7 +239,7 @@ func CheckSecurityLevel(ctx context.Context, level SecurityLevel) error {
}
if ci, ok := ri.AuthInfo.(internalInfo); ok {
// CommonAuthInfo.SecurityLevel has an invalid value.
if ci.GetCommonAuthInfo().SecurityLevel == 0 {
if ci.GetCommonAuthInfo().SecurityLevel == Invalid {
return nil
}
if ci.GetCommonAuthInfo().SecurityLevel < level {
Expand Down
4 changes: 2 additions & 2 deletions credentials/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ func (s) TestCheckSecurityLevel(t *testing.T) {
want: true,
},
{
authLevel: 0,
authLevel: Invalid,
testLevel: IntegrityOnly,
want: true,
},
{
authLevel: 0,
authLevel: Invalid,
testLevel: PrivacyAndIntegrity,
want: true,
},
Expand Down
109 changes: 109 additions & 0 deletions credentials/local/local.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

// Package local implements local transport credentials.
// Local credentials reports the security level based on the type
// of connetion. If the connection is local TCP, NoSecurity will be
// reported, and if the connection is UDS, PrivacyAndIntegrity will be
// reported. If local credentials is not used in local connections
// (local TCP or UDS), it will fail.
//
// This package is EXPERIMENTAL.
package local

import (
"context"
"fmt"
"net"
"strings"

"google.golang.org/grpc/credentials"
)

// Info contains the auth information for a local connection.
// It implements the AuthInfo interface.
type Info struct {
credentials.CommonAuthInfo
}

// AuthType returns the type of Info as a string.
func (Info) AuthType() string {
return "local"
}

// localTC is the credentials required to establish a local connection.
type localTC struct {
info credentials.ProtocolInfo
}

func (c *localTC) Info() credentials.ProtocolInfo {
return c.info
}

// getSecurityLevel returns the security level for a local connection.
// It returns an error if a connection is not local.
func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) {
switch {
// Local TCP connection
case strings.HasPrefix(addr, "127."), strings.HasPrefix(addr, "[::1]:"):
return credentials.NoSecurity, nil
// UDS connection
case network == "unix":
return credentials.PrivacyAndIntegrity, nil
// Not a local connection and should fail
default:
return credentials.Invalid, fmt.Errorf("local credentials rejected connection to non-local address %q", addr)
}
}

func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
if err != nil {
return nil, nil, err
}
return conn, Info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
}

// NewCredentials returns a local credential implementing credentials.TransportCredentials.
func NewCredentials() credentials.TransportCredentials {
return &localTC{
info: credentials.ProtocolInfo{
SecurityProtocol: "local",
},
}
}

// Clone makes a copy of Local credentials.
func (c *localTC) Clone() credentials.TransportCredentials {
return &localTC{info: c.info}
}

// OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server.
// Since this feature is specific to TLS (SNI + hostname verification check), it does not take any effet for local credentials.
func (c *localTC) OverrideServerName(serverNameOverride string) error {
c.info.ServerName = serverNameOverride
return nil
}
204 changes: 204 additions & 0 deletions credentials/local/local_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
*
* Copyright 2020 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package local

import (
"context"
"fmt"
"net"
"runtime"
"strings"
"testing"
"time"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal/grpctest"
)

type s struct {
grpctest.Tester
}

func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}

func (s) TestGetSecurityLevel(t *testing.T) {
testCases := []struct {
testNetwork string
testAddr string
want credentials.SecurityLevel
}{
{
testNetwork: "tcp",
testAddr: "127.0.0.1:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "tcp",
testAddr: "[::1]:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "unix",
testAddr: "/tmp/grpc_fullstack_test",
want: credentials.PrivacyAndIntegrity,
},
{
testNetwork: "tcp",
testAddr: "192.168.0.1:10000",
want: credentials.Invalid,
},
}
for _, tc := range testCases {
got, _ := getSecurityLevel(tc.testNetwork, tc.testAddr)
if got != tc.want {
t.Fatalf("GetSeurityLevel(%s, %s) returned %s but want %s", tc.testNetwork, tc.testAddr, got.String(), tc.want.String())
}
}
}

type serverHandshake func(net.Conn) (credentials.AuthInfo, error)

// Server local handshake implementation.
func serverLocalHandshake(conn net.Conn) (credentials.AuthInfo, error) {
cred := NewCredentials()
_, authInfo, err := cred.ServerHandshake(conn)
if err != nil {
return nil, err
}
return authInfo, nil
}

// Client local handshake implementation.
func clientLocalHandshake(conn net.Conn, lisAddr string) (credentials.AuthInfo, error) {
cred := NewCredentials()
_, authInfo, err := cred.ClientHandshake(context.Background(), lisAddr, conn)
if err != nil {
return nil, err
}
return authInfo, nil
}

// Client connects to a server with local credentials.
func clientHandle(hs func(net.Conn, string) (credentials.AuthInfo, error), network, lisAddr string) (credentials.AuthInfo, error) {
conn, _ := net.Dial(network, lisAddr)
defer conn.Close()
clientAuthInfo, err := hs(conn, lisAddr)
if err != nil {
return nil, fmt.Errorf("Error on client while handshake")
}
return clientAuthInfo, nil
}

type testServerHandleResult struct {
authInfo credentials.AuthInfo
err error
}

// Server accepts a client's connection with local credentials.
func serverHandle(hs serverHandshake, done chan testServerHandleResult, lis net.Listener) {
serverRawConn, err := lis.Accept()
if err != nil {
done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed to accept connection. Error: %v", err)}
}
serverAuthInfo, err := hs(serverRawConn)
if err != nil {
serverRawConn.Close()
done <- testServerHandleResult{authInfo: nil, err: fmt.Errorf("Server failed while handshake. Error: %v", err)}
}
done <- testServerHandleResult{authInfo: serverAuthInfo, err: nil}
}

func serverAndClientHandshake(lis net.Listener) (credentials.SecurityLevel, error) {
done := make(chan testServerHandleResult, 1)
const timeout = 5 * time.Second
timer := time.NewTimer(timeout)
defer timer.Stop()
go serverHandle(serverLocalHandshake, done, lis)
defer lis.Close()
clientAuthInfo, err := clientHandle(clientLocalHandshake, lis.Addr().Network(), lis.Addr().String())
if err != nil {
return credentials.Invalid, fmt.Errorf("Error at client-side: %v", err)
}
select {
case <-timer.C:
return credentials.Invalid, fmt.Errorf("Test didn't finish in time")
case serverHandleResult := <-done:
if serverHandleResult.err != nil {
return credentials.Invalid, fmt.Errorf("Error at server-side: %v", serverHandleResult.err)
}
clientLocal, _ := clientAuthInfo.(Info)
serverLocal, _ := serverHandleResult.authInfo.(Info)
clientSecLevel := clientLocal.CommonAuthInfo.SecurityLevel
serverSecLevel := serverLocal.CommonAuthInfo.SecurityLevel
if clientSecLevel != serverSecLevel {
return credentials.Invalid, fmt.Errorf("client's AuthInfo contains %s but server's AuthInfo contains %s", clientSecLevel.String(), serverSecLevel.String())
}
return clientSecLevel, nil
}
}

func (s) TestServerAndClientHandshake(t *testing.T) {
testCases := []struct {
testNetwork string
testAddr string
want credentials.SecurityLevel
}{
{
testNetwork: "tcp",
testAddr: "127.0.0.1:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "tcp",
testAddr: "[::1]:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "tcp",
testAddr: "localhost:10000",
want: credentials.NoSecurity,
},
{
testNetwork: "unix",
testAddr: fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano()),
want: credentials.PrivacyAndIntegrity,
},
}
for _, tc := range testCases {
if runtime.GOOS == "windows" && tc.testNetwork == "unix" {
dfawley marked this conversation as resolved.
Show resolved Hide resolved
t.Skip("skipping tests for unix connections on Windows")
}
t.Run("serverAndClientHandshakeResult", func(t *testing.T) {
lis, err := net.Listen(tc.testNetwork, tc.testAddr)
if err != nil {
if strings.Contains(err.Error(), "bind: cannot assign requested address") ||
strings.Contains(err.Error(), "socket: address family not supported by protocol") {
t.Skipf("no support for address %v", tc.testAddr)
}
t.Fatalf("Failed to listen: %v", err)
}
got, err := serverAndClientHandshake(lis)
if got != tc.want {
t.Fatalf("ServerAndClientHandshake(%s, %s) returned %s but want %s. Error: %v", tc.testNetwork, tc.testAddr, got.String(), tc.want.String(), err)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go style:

t.Fatalf("serverAndClientHandshake(%s, %s) = %v, %v; want %v, nil", tc.testNetwork, tc.testAddr, got, err, want)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
})
}
}
4 changes: 3 additions & 1 deletion internal/transport/http2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,14 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts
// address specific arbitrary data to reach the credential handshaker.
contextWithHandshakeInfo := internal.NewClientHandshakeInfoContext.(func(context.Context, credentials.ClientHandshakeInfo) context.Context)
connectCtx = contextWithHandshakeInfo(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
scheme = "https"
conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn)
if err != nil {
return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err)
}
isSecure = true
if transportCreds.Info().SecurityProtocol == "tls" {
scheme = "https"
}
}
dynamicWindow := true
icwz := int32(initialWindowSize)
Expand Down
Loading