Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes SocketMode support in slacktest #1247

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/socketmode_handler/socketmode_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,3 @@ func middlewareSlashCommand(evt *socketmode.Event, client *socketmode.Client) {
}}
client.Ack(*evt.Request, payload)
}

func middlewareDefault(evt *socketmode.Event, client *socketmode.Client) {
// fmt.Fprintf(os.Stderr, "Unexpected event type received: %s\n", evt.Type)
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ require (
github.com/gorilla/websocket v1.4.2
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.2.2
go.mills.io/logger v0.0.0-20230806012737-485dbd691907
)
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
go.mills.io/logger v0.0.0-20230806012737-485dbd691907 h1:KXwGupN4n3h/t9HyTLykODg1ope7KtXaknALkBMaaz4=
go.mills.io/logger v0.0.0-20230806012737-485dbd691907/go.mod h1:A+23JY9iOHzujHnRYbFKVzLLAQVObxHnsap8kjAjuQ8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
19 changes: 19 additions & 0 deletions slacktest/funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ func BotNameFromContext(ctx context.Context) string {
return botname
}

// ServerWSFromContext returns the server websocket endpoint from a provided context
func ServerWSFromContext(ctx context.Context) string {
url, ok := ctx.Value(ServerWSContextKey).(string)
if !ok {
return "ws://wtf?!"
}
return url
}

// BotIDFromContext returns the bot userid from a provided context
func BotIDFromContext(ctx context.Context) string {
botname, ok := ctx.Value(ServerBotIDContextKey).(string)
Expand Down Expand Up @@ -117,6 +126,16 @@ func nowAsJSONTime() slack.JSONTime {
return slack.JSONTime(time.Now().Unix())
}

func defaultAppsConnectionsJSON(ctx context.Context) string {
url := ServerWSFromContext(ctx)
return fmt.Sprintf(`
{
"ok":true,
"url": "%s"
}
`, url)
}

func defaultBotInfoJSON(ctx context.Context) string {
botid := BotIDFromContext(ctx)
botname := BotNameFromContext(ctx)
Expand Down
30 changes: 30 additions & 0 deletions slacktest/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ import (
slack "github.com/slack-go/slack"
)

var (
defaultPingPeriod = time.Second * 15
defaultWriteDeadline = time.Second * 3
)

func contextHandler(server *Server, next http.HandlerFunc) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), ServerURLContextKey, server.GetAPIURL())
Expand Down Expand Up @@ -277,10 +282,20 @@ func rtmStartHandler(w http.ResponseWriter, r *http.Request) {
}
}

// handle apps.connections.open
func appsConnectionsOpenHandler(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(defaultAppsConnectionsJSON(r.Context())))
}

func (sts *Server) wsHandler(w http.ResponseWriter, r *http.Request) {
Websocket(func(c *websocket.Conn) {
serverAddr := r.Context().Value(ServerBotHubNameContextKey).(string)
doneCh := make(chan struct{}, 1)
defer func() {
doneCh <- struct{}{}
}()
go handlePendingMessages(c, serverAddr)
go handlePeriodicPings(c, doneCh)
for {
var (
err error
Expand Down Expand Up @@ -309,6 +324,21 @@ func (sts *Server) wsHandler(w http.ResponseWriter, r *http.Request) {
})(w, r)
}

func handlePeriodicPings(c *websocket.Conn, done chan struct{}) {
ticker := time.NewTicker(defaultPingPeriod)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if err := c.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(defaultWriteDeadline)); err != nil {
log.Println("error sending ping:", err)
}
case <-done:
return
}
}
}

// Websocket handler
func Websocket(delegate func(c *websocket.Conn)) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
Expand Down
48 changes: 44 additions & 4 deletions slacktest/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import (
"log"
"net/http"
"net/http/httptest"
"regexp"
"time"

"github.com/slack-go/slack"
"go.mills.io/logger"
)

func newMessageChannels() *messageChannels {
Expand Down Expand Up @@ -61,8 +63,10 @@ func NewTestServer(custom ...Binder) *Server {
s.Handle("/bots.info", botsInfoHandler)
s.Handle("/auth.test", authTestHandler)
s.Handle("/reactions.add", reactionAddHandler)
s.Handle("/apps.connections.open", appsConnectionsOpenHandler)

httpserver := httptest.NewUnstartedServer(s.mux)
httpserver.Config.Handler = logger.New().Handler(httpserver.Config.Handler)
addr := httpserver.Listener.Addr().String()

s.ServerAddr = addr
Expand Down Expand Up @@ -138,8 +142,26 @@ func (sts *Server) SawOutgoingMessage(msg string) bool {
return false
}

// SawMessage checks if an incoming message was seen
func (sts *Server) SawMessage(msg string) bool {
// SawOutgoingMessageMatching checks if a message was sent to connected websocket clients that matches the given pattern
func (sts *Server) SawOutgoingMessageMatching(pattern string) bool {
sts.seenOutboundMessages.RLock()
defer sts.seenOutboundMessages.RUnlock()
for _, m := range sts.seenOutboundMessages.messages {
evt := &slack.MessageEvent{}
jErr := json.Unmarshal([]byte(m), evt)
if jErr != nil {
continue
}

if ok, err := regexp.MatchString(pattern, evt.Text); err == nil && ok {
return true
}
}
return false
}

// SawIncomingMessage checks if an incoming message was seen
func (sts *Server) SawIncomingMessage(msg string) bool {
sts.seenInboundMessages.RLock()
defer sts.seenInboundMessages.RUnlock()
for _, m := range sts.seenInboundMessages.messages {
Expand All @@ -156,6 +178,24 @@ func (sts *Server) SawMessage(msg string) bool {
return false
}

// SawIncomingMessageMatching checks if an incoming message was seen that matches a given pattern
func (sts *Server) SawIncomingMessageMatching(pattern string) bool {
sts.seenInboundMessages.RLock()
defer sts.seenInboundMessages.RUnlock()
for _, m := range sts.seenInboundMessages.messages {
evt := &slack.MessageEvent{}
jErr := json.Unmarshal([]byte(m), evt)
if jErr != nil {
// This event isn't a message event so we'll skip it
continue
}
if ok, err := regexp.MatchString(pattern, evt.Text); err == nil && ok {
return true
}
}
return false
}

// GetAPIURL returns the api url you can pass to slack.SLACK_API
func (sts *Server) GetAPIURL() string {
return "http://" + sts.ServerAddr + "/"
Expand Down Expand Up @@ -302,7 +342,7 @@ func (sts *Server) SendBotGroupInvite() {

// GetTestRTMInstance will give you an RTM instance in the context of the current fake server
func (sts *Server) GetTestRTMInstance() *slack.RTM {
api := slack.New("ABCEFG", slack.OptionAPIURL(sts.GetAPIURL()))
rtm := api.NewRTM()
api := slack.New("ABCEFG", slack.OptionDebug(true), slack.OptionAPIURL(sts.GetAPIURL()))
rtm := api.NewRTM(slack.RTMOptionPingInterval(5 * time.Second))
return rtm
}
4 changes: 2 additions & 2 deletions slacktest/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func TestGetSeenInboundMessages(t *testing.T) {
}
}
assert.True(t, hadMessage, "did not see my sent message")
assert.True(t, s.SawMessage("should see this inbound message"))
assert.True(t, s.SawIncomingMessage("should see this inbound message"))
}

func TestSendChannelInvite(t *testing.T) {
Expand Down Expand Up @@ -166,7 +166,7 @@ func TestSendGroupInvite(t *testing.T) {
func TestServerSawMessage(t *testing.T) {
s := NewTestServer()
go s.Start()
assert.False(t, s.SawMessage("foo"), "should not have seen any message")
assert.False(t, s.SawIncomingMessage("foo"), "should not have seen any message")
}

func TestServerSawOutgoingMessage(t *testing.T) {
Expand Down
2 changes: 0 additions & 2 deletions slacktest/types.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package slacktest

import (
"log"
"net/http"
"net/http/httptest"
"sync"
Expand Down Expand Up @@ -64,7 +63,6 @@ type Server struct {
registered map[string]struct{}
server *httptest.Server
mux *http.ServeMux
Logger *log.Logger
BotName string
BotID string
ServerAddr string
Expand Down
3 changes: 3 additions & 0 deletions socketmode/socket_mode_managed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"sync"
"time"
Expand Down Expand Up @@ -312,6 +313,7 @@ func (smc *Client) openAndDial(ctx context.Context, additionalPingHandler func(s
return nil, nil, err
}
if additionalPingHandler == nil {
log.Print("no ping handler, default to null handler")
additionalPingHandler = func(_ string) error { return nil }
}

Expand Down Expand Up @@ -380,6 +382,7 @@ func (smc *Client) runRequestHandler(ctx context.Context, websocket chan json.Ra
// listen for incoming messages that need to be parsed
evt, err := smc.parseEvent(message)
if err != nil {
log.Printf("error parsing event %q: %s", message, err)
smc.sendEvent(ctx, newEvent(EventTypeErrorBadMessage, &ErrorBadMessage{
Cause: err,
Message: message,
Expand Down
1 change: 1 addition & 0 deletions socketmode/socketmode.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ func OptionPingInterval(d time.Duration) Option {
// OptionDebug enable debugging for the client
func OptionDebug(b bool) func(*Client) {
return func(c *Client) {
c.log.Printf("Using debug mode: %t", b)
c.debug = b
}
}
Expand Down
Loading