diff --git a/docs/guide/scripting.md b/docs/guide/scripting.md index ddf82b17..d39b0013 100644 --- a/docs/guide/scripting.md +++ b/docs/guide/scripting.md @@ -94,6 +94,7 @@ def on_request(request: RequestModel, posting: Posting) -> None: # Set auth on the request. request.auth = Auth.basic_auth("username", "password") # request.auth = Auth.digest_auth("username", "password") + # request.auth = Auth.bearer_token_auth("token") # This will be captured and written to the log. print("Request is being sent!") diff --git a/docs/roadmap.md b/docs/roadmap.md index eb5e70c7..4963b056 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -22,7 +22,7 @@ If you have any feedback or suggestions, please open a new discussion on GitHub. - Changing the environment at runtime - probably via command palette - push a new command palette screen where you can search for and select one of the previously used environments. - Variable completion autocompletion TextAreas. - Variable resolution highlighting in TextAreas. -- Bearer token auth (can be done now by adding header). +- Bearer token auth (can be done now by adding header). ✅ - API key auth (can be done now by adding header). - OAuth2 (need to scope out what's involved here). - Add "quit" to command palette and footer ✅ diff --git a/pyproject.toml b/pyproject.toml index 8f9fddb9..8ec65b55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "xdg-base-dirs>=6.0.1,<7.0.0", "click-default-group>=1.2.4,<2.0.0", "httpx[brotli]>=0.27.2,<1.0.0", + "openapi-pydantic>=0.5.0", "pyperclip>=1.9.0,<2.0.0", "pydantic>=2.9.2,<3.0.0", "pyyaml>=6.0.2,<7.0.0", diff --git a/src/posting/auth.py b/src/posting/auth.py new file mode 100644 index 00000000..f95e4784 --- /dev/null +++ b/src/posting/auth.py @@ -0,0 +1,12 @@ +from typing import Generator + +import httpx + + +class HttpxBearerTokenAuth(httpx.Auth): + def __init__(self, token: str): + self.token = token + + def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]: + request.headers["Authorization"] = f"Bearer {self.token}" + yield request diff --git a/src/posting/collection.py b/src/posting/collection.py index c276ec9a..071febc2 100644 --- a/src/posting/collection.py +++ b/src/posting/collection.py @@ -8,6 +8,7 @@ import rich import yaml import os +from posting.auth import HttpxBearerTokenAuth from posting.tuple_to_multidict import tuples_to_dict from posting.variables import SubstitutionError @@ -34,9 +35,10 @@ def str_presenter(dumper, data): class Auth(BaseModel): - type: Literal["basic", "digest"] | None = Field(default=None) + type: Literal["basic", "digest", "bearer_token"] | None = Field(default=None) basic: BasicAuth | None = Field(default=None) digest: DigestAuth | None = Field(default=None) + bearer_token: BearerTokenAuth | None = Field(default=None) def to_httpx_auth(self) -> httpx.Auth | None: if self.type == "basic": @@ -45,6 +47,9 @@ def to_httpx_auth(self) -> httpx.Auth | None: elif self.type == "digest": assert self.digest is not None return httpx.DigestAuth(self.digest.username, self.digest.password) + elif self.type == "bearer_token": + assert self.bearer_token is not None + return HttpxBearerTokenAuth(self.bearer_token.token) return None @classmethod @@ -57,6 +62,10 @@ def digest_auth(cls, username: str, password: str) -> Auth: type="digest", digest=DigestAuth(username=username, password=password) ) + @classmethod + def bearer_token_auth(cls, token: str) -> Auth: + return cls(type="bearer_token", bearer_token=BearerTokenAuth(token=token)) + class BasicAuth(BaseModel): username: str = Field(default="") @@ -68,6 +77,11 @@ class DigestAuth(BaseModel): password: str = Field(default="") + +class BearerTokenAuth(BaseModel): + token: str = Field(default="") + + class Header(BaseModel): name: str value: str @@ -239,6 +253,9 @@ def apply_template(self, variables: dict[str, Any]) -> None: self.auth.digest.username = template.substitute(variables) template = Template(self.auth.digest.password) self.auth.digest.password = template.substitute(variables) + if self.auth.bearer_token is not None: + template = Template(self.auth.bearer_token.token) + self.auth.bearer_token.token = template.substitute(variables) except (KeyError, ValueError) as e: raise SubstitutionError(f"Variable not defined: {e}") diff --git a/src/posting/importing/curl.py b/src/posting/importing/curl.py index 7fe990df..535f79c5 100644 --- a/src/posting/importing/curl.py +++ b/src/posting/importing/curl.py @@ -218,6 +218,15 @@ def _extract_auth_from_headers(self) -> tuple[Auth | None, list[tuple[str, str]] except Exception: # If we can't parse it, keep it as a header remaining_headers.append((name, value)) + + elif auth_type_lower == "bearer": + # Bearer token auth + try: + auth = Auth.bearer_token_auth(auth_value) + except Exception: + # If we can't parse it, keep it as a header + remaining_headers.append((name, value)) + else: # Unknown auth type, keep as header remaining_headers.append((name, value)) diff --git a/src/posting/importing/open_api.py b/src/posting/importing/open_api.py index 3a2f65ce..279ccc17 100644 --- a/src/posting/importing/open_api.py +++ b/src/posting/importing/open_api.py @@ -4,12 +4,16 @@ from urllib.parse import urlparse import yaml +from openapi_pydantic import OpenAPI, Reference, SecurityScheme from pathlib import Path from posting.collection import ( VALID_HTTP_METHODS, APIInfo, + Auth, + BasicAuth, + BearerTokenAuth, Collection, ExternalDocs, FormItem, @@ -70,23 +74,58 @@ def extract_server_variables(spec: dict[str, Any]) -> dict[str, dict[str, str]]: "description": f"Server URL {i+1}: {server.get('description', '')}", } - # # Extract security schemes - # security_schemes = spec.get("components", {}).get("securitySchemes", {}) - # for scheme_name, scheme in security_schemes.items(): - # if scheme["type"] == "apiKey": - # variables[f"{scheme_name.upper()}_API_KEY"] = { - # "value": "YOUR_API_KEY_HERE", - # "description": f"API Key for {scheme_name} authentication", - # } - # elif scheme["type"] == "http" and scheme["scheme"] == "bearer": - # variables[f"{scheme_name.upper()}_BEARER_TOKEN"] = { - # "value": "YOUR_BEARER_TOKEN_HERE", - # "description": f"Bearer token for {scheme_name} authentication", - # } - return variables +def security_scheme_to_variables( + name: str, + security_scheme: SecurityScheme | Reference, +) -> dict[str, dict[str, str]]: + match security_scheme: + case SecurityScheme(type="http", scheme="basic"): + return { + f"{name.upper()}_USERNAME": { + "value": "YOUR USERNAME HERE", + "description": f"Username for {name} authentication", + }, + f"{name.upper()}_PASSWORD": { + "value": "YOUR PASSWORD HERE", + "description": f"Password for {name} authentication", + }, + } + case SecurityScheme(type="http", scheme="bearer"): + return { + f"{name.upper()}_BEARER_TOKEN": { + "value": "YOUR BEARER TOKEN HERE", + "description": f"Token for {name} authentication", + }, + } + case _: + return {} + + +def security_scheme_to_auth( + name: str, + security_scheme: SecurityScheme | Reference, +) -> Auth | None: + match security_scheme: + case SecurityScheme(type="http", scheme="basic"): + return Auth( + type="basic", + basic=BasicAuth( + username=f"${{{name.upper()}_USERNAME}}", + password=f"${{{name.upper()}_PASSWORD}}", + ), + ) + case SecurityScheme(type="http", scheme="bearer"): + return Auth( + type="bearer_token", + bearer_token=BearerTokenAuth(token=f"${{{name.upper()}_BEARER_TOKEN}}"), + ) + case _: + return None + + def generate_readme( spec_path: Path, info: APIInfo, @@ -184,9 +223,16 @@ def import_openapi_spec(spec_path: str | Path) -> Collection: name=collection_name, ) + openapi = OpenAPI.model_validate(spec) + security_schemes = openapi.components.securitySchemes or {} + env_files: list[Path] = [] for server in servers: - variables = extract_server_variables(server) + security_variables = {} + for scheme_name, scheme in security_schemes.items(): + security_variables.update(security_scheme_to_variables(scheme_name, scheme)) + + variables = {**extract_server_variables(server), **security_variables} env_filename = generate_unique_env_filename(collection_name, server["url"]) env_file = create_env_file(spec_path.parent, env_filename, variables) console.print( @@ -209,6 +255,14 @@ def import_openapi_spec(spec_path: str | Path) -> Collection: method=method, url=f"${{BASE_URL}}{path}", ) + + # Add auth + for security in operation.get("security", []): + for scheme_name, _scopes in security.items(): + if scheme := security_schemes.get(scheme_name): + request.auth = security_scheme_to_auth(scheme_name, scheme) + break + # Add query parameters for param in operation.get("parameters", []): if param["in"] == "query": diff --git a/src/posting/widgets/request/request_auth.py b/src/posting/widgets/request/request_auth.py index b7beae57..1f4ba9ea 100644 --- a/src/posting/widgets/request/request_auth.py +++ b/src/posting/widgets/request/request_auth.py @@ -6,7 +6,8 @@ from textual.containers import Horizontal, Vertical, VerticalScroll from textual.widgets import ContentSwitcher, Input, Label, Select, Static -from posting.collection import Auth, BasicAuth, DigestAuth +from posting.auth import HttpxBearerTokenAuth +from posting.collection import Auth, BasicAuth, BearerTokenAuth, DigestAuth from posting.widgets.select import PostingSelect from posting.widgets.variable_input import VariableInput @@ -52,6 +53,34 @@ def get_values(self) -> dict[str, str]: } +class BearerTokenForm(Vertical): + DEFAULT_CSS = """ + BearerTokenForm { + padding: 1 0; + + & #token-input { + margin-bottom: 1; + } + } + """ + + def compose(self) -> ComposeResult: + yield Label("Token") + yield VariableInput( + placeholder="Enter a token", + password=True, + id="token-input", + ) + + def set_values(self, token: str) -> None: + self.query_one("#token-input", Input).value = token + + def get_values(self) -> dict[str, str]: + return { + "token": self.query_one("#token-input", Input).value, + } + + class RequestAuth(VerticalScroll): DEFAULT_CSS = """ RequestAuth { @@ -93,6 +122,7 @@ def compose(self) -> ComposeResult: ("No Auth", None), ("Basic", "basic"), ("Digest", "digest"), + ("Bearer Token", "bearer-token"), ], allow_blank=False, prompt="Auth Type", @@ -107,6 +137,7 @@ def compose(self) -> ComposeResult: with ContentSwitcher(initial=None, id="auth-form-switcher"): yield UserNamePasswordForm(id="auth-form-basic") yield UserNamePasswordForm(id="auth-form-digest") + yield BearerTokenForm(id="auth-form-bearer-token") @on(Select.Changed, selector="#auth-type-select") def on_auth_type_changed(self, event: Select.Changed): @@ -123,6 +154,8 @@ def to_httpx_auth(self) -> httpx.Auth | None: return httpx.BasicAuth(**form.get_values()) case "auth-form-digest": return httpx.DigestAuth(**form.get_values()) + case "auth-form-bearer-token": + return HttpxBearerTokenAuth(**form.get_values()) case _: return None @@ -147,6 +180,10 @@ def to_model(self) -> Auth | None: type="digest", digest=DigestAuth(username=username, password=password), ) + case "auth-form-bearer-token": + form_values = form.get_values() + token = form_values["token"] + return Auth(type="bearer_token", bearer_token=BearerTokenAuth(token=token)) case _: return None @@ -177,6 +214,14 @@ def load_auth(self, auth: Auth | None) -> None: auth.digest.username, auth.digest.password, ) + case "bearer_token": + if auth.bearer_token is None: + log.warning("Bearer auth selected, but no values provided for token.") + return + self.query_one("#auth-type-select", Select).value = "bearer-token" + self.query_one("#auth-form-bearer-token", BearerTokenForm).set_values( + auth.bearer_token.token + ) case _: log.warning(f"Unknown auth type: {auth.type}") diff --git a/tests/test_curl_import.py b/tests/test_curl_import.py index c318a540..c388d042 100644 --- a/tests/test_curl_import.py +++ b/tests/test_curl_import.py @@ -82,6 +82,14 @@ def test_curl_with_user_and_password(): assert curl_import.url == "http://example.com" +def test_curl_with_bearer_token(): + """Test parsing of user credentials.""" + curl_command = "curl http://example.com -H 'Authorization: Bearer my-token'" + curl_import = CurlImport(curl_command) + assert curl_import.headers == [("Authorization", "Bearer my-token")] + assert curl_import.url == "http://example.com" + + def test_curl_with_insecure(): """Test parsing of --insecure flag.""" curl_command = "curl -k http://example.com" diff --git a/tests/test_open_api_import.py b/tests/test_open_api_import.py new file mode 100644 index 00000000..2ddbc34c --- /dev/null +++ b/tests/test_open_api_import.py @@ -0,0 +1,67 @@ +import json +from pathlib import Path + +from posting.importing.open_api import import_openapi_spec + + +def test_import(tmp_path: Path): + """Test importing security schemes.""" + spec = { + "openapi": "3.1.0", + "info": {"title": "Test", "version": "1.0", "description": "Test"}, + "paths": { + "/": { + "get": { + "parameters": [ + { + "name": "page", + "in": "query", + }, + { + "name": "account_id", + "in": "header", + "deprecated": True, + }, + ], + "responses": {"200": {"description": "OK"}}, + "security": [ + {"bearerAuth": []}, + ], + } + } + }, + "components": { + "securitySchemes": { + "bearerAuth": { + "type": "http", + "scheme": "bearer", + }, + }, + }, + } + spec_path = tmp_path / "spec.json" + spec_path.write_text(json.dumps(spec)) + collection = import_openapi_spec(spec_path) + + assert len(collection.requests) == 1 + + request = collection.requests[0] + assert request.url == "${BASE_URL}/" + assert request.method == "GET" + + assert len(request.params) == 1 + param = request.params[0] + assert param.name == "page" + assert param.value == "" + assert param.enabled + + assert len(request.headers) == 1 + header = request.headers[0] + assert header.name == "account_id" + assert header.value == "" + assert not header.enabled + + assert request.auth is not None + assert request.auth.type == "bearer_token" + assert request.auth.bearer_token is not None + assert request.auth.bearer_token.token == "${BEARERAUTH_BEARER_TOKEN}" diff --git a/uv.lock b/uv.lock index 737f2ad5..6be4daef 100644 --- a/uv.lock +++ b/uv.lock @@ -823,6 +823,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/b7/b9e70fde2c0f0c9af4cc5277782a89b66d35948ea3369ec9f598358c3ac5/multidict-6.1.0-py3-none-any.whl", hash = "sha256:48e171e52d1c4d33888e529b999e5900356b9ae588c2f09a52dcefb158b27506", size = 10051 }, ] +[[package]] +name = "openapi-pydantic" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/83/c6dd05cd518e1217b2096d04f959d7868396a3a99faf9669d14007668c38/openapi_pydantic-0.5.0.tar.gz", hash = "sha256:a48f88e2904a056e1ef6d4728cfb2f36aa3213ce194fb09fc04259b9007165f0", size = 60403 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/81/87f16c4bf2e73ea337b5bb1e074e3ddbdf3e5aabaaf3f8551aad6a9ac9fa/openapi_pydantic-0.5.0-py3-none-any.whl", hash = "sha256:06458efd34969446f42d96d51de39cdef4a9b19daf3cc456a2dfa697458ac542", size = 95858 }, +] + [[package]] name = "packaging" version = "24.1" @@ -870,12 +882,13 @@ wheels = [ [[package]] name = "posting" -version = "2.2.0" +version = "2.3.0" source = { editable = "." } dependencies = [ { name = "click" }, { name = "click-default-group" }, { name = "httpx", extra = ["brotli"] }, + { name = "openapi-pydantic" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "pyperclip" }, @@ -904,6 +917,7 @@ requires-dist = [ { name = "click", specifier = ">=8.1.7,<9.0.0" }, { name = "click-default-group", specifier = ">=1.2.4,<2.0.0" }, { name = "httpx", extras = ["brotli"], specifier = ">=0.27.2,<1.0.0" }, + { name = "openapi-pydantic", specifier = ">=0.5.0" }, { name = "pydantic", specifier = ">=2.9.2,<3.0.0" }, { name = "pydantic-settings", specifier = ">=2.4.0,<3.0.0" }, { name = "pyperclip", specifier = ">=1.9.0,<2.0.0" },