Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ruff formatting & linting #169

Merged
merged 8 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/publish-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ jobs:
version: "0.5.21"
- name: Pytest
run: uv run pytest
- name: Ruff check
run: uv run ruff check
- name: Ruff format
run: uv run ruff format --check

build-n-publish:
name: Build and publish Python distributions to PyPI
Expand Down
11 changes: 5 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,15 @@ If no arguments are supplied pytr will look for them in the file `~/.pytr/creden

## Linting and Code Formatting

This project uses [black](https://github.com/psf/black) for code linting and auto-formatting. You can auto-format the code by running:
This project uses [Ruff](https://astral.sh/ruff) for code linting and auto-formatting. You can auto-format the code by running:

```bash
# Install black if not already installed
pip install black

# Auto-format code
black ./pytr
uv run ruff format # Format code
uv run ruff check --fix-only # Remove unneeded imports, order imports, etc.
```

Ruff runs as part of CI and your Pull Request cannot be merged unless it satisfies the linting and formatting checks.

## Setting Up a Development Environment

1. Clone the repository:
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,12 @@ locale_dir = "pytr/locale"

[dependency-groups]
dev = [
"ruff>=0.9.4",
"pytest>=8.3.4",
]

[tool.ruff]
line-length = 120

[tool.ruff.lint]
extend-select = ["I"]
19 changes: 7 additions & 12 deletions pytr/account.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import json
import sys
from pygments import highlight, lexers, formatters
import time
from getpass import getpass

from pytr.api import TradeRepublicApi, CREDENTIALS_FILE
from pygments import formatters, highlight, lexers

from pytr.api import CREDENTIALS_FILE, TradeRepublicApi
from pytr.utils import get_logger


def get_settings(tr):
formatted_json = json.dumps(tr.settings(), indent=2)
if sys.stdout.isatty():
colorful_json = highlight(
formatted_json, lexers.JsonLexer(), formatters.TerminalFormatter()
)
colorful_json = highlight(formatted_json, lexers.JsonLexer(), formatters.TerminalFormatter())
return colorful_json
else:
return formatted_json
Expand Down Expand Up @@ -41,9 +40,7 @@ def login(phone_no=None, pin=None, web=True, store_credentials=False):
CREDENTIALS_FILE.parent.mkdir(parents=True, exist_ok=True)
if phone_no is None:
log.info("Credentials file not found")
print(
"Please enter your TradeRepublic phone number in the format +4912345678:"
)
print("Please enter your TradeRepublic phone number in the format +4912345678:")
phone_no = input()
else:
log.info("Phone number provided as argument")
Expand Down Expand Up @@ -74,15 +71,13 @@ def login(phone_no=None, pin=None, web=True, store_credentials=False):
exit(1)
request_time = time.time()
print("Enter the code you received to your mobile app as a notification.")
print(
f"Enter nothing if you want to receive the (same) code as SMS. (Countdown: {countdown})"
)
print(f"Enter nothing if you want to receive the (same) code as SMS. (Countdown: {countdown})")
code = input("Code: ")
if code == "":
countdown = countdown - (time.time() - request_time)
for remaining in range(int(countdown)):
print(
f"Need to wait {int(countdown-remaining)} seconds before requesting SMS...",
f"Need to wait {int(countdown - remaining)} seconds before requesting SMS...",
end="\r",
)
time.sleep(1)
Expand Down
20 changes: 5 additions & 15 deletions pytr/alarms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from datetime import datetime

from pytr.utils import preview, get_logger
from pytr.utils import get_logger, preview


class Alarms:
Expand All @@ -19,9 +19,7 @@ async def alarms_loop(self):
recv += 1
self.alarms = response
else:
print(
f"unmatched subscription of type '{subscription['type']}':\n{preview(response)}"
)
print(f"unmatched subscription of type '{subscription['type']}':\n{preview(response)}")

if recv == 1:
return
Expand All @@ -36,9 +34,7 @@ async def ticker_loop(self):
recv += 1
self.alarms = response
else:
print(
f"unmatched subscription of type '{subscription['type']}':\n{preview(response)}"
)
print(f"unmatched subscription of type '{subscription['type']}':\n{preview(response)}")

if recv == 1:
return
Expand All @@ -47,11 +43,7 @@ def overview(self):
print("ISIN status created target diff% createdAt triggeredAT")
self.log.debug(f"Processing {len(self.alarms)} alarms")

for (
a
) in (
self.alarms
): # sorted(positions, key=lambda x: x['netValue'], reverse=True):
for a in self.alarms: # sorted(positions, key=lambda x: x['netValue'], reverse=True):
self.log.debug(f" Processing {a} alarm")
ts = int(a["createdAt"]) / 1000.0
target_price = float(a["targetPrice"])
Expand All @@ -61,9 +53,7 @@ def overview(self):
triggered = "-"
else:
ts = int(a["triggeredAt"]) / 1000.0
triggered = datetime.fromtimestamp(ts).isoformat(
sep=" ", timespec="minutes"
)
triggered = datetime.fromtimestamp(ts).isoformat(sep=" ", timespec="minutes")

if a["createdPrice"] == 0:
diffP = 0.0
Expand Down
91 changes: 23 additions & 68 deletions pytr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@
import hashlib
import json
import pathlib
import ssl
import time
import urllib.parse
import uuid
from http.cookiejar import MozillaCookieJar

import certifi
import ssl
import requests
import websockets
from ecdsa import NIST256p, SigningKey
from ecdsa.util import sigencode_der
from http.cookiejar import MozillaCookieJar

from pytr.utils import get_logger


home = pathlib.Path.home()
BASE_DIR = home / ".pytr"
CREDENTIALS_FILE = BASE_DIR / "credentials"
Expand Down Expand Up @@ -96,9 +96,7 @@ def __init__(
self._locale = locale
self._save_cookies = save_cookies

self._credentials_file = (
pathlib.Path(credentials_file) if credentials_file else CREDENTIALS_FILE
)
self._credentials_file = pathlib.Path(credentials_file) if credentials_file else CREDENTIALS_FILE

if not (phone_no and pin):
try:
Expand All @@ -107,18 +105,12 @@ def __init__(
self.phone_no = lines[0].strip()
self.pin = lines[1].strip()
except FileNotFoundError:
raise ValueError(
f"phone_no and pin must be specified explicitly or via {self._credentials_file}"
)
raise ValueError(f"phone_no and pin must be specified explicitly or via {self._credentials_file}")
else:
self.phone_no = phone_no
self.pin = pin

self._cookies_file = (
pathlib.Path(cookies_file)
if cookies_file
else BASE_DIR / f"cookies.{self.phone_no}.txt"
)
self._cookies_file = pathlib.Path(cookies_file) if cookies_file else BASE_DIR / f"cookies.{self.phone_no}.txt"

self.keyfile = keyfile if keyfile else KEY_FILE
try:
Expand Down Expand Up @@ -231,9 +223,7 @@ def complete_weblogin(self, verify_code):
if not self._process_id and not self._websession:
raise ValueError("Initiate web login first.")

r = self._websession.post(
f"{self._host}/api/v1/auth/web/login/{self._process_id}/{verify_code}"
)
r = self._websession.post(f"{self._host}/api/v1/auth/web/login/{self._process_id}/{verify_code}")
r.raise_for_status()
self.save_websession()
self._weblogin = True
Expand Down Expand Up @@ -270,9 +260,7 @@ def _web_request(self, url_path, payload=None, method="GET"):
r = self._websession.get(f"{self._host}/api/v1/auth/web/session")
r.raise_for_status()
self._web_session_token_expires_at = time.time() + 290
return self._websession.request(
method=method, url=f"{self._host}{url_path}", data=payload
)
return self._websession.request(method=method, url=f"{self._host}{url_path}", data=payload)

async def _get_ws(self):
if self._ws and self._ws.open:
Expand Down Expand Up @@ -301,9 +289,7 @@ async def _get_ws(self):
}
connect_id = 31

self._ws = await websockets.connect(
"wss://api.traderepublic.com", ssl=ssl_context, extra_headers=extra_headers
)
self._ws = await websockets.connect("wss://api.traderepublic.com", ssl=ssl_context, extra_headers=extra_headers)
await self._ws.send(f"connect {connect_id} {json.dumps(connection_message)}")
response = await self._ws.recv()

Expand Down Expand Up @@ -354,9 +340,7 @@ async def recv(self):

if subscription_id not in self.subscriptions:
if code != "C":
self.log.debug(
f"No active subscription for id {subscription_id}, dropping message"
)
self.log.debug(f"No active subscription for id {subscription_id}, dropping message")
continue
subscription = self.subscriptions[subscription_id]

Expand Down Expand Up @@ -408,16 +392,12 @@ async def _receive_one(self, fut, timeout):
subscription_id = await fut

try:
return await asyncio.wait_for(
self._recv_subscription(subscription_id), timeout
)
return await asyncio.wait_for(self._recv_subscription(subscription_id), timeout)
finally:
await self.unsubscribe(subscription_id)

def run_blocking(self, fut, timeout=5.0):
return asyncio.get_event_loop().run_until_complete(
self._receive_one(fut, timeout=timeout)
)
return asyncio.get_event_loop().run_until_complete(self._receive_one(fut, timeout=timeout))

async def portfolio(self):
return await self.subscribe({"type": "portfolio"})
Expand All @@ -437,21 +417,14 @@ async def cash(self):
async def available_cash_for_payout(self):
return await self.subscribe({"type": "availableCashForPayout"})

async def portfolio_status(self):
return await self.subscribe({"type": "portfolioStatus"})

async def portfolio_history(self, timeframe):
return await self.subscribe(
{"type": "portfolioAggregateHistory", "range": timeframe}
)
return await self.subscribe({"type": "portfolioAggregateHistory", "range": timeframe})

async def instrument_details(self, isin):
return await self.subscribe({"type": "instrument", "id": isin})

async def instrument_suitability(self, isin):
return await self.subscribe(
{"type": "instrumentSuitability", "instrumentId": isin}
)
return await self.subscribe({"type": "instrumentSuitability", "instrumentId": isin})

async def stock_details(self, isin):
return await self.subscribe({"type": "stockDetails", "id": isin})
Expand All @@ -460,19 +433,15 @@ async def add_watchlist(self, isin):
return await self.subscribe({"type": "addToWatchlist", "instrumentId": isin})

async def remove_watchlist(self, isin):
return await self.subscribe(
{"type": "removeFromWatchlist", "instrumentId": isin}
)
return await self.subscribe({"type": "removeFromWatchlist", "instrumentId": isin})

async def ticker(self, isin, exchange="LSX"):
return await self.subscribe({"type": "ticker", "id": f"{isin}.{exchange}"})

async def performance(self, isin, exchange="LSX"):
return await self.subscribe({"type": "performance", "id": f"{isin}.{exchange}"})

async def performance_history(
self, isin, timeframe, exchange="LSX", resolution=None
):
async def performance_history(self, isin, timeframe, exchange="LSX", resolution=None):
parameters = {
"type": "aggregateHistory",
"id": f"{isin}.{exchange}",
Expand Down Expand Up @@ -501,9 +470,7 @@ async def timeline_detail_order(self, order_id):
return await self.subscribe({"type": "timelineDetail", "orderId": order_id})

async def timeline_detail_savings_plan(self, savings_plan_id):
return await self.subscribe(
{"type": "timelineDetail", "savingsPlanId": savings_plan_id}
)
return await self.subscribe({"type": "timelineDetail", "savingsPlanId": savings_plan_id})

async def timeline_transactions(self, after=None):
return await self.subscribe({"type": "timelineTransactions", "after": after})
Expand All @@ -518,9 +485,7 @@ async def search_tags(self):
return await self.subscribe({"type": "neonSearchTags"})

async def search_suggested_tags(self, query):
return await self.subscribe(
{"type": "neonSearchSuggestedTags", "data": {"q": query}}
)
return await self.subscribe({"type": "neonSearchSuggestedTags", "data": {"q": query}})

async def search(
self,
Expand All @@ -546,17 +511,11 @@ async def search(
if filter_index:
search_parameters["filter"].append({"key": "index", "value": filter_index})
if filter_country:
search_parameters["filter"].append(
{"key": "country", "value": filter_country}
)
search_parameters["filter"].append({"key": "country", "value": filter_country})
if filter_region:
search_parameters["filter"].append(
{"key": "region", "value": filter_region}
)
search_parameters["filter"].append({"key": "region", "value": filter_region})
if filter_sector:
search_parameters["filter"].append(
{"key": "sector", "value": filter_sector}
)
search_parameters["filter"].append({"key": "sector", "value": filter_sector})

search_type = "neonSearch" if not aggregate else "neonSearchAggregations"
return await self.subscribe({"type": search_type, "data": search_parameters})
Expand Down Expand Up @@ -750,17 +709,13 @@ async def change_savings_plan(
return await self.subscribe(parameters)

async def cancel_savings_plan(self, savings_plan_id):
return await self.subscribe(
{"type": "cancelSavingsPlan", "id": savings_plan_id}
)
return await self.subscribe({"type": "cancelSavingsPlan", "id": savings_plan_id})

async def price_alarm_overview(self):
return await self.subscribe({"type": "priceAlarms"})

async def create_price_alarm(self, isin, price):
return await self.subscribe(
{"type": "createPriceAlarm", "instrumentId": isin, "targetPrice": price}
)
return await self.subscribe({"type": "createPriceAlarm", "instrumentId": isin, "targetPrice": price})

async def cancel_price_alarm(self, price_alarm_id):
return await self.subscribe({"type": "cancelPriceAlarm", "id": price_alarm_id})
Expand Down
Loading
Loading