JSSE: fixes for calling SSLSocket methods after SSLSocket.close() has been called

pull/233/head
Chris Conlon 2024-11-22 15:25:27 -07:00
parent 11f6f4b5cd
commit 36f54b02e8
6 changed files with 323 additions and 15 deletions

View File

@ -65,7 +65,9 @@ public class WolfSSLDebug {
* Will be used to determine what string gets put into log messages.
*/
public enum Component {
/** wolfSSL JNI component */
JNI("wolfJNI"),
/** wolfSSL JSSE component */
JSSE("wolfJSSE");
private final String componentString;

View File

@ -1394,7 +1394,7 @@ public class WolfSSLEngine extends SSLEngine {
public String[] getSupportedCipherSuites() {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getSupportedCipherSuites()");
return this.engineHelper.getAllCiphers();
return WolfSSLEngineHelper.getAllCiphers();
}
@Override
@ -1415,7 +1415,7 @@ public class WolfSSLEngine extends SSLEngine {
public String[] getSupportedProtocols() {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getSupportedProtocols()");
return this.engineHelper.getAllProtocols();
return WolfSSLEngineHelper.getAllProtocols();
}
@Override

View File

@ -454,7 +454,7 @@ public class WolfSSLEngineHelper {
*
* @return String array of all supported cipher suites
*/
protected synchronized String[] getAllCiphers() {
protected static synchronized String[] getAllCiphers() {
return WolfSSLUtil.sanitizeSuites(WolfSSL.getCiphersIana());
}
@ -551,7 +551,7 @@ public class WolfSSLEngineHelper {
*
* @return String array of supported protocols
*/
protected synchronized String[] getAllProtocols() {
protected static synchronized String[] getAllProtocols() {
return WolfSSLUtil.sanitizeProtocols(WolfSSL.getProtocols());
}

View File

@ -112,6 +112,9 @@ public class WolfSSLSocket extends SSLSocket {
/** ALPN selector callback, if set */
protected BiFunction<SSLSocket, List<String>, String> alpnSelector = null;
/* true if client, otherwise false */
private boolean isClientMode = false;
/**
* Create new WolfSSLSocket object
*
@ -143,6 +146,7 @@ public class WolfSSLSocket extends SSLSocket {
EngineHelper = new WolfSSLEngineHelper(this.ssl, this.authStore,
this.params);
EngineHelper.setUseClientMode(clientMode);
this.isClientMode = clientMode;
} catch (WolfSSLException e) {
throw new IOException(e);
@ -183,6 +187,7 @@ public class WolfSSLSocket extends SSLSocket {
EngineHelper = new WolfSSLEngineHelper(this.ssl, this.authStore,
this.params, port, host);
EngineHelper.setUseClientMode(clientMode);
this.isClientMode = clientMode;
} catch (WolfSSLException e) {
throw new IOException(e);
@ -226,6 +231,7 @@ public class WolfSSLSocket extends SSLSocket {
EngineHelper = new WolfSSLEngineHelper(this.ssl, this.authStore,
this.params, port, address);
EngineHelper.setUseClientMode(clientMode);
this.isClientMode = clientMode;
} catch (WolfSSLException e) {
throw new IOException(e);
@ -266,6 +272,7 @@ public class WolfSSLSocket extends SSLSocket {
EngineHelper = new WolfSSLEngineHelper(this.ssl, this.authStore,
this.params, port, host);
EngineHelper.setUseClientMode(clientMode);
this.isClientMode = clientMode;
} catch (WolfSSLException e) {
throw new IOException(e);
@ -309,6 +316,7 @@ public class WolfSSLSocket extends SSLSocket {
EngineHelper = new WolfSSLEngineHelper(this.ssl, this.authStore,
this.params, port, host);
EngineHelper.setUseClientMode(clientMode);
this.isClientMode = clientMode;
} catch (WolfSSLException e) {
throw new IOException(e);
@ -366,6 +374,7 @@ public class WolfSSLSocket extends SSLSocket {
EngineHelper = new WolfSSLEngineHelper(this.ssl, this.authStore,
this.params, port, host);
EngineHelper.setUseClientMode(clientMode);
this.isClientMode = clientMode;
} catch (WolfSSLException e) {
throw new IOException(e);
@ -411,6 +420,7 @@ public class WolfSSLSocket extends SSLSocket {
EngineHelper = new WolfSSLEngineHelper(this.ssl, this.authStore,
this.params, s.getPort(), s.getInetAddress());
EngineHelper.setUseClientMode(clientMode);
this.isClientMode = clientMode;
} catch (WolfSSLException e) {
throw new IOException(e);
@ -460,6 +470,7 @@ public class WolfSSLSocket extends SSLSocket {
EngineHelper = new WolfSSLEngineHelper(this.ssl, this.authStore,
this.params, s.getPort(), s.getInetAddress());
EngineHelper.setUseClientMode(false);
this.isClientMode = false;
/* register custom receive callback to read consumed first */
if (consumed != null) {
@ -1030,7 +1041,8 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getSupportedCipherSuites()");
return EngineHelper.getAllCiphers();
/* getAllCiphers() is a static method, calling directly on class */
return WolfSSLEngineHelper.getAllCiphers();
}
/**
@ -1046,6 +1058,10 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getEnabledCipherSuites()");
if (this.isClosed()) {
return null;
}
return EngineHelper.getCiphers();
}
@ -1064,6 +1080,12 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered setEnabledCipherSuites()");
if (this.isClosed()) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"SSLSocket closed, not setting enabled cipher suites");
return;
}
/* sets cipher suite(s) to be used for connection */
EngineHelper.setCiphers(suites);
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
@ -1085,6 +1107,11 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getApplicationProtocol()");
/* If socket has been closed, return an empty string */
if (this.isClosed()) {
return "";
}
return EngineHelper.getAlpnSelectedProtocolString();
}
@ -1252,8 +1279,9 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getSupportedProtocols()");
/* returns all protocol version supported by native wolfSSL */
return EngineHelper.getAllProtocols();
/* returns all protocol version supported by native wolfSSL.
/* getAllProtocols() is a static method, calling directly on class */
return WolfSSLEngineHelper.getAllProtocols();
}
/**
@ -1267,6 +1295,10 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getEnabledProtocols()");
if (this.isClosed()) {
return null;
}
/* returns protocols versions enabled for this session */
return EngineHelper.getProtocols();
}
@ -1286,6 +1318,12 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered setEnabledProtocols()");
if (this.isClosed()) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"SSLSocket closed, not setting enabled protocols");
return;
}
/* sets protocol versions to be enabled for use with this session */
EngineHelper.setProtocols(protocols);
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
@ -1337,6 +1375,15 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getSession()");
if (this.isClosed()) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"SSLSocket has been closed, returning invalid session");
/* return invalid session object with cipher suite
* "SSL_NULL_WITH_NULL_NULL" */
return new WolfSSLImplementSSLSession(this.authStore);
}
try {
/* try to do handshake if not completed yet,
* handles synchronization */
@ -1380,7 +1427,7 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getHandshakeSession()");
if (this.handshakeStarted == false) {
if ((this.handshakeStarted == false) || this.isClosed()) {
return null;
}
@ -1587,7 +1634,11 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered setUseClientMode()");
EngineHelper.setUseClientMode(mode);
if (!this.isClosed()) {
EngineHelper.setUseClientMode(mode);
}
this.isClientMode = mode;
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"socket client mode set to: " + mode);
}
@ -1603,7 +1654,7 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getUseClientMode()");
return EngineHelper.getUseClientMode();
return this.isClientMode;
}
/**
@ -1621,7 +1672,9 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered setNeedClientAuth(need: " + String.valueOf(need) + ")");
EngineHelper.setNeedClientAuth(need);
if (!this.isClosed()) {
EngineHelper.setNeedClientAuth(need);
}
}
/**
@ -1636,6 +1689,12 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getNeedClientAuth()");
/* When socket is closed, EngineHelper gets set to null. Since we
* don't cache needClientAuth value, return false after closure. */
if (this.isClosed()) {
return false;
}
return EngineHelper.getNeedClientAuth();
}
@ -1655,7 +1714,9 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered setWantClientAuth(want: " + String.valueOf(want) + ")");
EngineHelper.setWantClientAuth(want);
if (!this.isClosed()) {
EngineHelper.setWantClientAuth(want);
}
}
/**
@ -1674,6 +1735,12 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getWantClientAuth()");
/* When socket is closed, EngineHelper gets set to null. Since we
* don't cache wantClientAuth value, return false after closure. */
if (this.isClosed()) {
return false;
}
return EngineHelper.getWantClientAuth();
}
@ -1692,7 +1759,9 @@ public class WolfSSLSocket extends SSLSocket {
"entered setEnableSessionCreation(flag: " +
String.valueOf(flag) + ")");
EngineHelper.setEnableSessionCreation(flag);
if (!this.isClosed()) {
EngineHelper.setEnableSessionCreation(flag);
}
}
/**
@ -1706,6 +1775,10 @@ public class WolfSSLSocket extends SSLSocket {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"entered getEnableSessionCreation()");
if (this.isClosed()) {
return false;
}
return EngineHelper.getEnableSessionCreation();
}

View File

@ -2746,6 +2746,237 @@ public class WolfSSLSocketTest {
System.out.println("\t... passed");
}
@Test
public void testSocketMethodsAfterClose() throws Exception {
String protocol = null;
System.out.print("\tTesting methods after close");
if (WolfSSL.TLSv12Enabled()) {
protocol = "TLSv1.2";
} else if (WolfSSL.TLSv11Enabled()) {
protocol = "TLSv1.1";
} else if (WolfSSL.TLSv1Enabled()) {
protocol = "TLSv1.0";
} else {
System.out.println("\t... skipped");
return;
}
/* create new CTX */
this.ctx = tf.createSSLContext(protocol, ctxProvider);
/* create SSLServerSocket first to get ephemeral port */
SSLServerSocket ss = (SSLServerSocket)ctx.getServerSocketFactory()
.createServerSocket(0);
SSLSocket cs = (SSLSocket)ctx.getSocketFactory().createSocket();
cs.connect(new InetSocketAddress(ss.getLocalPort()));
final SSLSocket server = (SSLSocket)ss.accept();
ExecutorService es = Executors.newSingleThreadExecutor();
Future<Void> serverFuture = es.submit(new Callable<Void>() {
@Override
public Void call() throws Exception {
try {
server.startHandshake();
} catch (SSLException e) {
System.out.println("\t... failed");
fail();
}
return null;
}
});
try {
cs.startHandshake();
} catch (SSLHandshakeException e) {
System.out.println("\t... failed");
fail();
}
es.shutdown();
serverFuture.get();
cs.close();
server.close();
ss.close();
/* Test calling public SSLSocket methods after close, make sure
* exception or return value is what we expect. */
try {
cs.getApplicationProtocol();
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getApplicationProtocol() exception after close()");
}
try {
cs.getEnableSessionCreation();
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getEnableSessionCreation() exception after close()");
}
try {
cs.setEnableSessionCreation(true);
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("setEnableSessionCreation() exception after close()");
}
try {
if (cs.getWantClientAuth() != false) {
System.out.println("\t... failed");
fail("getWantClientAuth() not false after close()");
}
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getWantClientAuth() exception after close()");
}
try {
cs.setWantClientAuth(true);
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("setWantClientAuth() exception after close()");
}
try {
if (cs.getNeedClientAuth() != false) {
System.out.println("\t... failed");
fail("getNeedClientAuth() not false after close()");
}
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getNeedClientAuth() exception after close()");
}
try {
cs.setNeedClientAuth(true);
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("setNeedClientAuth() exception after close()");
}
try {
if (cs.getUseClientMode() != true) {
System.out.println("\t... failed");
fail("getUseClientMode() on client not true after close()");
}
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getUseClientMode() exception after close()");
}
try {
cs.setUseClientMode(true);
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("setUseClientMode() exception after close()");
}
try {
if (cs.getHandshakeSession() != null) {
System.out.println("\t... failed");
fail("getHandshakeSession() not null after close()");
}
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getHandshakeSession() exception after close()");
}
try {
SSLSession closeSess = cs.getSession();
if (closeSess == null ||
!closeSess.getCipherSuite().equals("SSL_NULL_WITH_NULL_NULL")) {
System.out.println("\t... failed");
fail("getSession() null or wrong cipher suite after close()");
}
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getSession() exception after close()");
}
try {
if (cs.getEnabledProtocols() != null) {
System.out.println("\t... failed");
fail("getEnabledProtocols() not null after close()");
}
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getEnabledProtocols() exception after close()");
}
try {
cs.setEnabledProtocols(new String[] {"INVALID"});
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("setEnabledProtocols() exception after close()");
}
try {
cs.setEnabledCipherSuites(new String[] {"INVALID"});
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("setEnabledCipherSuites() exception after close()");
}
try {
String[] suppProtos = cs.getSupportedProtocols();
if (suppProtos == null || suppProtos.length == 0) {
System.out.println("\t... failed");
fail("getSupportedProtocols() null or empty after close()");
}
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getSupportedProtocols() exception after close()");
}
try {
String[] suppSuites = cs.getSupportedCipherSuites();
if (suppSuites == null || suppSuites.length == 0) {
System.out.println("\t... failed");
fail("getSupportedCipherSuites() null or empty after close()");
}
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getSupportedCipherSuites() exception after close()");
}
try {
if (cs.getEnabledCipherSuites() != null) {
System.out.println("\t... failed");
fail("getEnabledCipherSuites() not null after close()");
}
} catch (Exception e) {
/* should not throw exception */
System.out.println("\t... failed");
fail("getEnabledCipherSuites() exception after close()");
}
System.out.println("\t... passed");
}
/**
* Inner class used to hold configuration options for
* TestServer and TestClient classes.

View File

@ -1154,11 +1154,13 @@ public class WolfSSLSessionTest {
}
if (!debugOutput.contains("connect() ret: 1")) {
System.out.println("\t... failed");
fail("Debug output did not contain connect() success");
fail("Debug output did not contain connect() success:\n" +
debugOutput);
}
if (!debugOutput.contains("accept() ret: 1")) {
System.out.println("\t... failed");
fail("Debug output did not contain accept() success");
fail("Debug output did not contain accept() success:\n" +
debugOutput);
}
}