Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[py] Set user_agent and extra_headers via ClientConfig #14718

Merged
merged 8 commits into from
Nov 9, 2024
55 changes: 44 additions & 11 deletions py/selenium/webdriver/remote/client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import base64
import os
import socket
from enum import Enum
from typing import Optional
from urllib import parse

Expand All @@ -26,6 +27,12 @@
from selenium.webdriver.common.proxy import ProxyType


class AuthType(Enum):
BASIC = "Basic"
BEARER = "Bearer"
X_API_KEY = "X-API-Key"


class ClientConfig:
def __init__(
self,
Expand All @@ -38,8 +45,10 @@ def __init__(
ca_certs: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
auth_type: Optional[str] = "Basic",
auth_type: Optional[AuthType] = AuthType.BASIC,
token: Optional[str] = None,
user_agent: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> None:
self.remote_server_addr = remote_server_addr
self.keep_alive = keep_alive
Expand All @@ -51,6 +60,8 @@ def __init__(
self.password = password
self.auth_type = auth_type
self.token = token
self.user_agent = user_agent
self.extra_headers = extra_headers

self.timeout = (
(
Expand Down Expand Up @@ -198,14 +209,17 @@ def password(self, value: str) -> None:
self._password = value

@property
def auth_type(self) -> str:
def auth_type(self) -> AuthType:
"""Returns the type of authentication to the remote server."""
return self._auth_type

@auth_type.setter
def auth_type(self, value: str) -> None:
def auth_type(self, value: AuthType) -> None:
"""Sets the type of authentication to the remote server if it is not
using basic with username and password."""
using basic with username and password.

:Args: value - AuthType enum value. For others, please use `extra_headers` instead
"""
self._auth_type = value

@property
Expand All @@ -219,6 +233,26 @@ def token(self, value: str) -> None:
auth_type is not basic."""
self._token = value

@property
def user_agent(self) -> str:
"""Returns user agent to be added to the request headers."""
return self._user_agent

@user_agent.setter
def user_agent(self, value: str) -> None:
"""Sets user agent to be added to the request headers."""
self._user_agent = value

@property
def extra_headers(self) -> dict:
"""Returns extra headers to be added to the request."""
return self._extra_headers

@extra_headers.setter
def extra_headers(self, value: dict) -> None:
"""Sets extra headers to be added to the request."""
self._extra_headers = value

def get_proxy_url(self) -> Optional[str]:
"""Returns the proxy URL to use for the connection."""
proxy_type = self.proxy.proxy_type
Expand Down Expand Up @@ -246,13 +280,12 @@ def get_proxy_url(self) -> Optional[str]:

def get_auth_header(self) -> Optional[dict]:
"""Returns the authorization to add to the request headers."""
auth_type = self.auth_type.lower()
if auth_type == "basic" and self.username and self.password:
if self.auth_type is AuthType.BASIC and self.username and self.password:
credentials = f"{self.username}:{self.password}"
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
return {"Authorization": f"Basic {encoded_credentials}"}
if auth_type == "bearer" and self.token:
return {"Authorization": f"Bearer {self.token}"}
if auth_type == "oauth" and self.token:
return {"Authorization": f"OAuth {self.token}"}
return {"Authorization": f"{AuthType.BASIC.value} {encoded_credentials}"}
if self.auth_type is AuthType.BEARER and self.token:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks so much better!

return {"Authorization": f"{AuthType.BEARER.value} {self.token}"}
if self.auth_type is AuthType.X_API_KEY and self.token:
return {f"{AuthType.X_API_KEY.value}": f"{self.token}"}
return None
20 changes: 12 additions & 8 deletions py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from base64 import b64encode
from typing import Optional
from urllib import parse
from urllib.parse import urlparse

import urllib3

Expand Down Expand Up @@ -243,6 +244,9 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
}

if parsed_url.username:
warnings.warn(
"Embedding username and password in URL could be insecure, use ClientConfig instead", stacklevel=2
)
base64string = b64encode(f"{parsed_url.username}:{parsed_url.password}".encode())
headers.update({"Authorization": f"Basic {base64string.decode()}"})

Expand All @@ -255,16 +259,14 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False):
return headers

def _identify_http_proxy_auth(self):
url = self._proxy_url
url = url[url.find(":") + 3 :]
return "@" in url and len(url[: url.find("@")]) > 0
parsed_url = urlparse(self._proxy_url)
if parsed_url.username and parsed_url.password:
return True

def _separate_http_proxy_auth(self):
url = self._proxy_url
protocol = url[: url.find(":") + 3]
no_protocol = url[len(protocol) :]
auth = no_protocol[: no_protocol.find("@")]
proxy_without_auth = protocol + no_protocol[len(auth) + 1 :]
parsed_url = urlparse(self._proxy_url)
proxy_without_auth = f"{parsed_url.scheme}://{parsed_url.hostname}:{parsed_url.port}"
auth = f"{parsed_url.username}:{parsed_url.password}"
return proxy_without_auth, auth

def _get_connection_manager(self):
Expand Down Expand Up @@ -312,6 +314,8 @@ def __init__(
RemoteConnection._timeout = self._client_config.timeout
RemoteConnection._ca_certs = self._client_config.ca_certs
RemoteConnection._client_config = self._client_config
RemoteConnection.extra_headers = self._client_config.extra_headers or RemoteConnection.extra_headers
RemoteConnection.user_agent = self._client_config.user_agent or RemoteConnection.user_agent

if remote_server_addr:
warnings.warn(
Expand Down
Loading
Loading