Skip to content
This repository was archived by the owner on Jan 7, 2024. It is now read-only.

Commit

Permalink
Apply black and isort
Browse files Browse the repository at this point in the history
  • Loading branch information
rmol committed Jun 16, 2020
1 parent 24dc1ea commit a31be70
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 135 deletions.
139 changes: 34 additions & 105 deletions sdclientapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,22 @@
import http
import json
import os
import requests
from datetime import datetime
from requests.exceptions import (
ConnectTimeout,
ReadTimeout,
ConnectionError,
TooManyRedirects
)
from subprocess import PIPE, Popen, TimeoutExpired
from typing import List, Tuple, Dict, Optional, Any
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urljoin

import requests
from requests.exceptions import ConnectionError, ConnectTimeout, ReadTimeout, TooManyRedirects

from .sdlocalobjects import (
BaseError,
WrongUUIDError,
AuthError,
BaseError,
Reply,
ReplyError,
Source,
Reply,
Submission,
WrongUUIDError,
)

DEFAULT_PROXY_VM_NAME = "sd-proxy"
Expand Down Expand Up @@ -111,12 +107,8 @@ def __init__(
self.journalist_last_name = None # type: Optional[str]
self.req_headers = dict() # type: Dict[str, str]
self.proxy = proxy # type: bool
self.default_request_timeout = (
default_request_timeout or DEFAULT_REQUEST_TIMEOUT
)
self.default_download_timeout = (
default_download_timeout or DEFAULT_DOWNLOAD_TIMEOUT
)
self.default_request_timeout = default_request_timeout or DEFAULT_REQUEST_TIMEOUT
self.default_download_timeout = default_download_timeout or DEFAULT_DOWNLOAD_TIMEOUT

self.proxy_vm_name = DEFAULT_PROXY_VM_NAME
config = configparser.ConfigParser()
Expand Down Expand Up @@ -250,9 +242,7 @@ def authenticate(self, totp: Optional[str] = None) -> bool:
raise AuthError("Authentication error")

self.token = token_data["token"]
self.token_expiration = datetime.strptime(
token_data["expiration"], "%Y-%m-%dT%H:%M:%S.%fZ"
)
self.token_expiration = datetime.strptime(token_data["expiration"], "%Y-%m-%dT%H:%M:%S.%fZ")
self.token_journalist_uuid = token_data["journalist_uuid"]
self.journalist_first_name = token_data["journalist_first_name"]
self.journalist_last_name = token_data["journalist_last_name"]
Expand All @@ -279,10 +269,7 @@ def get_sources(self) -> List[Source]:
method = "GET"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

sources = data["sources"]
Expand All @@ -305,10 +292,7 @@ def get_source(self, source: Source) -> Source:
method = "GET"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -339,10 +323,7 @@ def delete_source(self, source: Source) -> bool:
method = "DELETE"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -377,10 +358,7 @@ def add_star(self, source: Source) -> bool:
method = "POST"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)
if status_code == 404:
raise WrongUUIDError("Missing source {}".format(source.uuid))
Expand All @@ -400,10 +378,7 @@ def remove_star(self, source: Source) -> bool:
method = "DELETE"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)
if status_code == 404:
raise WrongUUIDError("Missing source {}".format(source.uuid))
Expand All @@ -424,10 +399,7 @@ def get_submissions(self, source: Source) -> List[Submission]:
method = "GET"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -455,10 +427,7 @@ def get_submission(self, submission: Submission) -> Submission:
method = "GET"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -488,10 +457,7 @@ def get_all_submissions(self) -> List[Submission]:
method = "GET"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

result = [] # type: List[Submission]
Expand Down Expand Up @@ -519,10 +485,7 @@ def delete_submission(self, submission: Submission) -> bool:
method = "DELETE"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand Down Expand Up @@ -584,9 +547,7 @@ def download_submission(
# This is where we will save our downloaded file
filepath = os.path.join(path, submission.filename)
with open(filepath, "wb") as fobj:
for chunk in data.iter_content(
chunk_size=1024
): # Getting 1024 in each chunk
for chunk in data.iter_content(chunk_size=1024): # Getting 1024 in each chunk
if chunk:
fobj.write(chunk)

Expand All @@ -595,7 +556,7 @@ def download_submission(
"/home/user/QubesIncoming/", self.proxy_vm_name, data["filename"]
)

return headers['Etag'].strip('\"'), filepath
return headers["Etag"].strip('"'), filepath

def flag_source(self, source: Source) -> bool:
"""
Expand All @@ -608,10 +569,7 @@ def flag_source(self, source: Source) -> bool:
method = "POST"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand All @@ -635,17 +593,12 @@ def get_current_user(self) -> Any:
method = "GET"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

return data

def reply_source(
self, source: Source, msg: str, reply_uuid: Optional[str] = None
) -> Reply:
def reply_source(self, source: Source, msg: str, reply_uuid: Optional[str] = None) -> Reply:
"""
This method is used to reply to a given source. The message should be preencrypted with the
source's GPG public key.
Expand Down Expand Up @@ -685,10 +638,7 @@ def get_replies_from_source(self, source: Source) -> List[Reply]:
method = "GET"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand All @@ -713,10 +663,7 @@ def get_reply_from_source(self, source: Source, reply_uuid: str) -> Reply:
method = "GET"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand All @@ -736,10 +683,7 @@ def get_all_replies(self) -> List[Reply]:
method = "GET"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

result = []
Expand All @@ -760,9 +704,7 @@ def download_reply(self, reply: Reply, path: str = "") -> Tuple[str, str]:
:returns: Tuple of etag and path of the saved Reply.
"""
path_query = "api/v1/sources/{}/replies/{}/download".format(
reply.source_uuid, reply.uuid
)
path_query = "api/v1/sources/{}/replies/{}/download".format(reply.source_uuid, reply.uuid)

method = "GET"

Expand All @@ -771,10 +713,7 @@ def download_reply(self, reply: Reply, path: str = "") -> Tuple[str, str]:
raise BaseError("Please provide a valid directory to save.")

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand All @@ -787,9 +726,7 @@ def download_reply(self, reply: Reply, path: str = "") -> Tuple[str, str]:
# This is where we will save our downloaded file
filepath = os.path.join(path, reply.filename)
with open(filepath, "wb") as fobj:
for chunk in data.iter_content(
chunk_size=1024
): # Getting 1024 in each chunk
for chunk in data.iter_content(chunk_size=1024): # Getting 1024 in each chunk
if chunk:
fobj.write(chunk)

Expand All @@ -798,7 +735,7 @@ def download_reply(self, reply: Reply, path: str = "") -> Tuple[str, str]:
"/home/user/QubesIncoming/", self.proxy_vm_name, data["filename"]
)

return headers['Etag'].strip('\"'), filepath
return headers["Etag"].strip('"'), filepath

def delete_reply(self, reply: Reply) -> bool:
"""
Expand All @@ -810,17 +747,12 @@ def delete_reply(self, reply: Reply) -> bool:
# Not using direct URL because this helps to use the same method
# from local reply (not fetched from server) objects.
# See the *from_string for an example.
path_query = "api/v1/sources/{}/replies/{}".format(
reply.source_uuid, reply.uuid
)
path_query = "api/v1/sources/{}/replies/{}".format(reply.source_uuid, reply.uuid)

method = "DELETE"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if status_code == 404:
Expand All @@ -839,10 +771,7 @@ def logout(self) -> bool:
method = "POST"

data, status_code, headers = self._send_json_request(
method,
path_query,
headers=self.req_headers,
timeout=self.default_request_timeout,
method, path_query, headers=self.req_headers, timeout=self.default_request_timeout,
)

if "message" in data and data["message"] == "Your token has been revoked.":
Expand Down
1 change: 1 addition & 0 deletions sdclientapi/sdlocalobjects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

class BaseError(Exception):
"""For generic errors not covered by other exceptions"""

def __init__(self, message: typing.Optional[str] = None) -> None:
self.msg = message

Expand Down
21 changes: 11 additions & 10 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import datetime
import hashlib
import os
import pytest
import shutil
import tempfile
import time
import unittest

import pyotp
import pytest
import vcr
from requests.exceptions import ConnectTimeout, ReadTimeout
from utils import load_auth_for_http
from utils import save_auth_for_http

from sdclientapi import API, RequestTimeoutError
from sdclientapi.sdlocalobjects import AuthError
from sdclientapi.sdlocalobjects import BaseError
from sdclientapi.sdlocalobjects import Reply
from sdclientapi.sdlocalobjects import ReplyError
from sdclientapi.sdlocalobjects import Source
from sdclientapi.sdlocalobjects import Submission
from sdclientapi.sdlocalobjects import WrongUUIDError
from sdclientapi.sdlocalobjects import (
AuthError,
BaseError,
Reply,
ReplyError,
Source,
Submission,
WrongUUIDError,
)
from utils import load_auth_for_http, save_auth_for_http

NUM_REPLIES_PER_SOURCE = 2

Expand Down
Loading

0 comments on commit a31be70

Please sign in to comment.