1
0
mirror of https://github.com/owncloud/android-library.git synced 2025-06-08 00:16:09 +00:00

Add write timeout if socket implementation supports it, so that chances of blocking and upload if network is removed are minimized

This commit is contained in:
David A. Velasco 2017-04-21 16:10:16 +02:00
parent 02e3a90df3
commit 9fe7c995dd
2 changed files with 264 additions and 92 deletions

View File

@ -50,18 +50,17 @@ import org.apache.http.conn.ssl.X509HostnameVerifier;
import com.owncloud.android.lib.common.utils.Log_OC; import com.owncloud.android.lib.common.utils.Log_OC;
/** /**
* AdvancedSSLProtocolSocketFactory allows to create SSL {@link Socket}s with * AdvancedSSLProtocolSocketFactory allows to create SSL {@link Socket}s with
* a custom SSLContext and an optional Hostname Verifier. * a custom SSLContext and an optional Hostname Verifier.
* *
* @author David A. Velasco * @author David A. Velasco
*/ */
public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory {
private static final String TAG = AdvancedSslSocketFactory.class.getSimpleName(); private static final String TAG = AdvancedSslSocketFactory.class.getSimpleName();
private SSLContext mSslContext = null; private SSLContext mSslContext = null;
private AdvancedX509TrustManager mTrustManager = null; private AdvancedX509TrustManager mTrustManager = null;
private X509HostnameVerifier mHostnameVerifier = null; private X509HostnameVerifier mHostnameVerifier = null;
@ -69,33 +68,33 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory {
public SSLContext getSslContext() { public SSLContext getSslContext() {
return mSslContext; return mSslContext;
} }
/** /**
* Constructor for AdvancedSSLProtocolSocketFactory. * Constructor for AdvancedSSLProtocolSocketFactory.
*/ */
public AdvancedSslSocketFactory( public AdvancedSslSocketFactory(
SSLContext sslContext, AdvancedX509TrustManager trustManager, X509HostnameVerifier hostnameVerifier SSLContext sslContext, AdvancedX509TrustManager trustManager, X509HostnameVerifier hostnameVerifier
) { ) {
if (sslContext == null) if (sslContext == null)
throw new IllegalArgumentException("AdvancedSslSocketFactory can not be created with a null SSLContext"); throw new IllegalArgumentException("AdvancedSslSocketFactory can not be created with a null SSLContext");
if (trustManager == null && mHostnameVerifier != null) if (trustManager == null && mHostnameVerifier != null)
throw new IllegalArgumentException( throw new IllegalArgumentException(
"AdvancedSslSocketFactory can not be created with a null Trust Manager and a " + "AdvancedSslSocketFactory can not be created with a null Trust Manager and a " +
"not null Hostname Verifier" "not null Hostname Verifier"
); );
mSslContext = sslContext; mSslContext = sslContext;
mTrustManager = trustManager; mTrustManager = trustManager;
mHostnameVerifier = hostnameVerifier; mHostnameVerifier = hostnameVerifier;
} }
/** /**
* @see ProtocolSocketFactory#createSocket(java.lang.String,int,java.net.InetAddress,int) * @see ProtocolSocketFactory#createSocket(java.lang.String, int, java.net.InetAddress, int)
*/ */
@Override @Override
public Socket createSocket(String host, int port, InetAddress clientHost, int clientPort) public Socket createSocket(String host, int port, InetAddress clientHost, int clientPort)
throws IOException, UnknownHostException { throws IOException, UnknownHostException {
Socket socket = mSslContext.getSocketFactory().createSocket(host, port, clientHost, clientPort); Socket socket = mSslContext.getSocketFactory().createSocket(host, port, clientHost, clientPort);
enableSecureProtocols(socket); enableSecureProtocols(socket);
verifyPeerIdentity(host, port, socket); verifyPeerIdentity(host, port, socket);
@ -142,36 +141,34 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory {
} }
*/ */
/** /**
* Attempts to get a new socket connection to the given host within the * Attempts to get a new socket connection to the given host within the
* given time limit. * given time limit.
* *
* @param host the host name/IP * @param host the host name/IP
* @param port the port on the host * @param port the port on the host
* @param clientHost the local host name/IP to bind the socket to * @param localAddress the local host name/IP to bind the socket to
* @param clientPort the port on the local machine * @param localPort the port on the local machine
* @param params {@link HttpConnectionParams Http connection parameters} * @param params {@link HttpConnectionParams Http connection parameters}
*
* @return Socket a new socket * @return Socket a new socket
* * @throws IOException if an I/O error occurs while creating the socket
* @throws IOException if an I/O error occurs while creating the socket
* @throws UnknownHostException if the IP address of the host cannot be * @throws UnknownHostException if the IP address of the host cannot be
* determined * determined
*/ */
@Override @Override
public Socket createSocket(final String host, final int port, public Socket createSocket(final String host, final int port,
final InetAddress localAddress, final int localPort, final InetAddress localAddress, final int localPort,
final HttpConnectionParams params) throws IOException, final HttpConnectionParams params) throws IOException,
UnknownHostException, ConnectTimeoutException { UnknownHostException, ConnectTimeoutException {
Log_OC.d(TAG, "Creating SSL Socket with remote " + host + ":" + port + ", local " + localAddress + ":" + Log_OC.d(TAG, "Creating SSL Socket with remote " + host + ":" + port + ", local " + localAddress + ":" +
localPort + ", params: " + params); localPort + ", params: " + params);
if (params == null) { if (params == null) {
throw new IllegalArgumentException("Parameters may not be null"); throw new IllegalArgumentException("Parameters may not be null");
} }
int timeout = params.getConnectionTimeout(); int timeout = params.getConnectionTimeout();
//logSslInfo(); //logSslInfo();
SocketFactory socketfactory = mSslContext.getSocketFactory(); SocketFactory socketfactory = mSslContext.getSocketFactory();
Log_OC.d(TAG, " ... with connection timeout " + timeout + " and socket timeout " + params.getSoTimeout()); Log_OC.d(TAG, " ... with connection timeout " + timeout + " and socket timeout " + params.getSoTimeout());
Socket socket = socketfactory.createSocket(); Socket socket = socketfactory.createSocket();
@ -179,40 +176,41 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory {
SocketAddress localaddr = new InetSocketAddress(localAddress, localPort); SocketAddress localaddr = new InetSocketAddress(localAddress, localPort);
SocketAddress remoteaddr = new InetSocketAddress(host, port); SocketAddress remoteaddr = new InetSocketAddress(host, port);
socket.setSoTimeout(params.getSoTimeout()); socket.setSoTimeout(params.getSoTimeout());
WriteTimeoutEnforcer.setSoWriteTimeout(params.getSoTimeout(), socket);
socket.bind(localaddr); socket.bind(localaddr);
ServerNameIndicator.setServerNameIndication(host, (SSLSocket)socket); ServerNameIndicator.setServerNameIndication(host, (SSLSocket) socket);
socket.connect(remoteaddr, timeout); socket.connect(remoteaddr, timeout);
verifyPeerIdentity(host, port, socket); verifyPeerIdentity(host, port, socket);
return socket; return socket;
} }
/** /**
* @see ProtocolSocketFactory#createSocket(java.lang.String,int) * @see ProtocolSocketFactory#createSocket(java.lang.String, int)
*/ */
@Override @Override
public Socket createSocket(String host, int port) throws IOException, public Socket createSocket(String host, int port) throws IOException,
UnknownHostException { UnknownHostException {
Log_OC.d(TAG, "Creating SSL Socket with remote " + host + ":" + port); Log_OC.d(TAG, "Creating SSL Socket with remote " + host + ":" + port);
Socket socket = mSslContext.getSocketFactory().createSocket(host, port); Socket socket = mSslContext.getSocketFactory().createSocket(host, port);
enableSecureProtocols(socket); enableSecureProtocols(socket);
verifyPeerIdentity(host, port, socket); verifyPeerIdentity(host, port, socket);
return socket; return socket;
}
@Override
public Socket createSocket(Socket socket, String host, int port, boolean autoClose) throws IOException,
UnknownHostException {
Socket sslSocket = mSslContext.getSocketFactory().createSocket(socket, host, port, autoClose);
enableSecureProtocols(sslSocket);
verifyPeerIdentity(host, port, sslSocket);
return sslSocket;
} }
@Override
public Socket createSocket(Socket socket, String host, int port, boolean autoClose) throws IOException,
UnknownHostException {
Socket sslSocket = mSslContext.getSocketFactory().createSocket(socket, host, port, autoClose);
enableSecureProtocols(sslSocket);
verifyPeerIdentity(host, port, sslSocket);
return sslSocket;
}
public boolean equals(Object obj) { public boolean equals(Object obj) {
return ((obj != null) && obj.getClass().equals( return ((obj != null) && obj.getClass().equals(
AdvancedSslSocketFactory.class)); AdvancedSslSocketFactory.class));
} }
public int hashCode() { public int hashCode() {
@ -223,19 +221,20 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory {
public X509HostnameVerifier getHostNameVerifier() { public X509HostnameVerifier getHostNameVerifier() {
return mHostnameVerifier; return mHostnameVerifier;
} }
public void setHostNameVerifier(X509HostnameVerifier hostnameVerifier) { public void setHostNameVerifier(X509HostnameVerifier hostnameVerifier) {
mHostnameVerifier = hostnameVerifier; mHostnameVerifier = hostnameVerifier;
} }
/** /**
* Verifies the identity of the server. * Verifies the identity of the server.
* * <p>
* The server certificate is verified first. * The server certificate is verified first.
* * <p>
* Then, the host name is compared with the content of the server certificate using the current host name verifier, * Then, the host name is compared with the content of the server certificate using the current host name verifier,
* if any. * if any.
*
* @param socket * @param socket
*/ */
private void verifyPeerIdentity(String host, int port, Socket socket) throws IOException { private void verifyPeerIdentity(String host, int port, Socket socket) throws IOException {
@ -246,31 +245,31 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory {
try { try {
SSLSocket sock = (SSLSocket) socket; // a new SSLSession instance is created as a "side effect" SSLSocket sock = (SSLSocket) socket; // a new SSLSession instance is created as a "side effect"
sock.startHandshake(); sock.startHandshake();
} catch (RuntimeException e) { } catch (RuntimeException e) {
if (e instanceof CertificateCombinedException) { if (e instanceof CertificateCombinedException) {
failInHandshake = (CertificateCombinedException) e; failInHandshake = (CertificateCombinedException) e;
} else { } else {
Throwable cause = e.getCause(); Throwable cause = e.getCause();
Throwable previousCause = null; Throwable previousCause = null;
while ( cause != null && while (cause != null &&
cause != previousCause && cause != previousCause &&
!(cause instanceof CertificateCombinedException)) { !(cause instanceof CertificateCombinedException)) {
previousCause = cause; previousCause = cause;
cause = cause.getCause(); cause = cause.getCause();
} }
if (cause != null && cause instanceof CertificateCombinedException) { if (cause != null && cause instanceof CertificateCombinedException) {
failInHandshake = (CertificateCombinedException)cause; failInHandshake = (CertificateCombinedException) cause;
} }
} }
if (failInHandshake == null) { if (failInHandshake == null) {
throw e; throw e;
} }
failInHandshake.setHostInUrl(host); failInHandshake.setHostInUrl(host);
} }
/// 2. VERIFY HOSTNAME /// 2. VERIFY HOSTNAME
SSLSession newSession = null; SSLSession newSession = null;
boolean verifiedHostname = true; boolean verifiedHostname = true;
@ -283,12 +282,12 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory {
} catch (SSLException e) { } catch (SSLException e) {
verifiedHostname = false; verifiedHostname = false;
} }
} else { } else {
/// 2.2 : a new SSLSession instance was created in the handshake /// 2.2 : a new SSLSession instance was created in the handshake
newSession = ((SSLSocket)socket).getSession(); newSession = ((SSLSocket) socket).getSession();
if (!mTrustManager.isKnownServer((X509Certificate)(newSession.getPeerCertificates()[0]))) { if (!mTrustManager.isKnownServer((X509Certificate) (newSession.getPeerCertificates()[0]))) {
verifiedHostname = mHostnameVerifier.verify(host, newSession); verifiedHostname = mHostnameVerifier.verify(host, newSession);
} }
} }
} }
@ -296,25 +295,25 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory {
/// 3. Combine the exceptions to throw, if any /// 3. Combine the exceptions to throw, if any
if (!verifiedHostname) { if (!verifiedHostname) {
SSLPeerUnverifiedException pue = new SSLPeerUnverifiedException( SSLPeerUnverifiedException pue = new SSLPeerUnverifiedException(
"Names in the server certificate do not match to " + host + " in the URL" "Names in the server certificate do not match to " + host + " in the URL"
); );
if (failInHandshake == null) { if (failInHandshake == null) {
failInHandshake = new CertificateCombinedException( failInHandshake = new CertificateCombinedException(
(X509Certificate) newSession.getPeerCertificates()[0] (X509Certificate) newSession.getPeerCertificates()[0]
); );
failInHandshake.setHostInUrl(host); failInHandshake.setHostInUrl(host);
} }
failInHandshake.setSslPeerUnverifiedException(pue); failInHandshake.setSslPeerUnverifiedException(pue);
pue.initCause(failInHandshake); pue.initCause(failInHandshake);
throw pue; throw pue;
} else if (failInHandshake != null) { } else if (failInHandshake != null) {
SSLHandshakeException hse = new SSLHandshakeException("Server certificate could not be verified"); SSLHandshakeException hse = new SSLHandshakeException("Server certificate could not be verified");
hse.initCause(failInHandshake); hse.initCause(failInHandshake);
throw hse; throw hse;
} }
} catch (IOException io) { } catch (IOException io) {
try { try {
socket.close(); socket.close();
} catch (Exception x) { } catch (Exception x) {
@ -324,22 +323,22 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory {
} }
} }
/** /**
* Grants that all protocols supported by the Security Provider in mSslContext are enabled in socket. * Grants that all protocols supported by the Security Provider in mSslContext are enabled in socket.
* * <p>
* Grants also that no unsupported protocol is tried to be enabled. That would trigger an exception, breaking * Grants also that no unsupported protocol is tried to be enabled. That would trigger an exception, breaking
* the connection process although some protocols are supported. * the connection process although some protocols are supported.
* * <p>
* This is not cosmetic: not all the supported protocols are enabled by default. Too see an overview of * This is not cosmetic: not all the supported protocols are enabled by default. Too see an overview of
* supported and enabled protocols in the stock Security Provider in Android see the tables in * supported and enabled protocols in the stock Security Provider in Android see the tables in
* http://developer.android.com/reference/javax/net/ssl/SSLSocket.html. * http://developer.android.com/reference/javax/net/ssl/SSLSocket.html.
* *
* @param socket * @param socket
*/ */
private void enableSecureProtocols(Socket socket) { private void enableSecureProtocols(Socket socket) {
SSLParameters params = mSslContext.getSupportedSSLParameters(); SSLParameters params = mSslContext.getSupportedSSLParameters();
String [] supportedProtocols = params.getProtocols(); String[] supportedProtocols = params.getProtocols();
((SSLSocket) socket).setEnabledProtocols(supportedProtocols); ((SSLSocket) socket).setEnabledProtocols(supportedProtocols);
} }
} }

View File

@ -0,0 +1,173 @@
/* ownCloud Android Library is available under MIT license
* Copyright (C) 2017 ownCloud GmbH.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
* BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
* ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
* CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*
*/
package com.owncloud.android.lib.common.network;
import com.owncloud.android.lib.common.utils.Log_OC;
import java.lang.ref.WeakReference;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.Socket;
import java.util.concurrent.atomic.AtomicReference;
/**
* Enforces, if possible, a write timeout for a socket.
*
* Built as a singleton.
*
* Tries to hit something like this:
* https://android.googlesource.com/platform/external/conscrypt/+/lollipop-release/src/main/java/org/conscrypt/OpenSSLSocketImpl.java#1005
*
* Minimizes the chances of getting stalled in PUT/POST request if the network interface is lost while
* writing the entity into the outwards sockect.
*
* It happens. See https://github.com/owncloud/android/issues/1684#issuecomment-295306015
*
* @author David A. Velasco
*/
public class WriteTimeoutEnforcer {
private static final String TAG = WriteTimeoutEnforcer.class.getSimpleName();
private static final AtomicReference<WriteTimeoutEnforcer> mSingleInstance = new AtomicReference<>();
private static final String METHOD_NAME = "setSoWriteTimeout";
private final WeakReference<Class<?>> mSocketClassRef;
private final WeakReference<Method> mSetSoWriteTimeoutMethodRef;
/**
* Private constructor, class is a singleton.
*
* @param socketClass Underlying implementation class of {@link Socket} used to connect
* with the server.
* @param setSoWriteTimeoutMethod Name of the method to call to set a write timeout in the socket.
*/
private WriteTimeoutEnforcer(Class<?> socketClass, Method setSoWriteTimeoutMethod) {
mSocketClassRef = new WeakReference<Class<?>>(socketClass);
mSetSoWriteTimeoutMethodRef =
(setSoWriteTimeoutMethod == null) ?
null :
new WeakReference<>(setSoWriteTimeoutMethod)
;
}
/**
* Calls the {@code #setSoWrite(int)} method of the underlying implementation
* of {@link Socket} if exists.
* Creates and initializes the single instance of the class when needed
*
* @param writeTimeoutMilliseconds Write timeout to set, in milliseconds.
* @param socket Client socket to connect with the server.
*/
public static void setSoWriteTimeout(int writeTimeoutMilliseconds, Socket socket) {
final Method setSoWriteTimeoutMethod = getMethod(socket);
if (setSoWriteTimeoutMethod != null) {
try {
setSoWriteTimeoutMethod.invoke(socket, writeTimeoutMilliseconds);
Log_OC.i(
TAG,
"Write timeout set in socket, writeTimeoutMilliseconds: "
+ writeTimeoutMilliseconds
);
} catch (IllegalArgumentException e) {
Log_OC.e(TAG, "Call to (SocketImpl)#setSoWriteTimeout(int) failed ", e);
} catch (IllegalAccessException e) {
Log_OC.e(TAG, "Call to (SocketImpl)#setSoWriteTimeout(int) failed ", e);
} catch (InvocationTargetException e) {
Log_OC.e(TAG, "Call to (SocketImpl)#setSoWriteTimeout(int) failed ", e);
}
} else {
Log_OC.i(TAG, "Write timeout for socket not supported");
}
}
/**
* Gets the method to invoke trying to minimize the cost of reflection reusing objects cached
* in static members.
*
* @param socket Instance of the socket to use in connection with server.
* @return Method to call to set a write timeout in the socket.
*/
private static Method getMethod(Socket socket) {
final Class<?> socketClass = socket.getClass();
final WriteTimeoutEnforcer instance = mSingleInstance.get();
if (instance == null) {
return initFrom(socketClass);
} else if (instance.mSocketClassRef.get() != socketClass) {
// the underlying class changed
return initFrom(socketClass);
} else if (instance.mSetSoWriteTimeoutMethodRef == null) {
// method not supported
return null;
} else {
final Method cachedSetSoWriteTimeoutMethod = instance.mSetSoWriteTimeoutMethodRef.get();
return (cachedSetSoWriteTimeoutMethod == null) ?
initFrom(socketClass) :
cachedSetSoWriteTimeoutMethod
;
}
}
/**
* Singleton initializer.
*
* Uses reflection to extract and 'cache' the method to invoke to set a write timouet in a socket.
*
* @param socketClass Underlying class providing the implementation of {@link Socket}.
* @return Method to call to set a write timeout in the socket.
*/
private static Method initFrom(Class<?> socketClass) {
Log_OC.i(TAG, "Socket implementation: " + socketClass.getCanonicalName());
Method setSoWriteTimeoutMethod = null;
try {
setSoWriteTimeoutMethod = socketClass.getMethod(METHOD_NAME, int.class);
} catch (SecurityException e) {
Log_OC.e(TAG, "Could not access to (SocketImpl)#setSoWriteTimeout(int) method ", e);
} catch (NoSuchMethodException e) {
Log_OC.i(
TAG,
"Could not find (SocketImpl)#setSoWriteTimeout(int) method - write timeout not supported"
);
}
mSingleInstance.set(new WriteTimeoutEnforcer(socketClass, setSoWriteTimeoutMethod));
return setSoWriteTimeoutMethod;
}
}