diff --git a/anjani/plugins/spam_shield.py b/anjani/plugins/spam_shield.py index 17def5cb9..d6f06ddb8 100644 --- a/anjani/plugins/spam_shield.py +++ b/anjani/plugins/spam_shield.py @@ -15,13 +15,20 @@ # along with this program. If not, see . import asyncio +import re from datetime import datetime from typing import Any, ClassVar, MutableMapping, Optional from aiohttp import ClientResponseError from pyrogram import filters from pyrogram.errors import ChannelPrivate -from pyrogram.types import Message, User +from pyrogram.types import ( + CallbackQuery, + InlineKeyboardButton, + InlineKeyboardMarkup, + Message, + User +) from anjani import command, listener, plugin, util from anjani.filters import admin_only @@ -33,6 +40,8 @@ class SpamShield(plugin.Plugin): db: util.db.AsyncCollection federation_db: util.db.AsyncCollection token: str + sp_token: Optional[str] + sp_url: Optional[str] async def on_load(self) -> None: try: @@ -44,6 +53,15 @@ async def on_load(self) -> None: self.db = self.bot.db.get_collection("GBAN_SETTINGS") self.federation_db = self.bot.db.get_collection("FEDERATIONS") + try: + self.sp_token = self.bot.config["sp_token"] + self.sp_url = self.bot.config["sp_url"] + except KeyError: + self.sp_token = None + self.sp_url = None + else: + self.db_dump = self.bot.db.get_collection("SPAM_DUMP") + async def on_chat_migrate(self, message: Message) -> None: new_chat = message.chat.id old_chat = message.migrate_from_chat_id @@ -65,6 +83,85 @@ async def on_plugin_restore(self, chat_id: int, data: MutableMapping[str, Any]) {"$set": data[self.name]}, upsert=True) + @listener.filters(filters.regex(r"spam_check_(t|f)\[(.*?)\]")) + async def on_callback_query(self, query: CallbackQuery) -> None: + if isinstance(query.data, bytes): + query.data = query.data.decode() + + message = query.message + content_hash = re.compile(r"([A-Fa-f0-9]{64})").search(message.text) + author = str(query.from_user.id) + users_on_correct = users_on_incorrect = [] + total_correct = total_incorrect = 0 + + if not content_hash: + self.log.warning("Can't get hash from 'MessageID: %d'", message.message_id) + return + + # Correct button data + correct = re.compile(r"spam_check_t(.*?)").match(query.data) + if message.reply_markup and isinstance(message.reply_markup, InlineKeyboardMarkup): + data = message.reply_markup.inline_keyboard[0][0].callback_data + if isinstance(data, bytes): + data = data.decode() + + users_on_correct = re.findall("[0-9]+", data) + + # Incorrect button data + incorrect = re.compile(r"spam_check_f(.*?)").match(query.data) + if message.reply_markup and isinstance(message.reply_markup, InlineKeyboardMarkup): + data = message.reply_markup.inline_keyboard[0][1].callback_data + if isinstance(data, bytes): + data = data.decode() + + users_on_incorrect = re.findall("[0-9]+", data) + + if correct: + # Check user in incorrect data + if author in users_on_incorrect: + users_on_incorrect.remove(author) + if author in users_on_correct: + users_on_correct.remove(author) + else: + users_on_correct.append(author) + elif incorrect: + # Check user in correct data + if author in users_on_correct: + users_on_correct.remove(author) + if author in users_on_incorrect: + users_on_incorrect.remove(author) + else: + users_on_incorrect.append(author) + + total_correct, total_incorrect = len(users_on_correct), len(users_on_incorrect) + users_on_correct = f"[{', '.join(users_on_correct)}]" + users_on_incorrect = f"[{', '.join(users_on_incorrect)}]" + button = InlineKeyboardMarkup( + [ + [ + InlineKeyboardButton( + text=f"✅ Correct ({total_correct})", + callback_data=f"spam_check_t{users_on_correct}" + ), + InlineKeyboardButton( + text=f"❌ Incorrect ({total_incorrect})", + callback_data=f"spam_check_f{users_on_incorrect}" + ) + ] + ] + ) + + await self.db_dump.update_one( + {"_id": content_hash[0]}, + { + "$set": { + "spam": total_correct, + "ham": total_incorrect + } + } + ) + await query.edit_message_reply_markup(reply_markup=button) + async def on_chat_action(self, message: Message) -> None: """Checker service for new member""" if message.left_chat_member: @@ -88,9 +185,30 @@ async def on_chat_action(self, message: Message) -> None: async def on_message(self, message: Message) -> None: """Checker service for message""" chat = message.chat - if not chat or message.left_chat_member: + if not chat or message.left_chat_member or not message.from_user: + return + + text = ( + message.text.strip() + if message.text else + ( + message.caption.strip() + if message.media and message.caption else None + ) + ) + if not text: return + # Always check the probability but run it in the background + if self.sp_token: + self.bot.loop.create_task( + self.check_probability( + chat.id, + message.from_user.id, + text + ) + ) + if not await self.is_active(chat.id): return @@ -111,6 +229,70 @@ async def on_message(self, message: Message) -> None: except ChannelPrivate: return + async def check_probability(self, chat: int, user: int, text: str) -> None: + if not self.sp_token or not self.sp_url: + return + + async with self.bot.http.post( + self.sp_url, + headers={"Authorization": f"Bearer {self.sp_token}"}, + json={"msg": text} + ) as res: + if res.status != 200: + return + + response = await res.json() + self.log.info(response) + probability = response["spam_probability"] + if probability <= 0.6: + return + + content_hash = response["text_hash"] + data = await self.db_dump.find_one({"_id": content_hash}) + if data: + return + + msg = ( + "#SPAM_PREDICTION\n\n" + f"**Prediction Result**: {str(probability * 10 ** 2)[0:7]}\n" + f"**Message Hash:** `{content_hash}`\n" + f"\n**====== CONTENT =======**\n\n{text}" + ) + await asyncio.gather( + self.bot.client.send_message( + chat_id=-1001314588569, + text=msg, + disable_web_page_preview=True, + reply_markup=InlineKeyboardMarkup( + [ + [ + InlineKeyboardButton( + text="✅ Correct (0)", + callback_data=f"spam_check_t[]", + ), + InlineKeyboardButton( + text="❌ Incorrect (0)", + callback_data=f"spam_check_f[]", + ) + ] + ] + ) + ), + self.db_dump.update_one( + {"_id": content_hash}, + { + "$set": { + "text": text, + "spam": 0, + "ham": 0, + "chat": chat, + "id": user + } + }, + upsert=True + ) + ) + async def get_ban(self, user_id: int) -> MutableMapping[str, Any]: path = f"https://api.spamwat.ch/banlist/{user_id}" headers = {"Authorization": f"Bearer {self.token}"} diff --git a/anjani/util/config.py b/anjani/util/config.py index 4c8507d2c..ff5cff9f9 100644 --- a/anjani/util/config.py +++ b/anjani/util/config.py @@ -22,6 +22,8 @@ def __init__(self) -> None: "db_uri": os.environ.get("DB_URI"), "download_path": os.environ.get("DOWNLOAD_PATH"), "owner_id": os.environ.get("OWNER_ID"), + "sp_token": os.environ.get("SP_TOKEN"), + "sp_url": os.environ.get("SP_URL"), "sw_token": os.environ.get("SW_API"), "log_channel": os.environ.get("LOG_CHANNEL") }