001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataOutputStream; 021import java.io.EOFException; 022import java.io.IOException; 023import java.net.Socket; 024import java.net.URI; 025import java.net.UnknownHostException; 026import java.nio.ByteBuffer; 027import java.util.concurrent.atomic.AtomicInteger; 028 029import javax.net.SocketFactory; 030import javax.net.ssl.SSLContext; 031import javax.net.ssl.SSLEngine; 032import javax.net.ssl.SSLEngineResult; 033 034import org.apache.activemq.thread.TaskRunnerFactory; 035import org.apache.activemq.util.IOExceptionSupport; 036import org.apache.activemq.util.ServiceStopper; 037import org.apache.activemq.wireformat.WireFormat; 038 039/** 040 * This transport initializes the SSLEngine and reads the first command before 041 * handing off to the detected transport. 042 * 043 */ 044public class AutoInitNioSSLTransport extends NIOSSLTransport { 045 046 public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 047 super(wireFormat, socketFactory, remoteLocation, localLocation); 048 } 049 050 public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { 051 super(wireFormat, socket, null, null, null); 052 } 053 054 @Override 055 public void setSslContext(SSLContext sslContext) { 056 this.sslContext = sslContext; 057 } 058 059 public ByteBuffer getInputBuffer() { 060 return this.inputBuffer; 061 } 062 063 @Override 064 protected void initializeStreams() throws IOException { 065 NIOOutputStream outputStream = null; 066 try { 067 channel = socket.getChannel(); 068 channel.configureBlocking(false); 069 070 if (sslContext == null) { 071 sslContext = SSLContext.getDefault(); 072 } 073 074 String remoteHost = null; 075 int remotePort = -1; 076 077 try { 078 URI remoteAddress = new URI(this.getRemoteAddress()); 079 remoteHost = remoteAddress.getHost(); 080 remotePort = remoteAddress.getPort(); 081 } catch (Exception e) { 082 } 083 084 // initialize engine, the initial sslSession we get will need to be 085 // updated once the ssl handshake process is completed. 086 if (remoteHost != null && remotePort != -1) { 087 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 088 } else { 089 sslEngine = sslContext.createSSLEngine(); 090 } 091 092 sslEngine.setUseClientMode(false); 093 if (enabledCipherSuites != null) { 094 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 095 } 096 097 if (enabledProtocols != null) { 098 sslEngine.setEnabledProtocols(enabledProtocols); 099 } 100 101 if (wantClientAuth) { 102 sslEngine.setWantClientAuth(wantClientAuth); 103 } 104 105 if (needClientAuth) { 106 sslEngine.setNeedClientAuth(needClientAuth); 107 } 108 109 sslSession = sslEngine.getSession(); 110 111 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 112 inputBuffer.clear(); 113 114 outputStream = new NIOOutputStream(channel); 115 outputStream.setEngine(sslEngine); 116 this.dataOut = new DataOutputStream(outputStream); 117 this.buffOut = outputStream; 118 sslEngine.beginHandshake(); 119 handshakeStatus = sslEngine.getHandshakeStatus(); 120 doHandshake(); 121 122 } catch (Exception e) { 123 try { 124 if(outputStream != null) { 125 outputStream.close(); 126 } 127 super.closeStreams(); 128 } catch (Exception ex) {} 129 throw new IOException(e); 130 } 131 } 132 133 @Override 134 protected void doOpenWireInit() throws Exception { 135 136 } 137 138 public SSLEngine getSslSession() { 139 return this.sslEngine; 140 } 141 142 private volatile byte[] readData; 143 144 private final AtomicInteger readSize = new AtomicInteger(); 145 146 public byte[] getReadData() { 147 return readData != null ? readData : new byte[0]; 148 } 149 150 public AtomicInteger getReadSize() { 151 return readSize; 152 } 153 154 @Override 155 public void serviceRead() { 156 try { 157 if (handshakeInProgress) { 158 doHandshake(); 159 } 160 161 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 162 plain.position(plain.limit()); 163 164 while (true) { 165 if (!plain.hasRemaining()) { 166 int readCount = secureRead(plain); 167 168 if (readCount == 0) { 169 break; 170 } 171 172 // channel is closed, cleanup 173 if (readCount == -1) { 174 onException(new EOFException()); 175 break; 176 } 177 178 receiveCounter += readCount; 179 readSize.addAndGet(readCount); 180 } 181 182 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 183 processCommand(plain); 184 //we have received enough bytes to detect the protocol 185 if (receiveCounter >= 8) { 186 break; 187 } 188 } 189 } 190 } catch (IOException e) { 191 onException(e); 192 } catch (Throwable e) { 193 onException(IOExceptionSupport.create(e)); 194 } 195 } 196 197 @Override 198 protected void processCommand(ByteBuffer plain) throws Exception { 199 ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter); 200 if (readData != null) { 201 newBuffer.put(readData); 202 } 203 newBuffer.put(plain); 204 newBuffer.flip(); 205 readData = newBuffer.array(); 206 } 207 208 209 @Override 210 public void doStart() throws Exception { 211 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 212 // no need to init as we can delay that until demand (eg in doHandshake) 213 connect(); 214 } 215 216 217 @Override 218 protected void doStop(ServiceStopper stopper) throws Exception { 219 if (taskRunnerFactory != null) { 220 taskRunnerFactory.shutdownNow(); 221 taskRunnerFactory = null; 222 } 223 } 224 225 226}