package chat.server.moquette; import org.apache.log4j.Logger; import io.netty.buffer.ByteBuf; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.util.Attribute; import io.netty.util.AttributeKey; public class MqttBytesMetricsHandler extends ChannelDuplexHandler { private static Logger logger; private static AttributeKey ATTR_KEY_METRICS; private static AttributeKey ATTR_KEY_USERNAME; private BytesMetricsCollector collector; static { logger = Logger.getLogger(MqttBytesMetricsHandler.class); ATTR_KEY_METRICS = AttributeKey.valueOf("BytesMetrics"); ATTR_KEY_USERNAME = AttributeKey.valueOf("username"); } public MqttBytesMetricsHandler(BytesMetricsCollector collector) { this.collector = collector; } @Override public void channelActive(ChannelHandlerContext ctx) throws Exception { Attribute attr = ctx.channel().attr(ATTR_KEY_METRICS); attr.set(new Metrics()); super.channelActive(ctx); } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { Metrics metrics = ctx.channel().attr(ATTR_KEY_METRICS).get(); metrics.incrementRead(((ByteBuf) msg).readableBytes()); ctx.fireChannelRead(msg); } @Override public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { Metrics metrics = ctx.channel().attr(ATTR_KEY_METRICS).get(); metrics.incrementWrote(((ByteBuf) msg).writableBytes()); ctx.write(msg, promise); } @Override public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { Metrics metrics = ctx.channel().attr(ATTR_KEY_METRICS).get(); String userId = ctx.channel().attr(ATTR_KEY_USERNAME).get(); if (userId == null) { userId = ""; } logger.info("channel<" + userId + "> closing after" + metrics); collector.sumReadBytes(metrics.readLength()); collector.sumWroteBytes(metrics.wroteLength()); super.close(ctx, promise); } public static Metrics getBytesMetrics(Channel channel) { return channel.attr(ATTR_KEY_METRICS).get(); } }