Skip to content

Commit

Permalink
integrate locator
Browse files Browse the repository at this point in the history
  • Loading branch information
umputun committed Dec 10, 2023
1 parent 64b3001 commit a59d42a
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 19 deletions.
3 changes: 2 additions & 1 deletion app/bot/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (

//go:generate moq --out mocks/http_client.go --pkg mocks --skip-ensure . HTTPClient:HTTPClient

// PermanentBanDuration defines duration of permanent ban:
// If user is restricted for more than 366 days or less than 30 seconds from the current time,
// they are considered to be restricted forever.
var permanentBanDuration = time.Hour * 24 * 400
var PermanentBanDuration = time.Hour * 24 * 400

// Response describes bot's reaction on particular message
type Response struct {
Expand Down
2 changes: 1 addition & 1 deletion app/bot/spam.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (s *SpamFilter) OnMessage(msg Message) (response Response) {
msgPrefix = s.params.SpamDryMsg
}
spamRespMsg := fmt.Sprintf("%s: %q (%d)", msgPrefix, displayUsername, msg.From.ID)
return Response{Text: spamRespMsg, Send: true, ReplyTo: msg.ID, BanInterval: permanentBanDuration,
return Response{Text: spamRespMsg, Send: true, ReplyTo: msg.ID, BanInterval: PermanentBanDuration,
DeleteReplyTo: true, User: User{Username: msg.From.Username, ID: msg.From.ID, DisplayName: msg.From.DisplayName},
}
}
Expand Down
2 changes: 1 addition & 1 deletion app/bot/spam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func TestSpamFilter_OnMessage(t *testing.T) {
t.Run("spam detected", func(t *testing.T) {
s := NewSpamFilter(ctx, det, nil, nil, SpamParams{SpamMsg: "detected", SpamDryMsg: "detected dry"})
resp := s.OnMessage(Message{Text: "spam", From: User{ID: 1, Username: "john"}})
assert.Equal(t, Response{Text: `detected: "john" (1)`, Send: true, BanInterval: permanentBanDuration,
assert.Equal(t, Response{Text: `detected: "john" (1)`, Send: true, BanInterval: PermanentBanDuration,
User: User{ID: 1, Username: "john"}, DeleteReplyTo: true}, resp)
t.Logf("resp: %+v", resp)
})
Expand Down
74 changes: 74 additions & 0 deletions app/events/locator.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package events

import (
"crypto/sha256"
"fmt"
"time"
)

// Locator stores messages for a given time period.
// It is used to locate the message in the chat by its hash.
// Useful to match messages from admin chat (only text available) to the original message.
// Note: it is not thread-safe, use it from a single goroutine only.
type Locator struct {
ttl time.Duration // how long to keep messages
data map[string]MsgMeta // message hash -> message meta
lastRemoval time.Time // last time cleanup was performed
cleanupDuration time.Duration // how often to perform cleanup
}

// MsgMeta stores message metadata
type MsgMeta struct {
time time.Time
chatID int64
userID int64
msgID int
}

func (m MsgMeta) String() string {
return fmt.Sprintf("{chatID: %d, userID: %d, msgID: %d, time: %s}", m.chatID, m.userID, m.msgID, m.time.Format(time.RFC3339))
}

// NewLocator creates new Locator
func NewLocator(ttl time.Duration) *Locator {
return &Locator{
ttl: ttl,
data: make(map[string]MsgMeta),
lastRemoval: time.Now(),
cleanupDuration: 5 * time.Minute,
}
}

// Get returns message MsgMeta for give msg
func (l *Locator) Get(msg string) (MsgMeta, bool) {
hash := l.MsgHash(msg)
res, ok := l.data[hash]
return res, ok
}

// MsgHash returns sha256 hash of a message
func (l *Locator) MsgHash(msg string) string {
return fmt.Sprintf("%x", sha256.Sum256([]byte(msg)))
}

// Add adds message to the locator
func (l *Locator) Add(msg string, chatID, userID int64, msgID int) {
l.data[l.MsgHash(msg)] = MsgMeta{
time: time.Now(),
chatID: chatID,
userID: userID,
msgID: msgID,
}

if time.Since(l.lastRemoval) < l.cleanupDuration {
return
}

// remove old messages
for k, v := range l.data {
if time.Since(v.time) > l.ttl {
delete(l.data, k)
}
}
l.lastRemoval = time.Now()
}
89 changes: 89 additions & 0 deletions app/events/locator_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package events

import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestNewLocator(t *testing.T) {
ttl := 10 * time.Minute
locator := NewLocator(ttl)
require.NotNil(t, locator)

assert.Equal(t, ttl, locator.ttl)
assert.NotZero(t, locator.cleanupDuration)
assert.NotNil(t, locator.data)
assert.WithinDuration(t, time.Now(), locator.lastRemoval, time.Second)
}

func TestGet(t *testing.T) {
locator := NewLocator(10 * time.Minute)

// adding a message
msg := "test message"
chatID := int64(123)
userID := int64(456)
msgID := 7890
locator.Add(msg, chatID, userID, msgID)

// test retrieval of existing message
info, found := locator.Get("test message")
require.True(t, found)
assert.Equal(t, msgID, info.msgID)
assert.Equal(t, chatID, info.chatID)
assert.Equal(t, userID, info.userID)

// test retrieval of non-existing message
_, found = locator.Get("no such message") // non-existing msgID
assert.False(t, found)
}

func TestMsgHash(t *testing.T) {
locator := NewLocator(10 * time.Minute)

t.Run("hash for empty message", func(t *testing.T) {
hash := locator.MsgHash("")
assert.Equal(t, "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", hash)
})

t.Run("hash for non-empty message", func(t *testing.T) {
hash := locator.MsgHash("test message")
assert.Equal(t, "3f0a377ba0a4a460ecb616f6507ce0d8cfa3e704025d4fda3ed0c5ca05468728", hash)
})

t.Run("hash for different non-empty message", func(t *testing.T) {
hash := locator.MsgHash("test message blah")
assert.Equal(t, "21b7035e5ab5664eb7571b1f63d96951d5554a5465302b9cdd2e3de510eda6d8", hash)
})
}

func TestAddAndCleanup(t *testing.T) {
ttl := 2 * time.Second
cleanupDuration := 1 * time.Second
locator := NewLocator(ttl)
locator.cleanupDuration = cleanupDuration

// Adding a message
msg := "test message"
chatID := int64(123)
userID := int64(456)
msgID := 7890
locator.Add(msg, chatID, userID, msgID)

hash := locator.MsgHash(msg)
meta, exists := locator.data[hash]
require.True(t, exists)
assert.Equal(t, chatID, meta.chatID)
assert.Equal(t, userID, meta.userID)
assert.Equal(t, msgID, meta.msgID)

// wait for cleanup duration and add another message to trigger cleanup
time.Sleep(cleanupDuration + time.Second)
locator.Add("another message", 789, 555, 1011)

_, existsAfterCleanup := locator.data[hash]
assert.False(t, existsAfterCleanup)
}
69 changes: 55 additions & 14 deletions app/events/telegram.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type TelegramListener struct {
NoSpamReply bool
Dry bool
SpamWeb SpamWeb
Locator *Locator

chatID int64
adminChatID int64
Expand Down Expand Up @@ -159,29 +160,19 @@ func (l *TelegramListener) procEvents(update tbapi.Update) error {

// message from admin chat
if l.isAdminChat(fromChat, msg.From.Username) {
// message from supers to admin chat
if update.Message.ForwardSenderName != "" {
// this is a forwarded message from super to admin chat, it is an example of missed spam
// we need to update spam filter with this message
msgTxt := strings.ReplaceAll(update.Message.Text, "\n", " ")
log.Printf("[DEBUG] forwarded message from superuser %s to admin chat: %q", msg.From.Username, msgTxt)
if err := l.Bot.UpdateSpam(msgTxt); err != nil {
log.Printf("[WARN] failed to update spam for %q, %v", update.Message.Text, err)
return nil
}
log.Printf("[INFO] spam updated with %q", update.Message.Text)
// it would be nice to ban this user right away, but we don't have forwarded user ID here, it is empty in update.Message
if err := l.adminChatMsg(update, fromChat); err != nil {
log.Printf("[WARN] failed to process admin chat message: %v", err)
}
return nil
}

// ignore messages from other chats if not in the test list
if !l.isChatAllowed(fromChat) {
// ignore messages from other chats if not in the test list
return nil
}

log.Printf("[DEBUG] incoming msg: %+v", strings.ReplaceAll(msg.Text, "\n", " "))

l.Locator.Add(update.Message.Text, fromChat, msg.From.ID, msg.ID) // save message to locator
resp := l.Bot.OnMessage(*msg)

if resp.Send && !l.NoSpamReply {
Expand Down Expand Up @@ -223,6 +214,56 @@ func (l *TelegramListener) procEvents(update tbapi.Update) error {
return errs.ErrorOrNil()
}

func (l *TelegramListener) adminChatMsg(update tbapi.Update, fromChat int64) error {
shrink := func(inp string, max int) string {
if len(inp) <= max {
return inp
}
return inp[:max] + "..."
}

// message from supers to admin chat
if update.Message.ForwardSenderName != "" || update.FromChat() != nil {
// this is a forwarded message from super to admin chat, it is an example of missed spam
// we need to update spam filter with this message
msgTxt := strings.ReplaceAll(update.Message.Text, "\n", " ")
log.Printf("[DEBUG] forwarded message from superuser %q to admin chat %d: %q",
update.Message.From.UserName, l.adminChatID, msgTxt)

if !l.Dry {
if err := l.Bot.UpdateSpam(msgTxt); err != nil {
return fmt.Errorf("failed to update spam for %q: %w", msgTxt, err)
}
log.Printf("[INFO] spam updated with %q", shrink(update.Message.Text, 20))
}

// it would be nice to ban this user right away, but we don't have forwarded user ID here due to tg privacy limiatation,
// it is empty in update.Message. To ban this user, we need to get the match on the message from the locator and ban from there.
info, ok := l.Locator.Get(update.Message.Text)
if !ok {
return fmt.Errorf("not found %q in locator", update.Message.Text)
}

log.Printf("[DEBUG] locator found message %+v", info)
if l.Dry {
return nil
}

_, err := l.TbAPI.Request(tbapi.DeleteMessageConfig{ChatID: l.chatID, MessageID: info.msgID})
if err != nil {
return fmt.Errorf("failed to delete message %d: %w", info.msgID, err)
}
log.Printf("[INFO] message %d deleted", info.msgID)

err = l.banUserOrChannel(bot.PermanentBanDuration, fromChat, info.userID, info.chatID)
if err != nil {
return fmt.Errorf("failed to ban user %d: %w", info.userID, err)
}
log.Printf("[INFO] user %d %q banned", info.userID, update.Message.ForwardSenderName)
}
return nil
}

func (l *TelegramListener) isChatAllowed(fromChat int64) bool {
if fromChat == l.chatID {
return true
Expand Down
9 changes: 8 additions & 1 deletion app/events/telegram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ func TestTelegramListener_Do(t *testing.T) {
Group: "gr",
AdminGroup: "987654321",
StartupMsg: "startup",
Locator: NewLocator(10 * time.Minute),
}

ctx, cancel := context.WithTimeout(context.Background(), 500*time.Minute)
Expand Down Expand Up @@ -98,6 +99,7 @@ func TestTelegramListener_DoWithBotBan(t *testing.T) {
Bot: b,
SuperUsers: SuperUser{"admin"},
Group: "gr",
Locator: NewLocator(10 * time.Minute),
}

ctx, cancel := context.WithTimeout(context.Background(), 500*time.Minute)
Expand Down Expand Up @@ -228,6 +230,7 @@ func TestTelegramListener_DoDeleteMessages(t *testing.T) {
TbAPI: mockAPI,
Bot: b,
Group: "gr",
Locator: NewLocator(10 * time.Minute),
}

ctx, cancel := context.WithTimeout(context.Background(), 500*time.Minute)
Expand Down Expand Up @@ -274,6 +277,9 @@ func TestTelegramListener_DoWithForwarded(t *testing.T) {
SendFunc: func(c tbapi.Chattable) (tbapi.Message, error) {
return tbapi.Message{Text: c.(tbapi.MessageConfig).Text, From: &tbapi.User{UserName: "user"}}, nil
},
RequestFunc: func(c tbapi.Chattable) (*tbapi.APIResponse, error) {
return &tbapi.APIResponse{Ok: true}, nil
},
}
b := &mocks.BotMock{
OnMessageFunc: func(msg bot.Message) bot.Response {
Expand All @@ -297,6 +303,7 @@ func TestTelegramListener_DoWithForwarded(t *testing.T) {
AdminGroup: "123",
StartupMsg: "startup",
SuperUsers: SuperUser{"umputun"},
Locator: NewLocator(10 * time.Minute),
}

ctx, cancel := context.WithTimeout(context.Background(), 500*time.Minute)
Expand All @@ -306,7 +313,7 @@ func TestTelegramListener_DoWithForwarded(t *testing.T) {
Message: &tbapi.Message{
Chat: &tbapi.Chat{ID: 123},
Text: "text 123",
From: &tbapi.User{UserName: "umputun"},
From: &tbapi.User{UserName: "umputun", ID: 77},
Date: int(time.Date(2020, 2, 11, 19, 35, 55, 9, time.UTC).Unix()),
ForwardSenderName: "forwarded_name",
},
Expand Down
4 changes: 3 additions & 1 deletion app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ var opts struct {
Group string `long:"group" env:"GROUP" description:"admin group name/id"`
} `group:"admin" namespace:"admin" env-namespace:"ADMIN"`

TestingIDs []int64 `long:"testing-id" env:"TESTING_ID" env-delim:"," description:"testing ids, allow bot to reply to them"`
TestingIDs []int64 `long:"testing-id" env:"TESTING_ID" env-delim:"," description:"testing ids, allow bot to reply to them"`
HistoryDuration time.Duration `long:"history-duration" env:"HISTORY_DURATION" default:"1h" description:"history duration"`

Logger struct {
Enabled bool `long:"enabled" env:"ENABLED" description:"enable spam rotated logs"`
Expand Down Expand Up @@ -186,6 +187,7 @@ func execute(ctx context.Context) error {
AdminGroup: opts.Admin.Group,
TestingIDs: opts.TestingIDs,
SpamWeb: web,
Locator: events.NewLocator(opts.HistoryDuration),
Dry: opts.Dry,
}

Expand Down

0 comments on commit a59d42a

Please sign in to comment.