Skip to content

Commit

Permalink
Anjani: Plugins: SpamShield: add message prediction of spam
Browse files Browse the repository at this point in the history
Change-Id: I7f5734874e4f3c54826525b45106cf9946ec5e57
  • Loading branch information
adekmaulana committed Aug 8, 2021
1 parent 2761d4e commit cb824c2
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 2 deletions.
186 changes: 184 additions & 2 deletions anjani/plugins/spam_shield.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import re
from datetime import datetime
from typing import Any, ClassVar, MutableMapping, Optional

from aiohttp import ClientResponseError
from pyrogram import filters
from pyrogram.errors import ChannelPrivate
from pyrogram.types import Message, User
from pyrogram.types import (
CallbackQuery,
InlineKeyboardButton,
InlineKeyboardMarkup,
Message,
User
)

from anjani import command, listener, plugin, util
from anjani.filters import admin_only
Expand All @@ -33,6 +40,8 @@ class SpamShield(plugin.Plugin):
db: util.db.AsyncCollection
federation_db: util.db.AsyncCollection
token: str
sp_token: Optional[str]
sp_url: Optional[str]

async def on_load(self) -> None:
try:
Expand All @@ -44,6 +53,15 @@ async def on_load(self) -> None:
self.db = self.bot.db.get_collection("GBAN_SETTINGS")
self.federation_db = self.bot.db.get_collection("FEDERATIONS")

try:
self.sp_token = self.bot.config["sp_token"]
self.sp_url = self.bot.config["sp_url"]
except KeyError:
self.sp_token = None
self.sp_url = None
else:
self.db_dump = self.bot.db.get_collection("SPAM_DUMP")

async def on_chat_migrate(self, message: Message) -> None:
new_chat = message.chat.id
old_chat = message.migrate_from_chat_id
Expand All @@ -65,6 +83,85 @@ async def on_plugin_restore(self, chat_id: int, data: MutableMapping[str, Any])
{"$set": data[self.name]},
upsert=True)

@listener.filters(filters.regex(r"spam_check_(t|f)\[(.*?)\]"))
async def on_callback_query(self, query: CallbackQuery) -> None:
if isinstance(query.data, bytes):
query.data = query.data.decode()

message = query.message
content_hash = re.compile(r"([A-Fa-f0-9]{64})").search(message.text)
author = str(query.from_user.id)
users_on_correct = users_on_incorrect = []
total_correct = total_incorrect = 0

if not content_hash:
self.log.warning("Can't get hash from 'MessageID: %d'", message.message_id)
return

# Correct button data
correct = re.compile(r"spam_check_t(.*?)").match(query.data)
if message.reply_markup and isinstance(message.reply_markup, InlineKeyboardMarkup):
data = message.reply_markup.inline_keyboard[0][0].callback_data
if isinstance(data, bytes):
data = data.decode()

users_on_correct = re.findall("[0-9]+", data)

# Incorrect button data
incorrect = re.compile(r"spam_check_f(.*?)").match(query.data)
if message.reply_markup and isinstance(message.reply_markup, InlineKeyboardMarkup):
data = message.reply_markup.inline_keyboard[0][1].callback_data
if isinstance(data, bytes):
data = data.decode()

users_on_incorrect = re.findall("[0-9]+", data)

if correct:
# Check user in incorrect data
if author in users_on_incorrect:
users_on_incorrect.remove(author)
if author in users_on_correct:
users_on_correct.remove(author)
else:
users_on_correct.append(author)
elif incorrect:
# Check user in correct data
if author in users_on_correct:
users_on_correct.remove(author)
if author in users_on_incorrect:
users_on_incorrect.remove(author)
else:
users_on_incorrect.append(author)

total_correct, total_incorrect = len(users_on_correct), len(users_on_incorrect)
users_on_correct = f"[{', '.join(users_on_correct)}]"
users_on_incorrect = f"[{', '.join(users_on_incorrect)}]"
button = InlineKeyboardMarkup(
[
[
InlineKeyboardButton(
text=f"✅ Correct ({total_correct})",
callback_data=f"spam_check_t{users_on_correct}"
),
InlineKeyboardButton(
text=f"❌ Incorrect ({total_incorrect})",
callback_data=f"spam_check_f{users_on_incorrect}"
)
]
]
)

await self.db_dump.update_one(
{"_id": content_hash[0]},
{
"$set": {
"spam": total_correct,
"ham": total_incorrect
}
}
)
await query.edit_message_reply_markup(reply_markup=button)

async def on_chat_action(self, message: Message) -> None:
"""Checker service for new member"""
if message.left_chat_member:
Expand All @@ -88,9 +185,30 @@ async def on_chat_action(self, message: Message) -> None:
async def on_message(self, message: Message) -> None:
"""Checker service for message"""
chat = message.chat
if not chat or message.left_chat_member:
if not chat or message.left_chat_member or not message.from_user:
return

text = (
message.text.strip()
if message.text else
(
message.caption.strip()
if message.media and message.caption else None
)
)
if not text:
return

# Always check the probability but run it in the background
if self.sp_token:
self.bot.loop.create_task(
self.check_probability(
chat.id,
message.from_user.id,
text
)
)

if not await self.is_active(chat.id):
return

Expand All @@ -111,6 +229,70 @@ async def on_message(self, message: Message) -> None:
except ChannelPrivate:
return

async def check_probability(self, chat: int, user: int, text: str) -> None:
if not self.sp_token or not self.sp_url:
return

async with self.bot.http.post(
self.sp_url,
headers={"Authorization": f"Bearer {self.sp_token}"},
json={"msg": text}
) as res:
if res.status != 200:
return

response = await res.json()
self.log.info(response)
probability = response["spam_probability"]
if probability <= 0.6:
return

content_hash = response["text_hash"]
data = await self.db_dump.find_one({"_id": content_hash})
if data:
return

msg = (
"#SPAM_PREDICTION\n\n"
f"**Prediction Result**: {str(probability * 10 ** 2)[0:7]}\n"
f"**Message Hash:** `{content_hash}`\n"
f"\n**====== CONTENT =======**\n\n{text}"
)
await asyncio.gather(
self.bot.client.send_message(
chat_id=-1001314588569,
text=msg,
disable_web_page_preview=True,
reply_markup=InlineKeyboardMarkup(
[
[
InlineKeyboardButton(
text="✅ Correct (0)",
callback_data=f"spam_check_t[]",
),
InlineKeyboardButton(
text="❌ Incorrect (0)",
callback_data=f"spam_check_f[]",
)
]
]
)
),
self.db_dump.update_one(
{"_id": content_hash},
{
"$set": {
"text": text,
"spam": 0,
"ham": 0,
"chat": chat,
"id": user
}
},
upsert=True
)
)

async def get_ban(self, user_id: int) -> MutableMapping[str, Any]:
path = f"https://api.spamwat.ch/banlist/{user_id}"
headers = {"Authorization": f"Bearer {self.token}"}
Expand Down
2 changes: 2 additions & 0 deletions anjani/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(self) -> None:
"db_uri": os.environ.get("DB_URI"),
"download_path": os.environ.get("DOWNLOAD_PATH"),
"owner_id": os.environ.get("OWNER_ID"),
"sp_token": os.environ.get("SP_TOKEN"),
"sp_url": os.environ.get("SP_URL"),
"sw_token": os.environ.get("SW_API"),
"log_channel": os.environ.get("LOG_CHANNEL")
}
Expand Down

0 comments on commit cb824c2

Please sign in to comment.