Skip to content

Commit

Permalink
Fix review: Use AuthType enum values
Browse files Browse the repository at this point in the history
Signed-off-by: Viet Nguyen Duc <nguyenducviet4496@gmail.com>
  • Loading branch information
VietND96 committed Nov 7, 2024
1 parent 69eb1b4 commit 2657eed
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
28 changes: 17 additions & 11 deletions py/selenium/webdriver/remote/client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import base64
import os
import socket
from enum import Enum
from typing import Optional
from urllib import parse

Expand All @@ -26,6 +27,12 @@
from selenium.webdriver.common.proxy import ProxyType


class AuthType(Enum):
BASIC = "Basic"
BEARER = "Bearer"
X_API_KEY = "X-API-Key"


class ClientConfig:
def __init__(
self,
Expand All @@ -38,7 +45,7 @@ def __init__(
ca_certs: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
auth_type: Optional[str] = "Basic",
auth_type: Optional[AuthType] = AuthType.BASIC,
token: Optional[str] = None,
user_agent: Optional[str] = None,
extra_headers: Optional[dict] = None,
Expand Down Expand Up @@ -202,16 +209,16 @@ def password(self, value: str) -> None:
self._password = value

@property
def auth_type(self) -> str:
def auth_type(self) -> AuthType:
"""Returns the type of authentication to the remote server."""
return self._auth_type

@auth_type.setter
def auth_type(self, value: str) -> None:
def auth_type(self, value: AuthType) -> None:
"""Sets the type of authentication to the remote server if it is not
using basic with username and password.
Support values: Bearer, X-API-Key. For others, please use `extra_headers` instead
:Args: value - AuthType enum value. For others, please use `extra_headers` instead
"""
self._auth_type = value

Expand Down Expand Up @@ -273,13 +280,12 @@ def get_proxy_url(self) -> Optional[str]:

def get_auth_header(self) -> Optional[dict]:
"""Returns the authorization to add to the request headers."""
auth_type = self.auth_type.lower()
if auth_type == "basic" and self.username and self.password:
if self.auth_type is AuthType.BASIC and self.username and self.password:
credentials = f"{self.username}:{self.password}"
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
return {"Authorization": f"Basic {encoded_credentials}"}
if auth_type == "bearer" and self.token:
return {"Authorization": f"Bearer {self.token}"}
if auth_type == "x-api-key" and self.token:
return {"X-API-Key": f"{self.token}"}
return {"Authorization": f"{AuthType.BASIC.value} {encoded_credentials}"}
if self.auth_type is AuthType.BEARER and self.token:
return {"Authorization": f"{AuthType.BEARER.value} {self.token}"}
if self.auth_type is AuthType.X_API_KEY and self.token:
return {f"{AuthType.X_API_KEY.value}": f"{self.token}"}
return None
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from selenium import __version__
from selenium.webdriver import Proxy
from selenium.webdriver.common.proxy import ProxyType
from selenium.webdriver.remote.client_config import AuthType
from selenium.webdriver.remote.remote_connection import ClientConfig
from selenium.webdriver.remote.remote_connection import RemoteConnection

Expand Down Expand Up @@ -93,7 +94,7 @@ def test_get_proxy_url_http(mock_proxy_settings):

def test_get_auth_header_if_client_config_pass_basic_auth():
custom_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=True, username="user", password="pass", auth_type="Basic"
remote_server_addr="http://remote", keep_alive=True, username="user", password="pass", auth_type=AuthType.BASIC
)
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
headers = remote_connection._client_config.get_auth_header()
Expand All @@ -102,7 +103,7 @@ def test_get_auth_header_if_client_config_pass_basic_auth():

def test_get_auth_header_if_client_config_pass_bearer_token():
custom_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=True, auth_type="Bearer", token="dXNlcjpwYXNz"
remote_server_addr="http://remote", keep_alive=True, auth_type=AuthType.BEARER, token="dXNlcjpwYXNz"
)
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
headers = remote_connection._client_config.get_auth_header()
Expand All @@ -111,7 +112,7 @@ def test_get_auth_header_if_client_config_pass_bearer_token():

def test_get_auth_header_if_client_config_pass_x_api_key():
custom_config = ClientConfig(
remote_server_addr="http://remote", keep_alive=True, auth_type="X-API-Key", token="abcdefgh123456789"
remote_server_addr="http://remote", keep_alive=True, auth_type=AuthType.X_API_KEY, token="abcdefgh123456789"
)
remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config)
headers = remote_connection._client_config.get_auth_header()
Expand Down

0 comments on commit 2657eed

Please sign in to comment.