Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specify User Agent for Console CLI requests #952

Merged
merged 16 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- [Metadata Migration](#metadata-migration)
- [Replay](#replay)
- [Kafka](#kafka)
- [Client Options](#client-options)
- [Usage](#usage)
- [Library](#library)
- [CLI](#cli)
Expand Down Expand Up @@ -82,6 +83,8 @@ metadata_migration:
kafka:
broker_endpoints: "kafka:9092"
standard:
client_options:
user_agent_extra: "test-user-agent-v1.0"
```

## Services.yaml spec
Expand Down Expand Up @@ -225,13 +228,19 @@ Exactly one of the following blocks must be present:

A Kafka cluster is used in the capture and replay stage of the migration to store recorded requests and responses before they're replayed. While it's not necessary for a user to directly interact with the Kafka cluster in most cases, there are a handful of commands that can be helpful for checking on the status or resetting state that are exposed by the Console CLI.

- `broker_endpoints`: required, comma-separated list of kafaka broker endpoints
- `broker_endpoints`: required, comma-separated list of kafka broker endpoints

Exactly one of the following keys must be present, but both are nullable (they don't have or need any additional parameters).

- `msk`: the Kafka instance is deployed as AWS Managed Service Kafka
- `standard`: the Kafka instance is deployed as a standard Kafka cluster (e.g. on Docker)

### Client Options

Client options are global settings that are applied to different clients used throughout this library

- `user_agent_extra`: optional, a user agent string that will be appended to the `User-Agent` header of all requests from this library

## Usage

### Library
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from console_link.models.snapshot import Snapshot
from console_link.models.replayer_base import Replayer
from console_link.models.kafka import Kafka
from console_link.models.client_options import ClientOptions

import yaml
from cerberus import Validator
Expand All @@ -25,7 +26,8 @@
"snapshot": {"type": "dict", "required": False},
"metadata_migration": {"type": "dict", "required": False},
"replay": {"type": "dict", "required": False},
"kafka": {"type": "dict", "required": False}
"kafka": {"type": "dict", "required": False},
"client_options": {"type": "dict", "required": False},
}


Expand All @@ -38,6 +40,7 @@ class Environment:
metadata: Optional[Metadata] = None
replay: Optional[Replayer] = None
kafka: Optional[Kafka] = None
client_options: Optional[ClientOptions] = None

def __init__(self, config_file: str):
logger.info(f"Loading config file: {config_file}")
Expand All @@ -50,23 +53,29 @@ def __init__(self, config_file: str):
logger.error(f"Config file validation errors: {v.errors}")
raise ValueError("Invalid config file", v.errors)

if 'client_options' in self.config:
self.client_options: ClientOptions = ClientOptions(self.config["client_options"])

if 'source_cluster' in self.config:
self.source_cluster = Cluster(self.config["source_cluster"])
self.source_cluster = Cluster(config=self.config["source_cluster"],
client_options=self.client_options)
logger.info(f"Source cluster initialized: {self.source_cluster.endpoint}")
else:
logger.info("No source cluster provided")

# At some point, target and replayers should be stored as pairs, but for the time being
# we can probably assume one target cluster.
if 'target_cluster' in self.config:
self.target_cluster: Cluster = Cluster(self.config["target_cluster"])
self.target_cluster: Cluster = Cluster(config=self.config["target_cluster"],
client_options=self.client_options)
logger.info(f"Target cluster initialized: {self.target_cluster.endpoint}")
else:
logger.warning("No target cluster provided. This may prevent other actions from proceeding.")

if 'metrics_source' in self.config:
self.metrics_source: MetricsSource = get_metrics_source(
self.config["metrics_source"]
config=self.config["metrics_source"],
client_options=self.client_options
)
logger.info(f"Metrics source initialized: {self.metrics_source}")
else:
Expand All @@ -75,13 +84,14 @@ def __init__(self, config_file: str):
if 'backfill' in self.config:
self.backfill: Backfill = get_backfill(self.config["backfill"],
source_cluster=self.source_cluster,
target_cluster=self.target_cluster)
target_cluster=self.target_cluster,
client_options=self.client_options)
logger.info(f"Backfill migration initialized: {self.backfill}")
else:
logger.info("No backfill provided")

if 'replay' in self.config:
self.replay: Replayer = get_replayer(self.config["replay"])
self.replay: Replayer = get_replayer(self.config["replay"], client_options=self.client_options)
logger.info(f"Replay initialized: {self.replay}")

if 'snapshot' in self.config:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from console_link.models.client_options import ClientOptions
from console_link.models.osi_utils import (create_pipeline_from_env, start_pipeline, stop_pipeline,
OpenSearchIngestionMigrationProps)
from console_link.models.cluster import Cluster
from console_link.models.backfill_base import Backfill
from console_link.models.command_result import CommandResult
from typing import Dict
from typing import Dict, Optional
from cerberus import Validator
import boto3

from console_link.models.utils import create_boto3_client

OSI_SCHEMA = {
'pipeline_role_arn': {
Expand Down Expand Up @@ -61,15 +62,17 @@ class OpenSearchIngestionBackfill(Backfill):
A migration manager for an OpenSearch Ingestion pipeline.
"""

def __init__(self, config: Dict, source_cluster: Cluster, target_cluster: Cluster) -> None:
def __init__(self, config: Dict, source_cluster: Cluster, target_cluster: Cluster,
client_options: Optional[ClientOptions] = None) -> None:
super().__init__(config)
self.client_options = client_options
config = config["opensearch_ingestion"]

v = Validator(OSI_SCHEMA)
if not v.validate(config):
raise ValueError("Invalid config file for OpenSearchIngestion migration", v.errors)
self.osi_props = OpenSearchIngestionMigrationProps(config=config)
self.osi_client = boto3.client('osis')
self.osi_client = create_boto3_client(aws_service_name='osis', client_options=self.client_options)
self.source_cluster = source_cluster
self.target_cluster = target_cluster

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests

from console_link.models.backfill_base import Backfill, BackfillStatus
from console_link.models.client_options import ClientOptions
from console_link.models.cluster import Cluster
from console_link.models.schema_tools import contains_one_of
from console_link.models.command_result import CommandResult
Expand Down Expand Up @@ -87,14 +88,17 @@ def scale(self, units: int, *args, **kwargs) -> CommandResult:


class ECSRFSBackfill(RFSBackfill):
def __init__(self, config: Dict, target_cluster: Cluster) -> None:
def __init__(self, config: Dict, target_cluster: Cluster, client_options: Optional[ClientOptions] = None) -> None:
super().__init__(config)
self.client_options = client_options
self.target_cluster = target_cluster
self.default_scale = self.config["reindex_from_snapshot"].get("scale", 1)

self.ecs_config = self.config["reindex_from_snapshot"]["ecs"]
self.ecs_client = ECSService(self.ecs_config["cluster_name"], self.ecs_config["service_name"],
self.ecs_config.get("aws_region", None))
self.ecs_client = ECSService(cluster_name=self.ecs_config["cluster_name"],
service_name=self.ecs_config["service_name"],
aws_region=self.ecs_config.get("aws_region", None),
client_options=self.client_options)

def start(self, *args, **kwargs) -> CommandResult:
logger.info(f"Starting RFS backfill by setting desired count to {self.default_scale} instances")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Dict, Optional
import logging
from cerberus import Validator

logger = logging.getLogger(__name__)

SCHEMA = {
mikaylathompson marked this conversation as resolved.
Show resolved Hide resolved
"client_options": {
"type": "dict",
"schema": {
"user_agent_extra": {"type": "string", "required": False},
},
}
}


class ClientOptions:
"""
Options that can be configured for boto3 and request library clients.
"""

user_agent_extra: Optional[str] = None

def __init__(self, config: Dict) -> None:
logger.info(f"Initializing client options with config: {config}")
v = Validator(SCHEMA)
if not v.validate({'client_options': config}):
raise ValueError("Invalid config file for client options", v.errors)

self.user_agent_extra = config.get("user_agent_extra", None)
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from requests.auth import HTTPBasicAuth
from requests_auth_aws_sigv4 import AWSSigV4

from console_link.models.client_options import ClientOptions
from console_link.models.schema_tools import contains_one_of
from console_link.models.utils import create_boto3_client, append_user_agent_header_for_requests

requests.packages.urllib3.disable_warnings() # ignore: type

Expand Down Expand Up @@ -79,8 +81,9 @@ class Cluster:
auth_type: Optional[AuthMethod] = None
auth_details: Optional[Dict[str, Any]] = None
allow_insecure: bool = False
client_options: Optional[ClientOptions] = None

def __init__(self, config: Dict) -> None:
def __init__(self, config: Dict, client_options: Optional[ClientOptions] = None) -> None:
logger.info(f"Initializing cluster with config: {config}")
v = Validator(SCHEMA)
if not v.validate({'cluster': config}):
Expand All @@ -97,6 +100,7 @@ def __init__(self, config: Dict) -> None:
elif 'sigv4' in config:
self.auth_type = AuthMethod.SIGV4
self.auth_details = config["sigv4"] if config["sigv4"] is not None else {}
self.client_options = client_options

def get_basic_auth_password(self) -> str:
"""This method will return the basic auth password, if basic auth is enabled.
Expand All @@ -108,11 +112,11 @@ def get_basic_auth_password(self) -> str:
return self.auth_details["password"]
# Pull password from AWS Secrets Manager
assert "password_from_secret_arn" in self.auth_details # for mypy's sake
client = boto3.client('secretsmanager')
client = create_boto3_client(aws_service_name="secretsmanager", client_options=self.client_options)
password = client.get_secret_value(SecretId=self.auth_details["password_from_secret_arn"])
return password["SecretString"]

def _get_sigv4_details(self, force_region=False) -> tuple[str, str]:
def _get_sigv4_details(self, force_region=False) -> tuple[str, Optional[str]]:
"""Return the service signing name and region name. If force_region is true,
it will instantiate a boto3 session to guarantee that the region is not None.
This will fail if AWS credentials are not available.
Expand Down Expand Up @@ -145,9 +149,14 @@ def call_api(self, path, method: HttpMethod = HttpMethod.GET, data=None, headers
"""
if session is None:
session = requests.Session()

auth = self._generate_auth_object()

request_headers = headers
if self.client_options and self.client_options.user_agent_extra:
user_agent_extra = self.client_options.user_agent_extra
request_headers = append_user_agent_header_for_requests(headers=headers, user_agent_extra=user_agent_extra)
lewijacn marked this conversation as resolved.
Show resolved Hide resolved

# Extract query parameters from kwargs
params = kwargs.get('params', {})

Expand All @@ -159,7 +168,7 @@ def call_api(self, path, method: HttpMethod = HttpMethod.GET, data=None, headers
params=params,
auth=auth,
data=data,
headers=headers,
headers=request_headers,
timeout=timeout
)
logger.info(f"Received response: {r.status_code} {method.name} {self.endpoint}{path} - {r.text[:1000]}")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import logging
from typing import NamedTuple, Optional

import boto3

from console_link.models.command_result import CommandResult
from console_link.models.utils import AWSAPIError, raise_for_aws_api_error

from console_link.models.utils import AWSAPIError, raise_for_aws_api_error, create_boto3_client

logger = logging.getLogger(__name__)

Expand All @@ -20,13 +17,15 @@ def __str__(self):


class ECSService:
def __init__(self, cluster_name, service_name, aws_region=None):
def __init__(self, cluster_name, service_name, aws_region=None, client_options=None):
self.cluster_name = cluster_name
self.service_name = service_name
self.aws_region = aws_region
self.client_options = client_options

logger.info(f"Creating ECS client for region {aws_region}, if specified")
self.client = boto3.client("ecs", region_name=self.aws_region)
self.client = create_boto3_client(aws_service_name="ecs", region=self.aws_region,
client_options=self.client_options)

def set_desired_count(self, desired_count: int) -> CommandResult:
logger.info(f"Setting desired count for service {self.service_name} to {desired_count}")
Expand All @@ -47,7 +46,7 @@ def set_desired_count(self, desired_count: int) -> CommandResult:
desired_count = response["service"]["desiredCount"]
return CommandResult(True, f"Service {self.service_name} set to {desired_count} desired count."
f" Currently {running_count} running and {pending_count} pending.")

def get_instance_statuses(self) -> Optional[InstanceStatuses]:
logger.info(f"Getting instance statuses for service {self.service_name}")
response = self.client.describe_services(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from enum import Enum
from typing import Dict, Optional

from console_link.models.client_options import ClientOptions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I think I slightly regret YAML vs env variable decision because we could have kept env variable stuff much more limited to only grabbing it at the level of each client, instead of having to pass it all the way through the code -- it feels like conceptually this shouldn't have to be woven through all these factories and lots of unrelated code and it should just stay low level, but I don't really have any way in mind to implement that if it comes from the YAML.

Which is all to say, I'm not sure if there's a better solution or not, and definitely not sure whether it's worth refactoring this, because I think it's well-implemented for this route. Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Circling back to add that after looking at the tests, this services.yaml approach is definitely clearer and easier to test, so that's a point in its favor)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From some discussion with Mikayla there are pros and cons with the current approach versus setting something more globally instead of passing through yaml. I'm comfortable with the current approach for now, and think it can be improved so there are less touch points next time we want some global settings but let me know if you have concerns @mikaylathompson

from console_link.models.replayer_docker import DockerReplayer
from console_link.models.metrics_source import CloudwatchMetricsSource, PrometheusMetricsSource
from console_link.models.backfill_base import Backfill
Expand Down Expand Up @@ -55,9 +56,9 @@ def get_snapshot(config: Dict, source_cluster: Cluster):
raise UnsupportedSnapshotError(next(iter(config.keys())))


def get_replayer(config: Dict):
def get_replayer(config: Dict, client_options: Optional[ClientOptions] = None):
if 'ecs' in config:
return ECSReplayer(config)
return ECSReplayer(config=config, client_options=client_options)
if 'docker' in config:
return DockerReplayer(config)
logger.error(f"An unsupported replayer type was provided: {config.keys()}")
Expand All @@ -74,7 +75,8 @@ def get_kafka(config: Dict):
raise UnsupportedKafkaError(', '.join(config.keys()))


def get_backfill(config: Dict, source_cluster: Optional[Cluster], target_cluster: Optional[Cluster]) -> Backfill:
def get_backfill(config: Dict, source_cluster: Optional[Cluster], target_cluster: Optional[Cluster],
client_options: Optional[ClientOptions] = None) -> Backfill:
if BackfillType.opensearch_ingestion.name in config:
if source_cluster is None:
raise ValueError("source_cluster must be provided for OpenSearch Ingestion backfill")
Expand All @@ -83,7 +85,8 @@ def get_backfill(config: Dict, source_cluster: Optional[Cluster], target_cluster
logger.debug("Creating OpenSearch Ingestion backfill instance")
return OpenSearchIngestionBackfill(config=config,
source_cluster=source_cluster,
target_cluster=target_cluster)
target_cluster=target_cluster,
client_options=client_options)
elif BackfillType.reindex_from_snapshot.name in config:
if target_cluster is None:
raise ValueError("target_cluster must be provided for RFS backfill")
Expand All @@ -95,17 +98,18 @@ def get_backfill(config: Dict, source_cluster: Optional[Cluster], target_cluster
elif 'ecs' in config[BackfillType.reindex_from_snapshot.name]:
logger.debug("Creating ECS RFS backfill instance")
return ECSRFSBackfill(config=config,
target_cluster=target_cluster)
target_cluster=target_cluster,
client_options=client_options)

logger.error(f"An unsupported backfill source type was provided: {config.keys()}")
raise UnsupportedBackfillTypeError(', '.join(config.keys()))


def get_metrics_source(config):
def get_metrics_source(config, client_options: Optional[ClientOptions] = None):
if 'prometheus' in config:
return PrometheusMetricsSource(config)
return PrometheusMetricsSource(config=config, client_options=client_options)
elif 'cloudwatch' in config:
return CloudwatchMetricsSource(config)
return CloudwatchMetricsSource(config=config, client_options=client_options)
else:
logger.error(f"An unsupported metrics source type was provided: {config.keys()}")
raise UnsupportedMetricsSourceError(', '.join(config.keys()))
Loading