diff --git a/README.md b/README.md index 66e19a8f..ce5003c7 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,6 @@ Common Options: * Supports will messages -* Queue subscribe - * Websocket Support * TLS/SSL Support @@ -95,13 +93,6 @@ Common Options: ``` -### QUEUE SUBSCRIBE -~~~ -| Prefix | Examples | -| ------------- |---------------------------------| -| $queue/ | mosquitto_sub -t ‘$queue/topic’ | -~~~ - ### ACL Configure #### The ACL rules define: ~~~ @@ -154,6 +145,14 @@ Client -> | Rule1 | --nomatch--> | Rule2 | --nomatch--> | Rule3 | --> allow | deny allow | deny allow | deny ~~~ +### Online/Offline Notification +```bash + topic: + $SYS/broker/connection/clients/ + payload: + {"clientID":"client001","online":true/false,"timestamp":"2018-10-25T09:32:32Z"} +``` + ## Performance * High throughput @@ -166,3 +165,8 @@ Client -> | Rule1 | --nomatch--> | Rule2 | --nomatch--> | Rule3 | --> ## License * Apache License Version 2.0 + + +## Reference + +* Surgermq.(https://github.com/surgemq/surgemq) \ No newline at end of file diff --git a/broker/broker.go b/broker/broker.go index 11006f5d..de1aa735 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -4,6 +4,7 @@ package broker import ( "crypto/tls" + "fmt" "net" "net/http" "runtime/debug" @@ -13,6 +14,8 @@ import ( "github.com/eclipse/paho.mqtt.golang/packets" "github.com/fhmq/hmq/lib/acl" + "github.com/fhmq/hmq/lib/sessions" + "github.com/fhmq/hmq/lib/topics" "github.com/fhmq/hmq/pool" "github.com/shirou/gopsutil/mem" "go.uber.org/zap" @@ -42,9 +45,9 @@ type Broker struct { remotes sync.Map nodes map[string]interface{} clusterPool chan *Message - sl *Sublist - rl *RetainList queues map[string]int + topicsMgr *topics.Manager + sessionMgr *sessions.Manager // messagePool []chan *Message } @@ -62,13 +65,24 @@ func NewBroker(config *Config) (*Broker, error) { id: GenUniqueId(), config: config, wpool: pool.New(config.Worker), - sl: NewSublist(), - rl: NewRetainList(), nodes: make(map[string]interface{}), queues: make(map[string]int), clusterPool: make(chan *Message), - // messagePool: newMessagePool(), } + + var err error + b.topicsMgr, err = topics.NewManager("mem") + if err != nil { + log.Error("new topic manager error", zap.Error(err)) + return nil, err + } + + b.sessionMgr, err = sessions.NewManager("mem") + if err != nil { + log.Error("new session manager error", zap.Error(err)) + return nil, err + } + if b.config.TlsPort != "" { tlsconfig, err := NewTLSConfig(b.config.TlsInfo) if err != nil { @@ -333,6 +347,12 @@ func (b *Broker) handleConnection(typ int, conn net.Conn) { c.init() + err = b.getSession(c, msg, connack) + if err != nil { + log.Error("get session error: ", zap.String("clientID", c.info.clientID)) + return + } + cid := c.info.clientID var exist bool @@ -349,6 +369,8 @@ func (b *Broker) handleConnection(typ int, conn net.Conn) { } } b.clients.Store(cid, c) + + b.OnlineOfflineNotification(cid, true) case ROUTER: old, exist = b.routes.Load(cid) if exist { @@ -535,9 +557,9 @@ func (b *Broker) SendLocalSubsToRouter(c *client) { b.clients.Range(func(key, value interface{}) bool { client, ok := value.(*client) if ok { - subs := client.subs + subs := client.subMap for _, sub := range subs { - subInfo.Topics = append(subInfo.Topics, string(sub.topic)) + subInfo.Topics = append(subInfo.Topics, sub.topic) subInfo.Qoss = append(subInfo.Qoss, sub.qos) } } @@ -593,17 +615,20 @@ func (b *Broker) removeClient(c *client) { } func (b *Broker) PublishMessage(packet *packets.PublishPacket) { - topic := packet.TopicName - r := b.sl.Match(topic) - if len(r.psubs) == 0 { + var subs []interface{} + var qoss []byte + err := b.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &subs, &qoss) + if err != nil { + log.Error("search sub client error, ", zap.Error(err)) return } - for _, sub := range r.psubs { - if sub != nil { - err := sub.client.WriterPacket(packet) + for _, sub := range subs { + s, ok := sub.(*subscription) + if ok { + err := s.client.WriterPacket(packet) if err != nil { - log.Error("process message for psub error, ", zap.Error(err)) + log.Error("write message error, ", zap.Error(err)) } } } @@ -620,3 +645,12 @@ func (b *Broker) BroadcastUnSubscribe(subs map[string]*subscription) { b.BroadcastSubOrUnsubMessage(unsub) } } + +func (b *Broker) OnlineOfflineNotification(clientID string, online bool) { + packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) + packet.TopicName = "$SYS/broker/connection/clients/" + clientID + packet.Qos = 0 + packet.Payload = []byte(fmt.Sprintf(`{"clientID":"%s","online":%v,"timestamp":"%s"}`, clientID, online, time.Now().UTC().Format(time.RFC3339))) + + b.PublishMessage(packet) +} diff --git a/broker/client.go b/broker/client.go index 30d012e8..204b0c71 100644 --- a/broker/client.go +++ b/broker/client.go @@ -12,6 +12,8 @@ import ( "time" "github.com/eclipse/paho.mqtt.golang/packets" + "github.com/fhmq/hmq/lib/sessions" + "github.com/fhmq/hmq/lib/topics" "go.uber.org/zap" ) @@ -39,11 +41,14 @@ type client struct { info info route route status int - smu sync.RWMutex - subs map[string]*subscription - rsubs map[string]*subInfo ctx context.Context cancelFunc context.CancelFunc + session *sessions.Session + subMap map[string]*subscription + topicsMgr *topics.Manager + subs []interface{} + qoss []byte + rmsgs []*packets.PublishPacket } type subInfo struct { @@ -78,44 +83,12 @@ var ( ) func (c *client) init() { - c.smu.Lock() - defer c.smu.Unlock() c.status = Connected - c.rsubs = make(map[string]*subInfo) - c.subs = make(map[string]*subscription, 10) c.info.localIP = strings.Split(c.conn.LocalAddr().String(), ":")[0] c.info.remoteIP = strings.Split(c.conn.RemoteAddr().String(), ":")[0] c.ctx, c.cancelFunc = context.WithCancel(context.Background()) -} - -func (c *client) keepAlive(ch chan int) { - defer close(ch) - - b := c.broker - - keepalive := time.Duration(c.info.keepalive*3/2) * time.Second - timer := time.NewTimer(keepalive) - - for { - select { - case <-ch: - timer.Reset(keepalive) - case <-timer.C: - if c.typ == REMOTE || c.typ == CLUSTER { - timer.Reset(keepalive) - continue - } - log.Error("Client exceeded timeout, disconnecting. ", zap.String("ClientID", c.info.clientID), zap.Uint16("keepalive", c.info.keepalive)) - - msg := &Message{client: c, packet: DisconnectdPacket} - b.SubmitWork(msg) - - timer.Stop() - return - case <-c.ctx.Done(): - return - } - } + c.subMap = make(map[string]*subscription) + c.topicsMgr = c.broker.topicsMgr } func (c *client) readLoop() { @@ -125,14 +98,20 @@ func (c *client) readLoop() { return } - ch := make(chan int, 1000) - go c.keepAlive(ch) + keepAlive := time.Second * time.Duration(c.info.keepalive) + timeOut := keepAlive + (keepAlive / 2) for { select { case <-c.ctx.Done(): return default: + //add read timeout + if err := nc.SetReadDeadline(time.Now().Add(timeOut)); err != nil { + log.Error("set read timeout error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) + return + } + packet, err := packets.ReadPacket(nc) if err != nil { log.Error("read packet error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) @@ -140,8 +119,6 @@ func (c *client) readLoop() { b.SubmitWork(msg) return } - // keepalive channel - ch <- 1 msg := &Message{ client: c, @@ -159,7 +136,6 @@ func ProcessMessage(msg *Message) { if ca == nil { return } - log.Debug("Recv message:", zap.String("message type", reflect.TypeOf(msg.packet).String()[9:]), zap.String("ClientID", c.info.clientID)) switch ca.(type) { case *packets.ConnackPacket: @@ -222,14 +198,6 @@ func (c *client) ProcessPublish(packet *packets.PublishPacket) { log.Error("publish with unknown qos", zap.String("ClientID", c.info.clientID)) return } - if packet.Retain { - if b := c.broker; b != nil { - err := b.rl.Insert(topic, packet) - if err != nil { - log.Error("Insert Retain Message error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - } - } - } } @@ -243,85 +211,40 @@ func (c *client) ProcessPublishMessage(packet *packets.PublishPacket) { return } typ := c.typ - topic := packet.TopicName - r := b.sl.Match(topic) - if r == nil { - return + if packet.Retain { + if err := c.topicsMgr.Retain(packet); err != nil { + log.Error("Error retaining message: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) + } } - // log.Info("psubs num: ", len(r.psubs)) - if len(r.qsubs) == 0 && len(r.psubs) == 0 { + err := c.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &c.subs, &c.qoss) + if err != nil { + log.Error("Error retrieving subscribers list: ", zap.String("ClientID", c.info.clientID)) return } - for _, sub := range r.psubs { - if sub.client.typ == ROUTER { - if typ != CLIENT { - continue - } - } - if sub != nil { - err := sub.client.WriterPacket(packet) - if err != nil { - log.Error("process message for psub error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - } - } + // log.Info("psubs num: ", len(r.psubs)) + if len(c.subs) == 0 { + return } - pre := -1 - now := -1 - t := "$queue/" + topic - cnt, exist := b.queues[t] - if exist { - // log.Info("queue index : ", cnt) - for _, sub := range r.qsubs { - if sub.client.typ == ROUTER { + for _, sub := range c.subs { + s, ok := sub.(*subscription) + if ok { + if s.client.typ == ROUTER { if typ != CLIENT { continue } } - if c.typ == CLIENT { - now = now + 1 - } else { - now = now + sub.client.rsubs[t].num - } - if cnt > pre && cnt <= now { - if sub != nil { - err := sub.client.WriterPacket(packet) - if err != nil { - log.Error("send publish error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - } - } - - break + err := s.client.WriterPacket(packet) + if err != nil { + log.Error("process message for psub error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) } - pre = now } - } - length := getQueueSubscribeNum(r.qsubs) - if length > 0 { - b.queues[t] = (b.queues[t] + 1) % length } -} -func getQueueSubscribeNum(qsubs []*subscription) int { - topic := "$queue/" - if len(qsubs) < 1 { - return 0 - } else { - topic = topic + qsubs[0].topic - } - num := 0 - for _, sub := range qsubs { - if sub.client.typ == CLIENT { - num = num + 1 - } else { - num = num + sub.client.rsubs[topic].num - } - } - return num } func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) { @@ -349,54 +272,24 @@ func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) { continue } - queue := strings.HasPrefix(topic, "$queue/") - if queue { - if len(t) > 7 { - t = t[7:] - if _, exists := b.queues[topic]; !exists { - b.queues[topic] = 0 - } - } else { - retcodes = append(retcodes, QosFailure) - continue - } - } sub := &subscription{ topic: t, qos: qoss[i], client: c, - queue: queue, - } - switch c.typ { - case CLIENT: - if _, exist := c.subs[topic]; !exist { - c.subs[topic] = sub - - } else { - //if exist ,check whether qos change - c.subs[topic].qos = qoss[i] - retcodes = append(retcodes, qoss[i]) - continue - } - case ROUTER: - if subinfo, exist := c.rsubs[topic]; !exist { - sinfo := &subInfo{sub: sub, num: 1} - c.rsubs[topic] = sinfo - - } else { - subinfo.num = subinfo.num + 1 - retcodes = append(retcodes, qoss[i]) - continue - } } - err := b.sl.Insert(sub) + + rqos, err := c.topicsMgr.Subscribe([]byte(topic), qoss[i], sub) if err != nil { - log.Error("Insert subscription error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - retcodes = append(retcodes, QosFailure) - } else { - retcodes = append(retcodes, qoss[i]) + return } + + c.subMap[topic] = sub + c.session.AddTopic(topic, qoss[i]) + retcodes = append(retcodes, rqos) + c.topicsMgr.Retained([]byte(topic), &c.rmsgs) + } + suback.ReturnCodes = retcodes err := c.WriterPacket(suback) @@ -410,16 +303,11 @@ func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) { } //process retain message - for _, t := range topics { - packets := b.rl.Match(t) - if packets == nil { - continue - } - for _, packet := range packets { + for _, rm := range c.rmsgs { + if err := c.WriterPacket(rm); err != nil { + log.Error("Error publishing retained message:", zap.Any("err", err), zap.String("ClientID", c.info.clientID)) + } else { log.Info("process retain message: ", zap.Any("packet", packet), zap.String("ClientID", c.info.clientID)) - if packet != nil { - c.WriterPacket(packet) - } } } } @@ -432,30 +320,16 @@ func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) { if b == nil { return } - typ := c.typ topics := packet.Topics - for _, t := range topics { - - switch typ { - case CLIENT: - sub, ok := c.subs[t] - if ok { - c.unsubscribe(sub) - } - case ROUTER: - subinfo, ok := c.rsubs[t] - if ok { - subinfo.num = subinfo.num - 1 - if subinfo.num < 1 { - delete(c.rsubs, t) - c.unsubscribe(subinfo.sub) - } else { - c.rsubs[t] = subinfo - } - } + for _, topic := range topics { + t := []byte(topic) + sub, exist := c.subMap[topic] + if exist { + c.topicsMgr.Unsubscribe(t, sub) + c.session.RemoveTopic(topic) + delete(c.subMap, topic) } - } unsuback := packets.NewControlPacket(packets.Unsuback).(*packets.UnsubackPacket) @@ -472,19 +346,6 @@ func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) { } } -func (c *client) unsubscribe(sub *subscription) { - - if c.typ == CLIENT { - delete(c.subs, sub.topic) - - } - b := c.broker - if b != nil && sub != nil { - b.sl.Remove(sub) - } - -} - func (c *client) ProcessPing() { if c.status == Disconnected { return @@ -498,9 +359,7 @@ func (c *client) ProcessPing() { } func (c *client) Close() { - c.smu.Lock() if c.status == Disconnected { - c.smu.Unlock() return } @@ -516,21 +375,17 @@ func (c *client) Close() { c.conn = nil } - c.smu.Unlock() - b := c.broker - subs := c.subs + subs := c.subMap if b != nil { b.removeClient(c) - for _, sub := range subs { - err := b.sl.Remove(sub) - if err != nil { - log.Error("closed client but remove sublist error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - } - } + if c.typ == CLIENT { b.BroadcastUnSubscribe(subs) + //offline notification + b.OnlineOfflineNotification(c.info.clientID, false) } + if c.info.willMsg != nil { b.PublishMessage(c.info.willMsg) } diff --git a/broker/comm.go b/broker/comm.go index 534a3728..3293be2c 100644 --- a/broker/comm.go +++ b/broker/comm.go @@ -7,10 +7,8 @@ import ( "crypto/rand" "encoding/base64" "encoding/hex" - "errors" "io" "reflect" - "strings" "time" ) @@ -48,47 +46,6 @@ const ( QosFailure = 0x80 ) -func SubscribeTopicCheckAndSpilt(topic string) ([]string, error) { - if strings.Index(topic, "#") != -1 && strings.Index(topic, "#") != len(topic)-1 { - return nil, errors.New("Topic format error with index of #") - } - re := strings.Split(topic, "/") - for i, v := range re { - if i != 0 && i != (len(re)-1) { - if v == "" { - return nil, errors.New("Topic format error with index of //") - } - if strings.Contains(v, "+") && v != "+" { - return nil, errors.New("Topic format error with index of +") - } - } else { - if v == "" { - re[i] = "/" - } - } - } - return re, nil - -} - -func PublishTopicCheckAndSpilt(topic string) ([]string, error) { - if strings.Index(topic, "#") != -1 || strings.Index(topic, "+") != -1 { - return nil, errors.New("Publish Topic format error with + and #") - } - re := strings.Split(topic, "/") - for i, v := range re { - if v == "" { - if i != 0 && i != (len(re)-1) { - return nil, errors.New("Topic format error with index of //") - } else { - re[i] = "/" - } - } - - } - return re, nil -} - func equal(k1, k2 interface{}) bool { if reflect.TypeOf(k1) != reflect.TypeOf(k2) { return false diff --git a/broker/config.go b/broker/config.go index e354988b..a567a209 100644 --- a/broker/config.go +++ b/broker/config.go @@ -9,10 +9,11 @@ import ( "errors" "flag" "fmt" - "github.com/fhmq/hmq/logger" - "go.uber.org/zap" "io/ioutil" "os" + + "github.com/fhmq/hmq/logger" + "go.uber.org/zap" ) type Config struct { diff --git a/broker/retain.go b/broker/retain.go deleted file mode 100644 index 1c5cf123..00000000 --- a/broker/retain.go +++ /dev/null @@ -1,121 +0,0 @@ -package broker - -import ( - "github.com/eclipse/paho.mqtt.golang/packets" - "sync" -) - -type RetainList struct { - sync.RWMutex - root *rlevel -} -type rlevel struct { - nodes map[string]*rnode -} -type rnode struct { - next *rlevel - msg *packets.PublishPacket -} -type RetainResult struct { - msg []*packets.PublishPacket -} - -func newRNode() *rnode { - return &rnode{} -} - -func newRLevel() *rlevel { - return &rlevel{nodes: make(map[string]*rnode)} -} - -func NewRetainList() *RetainList { - return &RetainList{root: newRLevel()} -} - -func (r *RetainList) Insert(topic string, buf *packets.PublishPacket) error { - - tokens, err := PublishTopicCheckAndSpilt(topic) - if err != nil { - return err - } - // log.Info("insert tokens:", tokens) - r.Lock() - - l := r.root - var n *rnode - for _, t := range tokens { - n = l.nodes[t] - if n == nil { - n = newRNode() - l.nodes[t] = n - } - if n.next == nil { - n.next = newRLevel() - } - l = n.next - } - n.msg = buf - r.Unlock() - return nil -} - -func (r *RetainList) Match(topic string) []*packets.PublishPacket { - - tokens, err := SubscribeTopicCheckAndSpilt(topic) - if err != nil { - return nil - } - results := &RetainResult{} - - r.Lock() - l := r.root - matchRLevel(l, tokens, results) - r.Unlock() - // log.Info("results: ", results) - return results.msg - -} -func matchRLevel(l *rlevel, toks []string, results *RetainResult) { - var n *rnode - for i, t := range toks { - if l == nil { - return - } - // log.Info("l info :", l.nodes) - if t == "#" { - for _, n := range l.nodes { - n.GetAll(results) - } - } - if t == "+" { - for _, n := range l.nodes { - if len(t[i+1:]) == 0 { - results.msg = append(results.msg, n.msg) - } else { - matchRLevel(n.next, toks[i+1:], results) - } - } - } - - n = l.nodes[t] - if n != nil { - l = n.next - } else { - l = nil - } - } - if n != nil { - results.msg = append(results.msg, n.msg) - } -} - -func (r *rnode) GetAll(results *RetainResult) { - // log.Info("node 's message: ", string(r.msg)) - if r.msg != nil { - results.msg = append(results.msg, r.msg) - } - l := r.next - for _, n := range l.nodes { - n.GetAll(results) - } -} diff --git a/broker/sesson.go b/broker/sesson.go new file mode 100644 index 00000000..59d5c117 --- /dev/null +++ b/broker/sesson.go @@ -0,0 +1,53 @@ +package broker + +import "github.com/eclipse/paho.mqtt.golang/packets" + +func (b *Broker) getSession(cli *client, req *packets.ConnectPacket, resp *packets.ConnackPacket) error { + // If CleanSession is set to 0, the server MUST resume communications with the + // client based on state from the current session, as identified by the client + // identifier. If there is no session associated with the client identifier the + // server must create a new session. + // + // If CleanSession is set to 1, the client and server must discard any previous + // session and start a new one. b session lasts as long as the network c + // onnection. State data associated with b session must not be reused in any + // subsequent session. + + var err error + + // Check to see if the client supplied an ID, if not, generate one and set + // clean session. + + if len(req.ClientIdentifier) == 0 { + req.CleanSession = true + } + + cid := req.ClientIdentifier + + // If CleanSession is NOT set, check the session store for existing session. + // If found, return it. + if !req.CleanSession { + if cli.session, err = b.sessionMgr.Get(cid); err == nil { + resp.SessionPresent = true + + if err := cli.session.Update(req); err != nil { + return err + } + } + } + + // If CleanSession, or no existing session found, then create a new one + if cli.session == nil { + if cli.session, err = b.sessionMgr.New(cid); err != nil { + return err + } + + resp.SessionPresent = false + + if err := cli.session.Init(req); err != nil { + return err + } + } + + return nil +} diff --git a/broker/sublist.go b/broker/sublist.go deleted file mode 100644 index e1d8ecc1..00000000 --- a/broker/sublist.go +++ /dev/null @@ -1,317 +0,0 @@ -/* Copyright (c) 2018, joy.zhou - */ -package broker - -import ( - "errors" - "go.uber.org/zap" - "sync" -) - -// A result structure better optimized for queue subs. -type SublistResult struct { - psubs []*subscription - qsubs []*subscription // don't make this a map, too expensive to iterate -} - -// A Sublist stores and efficiently retrieves subscriptions. -type Sublist struct { - sync.RWMutex - cache map[string]*SublistResult - root *level -} - -// A node contains subscriptions and a pointer to the next level. -type node struct { - next *level - psubs []*subscription - qsubs []*subscription -} - -// A level represents a group of nodes and special pointers to -// wildcard nodes. -type level struct { - nodes map[string]*node -} - -// Create a new default node. -func newNode() *node { - return &node{psubs: make([]*subscription, 0, 4), qsubs: make([]*subscription, 0, 4)} -} - -// Create a new default level. We use FNV1A as the hash -// algortihm for the tokens, which should be short. -func newLevel() *level { - return &level{nodes: make(map[string]*node)} -} - -// New will create a default sublist -func NewSublist() *Sublist { - return &Sublist{root: newLevel(), cache: make(map[string]*SublistResult)} -} - -// Insert adds a subscription into the sublist -func (s *Sublist) Insert(sub *subscription) error { - - tokens, err := SubscribeTopicCheckAndSpilt(sub.topic) - if err != nil { - return err - } - s.Lock() - - l := s.root - var n *node - for _, t := range tokens { - n = l.nodes[t] - if n == nil { - n = newNode() - l.nodes[t] = n - } - if n.next == nil { - n.next = newLevel() - } - l = n.next - } - if sub.queue { - //check qsub is already exist - for i := range n.qsubs { - if equal(n.qsubs[i], sub) { - n.qsubs[i] = sub - return nil - } - } - n.qsubs = append(n.qsubs, sub) - } else { - //check psub is already exist - for i := range n.psubs { - if equal(n.psubs[i], sub) { - n.psubs[i] = sub - return nil - } - } - n.psubs = append(n.psubs, sub) - } - - topic := string(sub.topic) - s.addToCache(topic, sub) - s.Unlock() - return nil -} - -func (s *Sublist) addToCache(topic string, sub *subscription) { - for k, r := range s.cache { - if matchLiteral(k, topic) { - // Copy since others may have a reference. - nr := copyResult(r) - if sub.queue == false { - nr.psubs = append(nr.psubs, sub) - } else { - nr.qsubs = append(nr.qsubs, sub) - } - s.cache[k] = nr - } - } -} - -func (s *Sublist) removeFromCache(topic string, sub *subscription) { - for k := range s.cache { - if !matchLiteral(k, topic) { - continue - } - // Since someone else may be referecing, can't modify the list - // safely, just let it re-populate. - delete(s.cache, k) - } -} - -func matchLiteral(literal, topic string) bool { - tok, _ := SubscribeTopicCheckAndSpilt(topic) - li, _ := PublishTopicCheckAndSpilt(literal) - - for i := 0; i < len(tok); i++ { - b := tok[i] - switch b { - case "+": - - case "#": - return true - default: - if b != li[i] { - return false - } - } - } - return true -} - -// Deep copy -func copyResult(r *SublistResult) *SublistResult { - nr := &SublistResult{} - nr.psubs = append([]*subscription(nil), r.psubs...) - nr.qsubs = append([]*subscription(nil), r.qsubs...) - return nr -} - -func (s *Sublist) Remove(sub *subscription) error { - tokens, err := SubscribeTopicCheckAndSpilt(sub.topic) - if err != nil { - return err - } - s.Lock() - defer s.Unlock() - - l := s.root - var n *node - - for _, t := range tokens { - if l == nil { - return errors.New("No Matches subscription Found") - } - n = l.nodes[t] - if n != nil { - l = n.next - } else { - l = nil - } - } - if !s.removeFromNode(n, sub) { - return errors.New("No Matches subscription Found") - } - topic := string(sub.topic) - s.removeFromCache(topic, sub) - return nil - -} - -func (s *Sublist) removeFromNode(n *node, sub *subscription) (found bool) { - if n == nil { - return false - } - - if sub.queue { - n.qsubs, found = removeSubFromList(sub, n.qsubs) - return found - } else { - n.psubs, found = removeSubFromList(sub, n.psubs) - return found - } - - return false -} - -func (s *Sublist) Match(topic string) *SublistResult { - s.RLock() - rc, ok := s.cache[topic] - s.RUnlock() - - if ok { - return rc - } - - tokens, err := PublishTopicCheckAndSpilt(topic) - if err != nil { - log.Error("\tserver/sublist.go: ", zap.Error(err)) - return nil - } - - result := &SublistResult{} - - s.Lock() - l := s.root - if len(tokens) > 0 { - if tokens[0] == "/" { - if _, exist := l.nodes["#"]; exist { - addNodeToResults(l.nodes["#"], result) - } - if _, exist := l.nodes["+"]; exist { - matchLevel(l.nodes["/"].next, tokens[1:], result) - } - if _, exist := l.nodes["/"]; exist { - matchLevel(l.nodes["/"].next, tokens[1:], result) - } - } else { - matchLevel(s.root, tokens, result) - } - } - s.cache[topic] = result - if len(s.cache) > 1024 { - for k := range s.cache { - delete(s.cache, k) - break - } - } - - s.Unlock() - return result -} - -func matchLevel(l *level, toks []string, results *SublistResult) { - var swc, n *node - exist := false - for i, t := range toks { - if l == nil { - return - } - - if _, exist = l.nodes["#"]; exist { - addNodeToResults(l.nodes["#"], results) - } - if t != "/" { - if swc, exist = l.nodes["+"]; exist { - matchLevel(l.nodes["+"].next, toks[i+1:], results) - } - } else { - if _, exist = l.nodes["+"]; exist { - addNodeToResults(l.nodes["+"], results) - } - } - - n = l.nodes[t] - if n != nil { - l = n.next - } else { - l = nil - } - } - if n != nil { - addNodeToResults(n, results) - } - if swc != nil { - addNodeToResults(n, results) - } -} - -// This will add in a node's results to the total results. -func addNodeToResults(n *node, results *SublistResult) { - results.psubs = append(results.psubs, n.psubs...) - results.qsubs = append(results.qsubs, n.qsubs...) -} - -func removeSubFromList(sub *subscription, sl []*subscription) ([]*subscription, bool) { - for i := 0; i < len(sl); i++ { - if sl[i] == sub { - last := len(sl) - 1 - sl[i] = sl[last] - sl[last] = nil - sl = sl[:last] - return shrinkAsNeeded(sl), true - } - } - return sl, false -} - -// Checks if we need to do a resize. This is for very large growth then -// subsequent return to a more normal size from unsubscribe. -func shrinkAsNeeded(sl []*subscription) []*subscription { - lsl := len(sl) - csl := cap(sl) - // Don't bother if list not too big - if csl <= 8 { - return sl - } - pFree := float32(csl-lsl) / float32(csl) - if pFree > 0.50 { - return append([]*subscription(nil), sl...) - } - return sl -} diff --git a/lib/sessions/memprovider.go b/lib/sessions/memprovider.go new file mode 100644 index 00000000..d82d117e --- /dev/null +++ b/lib/sessions/memprovider.go @@ -0,0 +1,76 @@ +// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessions + +import ( + "fmt" + "sync" +) + +var _ SessionsProvider = (*memProvider)(nil) + +func init() { + Register("mem", NewMemProvider()) +} + +type memProvider struct { + st map[string]*Session + mu sync.RWMutex +} + +func NewMemProvider() *memProvider { + return &memProvider{ + st: make(map[string]*Session), + } +} + +func (this *memProvider) New(id string) (*Session, error) { + this.mu.Lock() + defer this.mu.Unlock() + + this.st[id] = &Session{id: id} + return this.st[id], nil +} + +func (this *memProvider) Get(id string) (*Session, error) { + this.mu.RLock() + defer this.mu.RUnlock() + + sess, ok := this.st[id] + if !ok { + return nil, fmt.Errorf("store/Get: No session found for key %s", id) + } + + return sess, nil +} + +func (this *memProvider) Del(id string) { + this.mu.Lock() + defer this.mu.Unlock() + delete(this.st, id) +} + +func (this *memProvider) Save(id string) error { + return nil +} + +func (this *memProvider) Count() int { + return len(this.st) +} + +func (this *memProvider) Close() error { + this.st = make(map[string]*Session) + return nil +} diff --git a/lib/sessions/redisprovider.go b/lib/sessions/redisprovider.go new file mode 100644 index 00000000..30c701d8 --- /dev/null +++ b/lib/sessions/redisprovider.go @@ -0,0 +1,95 @@ +package sessions + +import ( + "time" + + log "github.com/cihub/seelog" + "github.com/go-redis/redis" + jsoniter "github.com/json-iterator/go" +) + +var redisClient *redis.Client +var _ SessionsProvider = (*redisProvider)(nil) + +const ( + sessionName = "session" +) + +type redisProvider struct { +} + +func init() { + Register("redis", NewRedisProvider()) +} + +func InitRedisConn(url string) { + redisClient = redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + Password: "", // no password set + DB: 0, // use default DB + }) + err := redisClient.Ping().Err() + for err != nil { + log.Error("connect redis error: ", err, " 3s try again...") + time.Sleep(3 * time.Second) + err = redisClient.Ping().Err() + } +} + +func NewRedisProvider() *redisProvider { + return &redisProvider{} +} + +func (r *redisProvider) New(id string) (*Session, error) { + val, _ := jsoniter.Marshal(&Session{id: id}) + + err := redisClient.HSet(sessionName, id, val).Err() + if err != nil { + return nil, err + } + + result, err := redisClient.HGet(sessionName, id).Bytes() + if err != nil { + return nil, err + } + + sess := Session{} + err = jsoniter.Unmarshal(result, &sess) + if err != nil { + return nil, err + } + + return &sess, nil +} + +func (r *redisProvider) Get(id string) (*Session, error) { + + result, err := redisClient.HGet(sessionName, id).Bytes() + if err != nil { + return nil, err + } + + sess := Session{} + err = jsoniter.Unmarshal(result, &sess) + if err != nil { + return nil, err + } + + return &sess, nil +} + +func (r *redisProvider) Del(id string) { + redisClient.HDel(sessionName, id) +} + +func (r *redisProvider) Save(id string) error { + return nil +} + +func (r *redisProvider) Count() int { + return int(redisClient.HLen(sessionName).Val()) +} + +func (r *redisProvider) Close() error { + return redisClient.Del(sessionName).Err() +} diff --git a/lib/sessions/session.go b/lib/sessions/session.go new file mode 100644 index 00000000..83b8d296 --- /dev/null +++ b/lib/sessions/session.go @@ -0,0 +1,149 @@ +package sessions + +import ( + "fmt" + "sync" + + "github.com/eclipse/paho.mqtt.golang/packets" +) + +const ( + // Queue size for the ack queue + defaultQueueSize = 16 +) + +type Session struct { + + // cmsg is the CONNECT message + cmsg *packets.ConnectPacket + + // Will message to publish if connect is closed unexpectedly + Will *packets.PublishPacket + + // Retained publish message + Retained *packets.PublishPacket + + // topics stores all the topis for this session/client + topics map[string]byte + + // Initialized? + initted bool + + // Serialize access to this session + mu sync.Mutex + + id string +} + +func (this *Session) Init(msg *packets.ConnectPacket) error { + this.mu.Lock() + defer this.mu.Unlock() + + if this.initted { + return fmt.Errorf("Session already initialized") + } + + this.cmsg = msg + + if this.cmsg.WillFlag { + this.Will = packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) + this.Will.Qos = this.cmsg.Qos + this.Will.TopicName = this.cmsg.WillTopic + this.Will.Payload = this.cmsg.WillMessage + this.Will.Retain = this.cmsg.WillRetain + } + + this.topics = make(map[string]byte, 1) + + this.id = string(msg.ClientIdentifier) + + this.initted = true + + return nil +} + +func (this *Session) Update(msg *packets.ConnectPacket) error { + this.mu.Lock() + defer this.mu.Unlock() + + this.cmsg = msg + return nil +} + +func (this *Session) RetainMessage(msg *packets.PublishPacket) error { + this.mu.Lock() + defer this.mu.Unlock() + + this.Retained = msg + + return nil +} + +func (this *Session) AddTopic(topic string, qos byte) error { + this.mu.Lock() + defer this.mu.Unlock() + + if !this.initted { + return fmt.Errorf("Session not yet initialized") + } + + this.topics[topic] = qos + + return nil +} + +func (this *Session) RemoveTopic(topic string) error { + this.mu.Lock() + defer this.mu.Unlock() + + if !this.initted { + return fmt.Errorf("Session not yet initialized") + } + + delete(this.topics, topic) + + return nil +} + +func (this *Session) Topics() ([]string, []byte, error) { + this.mu.Lock() + defer this.mu.Unlock() + + if !this.initted { + return nil, nil, fmt.Errorf("Session not yet initialized") + } + + var ( + topics []string + qoss []byte + ) + + for k, v := range this.topics { + topics = append(topics, k) + qoss = append(qoss, v) + } + + return topics, qoss, nil +} + +func (this *Session) ID() string { + return this.cmsg.ClientIdentifier +} + +func (this *Session) WillFlag() bool { + this.mu.Lock() + defer this.mu.Unlock() + return this.cmsg.WillFlag +} + +func (this *Session) SetWillFlag(v bool) { + this.mu.Lock() + defer this.mu.Unlock() + this.cmsg.WillFlag = v +} + +func (this *Session) CleanSession() bool { + this.mu.Lock() + defer this.mu.Unlock() + return this.cmsg.CleanSession +} diff --git a/lib/sessions/sessions.go b/lib/sessions/sessions.go new file mode 100644 index 00000000..b160d516 --- /dev/null +++ b/lib/sessions/sessions.go @@ -0,0 +1,92 @@ +package sessions + +import ( + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" +) + +var ( + ErrSessionsProviderNotFound = errors.New("Session: Session provider not found") + ErrKeyNotAvailable = errors.New("Session: not item found for key.") + + providers = make(map[string]SessionsProvider) +) + +type SessionsProvider interface { + New(id string) (*Session, error) + Get(id string) (*Session, error) + Del(id string) + Save(id string) error + Count() int + Close() error +} + +// Register makes a session provider available by the provided name. +// If a Register is called twice with the same name or if the driver is nil, +// it panics. +func Register(name string, provider SessionsProvider) { + if provider == nil { + panic("session: Register provide is nil") + } + + if _, dup := providers[name]; dup { + panic("session: Register called twice for provider " + name) + } + + providers[name] = provider +} + +func Unregister(name string) { + delete(providers, name) +} + +type Manager struct { + p SessionsProvider +} + +func NewManager(providerName string) (*Manager, error) { + p, ok := providers[providerName] + if !ok { + return nil, fmt.Errorf("session: unknown provider %q", providerName) + } + + return &Manager{p: p}, nil +} + +func (this *Manager) New(id string) (*Session, error) { + if id == "" { + id = this.sessionId() + } + return this.p.New(id) +} + +func (this *Manager) Get(id string) (*Session, error) { + return this.p.Get(id) +} + +func (this *Manager) Del(id string) { + this.p.Del(id) +} + +func (this *Manager) Save(id string) error { + return this.p.Save(id) +} + +func (this *Manager) Count() int { + return this.p.Count() +} + +func (this *Manager) Close() error { + return this.p.Close() +} + +func (manager *Manager) sessionId() string { + b := make([]byte, 15) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "" + } + return base64.URLEncoding.EncodeToString(b) +} diff --git a/lib/topics/memtopics.go b/lib/topics/memtopics.go new file mode 100644 index 00000000..eb813455 --- /dev/null +++ b/lib/topics/memtopics.go @@ -0,0 +1,549 @@ +package topics + +import ( + "fmt" + "reflect" + "sync" + + "github.com/eclipse/paho.mqtt.golang/packets" +) + +const ( + QosAtMostOnce byte = iota + QosAtLeastOnce + QosExactlyOnce + QosFailure = 0x80 +) + +var _ TopicsProvider = (*memTopics)(nil) + +type memTopics struct { + // Sub/unsub mutex + smu sync.RWMutex + // Subscription tree + sroot *snode + + // Retained message mutex + rmu sync.RWMutex + // Retained messages topic tree + rroot *rnode +} + +func init() { + Register("mem", NewMemProvider()) +} + +// NewMemProvider returns an new instance of the memTopics, which is implements the +// TopicsProvider interface. memProvider is a hidden struct that stores the topic +// subscriptions and retained messages in memory. The content is not persistend so +// when the server goes, everything will be gone. Use with care. +func NewMemProvider() *memTopics { + return &memTopics{ + sroot: newSNode(), + rroot: newRNode(), + } +} + +func ValidQos(qos byte) bool { + return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce +} + +func (this *memTopics) Subscribe(topic []byte, qos byte, sub interface{}) (byte, error) { + if !ValidQos(qos) { + return QosFailure, fmt.Errorf("Invalid QoS %d", qos) + } + + if sub == nil { + return QosFailure, fmt.Errorf("Subscriber cannot be nil") + } + + this.smu.Lock() + defer this.smu.Unlock() + + if qos > QosExactlyOnce { + qos = QosExactlyOnce + } + + if err := this.sroot.sinsert(topic, qos, sub); err != nil { + return QosFailure, err + } + + return qos, nil +} + +func (this *memTopics) Unsubscribe(topic []byte, sub interface{}) error { + this.smu.Lock() + defer this.smu.Unlock() + + return this.sroot.sremove(topic, sub) +} + +// Returned values will be invalidated by the next Subscribers call +func (this *memTopics) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error { + if !ValidQos(qos) { + return fmt.Errorf("Invalid QoS %d", qos) + } + + this.smu.RLock() + defer this.smu.RUnlock() + + *subs = (*subs)[0:0] + *qoss = (*qoss)[0:0] + + return this.sroot.smatch(topic, qos, subs, qoss) +} + +func (this *memTopics) Retain(msg *packets.PublishPacket) error { + this.rmu.Lock() + defer this.rmu.Unlock() + + // So apparently, at least according to the MQTT Conformance/Interoperability + // Testing, that a payload of 0 means delete the retain message. + // https://eclipse.org/paho/clients/testing/ + if len(msg.Payload) == 0 { + return this.rroot.rremove([]byte(msg.TopicName)) + } + + return this.rroot.rinsert([]byte(msg.TopicName), msg) +} + +func (this *memTopics) Retained(topic []byte, msgs *[]*packets.PublishPacket) error { + this.rmu.RLock() + defer this.rmu.RUnlock() + + return this.rroot.rmatch(topic, msgs) +} + +func (this *memTopics) Close() error { + this.sroot = nil + this.rroot = nil + return nil +} + +// subscrition nodes +type snode struct { + // If this is the end of the topic string, then add subscribers here + subs []interface{} + qos []byte + + // Otherwise add the next topic level here + snodes map[string]*snode +} + +func newSNode() *snode { + return &snode{ + snodes: make(map[string]*snode), + } +} + +func (this *snode) sinsert(topic []byte, qos byte, sub interface{}) error { + // If there's no more topic levels, that means we are at the matching snode + // to insert the subscriber. So let's see if there's such subscriber, + // if so, update it. Otherwise insert it. + if len(topic) == 0 { + // Let's see if the subscriber is already on the list. If yes, update + // QoS and then return. + for i := range this.subs { + if equal(this.subs[i], sub) { + this.qos[i] = qos + return nil + } + } + + // Otherwise add. + this.subs = append(this.subs, sub) + this.qos = append(this.qos, qos) + + return nil + } + + // Not the last level, so let's find or create the next level snode, and + // recursively call it's insert(). + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + // Add snode if it doesn't already exist + n, ok := this.snodes[level] + if !ok { + n = newSNode() + this.snodes[level] = n + } + + return n.sinsert(rem, qos, sub) +} + +// This remove implementation ignores the QoS, as long as the subscriber +// matches then it's removed +func (this *snode) sremove(topic []byte, sub interface{}) error { + // If the topic is empty, it means we are at the final matching snode. If so, + // let's find the matching subscribers and remove them. + if len(topic) == 0 { + // If subscriber == nil, then it's signal to remove ALL subscribers + if sub == nil { + this.subs = this.subs[0:0] + this.qos = this.qos[0:0] + return nil + } + + // If we find the subscriber then remove it from the list. Technically + // we just overwrite the slot by shifting all other items up by one. + for i := range this.subs { + if equal(this.subs[i], sub) { + this.subs = append(this.subs[:i], this.subs[i+1:]...) + this.qos = append(this.qos[:i], this.qos[i+1:]...) + return nil + } + } + + return fmt.Errorf("No topic found for subscriber") + } + + // Not the last level, so let's find the next level snode, and recursively + // call it's remove(). + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + // Find the snode that matches the topic level + n, ok := this.snodes[level] + if !ok { + return fmt.Errorf("No topic found") + } + + // Remove the subscriber from the next level snode + if err := n.sremove(rem, sub); err != nil { + return err + } + + // If there are no more subscribers and snodes to the next level we just visited + // let's remove it + if len(n.subs) == 0 && len(n.snodes) == 0 { + delete(this.snodes, level) + } + + return nil +} + +// smatch() returns all the subscribers that are subscribed to the topic. Given a topic +// with no wildcards (publish topic), it returns a list of subscribers that subscribes +// to the topic. For each of the level names, it's a match +// - if there are subscribers to '#', then all the subscribers are added to result set +func (this *snode) smatch(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error { + // If the topic is empty, it means we are at the final matching snode. If so, + // let's find the subscribers that match the qos and append them to the list. + if len(topic) == 0 { + this.matchQos(qos, subs, qoss) + return nil + } + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + for k, n := range this.snodes { + // If the key is "#", then these subscribers are added to the result set + if k == MWC { + n.matchQos(qos, subs, qoss) + } else if k == SWC || k == level { + if err := n.smatch(rem, qos, subs, qoss); err != nil { + return err + } + } + } + + return nil +} + +// retained message nodes +type rnode struct { + // If this is the end of the topic string, then add retained messages here + msg *packets.PublishPacket + // Otherwise add the next topic level here + rnodes map[string]*rnode +} + +func newRNode() *rnode { + return &rnode{ + rnodes: make(map[string]*rnode), + } +} + +func (this *rnode) rinsert(topic []byte, msg *packets.PublishPacket) error { + // If there's no more topic levels, that means we are at the matching rnode. + if len(topic) == 0 { + // Reuse the message if possible + if this.msg == nil { + this.msg = msg + } + + return nil + } + + // Not the last level, so let's find or create the next level snode, and + // recursively call it's insert(). + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + // Add snode if it doesn't already exist + n, ok := this.rnodes[level] + if !ok { + n = newRNode() + this.rnodes[level] = n + } + + return n.rinsert(rem, msg) +} + +// Remove the retained message for the supplied topic +func (this *rnode) rremove(topic []byte) error { + // If the topic is empty, it means we are at the final matching rnode. If so, + // let's remove the buffer and message. + if len(topic) == 0 { + this.msg = nil + return nil + } + + // Not the last level, so let's find the next level rnode, and recursively + // call it's remove(). + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + // Find the rnode that matches the topic level + n, ok := this.rnodes[level] + if !ok { + return fmt.Errorf("No topic found") + } + + // Remove the subscriber from the next level rnode + if err := n.rremove(rem); err != nil { + return err + } + + // If there are no more rnodes to the next level we just visited let's remove it + if len(n.rnodes) == 0 { + delete(this.rnodes, level) + } + + return nil +} + +// rmatch() finds the retained messages for the topic and qos provided. It's somewhat +// of a reverse match compare to match() since the supplied topic can contain +// wildcards, whereas the retained message topic is a full (no wildcard) topic. +func (this *rnode) rmatch(topic []byte, msgs *[]*packets.PublishPacket) error { + // If the topic is empty, it means we are at the final matching rnode. If so, + // add the retained msg to the list. + if len(topic) == 0 { + if this.msg != nil { + *msgs = append(*msgs, this.msg) + } + return nil + } + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + if level == MWC { + // If '#', add all retained messages starting this node + this.allRetained(msgs) + } else if level == SWC { + // If '+', check all nodes at this level. Next levels must be matched. + for _, n := range this.rnodes { + if err := n.rmatch(rem, msgs); err != nil { + return err + } + } + } else { + // Otherwise, find the matching node, go to the next level + if n, ok := this.rnodes[level]; ok { + if err := n.rmatch(rem, msgs); err != nil { + return err + } + } + } + + return nil +} + +func (this *rnode) allRetained(msgs *[]*packets.PublishPacket) { + if this.msg != nil { + *msgs = append(*msgs, this.msg) + } + + for _, n := range this.rnodes { + n.allRetained(msgs) + } +} + +const ( + stateCHR byte = iota // Regular character + stateMWC // Multi-level wildcard + stateSWC // Single-level wildcard + stateSEP // Topic level separator + stateSYS // System level topic ($) +) + +// Returns topic level, remaining topic levels and any errors +func nextTopicLevel(topic []byte) ([]byte, []byte, error) { + s := stateCHR + + for i, c := range topic { + switch c { + case '/': + if s == stateMWC { + return nil, nil, fmt.Errorf("Multi-level wildcard found in topic and it's not at the last level") + } + + if i == 0 { + return []byte(SWC), topic[i+1:], nil + } + + return topic[:i], topic[i+1:], nil + + case '#': + if i != 0 { + return nil, nil, fmt.Errorf("Wildcard character '#' must occupy entire topic level") + } + + s = stateMWC + + case '+': + if i != 0 { + return nil, nil, fmt.Errorf("Wildcard character '+' must occupy entire topic level") + } + + s = stateSWC + + // case '$': + // if i == 0 { + // return nil, nil, fmt.Errorf("Cannot publish to $ topics") + // } + + // s = stateSYS + + default: + if s == stateMWC || s == stateSWC { + return nil, nil, fmt.Errorf("Wildcard characters '#' and '+' must occupy entire topic level") + } + + s = stateCHR + } + } + + // If we got here that means we didn't hit the separator along the way, so the + // topic is either empty, or does not contain a separator. Either way, we return + // the full topic + return topic, nil, nil +} + +// The QoS of the payload messages sent in response to a subscription must be the +// minimum of the QoS of the originally published message (in this case, it's the +// qos parameter) and the maximum QoS granted by the server (in this case, it's +// the QoS in the topic tree). +// +// It's also possible that even if the topic matches, the subscriber is not included +// due to the QoS granted is lower than the published message QoS. For example, +// if the client is granted only QoS 0, and the publish message is QoS 1, then this +// client is not to be send the published message. +func (this *snode) matchQos(qos byte, subs *[]interface{}, qoss *[]byte) { + for _, sub := range this.subs { + // If the published QoS is higher than the subscriber QoS, then we skip the + // subscriber. Otherwise, add to the list. + // if qos >= this.qos[i] { + *subs = append(*subs, sub) + *qoss = append(*qoss, qos) + // } + } +} + +func equal(k1, k2 interface{}) bool { + if reflect.TypeOf(k1) != reflect.TypeOf(k2) { + return false + } + + if reflect.ValueOf(k1).Kind() == reflect.Func { + return &k1 == &k2 + } + + if k1 == k2 { + return true + } + + switch k1 := k1.(type) { + case string: + return k1 == k2.(string) + + case int64: + return k1 == k2.(int64) + + case int32: + return k1 == k2.(int32) + + case int16: + return k1 == k2.(int16) + + case int8: + return k1 == k2.(int8) + + case int: + return k1 == k2.(int) + + case float32: + return k1 == k2.(float32) + + case float64: + return k1 == k2.(float64) + + case uint: + return k1 == k2.(uint) + + case uint8: + return k1 == k2.(uint8) + + case uint16: + return k1 == k2.(uint16) + + case uint32: + return k1 == k2.(uint32) + + case uint64: + return k1 == k2.(uint64) + + case uintptr: + return k1 == k2.(uintptr) + } + + return false +} diff --git a/lib/topics/topics.go b/lib/topics/topics.go new file mode 100644 index 00000000..b99696a9 --- /dev/null +++ b/lib/topics/topics.go @@ -0,0 +1,91 @@ +package topics + +import ( + "fmt" + + "github.com/eclipse/paho.mqtt.golang/packets" +) + +const ( + // MWC is the multi-level wildcard + MWC = "#" + + // SWC is the single level wildcard + SWC = "+" + + // SEP is the topic level separator + SEP = "/" + + // SYS is the starting character of the system level topics + SYS = "$" + + // Both wildcards + _WC = "#+" +) + +var ( + providers = make(map[string]TopicsProvider) +) + +// TopicsProvider +type TopicsProvider interface { + Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error) + Unsubscribe(topic []byte, subscriber interface{}) error + Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error + Retain(msg *packets.PublishPacket) error + Retained(topic []byte, msgs *[]*packets.PublishPacket) error + Close() error +} + +func Register(name string, provider TopicsProvider) { + if provider == nil { + panic("topics: Register provide is nil") + } + + if _, dup := providers[name]; dup { + panic("topics: Register called twice for provider " + name) + } + + providers[name] = provider +} + +func Unregister(name string) { + delete(providers, name) +} + +type Manager struct { + p TopicsProvider +} + +func NewManager(providerName string) (*Manager, error) { + p, ok := providers[providerName] + if !ok { + return nil, fmt.Errorf("session: unknown provider %q", providerName) + } + + return &Manager{p: p}, nil +} + +func (this *Manager) Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error) { + return this.p.Subscribe(topic, qos, subscriber) +} + +func (this *Manager) Unsubscribe(topic []byte, subscriber interface{}) error { + return this.p.Unsubscribe(topic, subscriber) +} + +func (this *Manager) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error { + return this.p.Subscribers(topic, qos, subs, qoss) +} + +func (this *Manager) Retain(msg *packets.PublishPacket) error { + return this.p.Retain(msg) +} + +func (this *Manager) Retained(topic []byte, msgs *[]*packets.PublishPacket) error { + return this.p.Retained(topic, msgs) +} + +func (this *Manager) Close() error { + return this.p.Close() +} diff --git a/logger/logger_test.go b/logger/logger_test.go index c1291d42..d1af299d 100644 --- a/logger/logger_test.go +++ b/logger/logger_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "go.uber.org/zap" ) diff --git a/main.go b/main.go index f58d31d0..0062ff69 100644 --- a/main.go +++ b/main.go @@ -8,10 +8,11 @@ package main import ( "fmt" - "github.com/fhmq/hmq/broker" "os" "os/signal" "runtime" + + "github.com/fhmq/hmq/broker" ) func main() {