Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsq_to_nsq: multiple topics support #945

Merged
merged 2 commits into from
Sep 25, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 60 additions & 37 deletions apps/nsq_to_nsq/nsq_to_nsq.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ const (

var (
showVersion = flag.Bool("version", false, "print version string")

topic = flag.String("topic", "", "nsq topic")
channel = flag.String("channel", "nsq_to_nsq", "nsq channel")
destTopic = flag.String("destination-topic", "", "destination nsq topic")
destTopic = flag.String("destination-topic", "", "use this destination topic for all consumed topics (default is consumed topic name)")
maxInFlight = flag.Int("max-in-flight", 200, "max number of messages to allow in flight")

statusEvery = flag.Int("status-every", 250, "the # of requests between logging status (per destination), 0 disables")
Expand All @@ -44,6 +42,7 @@ var (
lookupdHTTPAddrs = app.StringArray{}
destNsqdTCPAddrs = app.StringArray{}
whitelistJSONFields = app.StringArray{}
topics = app.StringArray{}

requireJSONField = flag.String("require-json-field", "", "for JSON messages: only pass messages that contain this field")
requireJSONValue = flag.String("require-json-value", "", "for JSON messages: only pass messages in which the required field has this value")
Expand All @@ -53,7 +52,7 @@ func init() {
flag.Var(&nsqdTCPAddrs, "nsqd-tcp-address", "nsqd TCP address (may be given multiple times)")
flag.Var(&destNsqdTCPAddrs, "destination-nsqd-tcp-address", "destination nsqd TCP address (may be given multiple times)")
flag.Var(&lookupdHTTPAddrs, "lookupd-http-address", "lookupd HTTP address (may be given multiple times)")

flag.Var(&topics, "topic", "nsq topic (may be given multiple times)")
flag.Var(&whitelistJSONFields, "whitelist-json-field", "for JSON messages: pass this field (may be given multiple times)")
}

Expand All @@ -75,6 +74,11 @@ type PublishHandler struct {
timermetrics *timer_metrics.TimerMetrics
}

type TopicHandler struct {
publishHandler *PublishHandler
destinationTopic string
}

func (ph *PublishHandler) responder() {
var msg *nsq.Message
var startTime time.Time
Expand Down Expand Up @@ -194,7 +198,11 @@ func filterMessage(js map[string]interface{}, rawMsg []byte) ([]byte, error) {
return newRawMsg, nil
}

func (ph *PublishHandler) HandleMessage(m *nsq.Message) error {
func (t *TopicHandler) HandleMessage(m *nsq.Message) error {
return t.publishHandler.HandleMessage(m, t.destinationTopic)
}

func (ph *PublishHandler) HandleMessage(m *nsq.Message, destinationTopic string) error {
var err error
msgBody := m.Body

Expand All @@ -214,6 +222,7 @@ func (ph *PublishHandler) HandleMessage(m *nsq.Message) error {
}

msgBody, err = filterMessage(js, msgBody)

if err != nil {
log.Printf("ERROR: filterMessage() failed: %s", err)
return err
Expand All @@ -228,11 +237,11 @@ func (ph *PublishHandler) HandleMessage(m *nsq.Message) error {
idx := counter % uint64(len(ph.addresses))
addr := ph.addresses[idx]
p := ph.producers[addr]
err = p.PublishAsync(*destTopic, msgBody, ph.respChan, m, startTime, addr)
err = p.PublishAsync(destinationTopic, msgBody, ph.respChan, m, startTime, addr)
case ModeHostPool:
hostPoolResponse := ph.hostPool.Get()
p := ph.producers[hostPoolResponse.Host()]
err = p.PublishAsync(*destTopic, msgBody, ph.respChan, m, startTime, hostPoolResponse)
err = p.PublishAsync(destinationTopic, msgBody, ph.respChan, m, startTime, hostPoolResponse)
if err != nil {
hostPoolResponse.Mark(err)
}
Expand Down Expand Up @@ -271,19 +280,17 @@ func main() {
return
}

if *topic == "" || *channel == "" {
if len(topics) == 0 || *channel == "" {
log.Fatal("--topic and --channel are required")
}

if *destTopic == "" {
*destTopic = *topic
}

if !protocol.IsValidTopicName(*topic) {
log.Fatal("--topic is invalid")
for _, topic := range topics {
if !protocol.IsValidTopicName(topic) {
log.Fatal("--topic is invalid")
}
}

if !protocol.IsValidTopicName(*destTopic) {
if *destTopic != "" && !protocol.IsValidTopicName(*destTopic) {
log.Fatal("--destination-topic is invalid")
}

Expand Down Expand Up @@ -316,12 +323,6 @@ func main() {

cCfg.UserAgent = defaultUA
cCfg.MaxInFlight = *maxInFlight

consumer, err := nsq.NewConsumer(*topic, *channel, cCfg)
if err != nil {
log.Fatal(err)
}

pCfg.UserAgent = defaultUA

producers := make(map[string]*nsq.Producer)
Expand Down Expand Up @@ -349,7 +350,9 @@ func main() {
hostPool = hostpool.NewEpsilonGreedy(destNsqdTCPAddrs, 0, &hostpool.LinearEpsilonValueCalculator{})
}

handler := &PublishHandler{
var consumerList []*nsq.Consumer

publisher := &PublishHandler{
addresses: destNsqdTCPAddrs,
producers: producers,
mode: selectedMode,
Expand All @@ -358,28 +361,48 @@ func main() {
perAddressStatus: perAddressStatus,
timermetrics: timer_metrics.NewTimerMetrics(*statusEvery, "[aggregate]:"),
}
consumer.AddConcurrentHandlers(handler, len(destNsqdTCPAddrs))

for _, topic := range topics {
consumer, err := nsq.NewConsumer(topic, *channel, cCfg)
consumerList = append(consumerList, consumer)
if err != nil {
log.Fatal(err)
}

publishTopic := topic
if *destTopic != "" {
publishTopic = *destTopic
}
topicHandler := &TopicHandler{
publishHandler: publisher,
destinationTopic: publishTopic,
}
consumer.AddConcurrentHandlers(topicHandler, len(destNsqdTCPAddrs))
}
for i := 0; i < len(destNsqdTCPAddrs); i++ {
go handler.responder()
go publisher.responder()
}

err = consumer.ConnectToNSQDs(nsqdTCPAddrs)
if err != nil {
log.Fatal(err)
for _, consumer := range consumerList {
err := consumer.ConnectToNSQDs(nsqdTCPAddrs)
if err != nil {
log.Fatal(err)
}
}

err = consumer.ConnectToNSQLookupds(lookupdHTTPAddrs)
if err != nil {
log.Fatal(err)
for _, consumer := range consumerList {
err := consumer.ConnectToNSQLookupds(lookupdHTTPAddrs)
if err != nil {
log.Fatal(err)
}
}

for {
select {
case <-consumer.StopChan:
return
case <-termChan:
consumer.Stop()
}
<-termChan // wait for signal

for _, consumer := range consumerList {
consumer.Stop()
}
for _, consumer := range consumerList {
<-consumer.StopChan
}
}