Skip to content

Commit

Permalink
Refactoring code, update tests and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
orenlab committed Aug 28, 2024
1 parent bcc1db2 commit 0ba8606
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 163 deletions.
129 changes: 86 additions & 43 deletions pyoutlineapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@
PyOutlineAPI is a Python package for interacting with the Outline VPN Server.
Licensed under the MIT License. See the LICENSE file for more details.
This module provides a wrapper around the Outline VPN Server API, allowing
users to programmatically manage access keys, server settings, and monitor
data usage.
Typical usage example:
api = PyOutlineWrapper(api_url="https://example.com", cert_sha256="abc123...")
server_info = api.get_server_info()
access_key = api.create_access_key(name="User1")
Licensed under the MIT License. See the LICENSE file for more details.
"""

from typing import Optional
Expand All @@ -15,7 +24,7 @@
from pydantic import SecretStr, ValidationError as PydanticValidationError
from requests_toolbelt.adapters.fingerprint import FingerprintAdapter

from pyoutlineapi.exceptions import APIError, ValidationError
from pyoutlineapi.exceptions import APIError, ValidationError, HTTPError
from pyoutlineapi.logger import setup_logger
from pyoutlineapi.models import (
AccessKeyCreateRequest,
Expand All @@ -24,7 +33,6 @@
AccessKeyList,
ServerPort,
DataLimit,
MetricsEnabled,
Metrics
)

Expand All @@ -36,9 +44,13 @@ class PyOutlineWrapper:
"""
Class for interacting with the Outline VPN Server.
This class provides methods for managing access keys, retrieving server
information, updating server settings, and monitoring data usage.
Attributes:
api_url (str): The base URL of the API.
cert_sha256 (str): SHA-256 fingerprint of the certificate for authenticity verification.
_api_url (str): The base URL of the API.
_cert_sha256 (str): SHA-256 fingerprint of the certificate for authenticity verification.
_verify_tls (bool): Whether to verify the TLS certificate.
"""

def __init__(self, api_url: str, cert_sha256: str, verify_tls: bool = True):
Expand All @@ -48,12 +60,13 @@ def __init__(self, api_url: str, cert_sha256: str, verify_tls: bool = True):
Args:
api_url (str): The base URL of the API.
cert_sha256 (str): SHA-256 fingerprint of the certificate.
verify_tls (bool, optional): Whether to verify the TLS certificate. Defaults to True.
"""
self.api_url = api_url
self.cert_sha256 = cert_sha256
self.verify_tls = verify_tls
self.session = requests.Session()
self.session.mount(self.api_url, FingerprintAdapter(self.cert_sha256))
self._api_url = api_url
self._cert_sha256 = cert_sha256
self._verify_tls = verify_tls
self._session = requests.Session()
self._session.mount(self._api_url, FingerprintAdapter(self._cert_sha256))

def _request(self, method: str, endpoint: str, json_data=None) -> requests.Response:
"""
Expand All @@ -70,17 +83,18 @@ def _request(self, method: str, endpoint: str, json_data=None) -> requests.Respo
Raises:
APIError: If the request fails.
"""
url = f"{self.api_url}/{endpoint}"
url = f"{self._api_url}/{endpoint}"
try:
response = self.session.request(
response = self._session.request(
method,
url,
json=json_data,
verify=self.verify_tls,
verify=self._verify_tls,
timeout=15
)
response.raise_for_status()
return response
except requests.RequestException as exception:
except (requests.RequestException, HTTPError) as exception:
raise APIError(f"Request to {url} failed: {exception}")

def get_server_info(self) -> Server:
Expand All @@ -89,6 +103,13 @@ def get_server_info(self) -> Server:
Returns:
Server: An object containing server information.
Raises:
APIError: If the request fails.
ValidationError: If the response data is invalid.
Example:
server_info = api.get_server_info()
"""
try:
response = self._request("GET", "server")
Expand All @@ -111,6 +132,9 @@ def create_access_key(self, name: Optional[str] = None, password: Optional[str]
Raises:
ValidationError: If the server response is not 201 or if there's an issue with the request.
Example:
access_key = api.create_access_key(name="User1", password="securepassword")
"""
request_data = {
"name": name,
Expand Down Expand Up @@ -149,45 +173,67 @@ def get_access_keys(self) -> AccessKeyList:
Returns:
AccessKeyList: An object containing a list of access keys.
Raises:
APIError: If the request fails.
ValidationError: If the response data is invalid.
Example:
access_keys = api.get_access_keys()
"""
try:
response = self._request("GET", "access-keys")
return AccessKeyList(**response.json())
except PydanticValidationError as e:
raise ValidationError(f"Failed to get access keys: {e}")

def delete_access_key(self, key_id: str):
def delete_access_key(self, key_id: str) -> bool:
"""
Delete an access key by its ID.
Args:
key_id (str): The ID of the access key.
Returns:
bool: True if the access key was successfully deleted, False otherwise.
Raises:
APIError: If the request fails.
Example:
success = api.delete_access_key(key_id="some_key_id")
"""
try:
self._request("DELETE", f"access-keys/{key_id}")
except PydanticValidationError as e:
raise ValidationError(f"Failed to delete access key with ID {key_id}: {e}")
query = self._request("DELETE", f"access-keys/{key_id}")
return query.status_code == 204
except HTTPError as e:
raise APIError(f"Failed to delete access key with ID {key_id}: {e}")

def update_server_port(self, port: int) -> ServerPort:
def update_server_port(self, port: ServerPort) -> bool:
"""
Update the port for new access keys.
Args:
port (int): The new port.
Returns:
ServerPort: An object containing the updated port information.
bool: True if the server port was successfully updated, False otherwise.
Raises:
APIError: If the request fails.
Example:
success = api.update_server_port(port=12345)
"""
try:
response = self._request("PUT", "server/port-for-new-access-keys", {"port": port})
return ServerPort(**response.json())
except PydanticValidationError as e:
raise ValidationError(f"Failed to update server port: {e}")
if response.status_code == 409:
raise APIError(f"Port {port} is already in use")
return response.status_code == 204
except (PydanticValidationError, APIError) as e:
raise APIError(f"Failed to update server port: {e}")

def set_access_key_data_limit(self, key_id: str, limit: int) -> DataLimit:
def set_access_key_data_limit(self, key_id: str, limit: DataLimit) -> bool:
"""
Set the data limit for an access key.
Expand All @@ -196,36 +242,33 @@ def set_access_key_data_limit(self, key_id: str, limit: int) -> DataLimit:
limit (int): The data limit in bytes.
Returns:
DataLimit: An object containing the set data limit.
"""
try:
response = self._request("PUT", f"access-keys/{key_id}/data-limit", {"bytes": limit})
return DataLimit(**response.json())
except PydanticValidationError as e:
raise ValidationError(f"Failed to set data limit for access key with ID {key_id}: {e}")
bool: True if the data limit was successfully set, False otherwise.
def set_metrics_enabled(self, enabled: bool) -> MetricsEnabled:
"""
Enable or disable metrics on the server.
Args:
enabled (bool): The state of metrics (enabled/disabled).
Raises:
APIError: If the request fails.
Returns:
MetricsEnabled: An object containing the metrics state information.
Example:
success = api.set_access_key_data_limit(key_id="some_key_id", limit=DataLimit(bytes=1048576))
"""
try:
response = self._request("PUT", "server/metrics/enabled", {"enabled": enabled})
return MetricsEnabled(**response.json())
except PydanticValidationError as e:
raise ValidationError(f"Failed to set metrics enabled state: {e}")
response = self._request("PUT", f"access-keys/{key_id}/data-limit", {"bytes": limit})
return response.status_code == 204
except (PydanticValidationError, APIError) as e:
raise APIError(f"Failed to set data limit for access key with ID {key_id}: {e}")

def get_metrics(self) -> Metrics:
"""
Get metrics about data transfer.
Returns:
Metrics: An object containing data transfer metrics.
Raises:
APIError: If the request fails.
ValidationError: If the response data is invalid.
Example:
metrics = api.get_metrics()
"""
try:
response = self._request("GET", "metrics/transfer")
Expand Down
86 changes: 76 additions & 10 deletions pyoutlineapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
PyOutlineAPI is a Python package for interacting with the Outline VPN Server.
Licensed under the MIT License. See the LICENSE file for more details.
"""

from typing import Optional, List, Dict
Expand All @@ -15,7 +14,16 @@


class Server(BaseModel):
"""Model for server information."""
"""
Model for server information.
Attributes:
name (str): The name of the server.
serverId (str): The unique identifier for the server.
metricsEnabled (bool): Indicates if metrics collection is enabled.
createdTimestampMs (int): The timestamp when the server was created, must be non-negative.
portForNewAccessKeys (int): The port used for new access keys, must be between 1 and 65535.
"""
name: str
serverId: str
metricsEnabled: bool
Expand All @@ -24,12 +32,27 @@ class Server(BaseModel):


class DataLimit(BaseModel):
"""Model for data limit information."""
"""
Model for data limit information.
Attributes:
bytes (int): The data limit in bytes, must be non-negative.
"""
bytes: int = Field(ge=0, description="Data limit in bytes must be non-negative")


class AccessKey(BaseModel):
"""Model for access key information."""
"""
Model for access key information.
Attributes:
id (str): The unique identifier for the access key.
name (str): The name of the access key.
password (SecretStr): The password for the access key, must not be empty.
port (int): The port used by the access key, must be between 1 and 65535.
method (str): The encryption method used by the access key.
accessUrl (SecretStr): The URL used to access the server, must not be empty.
"""
id: str
name: str
password: SecretStr = Field(..., min_length=1, description="Password must not be empty")
Expand All @@ -39,34 +62,77 @@ class AccessKey(BaseModel):


class ServerPort(BaseModel):
"""Model for server port information."""
"""
Model for server port information.
Attributes:
port (int): The port used by the server, must be between 1 and 65535.
"""
port: int = Field(ge=1, le=65535, description="Port must be between 1 and 65535")


class AccessKeyCreateRequest(BaseModel):
"""Model for creating access key information."""
"""
Model for creating access key information.
Attributes:
name (Optional[str]): The name of the access key (optional).
password (Optional[str]): The password for the access key (optional).
port (Optional[int]): The port used by the access key, must be between 0 and 65535 (optional).
"""
name: Optional[str]
password: Optional[str]
port: Optional[int] = Field(..., ge=0, le=65535, description="Port must be between 0 and 65535")
port: Optional[int] = Field(ge=0, le=65535, description="Port must be between 0 and 65535")


class AccessKeyList(BaseModel):
"""Model for access key list information."""
"""
Model for access key list information.
Attributes:
accessKeys (List[AccessKey]): A list of access keys.
"""
accessKeys: List[AccessKey]


class MetricsEnabled(BaseModel):
"""Model for metrics enabled information."""
"""
Model for metrics enabled information.
Attributes:
enabled (bool): Indicates if metrics collection is enabled.
"""
enabled: bool


class Metrics(BaseModel):
"""Model for metrics information."""
"""
Model for metrics information.
Attributes:
bytesTransferredByUserId (Dict[str, int]): A dictionary mapping user IDs to the number of bytes transferred.
User IDs must be non-empty strings, and byte values must be non-negative.
Methods:
validate_bytes_transferred: Validates that all byte values in the dictionary are non-negative.
"""
bytesTransferredByUserId: Dict[constr(min_length=1), int] = Field(
description="User IDs must be non-empty strings and byte values must be non-negative")

@field_validator("bytesTransferredByUserId")
def validate_bytes_transferred(cls, value: Dict[str, int]) -> Dict[str, int]:
"""
Validate that all byte values in the dictionary are non-negative.
Args:
value (Dict[str, int]): The dictionary to validate.
Returns:
Dict[str, int]: The validated dictionary.
Raises:
ValueError: If any byte value is negative.
"""
for user_id, bytes_transferred in value.items():
if bytes_transferred < 0:
raise ValueError(f"Transferred bytes for user {user_id} must be non-negative")
Expand Down
Loading

0 comments on commit 0ba8606

Please sign in to comment.