diff --git a/apps/nsq_to_nsq/nsq_to_nsq.go b/apps/nsq_to_nsq/nsq_to_nsq.go index ae1926ba4..7e7c24b5e 100644 --- a/apps/nsq_to_nsq/nsq_to_nsq.go +++ b/apps/nsq_to_nsq/nsq_to_nsq.go @@ -31,10 +31,8 @@ const ( var ( showVersion = flag.Bool("version", false, "print version string") - - topic = flag.String("topic", "", "nsq topic") channel = flag.String("channel", "nsq_to_nsq", "nsq channel") - destTopic = flag.String("destination-topic", "", "destination nsq topic") + destTopic = flag.String("destination-topic", "", "use this destination topic for all consumed topics (default is consumed topic name)") maxInFlight = flag.Int("max-in-flight", 200, "max number of messages to allow in flight") statusEvery = flag.Int("status-every", 250, "the # of requests between logging status (per destination), 0 disables") @@ -44,6 +42,7 @@ var ( lookupdHTTPAddrs = app.StringArray{} destNsqdTCPAddrs = app.StringArray{} whitelistJSONFields = app.StringArray{} + topics = app.StringArray{} requireJSONField = flag.String("require-json-field", "", "for JSON messages: only pass messages that contain this field") requireJSONValue = flag.String("require-json-value", "", "for JSON messages: only pass messages in which the required field has this value") @@ -53,7 +52,7 @@ func init() { flag.Var(&nsqdTCPAddrs, "nsqd-tcp-address", "nsqd TCP address (may be given multiple times)") flag.Var(&destNsqdTCPAddrs, "destination-nsqd-tcp-address", "destination nsqd TCP address (may be given multiple times)") flag.Var(&lookupdHTTPAddrs, "lookupd-http-address", "lookupd HTTP address (may be given multiple times)") - + flag.Var(&topics, "topic", "nsq topic (may be given multiple times)") flag.Var(&whitelistJSONFields, "whitelist-json-field", "for JSON messages: pass this field (may be given multiple times)") } @@ -75,6 +74,11 @@ type PublishHandler struct { timermetrics *timer_metrics.TimerMetrics } +type TopicHandler struct { + publishHandler *PublishHandler + destinationTopic string +} + func (ph *PublishHandler) responder() { var msg *nsq.Message var startTime time.Time @@ -194,7 +198,11 @@ func filterMessage(js map[string]interface{}, rawMsg []byte) ([]byte, error) { return newRawMsg, nil } -func (ph *PublishHandler) HandleMessage(m *nsq.Message) error { +func (t *TopicHandler) HandleMessage(m *nsq.Message) error { + return t.publishHandler.HandleMessage(m, t.destinationTopic) +} + +func (ph *PublishHandler) HandleMessage(m *nsq.Message, destinationTopic string) error { var err error msgBody := m.Body @@ -214,6 +222,7 @@ func (ph *PublishHandler) HandleMessage(m *nsq.Message) error { } msgBody, err = filterMessage(js, msgBody) + if err != nil { log.Printf("ERROR: filterMessage() failed: %s", err) return err @@ -228,11 +237,11 @@ func (ph *PublishHandler) HandleMessage(m *nsq.Message) error { idx := counter % uint64(len(ph.addresses)) addr := ph.addresses[idx] p := ph.producers[addr] - err = p.PublishAsync(*destTopic, msgBody, ph.respChan, m, startTime, addr) + err = p.PublishAsync(destinationTopic, msgBody, ph.respChan, m, startTime, addr) case ModeHostPool: hostPoolResponse := ph.hostPool.Get() p := ph.producers[hostPoolResponse.Host()] - err = p.PublishAsync(*destTopic, msgBody, ph.respChan, m, startTime, hostPoolResponse) + err = p.PublishAsync(destinationTopic, msgBody, ph.respChan, m, startTime, hostPoolResponse) if err != nil { hostPoolResponse.Mark(err) } @@ -271,19 +280,17 @@ func main() { return } - if *topic == "" || *channel == "" { + if len(topics) == 0 || *channel == "" { log.Fatal("--topic and --channel are required") } - if *destTopic == "" { - *destTopic = *topic - } - - if !protocol.IsValidTopicName(*topic) { - log.Fatal("--topic is invalid") + for _, topic := range topics { + if !protocol.IsValidTopicName(topic) { + log.Fatal("--topic is invalid") + } } - if !protocol.IsValidTopicName(*destTopic) { + if *destTopic != "" && !protocol.IsValidTopicName(*destTopic) { log.Fatal("--destination-topic is invalid") } @@ -316,12 +323,6 @@ func main() { cCfg.UserAgent = defaultUA cCfg.MaxInFlight = *maxInFlight - - consumer, err := nsq.NewConsumer(*topic, *channel, cCfg) - if err != nil { - log.Fatal(err) - } - pCfg.UserAgent = defaultUA producers := make(map[string]*nsq.Producer) @@ -349,7 +350,9 @@ func main() { hostPool = hostpool.NewEpsilonGreedy(destNsqdTCPAddrs, 0, &hostpool.LinearEpsilonValueCalculator{}) } - handler := &PublishHandler{ + var consumerList []*nsq.Consumer + + publisher := &PublishHandler{ addresses: destNsqdTCPAddrs, producers: producers, mode: selectedMode, @@ -358,28 +361,48 @@ func main() { perAddressStatus: perAddressStatus, timermetrics: timer_metrics.NewTimerMetrics(*statusEvery, "[aggregate]:"), } - consumer.AddConcurrentHandlers(handler, len(destNsqdTCPAddrs)) + for _, topic := range topics { + consumer, err := nsq.NewConsumer(topic, *channel, cCfg) + consumerList = append(consumerList, consumer) + if err != nil { + log.Fatal(err) + } + + publishTopic := topic + if *destTopic != "" { + publishTopic = *destTopic + } + topicHandler := &TopicHandler{ + publishHandler: publisher, + destinationTopic: publishTopic, + } + consumer.AddConcurrentHandlers(topicHandler, len(destNsqdTCPAddrs)) + } for i := 0; i < len(destNsqdTCPAddrs); i++ { - go handler.responder() + go publisher.responder() } - err = consumer.ConnectToNSQDs(nsqdTCPAddrs) - if err != nil { - log.Fatal(err) + for _, consumer := range consumerList { + err := consumer.ConnectToNSQDs(nsqdTCPAddrs) + if err != nil { + log.Fatal(err) + } } - err = consumer.ConnectToNSQLookupds(lookupdHTTPAddrs) - if err != nil { - log.Fatal(err) + for _, consumer := range consumerList { + err := consumer.ConnectToNSQLookupds(lookupdHTTPAddrs) + if err != nil { + log.Fatal(err) + } } - for { - select { - case <-consumer.StopChan: - return - case <-termChan: - consumer.Stop() - } + <-termChan // wait for signal + + for _, consumer := range consumerList { + consumer.Stop() + } + for _, consumer := range consumerList { + <-consumer.StopChan } }