package chat.server.moquette;
|
|
import java.util.ArrayList;
|
import java.util.List;
|
|
import org.apache.log4j.Logger;
|
|
import chat.server.moquette.MqttDecodHandler.DecoderState;
|
import chat.server.moquette.message.MqttConnAckVariableHeader;
|
import chat.server.moquette.message.MqttConnectPayload;
|
import chat.server.moquette.message.MqttConnectReturnCode;
|
import chat.server.moquette.message.MqttConnectVariableHeader;
|
import chat.server.moquette.message.MqttFixedHeader;
|
import chat.server.moquette.message.MqttMessage;
|
import chat.server.moquette.message.MqttMessageFactory;
|
import chat.server.moquette.message.MqttMessageIdVariableHeader;
|
import chat.server.moquette.message.MqttMessageType;
|
import chat.server.moquette.message.MqttPublishVariableHeader;
|
import chat.server.moquette.message.MqttQoS;
|
import chat.server.moquette.message.MqttSubAckPayload;
|
import chat.server.moquette.message.MqttSubscribePayload;
|
import chat.server.moquette.message.MqttTopicSubscription;
|
import chat.server.moquette.message.MqttUnsubscribePayload;
|
import io.netty.buffer.ByteBuf;
|
import io.netty.channel.ChannelHandlerContext;
|
import io.netty.handler.codec.DecoderException;
|
import io.netty.handler.codec.ReplayingDecoder;
|
import io.netty.util.CharsetUtil;
|
|
public class MqttDecodHandler extends ReplayingDecoder<DecoderState> {
|
|
private static final int DEFAULT_MAX_BYTES_IN_MESSAGE = 262144; //256K
|
|
private static Logger logger = Logger.getLogger(MqttDecodHandler.class);
|
|
/**
|
* States of the decoder.
|
* We start at READ_FIXED_HEADER, followed by
|
* READ_VARIABLE_HEADER and finally READ_PAYLOAD.
|
*/
|
enum DecoderState {
|
READ_FIXED_HEADER,
|
READ_VARIABLE_HEADER,
|
READ_PAYLOAD,
|
BAD_MESSAGE,
|
}
|
|
private MqttFixedHeader mqttFixedHeader;
|
private Object variableHeader;
|
private int bytesRemainingInVariablePart;
|
|
private final int maxBytesInMessage;
|
|
public MqttDecodHandler() {
|
this(DEFAULT_MAX_BYTES_IN_MESSAGE);
|
}
|
|
public MqttDecodHandler(int maxBytesInMessage) {
|
super(DecoderState.READ_FIXED_HEADER);
|
logger.debug("MqttDecodHandler=>1");
|
this.maxBytesInMessage = maxBytesInMessage;
|
}
|
|
@Override
|
protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List<Object> out) throws Exception {
|
logger.debug("begin decode");
|
try {
|
switch (state()) {
|
case READ_FIXED_HEADER:
|
logger.debug("decode Read_FIXED_HEADER");
|
mqttFixedHeader = decodeFixedHeader(buffer);
|
bytesRemainingInVariablePart = mqttFixedHeader.remainingLength();
|
checkpoint(DecoderState.READ_VARIABLE_HEADER);
|
break;
|
case READ_VARIABLE_HEADER:
|
logger.debug("decode READ_VARIABLE_HEADER");
|
if (bytesRemainingInVariablePart > maxBytesInMessage) {
|
throw new DecoderException("too large message: " + bytesRemainingInVariablePart + " bytes");
|
}
|
final Result<?> decodedVariableHeader = decodeVariableHeader(buffer, mqttFixedHeader);
|
variableHeader = decodedVariableHeader.value;
|
bytesRemainingInVariablePart -= decodedVariableHeader.numberOfBytesConsumed;
|
checkpoint(DecoderState.READ_PAYLOAD);
|
break;
|
case READ_PAYLOAD:
|
logger.debug("decode READ_PAYLOAD");
|
final Result<?> decodedPayload =
|
decodePayload(
|
buffer,
|
mqttFixedHeader.messageType(),
|
bytesRemainingInVariablePart,
|
variableHeader);
|
bytesRemainingInVariablePart -= decodedPayload.numberOfBytesConsumed;
|
if (bytesRemainingInVariablePart != 0) {
|
throw new DecoderException(
|
"non-zero remaining payload bytes: " +
|
bytesRemainingInVariablePart + " (" + mqttFixedHeader.messageType() + ')');
|
}
|
checkpoint(DecoderState.READ_FIXED_HEADER);
|
MqttMessage message = MqttMessageFactory.newMessage(
|
mqttFixedHeader, variableHeader, decodedPayload.value);
|
mqttFixedHeader = null;
|
variableHeader = null;
|
out.add(message);
|
case BAD_MESSAGE:
|
// Keep discarding until disconnection.
|
logger.debug("decode BAD_MESSAGE");
|
buffer.skipBytes(actualReadableBytes());
|
break;
|
default:
|
// Shouldn't reach here.
|
throw new Error();
|
}
|
}
|
catch (Exception cause) {
|
cause.printStackTrace();
|
out.add(invalidMessage(cause));
|
}
|
}
|
|
private MqttMessage invalidMessage(Throwable cause) {
|
logger.debug("decode invalidMessage");
|
checkpoint(DecoderState.BAD_MESSAGE);
|
return MqttMessageFactory.newInvalidMessage(cause);
|
}
|
|
/**
|
* Decodes the fixed header. It's one byte for the flags and then variable bytes for the remaining length.
|
*
|
* @param buffer the buffer to decode from
|
* @return the fixed header
|
*/
|
private static MqttFixedHeader decodeFixedHeader(ByteBuf buffer) {
|
logger.debug("decode decodeFixedHeader");
|
|
short b1 = buffer.readUnsignedByte();
|
|
MqttMessageType messageType = MqttMessageType.valueOf(b1 >> 4);
|
boolean dupFlag = (b1 & 0x08) == 0x08;
|
int qosLevel = (b1 & 0x06) >> 1;
|
boolean retain = (b1 & 0x01) != 0;
|
|
int remainingLength = 0;
|
int multiplier = 1;
|
short digit;
|
int loops = 0;
|
do {
|
digit = buffer.readUnsignedByte();
|
remainingLength += (digit & 127) * multiplier;
|
multiplier *= 128;
|
loops++;
|
} while ((digit & 128) != 0 && loops < 4);
|
|
// MQTT protocol limits Remaining Length to 4 bytes
|
if (loops == 4 && (digit & 128) != 0) {
|
throw new DecoderException("remaining length exceeds 4 digits (" + messageType + ')');
|
}
|
MqttFixedHeader decodedFixedHeader =
|
new MqttFixedHeader(messageType, dupFlag, MqttQoS.valueOf(qosLevel), retain, remainingLength);
|
return MqttCodecUtil.validateFixedHeader(MqttCodecUtil.resetUnusedFields(decodedFixedHeader));
|
}
|
|
/**
|
* Decodes the variable header (if any)
|
* @param buffer the buffer to decode from
|
* @param mqttFixedHeader MqttFixedHeader of the same message
|
* @return the variable header
|
*/
|
private static Result<?> decodeVariableHeader(ByteBuf buffer, MqttFixedHeader mqttFixedHeader) {
|
switch (mqttFixedHeader.messageType()) {
|
case CONNECT:
|
logger.debug("decode==>decodeVariableHeader==>CONNECT");
|
return decodeConnectionVariableHeader(buffer);
|
|
case CONNACK:
|
logger.debug("decode==>decodeVariableHeader==>CONNACK");
|
return decodeConnAckVariableHeader(buffer);
|
|
case SUBSCRIBE:
|
case UNSUBSCRIBE:
|
case SUBACK:
|
case UNSUBACK:
|
case PUBACK:
|
case PUBREC:
|
case PUBCOMP:
|
case PUBREL:
|
logger.debug("decode==>decodeVariableHeader==>PUBREL");
|
return decodeMessageIdVariableHeader(buffer);
|
|
case PUBLISH:
|
logger.debug("decode==>decodeVariableHeader==>PUBLISH");
|
return decodePublishVariableHeader(buffer, mqttFixedHeader);
|
|
case PINGREQ:
|
case PINGRESP:
|
case DISCONNECT:
|
// Empty variable header
|
logger.debug("decode==>decodeVariableHeader==>DISCONNECT");
|
return new Result<Object>(null, 0);
|
}
|
return new Result<Object>(null, 0); //should never reach here
|
}
|
|
private static Result<MqttConnectVariableHeader> decodeConnectionVariableHeader(ByteBuf buffer) {
|
logger.debug("decode==>decodeConnectionVariableHeader");
|
|
final Result<String> protoString = decodeString(buffer);
|
int numberOfBytesConsumed = protoString.numberOfBytesConsumed;
|
|
final byte protocolLevel = buffer.readByte();
|
numberOfBytesConsumed += 1;
|
|
final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(protoString.value, protocolLevel);
|
|
final int b1 = buffer.readUnsignedByte();
|
numberOfBytesConsumed += 1;
|
|
final Result<Integer> keepAlive = decodeMsbLsb(buffer);
|
numberOfBytesConsumed += keepAlive.numberOfBytesConsumed;
|
|
final boolean hasUserName = (b1 & 0x80) == 0x80;
|
final boolean hasPassword = (b1 & 0x40) == 0x40;
|
final boolean willRetain = (b1 & 0x20) == 0x20;
|
final int willQos = (b1 & 0x18) >> 3;
|
final boolean willFlag = (b1 & 0x04) == 0x04;
|
final boolean cleanSession = (b1 & 0x02) == 0x02;
|
if (mqttVersion.protocolLevel() >= MqttVersion.MQTT_3_1_1.protocolLevel()) {
|
final boolean zeroReservedFlag = (b1 & 0x01) == 0x0;
|
if (!zeroReservedFlag) {
|
// MQTT v3.1.1: The Server MUST validate that the reserved flag in the CONNECT Control Packet is
|
// set to zero and disconnect the Client if it is not zero.
|
// See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349230
|
throw new DecoderException("non-zero reserved flag");
|
}
|
}
|
|
final MqttConnectVariableHeader mqttConnectVariableHeader = new MqttConnectVariableHeader(
|
mqttVersion.protocolName(),
|
mqttVersion.protocolLevel(),
|
hasUserName,
|
hasPassword,
|
willRetain,
|
willQos,
|
willFlag,
|
cleanSession,
|
keepAlive.value);
|
return new Result<MqttConnectVariableHeader>(mqttConnectVariableHeader, numberOfBytesConsumed);
|
}
|
|
private static Result<MqttConnAckVariableHeader> decodeConnAckVariableHeader(ByteBuf buffer) {
|
logger.debug("decode==>decodeConnAckVariableHeader");
|
|
final boolean sessionPresent = (buffer.readUnsignedByte() & 0x01) == 0x01;
|
byte returnCode = buffer.readByte();
|
final int numberOfBytesConsumed = 2;
|
final MqttConnAckVariableHeader mqttConnAckVariableHeader =
|
new MqttConnAckVariableHeader(MqttConnectReturnCode.valueOf(returnCode), sessionPresent);
|
return new Result<MqttConnAckVariableHeader>(mqttConnAckVariableHeader, numberOfBytesConsumed);
|
}
|
|
private static Result<MqttMessageIdVariableHeader> decodeMessageIdVariableHeader(ByteBuf buffer) {
|
logger.debug("decode==>decodeMessageIdVariableHeader");
|
|
final Result<Integer> messageId = decodeMessageId(buffer);
|
return new Result<MqttMessageIdVariableHeader>(
|
MqttMessageIdVariableHeader.from(messageId.value),
|
messageId.numberOfBytesConsumed);
|
}
|
|
private static Result<MqttPublishVariableHeader> decodePublishVariableHeader(
|
ByteBuf buffer,
|
MqttFixedHeader mqttFixedHeader) {
|
logger.debug("decode==>decodePublishVariableHeader");
|
|
final Result<String> decodedTopic = decodeString(buffer);
|
if (!MqttCodecUtil.isValidPublishTopicName(decodedTopic.value)) {
|
throw new DecoderException("invalid publish topic name: " + decodedTopic.value + " (contains wildcards)");
|
}
|
int numberOfBytesConsumed = decodedTopic.numberOfBytesConsumed;
|
|
int messageId = -1;
|
if (mqttFixedHeader.qosLevel().value() > 0) {
|
final Result<Integer> decodedMessageId = decodeMessageId(buffer);
|
messageId = decodedMessageId.value;
|
numberOfBytesConsumed += decodedMessageId.numberOfBytesConsumed;
|
}
|
final MqttPublishVariableHeader mqttPublishVariableHeader =
|
new MqttPublishVariableHeader(decodedTopic.value, messageId);
|
return new Result<MqttPublishVariableHeader>(mqttPublishVariableHeader, numberOfBytesConsumed);
|
}
|
|
private static Result<Integer> decodeMessageId(ByteBuf buffer) {
|
logger.debug("decode==>decodeMessageId");
|
|
final Result<Integer> messageId = decodeMsbLsb(buffer);
|
if (!MqttCodecUtil.isValidMessageId(messageId.value)) {
|
throw new DecoderException("invalid messageId: " + messageId.value);
|
}
|
return messageId;
|
}
|
|
/**
|
* Decodes the payload.
|
*
|
* @param buffer the buffer to decode from
|
* @param messageType type of the message being decoded
|
* @param bytesRemainingInVariablePart bytes remaining
|
* @param variableHeader variable header of the same message
|
* @return the payload
|
*/
|
private static Result<?> decodePayload(
|
ByteBuf buffer,
|
MqttMessageType messageType,
|
int bytesRemainingInVariablePart,
|
Object variableHeader) {
|
switch (messageType) {
|
case CONNECT:
|
logger.debug("decode==>decodePayload==>CONNECT");
|
return decodeConnectionPayload(buffer, (MqttConnectVariableHeader) variableHeader);
|
|
case SUBSCRIBE:
|
logger.debug("decode==>decodePayload==>SUBSCRIBE");
|
return decodeSubscribePayload(buffer, bytesRemainingInVariablePart);
|
|
case SUBACK:
|
logger.debug("decode==>decodePayload==>SUBACK");
|
return decodeSubackPayload(buffer, bytesRemainingInVariablePart);
|
|
case UNSUBSCRIBE:
|
logger.debug("decode==>decodePayload==>UNSUBSCRIBE");
|
return decodeUnsubscribePayload(buffer, bytesRemainingInVariablePart);
|
|
case PUBLISH:
|
logger.debug("decode==>decodePayload==>PUBLISH");
|
return decodePublishPayload(buffer, bytesRemainingInVariablePart);
|
|
default:
|
// unknown payload , no byte consumed
|
return new Result<Object>(null, 0);
|
}
|
}
|
|
private static Result<MqttConnectPayload> decodeConnectionPayload(
|
ByteBuf buffer,
|
MqttConnectVariableHeader mqttConnectVariableHeader) {
|
logger.debug("decode==>decodeConnectionPayload");
|
|
final Result<String> decodedClientId = decodeString(buffer);
|
final String decodedClientIdValue = decodedClientId.value;
|
final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(mqttConnectVariableHeader.name(),
|
(byte) mqttConnectVariableHeader.version());
|
if (!MqttCodecUtil.isValidClientId(mqttVersion, decodedClientIdValue)) {
|
throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + decodedClientIdValue);
|
}
|
int numberOfBytesConsumed = decodedClientId.numberOfBytesConsumed;
|
|
Result<String> decodedWillTopic = null;
|
Result<String> decodedWillMessage = null;
|
if (mqttConnectVariableHeader.isWillFlag()) {
|
decodedWillTopic = decodeString(buffer, 0, 32767);
|
numberOfBytesConsumed += decodedWillTopic.numberOfBytesConsumed;
|
decodedWillMessage = decodeAsciiString(buffer);
|
numberOfBytesConsumed += decodedWillMessage.numberOfBytesConsumed;
|
}
|
Result<String> decodedUserName = null;
|
Result<byte[]> decodedPassword = null;
|
if (mqttConnectVariableHeader.hasUserName()) {
|
decodedUserName = decodeString(buffer);
|
numberOfBytesConsumed += decodedUserName.numberOfBytesConsumed;
|
}
|
if (mqttConnectVariableHeader.hasPassword()) {
|
decodedPassword = decodeByte(buffer);
|
numberOfBytesConsumed += decodedPassword.numberOfBytesConsumed;
|
}
|
|
final MqttConnectPayload mqttConnectPayload =
|
new MqttConnectPayload(
|
decodedClientId.value,
|
decodedWillTopic != null ? decodedWillTopic.value : null,
|
decodedWillMessage != null ? decodedWillMessage.value : null,
|
decodedUserName != null ? decodedUserName.value : null,
|
decodedPassword != null ? decodedPassword.value : null);
|
return new Result<MqttConnectPayload>(mqttConnectPayload, numberOfBytesConsumed);
|
}
|
|
private static Result<MqttSubscribePayload> decodeSubscribePayload(
|
ByteBuf buffer,
|
int bytesRemainingInVariablePart) {
|
logger.debug("decode==>decodeSubscribePayload");
|
|
final List<MqttTopicSubscription> subscribeTopics = new ArrayList<MqttTopicSubscription>();
|
int numberOfBytesConsumed = 0;
|
while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
|
final Result<String> decodedTopicName = decodeString(buffer);
|
numberOfBytesConsumed += decodedTopicName.numberOfBytesConsumed;
|
int qos = buffer.readUnsignedByte() & 0x03;
|
numberOfBytesConsumed++;
|
subscribeTopics.add(new MqttTopicSubscription(decodedTopicName.value, MqttQoS.valueOf(qos)));
|
}
|
return new Result<MqttSubscribePayload>(new MqttSubscribePayload(subscribeTopics), numberOfBytesConsumed);
|
}
|
|
private static Result<MqttSubAckPayload> decodeSubackPayload(
|
ByteBuf buffer,
|
int bytesRemainingInVariablePart) {
|
logger.debug("decode==>decodeSubackPayload");
|
|
final List<Integer> grantedQos = new ArrayList<Integer>();
|
int numberOfBytesConsumed = 0;
|
while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
|
int qos = buffer.readUnsignedByte() & 0x03;
|
numberOfBytesConsumed++;
|
grantedQos.add(qos);
|
}
|
return new Result<MqttSubAckPayload>(new MqttSubAckPayload(grantedQos), numberOfBytesConsumed);
|
}
|
|
private static Result<MqttUnsubscribePayload> decodeUnsubscribePayload(
|
ByteBuf buffer,
|
int bytesRemainingInVariablePart) {
|
logger.debug("decode==>decodeUnsubscribePayload");
|
|
final List<String> unsubscribeTopics = new ArrayList<String>();
|
int numberOfBytesConsumed = 0;
|
while (numberOfBytesConsumed < bytesRemainingInVariablePart) {
|
final Result<String> decodedTopicName = decodeString(buffer);
|
numberOfBytesConsumed += decodedTopicName.numberOfBytesConsumed;
|
unsubscribeTopics.add(decodedTopicName.value);
|
}
|
return new Result<MqttUnsubscribePayload>(
|
new MqttUnsubscribePayload(unsubscribeTopics),
|
numberOfBytesConsumed);
|
}
|
|
private static Result<ByteBuf> decodePublishPayload(ByteBuf buffer, int bytesRemainingInVariablePart) {
|
logger.debug("decode==>decodePublishPayload");
|
|
ByteBuf b = buffer.readRetainedSlice(bytesRemainingInVariablePart);
|
return new Result<ByteBuf>(b, bytesRemainingInVariablePart);
|
}
|
|
private static Result<String> decodeString(ByteBuf buffer) {
|
logger.debug("decode==>decodeString");
|
|
return decodeString(buffer, 0, Integer.MAX_VALUE);
|
}
|
|
private static Result<String> decodeAsciiString(ByteBuf buffer) {
|
logger.debug("decode==>decodeAsciiString");
|
|
Result<String> result = decodeString(buffer, 0, Integer.MAX_VALUE);
|
final String s = result.value;
|
for (int i = 0; i < s.length(); i++) {
|
if (s.charAt(i) > 127) {
|
return new Result<String>(null, result.numberOfBytesConsumed);
|
}
|
}
|
return new Result<String>(s, result.numberOfBytesConsumed);
|
}
|
|
private static Result<byte[]> decodeByte(ByteBuf buffer) {
|
logger.debug("decode==>decodeByte");
|
|
return decodeByte(buffer, 0, Integer.MAX_VALUE);
|
}
|
|
private static Result<String> decodeString(ByteBuf buffer, int minBytes, int maxBytes) {
|
logger.debug("decode==>decodeString");
|
|
final Result<Integer> decodedSize = decodeMsbLsb(buffer);
|
int size = decodedSize.value;
|
int numberOfBytesConsumed = decodedSize.numberOfBytesConsumed;
|
if (size < minBytes || size > maxBytes) {
|
buffer.skipBytes(size);
|
numberOfBytesConsumed += size;
|
return new Result<String>(null, numberOfBytesConsumed);
|
}
|
String s = buffer.toString(buffer.readerIndex(), size, CharsetUtil.UTF_8);
|
buffer.skipBytes(size);
|
numberOfBytesConsumed += size;
|
return new Result<String>(s, numberOfBytesConsumed);
|
}
|
|
private static Result<byte[]> decodeByte(ByteBuf buffer, int minBytes, int maxBytes) {
|
logger.debug("decode==>decodeByte");
|
|
final Result<Integer> decodedSize = decodeMsbLsb(buffer);
|
int size = decodedSize.value;
|
int numberOfBytesConsumed = decodedSize.numberOfBytesConsumed;
|
if (size < minBytes || size > maxBytes) {
|
buffer.skipBytes(size);
|
numberOfBytesConsumed += size;
|
return new Result<>(null, numberOfBytesConsumed);
|
}
|
byte[] s = new byte[size];
|
buffer.getBytes(buffer.readerIndex(), s, 0, size);
|
buffer.skipBytes(size);
|
numberOfBytesConsumed += size;
|
return new Result<>(s, numberOfBytesConsumed);
|
}
|
|
private static Result<Integer> decodeMsbLsb(ByteBuf buffer) {
|
logger.debug("decode==>decodeMsbLsb=>1");
|
|
return decodeMsbLsb(buffer, 0, 65535);
|
}
|
|
private static Result<Integer> decodeMsbLsb(ByteBuf buffer, int min, int max) {
|
logger.debug("decode==>decodeMsbLsb=>2");
|
|
short msbSize = buffer.readUnsignedByte();
|
short lsbSize = buffer.readUnsignedByte();
|
final int numberOfBytesConsumed = 2;
|
int result = msbSize << 8 | lsbSize;
|
if (result < min || result > max) {
|
result = -1;
|
}
|
return new Result<Integer>(result, numberOfBytesConsumed);
|
}
|
|
private static final class Result<T> {
|
|
private final T value;
|
private final int numberOfBytesConsumed;
|
|
Result(T value, int numberOfBytesConsumed) {
|
this.value = value;
|
this.numberOfBytesConsumed = numberOfBytesConsumed;
|
}
|
}
|
}
|