Skip to content

Commit

Permalink
NIFI-11270 Refactoring of the overly Paho-specific MQTT interface
Browse files Browse the repository at this point in the history
This closes #7032.

Signed-off-by: Peter Turcsanyi <turcsanyi@apache.org>
  • Loading branch information
nandorsoma authored and turcsanyip committed Mar 29, 2023
1 parent 0afd155 commit 2b9f207
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.processors.mqtt.common.AbstractMQTTProcessor;
import org.apache.nifi.processors.mqtt.common.MqttCallback;
import org.apache.nifi.processors.mqtt.common.MqttException;
import org.apache.nifi.processors.mqtt.common.ReceivedMqttMessage;
import org.apache.nifi.serialization.MalformedRecordException;
Expand Down Expand Up @@ -104,7 +103,7 @@
"on the topic.")})
@SystemResourceConsideration(resource = SystemResource.MEMORY, description = "The 'Max Queue Size' specifies the maximum number of messages that can be hold in memory by NiFi by a single "
+ "instance of this processor. A high value for this property could represent a lot of data being stored in memory.")
public class ConsumeMQTT extends AbstractMQTTProcessor implements MqttCallback {
public class ConsumeMQTT extends AbstractMQTTProcessor {

public final static String RECORD_COUNT_KEY = "record.count";
public final static String BROKER_ATTRIBUTE_KEY = "mqtt.broker";
Expand Down Expand Up @@ -383,9 +382,8 @@ private void initializeClient(ProcessContext context) {
// non-null but not connected, so we need to handle each case and only create a new client when it is null
try {
mqttClient = createMqttClient();
mqttClient.setCallback(this);
mqttClient.connect();
mqttClient.subscribe(topicPrefix + topicFilter, qos);
mqttClient.subscribe(topicPrefix + topicFilter, qos, this::handleReceivedMessage);
} catch (Exception e) {
logger.error("Connection failed to {}. Yielding processor", clientProperties.getRawBrokerUris(), e);
mqttClient = null; // prevent stucked processor when subscribe fails
Expand Down Expand Up @@ -614,13 +612,7 @@ private String getTransitUri(String... appends) {
return stringBuilder.toString();
}

@Override
public void connectionLost(Throwable cause) {
logger.error("Connection to {} lost", clientProperties.getRawBrokerUris(), cause);
}

@Override
public void messageArrived(ReceivedMqttMessage message) {
private void handleReceivedMessage(ReceivedMqttMessage message) {
if (logger.isDebugEnabled()) {
byte[] payload = message.getPayload();
final String text = new String(payload, StandardCharsets.UTF_8);
Expand All @@ -639,11 +631,4 @@ public void messageArrived(ReceivedMqttMessage message) {
throw new MqttException("Failed to process message arrived from topic " + message.getTopic());
}
}

@Override
public void deliveryComplete(String token) {
// Unlikely situation. Api uses the same callback for publisher and consumer as well.
// That's why we have this log message here to indicate something really messy thing happened.
logger.error("Received MQTT 'delivery complete' message to subscriber. Token: [{}]", token);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
import org.apache.nifi.processor.exception.ProcessException;
import org.apache.nifi.processor.util.StandardValidators;
import org.apache.nifi.processors.mqtt.common.AbstractMQTTProcessor;
import org.apache.nifi.processors.mqtt.common.MqttCallback;
import org.apache.nifi.processors.mqtt.common.ReceivedMqttMessage;
import org.apache.nifi.processors.mqtt.common.StandardMqttMessage;
import org.apache.nifi.schema.access.SchemaNotFoundException;
import org.apache.nifi.serialization.MalformedRecordException;
Expand Down Expand Up @@ -74,7 +72,7 @@
@CapabilityDescription("Publishes a message to an MQTT topic")
@SeeAlso({ConsumeMQTT.class})
@SystemResourceConsideration(resource = SystemResource.MEMORY)
public class PublishMQTT extends AbstractMQTTProcessor implements MqttCallback {
public class PublishMQTT extends AbstractMQTTProcessor {

public static final PropertyDescriptor PROP_TOPIC = new PropertyDescriptor.Builder()
.name("Topic")
Expand Down Expand Up @@ -289,32 +287,13 @@ private void initializeClient(ProcessContext context) {
// non-null but not connected, so we need to handle each case and only create a new client when it is null
try {
mqttClient = createMqttClient();
mqttClient.setCallback(this);
mqttClient.connect();
} catch (Exception e) {
logger.error("Connection failed to {}. Yielding processor", clientProperties.getRawBrokerUris(), e);
context.yield();
}
}

@Override
public void connectionLost(Throwable cause) {
logger.error("Connection to {} lost", clientProperties.getRawBrokerUris(), cause);
}

@Override
public void messageArrived(ReceivedMqttMessage message) {
// Unlikely situation. Api uses the same callback for publisher and consumer as well.
// That's why we have this log message here to indicate something really messy thing happened.
logger.error("Message arrived to a PublishMQTT processor { topic:'" + message.getTopic() + "; payload:" + Arrays.toString(message.getPayload()) + "}");
}

@Override
public void deliveryComplete(String token) {
// Client.publish waits for message to be delivered so this token will always have a null message and is useless in this application.
logger.trace("Received 'delivery complete' message from broker. Token: [{}]", token);
}

interface ProcessStrategy {
void process(ProcessContext context, FlowFile flowfile, InputStream in, String topic, AtomicInteger processedRecords, Long previousProcessFailedAt) throws IOException;
String getFailureTemplateMessage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
import com.hivemq.client.mqtt.mqtt5.message.connect.Mqtt5ConnectBuilder;
import com.hivemq.client.mqtt.mqtt5.message.subscribe.suback.Mqtt5SubAck;
import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processors.mqtt.common.MqttCallback;
import org.apache.nifi.processors.mqtt.common.MqttClient;
import org.apache.nifi.processors.mqtt.common.MqttClientProperties;
import org.apache.nifi.processors.mqtt.common.MqttException;
import org.apache.nifi.processors.mqtt.common.MqttProtocolScheme;
import org.apache.nifi.processors.mqtt.common.ReceivedMqttMessage;
import org.apache.nifi.processors.mqtt.common.ReceivedMqttMessageHandler;
import org.apache.nifi.processors.mqtt.common.StandardMqttMessage;
import org.apache.nifi.security.util.KeyStoreUtils;
import org.apache.nifi.security.util.TlsException;
Expand All @@ -50,8 +50,6 @@ public class HiveMqV5ClientAdapter implements MqttClient {
private final MqttClientProperties clientProperties;
private final ComponentLog logger;

private MqttCallback callback;

public HiveMqV5ClientAdapter(URI brokerUri, MqttClientProperties clientProperties, ComponentLog logger) throws TlsException {
this.mqtt5BlockingClient = createClient(brokerUri, clientProperties, logger);
this.clientProperties = clientProperties;
Expand Down Expand Up @@ -124,9 +122,7 @@ public void publish(String topic, StandardMqttMessage message) {
}

@Override
public void subscribe(String topicFilter, int qos) {
Objects.requireNonNull(callback, "callback should be set");

public void subscribe(String topicFilter, int qos, ReceivedMqttMessageHandler handler) {
logger.debug("Subscribing to {} with QoS: {}", topicFilter, qos);

CompletableFuture<Mqtt5SubAck> futureAck = mqtt5BlockingClient.toAsync().subscribeWith()
Expand All @@ -138,7 +134,7 @@ public void subscribe(String topicFilter, int qos) {
mqtt5Publish.getQos().getCode(),
mqtt5Publish.isRetain(),
mqtt5Publish.getTopic().toString());
callback.messageArrived(receivedMessage);
handler.handleReceivedMessage(receivedMessage);
})
.send();

Expand All @@ -152,11 +148,6 @@ public void subscribe(String topicFilter, int qos) {
}
}

@Override
public void setCallback(MqttCallback callback) {
this.callback = callback;
}

private static Mqtt5BlockingClient createClient(URI brokerUri, MqttClientProperties clientProperties, ComponentLog logger) throws TlsException {
logger.debug("Creating Mqtt v5 client");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@
package org.apache.nifi.processors.mqtt.adapters;

import org.apache.nifi.logging.ComponentLog;
import org.apache.nifi.processors.mqtt.common.MqttCallback;
import org.apache.nifi.processors.mqtt.common.MqttClient;
import org.apache.nifi.processors.mqtt.common.MqttClientProperties;
import org.apache.nifi.processors.mqtt.common.MqttClient;
import org.apache.nifi.processors.mqtt.common.MqttException;
import org.apache.nifi.processors.mqtt.common.ReceivedMqttMessage;
import org.apache.nifi.processors.mqtt.common.ReceivedMqttMessageHandler;
import org.apache.nifi.processors.mqtt.common.StandardMqttMessage;
import org.apache.nifi.security.util.TlsConfiguration;
import org.eclipse.paho.client.mqttv3.IMqttClient;
import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken;
import org.eclipse.paho.client.mqttv3.MqttCallback;
import org.eclipse.paho.client.mqttv3.MqttConnectOptions;
import org.eclipse.paho.client.mqttv3.MqttMessage;
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;

import java.net.URI;
import java.util.Arrays;
import java.util.Properties;

public class PahoMqttClientAdapter implements MqttClient {
Expand All @@ -45,6 +47,7 @@ public PahoMqttClientAdapter(URI brokerUri, MqttClientProperties clientPropertie
this.client = createClient(brokerUri, clientProperties, logger);
this.clientProperties = clientProperties;
this.logger = logger;
client.setCallback(new DefaultMqttCallback());
}

@Override
Expand Down Expand Up @@ -121,38 +124,18 @@ public void publish(String topic, StandardMqttMessage message) {
}

@Override
public void subscribe(String topicFilter, int qos) {
public void subscribe(String topicFilter, int qos, ReceivedMqttMessageHandler handler) {
logger.debug("Subscribing to {} with QoS: {}", topicFilter, qos);

client.setCallback(new ConsumerMqttCallback(handler));

try {
client.subscribe(topicFilter, qos);
} catch (org.eclipse.paho.client.mqttv3.MqttException e) {
throw new MqttException("An error has occurred during subscribing to " + topicFilter + " with QoS: " + qos, e);
}
}

@Override
public void setCallback(MqttCallback callback) {
client.setCallback(new org.eclipse.paho.client.mqttv3.MqttCallback() {
@Override
public void connectionLost(Throwable cause) {
callback.connectionLost(cause);
}

@Override
public void messageArrived(String topic, MqttMessage message) {
logger.debug("Message arrived with id: {}", message.getId());
final ReceivedMqttMessage receivedMessage = new ReceivedMqttMessage(message.getPayload(), message.getQos(), message.isRetained(), topic);
callback.messageArrived(receivedMessage);
}

@Override
public void deliveryComplete(IMqttDeliveryToken token) {
callback.deliveryComplete(token.toString());
}
});
}

public static Properties transformSSLContextService(TlsConfiguration tlsConfiguration) {
final Properties properties = new Properties();
if (tlsConfiguration.getProtocol() != null) {
Expand All @@ -176,7 +159,7 @@ public static Properties transformSSLContextService(TlsConfiguration tlsConfigur
if (tlsConfiguration.getTruststoreType() != null) {
properties.setProperty("com.ibm.ssl.trustStoreType", tlsConfiguration.getTruststoreType().getType());
}
return properties;
return properties;
}

private static org.eclipse.paho.client.mqttv3.MqttClient createClient(URI brokerUri, MqttClientProperties clientProperties, ComponentLog logger) {
Expand All @@ -189,4 +172,58 @@ private static org.eclipse.paho.client.mqttv3.MqttClient createClient(URI broker
}
}

/**
* Paho API uses the same callback for the publisher and consumer as well.
* Because of that, DefaultMqttCallback sets some reasonable default logs
* to make it easier to track misconfiguration errors.
* <p>
* In case of subscribing clients messageArrived needs to be overridden.
*/
private class DefaultMqttCallback implements MqttCallback {

@Override
public void connectionLost(Throwable cause) {
logger.error("Connection to [{}] lost", clientProperties.getRawBrokerUris(), cause);
}

@Override
public void messageArrived(String topic, MqttMessage message) {
// Unlikely situation. The Paho api uses the same callback for publisher and consumer as well. That's why
// we have this log message here to indicate something messy thing happened because we don't expect to
// receive messages until the client is not subscribed and the callback is not changed to ConsumerMqttCallback.
logger.error("MQTT message arrived [topic:{}; payload:{}]", topic, Arrays.toString(message.getPayload()));
}

@Override
public void deliveryComplete(IMqttDeliveryToken token) {
logger.trace("Received 'delivery complete' message from broker. Token: [{}]", token);
}
}

/**
* Subscriber specific implementation of MqttCallback
*/
private class ConsumerMqttCallback extends DefaultMqttCallback {

private final ReceivedMqttMessageHandler handler;

private ConsumerMqttCallback(ReceivedMqttMessageHandler handler) {
this.handler = handler;
}

@Override
public void messageArrived(String topic, MqttMessage message) {
logger.debug("Message arrived. Id: [{}]", message.getId());
final ReceivedMqttMessage receivedMessage = new ReceivedMqttMessage(message.getPayload(), message.getQos(), message.isRetained(), topic);
handler.handleReceivedMessage(receivedMessage);
}

@Override
public void deliveryComplete(IMqttDeliveryToken token) {
// Unlikely situation. The Paho api uses the same callback for publisher and consumer as well. That's why
// we have this log message here to indicate something messy thing happened because we don't expect to
// receive 'delivery complete' messages while the client is subscribed.
logger.error("Received MQTT 'delivery complete' message to a subscribed client. Token: [{}]", token);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,7 @@ public interface MqttClient {
* published at a lower quality of service will be received at the published
* QoS. Messages published at a higher quality of service will be received using
* the QoS specified on the subscribe.
* @param handler that further processes the message received by the client
*/
void subscribe(String topicFilter, int qos);

/**
* Sets a callback listener to use for events that happen asynchronously.
*
* @param callback for matching events
*/
void setCallback(MqttCallback callback);
void subscribe(String topicFilter, int qos, ReceivedMqttMessageHandler handler);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
*/
package org.apache.nifi.processors.mqtt.common;

public interface MqttCallback {
void connectionLost(Throwable cause);
void messageArrived(ReceivedMqttMessage message);
void deliveryComplete(String token);
public interface ReceivedMqttMessageHandler {

/**
* Handler to process received MQTT message
*
* @param message to process
*/
void handleReceivedMessage(ReceivedMqttMessage message);
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@ public class MqttTestClient implements MqttClient {

public AtomicBoolean connected = new AtomicBoolean(false);

public MqttCallback mqttCallback;
public ConnectType type;

public enum ConnectType {Publisher, Subscriber}

public String subscribedTopic;
public int subscribedQos;

public ReceivedMqttMessageHandler receivedMqttMessageHandler;
public MqttTestClient(ConnectType type) {
this.type = type;
}
Expand Down Expand Up @@ -68,20 +67,16 @@ public void publish(String topic, StandardMqttMessage message) {
publishedMessages.add(Pair.of(topic, message));
break;
case Subscriber:
mqttCallback.messageArrived(new ReceivedMqttMessage(message.getPayload(), message.getQos(), message.isRetained(), topic));
receivedMqttMessageHandler.handleReceivedMessage(new ReceivedMqttMessage(message.getPayload(), message.getQos(), message.isRetained(), topic));
break;
}
}

@Override
public void subscribe(String topicFilter, int qos) {
public void subscribe(String topicFilter, int qos, ReceivedMqttMessageHandler handler) {
subscribedTopic = topicFilter;
subscribedQos = qos;
}

@Override
public void setCallback(MqttCallback callback) {
this.mqttCallback = callback;
receivedMqttMessageHandler = handler;
}

public Pair<String, StandardMqttMessage> getLastPublished() {
Expand Down

0 comments on commit 2b9f207

Please sign in to comment.