001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.activemq.transport.mqtt; 018 019import java.util.Map; 020import java.util.concurrent.ConcurrentHashMap; 021 022import org.apache.activemq.Service; 023import org.apache.activemq.broker.BrokerService; 024import org.apache.activemq.command.ActiveMQMessage; 025import org.apache.activemq.util.LRUCache; 026import org.apache.activemq.util.ServiceStopper; 027import org.apache.activemq.util.ServiceSupport; 028import org.fusesource.mqtt.codec.PUBLISH; 029import org.slf4j.Logger; 030import org.slf4j.LoggerFactory; 031 032/** 033 * Manages PUBLISH packet ids for clients. 034 * 035 * @author Dhiraj Bokde 036 */ 037public class MQTTPacketIdGenerator extends ServiceSupport { 038 039 private static final Logger LOG = LoggerFactory.getLogger(MQTTPacketIdGenerator.class); 040 private static final Object LOCK = new Object(); 041 042 Map<String, PacketIdMaps> clientIdMap = new ConcurrentHashMap<String, PacketIdMaps>(); 043 044 private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator(); 045 046 private MQTTPacketIdGenerator() { 047 } 048 049 @Override 050 protected void doStop(ServiceStopper stopper) throws Exception { 051 synchronized (this) { 052 clientIdMap = new ConcurrentHashMap<String, PacketIdMaps>(); 053 } 054 } 055 056 @Override 057 protected void doStart() throws Exception { 058 } 059 060 public void startClientSession(String clientId) { 061 if (!clientIdMap.containsKey(clientId)) { 062 clientIdMap.put(clientId, new PacketIdMaps()); 063 } 064 } 065 066 public boolean stopClientSession(String clientId) { 067 return clientIdMap.remove(clientId) != null; 068 } 069 070 public short setPacketId(String clientId, MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) { 071 final PacketIdMaps idMaps = clientIdMap.get(clientId); 072 if (idMaps == null) { 073 // maybe its a cleansession=true client id, use session less message id 074 final short id = messageIdGenerator.getNextSequenceId(); 075 publish.messageId(id); 076 return id; 077 } else { 078 return idMaps.setPacketId(subscription, message, publish); 079 } 080 } 081 082 public void ackPacketId(String clientId, short packetId) { 083 final PacketIdMaps idMaps = clientIdMap.get(clientId); 084 if (idMaps != null) { 085 idMaps.ackPacketId(packetId); 086 } 087 } 088 089 public short getNextSequenceId(String clientId) { 090 final PacketIdMaps idMaps = clientIdMap.get(clientId); 091 return idMaps != null ? idMaps.getNextSequenceId(): messageIdGenerator.getNextSequenceId(); 092 } 093 094 public static MQTTPacketIdGenerator getMQTTPacketIdGenerator(BrokerService broker) { 095 MQTTPacketIdGenerator result = null; 096 if (broker != null) { 097 synchronized (LOCK) { 098 Service[] services = broker.getServices(); 099 if (services != null) { 100 for (Service service : services) { 101 if (service instanceof MQTTPacketIdGenerator) { 102 return (MQTTPacketIdGenerator) service; 103 } 104 } 105 } 106 result = new MQTTPacketIdGenerator(); 107 broker.addService(result); 108 if (broker.isStarted()) { 109 try { 110 result.start(); 111 } catch (Exception e) { 112 LOG.warn("Couldn't start MQTTPacketIdGenerator"); 113 } 114 } 115 } 116 } 117 118 119 return result; 120 } 121 122 private class PacketIdMaps { 123 124 private final NonZeroSequenceGenerator messageIdGenerator = new NonZeroSequenceGenerator(); 125 final Map<String, Short> activemqToPacketIds = new LRUCache<String, Short>(MQTTProtocolConverter.DEFAULT_CACHE_SIZE); 126 final Map<Short, String> packetIdsToActivemq = new LRUCache<Short, String>(MQTTProtocolConverter.DEFAULT_CACHE_SIZE); 127 128 short setPacketId(MQTTSubscription subscription, ActiveMQMessage message, PUBLISH publish) { 129 // subscription key 130 final StringBuilder subscriptionKey = new StringBuilder(); 131 subscriptionKey.append(subscription.getConsumerInfo().getDestination().getPhysicalName()) 132 .append(':').append(message.getJMSMessageID()); 133 final String keyStr = subscriptionKey.toString(); 134 Short packetId; 135 synchronized (activemqToPacketIds) { 136 packetId = activemqToPacketIds.get(keyStr); 137 if (packetId == null) { 138 packetId = getNextSequenceId(); 139 activemqToPacketIds.put(keyStr, packetId); 140 packetIdsToActivemq.put(packetId, keyStr); 141 } else { 142 // mark publish as duplicate! 143 publish.dup(true); 144 } 145 } 146 publish.messageId(packetId); 147 return packetId; 148 } 149 150 void ackPacketId(short packetId) { 151 synchronized (activemqToPacketIds) { 152 final String subscriptionKey = packetIdsToActivemq.remove(packetId); 153 if (subscriptionKey != null) { 154 activemqToPacketIds.remove(subscriptionKey); 155 } 156 } 157 } 158 159 short getNextSequenceId() { 160 return messageIdGenerator.getNextSequenceId(); 161 } 162 163 } 164 165 private class NonZeroSequenceGenerator { 166 167 private short lastSequenceId; 168 169 public synchronized short getNextSequenceId() { 170 final short val = ++lastSequenceId; 171 return val != 0 ? val : ++lastSequenceId; 172 } 173 174 } 175 176}