diff --git a/CryptoExchange.Net/Clients/SocketApiClient.cs b/CryptoExchange.Net/Clients/SocketApiClient.cs index 7e3b3e8..38c0b40 100644 --- a/CryptoExchange.Net/Clients/SocketApiClient.cs +++ b/CryptoExchange.Net/Clients/SocketApiClient.cs @@ -233,7 +233,7 @@ namespace CryptoExchange.Net if (subQuery != null) { // Send the request and wait for answer - var subResult = await socketConnection.SendAndWaitQueryAsync(subQuery).ConfigureAwait(false); // TODO return null on timeout + var subResult = await socketConnection.SendAndWaitQueryAsync(subQuery).ConfigureAwait(false); if (!subResult) { _logger.Log(LogLevel.Warning, $"Socket {socketConnection.SocketId} failed to subscribe: {subResult.Error}"); @@ -259,6 +259,7 @@ namespace CryptoExchange.Net }, false); } + subscription.Confirmed = true; _logger.Log(LogLevel.Information, $"Socket {socketConnection.SocketId} subscription {subscription.Id} completed successfully"); return new CallResult(new UpdateSubscription(socketConnection, subscription)); } diff --git a/CryptoExchange.Net/Converters/SocketConverter.cs b/CryptoExchange.Net/Converters/SocketConverter.cs index 33c02b3..7fd87bf 100644 --- a/CryptoExchange.Net/Converters/SocketConverter.cs +++ b/CryptoExchange.Net/Converters/SocketConverter.cs @@ -22,7 +22,7 @@ namespace CryptoExchange.Net.Converters public abstract MessageInterpreterPipeline InterpreterPipeline { get; } /// - public BaseParsedMessage? ReadJson(Stream stream, IDictionary processors, bool outputOriginalData) + public BaseParsedMessage? ReadJson(Stream stream, Dictionary processors, bool outputOriginalData) { // Start reading the data // Once we reach the properties that identify the message we save those in a dict @@ -62,43 +62,69 @@ namespace CryptoExchange.Net.Converters return null; } - if (token.Type == JTokenType.Array) - { - // Received array, take first item as reference - token = token.First!; - } - PostInspectResult? inspectResult = null; Dictionary typeIdDict = new Dictionary(); - PostInspectCallback? usedParser = null; - foreach (var callback in InterpreterPipeline.PostInspectCallbacks) + object? usedParser = null; + if (token.Type == JTokenType.Object) { - bool allFieldsPresent = true; - foreach(var field in callback.TypeFields) + foreach (var callback in InterpreterPipeline.PostInspectCallbacks.OfType()) { - var value = typeIdDict.TryGetValue(field, out var cachedValue) ? cachedValue : GetValueForKey(token, field); - if (value == null) + bool allFieldsPresent = true; + foreach (var field in callback.TypeFields) { - allFieldsPresent = false; - break; + var value = typeIdDict.TryGetValue(field, out var cachedValue) ? cachedValue : GetValueForKey(token, field); + if (value == null) + { + allFieldsPresent = false; + break; + } + + typeIdDict[field] = value; } - typeIdDict[field] = value; + if (allFieldsPresent) + { + inspectResult = callback.Callback(typeIdDict, processors); + usedParser = callback; + break; + } } - - if (allFieldsPresent) + } + else + { + foreach (var callback in InterpreterPipeline.PostInspectCallbacks.OfType()) { - inspectResult = callback.Callback(typeIdDict, processors); - usedParser = callback; - break; + var typeIdArrayDict = new Dictionary(); + bool allFieldsPresent = true; + var maxIndex = callback.TypeFields.Max(); + if (((JArray)token).Count <= maxIndex) + continue; + + foreach (var field in callback.TypeFields) + { + var value = token[field]; + if (value == null) + { + allFieldsPresent = false; + break; + } + + typeIdArrayDict[field] = value.ToString(); + } + + if (allFieldsPresent) + { + inspectResult = callback.Callback(typeIdArrayDict, processors); + usedParser = callback; + break; + } } } if (usedParser == null) throw new Exception("No parser found for message"); - var resultMessageType = typeof(ParsedMessage<>).MakeGenericType(inspectResult.Type); - var instance = (BaseParsedMessage)Activator.CreateInstance(resultMessageType, inspectResult.Type == null ? null : token.ToObject(inspectResult.Type, _serializer)); + var instance = InterpreterPipeline.ObjectInitializer(token, inspectResult.Type); if (outputOriginalData) { stream.Position = 0; @@ -110,6 +136,13 @@ namespace CryptoExchange.Net.Converters return instance; } + public static BaseParsedMessage InstantiateMessageObject(JToken token, Type type) + { + var resultMessageType = typeof(ParsedMessage<>).MakeGenericType(type); + var instance = (BaseParsedMessage)Activator.CreateInstance(resultMessageType, type == null ? null : token.ToObject(type, _serializer)); + return instance; + } + private string? GetValueForKey(JToken token, string key) { var splitTokens = key.Split(new char[] { ':' }); diff --git a/CryptoExchange.Net/Interfaces/IMessageProcessor.cs b/CryptoExchange.Net/Interfaces/IMessageProcessor.cs index 9d049b7..9fa1824 100644 --- a/CryptoExchange.Net/Interfaces/IMessageProcessor.cs +++ b/CryptoExchange.Net/Interfaces/IMessageProcessor.cs @@ -10,6 +10,7 @@ namespace CryptoExchange.Net.Interfaces public interface IMessageProcessor { public int Id { get; } + public List Identifiers { get; } Task HandleMessageAsync(DataEvent message); public Type ExpectedMessageType { get; } } diff --git a/CryptoExchange.Net/Objects/Sockets/StreamMessageParseCallback.cs b/CryptoExchange.Net/Objects/Sockets/StreamMessageParseCallback.cs index c60f178..484441a 100644 --- a/CryptoExchange.Net/Objects/Sockets/StreamMessageParseCallback.cs +++ b/CryptoExchange.Net/Objects/Sockets/StreamMessageParseCallback.cs @@ -1,6 +1,7 @@ using CryptoExchange.Net.Converters; using CryptoExchange.Net.Interfaces; using CryptoExchange.Net.Sockets; +using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.IO; @@ -11,7 +12,8 @@ namespace CryptoExchange.Net.Objects.Sockets public class MessageInterpreterPipeline { public List PreInspectCallbacks { get; set; } = new List(); - public List PostInspectCallbacks { get; set; } = new List(); + public List PostInspectCallbacks { get; set; } = new List(); + public Func ObjectInitializer { get; set; } = SocketConverter.InstantiateMessageObject; } public class PreInspectCallback @@ -22,7 +24,13 @@ namespace CryptoExchange.Net.Objects.Sockets public class PostInspectCallback { public List TypeFields { get; set; } = new List(); - public Func, IDictionary, PostInspectResult> Callback { get; set; } + public Func, Dictionary, PostInspectResult> Callback { get; set; } + } + + public class PostInspectArrayCallback + { + public List TypeFields { get; set; } = new List(); + public Func, Dictionary, PostInspectResult> Callback { get; set; } } public class PreInspectResult diff --git a/CryptoExchange.Net/Sockets/SocketConnection.cs b/CryptoExchange.Net/Sockets/SocketConnection.cs index 012e18e..00b718e 100644 --- a/CryptoExchange.Net/Sockets/SocketConnection.cs +++ b/CryptoExchange.Net/Sockets/SocketConnection.cs @@ -152,9 +152,9 @@ namespace CryptoExchange.Net.Sockets } private bool _pausedActivity; - private readonly ConcurrentList _pendingRequests; - private readonly ConcurrentList _subscriptions; - private readonly ConcurrentDictionary _messageIdMap; + //private readonly ConcurrentList _pendingRequests; + //private readonly ConcurrentList _subscriptions; + private readonly SocketListenerManager _messageIdMap; private readonly ILogger _logger; private SocketStatus _status; @@ -177,9 +177,9 @@ namespace CryptoExchange.Net.Sockets Tag = tag; Properties = new Dictionary(); - _pendingRequests = new ConcurrentList(); - _subscriptions = new ConcurrentList(); - _messageIdMap = new ConcurrentDictionary(); + //_pendingRequests = new ConcurrentList(); + //_subscriptions = new ConcurrentList(); + _messageIdMap = new SocketListenerManager(_logger); _socket = socket; _socket.OnStreamMessage += HandleStreamMessage; @@ -209,7 +209,7 @@ namespace CryptoExchange.Net.Sockets Status = SocketStatus.Closed; Authenticated = false; - foreach (var subscription in _subscriptions) + foreach (var subscription in _messageIdMap.GetSubscriptions()) subscription.Confirmed = false; Task.Run(() => ConnectionClosed?.Invoke()); @@ -224,7 +224,7 @@ namespace CryptoExchange.Net.Sockets DisconnectTime = DateTime.UtcNow; Authenticated = false; - foreach (var subscription in _subscriptions) + foreach (var subscription in _messageIdMap.GetSubscriptions()) subscription.Confirmed = false; _ = Task.Run(() => ConnectionLost?.Invoke()); @@ -307,7 +307,9 @@ namespace CryptoExchange.Net.Sockets var timestamp = DateTime.UtcNow; TimeSpan userCodeDuration = TimeSpan.Zero; - var result = ApiClient.StreamConverter.ReadJson(stream, _messageIdMap, ApiClient.ApiOptions.OutputOriginalData ?? ApiClient.ClientOptions.OutputOriginalData); + // TODO This shouldn't be done for every request, just when something changes. Might want to make it a seperate type or something with functions 'Add', 'Remove' and 'GetMapping' or something + // This could then cache the internal dictionary mapping of `GetMapping` until something changes, and also make sure there aren't duplicate ids with different message types + var result = ApiClient.StreamConverter.ReadJson(stream, _messageIdMap.GetMapping(), ApiClient.ApiOptions.OutputOriginalData ?? ApiClient.ClientOptions.OutputOriginalData); if(result == null) { stream.Position = 0; @@ -326,8 +328,7 @@ namespace CryptoExchange.Net.Sockets return; } - // TODO lock - if (!_messageIdMap.TryGetValue(result.Identifier, out var messageProcessor)) + if (!await _messageIdMap.InvokeListenersAsync(result.Identifier, result).ConfigureAwait(false)) { stream.Position = 0; var unhandledBuffer = new byte[stream.Length]; @@ -340,28 +341,6 @@ namespace CryptoExchange.Net.Sockets } stream.Dispose(); - _logger.Log(LogLevel.Trace, $"Socket {SocketId} Message mapped to processor {messageProcessor.Id} with identifier {result.Identifier}"); - - if (messageProcessor is BaseQuery query) - { - foreach (var id in query.Identifiers) - _messageIdMap.TryRemove(id, out _); - - if (query.PendingRequest != null) - _pendingRequests.Remove(query.PendingRequest); - - if (query.PendingRequest?.Completed == true) - { - // Answer to a timed out request - _logger.Log(LogLevel.Warning, $"Socket {SocketId} Received after request timeout. Consider increasing the RequestTimeout"); - } - } - - // Matched based on identifier - var userSw = Stopwatch.StartNew(); - var dataEvent = new DataEvent(result, null, result.OriginalData, DateTime.UtcNow, null); - await messageProcessor.HandleMessageAsync(dataEvent).ConfigureAwait(false); - userSw.Stop(); } /// @@ -394,7 +373,7 @@ namespace CryptoExchange.Net.Sockets if (ApiClient.socketConnections.ContainsKey(SocketId)) ApiClient.socketConnections.TryRemove(SocketId, out _); - foreach (var subscription in _subscriptions) + foreach (var subscription in _messageIdMap.GetSubscriptions()) { if (subscription.CancellationTokenRegistration.HasValue) subscription.CancellationTokenRegistration.Value.Dispose(); @@ -412,7 +391,7 @@ namespace CryptoExchange.Net.Sockets /// public async Task CloseAsync(Subscription subscription, bool unsubEvenIfNotConfirmed = false) { - if (!_subscriptions.Contains(subscription)) + if (!_messageIdMap.Contains(subscription)) return; subscription.Closed = true; @@ -433,7 +412,7 @@ namespace CryptoExchange.Net.Sockets return; } - var shouldCloseConnection = _subscriptions.All(r => !r.UserSubscription || r.Closed); + var shouldCloseConnection = _messageIdMap.GetSubscriptions().All(r => !r.UserSubscription || r.Closed); if (shouldCloseConnection) Status = SocketStatus.Closing; @@ -443,9 +422,7 @@ namespace CryptoExchange.Net.Sockets await CloseAsync().ConfigureAwait(false); } - _subscriptions.Remove(subscription); - foreach (var id in subscription.Identifiers) - _messageIdMap.TryRemove(id, out _); + _messageIdMap.Remove(subscription); } /// @@ -466,18 +443,12 @@ namespace CryptoExchange.Net.Sockets if (Status != SocketStatus.None && Status != SocketStatus.Connected) return false; - _subscriptions.Add(subscription); + _messageIdMap.Add(subscription); if (subscription.Identifiers != null) - { - foreach (var id in subscription.Identifiers) - { - if (!_messageIdMap.TryAdd(id.ToLowerInvariant(), subscription)) - throw new InvalidOperationException($"Failed to register subscription id {id}, already registered"); - } - } + _messageIdMap.Add(subscription); - if (subscription.UserSubscription) - _logger.Log(LogLevel.Debug, $"Socket {SocketId} adding new subscription with id {subscription.Id}, total subscriptions on connection: {_subscriptions.Count(s => s.UserSubscription)}"); + //if (subscription.UserSubscription) + // _logger.Log(LogLevel.Debug, $"Socket {SocketId} adding new subscription with id {subscription.Id}, total subscriptions on connection: {_subscriptions.Count(s => s.UserSubscription)}"); return true; } @@ -485,14 +456,14 @@ namespace CryptoExchange.Net.Sockets /// Get a subscription on this connection by id /// /// - public Subscription? GetSubscription(int id) => _subscriptions.SingleOrDefault(s => s.Id == id); + public Subscription? GetSubscription(int id) => _messageIdMap.GetSubscriptions().SingleOrDefault(s => s.Id == id); /// /// Get a subscription on this connection by its subscribe request /// /// Filter for a request /// - public Subscription? GetSubscriptionByRequest(Func predicate) => _subscriptions.SingleOrDefault(s => predicate(s)); + public Subscription? GetSubscriptionByRequest(Func predicate) => _messageIdMap.GetSubscriptions().SingleOrDefault(s => predicate(s)); /// /// Send a query request and wait for an answer @@ -503,16 +474,10 @@ namespace CryptoExchange.Net.Sockets { var pendingRequest = query.CreatePendingRequest(); if (query.Identifiers != null) - { - foreach (var id in query.Identifiers) - { - if(!_messageIdMap.TryAdd(id.ToLowerInvariant(), query)) - throw new InvalidOperationException($"Failed to register subscription id {id}, already registered"); - } - } + _messageIdMap.Add(query); await SendAndWaitAsync(pendingRequest, query.Weight).ConfigureAwait(false); - return pendingRequest.Result!; + return pendingRequest.Result ?? new CallResult(new ServerError("Timeout")); } /// @@ -523,18 +488,12 @@ namespace CryptoExchange.Net.Sockets /// public virtual async Task> SendAndWaitQueryAsync(Query query) { - var pendingRequest = PendingRequest.CreateForQuery(query, query.Id); + var pendingRequest = (PendingRequest)query.CreatePendingRequest(); if (query.Identifiers != null) - { - foreach (var id in query.Identifiers) - { - if (!_messageIdMap.TryAdd(id.ToLowerInvariant(), query)) - throw new InvalidOperationException($"Failed to register subscription id {id}, already registered"); - } - } + _messageIdMap.Add(query); await SendAndWaitAsync(pendingRequest, query.Weight).ConfigureAwait(false); - return pendingRequest.TypedResult!; + return pendingRequest.TypedResult ?? new CallResult(new ServerError("Timeout")); } private async Task SendAndWaitAsync(BasePendingRequest pending, int weight) @@ -608,7 +567,7 @@ namespace CryptoExchange.Net.Sockets if (!_socket.IsOpen) return new CallResult(new WebError("Socket not connected")); - var anySubscriptions = _subscriptions.Any(s => s.UserSubscription); + var anySubscriptions = _messageIdMap.GetSubscriptions().Any(s => s.UserSubscription); if (!anySubscriptions) { // No need to resubscribe anything @@ -617,7 +576,7 @@ namespace CryptoExchange.Net.Sockets return new CallResult(true); } - var anyAuthenticated = _subscriptions.Any(s => s.Authenticated); + var anyAuthenticated = _messageIdMap.GetSubscriptions().Any(s => s.Authenticated); if (anyAuthenticated) { // If we reconnected a authenticated connection we need to re-authenticate @@ -633,15 +592,16 @@ namespace CryptoExchange.Net.Sockets } // Get a list of all subscriptions on the socket - var subList = _subscriptions.ToList(); + var subList = _messageIdMap.GetSubscriptions(); foreach(var subscription in subList) { + subscription.ConnectionInvocations = 0; var result = await ApiClient.RevitalizeRequestAsync(subscription).ConfigureAwait(false); if (!result) { _logger.Log(LogLevel.Warning, $"Socket {SocketId} Failed request revitalization: " + result.Error); - return result.As(false); + return result.As(false); } } diff --git a/CryptoExchange.Net/Sockets/SocketListenerManager.cs b/CryptoExchange.Net/Sockets/SocketListenerManager.cs new file mode 100644 index 0000000..307334a --- /dev/null +++ b/CryptoExchange.Net/Sockets/SocketListenerManager.cs @@ -0,0 +1,132 @@ +using CryptoExchange.Net.Interfaces; +using CryptoExchange.Net.Objects.Sockets; +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net.Sockets; +using System.Text; +using System.Threading.Tasks; + +namespace CryptoExchange.Net.Sockets +{ + internal class SocketListenerManager + { + private ILogger _logger; + private object _lock = new object(); + private Dictionary _typeMap; + private Dictionary> _listeners; + + public SocketListenerManager(ILogger logger) + { + _typeMap = new Dictionary(); + _logger = logger; + } + + public Dictionary GetMapping() + { + lock (this) + return _typeMap; + } + + public void Add(IMessageProcessor processor) + { + lock (_lock) + { + foreach (var identifier in processor.Identifiers) + { + if (!_listeners.TryGetValue(identifier, out var list)) + { + list = new List(); + _listeners.Add(identifier, list); + } + + list.Add(processor); + } + + UpdateMap(); + } + } + + public async Task InvokeListenersAsync(string id, BaseParsedMessage data) + { + List listeners; + lock (_lock) + { + if (!_listeners.TryGetValue(id, out var idListeners)) + return false; + + listeners = idListeners.ToList(); + } + + foreach (var listener in listeners) + { + //_logger.Log(LogLevel.Trace, $"Socket {SocketId} Message mapped to processor {messageProcessor.Id} with identifier {result.Identifier}"); + if (listener is BaseQuery query) + { + Remove(listener); + + if (query.PendingRequest != null) + _pendingRequests.Remove(query.PendingRequest); + + if (query.PendingRequest?.Completed == true) + { + // Answer to a timed out request + //_logger.Log(LogLevel.Warning, $"Socket {SocketId} Received after request timeout. Consider increasing the RequestTimeout"); + } + } + + // Matched based on identifier + var userSw = Stopwatch.StartNew(); + var dataEvent = new DataEvent(data, null, data.OriginalData, DateTime.UtcNow, null); + await listener.HandleMessageAsync(dataEvent).ConfigureAwait(false); + userSw.Stop(); + } + + return true; + } + + public List GetSubscriptions() + { + lock (_lock) + return _listeners.Values.SelectMany(v => v.OfType()).ToList(); + } + + public List GetQueries() + { + lock (_lock) + return _listeners.Values.SelectMany(v => v.OfType()).ToList(); + } + + public bool Contains(IMessageProcessor processor) + { + lock (_lock) + return _listeners.Any(l => l.Value.Contains(processor)); + } + + public bool Remove(IMessageProcessor processor) + { + lock (_lock) + { + var removed = false; + foreach (var identifier in processor.Identifiers) + { + if (_listeners[identifier].Remove(processor)) + removed = true; + + if (!_listeners[identifier].Any()) + _listeners.Remove(identifier); + } + + UpdateMap(); + return removed; + } + } + + private void UpdateMap() + { + _typeMap = _listeners.ToDictionary(x => x.Key, x => x.Value.First().ExpectedMessageType); + } + } +} diff --git a/CryptoExchange.Net/Sockets/Subscription.cs b/CryptoExchange.Net/Sockets/Subscription.cs index eaee2c1..50c99d2 100644 --- a/CryptoExchange.Net/Sockets/Subscription.cs +++ b/CryptoExchange.Net/Sockets/Subscription.cs @@ -19,6 +19,16 @@ namespace CryptoExchange.Net.Sockets /// public int Id { get; set; } + /// + /// Total amount of invocations + /// + public int TotalInvocations { get; set; } + + /// + /// Amount of invocation during this connection + /// + public int ConnectionInvocations { get; set; } + /// /// Is it a user subscription /// @@ -87,12 +97,19 @@ namespace CryptoExchange.Net.Sockets /// public abstract BaseQuery? GetUnsubQuery(); + public async Task HandleMessageAsync(DataEvent message) + { + ConnectionInvocations++; + TotalInvocations++; + return await DoHandleMessageAsync(message).ConfigureAwait(false); + } + /// /// Handle the update message /// /// /// - public abstract Task HandleMessageAsync(DataEvent message); + public abstract Task DoHandleMessageAsync(DataEvent message); /// /// Invoke the exception event @@ -132,7 +149,7 @@ namespace CryptoExchange.Net.Sockets } /// - public override Task HandleMessageAsync(DataEvent message) + public override Task DoHandleMessageAsync(DataEvent message) => HandleEventAsync(message.As((ParsedMessage)message.Data)); /// diff --git a/CryptoExchange.Net/Sockets/SystemSubscription.cs b/CryptoExchange.Net/Sockets/SystemSubscription.cs index e5d6165..c600b05 100644 --- a/CryptoExchange.Net/Sockets/SystemSubscription.cs +++ b/CryptoExchange.Net/Sockets/SystemSubscription.cs @@ -1,4 +1,8 @@ -using Microsoft.Extensions.Logging; +using CryptoExchange.Net.Objects; +using CryptoExchange.Net.Objects.Sockets; +using Microsoft.Extensions.Logging; +using System; +using System.Threading.Tasks; namespace CryptoExchange.Net.Sockets { @@ -22,4 +26,17 @@ namespace CryptoExchange.Net.Sockets /// public override BaseQuery? GetUnsubQuery() => null; } + + public abstract class SystemSubscription : SystemSubscription + { + public override Type ExpectedMessageType => typeof(T); + public override Task DoHandleMessageAsync(DataEvent message) + => HandleMessageAsync(message.As((ParsedMessage)message.Data)); + + protected SystemSubscription(ILogger logger, bool authenticated) : base(logger, authenticated) + { + } + + public abstract Task HandleMessageAsync(DataEvent> message); + } }