diff --git a/api.go b/api.go index a66f6211..ccfb49f5 100644 --- a/api.go +++ b/api.go @@ -1,377 +1,113 @@ package telebot -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "io/ioutil" - "log" - "mime/multipart" - "net/http" - "os" - "strconv" - "strings" - "time" -) - -// Raw lets you call any method of Bot API manually. -// It also handles API errors, so you only need to unwrap -// result field from json data. -func (b *Bot) Raw(method string, payload interface{}) ([]byte, error) { - url := b.URL + "/bot" + b.Token + "/" + method - - var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(payload); err != nil { - return nil, err - } - - // Cancel the request immediately without waiting for the timeout - // when bot is about to stop. - // This may become important if doing long polling with long timeout. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - b.stopMu.RLock() - stopCh := b.stopClient - b.stopMu.RUnlock() - - select { - case <-stopCh: - cancel() - case <-ctx.Done(): - } - }() - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, &buf) - if err != nil { - return nil, wrapError(err) - } - req.Header.Set("Content-Type", "application/json") - - resp, err := b.client.Do(req) - if err != nil { - return nil, wrapError(err) - } - resp.Close = true - defer resp.Body.Close() - - data, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, wrapError(err) - } - - if b.verbose { - verbose(method, payload, data) - } - - // returning data as well - return data, extractOk(data) -} - -func (b *Bot) sendFiles(method string, files map[string]File, params map[string]string) ([]byte, error) { - rawFiles := make(map[string]interface{}) - for name, f := range files { - switch { - case f.InCloud(): - params[name] = f.FileID - case f.FileURL != "": - params[name] = f.FileURL - case f.OnDisk(): - rawFiles[name] = f.FileLocal - case f.FileReader != nil: - rawFiles[name] = f.FileReader - default: - return nil, fmt.Errorf("telebot: file for field %s doesn't exist", name) - } - } - - if len(rawFiles) == 0 { - return b.Raw(method, params) - } - - pipeReader, pipeWriter := io.Pipe() - writer := multipart.NewWriter(pipeWriter) - - go func() { - defer pipeWriter.Close() - - for field, file := range rawFiles { - if err := addFileToWriter(writer, files[field].fileName, field, file); err != nil { - pipeWriter.CloseWithError(err) - return - } - } - for field, value := range params { - if err := writer.WriteField(field, value); err != nil { - pipeWriter.CloseWithError(err) - return - } - } - if err := writer.Close(); err != nil { - pipeWriter.CloseWithError(err) - return - } - }() - - url := b.URL + "/bot" + b.Token + "/" + method - - resp, err := b.client.Post(url, writer.FormDataContentType(), pipeReader) - if err != nil { - err = wrapError(err) - pipeReader.CloseWithError(err) - return nil, err - } - resp.Close = true - defer resp.Body.Close() - - if resp.StatusCode == http.StatusInternalServerError { - return nil, ErrInternal - } - - data, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, wrapError(err) - } - - return data, extractOk(data) -} - -func addFileToWriter(writer *multipart.Writer, filename, field string, file interface{}) error { - var reader io.Reader - if r, ok := file.(io.Reader); ok { - reader = r - } else if path, ok := file.(string); ok { - f, err := os.Open(path) - if err != nil { - return err - } - defer f.Close() - reader = f - } else { - return fmt.Errorf("telebot: file for field %v should be io.ReadCloser or string", field) - } - - part, err := writer.CreateFormFile(field, filename) - if err != nil { - return err - } - - _, err = io.Copy(part, reader) - return err -} - -func (f *File) process(name string, files map[string]File) string { - switch { - case f.InCloud(): - return f.FileID - case f.FileURL != "": - return f.FileURL - case f.OnDisk() || f.FileReader != nil: - files[name] = *f - return "attach://" + name - } - return "" -} - -func (b *Bot) sendText(to Recipient, text string, opt *SendOptions) (*Message, error) { - params := map[string]string{ - "chat_id": to.Recipient(), - "text": text, - } - b.embedSendOptions(params, opt) - - data, err := b.Raw("sendMessage", params) - if err != nil { - return nil, err - } - - return extractMessage(data) -} - -func (b *Bot) sendMedia(media Media, params map[string]string, files map[string]File) (*Message, error) { - kind := media.MediaType() - what := "send" + strings.Title(kind) - - if kind == "videoNote" { - kind = "video_note" - } - - sendFiles := map[string]File{kind: *media.MediaFile()} - for k, v := range files { - sendFiles[k] = v - } - - data, err := b.sendFiles(what, sendFiles, params) - if err != nil { - return nil, err - } - - return extractMessage(data) -} - -func (b *Bot) getMe() (*User, error) { - data, err := b.Raw("getMe", nil) - if err != nil { - return nil, err - } - - var resp struct { - Result *User - } - if err := json.Unmarshal(data, &resp); err != nil { - return nil, wrapError(err) - } - return resp.Result, nil -} - -func (b *Bot) getUpdates(offset, limit int, timeout time.Duration, allowed []string) ([]Update, error) { - params := map[string]string{ - "offset": strconv.Itoa(offset), - "timeout": strconv.Itoa(int(timeout / time.Second)), - } - - data, _ := json.Marshal(allowed) - params["allowed_updates"] = string(data) - - if limit != 0 { - params["limit"] = strconv.Itoa(limit) - } - - data, err := b.Raw("getUpdates", params) - if err != nil { - return nil, err - } - - var resp struct { - Result []Update - } - if err := json.Unmarshal(data, &resp); err != nil { - return nil, wrapError(err) - } - return resp.Result, nil -} - -func (b *Bot) forwardCopyMany(to Recipient, msgs []Editable, key string, opts ...*SendOptions) ([]Message, error) { - params := map[string]string{ - "chat_id": to.Recipient(), - } - - embedMessages(params, msgs) - - if len(opts) > 0 { - b.embedSendOptions(params, opts[0]) - } - - data, err := b.Raw(key, params) - if err != nil { - return nil, err - } - - var resp struct { - Result []Message - } - if err := json.Unmarshal(data, &resp); err != nil { - var resp struct { - Result bool - } - if err := json.Unmarshal(data, &resp); err != nil { - return nil, wrapError(err) - } - return nil, wrapError(err) - } - return resp.Result, nil -} - -// extractOk checks given result for error. If result is ok returns nil. -// In other cases it extracts API error. If error is not presented -// in errors.go, it will be prefixed with `unknown` keyword. -func extractOk(data []byte) error { - var e struct { - Ok bool `json:"ok"` - Code int `json:"error_code"` - Description string `json:"description"` - Parameters map[string]interface{} `json:"parameters"` - } - if json.NewDecoder(bytes.NewReader(data)).Decode(&e) != nil { - return nil // FIXME - } - if e.Ok { - return nil - } - - err := Err(e.Description) - switch err { - case nil: - case ErrGroupMigrated: - migratedTo, ok := e.Parameters["migrate_to_chat_id"] - if !ok { - return NewError(e.Code, e.Description) - } - - return GroupError{ - err: err.(*Error), - MigratedTo: int64(migratedTo.(float64)), - } - default: - return err - } - - switch e.Code { - case http.StatusTooManyRequests: - retryAfter, ok := e.Parameters["retry_after"] - if !ok { - return NewError(e.Code, e.Description) - } - - err = FloodError{ - err: NewError(e.Code, e.Description), - RetryAfter: int(retryAfter.(float64)), - } - default: - err = fmt.Errorf("telegram: %s (%d)", e.Description, e.Code) - } - - return err -} - -// extractMessage extracts common Message result from given data. -// Should be called after extractOk or b.Raw() to handle possible errors. -func extractMessage(data []byte) (*Message, error) { - var resp struct { - Result *Message - } - if err := json.Unmarshal(data, &resp); err != nil { - var resp struct { - Result bool - } - if err := json.Unmarshal(data, &resp); err != nil { - return nil, wrapError(err) - } - if resp.Result { - return nil, ErrTrueResult - } - return nil, wrapError(err) - } - return resp.Result, nil -} - -func verbose(method string, payload interface{}, data []byte) { - body, _ := json.Marshal(payload) - body = bytes.ReplaceAll(body, []byte(`\"`), []byte(`"`)) - body = bytes.ReplaceAll(body, []byte(`"{`), []byte(`{`)) - body = bytes.ReplaceAll(body, []byte(`}"`), []byte(`}`)) - - indent := func(b []byte) string { - var buf bytes.Buffer - json.Indent(&buf, b, "", " ") - return buf.String() - } - - log.Printf( - "[verbose] telebot: sent request\nMethod: %v\nParams: %v\nResponse: %v", - method, indent(body), indent(data), - ) +import "io" + +// API is the interface that wraps all basic methods for interacting +// with Telegram Bot API. +type API interface { + Raw(method string, payload interface{}) ([]byte, error) + + Accept(query *PreCheckoutQuery, errorMessage ...string) error + AddStickerToSet(of Recipient, name string, sticker InputSticker) error + AdminsOf(chat *Chat) ([]ChatMember, error) + Answer(query *Query, resp *QueryResponse) error + AnswerWebApp(query *Query, r Result) (*WebAppMessage, error) + ApproveJoinRequest(chat Recipient, user *User) error + Ban(chat *Chat, member *ChatMember, revokeMessages ...bool) error + BanSenderChat(chat *Chat, sender Recipient) error + ChatByID(id int64) (*Chat, error) + ChatByUsername(name string) (*Chat, error) + ChatMemberOf(chat, user Recipient) (*ChatMember, error) + Close() (bool, error) + CloseGeneralTopic(chat *Chat) error + CloseTopic(chat *Chat, topic *Topic) error + Commands(opts ...interface{}) ([]Command, error) + Copy(to Recipient, msg Editable, opts ...interface{}) (*Message, error) + CopyMany(to Recipient, msgs []Editable, opts ...*SendOptions) ([]Message, error) + CreateInviteLink(chat Recipient, link *ChatInviteLink) (*ChatInviteLink, error) + CreateInvoiceLink(i Invoice) (string, error) + CreateStickerSet(of Recipient, set *StickerSet) error + CreateTopic(chat *Chat, topic *Topic) (*Topic, error) + CustomEmojiStickers(ids []string) ([]Sticker, error) + DeclineJoinRequest(chat Recipient, user *User) error + DefaultRights(forChannels bool) (*Rights, error) + Delete(msg Editable) error + DeleteCommands(opts ...interface{}) error + DeleteGroupPhoto(chat *Chat) error + DeleteGroupStickerSet(chat *Chat) error + DeleteMany(msgs []Editable) error + DeleteSticker(sticker string) error + DeleteStickerSet(name string) error + DeleteTopic(chat *Chat, topic *Topic) error + Download(file *File, localFilename string) error + Edit(msg Editable, what interface{}, opts ...interface{}) (*Message, error) + EditCaption(msg Editable, caption string, opts ...interface{}) (*Message, error) + EditGeneralTopic(chat *Chat, topic *Topic) error + EditInviteLink(chat Recipient, link *ChatInviteLink) (*ChatInviteLink, error) + EditMedia(msg Editable, media Inputtable, opts ...interface{}) (*Message, error) + EditReplyMarkup(msg Editable, markup *ReplyMarkup) (*Message, error) + EditTopic(chat *Chat, topic *Topic) error + File(file *File) (io.ReadCloser, error) + FileByID(fileID string) (File, error) + Forward(to Recipient, msg Editable, opts ...interface{}) (*Message, error) + ForwardMany(to Recipient, msgs []Editable, opts ...*SendOptions) ([]Message, error) + GameScores(user Recipient, msg Editable) ([]GameHighScore, error) + HideGeneralTopic(chat *Chat) error + InviteLink(chat *Chat) (string, error) + Leave(chat Recipient) error + Len(chat *Chat) (int, error) + Logout() (bool, error) + MenuButton(chat *User) (*MenuButton, error) + MyDescription(language string) (*BotInfo, error) + MyName(language string) (*BotInfo, error) + MyShortDescription(language string) (*BotInfo, error) + Notify(to Recipient, action ChatAction, threadID ...int) error + Pin(msg Editable, opts ...interface{}) error + ProfilePhotosOf(user *User) ([]Photo, error) + Promote(chat *Chat, member *ChatMember) error + React(to Recipient, msg Editable, opts ...ReactionOptions) error + RemoveWebhook(dropPending ...bool) error + ReopenGeneralTopic(chat *Chat) error + ReopenTopic(chat *Chat, topic *Topic) error + Reply(to *Message, what interface{}, opts ...interface{}) (*Message, error) + Respond(c *Callback, resp ...*CallbackResponse) error + Restrict(chat *Chat, member *ChatMember) error + RevokeInviteLink(chat Recipient, link string) (*ChatInviteLink, error) + Send(to Recipient, what interface{}, opts ...interface{}) (*Message, error) + SendAlbum(to Recipient, a Album, opts ...interface{}) ([]Message, error) + SetAdminTitle(chat *Chat, user *User, title string) error + SetCommands(opts ...interface{}) error + SetCustomEmojiStickerSetThumb(name, id string) error + SetDefaultRights(rights Rights, forChannels bool) error + SetGameScore(user Recipient, msg Editable, score GameHighScore) (*Message, error) + SetGroupDescription(chat *Chat, description string) error + SetGroupPermissions(chat *Chat, perms Rights) error + SetGroupStickerSet(chat *Chat, setName string) error + SetGroupTitle(chat *Chat, title string) error + SetMenuButton(chat *User, mb interface{}) error + SetMyDescription(desc, language string) error + SetMyName(name, language string) error + SetMyShortDescription(desc, language string) error + SetStickerEmojis(sticker string, emojis []string) error + SetStickerKeywords(sticker string, keywords []string) error + SetStickerMaskPosition(sticker string, mask MaskPosition) error + SetStickerPosition(sticker string, position int) error + SetStickerSetThumb(of Recipient, set *StickerSet) error + SetStickerSetTitle(s StickerSet) error + SetWebhook(w *Webhook) error + Ship(query *ShippingQuery, what ...interface{}) error + StickerSet(name string) (*StickerSet, error) + StopLiveLocation(msg Editable, opts ...interface{}) (*Message, error) + StopPoll(msg Editable, opts ...interface{}) (*Poll, error) + TopicIconStickers() ([]Sticker, error) + Unban(chat *Chat, user *User, forBanned ...bool) error + UnbanSenderChat(chat *Chat, sender Recipient) error + UnhideGeneralTopic(chat *Chat) error + Unpin(chat Recipient, messageID ...int) error + UnpinAll(chat Recipient) error + UnpinAllGeneralTopicMessages(chat *Chat) error + UnpinAllTopicMessages(chat *Chat, topic *Topic) error + UploadSticker(to Recipient, format StickerSetFormat, f File) (*File, error) + UserBoosts(chat, user Recipient) ([]Boost, error) + Webhook() (*Webhook, error) } diff --git a/bot.go b/bot.go index 72406a20..4ba3eac3 100644 --- a/bot.go +++ b/bot.go @@ -266,10 +266,7 @@ func (b *Bot) NewMarkup() *ReplyMarkup { // NewContext returns a new native context object, // field by the passed update. func (b *Bot) NewContext(u Update) Context { - return &nativeContext{ - b: b, - u: u, - } + return NewContext(b, u) } // Send accepts 2+ arguments, starting with destination chat, followed by diff --git a/bot_raw.go b/bot_raw.go new file mode 100644 index 00000000..a66f6211 --- /dev/null +++ b/bot_raw.go @@ -0,0 +1,377 @@ +package telebot + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "mime/multipart" + "net/http" + "os" + "strconv" + "strings" + "time" +) + +// Raw lets you call any method of Bot API manually. +// It also handles API errors, so you only need to unwrap +// result field from json data. +func (b *Bot) Raw(method string, payload interface{}) ([]byte, error) { + url := b.URL + "/bot" + b.Token + "/" + method + + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(payload); err != nil { + return nil, err + } + + // Cancel the request immediately without waiting for the timeout + // when bot is about to stop. + // This may become important if doing long polling with long timeout. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + b.stopMu.RLock() + stopCh := b.stopClient + b.stopMu.RUnlock() + + select { + case <-stopCh: + cancel() + case <-ctx.Done(): + } + }() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, &buf) + if err != nil { + return nil, wrapError(err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := b.client.Do(req) + if err != nil { + return nil, wrapError(err) + } + resp.Close = true + defer resp.Body.Close() + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, wrapError(err) + } + + if b.verbose { + verbose(method, payload, data) + } + + // returning data as well + return data, extractOk(data) +} + +func (b *Bot) sendFiles(method string, files map[string]File, params map[string]string) ([]byte, error) { + rawFiles := make(map[string]interface{}) + for name, f := range files { + switch { + case f.InCloud(): + params[name] = f.FileID + case f.FileURL != "": + params[name] = f.FileURL + case f.OnDisk(): + rawFiles[name] = f.FileLocal + case f.FileReader != nil: + rawFiles[name] = f.FileReader + default: + return nil, fmt.Errorf("telebot: file for field %s doesn't exist", name) + } + } + + if len(rawFiles) == 0 { + return b.Raw(method, params) + } + + pipeReader, pipeWriter := io.Pipe() + writer := multipart.NewWriter(pipeWriter) + + go func() { + defer pipeWriter.Close() + + for field, file := range rawFiles { + if err := addFileToWriter(writer, files[field].fileName, field, file); err != nil { + pipeWriter.CloseWithError(err) + return + } + } + for field, value := range params { + if err := writer.WriteField(field, value); err != nil { + pipeWriter.CloseWithError(err) + return + } + } + if err := writer.Close(); err != nil { + pipeWriter.CloseWithError(err) + return + } + }() + + url := b.URL + "/bot" + b.Token + "/" + method + + resp, err := b.client.Post(url, writer.FormDataContentType(), pipeReader) + if err != nil { + err = wrapError(err) + pipeReader.CloseWithError(err) + return nil, err + } + resp.Close = true + defer resp.Body.Close() + + if resp.StatusCode == http.StatusInternalServerError { + return nil, ErrInternal + } + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, wrapError(err) + } + + return data, extractOk(data) +} + +func addFileToWriter(writer *multipart.Writer, filename, field string, file interface{}) error { + var reader io.Reader + if r, ok := file.(io.Reader); ok { + reader = r + } else if path, ok := file.(string); ok { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + reader = f + } else { + return fmt.Errorf("telebot: file for field %v should be io.ReadCloser or string", field) + } + + part, err := writer.CreateFormFile(field, filename) + if err != nil { + return err + } + + _, err = io.Copy(part, reader) + return err +} + +func (f *File) process(name string, files map[string]File) string { + switch { + case f.InCloud(): + return f.FileID + case f.FileURL != "": + return f.FileURL + case f.OnDisk() || f.FileReader != nil: + files[name] = *f + return "attach://" + name + } + return "" +} + +func (b *Bot) sendText(to Recipient, text string, opt *SendOptions) (*Message, error) { + params := map[string]string{ + "chat_id": to.Recipient(), + "text": text, + } + b.embedSendOptions(params, opt) + + data, err := b.Raw("sendMessage", params) + if err != nil { + return nil, err + } + + return extractMessage(data) +} + +func (b *Bot) sendMedia(media Media, params map[string]string, files map[string]File) (*Message, error) { + kind := media.MediaType() + what := "send" + strings.Title(kind) + + if kind == "videoNote" { + kind = "video_note" + } + + sendFiles := map[string]File{kind: *media.MediaFile()} + for k, v := range files { + sendFiles[k] = v + } + + data, err := b.sendFiles(what, sendFiles, params) + if err != nil { + return nil, err + } + + return extractMessage(data) +} + +func (b *Bot) getMe() (*User, error) { + data, err := b.Raw("getMe", nil) + if err != nil { + return nil, err + } + + var resp struct { + Result *User + } + if err := json.Unmarshal(data, &resp); err != nil { + return nil, wrapError(err) + } + return resp.Result, nil +} + +func (b *Bot) getUpdates(offset, limit int, timeout time.Duration, allowed []string) ([]Update, error) { + params := map[string]string{ + "offset": strconv.Itoa(offset), + "timeout": strconv.Itoa(int(timeout / time.Second)), + } + + data, _ := json.Marshal(allowed) + params["allowed_updates"] = string(data) + + if limit != 0 { + params["limit"] = strconv.Itoa(limit) + } + + data, err := b.Raw("getUpdates", params) + if err != nil { + return nil, err + } + + var resp struct { + Result []Update + } + if err := json.Unmarshal(data, &resp); err != nil { + return nil, wrapError(err) + } + return resp.Result, nil +} + +func (b *Bot) forwardCopyMany(to Recipient, msgs []Editable, key string, opts ...*SendOptions) ([]Message, error) { + params := map[string]string{ + "chat_id": to.Recipient(), + } + + embedMessages(params, msgs) + + if len(opts) > 0 { + b.embedSendOptions(params, opts[0]) + } + + data, err := b.Raw(key, params) + if err != nil { + return nil, err + } + + var resp struct { + Result []Message + } + if err := json.Unmarshal(data, &resp); err != nil { + var resp struct { + Result bool + } + if err := json.Unmarshal(data, &resp); err != nil { + return nil, wrapError(err) + } + return nil, wrapError(err) + } + return resp.Result, nil +} + +// extractOk checks given result for error. If result is ok returns nil. +// In other cases it extracts API error. If error is not presented +// in errors.go, it will be prefixed with `unknown` keyword. +func extractOk(data []byte) error { + var e struct { + Ok bool `json:"ok"` + Code int `json:"error_code"` + Description string `json:"description"` + Parameters map[string]interface{} `json:"parameters"` + } + if json.NewDecoder(bytes.NewReader(data)).Decode(&e) != nil { + return nil // FIXME + } + if e.Ok { + return nil + } + + err := Err(e.Description) + switch err { + case nil: + case ErrGroupMigrated: + migratedTo, ok := e.Parameters["migrate_to_chat_id"] + if !ok { + return NewError(e.Code, e.Description) + } + + return GroupError{ + err: err.(*Error), + MigratedTo: int64(migratedTo.(float64)), + } + default: + return err + } + + switch e.Code { + case http.StatusTooManyRequests: + retryAfter, ok := e.Parameters["retry_after"] + if !ok { + return NewError(e.Code, e.Description) + } + + err = FloodError{ + err: NewError(e.Code, e.Description), + RetryAfter: int(retryAfter.(float64)), + } + default: + err = fmt.Errorf("telegram: %s (%d)", e.Description, e.Code) + } + + return err +} + +// extractMessage extracts common Message result from given data. +// Should be called after extractOk or b.Raw() to handle possible errors. +func extractMessage(data []byte) (*Message, error) { + var resp struct { + Result *Message + } + if err := json.Unmarshal(data, &resp); err != nil { + var resp struct { + Result bool + } + if err := json.Unmarshal(data, &resp); err != nil { + return nil, wrapError(err) + } + if resp.Result { + return nil, ErrTrueResult + } + return nil, wrapError(err) + } + return resp.Result, nil +} + +func verbose(method string, payload interface{}, data []byte) { + body, _ := json.Marshal(payload) + body = bytes.ReplaceAll(body, []byte(`\"`), []byte(`"`)) + body = bytes.ReplaceAll(body, []byte(`"{`), []byte(`{`)) + body = bytes.ReplaceAll(body, []byte(`}"`), []byte(`}`)) + + indent := func(b []byte) string { + var buf bytes.Buffer + json.Indent(&buf, b, "", " ") + return buf.String() + } + + log.Printf( + "[verbose] telebot: sent request\nMethod: %v\nParams: %v\nResponse: %v", + method, indent(body), indent(data), + ) +} diff --git a/api_test.go b/bot_raw_test.go similarity index 100% rename from api_test.go rename to bot_raw_test.go diff --git a/context.go b/context.go index f5c49656..74f45aba 100644 --- a/context.go +++ b/context.go @@ -11,10 +11,19 @@ import ( // used to handle actual endpoints. type HandlerFunc func(Context) error +// NewContext returns a new native context object, +// field by the passed update. +func NewContext(b API, u Update) Context { + return &nativeContext{ + b: b, + u: u, + } +} + // Context wraps an update and represents the context of current event. type Context interface { // Bot returns the bot instance. - Bot() *Bot + Bot() API // Update returns the original update. Update() Update @@ -174,13 +183,13 @@ type Context interface { // nativeContext is a native implementation of the Context interface. // "context" is taken by context package, maybe there is a better name. type nativeContext struct { - b *Bot + b API u Update lock sync.RWMutex store map[string]interface{} } -func (c *nativeContext) Bot() *Bot { +func (c *nativeContext) Bot() API { return c.b } @@ -478,7 +487,9 @@ func (c *nativeContext) Delete() error { func (c *nativeContext) DeleteAfter(d time.Duration) *time.Timer { return time.AfterFunc(d, func() { if err := c.Delete(); err != nil { - c.b.OnError(err, c) + if b, ok := c.b.(*Bot); ok { + b.OnError(err, c) + } } }) } diff --git a/middleware/middleware.go b/middleware/middleware.go index e616c1e5..7b17d3b4 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -2,6 +2,7 @@ package middleware import ( "errors" + "log" tele "gopkg.in/telebot.v3" ) @@ -42,9 +43,11 @@ func Recover(onError ...RecoverFunc) tele.MiddlewareFunc { var f RecoverFunc if len(onError) > 0 { f = onError[0] + } else if b, ok := c.Bot().(*tele.Bot); ok { + f = b.OnError } else { - f = func(err error, c tele.Context) { - c.Bot().OnError(err, c) + f = func(err error, _ tele.Context) { + log.Println("telebot/middleware/recover:", err) } } diff --git a/update.go b/update.go index 63396dd4..fed3aa09 100644 --- a/update.go +++ b/update.go @@ -29,7 +29,13 @@ type Update struct { // ProcessUpdate processes a single incoming update. // A started bot calls this function automatically. func (b *Bot) ProcessUpdate(u Update) { - c := b.NewContext(u) + b.ProcessContext(b.NewContext(u)) +} + +// ProcessContext processes the given context. +// A started bot calls this function automatically. +func (b *Bot) ProcessContext(c Context) { + u := c.Update() if u.Message != nil { m := u.Message