Skip to content

Commit

Permalink
Issue #83: Use Asyncio
Browse files Browse the repository at this point in the history
  • Loading branch information
Nekmo committed Aug 10, 2023
1 parent b5a2d14 commit d742ffe
Show file tree
Hide file tree
Showing 13 changed files with 200 additions and 101 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.9.0
current_version = 2.0.0
commit = True
tag = True

Expand Down
2 changes: 1 addition & 1 deletion dirhunt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# -*- coding: utf-8 -*-

__version__ = "0.9.0"
__version__ = "2.0.0"
59 changes: 36 additions & 23 deletions dirhunt/crawler.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,37 @@
# -*- coding: utf-8 -*-
import asyncio
import datetime
import json
import multiprocessing
import os
from asyncio import Semaphore
from hashlib import sha256
from asyncio import Semaphore, Task
from concurrent.futures.thread import _python_exit
from hashlib import sha256
from threading import Lock, ThreadError
import datetime
from typing import Optional
from typing import Optional, Set, Coroutine, Any

import humanize as humanize
from click import get_terminal_size
from rich.console import Console
from rich.text import Text
from rich.traceback import install

from dirhunt import processors
from dirhunt import __version__
from dirhunt._compat import queue, Queue, unregister
from dirhunt.cli import random_spinner
from dirhunt.configuration import Configuration
from dirhunt.crawler_url import CrawlerUrl
from dirhunt.exceptions import (
EmptyError,
RequestError,
reraise_with_stack,
IncompatibleVersionError,
)
from dirhunt.json_report import JsonReportEncoder
from dirhunt.sessions import Sessions, Session
from dirhunt.sessions import Session
from dirhunt.sources import Sources
from dirhunt.url import Url
from dirhunt.url_info import UrlsInfo

"""Flags importance"""


resume_dir = os.path.expanduser("~/.cache/dirhunt/")
install(show_locals=True)


class DomainSemaphore:
Expand Down Expand Up @@ -67,9 +63,9 @@ def __init__(self, configuration: Configuration, loop: asyncio.AbstractEventLoop
"""
self.configuration = configuration
self.loop = loop
self.tasks = set()
self.crawler_urls = set()
self.domains = set()
self.tasks: Set[Task] = set()
self.crawler_urls: Set[CrawlerUrl] = set()
self.domains: Set[str] = set()
self.console = Console(highlight=False)
self.session = Session()
self.domain_semaphore = DomainSemaphore(configuration.concurrency)
Expand All @@ -78,13 +74,14 @@ def __init__(self, configuration: Configuration, loop: asyncio.AbstractEventLoop
self.processed = {}
self.add_lock = Lock()
self.start_dt = datetime.datetime.now()
self.current_processed_count = 0
self.current_processed_count: int = 0
self.sources = Sources(self)

async def start(self):
"""Add urls to process."""
for url in self.configuration.urls:
crawler_url = CrawlerUrl(self, url, depth=self.configuration.max_depth)
self.domains.add(crawler_url.url.domain)
await self.add_domain(crawler_url.url.domain)
await self.add_crawler_url(crawler_url)

while self.tasks:
Expand All @@ -98,11 +95,18 @@ async def add_crawler_url(self, crawler_url: CrawlerUrl) -> Optional[asyncio.Tas
or crawler_url.url.domain not in self.domains
):
return
self.current_processed_count += 1
self.crawler_urls.add(crawler_url)
task = self.loop.create_task(crawler_url.retrieve())
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
return task
return self.add_task(
crawler_url.retrieve(), name=f"crawlerurl-{self.current_processed_count}"
)

def print_error(self, message: str):
"""Print error message to console."""
text = Text()
text.append("[ERROR] ", style="red")
text.append(message)
self.console.print(text)

def add_init_urls(self, *urls):
"""Add urls to queue."""
Expand Down Expand Up @@ -130,11 +134,20 @@ def in_domains(self, domain):
return False
domain = ".".join(parts[1:])

def add_domain(self, domain):
async def add_domain(self, domain: str):
"""Add domain to domains."""
if domain in self.domains:
return
self.domains.add(domain)
self.sources.add_domain(domain)
await self.sources.add_domain(domain)

def add_task(
self, coro: Coroutine[Any, Any, Any], name: Optional[str] = None
) -> Task:
task = self.loop.create_task(coro, name=name)
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
return task

def add_message(self, body):
from dirhunt.processors import Message
Expand Down
65 changes: 41 additions & 24 deletions dirhunt/sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from functools import cached_property
from typing import List, Type, Union

from typing_extensions import TYPE_CHECKING

from dirhunt.sources.commoncrawl import CommonCrawl
from dirhunt.sources.crtsh import CrtSh
from dirhunt.sources.google import Google
Expand All @@ -6,37 +11,49 @@
from dirhunt.sources.virustotal import VirusTotal
from dirhunt.sources.wayback import Wayback

SOURCE_CLASSES = [
Robots,
VirusTotal,
Google,

if TYPE_CHECKING:
from dirhunt.sources.base import SourceBase


SOURCE_CLASSES: List[Type["SourceBase"]] = [
# Robots,
# VirusTotal,
# Google,
CommonCrawl,
CrtSh,
CertificateSSL,
Wayback,
# CrtSh,
# CertificateSSL,
# Wayback,
]


def get_source_name(cls):
def get_source_name(cls: Type["SourceBase"]):
return cls.__name__.lower()


class Sources(object):
def __init__(self, callback, error_callback, excluded_sources=()):
self.callback = callback
self.error_callback = error_callback
self.sources = [
cls(self.callback, error_callback)
if TYPE_CHECKING:
from dirhunt.crawler import Crawler


class Sources:
"""Sources class. This class is used to manage the sources."""

def __init__(self, crawler: "Crawler"):
self.crawler = crawler

@cached_property
def source_classes(self) -> List[Type["SourceBase"]]:
"""Return source classes."""
return [
cls
for cls in SOURCE_CLASSES
if get_source_name(cls) not in excluded_sources
if cls not in self.crawler.configuration.exclude_sources
]

def add_domain(self, domain):
for source in self.sources:
source.add_domain(domain)

def finished(self):
for source in self.sources:
if source.is_running():
return False
return True
async def add_domain(self, domain: str):
"""Add domain to sources."""
for source_cls in self.source_classes:
source = source_cls(self, domain)
self.crawler.add_task(
source.retrieve_urls(domain), f"{source.get_source_name()}-{domain}"
)
92 changes: 80 additions & 12 deletions dirhunt/sources/base.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,89 @@
from dirhunt.pool import Pool
import datetime
import json
import os
from functools import cached_property
from pathlib import Path
from typing import List, Iterable, Optional

import aiofiles
from aiohttp import ClientError
from platformdirs import user_cache_dir
from typing_extensions import TYPE_CHECKING

class Source(Pool):
def __init__(self, result_callback, error_callback, max_workers=None):
super(Source, self).__init__(max_workers)
self.result_callback = result_callback
self.error_callback = error_callback
from dirhunt import __version__
from dirhunt.crawler_url import CrawlerUrl

def add_domain(self, domain):
self.submit(domain)
if TYPE_CHECKING:
from dirhunt.sources import Sources

def callback(self, domain):

class SourceBase:
max_cache_age = datetime.timedelta(days=7)

def __init__(self, sources: "Sources", domain: str):
self.sources = sources
self.domain = domain

@classmethod
def get_source_name(cls) -> str:
return cls.__name__.lower()

@property
def cache_dir(self) -> Path:
return Path(user_cache_dir()) / "dirhunt" / self.get_source_name()

@property
def cache_file(self) -> Path:
return self.cache_dir / f"{self.domain}.json"

@cached_property
def is_cache_expired(self) -> bool:
return (
not self.cache_file.exists()
or self.cache_file.stat().st_mtime
< (datetime.datetime.now() - self.max_cache_age).timestamp()
)

def get_from_cache(self) -> Optional[List[str]]:
with open(self.cache_file) as file:
data = json.load(file)
if data["version"] != __version__:
return None
return data["urls"]

async def search_by_domain(self, domain: str) -> Iterable[str]:
raise NotImplementedError

def add_result(self, url):
if self.result_callback:
self.result_callback(url)
async def retrieve_urls(self, domain: str):
urls = None
if not self.is_cache_expired:
urls = self.get_from_cache()
if urls is None:
try:
urls = await self.search_by_domain(domain)
except ClientError as e:
self.sources.crawler.print_error(str(e))
urls = []
else:
self.save_to_cache(urls)
for url in urls:
await self.add_url(url)

def save_to_cache(self, urls: Iterable[str]) -> None:
cache_data = {
"version": __version__,
"domain": self.domain,
"urls": list(urls),
}
os.makedirs(str(self.cache_file.parent), exist_ok=True)
with open(self.cache_file, "w") as file:
json.dump(cache_data, file)

async def add_url(self, url: str):
"""Add url to crawler."""
await self.sources.crawler.add_crawler_url(
CrawlerUrl(self.sources.crawler, url)
)

def add_error(self, message):
if self.error_callback:
Expand Down
56 changes: 28 additions & 28 deletions dirhunt/sources/commoncrawl.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,48 @@
import json
from json import JSONDecodeError
from typing import Iterable

from requests.exceptions import RequestException
from dirhunt.sessions import Sessions
from dirhunt.sources.base import Source
from aiohttp import ClientError

from dirhunt.sources.base import SourceBase


COMMONCRAWL_URL = "https://index.commoncrawl.org/collinfo.json"
TIMEOUT = 10


class CommonCrawl(Source):
def get_latest_craw_index(self):
class CommonCrawl(SourceBase):
async def get_latest_craw_index(self):
url = COMMONCRAWL_URL
session = Sessions().get_session()
try:
with session.get(url, timeout=TIMEOUT) as response:
async with self.sources.crawler.session.get(
url, timeout=TIMEOUT
) as response:
response.raise_for_status()
crawl_indexes = response.json()
except (RequestException, ValueError, JSONDecodeError) as e:
crawl_indexes = await response.json()
except (ClientError, ValueError, JSONDecodeError) as e:
self.add_error("Error on CommonCrawl source: {}".format(e))
return
if not crawl_indexes:
return
latest_crawl_index = crawl_indexes[0]
return latest_crawl_index["cdx-api"]

def callback(self, domain):
latest_crawl_index = self.get_latest_craw_index()
async def search_by_domain(self, domain: str) -> Iterable[str]:
latest_crawl_index = await self.get_latest_craw_index()
if not latest_crawl_index:
return
session = Sessions().get_session()
try:
with session.get(
latest_crawl_index,
params={"url": "*.{}".format(domain), "output": "json"},
timeout=TIMEOUT,
stream=True,
) as response:
response.raise_for_status()
for line in filter(bool, response.iter_lines()):
if isinstance(line, bytes):
line = line.decode("utf-8")
data = json.loads(line)
self.add_result(data["url"])
except RequestException:
return
return []
async with self.sources.crawler.session.get(
latest_crawl_index,
params={"url": "*.{}".format(domain), "output": "json"},
timeout=TIMEOUT,
) as response:
response.raise_for_status()
urls = set()
while True:
line = (await response.content.readline()).decode("utf-8")
if not line:
break
data = json.loads(line)
urls.add(data["url"])
return urls
Loading

0 comments on commit d742ffe

Please sign in to comment.