Skip to content

Commit

Permalink
Issue #83: Use Asyncio.
Browse files Browse the repository at this point in the history
  • Loading branch information
Nekmo committed Aug 12, 2023
1 parent 86613a4 commit 069402c
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 84 deletions.
89 changes: 89 additions & 0 deletions dirhunt/console.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
from __future__ import unicode_literals

from typing import Sequence, Tuple, Optional, TypeVar, Union, Coroutine, Any

from prompt_toolkit.application import Application
from prompt_toolkit.formatted_text import AnyFormattedText
from prompt_toolkit.key_binding import KeyPressEvent
from prompt_toolkit.key_binding.defaults import load_key_bindings
from prompt_toolkit.key_binding.key_bindings import KeyBindings, merge_key_bindings
from prompt_toolkit.layout import Layout
from prompt_toolkit.styles import BaseStyle
from prompt_toolkit.widgets import RadioList, Label
from prompt_toolkit.layout.containers import HSplit


_T = TypeVar("_T")
E = KeyPressEvent


def status_code_colors(status_code):
"""Return a color for a status code."""
if 100 <= status_code < 200:
Expand All @@ -16,3 +35,73 @@ def status_code_colors(status_code):
return "magenta1"
else:
return "medium_orchid1"


def radiolist_prompt(
title: str = "",
values: Sequence[Tuple[_T, AnyFormattedText]] = None,
default: Optional[_T] = None,
cancel_value: Optional[_T] = None,
style: Optional[BaseStyle] = None,
async_: bool = False,
) -> Union[_T, Coroutine[Any, Any, _T]]:
"""Create a mini inline application for a radiolist prompt.
:param title: The title to display above the radiolist.
:param values: A sequence of tuples of the form (value, formatted_text).
:param default: The default value to select.
:param cancel_value: The value to return if the user presses Ctrl-C.
:param style: A style to apply to the radiolist.
:param async_: If True, run the prompt in async mode.
:return: The value selected by the user.
"""
# Create the radio list
radio_list = RadioList(values, default)
# Remove the enter key binding so that we can augment it
radio_list.control.key_bindings.remove("up")
radio_list.control.key_bindings.remove("down")
radio_list.control.key_bindings.remove("enter")

bindings = KeyBindings()

@bindings.add("up")
def up(_) -> None:
radio_list._selected_index = max(0, radio_list._selected_index - 1)
radio_list._handle_enter()

@bindings.add("down")
def down(_) -> None:
radio_list._selected_index = min(
len(radio_list.values) - 1, radio_list._selected_index + 1
)
radio_list._handle_enter()

# Replace the enter key binding to select the value and also submit it
@bindings.add("enter")
def exit_with_value(event: E):
"""
Pressing Enter will exit the user interface, returning the highlighted value.
"""
radio_list._handle_enter()
event.app.exit(result=radio_list.current_value)

@bindings.add("c-c")
def backup_exit_with_value(event: E):
"""
Pressing Ctrl-C will exit the user interface with the cancel_value.
"""
event.app.exit(result=cancel_value)

merged_key_bindings = merge_key_bindings([load_key_bindings(), bindings])
# Create and run the mini inline application
application = Application(
layout=Layout(HSplit([Label(title), radio_list])),
key_bindings=merged_key_bindings,
mouse_support=True,
style=style,
full_screen=False,
)
if async_:
return application.run_async()
else:
return application.run()
101 changes: 43 additions & 58 deletions dirhunt/crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,17 @@
from collections import defaultdict
from concurrent.futures.thread import _python_exit
from hashlib import sha256
from threading import Lock, ThreadError
from typing import Optional, Set, Coroutine, Any, Dict
from typing import Optional, Set, Coroutine, Any, Dict, Literal

import humanize as humanize
from click import get_terminal_size
from rich.console import Console
from rich.text import Text
from rich.progress import (
Progress,
TaskProgressColumn,
TimeRemainingColumn,
BarColumn,
TextColumn,
)
from rich.text import Text

from dirhunt import __version__
from dirhunt._compat import queue, Queue, unregister
Expand All @@ -40,7 +37,12 @@
"""Flags importance"""


# Some URLs can take a long time to process, so we increase the number of urls
# to process in the queue to finish as soon as possible.
DIRHUNT_OVERFLOW_MULTIPLIER = int(os.environ.get("DIRHUNT_OVERFLOW_MULTIPLIER", 3))

resume_dir = os.path.expanduser("~/.cache/dirhunt/")
steps = Literal[None, "crawling_urls", "finishing_crawler_urls"]


class DomainSemaphore:
Expand All @@ -64,7 +66,7 @@ def release(self, domain: str):

class Crawler:
urls_info = None
started = False
step: steps = None

def __init__(self, configuration: Configuration, loop: asyncio.AbstractEventLoop):
"""Initialize Crawler.
Expand All @@ -81,11 +83,9 @@ def __init__(self, configuration: Configuration, loop: asyncio.AbstractEventLoop
self.domain_semaphore = DomainSemaphore(configuration.concurrency)
self.results = Queue()
self.index_of_processors = []
self.processed = {}
self.add_lock = Lock()
self.start_dt = datetime.datetime.now()
self.total_crawler_urls: int = 0
self.current_processed_count: int = 0
self.total_crawler_urls: int = 0
self.sources = Sources(self)
self.domain_protocols: Dict[str, set] = defaultdict(set)
self.progress = Progress(
Expand All @@ -100,15 +100,15 @@ def __init__(self, configuration: Configuration, loop: asyncio.AbstractEventLoop

async def start(self):
"""Add urls to process."""
if self.started:
if self.step == "crawling_urls":
await self.restart()
return
for url in self.configuration.urls:
crawler_url = CrawlerUrl(self, url, depth=self.configuration.max_depth)
await self.add_domain(crawler_url.url.domain)
await self.add_crawler_url(crawler_url)
self.add_domain_protocol(crawler_url)
self.started = True
self.step = "crawling_urls"
await self.run_tasks()

async def run_tasks(self) -> None:
Expand All @@ -124,7 +124,7 @@ async def restart(self):
async def add_crawler_url(self, crawler_url: CrawlerUrl) -> Optional[asyncio.Task]:
"""Add crawler_url to tasks"""
if (
self.total_crawler_urls > self.configuration.limit
self.hard_limit_reached
or crawler_url in self.crawler_urls
or not self.in_domains(crawler_url.url.domain)
):
Expand Down Expand Up @@ -171,15 +171,20 @@ def print_processor(self, processor: ProcessBase):
"""Print processor to console."""
if 300 > processor.status >= 200:
self.add_domain_protocol(processor.crawler_url)
if self.step == "finishing_crawler_urls":
return
if self.soft_limit_reached:
self.step = "finishing_crawler_urls"
self.cancel_tasks()
self.console.print(processor.get_text())
self.progress.update(
self.progress_task,
description=f"Obtained [bold blue]{self.current_processed_count}[/bold blue] urls out of "
f"[bold blue]{self.total_crawler_urls}[/bold blue]",
f"[bold blue]{min(self.total_crawler_urls, self.configuration.limit)}[/bold blue]",
completed=self.current_processed_count,
refresh=True,
total=self.configuration.limit
if self.total_crawler_urls > self.configuration.limit
if self.total_crawler_urls >= self.configuration.limit
else None,
)

Expand All @@ -192,16 +197,26 @@ def pending_crawler_urls(self):
"""Return pending crawler_urls without finished."""
return filter(lambda x: not x.finished, self.crawler_urls)

def add_init_urls(self, *urls):
"""Add urls to queue."""
self.initial_urls.extend(urls)
for crawler_url in urls:
if not isinstance(crawler_url, CrawlerUrl):
crawler_url = CrawlerUrl(
self, crawler_url, depth=self.depth, timeout=self.timeout
)
self.add_domain(crawler_url.url.only_domain)
self.add_url(crawler_url, lock=False)
@property
def soft_limit_reached(self) -> bool:
"""Return True if the soft limit is reached.
This limit is used when the number of urls already processed is reached.
"""
return self.current_processed_count >= self.configuration.limit

@property
def hard_limit_reached(self) -> bool:
"""Return True if the hard limit is reached.
This limit is used to add urls to the queue. Some URLs can take a long time to
process, so we increase the number of urls to process in the queue to finish
as soon as possible.
"""
return (
self.total_crawler_urls
>= self.configuration.limit * DIRHUNT_OVERFLOW_MULTIPLIER
)

def add_task(
self, coro: Coroutine[Any, Any, Any], name: Optional[str] = None
Expand All @@ -211,40 +226,10 @@ def add_task(
task.add_done_callback(self.tasks.discard)
return task

def add_message(self, body):
from dirhunt.processors import Message

self.results.put(Message(body))

def echo(self, body):
if self.std is None:
return
# TODO: remove ANSI chars on is not tty
self.std.write(str(body))
self.std.write("\n")

def erase(self):
if self.std is None or not self.std.isatty():
return
CURSOR_UP_ONE = "\x1b[1A"
ERASE_LINE = "\x1b[2K"
# This can be improved. In the future we may want to define stdout/stderr with an cli option
# fn = sys.stderr.write if sys.stderr.isatty() else sys.stdout.write
self.std.write(CURSOR_UP_ONE + ERASE_LINE)

def print_progress(self, finished=False):
if not self.progress_enabled:
# Cancel print progress on
return
self.echo(
"{} {} {}".format(
next(self.spinner),
"Finished after" if finished else "Started",
(humanize.naturaldelta if finished else humanize.naturaltime)(
datetime.datetime.now() - self.start_dt
),
)
)
def cancel_tasks(self):
"""Cancel all tasks."""
for task in self.tasks:
task.cancel()

def print_results(self, exclude=None, include=None):
exclude = exclude or set()
Expand Down
21 changes: 10 additions & 11 deletions dirhunt/crawler_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def retrieve(self, retries: Optional[int] = None) -> Optional["ProcessBase
get_processor,
)

processor = None
released_lock = False
try:
await self.crawler.domain_semaphore.acquire(self.crawler_url.url.domain)
async with self.crawler.session.get(
Expand All @@ -76,18 +76,17 @@ async def retrieve(self, retries: Optional[int] = None) -> Optional["ProcessBase
self.content = await get_content(response)
if processor.has_descendants:
processor = get_processor(self)
except (ClientError, asyncio.TimeoutError) as e:
except (ClientError, asyncio.TimeoutError):
if retries and retries > 0:
self.crawler.domain_semaphore.release(self.crawler_url.url.domain)
released_lock = True
await asyncio.sleep(RETRIES_WAIT)
return await self.retrieve(retries - 1)
else:
self.crawler.print_error(
f"Request error to {self.crawler_url.url}: {get_message_from_exception(e)}"
)
else:
await processor.process(self)
finally:
self.crawler.domain_semaphore.release(self.crawler_url.url.domain)
if not released_lock:
self.crawler.domain_semaphore.release(self.crawler_url.url.domain)
return processor

@property
Expand Down Expand Up @@ -157,14 +156,14 @@ async def retrieve(self):
from processors import GenericProcessor

crawler_url_request = CrawlerUrlRequest(self)
processor = await crawler_url_request.retrieve()
self.processor = await crawler_url_request.retrieve()
self.crawler.current_processed_count += 1
if (
processor is not None
and not isinstance(processor, GenericProcessor)
self.processor is not None
and not isinstance(self.processor, GenericProcessor)
and self.url_type not in {"asset", "index_file"}
):
self.crawler.print_processor(processor)
self.crawler.print_processor(self.processor)
# if self.must_be_downloaded(response):
# processor = get_processor(response, text, self, soup) or GenericProcessor(
# response, self
Expand Down
Loading

0 comments on commit 069402c

Please sign in to comment.