1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.transport.nio;
21
22 import static org.apache.mina.session.AttributeKey.createKey;
23
24 import java.net.InetSocketAddress;
25 import java.nio.ByteBuffer;
26 import java.util.Queue;
27
28 import javax.net.ssl.SSLContext;
29 import javax.net.ssl.SSLEngine;
30 import javax.net.ssl.SSLEngineResult;
31 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
32 import javax.net.ssl.SSLException;
33 import javax.net.ssl.SSLSession;
34
35 import org.apache.mina.api.IoClient;
36 import org.apache.mina.api.IoSession;
37 import org.apache.mina.session.AbstractIoSession;
38 import org.apache.mina.session.AttributeKey;
39 import org.apache.mina.session.DefaultWriteRequest;
40 import org.apache.mina.session.WriteRequest;
41 import org.slf4j.Logger;
42 import org.slf4j.LoggerFactory;
43
44
45
46
47
48
49 public class SslHelper {
50
51 private static final Logger LOGGER = LoggerFactory.getLogger(SslHelper.class);
52
53
54 private SSLEngine sslEngine;
55
56
57 private final SSLContext sslContext;
58
59
60 private final IoSession session;
61
62
63
64
65
66
67
68
69
70
71 public static final AttributeKey<InetSocketAddress> PEER_ADDRESS = createKey(InetSocketAddress.class,
72 "internal_peerAddress");
73
74 public static final AttributeKey<Boolean> WANT_CLIENT_AUTH = createKey(Boolean.class, "internal_wantClientAuth");
75
76 public static final AttributeKey<Boolean> NEED_CLIENT_AUTH = createKey(Boolean.class, "internal_needClientAuth");
77
78
79
80 private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);
81
82 private ByteBuffer previous = null;
83
84
85
86
87
88
89 public SslHelper(IoSession session, SSLContext sslContext) {
90 this.session = session;
91 this.sslContext = sslContext;
92 }
93
94
95
96
97
98 return session;
99 }
100
101
102
103
104
105 return sslEngine;
106 }
107
108 boolean isHanshaking() {
109 return sslEngine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING;
110 }
111
112
113
114
115
116 public void init() {
117 if (sslEngine != null) {
118
119 return;
120 }
121
122 LOGGER.debug("{} Initializing the SSL Helper", session);
123
124 InetSocketAddress peer = session.getAttribute(PEER_ADDRESS, null);
125
126
127 if (peer == null) {
128 sslEngine = sslContext.createSSLEngine();
129 } else {
130 sslEngine = sslContext.createSSLEngine(peer.getHostName(), peer.getPort());
131 }
132
133
134 sslEngine.setUseClientMode(session.getService() instanceof IoClient);
135
136
137 if (!sslEngine.getUseClientMode()) {
138
139 boolean needClientAuth = session.getAttribute(NEED_CLIENT_AUTH, false);
140 boolean wantClientAuth = session.getAttribute(WANT_CLIENT_AUTH, false);
141
142
143 if (needClientAuth) {
144 sslEngine.setNeedClientAuth(true);
145 }
146
147 if (wantClientAuth) {
148 sslEngine.setWantClientAuth(true);
149 }
150 }
151
152 if (LOGGER.isDebugEnabled()) {
153 LOGGER.debug("{} SSL Handler Initialization done.", session);
154 }
155 }
156
157
158
159
160
161
162
163 private ByteBuffer duplicate(ByteBuffer buffer) {
164 ByteBuffer newBuffer = ByteBuffer.allocateDirect(buffer.remaining() * 2);
165 newBuffer.put(buffer);
166 newBuffer.flip();
167 return newBuffer;
168 }
169
170
171
172
173
174
175
176 private ByteBuffer accumulate(ByteBuffer buffer) {
177 if (previous.capacity() - previous.remaining() > buffer.remaining()) {
178 int oldPosition = previous.position();
179 previous.position(previous.limit());
180 previous.limit(previous.limit() + buffer.remaining());
181 previous.put(buffer);
182 previous.position(oldPosition);
183 } else {
184 ByteBuffer newPrevious = ByteBuffer.allocateDirect((previous.remaining() + buffer.remaining()) * 2);
185 newPrevious.put(previous);
186 newPrevious.put(buffer);
187 newPrevious.flip();
188 previous = newPrevious;
189 }
190 return previous;
191 }
192
193
194
195
196
197
198
199
200 public void processRead(AbstractIoSession session, ByteBuffer readBuffer) throws SSLException {
201 ByteBuffer tempBuffer;
202
203 if (previous != null) {
204 tempBuffer = accumulate(readBuffer);
205 } else {
206 tempBuffer = readBuffer;
207 }
208
209 boolean done = false;
210 SSLEngineResult result;
211 ByteBuffer appBuffer = ByteBuffer.allocateDirect(sslEngine.getSession().getApplicationBufferSize());
212
213 HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
214 while (!done) {
215 switch (handshakeStatus) {
216 case NEED_UNWRAP:
217 case NOT_HANDSHAKING:
218 case FINISHED:
219 result = sslEngine.unwrap(tempBuffer, appBuffer);
220 handshakeStatus = result.getHandshakeStatus();
221
222 switch (result.getStatus()) {
223 case BUFFER_UNDERFLOW:
224
225 done = true;
226 break;
227 case BUFFER_OVERFLOW:
228
229 appBuffer = ByteBuffer.allocateDirect(appBuffer.capacity() * 2);
230 break;
231 case OK:
232 if ((handshakeStatus == HandshakeStatus.NOT_HANDSHAKING) && (result.bytesProduced() > 0)) {
233 appBuffer.flip();
234 session.processMessageReceived(appBuffer);
235 }
236 }
237 break;
238 case NEED_TASK:
239 Runnable task;
240
241 while ((task = sslEngine.getDelegatedTask()) != null) {
242 task.run();
243 }
244 handshakeStatus = sslEngine.getHandshakeStatus();
245 break;
246 case NEED_WRAP:
247 result = sslEngine.wrap(EMPTY_BUFFER, appBuffer);
248 handshakeStatus = result.getHandshakeStatus();
249 switch (result.getStatus()) {
250 case BUFFER_OVERFLOW:
251 appBuffer = ByteBuffer.allocateDirect(appBuffer.capacity() * 2);
252 break;
253 case BUFFER_UNDERFLOW:
254 done = true;
255 break;
256 case CLOSED:
257 case OK:
258 appBuffer.flip();
259 WriteRequest writeRequest = new DefaultWriteRequest(readBuffer);
260 writeRequest.setMessage(appBuffer);
261 session.enqueueWriteRequest(writeRequest);
262 break;
263 }
264 }
265 }
266 if (tempBuffer.remaining() > 0) {
267 previous = duplicate(tempBuffer);
268 } else {
269 previous = null;
270 }
271 readBuffer.clear();
272 }
273
274
275
276
277
278
279
280
281
282
283 WriteRequest processWrite(IoSession session, Object message, Queue<WriteRequest> writeQueue) {
284 ByteBuffer buf = (ByteBuffer) message;
285 ByteBuffer appBuffer = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize());
286
287 try {
288 while (true) {
289
290 SSLEngineResult result = sslEngine.wrap(buf, appBuffer);
291
292 switch (result.getStatus()) {
293 case BUFFER_OVERFLOW:
294
295 appBuffer = ByteBuffer.allocate(appBuffer.capacity() + 4096);
296 break;
297
298 case BUFFER_UNDERFLOW:
299 case CLOSED:
300 break;
301
302 case OK:
303
304 appBuffer.flip();
305 WriteRequest request = new DefaultWriteRequest(appBuffer);
306
307 return request;
308 }
309 }
310 } catch (SSLException se) {
311 throw new IllegalStateException(se.getMessage());
312 }
313 }
314 }