Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

credentials: local creds implementation #3517

Merged
merged 9 commits into from
May 20, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion credentials/local/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (s) TestServerAndClientHandshake(t *testing.T) {
}
got, err := serverAndClientHandshake(lis)
if got != tc.want {
t.Fatalf("ServerAndClientHandshake(%s, %s) returned %s but want %s. Error: %v", tc.testNetwork, tc.testAddr, got.String(), tc.want.String(), err)
t.Fatalf("serverAndClientHandshake(%s, %s) = %v, %v; want %v, nil", tc.testNetwork, tc.testAddr, got, err, tc.want)
}
})
}
Expand Down
39 changes: 2 additions & 37 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/local"
"google.golang.org/grpc/encoding"
_ "google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/health"
Expand Down Expand Up @@ -215,43 +214,18 @@ func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*
if s.security != "" {
// Check Auth info
var authType, serverName string
var secLevel credentials.SecurityLevel
switch info := pr.AuthInfo.(type) {
case credentials.TLSInfo:
authType = info.AuthType()
serverName = info.State.ServerName
secLevel = info.CommonAuthInfo.SecurityLevel
case local.Info:
authType = info.AuthType()
secLevel = info.CommonAuthInfo.SecurityLevel
default:
return nil, status.Error(codes.Unauthenticated, "Unknown AuthInfo type")
}
if authType != s.security {
return nil, status.Errorf(codes.Unauthenticated, "Wrong auth type: got %q, want %q", authType, s.security)
}

// Check Auth info specific to credentials.TLSInfo
if s.security == "tls" {
if secLevel != credentials.PrivacyAndIntegrity {
return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String())
}
if serverName != "x.test.youtube.com" {
return nil, status.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName)
}
}
// Check Auth info specific to local.Info
if s.security == "local" {
switch pr.Addr.Network() {
case "tcp":
if secLevel != credentials.NoSecurity {
return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.NoSecurity.String())
}
case "unix":
if secLevel != credentials.PrivacyAndIntegrity {
return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String())
}
}
if serverName != "x.test.youtube.com" {
return nil, status.Errorf(codes.Unauthenticated, "Unknown server name %q", serverName)
}
}
// Simulate some service delay.
Expand Down Expand Up @@ -420,7 +394,6 @@ func (s *testServer) HalfDuplexCall(stream testpb.TestService_HalfDuplexCallServ
type env struct {
name string
network string // The type of network such as tcp, unix, etc.
listenerAddr string // The address of listener.
security string // The security protocol such as TLS, SSH, etc.
httpHandler bool // whether to use the http.Handler ServerTransport; requires TLS
balancer string // One of "round_robin", "pick_first", or "".
Expand All @@ -444,7 +417,6 @@ func (e env) dialer(addr string, timeout time.Duration) (net.Conn, error) {
var (
tcpClearEnv = env{name: "tcp-clear-v1-balancer", network: "tcp"}
tcpTLSEnv = env{name: "tcp-tls-v1-balancer", network: "tcp", security: "tls"}
tcpLocalEnv = env{name: "tcp-local-v1-balancer", network: "tcp", listenerAddr: "[::1]:0", security: "local"}
tcpClearRREnv = env{name: "tcp-clear", network: "tcp", balancer: "round_robin"}
tcpTLSRREnv = env{name: "tcp-tls", network: "tcp", security: "tls", balancer: "round_robin"}
handlerEnv = env{name: "handler-tls", network: "tcp", security: "tls", httpHandler: true, balancer: "round_robin"}
Expand Down Expand Up @@ -635,9 +607,6 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network,
sopts = append(sopts, grpc.InitialConnWindowSize(te.serverInitialConnWindowSize))
}
la := "localhost:0"
if te.e.listenerAddr != "" {
la = te.e.listenerAddr
}
switch te.e.network {
case "unix":
la = "/tmp/testsock" + fmt.Sprintf("%d", time.Now().UnixNano())
Expand All @@ -653,8 +622,6 @@ func (te *test) listenAndServe(ts testpb.TestServiceServer, listen func(network,
te.t.Fatalf("Failed to generate credentials %v", err)
}
sopts = append(sopts, grpc.Creds(creds))
} else if te.e.security == "local" {
sopts = append(sopts, grpc.Creds(local.NewCredentials()))
}
sopts = append(sopts, te.customServerOptions...)
s := grpc.NewServer(sopts...)
Expand Down Expand Up @@ -831,8 +798,6 @@ func (te *test) configDial(opts ...grpc.DialOption) ([]grpc.DialOption, string)
te.t.Fatalf("Failed to load credentials: %v", err)
}
opts = append(opts, grpc.WithTransportCredentials(creds))
case "local":
opts = append(opts, grpc.WithTransportCredentials(local.NewCredentials()))
case "empty":
// Don't add any transport creds option.
default:
Expand Down
111 changes: 68 additions & 43 deletions test/local_creds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,34 @@ import (
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/local"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
testpb "google.golang.org/grpc/test/grpc_testing"
)

func testE2ESucceed(network, address string) error {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is in a file called local_creds_test.go, but it's in a huge testing package. Please give this a name that includes LocalCreds. Same for everything global in this file, especially the Test____ functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ss := &stubServer{
emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
pr, ok := peer.FromContext(ctx)
if !ok {
return nil, status.Error(codes.DataLoss, "Failed to get peer from ctx")
}
// Check security level
info := pr.AuthInfo.(local.Info)
secLevel := info.CommonAuthInfo.SecurityLevel
switch network {
case "unix":
if secLevel != credentials.PrivacyAndIntegrity {
return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.PrivacyAndIntegrity.String())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure fmt.Sprintf (which status.Errorf uses) will call .String() for you in this case (with %q) - won't it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right. Thanks for the suggestion. The code is updated.

}
case "tcp":
if secLevel != credentials.NoSecurity {
return nil, status.Errorf(codes.Unauthenticated, "Wrong security level: got %q, want %q", secLevel.String(), credentials.NoSecurity.String())
}
}
return &testpb.Empty{}, nil
},
}
Expand All @@ -52,15 +73,17 @@ func testE2ESucceed(network, address string) error {
go s.Serve(lis)

var cc *grpc.ClientConn
if network == "unix" {
cc, err = grpc.Dial("passthrough:///"+address, grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithContextDialer(
switch network {
case "unix":
cc, err = grpc.Dial(address, grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithContextDialer(
func(ctx context.Context, addr string) (net.Conn, error) {
return net.Dial("unix", address)
}))
} else {
case "tcp":
cc, err = grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(local.NewCredentials()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this also dial address to be consistent with the above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By dialing address it will dial with port 0, which is incorrect. I updated the test to use lisAddr in both unix and tcp cases.

default:
return fmt.Errorf("unsupported network %q", network)
}

if err != nil {
return fmt.Errorf("Failed to dial server: %v, %v", err, lis.Addr().String())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same: address ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated it to use lisAddr instead.

}
Expand All @@ -73,21 +96,18 @@ func testE2ESucceed(network, address string) error {
if _, err := c.EmptyCall(ctx, &testpb.Empty{}); err != nil {
return fmt.Errorf("EmptyCall(_, _) = _, %v; want _, <nil>", err)
}

return nil
}

func (s) TestLocalhost(t *testing.T) {
err := testE2ESucceed("tcp", "localhost:0")
if err != nil {
if err := testE2ESucceed("tcp", "localhost:0"); err != nil {
t.Fatalf("Failed e2e test for localhost: %v", err)
}
}

func (s) TestUDS(t *testing.T) {
addr := fmt.Sprintf("/tmp/grpc_fullstck_test%d", time.Now().UnixNano())
err := testE2ESucceed("unix", addr)
if err != nil {
if err := testE2ESucceed("unix", addr); err != nil {
t.Fatalf("Failed e2e test for UDS: %v", err)
}
}
Expand All @@ -103,40 +123,32 @@ func (c connWrapper) RemoteAddr() net.Addr {

type lisWrapper struct {
net.Listener
remote net.Addr
}

func newLisWrapper(l net.Listener) net.Listener {
return &lisWrapper{l}
}

var remoteAddrs = []net.Addr{
&net.IPAddr{
IP: net.ParseIP("10.8.9.10"),
Zone: "",
},
&net.IPAddr{
IP: net.ParseIP("10.8.9.11"),
Zone: "",
},
func newLisWrapper(l net.Listener, remote net.Addr) net.Listener {
return &lisWrapper{l, remote}
}

func (l *lisWrapper) Accept() (net.Conn, error) {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return connWrapper{c, remoteAddrs[0]}, nil
return connWrapper{c, l.remote}, nil
}

func dialer(target string, t time.Duration) (net.Conn, error) {
c, err := net.DialTimeout("tcp", target, t)
if err != nil {
return nil, err
func spoofDialer(addr net.Addr) func(target string, t time.Duration) (net.Conn, error) {
return func(t string, d time.Duration) (net.Conn, error) {
c, err := net.DialTimeout("tcp", t, d)
if err != nil {
return nil, err
}
return connWrapper{c, addr}, nil
}
return connWrapper{c, remoteAddrs[1]}, nil
}

func testE2EFail(useLocal bool) error {
func testE2EFail(dopts []grpc.DialOption) error {
ss := &stubServer{
emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
return &testpb.Empty{}, nil
Expand All @@ -154,15 +166,19 @@ func testE2EFail(useLocal bool) error {
return fmt.Errorf("Failed to create listener: %v", err)
}

go s.Serve(newLisWrapper(lis))

var cc *grpc.ClientConn
if useLocal {
cc, err = grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(local.NewCredentials()), grpc.WithDialer(dialer))
} else {
cc, err = grpc.Dial(lis.Addr().String(), grpc.WithInsecure(), grpc.WithDialer(dialer))
var fakeClientAddr, fakeServerAddr net.Addr
fakeClientAddr = &net.IPAddr{
IP: net.ParseIP("10.8.9.10"),
Zone: "",
}
fakeServerAddr = &net.IPAddr{
IP: net.ParseIP("10.8.9.11"),
Zone: "",
}

go s.Serve(newLisWrapper(lis, fakeClientAddr))

cc, err := grpc.Dial(lis.Addr().String(), append(dopts, grpc.WithDialer(spoofDialer(fakeServerAddr)))...)
if err != nil {
return fmt.Errorf("Failed to dial server: %v, %v", err, lis.Addr().String())
}
Expand All @@ -176,18 +192,27 @@ func testE2EFail(useLocal bool) error {
return err
}

func isExpected(got, want error) bool {
if status.Code(got) == status.Code(want) && strings.Contains(status.Convert(got).Message(), status.Convert(want).Message()) {
return true
}
return false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplify: return <condition>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

func (s) TestClientFail(t *testing.T) {
// Use local creds at client-side which should lead to client-side failure.
err := testE2EFail(true /*useLocal*/)
if err == nil || !strings.Contains(err.Error(), "local credentials rejected connection to non-local address") {
t.Fatalf("testE2EFail(%v) = _; want security handshake fails, %v", false, err)
opts := []grpc.DialOption{grpc.WithTransportCredentials(local.NewCredentials())}
want := status.Error(codes.Unavailable, "transport: authentication handshake failed: local credentials rejected connection to non-local address")
if err := testE2EFail(opts); !isExpected(err, want) {
t.Fatalf("testE2EFail() = %v; want %v", err, want)
}
}

func (s) TestServerFail(t *testing.T) {
// Use insecure at client-side which should lead to server-side failure.
err := testE2EFail(false /*useLocal*/)
if err == nil {
t.Fatalf("testE2EFail(%v) = _; want security handshake fails, %v", true, err)
opts := []grpc.DialOption{grpc.WithInsecure()}
want := status.Error(codes.Unavailable, "connection closed")
if err := testE2EFail(opts); !isExpected(err, want) {
t.Fatalf("testE2EFail() = %v; want %v", err, want)
}
}