1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.transport.nio;
21
22 import static org.junit.Assert.assertTrue;
23
24 import java.io.BufferedReader;
25 import java.io.IOException;
26 import java.io.InputStreamReader;
27 import java.net.InetAddress;
28 import java.net.InetSocketAddress;
29 import java.net.Socket;
30 import java.nio.ByteBuffer;
31 import java.nio.charset.Charset;
32 import java.security.GeneralSecurityException;
33 import java.security.KeyStore;
34 import java.security.Security;
35 import java.util.concurrent.CountDownLatch;
36 import java.util.concurrent.TimeUnit;
37
38 import javax.net.ssl.KeyManagerFactory;
39 import javax.net.ssl.SSLContext;
40 import javax.net.ssl.SSLSocketFactory;
41 import javax.net.ssl.TrustManagerFactory;
42
43 import org.apache.mina.api.AbstractIoHandler;
44 import org.apache.mina.api.IoSession;
45 import org.apache.mina.transport.nio.NioTcpServer;
46 import org.junit.Ignore;
47 import org.junit.Test;
48
49
50
51
52
53
54
55 public class SslTest {
56 private static Exception clientError = null;
57
58 private static InetAddress address;
59
60 private static SSLSocketFactory factory;
61
62
63 private static final String KEY_MANAGER_FACTORY_ALGORITHM;
64
65 static {
66 String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm");
67 if (algorithm == null) {
68 algorithm = KeyManagerFactory.getDefaultAlgorithm();
69 }
70
71 KEY_MANAGER_FACTORY_ALGORITHM = algorithm;
72 }
73
74 private static class TestHandler extends AbstractIoHandler {
75 public void messageReceived(IoSession session, Object message) {
76 String line = Charset.defaultCharset().decode((ByteBuffer) message).toString();
77
78 if (line.startsWith("hello")) {
79 System.out.println("Server got: 'hello', waiting for 'send'");
80 } else if (line.startsWith("send")) {
81 System.out.println("Server got: 'send', sending 'data'");
82 session.write(Charset.defaultCharset().encode("data\n"));
83 }
84 }
85 }
86
87
88
89
90
91 private static int startServer() throws Exception {
92 NioTcpServer server = new NioTcpServer();
93
94 server.setReuseAddress(true);
95 server.getSessionConfig().setSslContext(createSSLContext());
96 server.setIoHandler(new TestHandler());
97 server.bind(new InetSocketAddress(0));
98 return server.getServerSocketChannel().socket().getLocalPort();
99 }
100
101
102
103
104 private static void startClient(int port) throws Exception {
105 address = InetAddress.getByName("localhost");
106
107 SSLContext context = createSSLContext();
108 factory = context.getSocketFactory();
109
110 connectAndSend(port);
111
112
113 connectAndSend(port);
114 }
115
116 private static void connectAndSend(int port) throws Exception {
117 Socket parent = new Socket(address, port);
118 Socket socket = factory.createSocket(parent, address.getCanonicalHostName(), port, false);
119
120 System.out.println("Client sending: hello");
121 socket.getOutputStream().write("hello \n".getBytes());
122 socket.getOutputStream().flush();
123 socket.setSoTimeout(10000);
124
125 System.out.println("Client sending: send");
126 socket.getOutputStream().write("send\n".getBytes());
127 socket.getOutputStream().flush();
128
129 BufferedReader in = new BufferedReader(new InputStreamReader(socket.getInputStream()));
130 String line = in.readLine();
131 System.out.println("Client got: " + line);
132 socket.close();
133
134 }
135
136 private static SSLContext createSSLContext() throws IOException, GeneralSecurityException {
137 char[] passphrase = "password".toCharArray();
138
139 SSLContext ctx = SSLContext.getInstance("TLS");
140 KeyManagerFactory kmf = KeyManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
141 TrustManagerFactory tmf = TrustManagerFactory.getInstance(KEY_MANAGER_FACTORY_ALGORITHM);
142
143 KeyStore ks = KeyStore.getInstance("JKS");
144 KeyStore ts = KeyStore.getInstance("JKS");
145
146 ks.load(SslTest.class.getResourceAsStream("keystore.sslTest"), passphrase);
147 ts.load(SslTest.class.getResourceAsStream("truststore.sslTest"), passphrase);
148
149 kmf.init(ks, passphrase);
150 tmf.init(ts);
151 ctx.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
152
153 return ctx;
154 }
155
156 @Test
157 @Ignore("check for fragmentation")
158 public void testSSL() throws Exception {
159 final int port = startServer();
160
161 Thread t = new Thread() {
162 public void run() {
163 try {
164 startClient(port);
165 } catch (Exception e) {
166 clientError = e;
167 }
168 }
169 };
170 t.start();
171 t.join();
172 if (clientError != null)
173 throw clientError;
174 }
175
176 @Test
177 public void testBigMessage() throws IOException, GeneralSecurityException, InterruptedException {
178 final CountDownLatch counter = new CountDownLatch(1);
179 NioTcpServer server = new NioTcpServer();
180 final int messageSize = 1 * 1024 * 1024;
181
182
183
184
185 server.setReuseAddress(true);
186 server.getSessionConfig().setSslContext(createSSLContext());
187 server.setIoHandler(new AbstractIoHandler() {
188 private int receivedSize = 0;
189
190
191
192
193 @Override
194 public void messageReceived(IoSession session, Object message) {
195 receivedSize += ((ByteBuffer) message).remaining();
196 if (receivedSize == messageSize) {
197 counter.countDown();
198 }
199 }
200 });
201 server.bind(new InetSocketAddress(0));
202 int port = server.getServerSocketChannel().socket().getLocalPort();
203
204
205
206
207 Socket socket = server.getSessionConfig().getSslContext().getSocketFactory().createSocket("localhost", port);
208 socket.getOutputStream().write(new byte[messageSize]);
209 socket.getOutputStream().flush();
210 socket.close();
211 assertTrue(counter.await(10, TimeUnit.SECONDS));
212
213 }
214 }