JSSE: avoid potential deadlock between SSLSocket.close() and Input/OutputStream.close()

pull/220/head
Chris Conlon 2024-09-12 15:32:31 -06:00
parent a9c28d7377
commit 3f80193da8
1 changed files with 127 additions and 50 deletions

View File

@ -35,6 +35,7 @@ import java.util.ArrayList;
import java.util.function.BiFunction;
import java.util.List;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicBoolean;
import java.nio.channels.SocketChannel;
import java.security.cert.CertificateEncodingException;
@ -96,6 +97,9 @@ public class WolfSSLSocket extends SSLSocket {
* accessing WolfSSLSession object / WOLFSSL struct */
private final Object ioLock = new Object();
/* lock for get/set of SO timeout */
private final Object timeoutLock = new Object();
/** ALPN selector callback, if set */
protected BiFunction<SSLSocket, List<String>, String> alpnSelector = null;
@ -1733,11 +1737,14 @@ public class WolfSSLSocket extends SSLSocket {
* @throws SocketException if there is an error setting the timeout value
*/
@Override
public synchronized void setSoTimeout(int timeout) throws SocketException {
if (this.socket != null) {
this.socket.setSoTimeout(timeout);
} else {
super.setSoTimeout(timeout);
public void setSoTimeout(int timeout) throws SocketException {
/* timeoutLock synchronizes get/set of timeout */
synchronized (timeoutLock) {
if (this.socket != null) {
this.socket.setSoTimeout(timeout);
} else {
super.setSoTimeout(timeout);
}
}
}
@ -1749,11 +1756,14 @@ public class WolfSSLSocket extends SSLSocket {
* @throws SocketException if there is an error getting timeout value
*/
@Override
public synchronized int getSoTimeout() throws SocketException {
if (this.socket != null) {
return this.socket.getSoTimeout();
} else {
return super.getSoTimeout();
public int getSoTimeout() throws SocketException {
/* timeoutLock synchronizes get/set of timeout */
synchronized (timeoutLock) {
if (this.socket != null) {
return this.socket.getSoTimeout();
} else {
return super.getSoTimeout();
}
}
}
@ -1938,13 +1948,17 @@ public class WolfSSLSocket extends SSLSocket {
this.EngineHelper.clearObjectState();
this.EngineHelper = null;
/* Release Input/OutputStream objects */
/* Release Input/OutputStream objects. Do not
* close WolfSSLSocket inside stream close,
* since we handle that next below and do
* differently depending on if autoClose has been
* set or not. */
if (this.inStream != null) {
this.inStream.close();
this.inStream.close(false);
this.inStream = null;
}
if (this.outStream != null) {
this.outStream.close();
this.outStream.close(false);
this.outStream = null;
}
@ -2376,35 +2390,65 @@ public class WolfSSLSocket extends SSLSocket {
private WolfSSLSocket socket;
private boolean isClosed = true;
/* Atomic boolean to indicate if this InputStream has started to
* close. Protects against deadlock between two threads calling
* SSLSocket.close() and InputStream.close() simulatenously. */
private AtomicBoolean isClosing = new AtomicBoolean(false);
public WolfSSLInputStream(WolfSSLSession ssl, WolfSSLSocket socket) {
this.ssl = ssl;
this.socket = socket; /* parent socket */
this.isClosed = false;
}
public synchronized void close() throws IOException {
/**
* Close InputStream, but gives caller option to close underlying
* Socket or not.
*
* @param closeSocket close underlying WolfSSLSocket if set to true,
* otherwise if false leave WolfSSLSocket open.
*/
protected void close(boolean closeSocket) throws IOException {
if (this.socket == null || this.isClosed) {
return;
}
if (isClosing.compareAndSet(false, true)) {
if (this.socket.isClosed()) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"socket (input) already closed");
synchronized (this) {
if (closeSocket) {
if (this.socket == null || this.isClosed) {
return;
}
if (this.socket.isClosed()) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"socket (input) already closed");
}
else {
this.socket.close();
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"socket (input) closed: " + this.socket);
}
}
this.socket = null;
this.ssl = null;
this.isClosed = true;
}
}
else {
this.socket.close();
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"socket (input) closed: " + this.socket);
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"InputStream already in process of being closed");
}
this.socket = null;
this.ssl = null;
this.isClosed = true;
return;
}
/**
* Close InputStream, also closes internal WolfSSLSocket.
*/
public void close() throws IOException {
close(true);
}
@Override
public synchronized int read() throws IOException {
@ -2487,11 +2531,12 @@ public class WolfSSLSocket extends SSLSocket {
try {
int err;
int timeout = socket.getSoTimeout();
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"ssl.read() socket timeout = " + socket.getSoTimeout());
"ssl.read() socket timeout = " + timeout);
ret = ssl.read(b, off, len, socket.getSoTimeout());
ret = ssl.read(b, off, len, timeout);
err = ssl.getError(ret);
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
@ -2519,16 +2564,17 @@ public class WolfSSLSocket extends SSLSocket {
* end of stream */
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"Native wolfSSL_read() error: " + errStr +
" (error code: " + err + "), end of stream");
" (error code: " + err + "ret: " + ret +
"), end of stream");
return -1;
} else {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"Native wolfSSL_read() error: " + errStr +
" (error code: " + err + ")");
" (error code: " + err + ", ret: " + ret + ")");
throw new IOException("Native wolfSSL_read() " +
"error: " + errStr +
" (error code: " + err + ")");
" (error code: " + err + ", ret: " + ret + ")");
}
}
@ -2557,35 +2603,65 @@ public class WolfSSLSocket extends SSLSocket {
private WolfSSLSocket socket;
private boolean isClosed = true;
/* Atomic boolean to indicate if this InputStream has started to
* close. Protects against deadlock between two threads calling
* SSLSocket.close() and InputStream.close() simulatenously. */
private AtomicBoolean isClosing = new AtomicBoolean(false);
public WolfSSLOutputStream(WolfSSLSession ssl, WolfSSLSocket socket) {
this.ssl = ssl;
this.socket = socket; /* parent socket */
this.isClosed = false;
}
public synchronized void close() throws IOException {
/**
* Close OutputStream, but gives caller option to close underlying
* Socket or not.
*
* @param closeSocket close underlying WolfSSLSocket if set to true,
* otherwise if false leave WolfSSLSocket open.
*/
protected void close(boolean closeSocket) throws IOException {
if (this.socket == null || this.isClosed) {
return;
}
if (isClosing.compareAndSet(false, true)) {
if (this.socket.isClosed()) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"socket (output) already closed");
synchronized (this) {
if (closeSocket) {
if (this.socket == null || this.isClosed) {
return;
}
if (this.socket.isClosed()) {
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"socket (output) already closed");
}
else {
this.socket.close();
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"socket (output) closed: " + this.socket);
}
}
this.socket = null;
this.ssl = null;
this.isClosed = true;
}
}
else {
this.socket.close();
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"socket (output) closed: " + this.socket);
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"OutputStream already in process of being closed");
}
this.socket = null;
this.ssl = null;
this.isClosed = true;
return;
}
/**
* Close OutputStream, also closes internal WolfSSLSocket.
*/
public void close() throws IOException {
this.close(true);
}
public synchronized void write(int b) throws IOException {
byte[] data = new byte[1];
data[0] = (byte)(b & 0xFF);
@ -2606,7 +2682,8 @@ public class WolfSSLSocket extends SSLSocket {
throw new NullPointerException("Input array is null");
}
if (this.socket == null || this.isClosed) {
/* check if socket is closed */
if (this.isClosed || socket == null || socket.isClosed()) {
throw new SocketException("Socket is closed");
}
@ -2642,12 +2719,12 @@ public class WolfSSLSocket extends SSLSocket {
try {
int err;
int timeout = socket.getSoTimeout();
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,
"ssl.write() socket timeout = " +
socket.getSoTimeout());
"ssl.write() socket timeout = " + timeout);
ret = ssl.write(b, off, len, socket.getSoTimeout());
ret = ssl.write(b, off, len, timeout);
err = ssl.getError(ret);
WolfSSLDebug.log(getClass(), WolfSSLDebug.INFO,