package chat.server.moquette; import java.util.ArrayList; import java.util.List; import org.apache.log4j.Logger; import chat.server.moquette.MqttDecodHandler.DecoderState; import chat.server.moquette.message.MqttConnAckVariableHeader; import chat.server.moquette.message.MqttConnectPayload; import chat.server.moquette.message.MqttConnectReturnCode; import chat.server.moquette.message.MqttConnectVariableHeader; import chat.server.moquette.message.MqttFixedHeader; import chat.server.moquette.message.MqttMessage; import chat.server.moquette.message.MqttMessageFactory; import chat.server.moquette.message.MqttMessageIdVariableHeader; import chat.server.moquette.message.MqttMessageType; import chat.server.moquette.message.MqttPublishVariableHeader; import chat.server.moquette.message.MqttQoS; import chat.server.moquette.message.MqttSubAckPayload; import chat.server.moquette.message.MqttSubscribePayload; import chat.server.moquette.message.MqttTopicSubscription; import chat.server.moquette.message.MqttUnsubscribePayload; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.DecoderException; import io.netty.handler.codec.ReplayingDecoder; import io.netty.util.CharsetUtil; public class MqttDecodHandler extends ReplayingDecoder { private static final int DEFAULT_MAX_BYTES_IN_MESSAGE = 262144; //256K private static Logger logger = Logger.getLogger(MqttDecodHandler.class); /** * States of the decoder. * We start at READ_FIXED_HEADER, followed by * READ_VARIABLE_HEADER and finally READ_PAYLOAD. */ enum DecoderState { READ_FIXED_HEADER, READ_VARIABLE_HEADER, READ_PAYLOAD, BAD_MESSAGE, } private MqttFixedHeader mqttFixedHeader; private Object variableHeader; private int bytesRemainingInVariablePart; private final int maxBytesInMessage; public MqttDecodHandler() { this(DEFAULT_MAX_BYTES_IN_MESSAGE); } public MqttDecodHandler(int maxBytesInMessage) { super(DecoderState.READ_FIXED_HEADER); logger.debug("MqttDecodHandler=>1"); this.maxBytesInMessage = maxBytesInMessage; } @Override protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List out) throws Exception { logger.debug("begin decode"); try { switch (state()) { case READ_FIXED_HEADER: logger.debug("decode Read_FIXED_HEADER"); mqttFixedHeader = decodeFixedHeader(buffer); bytesRemainingInVariablePart = mqttFixedHeader.remainingLength(); checkpoint(DecoderState.READ_VARIABLE_HEADER); break; case READ_VARIABLE_HEADER: logger.debug("decode READ_VARIABLE_HEADER"); if (bytesRemainingInVariablePart > maxBytesInMessage) { throw new DecoderException("too large message: " + bytesRemainingInVariablePart + " bytes"); } final Result decodedVariableHeader = decodeVariableHeader(buffer, mqttFixedHeader); variableHeader = decodedVariableHeader.value; bytesRemainingInVariablePart -= decodedVariableHeader.numberOfBytesConsumed; checkpoint(DecoderState.READ_PAYLOAD); break; case READ_PAYLOAD: logger.debug("decode READ_PAYLOAD"); final Result decodedPayload = decodePayload( buffer, mqttFixedHeader.messageType(), bytesRemainingInVariablePart, variableHeader); bytesRemainingInVariablePart -= decodedPayload.numberOfBytesConsumed; if (bytesRemainingInVariablePart != 0) { throw new DecoderException( "non-zero remaining payload bytes: " + bytesRemainingInVariablePart + " (" + mqttFixedHeader.messageType() + ')'); } checkpoint(DecoderState.READ_FIXED_HEADER); MqttMessage message = MqttMessageFactory.newMessage( mqttFixedHeader, variableHeader, decodedPayload.value); mqttFixedHeader = null; variableHeader = null; out.add(message); case BAD_MESSAGE: // Keep discarding until disconnection. logger.debug("decode BAD_MESSAGE"); buffer.skipBytes(actualReadableBytes()); break; default: // Shouldn't reach here. throw new Error(); } } catch (Exception cause) { cause.printStackTrace(); out.add(invalidMessage(cause)); } } private MqttMessage invalidMessage(Throwable cause) { logger.debug("decode invalidMessage"); checkpoint(DecoderState.BAD_MESSAGE); return MqttMessageFactory.newInvalidMessage(cause); } /** * Decodes the fixed header. It's one byte for the flags and then variable bytes for the remaining length. * * @param buffer the buffer to decode from * @return the fixed header */ private static MqttFixedHeader decodeFixedHeader(ByteBuf buffer) { logger.debug("decode decodeFixedHeader"); short b1 = buffer.readUnsignedByte(); MqttMessageType messageType = MqttMessageType.valueOf(b1 >> 4); boolean dupFlag = (b1 & 0x08) == 0x08; int qosLevel = (b1 & 0x06) >> 1; boolean retain = (b1 & 0x01) != 0; int remainingLength = 0; int multiplier = 1; short digit; int loops = 0; do { digit = buffer.readUnsignedByte(); remainingLength += (digit & 127) * multiplier; multiplier *= 128; loops++; } while ((digit & 128) != 0 && loops < 4); // MQTT protocol limits Remaining Length to 4 bytes if (loops == 4 && (digit & 128) != 0) { throw new DecoderException("remaining length exceeds 4 digits (" + messageType + ')'); } MqttFixedHeader decodedFixedHeader = new MqttFixedHeader(messageType, dupFlag, MqttQoS.valueOf(qosLevel), retain, remainingLength); return MqttCodecUtil.validateFixedHeader(MqttCodecUtil.resetUnusedFields(decodedFixedHeader)); } /** * Decodes the variable header (if any) * @param buffer the buffer to decode from * @param mqttFixedHeader MqttFixedHeader of the same message * @return the variable header */ private static Result decodeVariableHeader(ByteBuf buffer, MqttFixedHeader mqttFixedHeader) { switch (mqttFixedHeader.messageType()) { case CONNECT: logger.debug("decode==>decodeVariableHeader==>CONNECT"); return decodeConnectionVariableHeader(buffer); case CONNACK: logger.debug("decode==>decodeVariableHeader==>CONNACK"); return decodeConnAckVariableHeader(buffer); case SUBSCRIBE: case UNSUBSCRIBE: case SUBACK: case UNSUBACK: case PUBACK: case PUBREC: case PUBCOMP: case PUBREL: logger.debug("decode==>decodeVariableHeader==>PUBREL"); return decodeMessageIdVariableHeader(buffer); case PUBLISH: logger.debug("decode==>decodeVariableHeader==>PUBLISH"); return decodePublishVariableHeader(buffer, mqttFixedHeader); case PINGREQ: case PINGRESP: case DISCONNECT: // Empty variable header logger.debug("decode==>decodeVariableHeader==>DISCONNECT"); return new Result(null, 0); } return new Result(null, 0); //should never reach here } private static Result decodeConnectionVariableHeader(ByteBuf buffer) { logger.debug("decode==>decodeConnectionVariableHeader"); final Result protoString = decodeString(buffer); int numberOfBytesConsumed = protoString.numberOfBytesConsumed; final byte protocolLevel = buffer.readByte(); numberOfBytesConsumed += 1; final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(protoString.value, protocolLevel); final int b1 = buffer.readUnsignedByte(); numberOfBytesConsumed += 1; final Result keepAlive = decodeMsbLsb(buffer); numberOfBytesConsumed += keepAlive.numberOfBytesConsumed; final boolean hasUserName = (b1 & 0x80) == 0x80; final boolean hasPassword = (b1 & 0x40) == 0x40; final boolean willRetain = (b1 & 0x20) == 0x20; final int willQos = (b1 & 0x18) >> 3; final boolean willFlag = (b1 & 0x04) == 0x04; final boolean cleanSession = (b1 & 0x02) == 0x02; if (mqttVersion.protocolLevel() >= MqttVersion.MQTT_3_1_1.protocolLevel()) { final boolean zeroReservedFlag = (b1 & 0x01) == 0x0; if (!zeroReservedFlag) { // MQTT v3.1.1: The Server MUST validate that the reserved flag in the CONNECT Control Packet is // set to zero and disconnect the Client if it is not zero. // See http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/os/mqtt-v3.1.1-os.html#_Toc385349230 throw new DecoderException("non-zero reserved flag"); } } final MqttConnectVariableHeader mqttConnectVariableHeader = new MqttConnectVariableHeader( mqttVersion.protocolName(), mqttVersion.protocolLevel(), hasUserName, hasPassword, willRetain, willQos, willFlag, cleanSession, keepAlive.value); return new Result(mqttConnectVariableHeader, numberOfBytesConsumed); } private static Result decodeConnAckVariableHeader(ByteBuf buffer) { logger.debug("decode==>decodeConnAckVariableHeader"); final boolean sessionPresent = (buffer.readUnsignedByte() & 0x01) == 0x01; byte returnCode = buffer.readByte(); final int numberOfBytesConsumed = 2; final MqttConnAckVariableHeader mqttConnAckVariableHeader = new MqttConnAckVariableHeader(MqttConnectReturnCode.valueOf(returnCode), sessionPresent); return new Result(mqttConnAckVariableHeader, numberOfBytesConsumed); } private static Result decodeMessageIdVariableHeader(ByteBuf buffer) { logger.debug("decode==>decodeMessageIdVariableHeader"); final Result messageId = decodeMessageId(buffer); return new Result( MqttMessageIdVariableHeader.from(messageId.value), messageId.numberOfBytesConsumed); } private static Result decodePublishVariableHeader( ByteBuf buffer, MqttFixedHeader mqttFixedHeader) { logger.debug("decode==>decodePublishVariableHeader"); final Result decodedTopic = decodeString(buffer); if (!MqttCodecUtil.isValidPublishTopicName(decodedTopic.value)) { throw new DecoderException("invalid publish topic name: " + decodedTopic.value + " (contains wildcards)"); } int numberOfBytesConsumed = decodedTopic.numberOfBytesConsumed; int messageId = -1; if (mqttFixedHeader.qosLevel().value() > 0) { final Result decodedMessageId = decodeMessageId(buffer); messageId = decodedMessageId.value; numberOfBytesConsumed += decodedMessageId.numberOfBytesConsumed; } final MqttPublishVariableHeader mqttPublishVariableHeader = new MqttPublishVariableHeader(decodedTopic.value, messageId); return new Result(mqttPublishVariableHeader, numberOfBytesConsumed); } private static Result decodeMessageId(ByteBuf buffer) { logger.debug("decode==>decodeMessageId"); final Result messageId = decodeMsbLsb(buffer); if (!MqttCodecUtil.isValidMessageId(messageId.value)) { throw new DecoderException("invalid messageId: " + messageId.value); } return messageId; } /** * Decodes the payload. * * @param buffer the buffer to decode from * @param messageType type of the message being decoded * @param bytesRemainingInVariablePart bytes remaining * @param variableHeader variable header of the same message * @return the payload */ private static Result decodePayload( ByteBuf buffer, MqttMessageType messageType, int bytesRemainingInVariablePart, Object variableHeader) { switch (messageType) { case CONNECT: logger.debug("decode==>decodePayload==>CONNECT"); return decodeConnectionPayload(buffer, (MqttConnectVariableHeader) variableHeader); case SUBSCRIBE: logger.debug("decode==>decodePayload==>SUBSCRIBE"); return decodeSubscribePayload(buffer, bytesRemainingInVariablePart); case SUBACK: logger.debug("decode==>decodePayload==>SUBACK"); return decodeSubackPayload(buffer, bytesRemainingInVariablePart); case UNSUBSCRIBE: logger.debug("decode==>decodePayload==>UNSUBSCRIBE"); return decodeUnsubscribePayload(buffer, bytesRemainingInVariablePart); case PUBLISH: logger.debug("decode==>decodePayload==>PUBLISH"); return decodePublishPayload(buffer, bytesRemainingInVariablePart); default: // unknown payload , no byte consumed return new Result(null, 0); } } private static Result decodeConnectionPayload( ByteBuf buffer, MqttConnectVariableHeader mqttConnectVariableHeader) { logger.debug("decode==>decodeConnectionPayload"); final Result decodedClientId = decodeString(buffer); final String decodedClientIdValue = decodedClientId.value; final MqttVersion mqttVersion = MqttVersion.fromProtocolNameAndLevel(mqttConnectVariableHeader.name(), (byte) mqttConnectVariableHeader.version()); if (!MqttCodecUtil.isValidClientId(mqttVersion, decodedClientIdValue)) { throw new MqttIdentifierRejectedException("invalid clientIdentifier: " + decodedClientIdValue); } int numberOfBytesConsumed = decodedClientId.numberOfBytesConsumed; Result decodedWillTopic = null; Result decodedWillMessage = null; if (mqttConnectVariableHeader.isWillFlag()) { decodedWillTopic = decodeString(buffer, 0, 32767); numberOfBytesConsumed += decodedWillTopic.numberOfBytesConsumed; decodedWillMessage = decodeAsciiString(buffer); numberOfBytesConsumed += decodedWillMessage.numberOfBytesConsumed; } Result decodedUserName = null; Result decodedPassword = null; if (mqttConnectVariableHeader.hasUserName()) { decodedUserName = decodeString(buffer); numberOfBytesConsumed += decodedUserName.numberOfBytesConsumed; } if (mqttConnectVariableHeader.hasPassword()) { decodedPassword = decodeByte(buffer); numberOfBytesConsumed += decodedPassword.numberOfBytesConsumed; } final MqttConnectPayload mqttConnectPayload = new MqttConnectPayload( decodedClientId.value, decodedWillTopic != null ? decodedWillTopic.value : null, decodedWillMessage != null ? decodedWillMessage.value : null, decodedUserName != null ? decodedUserName.value : null, decodedPassword != null ? decodedPassword.value : null); return new Result(mqttConnectPayload, numberOfBytesConsumed); } private static Result decodeSubscribePayload( ByteBuf buffer, int bytesRemainingInVariablePart) { logger.debug("decode==>decodeSubscribePayload"); final List subscribeTopics = new ArrayList(); int numberOfBytesConsumed = 0; while (numberOfBytesConsumed < bytesRemainingInVariablePart) { final Result decodedTopicName = decodeString(buffer); numberOfBytesConsumed += decodedTopicName.numberOfBytesConsumed; int qos = buffer.readUnsignedByte() & 0x03; numberOfBytesConsumed++; subscribeTopics.add(new MqttTopicSubscription(decodedTopicName.value, MqttQoS.valueOf(qos))); } return new Result(new MqttSubscribePayload(subscribeTopics), numberOfBytesConsumed); } private static Result decodeSubackPayload( ByteBuf buffer, int bytesRemainingInVariablePart) { logger.debug("decode==>decodeSubackPayload"); final List grantedQos = new ArrayList(); int numberOfBytesConsumed = 0; while (numberOfBytesConsumed < bytesRemainingInVariablePart) { int qos = buffer.readUnsignedByte() & 0x03; numberOfBytesConsumed++; grantedQos.add(qos); } return new Result(new MqttSubAckPayload(grantedQos), numberOfBytesConsumed); } private static Result decodeUnsubscribePayload( ByteBuf buffer, int bytesRemainingInVariablePart) { logger.debug("decode==>decodeUnsubscribePayload"); final List unsubscribeTopics = new ArrayList(); int numberOfBytesConsumed = 0; while (numberOfBytesConsumed < bytesRemainingInVariablePart) { final Result decodedTopicName = decodeString(buffer); numberOfBytesConsumed += decodedTopicName.numberOfBytesConsumed; unsubscribeTopics.add(decodedTopicName.value); } return new Result( new MqttUnsubscribePayload(unsubscribeTopics), numberOfBytesConsumed); } private static Result decodePublishPayload(ByteBuf buffer, int bytesRemainingInVariablePart) { logger.debug("decode==>decodePublishPayload"); ByteBuf b = buffer.readRetainedSlice(bytesRemainingInVariablePart); return new Result(b, bytesRemainingInVariablePart); } private static Result decodeString(ByteBuf buffer) { logger.debug("decode==>decodeString"); return decodeString(buffer, 0, Integer.MAX_VALUE); } private static Result decodeAsciiString(ByteBuf buffer) { logger.debug("decode==>decodeAsciiString"); Result result = decodeString(buffer, 0, Integer.MAX_VALUE); final String s = result.value; for (int i = 0; i < s.length(); i++) { if (s.charAt(i) > 127) { return new Result(null, result.numberOfBytesConsumed); } } return new Result(s, result.numberOfBytesConsumed); } private static Result decodeByte(ByteBuf buffer) { logger.debug("decode==>decodeByte"); return decodeByte(buffer, 0, Integer.MAX_VALUE); } private static Result decodeString(ByteBuf buffer, int minBytes, int maxBytes) { logger.debug("decode==>decodeString"); final Result decodedSize = decodeMsbLsb(buffer); int size = decodedSize.value; int numberOfBytesConsumed = decodedSize.numberOfBytesConsumed; if (size < minBytes || size > maxBytes) { buffer.skipBytes(size); numberOfBytesConsumed += size; return new Result(null, numberOfBytesConsumed); } String s = buffer.toString(buffer.readerIndex(), size, CharsetUtil.UTF_8); buffer.skipBytes(size); numberOfBytesConsumed += size; return new Result(s, numberOfBytesConsumed); } private static Result decodeByte(ByteBuf buffer, int minBytes, int maxBytes) { logger.debug("decode==>decodeByte"); final Result decodedSize = decodeMsbLsb(buffer); int size = decodedSize.value; int numberOfBytesConsumed = decodedSize.numberOfBytesConsumed; if (size < minBytes || size > maxBytes) { buffer.skipBytes(size); numberOfBytesConsumed += size; return new Result<>(null, numberOfBytesConsumed); } byte[] s = new byte[size]; buffer.getBytes(buffer.readerIndex(), s, 0, size); buffer.skipBytes(size); numberOfBytesConsumed += size; return new Result<>(s, numberOfBytesConsumed); } private static Result decodeMsbLsb(ByteBuf buffer) { logger.debug("decode==>decodeMsbLsb=>1"); return decodeMsbLsb(buffer, 0, 65535); } private static Result decodeMsbLsb(ByteBuf buffer, int min, int max) { logger.debug("decode==>decodeMsbLsb=>2"); short msbSize = buffer.readUnsignedByte(); short lsbSize = buffer.readUnsignedByte(); final int numberOfBytesConsumed = 2; int result = msbSize << 8 | lsbSize; if (result < min || result > max) { result = -1; } return new Result(result, numberOfBytesConsumed); } private static final class Result { private final T value; private final int numberOfBytesConsumed; Result(T value, int numberOfBytesConsumed) { this.value = value; this.numberOfBytesConsumed = numberOfBytesConsumed; } } }