diff --git a/src/com/owncloud/android/lib/common/network/AdvancedSslSocketFactory.java b/src/com/owncloud/android/lib/common/network/AdvancedSslSocketFactory.java index 8717590d..898b8f43 100644 --- a/src/com/owncloud/android/lib/common/network/AdvancedSslSocketFactory.java +++ b/src/com/owncloud/android/lib/common/network/AdvancedSslSocketFactory.java @@ -50,18 +50,17 @@ import org.apache.http.conn.ssl.X509HostnameVerifier; import com.owncloud.android.lib.common.utils.Log_OC; - /** - * AdvancedSSLProtocolSocketFactory allows to create SSL {@link Socket}s with + * AdvancedSSLProtocolSocketFactory allows to create SSL {@link Socket}s with * a custom SSLContext and an optional Hostname Verifier. - * + * * @author David A. Velasco */ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { private static final String TAG = AdvancedSslSocketFactory.class.getSimpleName(); - + private SSLContext mSslContext = null; private AdvancedX509TrustManager mTrustManager = null; private X509HostnameVerifier mHostnameVerifier = null; @@ -69,33 +68,33 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { public SSLContext getSslContext() { return mSslContext; } - + /** * Constructor for AdvancedSSLProtocolSocketFactory. */ public AdvancedSslSocketFactory( - SSLContext sslContext, AdvancedX509TrustManager trustManager, X509HostnameVerifier hostnameVerifier - ) { - + SSLContext sslContext, AdvancedX509TrustManager trustManager, X509HostnameVerifier hostnameVerifier + ) { + if (sslContext == null) throw new IllegalArgumentException("AdvancedSslSocketFactory can not be created with a null SSLContext"); if (trustManager == null && mHostnameVerifier != null) throw new IllegalArgumentException( - "AdvancedSslSocketFactory can not be created with a null Trust Manager and a " + - "not null Hostname Verifier" - ); + "AdvancedSslSocketFactory can not be created with a null Trust Manager and a " + + "not null Hostname Verifier" + ); mSslContext = sslContext; mTrustManager = trustManager; mHostnameVerifier = hostnameVerifier; } /** - * @see ProtocolSocketFactory#createSocket(java.lang.String,int,java.net.InetAddress,int) + * @see ProtocolSocketFactory#createSocket(java.lang.String, int, java.net.InetAddress, int) */ @Override - public Socket createSocket(String host, int port, InetAddress clientHost, int clientPort) - throws IOException, UnknownHostException { - + public Socket createSocket(String host, int port, InetAddress clientHost, int clientPort) + throws IOException, UnknownHostException { + Socket socket = mSslContext.getSocketFactory().createSocket(host, port, clientHost, clientPort); enableSecureProtocols(socket); verifyPeerIdentity(host, port, socket); @@ -142,36 +141,34 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { } */ - /** + /** * Attempts to get a new socket connection to the given host within the * given time limit. - * - * @param host the host name/IP - * @param port the port on the host - * @param clientHost the local host name/IP to bind the socket to - * @param clientPort the port on the local machine - * @param params {@link HttpConnectionParams Http connection parameters} - * + * + * @param host the host name/IP + * @param port the port on the host + * @param localAddress the local host name/IP to bind the socket to + * @param localPort the port on the local machine + * @param params {@link HttpConnectionParams Http connection parameters} * @return Socket a new socket - * - * @throws IOException if an I/O error occurs while creating the socket + * @throws IOException if an I/O error occurs while creating the socket * @throws UnknownHostException if the IP address of the host cannot be - * determined + * determined */ @Override public Socket createSocket(final String host, final int port, - final InetAddress localAddress, final int localPort, - final HttpConnectionParams params) throws IOException, - UnknownHostException, ConnectTimeoutException { - Log_OC.d(TAG, "Creating SSL Socket with remote " + host + ":" + port + ", local " + localAddress + ":" + + final InetAddress localAddress, final int localPort, + final HttpConnectionParams params) throws IOException, + UnknownHostException, ConnectTimeoutException { + Log_OC.d(TAG, "Creating SSL Socket with remote " + host + ":" + port + ", local " + localAddress + ":" + localPort + ", params: " + params); if (params == null) { throw new IllegalArgumentException("Parameters may not be null"); - } + } int timeout = params.getConnectionTimeout(); - + //logSslInfo(); - + SocketFactory socketfactory = mSslContext.getSocketFactory(); Log_OC.d(TAG, " ... with connection timeout " + timeout + " and socket timeout " + params.getSoTimeout()); Socket socket = socketfactory.createSocket(); @@ -179,40 +176,41 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { SocketAddress localaddr = new InetSocketAddress(localAddress, localPort); SocketAddress remoteaddr = new InetSocketAddress(host, port); socket.setSoTimeout(params.getSoTimeout()); + WriteTimeoutEnforcer.setSoWriteTimeout(params.getSoTimeout(), socket); socket.bind(localaddr); - ServerNameIndicator.setServerNameIndication(host, (SSLSocket)socket); + ServerNameIndicator.setServerNameIndication(host, (SSLSocket) socket); socket.connect(remoteaddr, timeout); verifyPeerIdentity(host, port, socket); return socket; } - /** - * @see ProtocolSocketFactory#createSocket(java.lang.String,int) + /** + * @see ProtocolSocketFactory#createSocket(java.lang.String, int) */ @Override public Socket createSocket(String host, int port) throws IOException, - UnknownHostException { - Log_OC.d(TAG, "Creating SSL Socket with remote " + host + ":" + port); + UnknownHostException { + Log_OC.d(TAG, "Creating SSL Socket with remote " + host + ":" + port); Socket socket = mSslContext.getSocketFactory().createSocket(host, port); enableSecureProtocols(socket); verifyPeerIdentity(host, port, socket); - return socket; + return socket; + } + + + @Override + public Socket createSocket(Socket socket, String host, int port, boolean autoClose) throws IOException, + UnknownHostException { + Socket sslSocket = mSslContext.getSocketFactory().createSocket(socket, host, port, autoClose); + enableSecureProtocols(sslSocket); + verifyPeerIdentity(host, port, sslSocket); + return sslSocket; } - - @Override - public Socket createSocket(Socket socket, String host, int port, boolean autoClose) throws IOException, - UnknownHostException { - Socket sslSocket = mSslContext.getSocketFactory().createSocket(socket, host, port, autoClose); - enableSecureProtocols(sslSocket); - verifyPeerIdentity(host, port, sslSocket); - return sslSocket; - } - public boolean equals(Object obj) { return ((obj != null) && obj.getClass().equals( - AdvancedSslSocketFactory.class)); + AdvancedSslSocketFactory.class)); } public int hashCode() { @@ -223,19 +221,20 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { public X509HostnameVerifier getHostNameVerifier() { return mHostnameVerifier; } - - + + public void setHostNameVerifier(X509HostnameVerifier hostnameVerifier) { mHostnameVerifier = hostnameVerifier; } - + /** - * Verifies the identity of the server. - * + * Verifies the identity of the server. + *

* The server certificate is verified first. - * + *

* Then, the host name is compared with the content of the server certificate using the current host name verifier, - * if any. + * if any. + * * @param socket */ private void verifyPeerIdentity(String host, int port, Socket socket) throws IOException { @@ -246,31 +245,31 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { try { SSLSocket sock = (SSLSocket) socket; // a new SSLSession instance is created as a "side effect" sock.startHandshake(); - + } catch (RuntimeException e) { - + if (e instanceof CertificateCombinedException) { failInHandshake = (CertificateCombinedException) e; } else { Throwable cause = e.getCause(); Throwable previousCause = null; - while ( cause != null && - cause != previousCause && - !(cause instanceof CertificateCombinedException)) { + while (cause != null && + cause != previousCause && + !(cause instanceof CertificateCombinedException)) { previousCause = cause; cause = cause.getCause(); } if (cause != null && cause instanceof CertificateCombinedException) { - failInHandshake = (CertificateCombinedException)cause; + failInHandshake = (CertificateCombinedException) cause; } } if (failInHandshake == null) { throw e; } failInHandshake.setHostInUrl(host); - + } - + /// 2. VERIFY HOSTNAME SSLSession newSession = null; boolean verifiedHostname = true; @@ -283,12 +282,12 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { } catch (SSLException e) { verifiedHostname = false; } - + } else { /// 2.2 : a new SSLSession instance was created in the handshake - newSession = ((SSLSocket)socket).getSession(); - if (!mTrustManager.isKnownServer((X509Certificate)(newSession.getPeerCertificates()[0]))) { - verifiedHostname = mHostnameVerifier.verify(host, newSession); + newSession = ((SSLSocket) socket).getSession(); + if (!mTrustManager.isKnownServer((X509Certificate) (newSession.getPeerCertificates()[0]))) { + verifiedHostname = mHostnameVerifier.verify(host, newSession); } } } @@ -296,25 +295,25 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { /// 3. Combine the exceptions to throw, if any if (!verifiedHostname) { SSLPeerUnverifiedException pue = new SSLPeerUnverifiedException( - "Names in the server certificate do not match to " + host + " in the URL" - ); + "Names in the server certificate do not match to " + host + " in the URL" + ); if (failInHandshake == null) { failInHandshake = new CertificateCombinedException( - (X509Certificate) newSession.getPeerCertificates()[0] - ); + (X509Certificate) newSession.getPeerCertificates()[0] + ); failInHandshake.setHostInUrl(host); } failInHandshake.setSslPeerUnverifiedException(pue); pue.initCause(failInHandshake); throw pue; - + } else if (failInHandshake != null) { SSLHandshakeException hse = new SSLHandshakeException("Server certificate could not be verified"); hse.initCause(failInHandshake); throw hse; } - - } catch (IOException io) { + + } catch (IOException io) { try { socket.close(); } catch (Exception x) { @@ -324,22 +323,22 @@ public class AdvancedSslSocketFactory implements SecureProtocolSocketFactory { } } - /** - * Grants that all protocols supported by the Security Provider in mSslContext are enabled in socket. - * - * Grants also that no unsupported protocol is tried to be enabled. That would trigger an exception, breaking - * the connection process although some protocols are supported. - * - * This is not cosmetic: not all the supported protocols are enabled by default. Too see an overview of - * supported and enabled protocols in the stock Security Provider in Android see the tables in - * http://developer.android.com/reference/javax/net/ssl/SSLSocket.html. - * - * @param socket - */ + /** + * Grants that all protocols supported by the Security Provider in mSslContext are enabled in socket. + *

+ * Grants also that no unsupported protocol is tried to be enabled. That would trigger an exception, breaking + * the connection process although some protocols are supported. + *

+ * This is not cosmetic: not all the supported protocols are enabled by default. Too see an overview of + * supported and enabled protocols in the stock Security Provider in Android see the tables in + * http://developer.android.com/reference/javax/net/ssl/SSLSocket.html. + * + * @param socket + */ private void enableSecureProtocols(Socket socket) { - SSLParameters params = mSslContext.getSupportedSSLParameters(); - String [] supportedProtocols = params.getProtocols(); - ((SSLSocket) socket).setEnabledProtocols(supportedProtocols); + SSLParameters params = mSslContext.getSupportedSSLParameters(); + String[] supportedProtocols = params.getProtocols(); + ((SSLSocket) socket).setEnabledProtocols(supportedProtocols); } - + } diff --git a/src/com/owncloud/android/lib/common/network/WriteTimeoutEnforcer.java b/src/com/owncloud/android/lib/common/network/WriteTimeoutEnforcer.java new file mode 100644 index 00000000..665a07e3 --- /dev/null +++ b/src/com/owncloud/android/lib/common/network/WriteTimeoutEnforcer.java @@ -0,0 +1,173 @@ +/* ownCloud Android Library is available under MIT license + * Copyright (C) 2017 ownCloud GmbH. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ + +package com.owncloud.android.lib.common.network; + +import com.owncloud.android.lib.common.utils.Log_OC; + +import java.lang.ref.WeakReference; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.net.Socket; +import java.util.concurrent.atomic.AtomicReference; + + +/** + * Enforces, if possible, a write timeout for a socket. + * + * Built as a singleton. + * + * Tries to hit something like this: + * https://android.googlesource.com/platform/external/conscrypt/+/lollipop-release/src/main/java/org/conscrypt/OpenSSLSocketImpl.java#1005 + * + * Minimizes the chances of getting stalled in PUT/POST request if the network interface is lost while + * writing the entity into the outwards sockect. + * + * It happens. See https://github.com/owncloud/android/issues/1684#issuecomment-295306015 + * + * @author David A. Velasco + */ +public class WriteTimeoutEnforcer { + + private static final String TAG = WriteTimeoutEnforcer.class.getSimpleName(); + + private static final AtomicReference mSingleInstance = new AtomicReference<>(); + + private static final String METHOD_NAME = "setSoWriteTimeout"; + + + private final WeakReference> mSocketClassRef; + private final WeakReference mSetSoWriteTimeoutMethodRef; + + + /** + * Private constructor, class is a singleton. + * + * @param socketClass Underlying implementation class of {@link Socket} used to connect + * with the server. + * @param setSoWriteTimeoutMethod Name of the method to call to set a write timeout in the socket. + */ + private WriteTimeoutEnforcer(Class socketClass, Method setSoWriteTimeoutMethod) { + mSocketClassRef = new WeakReference>(socketClass); + mSetSoWriteTimeoutMethodRef = + (setSoWriteTimeoutMethod == null) ? + null : + new WeakReference<>(setSoWriteTimeoutMethod) + ; + } + + + /** + * Calls the {@code #setSoWrite(int)} method of the underlying implementation + * of {@link Socket} if exists. + + * Creates and initializes the single instance of the class when needed + * + * @param writeTimeoutMilliseconds Write timeout to set, in milliseconds. + * @param socket Client socket to connect with the server. + */ + public static void setSoWriteTimeout(int writeTimeoutMilliseconds, Socket socket) { + final Method setSoWriteTimeoutMethod = getMethod(socket); + if (setSoWriteTimeoutMethod != null) { + try { + setSoWriteTimeoutMethod.invoke(socket, writeTimeoutMilliseconds); + Log_OC.i( + TAG, + "Write timeout set in socket, writeTimeoutMilliseconds: " + + writeTimeoutMilliseconds + ); + + } catch (IllegalArgumentException e) { + Log_OC.e(TAG, "Call to (SocketImpl)#setSoWriteTimeout(int) failed ", e); + + } catch (IllegalAccessException e) { + Log_OC.e(TAG, "Call to (SocketImpl)#setSoWriteTimeout(int) failed ", e); + + } catch (InvocationTargetException e) { + Log_OC.e(TAG, "Call to (SocketImpl)#setSoWriteTimeout(int) failed ", e); + } + } else { + Log_OC.i(TAG, "Write timeout for socket not supported"); + } + } + + + /** + * Gets the method to invoke trying to minimize the cost of reflection reusing objects cached + * in static members. + * + * @param socket Instance of the socket to use in connection with server. + * @return Method to call to set a write timeout in the socket. + */ + private static Method getMethod(Socket socket) { + final Class socketClass = socket.getClass(); + final WriteTimeoutEnforcer instance = mSingleInstance.get(); + if (instance == null) { + return initFrom(socketClass); + + } else if (instance.mSocketClassRef.get() != socketClass) { + // the underlying class changed + return initFrom(socketClass); + + } else if (instance.mSetSoWriteTimeoutMethodRef == null) { + // method not supported + return null; + + } else { + final Method cachedSetSoWriteTimeoutMethod = instance.mSetSoWriteTimeoutMethodRef.get(); + return (cachedSetSoWriteTimeoutMethod == null) ? + initFrom(socketClass) : + cachedSetSoWriteTimeoutMethod + ; + } + } + + + /** + * Singleton initializer. + * + * Uses reflection to extract and 'cache' the method to invoke to set a write timouet in a socket. + * + * @param socketClass Underlying class providing the implementation of {@link Socket}. + * @return Method to call to set a write timeout in the socket. + */ + private static Method initFrom(Class socketClass) { + Log_OC.i(TAG, "Socket implementation: " + socketClass.getCanonicalName()); + Method setSoWriteTimeoutMethod = null; + try { + setSoWriteTimeoutMethod = socketClass.getMethod(METHOD_NAME, int.class); + } catch (SecurityException e) { + Log_OC.e(TAG, "Could not access to (SocketImpl)#setSoWriteTimeout(int) method ", e); + + } catch (NoSuchMethodException e) { + Log_OC.i( + TAG, + "Could not find (SocketImpl)#setSoWriteTimeout(int) method - write timeout not supported" + ); + } + mSingleInstance.set(new WriteTimeoutEnforcer(socketClass, setSoWriteTimeoutMethod)); + return setSoWriteTimeoutMethod; + } + +}