using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; using System.Threading; using System.Threading.Tasks; using CryptoExchange.Net.Authentication; using CryptoExchange.Net.Interfaces; using CryptoExchange.Net.Logging; using CryptoExchange.Net.Objects; using CryptoExchange.Net.Sockets; using Newtonsoft.Json.Linq; namespace CryptoExchange.Net { /// /// Base for socket client implementations /// public abstract class SocketClient: BaseClient, ISocketClient { #region fields /// /// The factory for creating sockets. Used for unit testing /// public IWebsocketFactory SocketFactory { get; set; } = new WebsocketFactory(); /// /// List of socket connections currently connecting/connected /// protected internal ConcurrentDictionary sockets = new ConcurrentDictionary(); /// /// protected internal readonly SemaphoreSlim semaphoreSlim = new SemaphoreSlim(1); /// public TimeSpan ReconnectInterval { get; } /// public bool AutoReconnect { get; } /// public TimeSpan ResponseTimeout { get; } /// public TimeSpan SocketNoDataTimeout { get; } /// /// The max amount of concurrent socket connections /// public int MaxSocketConnections { get; protected set; } = 9999; /// public int SocketCombineTarget { get; protected set; } /// /// Handler for byte data /// protected Func? dataInterpreterBytes; /// /// Handler for string data /// protected Func? dataInterpreterString; /// /// Generic handlers /// protected Dictionary> genericHandlers = new Dictionary>(); /// /// Periodic task /// protected Task? periodicTask; /// /// Periodic task event /// protected AutoResetEvent? periodicEvent; /// /// Is disposing /// protected bool disposing; /// /// If true; data which is a response to a query will also be distributed to subscriptions /// If false; data which is a response to a query won't get forwarded to subscriptions as well /// protected internal bool ContinueOnQueryResponse { get; protected set; } #endregion /// /// Create a socket client /// /// Client options /// Authentication provider protected SocketClient(SocketClientOptions exchangeOptions, AuthenticationProvider? authenticationProvider): base(exchangeOptions, authenticationProvider) { if (exchangeOptions == null) throw new ArgumentNullException(nameof(exchangeOptions)); AutoReconnect = exchangeOptions.AutoReconnect; ReconnectInterval = exchangeOptions.ReconnectInterval; ResponseTimeout = exchangeOptions.SocketResponseTimeout; SocketNoDataTimeout = exchangeOptions.SocketNoDataTimeout; SocketCombineTarget = exchangeOptions.SocketSubscriptionsCombineTarget ?? 1; } /// /// Set a function to interpret the data, used when the data is received as bytes instead of a string /// /// Handler for byte data /// Handler for string data protected void SetDataInterpreter(Func? byteHandler, Func? stringHandler) { dataInterpreterBytes = byteHandler; dataInterpreterString = stringHandler; } /// /// Subscribe /// /// The expected return data /// The request to send /// The identifier to use /// If the subscription should be authenticated /// The handler of update data /// protected virtual Task> Subscribe(object? request, string? identifier, bool authenticated, Action dataHandler) { return Subscribe(BaseAddress, request, identifier, authenticated, dataHandler); } /// /// Subscribe using a specif URL /// /// The type of the expected data /// The URL to connect to /// The request to send /// The identifier to use /// If the subscription should be authenticated /// The handler of update data /// protected virtual async Task> Subscribe(string url, object? request, string? identifier, bool authenticated, Action dataHandler) { SocketConnection socket; SocketSubscription handler; var released = false; await semaphoreSlim.WaitAsync().ConfigureAwait(false); try { socket = GetWebsocket(url, authenticated); handler = AddHandler(request, identifier, true, socket, dataHandler); if (SocketCombineTarget == 1) { // Can release early when only a single sub per connection semaphoreSlim.Release(); released = true; } var connectResult = await ConnectIfNeeded(socket, authenticated).ConfigureAwait(false); if (!connectResult) return new CallResult(null, connectResult.Error); } finally { //When the task is ready, release the semaphore. It is vital to ALWAYS release the semaphore when we are ready, or else we will end up with a Semaphore that is forever locked. //This is why it is important to do the Release within a try...finally clause; program execution may crash or take a different path, this way you are guaranteed execution if(!released) semaphoreSlim.Release(); } if (socket.PausedActivity) { log.Write(LogVerbosity.Info, "Socket has been paused, can't subscribe at this moment"); return new CallResult(default, new ServerError("Socket is paused")); } if (request != null) { var subResult = await SubscribeAndWait(socket, request, handler).ConfigureAwait(false); if (!subResult) { await socket.Close(handler).ConfigureAwait(false); return new CallResult(null, subResult.Error); } } else { handler.Confirmed = true; } socket.ShouldReconnect = true; return new CallResult(new UpdateSubscription(socket, handler), null); } /// /// Sends the subscribe request and waits for a response to that request /// /// The connection to send the request on /// The request to send /// The subscription the request is for /// protected internal virtual async Task> SubscribeAndWait(SocketConnection socket, object request, SocketSubscription subscription) { CallResult? callResult = null; await socket.SendAndWait(request, ResponseTimeout, data => HandleSubscriptionResponse(socket, subscription, request, data, out callResult)).ConfigureAwait(false); if (callResult?.Success == true) subscription.Confirmed = true; return new CallResult(callResult?.Success ?? false, callResult == null ? new ServerError("No response on subscription request received"): callResult.Error); } /// /// Query for data /// /// Expected result type /// The request to send /// Whether the socket should be authenticated /// protected virtual Task> Query(object request, bool authenticated) { return Query(BaseAddress, request, authenticated); } /// /// Query for data /// /// The expected result type /// The url for the request /// The request to send /// Whether the socket should be authenticated /// protected virtual async Task> Query(string url, object request, bool authenticated) { SocketConnection socket; var released = false; await semaphoreSlim.WaitAsync().ConfigureAwait(false); try { socket = GetWebsocket(url, authenticated); if (SocketCombineTarget == 1) { // Can release early when only a single sub per connection semaphoreSlim.Release(); released = true; } var connectResult = await ConnectIfNeeded(socket, authenticated).ConfigureAwait(false); if (!connectResult) return new CallResult(default, connectResult.Error); } finally { //When the task is ready, release the semaphore. It is vital to ALWAYS release the semaphore when we are ready, or else we will end up with a Semaphore that is forever locked. //This is why it is important to do the Release within a try...finally clause; program execution may crash or take a different path, this way you are guaranteed execution if (!released) semaphoreSlim.Release(); } if (socket.PausedActivity) { log.Write(LogVerbosity.Info, "Socket has been paused, can't send query at this moment"); return new CallResult(default, new ServerError("Socket is paused")); } return await QueryAndWait(socket, request).ConfigureAwait(false); } /// /// Sends the query request and waits for the result /// /// The expected result type /// The connection to send and wait on /// The request to send /// protected virtual async Task> QueryAndWait(SocketConnection socket, object request) { var dataResult = new CallResult(default, new ServerError("No response on query received")); await socket.SendAndWait(request, ResponseTimeout, data => { if (!HandleQueryResponse(socket, request, data, out var callResult)) return false; dataResult = callResult; return true; }).ConfigureAwait(false); return dataResult; } /// /// Checks if a socket needs to be connected and does so if needed /// /// The connection to check /// Whether the socket should authenticated /// protected virtual async Task> ConnectIfNeeded(SocketConnection socket, bool authenticated) { if (socket.Connected) return new CallResult(true, null); var connectResult = await ConnectSocket(socket).ConfigureAwait(false); if (!connectResult) return new CallResult(false, new CantConnectError()); if (!authenticated || socket.Authenticated) return new CallResult(true, null); var result = await AuthenticateSocket(socket).ConfigureAwait(false); if (!result) { log.Write(LogVerbosity.Warning, "Socket authentication failed"); result.Error!.Message = "Authentication failed: " + result.Error.Message; return new CallResult(false, result.Error); } socket.Authenticated = true; return new CallResult(true, null); } /// /// Needs to check if a received message was an answer to a query request (preferable by id) and set the callResult out to whatever the response is /// /// The type of response /// The socket connection /// The request that a response is awaited for /// The message /// The interpretation (null if message wasn't a response to the request) /// True if the message was a response to the query protected internal abstract bool HandleQueryResponse(SocketConnection s, object request, JToken data, [NotNullWhen(true)]out CallResult? callResult); /// /// Needs to check if a received message was an answer to a subscription request (preferable by id) and set the callResult out to whatever the response is /// /// The socket connection /// /// The request that a response is awaited for /// The message /// The interpretation (null if message wasn't a response to the request) /// True if the message was a response to the subscription request protected internal abstract bool HandleSubscriptionResponse(SocketConnection s, SocketSubscription subscription, object request, JToken message, out CallResult? callResult); /// /// Needs to check if a received message matches a handler. Typically if an update message matches the request /// /// The received data /// The subscription request /// protected internal abstract bool MessageMatchesHandler(JToken message, object request); /// /// Needs to check if a received message matches a handler. Typically if an received message matches a ping request or a other information pushed from the the server /// /// The received data /// The string identifier of the handler /// protected internal abstract bool MessageMatchesHandler(JToken message, string identifier); /// /// Needs to authenticate the socket so authenticated queries/subscriptions can be made on this socket connection /// /// /// protected internal abstract Task> AuthenticateSocket(SocketConnection s); /// /// Needs to unsubscribe a subscription, typically by sending an unsubscribe request. If multiple subscriptions per socket is not allowed this can just return since the socket will be closed anyway /// /// The connection on which to unsubscribe /// The subscription to unsubscribe /// protected internal abstract Task Unsubscribe(SocketConnection connection, SocketSubscription s); /// /// Optional handler to interpolate data before sending it to the handlers /// /// /// protected internal virtual JToken ProcessTokenData(JToken message) { return message; } /// /// Add a handler for a subscription /// /// The type of data the subscription expects /// The request of the subscription /// The identifier of the subscription (can be null if request param is used) /// Whether or not this is a user subscription (counts towards the max amount of handlers on a socket) /// The socket connection the handler is on /// The handler of the data received /// protected virtual SocketSubscription AddHandler(object? request, string? identifier, bool userSubscription, SocketConnection connection, Action dataHandler) { void InternalHandler(SocketConnection socketWrapper, JToken data) { if (typeof(T) == typeof(string)) { dataHandler((T) Convert.ChangeType(data.ToString(), typeof(T))); return; } var desResult = Deserialize(data, false); if (!desResult) { log.Write(LogVerbosity.Warning, $"Failed to deserialize data into type {typeof(T)}: {desResult.Error}"); return; } dataHandler(desResult.Data); } var handler = request == null ? SocketSubscription.CreateForIdentifier(identifier!, userSubscription, InternalHandler) : SocketSubscription.CreateForRequest(request, userSubscription, InternalHandler); connection.AddHandler(handler); return handler; } /// /// Adds a generic message handler. Used for example to reply to ping requests /// /// The name of the request handler. Needs to be unique /// The action to execute when receiving a message for this handler (checked by ) protected void AddGenericHandler(string identifier, Action action) { genericHandlers.Add(identifier, action); var handler = SocketSubscription.CreateForIdentifier(identifier, false, action); foreach (var connection in sockets.Values) connection.AddHandler(handler); } /// /// Gets a connection for a new subscription or query. Can be an existing if there are open position or a new one. /// /// The address the socket is for /// Whether the socket should be authenticated /// protected virtual SocketConnection GetWebsocket(string address, bool authenticated) { var socketResult = sockets.Where(s => s.Value.Socket.Url == address && (s.Value.Authenticated == authenticated || !authenticated) && s.Value.Connected).OrderBy(s => s.Value.HandlerCount).FirstOrDefault(); var result = socketResult.Equals(default(KeyValuePair)) ? null : socketResult.Value; if (result != null) { if (result.HandlerCount < SocketCombineTarget || (sockets.Count >= MaxSocketConnections && sockets.All(s => s.Value.HandlerCount >= SocketCombineTarget))) { // Use existing socket if it has less than target connections OR it has the least connections and we can't make new return result; } } // Create new socket var socket = CreateSocket(address); var socketWrapper = new SocketConnection(this, socket); foreach (var kvp in genericHandlers) { var handler = SocketSubscription.CreateForIdentifier(kvp.Key, false, kvp.Value); socketWrapper.AddHandler(handler); } return socketWrapper; } /// /// Connect a socket /// /// The socket to connect /// protected virtual async Task> ConnectSocket(SocketConnection socketConnection) { if (await socketConnection.Socket.Connect().ConfigureAwait(false)) { sockets.TryAdd(socketConnection.Socket.Id, socketConnection); return new CallResult(true, null); } socketConnection.Socket.Dispose(); return new CallResult(false, new CantConnectError()); } /// /// Create a socket for an address /// /// The address the socket should connect to /// protected virtual IWebsocket CreateSocket(string address) { var socket = SocketFactory.CreateWebsocket(log, address); log.Write(LogVerbosity.Debug, "Created new socket for " + address); if (apiProxy != null) socket.SetProxy(apiProxy.Host, apiProxy.Port); socket.Timeout = SocketNoDataTimeout; socket.DataInterpreterBytes = dataInterpreterBytes; socket.DataInterpreterString = dataInterpreterString; socket.OnError += e => { log.Write(LogVerbosity.Info, $"Socket {socket.Id} error: " + e); }; return socket; } /// /// Periodically sends an object to a socket /// /// How often /// Method returning the object to send public virtual void SendPeriodic(TimeSpan interval, Func objGetter) { if (objGetter == null) throw new ArgumentNullException(nameof(objGetter)); periodicEvent = new AutoResetEvent(false); periodicTask = Task.Run(async () => { while (!disposing) { await periodicEvent.WaitOneAsync(interval).ConfigureAwait(false); if (disposing) break; if (sockets.Any()) log.Write(LogVerbosity.Debug, "Sending periodic"); foreach (var socket in sockets.Values) { if (disposing) break; var obj = objGetter(socket); if (obj == null) continue; try { socket.Send(obj); } catch (Exception ex) { log.Write(LogVerbosity.Warning, "Periodic send failed: " + ex); } } } }); } /// /// Unsubscribe from a stream /// /// The subscription to unsubscribe /// public virtual async Task Unsubscribe(UpdateSubscription subscription) { if (subscription == null) throw new ArgumentNullException(nameof(subscription)); log.Write(LogVerbosity.Info, "Closing subscription"); await subscription.Close().ConfigureAwait(false); } /// /// Unsubscribe all subscriptions /// /// public virtual async Task UnsubscribeAll() { log.Write(LogVerbosity.Debug, $"Closing all {sockets.Sum(s => s.Value.HandlerCount)} subscriptions"); await Task.Run(async () => { var tasks = new List(); { var socketList = sockets.Values; foreach (var sub in socketList) tasks.Add(sub.Close()); } await Task.WhenAll(tasks.ToArray()).ConfigureAwait(false); }).ConfigureAwait(false); } /// /// Dispose the client /// public override void Dispose() { disposing = true; periodicEvent?.Set(); periodicEvent?.Dispose(); log.Write(LogVerbosity.Debug, "Disposing socket client, closing all subscriptions"); Task.Run(UnsubscribeAll).Wait(); semaphoreSlim?.Dispose(); base.Dispose(); } } }