diff --git a/client/auth.go b/client/auth.go index 7e422c6e2..269d80d84 100644 --- a/client/auth.go +++ b/client/auth.go @@ -39,6 +39,11 @@ func NewAuthInterceptor(apiKey, token string) *AuthInterceptor { } } +// SetToken sets the token. +func (i *AuthInterceptor) SetToken(token string) { + i.token = token +} + // WrapUnary creates a unary server interceptor for authorization. func (i *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func( diff --git a/client/client.go b/client/client.go index 3d8e4a466..282a19858 100644 --- a/client/client.go +++ b/client/client.go @@ -99,6 +99,7 @@ type Client struct { client v1connect.YorkieServiceClient options Options clientOptions []connect.ClientOption + interceptor *AuthInterceptor logger *zap.Logger id *time.ActorID @@ -149,8 +150,8 @@ func New(opts ...Option) (*Client, error) { } var clientOptions []connect.ClientOption - - clientOptions = append(clientOptions, connect.WithInterceptors(NewAuthInterceptor(options.APIKey, options.Token))) + interceptor := NewAuthInterceptor(options.APIKey, options.Token) + clientOptions = append(clientOptions, connect.WithInterceptors(interceptor)) if options.MaxCallRecvMsgSize != 0 { clientOptions = append(clientOptions, connect.WithReadMaxBytes(options.MaxCallRecvMsgSize)) } @@ -169,6 +170,7 @@ func New(opts ...Option) (*Client, error) { clientOptions: clientOptions, options: options, logger: logger, + interceptor: interceptor, key: k, status: deactivated, @@ -183,7 +185,6 @@ func Dial(rpcAddr string, opts ...Option) (*Client, error) { return nil, err } - cli.options.RPCAddress = rpcAddr if err := cli.Dial(rpcAddr); err != nil { return nil, err } @@ -206,19 +207,9 @@ func (c *Client) Dial(rpcAddr string) error { return nil } -// SetToken updates the client's token for reauthentication purposes. -func (c *Client) SetToken(token string) error { - newClientOptions := []connect.ClientOption{ - connect.WithInterceptors(NewAuthInterceptor(c.options.APIKey, token)), - } - if c.options.MaxCallRecvMsgSize != 0 { - newClientOptions = append(newClientOptions, - connect.WithReadMaxBytes(c.options.MaxCallRecvMsgSize)) - } - c.clientOptions = newClientOptions - - c.conn.CloseIdleConnections() - return c.Dial(c.options.RPCAddress) +// SetToken sets the given token of this client. +func (c *Client) SetToken(token string) { + c.interceptor.SetToken(token) } // Close closes all resources of this client. diff --git a/client/options.go b/client/options.go index 03799cd8f..fcb03b6a9 100644 --- a/client/options.go +++ b/client/options.go @@ -41,9 +41,6 @@ type Options struct { // CertFile is the path to the certificate file. CertFile string - // RPCAddress is the address of the RPC server. - RPCAddress string - // ServerNameOverride is the server name override. ServerNameOverride string diff --git a/test/integration/auth_webhook_test.go b/test/integration/auth_webhook_test.go index 4dc405812..b8f9692b4 100644 --- a/test/integration/auth_webhook_test.go +++ b/test/integration/auth_webhook_test.go @@ -604,7 +604,7 @@ func TestAuthWebhookCache(t *testing.T) { } func TestAuthWebhookNewToken(t *testing.T) { - t.Run("set new token when receiving invalid token test", func(t *testing.T) { + t.Run("set valid token after invalid token test", func(t *testing.T) { ctx := context.Background() authServer, validToken := newAuthServer(t) @@ -636,15 +636,12 @@ func TestAuthWebhookNewToken(t *testing.T) { defer func() { assert.NoError(t, cli.Close()) }() err = cli.Activate(ctx) - // reactivate with new token - if err != nil { - metadata := converter.ErrorMetadataOf(err) - if metadata["reason"] == "invalid token" { - err = cli.SetToken(validToken) - assert.NoError(t, err) - err = cli.Activate(ctx) - assert.NoError(t, err) - } - } + assert.Equal(t, connect.CodeUnauthenticated, connect.CodeOf(err)) + + // activate again with valid token + metadata := converter.ErrorMetadataOf(err) + assert.Equal(t, "invalid token", metadata["reason"]) + cli.SetToken(validToken) + assert.NoError(t, cli.Activate(ctx)) }) }