Skip to content

Commit

Permalink
spam_shield: Apply spam prediction natively
Browse files Browse the repository at this point in the history
+ make scikit-learn as extra dependency
+ edit Dockerfile to download all extras dependency
  • Loading branch information
MrMissx committed Aug 15, 2021
1 parent aa3068d commit e5d853d
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 109 deletions.
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ RUN apt -qq update && apt -qq upgrade -y
RUN apt -qq install -y --no-install-recommends \
wget \
git \
gnupg2
gnupg2

# Copy directory and install dependencies
COPY . .
Expand All @@ -17,7 +17,7 @@ RUN curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-
ENV PATH="${PATH}:/root/.poetry/bin"

RUN poetry config virtualenvs.create false
RUN poetry install --no-root --no-dev -E uvloop
RUN poetry install --no-root --no-dev -E all

# command to run on container start
CMD ["python3","-m","anjani"]
177 changes: 80 additions & 97 deletions anjani/plugins/spam_shield.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import hashlib
import pickle
import re
from datetime import datetime
from typing import Any, ClassVar, MutableMapping, Optional
from typing import Any, ClassVar, List, MutableMapping, Optional, TypeVar

from aiohttp import ClientResponseError
from pyrogram import filters
Expand All @@ -27,11 +29,20 @@
InlineKeyboardButton,
InlineKeyboardMarkup,
Message,
User
User,
)

try:
from sklearn.pipeline import Pipeline

_run_predict = True
except ImportError:
Pipeline = TypeVar("Pipeline")
_run_predict = False

from anjani import command, listener, plugin, util
from anjani.filters import admin_only
from anjani.util import run_sync


class SpamShield(plugin.Plugin):
Expand All @@ -40,9 +51,10 @@ class SpamShield(plugin.Plugin):
db: util.db.AsyncCollection
db_dump: util.db.AsyncCollection
federation_db: util.db.AsyncCollection
token: Optional[str]
sp_token: Optional[str]
sp_url: Optional[str]
model: Optional[Pipeline] = None
token: Optional[str] = None
sp_token: Optional[str] = None
sp_url: Optional[str] = None

async def on_load(self) -> None:
self.token = self.bot.config.get("sw_api")
Expand All @@ -51,9 +63,33 @@ async def on_load(self) -> None:

self.db = self.bot.db.get_collection("GBAN_SETTINGS")
self.federation_db = self.bot.db.get_collection("FEDERATIONS")
self.sp_token = self.bot.config.get("sp_token")
self.sp_url = self.bot.config.get("sp_url")
self.db_dump = self.bot.db.get_collection("SPAM_DUMP")
if _run_predict:
self.sp_token = self.bot.config.get("sp_token")
self.sp_url = self.bot.config.get("sp_url")
self.db_dump = self.bot.db.get_collection("SPAM_DUMP")
if self.sp_url and self.sp_token:
await self._load_model()

async def _load_model(self) -> None:
self.log.info("Downloading spam prediction model!")
async with self.bot.http.get(
self.sp_url, # type: ignore
headers={
"Authorization": f"token {self.sp_token}",
"Accept": "application/vnd.github.v3.raw",
},
) as res:
if res.status == 200:
self.model = await run_sync(pickle.loads, await res.read())
else:
self.model = None
self.log.warning("Failed to download model")

def _build_hash(self, content: str) -> str:
return hashlib.sha256(content.strip().encode()).hexdigest()

def _predict(self, text: str) -> List[List[float]]:
return self.model.predict_proba([text]) # type: ignore

async def on_chat_migrate(self, message: Message) -> None:
new_chat = message.chat.id
Expand All @@ -72,9 +108,7 @@ async def on_plugin_backup(self, chat_id: int) -> MutableMapping[str, Any]:
return {self.name: setting}

async def on_plugin_restore(self, chat_id: int, data: MutableMapping[str, Any]) -> None:
await self.db.update_one({"chat_id": chat_id},
{"$set": data[self.name]},
upsert=True)
await self.db.update_one({"chat_id": chat_id}, {"$set": data[self.name]}, upsert=True)

@listener.filters(filters.regex(r"spam_check_(t|f)\[(.*)\]"))
async def on_callback_query(self, query: CallbackQuery) -> None:
Expand Down Expand Up @@ -132,27 +166,21 @@ async def on_callback_query(self, query: CallbackQuery) -> None:
[
InlineKeyboardButton(
text=f"✅ Correct ({total_correct})",
callback_data=f"spam_check_t{users_on_correct}"
callback_data=f"spam_check_t{users_on_correct}",
),
InlineKeyboardButton(
text=f"❌ Incorrect ({total_incorrect})",
callback_data=f"spam_check_f{users_on_incorrect}"
)
callback_data=f"spam_check_f{users_on_incorrect}",
),
]
]
)

await asyncio.gather(
self.db_dump.update_one(
{"_id": content_hash[0]},
{
"$set": {
"spam": total_correct,
"ham": total_incorrect
}
}
{"_id": content_hash[0]}, {"$set": {"spam": total_correct, "ham": total_incorrect}}
),
query.edit_message_reply_markup(reply_markup=button)
query.edit_message_reply_markup(reply_markup=button),
)

async def on_chat_action(self, message: Message) -> None:
Expand All @@ -178,32 +206,20 @@ async def on_message(self, message: Message) -> None:
user = message.from_user
text = (
message.text.strip()
if message.text else
(
message.caption.strip()
if message.media and message.caption else None
)
if message.text
else (message.caption.strip() if message.media and message.caption else None)
)
if not chat or message.left_chat_member or not user or not text:
return

# Always check the spam probability but run it in the background
self.bot.loop.create_task(self.check_probability(
chat.id,
message.from_user.id,
text
)
)
self.bot.loop.create_task(self.check_probability(chat.id, message.from_user.id, text))

if not await self.is_active(chat.id):
return

try:
me, target = await util.tg.fetch_permissions(
self.bot.client,
chat.id,
user.id
)
me, target = await util.tg.fetch_permissions(self.bot.client, chat.id, user.id)
if not me.can_restrict_members or util.tg.is_staff_or_admin(target, self.bot.staff):
return

Expand All @@ -212,23 +228,17 @@ async def on_message(self, message: Message) -> None:
return

async def check_probability(self, chat: int, user: int, text: str) -> None:
if not self.sp_token or not self.sp_url:
if not self.model:
return

async with self.bot.http.post(
self.sp_url,
headers={"Authorization": f"Bearer {self.sp_token}"},
json={"msg": text}
) as res:
if res.status != 200:
return
text = repr(text.strip())
response = await run_sync(self._predict, text)
probability = response[0][1]
if probability <= 0.6:
return

response = await res.json()
probability = response["spam_probability"]
if probability <= 0.6:
return
content_hash = self._build_hash(text)

content_hash = response["text_hash"]
data = await self.db_dump.find_one({"_id": content_hash})
if data:
return
Expand All @@ -254,24 +264,16 @@ async def check_probability(self, chat: int, user: int, text: str) -> None:
InlineKeyboardButton(
text="❌ Incorrect (0)",
callback_data=f"spam_check_f[]",
)
),
]
]
)
),
),
self.db_dump.update_one(
{"_id": content_hash},
{
"$set": {
"text": text,
"spam": 0,
"ham": 0,
"chat": chat,
"id": user
}
},
upsert=True
)
{"$set": {"text": text, "spam": 0, "ham": 0, "chat": chat, "id": user}},
upsert=True,
),
)

async def get_ban(self, user_id: int) -> MutableMapping[str, Any]:
Expand All @@ -289,13 +291,13 @@ async def get_ban(self, user_id: int) -> MutableMapping[str, Any]:
raise ClientResponseError(
resp.request_info,
resp.history,
message="Make sure your Spamwatch API token is corret"
message="Make sure your Spamwatch API token is corret",
)
elif resp.status == 403:
raise ClientResponseError(
resp.request_info,
resp.history,
message="Forbidden, your token permissions is not valid"
message="Forbidden, your token permissions is not valid",
)
elif resp.status == 404:
return {}
Expand All @@ -304,7 +306,7 @@ async def get_ban(self, user_id: int) -> MutableMapping[str, Any]:
raise ClientResponseError(
resp.request_info,
resp.history,
message=f"Too many requests. Try again in {until - datetime.now()}"
message=f"Too many requests. Try again in {until - datetime.now()}",
)
else:
raise ClientResponseError(resp.request_info, resp.history)
Expand All @@ -314,10 +316,7 @@ async def cas_check(self, user: User) -> Optional[str]:
async with self.bot.http.get(f"https://api.cas.chat/check?user_id={user.id}") as res:
data = await res.json()
if data["ok"]:
fullname = (
user.first_name + user.last_name
if user.last_name else user.first_name
)
fullname = user.first_name + user.last_name if user.last_name else user.first_name
reason = f"https://cas.chat/query?u={user.id}"
await self.federation_db.update_one(
{"_id": "AnjaniSpamShield"},
Expand All @@ -326,10 +325,10 @@ async def cas_check(self, user: User) -> Optional[str]:
f"banned.{user.id}": {
"name": fullname,
"reason": reason,
"time": datetime.now()
"time": datetime.now(),
}
}
}
},
)
return reason

Expand All @@ -342,20 +341,11 @@ async def is_active(self, chat_id: int) -> bool:

async def setting(self, chat_id: int, setting: bool) -> None:
"""Turn on/off SpamShield in chats"""
await self.db.update_one(
{"chat_id": chat_id},
{
"$set": {
"setting": setting
}
},
upsert=True
)
await self.db.update_one({"chat_id": chat_id}, {"$set": {"setting": setting}}, upsert=True)

async def check(self, user: User, chat_id: int) -> None:
"""Shield checker action."""
cas, sw = await asyncio.gather(self.cas_check(user),
self.get_ban(user.id))
cas, sw = await asyncio.gather(self.cas_check(user), self.get_ban(user.id))
if not cas or not sw:
return

Expand All @@ -377,17 +367,10 @@ async def check(self, user: User, chat_id: int) -> None:
self.bot.client.kick_chat_member(chat_id, user.id),
self.bot.client.send_message(
chat_id,
text=await self.text(
chat_id,
"banned-text",
userlink,
user.id,
reason,
banner
),
text=await self.text(chat_id, "banned-text", userlink, user.id, reason, banner),
parse_mode="markdown",
disable_web_page_preview=True,
)
),
)

@command.filters(admin_only)
Expand All @@ -400,8 +383,8 @@ async def cmd_spamshield(self, ctx: command.Context, enable: Optional[bool] = No
if enable is None:
return await self.text(chat.id, "err-invalid-option")

ret, _ = await asyncio.gather(self.text(chat.id,
"spamshield-set",
"on" if enable else "off"),
self.setting(chat.id, enable))
ret, _ = await asyncio.gather(
self.text(chat.id, "spamshield-set", "on" if enable else "off"),
self.setting(chat.id, enable),
)
return ret
Loading

0 comments on commit e5d853d

Please sign in to comment.