From ac434fa2c61a1a52b63ef24560cefa9cac8bb7d2 Mon Sep 17 00:00:00 2001 From: JKorf Date: Mon, 6 Nov 2023 21:43:41 +0100 Subject: [PATCH] wip --- .../Converters/SocketConverter.cs | 68 +++---- CryptoExchange.Net/Objects/ConcurrentList.cs | 81 +++++++++ .../Objects/Sockets/MatchingStrategy.cs | 21 +++ .../Sockets/StreamMessageParseCallback.cs | 14 ++ CryptoExchange.Net/Sockets/Query.cs | 15 +- .../Sockets/SocketConnection.cs | 171 +++++++----------- CryptoExchange.Net/Sockets/Subscription.cs | 4 +- .../Sockets/SystemSubscription.cs | 2 +- 8 files changed, 232 insertions(+), 144 deletions(-) create mode 100644 CryptoExchange.Net/Objects/ConcurrentList.cs create mode 100644 CryptoExchange.Net/Objects/Sockets/MatchingStrategy.cs create mode 100644 CryptoExchange.Net/Objects/Sockets/StreamMessageParseCallback.cs diff --git a/CryptoExchange.Net/Converters/SocketConverter.cs b/CryptoExchange.Net/Converters/SocketConverter.cs index 07d0c75..1766474 100644 --- a/CryptoExchange.Net/Converters/SocketConverter.cs +++ b/CryptoExchange.Net/Converters/SocketConverter.cs @@ -1,4 +1,5 @@ -using CryptoExchange.Net.Objects.Sockets; +using CryptoExchange.Net.Objects; +using CryptoExchange.Net.Objects.Sockets; using CryptoExchange.Net.Sockets; using Newtonsoft.Json; using Newtonsoft.Json.Linq; @@ -17,27 +18,12 @@ namespace CryptoExchange.Net.Converters { private static JsonSerializer _serializer = JsonSerializer.Create(SerializerOptions.WithConverters); - /// - /// Fields to use for the message subscription identifier - /// - public virtual string[]? SubscriptionIdFields => null; - /// - /// Fields to use for the message type identifier - /// - public abstract string[] TypeIdFields { get; } - - /// - /// Return the type of object that the message should be parsed to based on the type id values dictionary - /// - /// - /// - /// - public abstract Type? GetDeserializationType(Dictionary idValues, List pendingRequests, List listeners); + public abstract List InterpreterPipeline { get; } public virtual string CreateIdentifierString(Dictionary idValues) => string.Join("-", idValues.Values.Where(v => v != null).Select(v => v!.ToLower())); /// - public BaseParsedMessage? ReadJson(Stream stream, List pendingRequests, List listeners, bool outputOriginalData) + public BaseParsedMessage? ReadJson(Stream stream, ConcurrentList pendingRequests, ConcurrentList listeners, bool outputOriginalData) { // Start reading the data // Once we reach the properties that identify the message we save those in a dict @@ -63,25 +49,39 @@ namespace CryptoExchange.Net.Converters token = token.First!; } - var typeIdDict = new Dictionary(); - foreach (var idField in TypeIdFields) - typeIdDict[idField] = GetValueForKey(token, idField); - - Dictionary? subIdDict = null; - if (SubscriptionIdFields != null) + Type? resultType = null; + Dictionary typeIdDict = new Dictionary(); + StreamMessageParseCallback? usedParser = null; + foreach (var callback in InterpreterPipeline) { - subIdDict = new Dictionary(); - foreach (var idField in SubscriptionIdFields) - subIdDict[idField] = GetValueForKey(token, idField); + bool allFieldsPresent = true; + foreach(var field in callback.TypeFields) + { + var value = typeIdDict.TryGetValue(field, out var cachedValue) ? cachedValue : GetValueForKey(token, field); + if (value == null) + { + allFieldsPresent = false; + break; + } + + typeIdDict[field] = value; + } + + if (allFieldsPresent) + { + resultType = callback.Callback(typeIdDict, pendingRequests, listeners); + usedParser = callback; + break; + } } - var resultType = GetDeserializationType(typeIdDict, pendingRequests, listeners); - if (resultType == null) - { - // ? - return null; - } + if (usedParser == null) + throw new Exception("No parser found for message"); + var subIdDict = new Dictionary(); + foreach (var field in usedParser.IdFields) + subIdDict[field] = typeIdDict.TryGetValue(field, out var cachedValue) ? cachedValue : GetValueForKey(token, field); + var resultMessageType = typeof(ParsedMessage<>).MakeGenericType(resultType); var instance = (BaseParsedMessage)Activator.CreateInstance(resultMessageType, resultType == null ? null : token.ToObject(resultType, _serializer)); if (outputOriginalData) @@ -90,7 +90,7 @@ namespace CryptoExchange.Net.Converters instance.OriginalData = sr.ReadToEnd(); } - instance.Identifier = CreateIdentifierString(subIdDict ?? typeIdDict); + instance.Identifier = CreateIdentifierString(subIdDict); instance.Parsed = resultType != null; return instance; } diff --git a/CryptoExchange.Net/Objects/ConcurrentList.cs b/CryptoExchange.Net/Objects/ConcurrentList.cs new file mode 100644 index 0000000..40a6ce1 --- /dev/null +++ b/CryptoExchange.Net/Objects/ConcurrentList.cs @@ -0,0 +1,81 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using static System.Collections.Specialized.BitVector32; + +namespace CryptoExchange.Net.Objects +{ + public class ConcurrentList : IEnumerable + { + private readonly object _lock = new object(); + private readonly List _collection = new List(); + + public void Add(T item) + { + lock (_lock) + _collection.Add(item); + } + + + public void Remove(T item) + { + lock (_lock) + _collection.Remove(item); + } + + public T? SingleOrDefault(Func action) + { + lock (_lock) + return _collection.SingleOrDefault(action); + } + + public bool All(Func action) + { + lock (_lock) + return _collection.All(action); + } + + public bool Any(Func action) + { + lock (_lock) + return _collection.Any(action); + } + + public int Count(Func action) + { + lock (_lock) + return _collection.Count(action); + } + + public bool Contains(T item) + { + lock (_lock) + return _collection.Contains(item); + } + + public T[] ToArray(Func predicate) + { + lock (_lock) + return _collection.Where(predicate).ToArray(); + } + + public List ToList() + { + lock (_lock) + return _collection.ToList(); + } + + public IEnumerator GetEnumerator() + { + lock (_lock) + { + foreach (var item in _collection) + yield return item; + } + } + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } +} diff --git a/CryptoExchange.Net/Objects/Sockets/MatchingStrategy.cs b/CryptoExchange.Net/Objects/Sockets/MatchingStrategy.cs new file mode 100644 index 0000000..00c4e45 --- /dev/null +++ b/CryptoExchange.Net/Objects/Sockets/MatchingStrategy.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace CryptoExchange.Net.Objects.Sockets +{ + public interface IMatchingStrategy + { + + } + + internal class IdMatchingStrategy : IMatchingStrategy + { + + } + + internal class FieldsMatchingStrategy : IMatchingStrategy + { + + } +} diff --git a/CryptoExchange.Net/Objects/Sockets/StreamMessageParseCallback.cs b/CryptoExchange.Net/Objects/Sockets/StreamMessageParseCallback.cs new file mode 100644 index 0000000..288cf64 --- /dev/null +++ b/CryptoExchange.Net/Objects/Sockets/StreamMessageParseCallback.cs @@ -0,0 +1,14 @@ +using CryptoExchange.Net.Sockets; +using System; +using System.Collections.Generic; +using System.Text; + +namespace CryptoExchange.Net.Objects.Sockets +{ + public class StreamMessageParseCallback + { + public List TypeFields { get; set; } = new List(); + public List IdFields { get; set; } = new List(); + public Func, IEnumerable, IEnumerable, Type?> Callback { get; set; } + } +} diff --git a/CryptoExchange.Net/Sockets/Query.cs b/CryptoExchange.Net/Sockets/Query.cs index 6eacbf7..e947d47 100644 --- a/CryptoExchange.Net/Sockets/Query.cs +++ b/CryptoExchange.Net/Sockets/Query.cs @@ -11,6 +11,9 @@ namespace CryptoExchange.Net.Sockets /// public abstract class BaseQuery : IMessageProcessor { + /// + /// Unique identifier + /// public int Id { get; } = ExchangeHelpers.NextId(); /// /// Strings to identify this subscription with @@ -32,7 +35,10 @@ namespace CryptoExchange.Net.Sockets /// public int Weight { get; } - public BasePendingRequest PendingRequest { get; private set; } + /// + /// The pending request for this query + /// + public BasePendingRequest? PendingRequest { get; private set; } /// /// ctor @@ -56,6 +62,11 @@ namespace CryptoExchange.Net.Sockets return PendingRequest; } + /// + /// Create a pending request for this query + /// + /// + /// public abstract BasePendingRequest GetPendingRequest(int id); /// @@ -86,7 +97,7 @@ namespace CryptoExchange.Net.Sockets /// public override async Task HandleMessageAsync(DataEvent message) { - await PendingRequest.ProcessAsync(message).ConfigureAwait(false); + await PendingRequest!.ProcessAsync(message).ConfigureAwait(false); return await HandleMessageAsync(message.As((ParsedMessage)message.Data)).ConfigureAwait(false); } diff --git a/CryptoExchange.Net/Sockets/SocketConnection.cs b/CryptoExchange.Net/Sockets/SocketConnection.cs index 7a52282..0a195d7 100644 --- a/CryptoExchange.Net/Sockets/SocketConnection.cs +++ b/CryptoExchange.Net/Sockets/SocketConnection.cs @@ -11,6 +11,7 @@ using System.Net.WebSockets; using System.IO; using CryptoExchange.Net.Objects.Sockets; using System.Text; +using System.Collections.Concurrent; namespace CryptoExchange.Net.Sockets { @@ -57,11 +58,7 @@ namespace CryptoExchange.Net.Sockets /// /// The amount of subscriptions on this connection /// - public int UserSubscriptionCount - { - get { lock (_subscriptionLock) - return _subscriptions.Count(h => h.UserSubscription); } - } + public int UserSubscriptionCount => _subscriptions.Count(h => h.UserSubscription); /// /// Get a copy of the current message subscriptions @@ -70,8 +67,7 @@ namespace CryptoExchange.Net.Sockets { get { - lock (_subscriptionLock) - return _subscriptions.Where(h => h.UserSubscription).ToArray(); + return _subscriptions.ToArray(h => h.UserSubscription); } } @@ -156,13 +152,10 @@ namespace CryptoExchange.Net.Sockets } private bool _pausedActivity; - private readonly List _pendingRequests; - private readonly List _subscriptions; - private readonly Dictionary _messageIdMap; - - private readonly object _subscriptionLock = new(); + private readonly ConcurrentList _pendingRequests; + private readonly ConcurrentList _subscriptions; + private readonly ConcurrentDictionary _messageIdMap; private readonly ILogger _logger; - private SocketStatus _status; /// @@ -184,9 +177,9 @@ namespace CryptoExchange.Net.Sockets Tag = tag; Properties = new Dictionary(); - _pendingRequests = new List(); - _subscriptions = new List(); - _messageIdMap = new Dictionary(); + _pendingRequests = new ConcurrentList(); + _subscriptions = new ConcurrentList(); + _messageIdMap = new ConcurrentDictionary(); _socket = socket; _socket.OnStreamMessage += HandleStreamMessage; @@ -215,11 +208,10 @@ namespace CryptoExchange.Net.Sockets { Status = SocketStatus.Closed; Authenticated = false; - lock(_subscriptionLock) - { - foreach (var subscription in _subscriptions) - subscription.Confirmed = false; - } + + foreach (var subscription in _subscriptions) + subscription.Confirmed = false; + Task.Run(() => ConnectionClosed?.Invoke()); } @@ -231,11 +223,9 @@ namespace CryptoExchange.Net.Sockets Status = SocketStatus.Reconnecting; DisconnectTime = DateTime.UtcNow; Authenticated = false; - lock (_subscriptionLock) - { - foreach (var subscription in _subscriptions) - subscription.Confirmed = false; - } + + foreach (var subscription in _subscriptions) + subscription.Confirmed = false; _ = Task.Run(() => ConnectionLost?.Invoke()); } @@ -255,13 +245,11 @@ namespace CryptoExchange.Net.Sockets protected virtual async void HandleReconnected() { Status = SocketStatus.Resubscribing; - lock (_subscriptions) + + foreach (var pendingRequest in _pendingRequests.ToList()) { - foreach (var pendingRequest in _pendingRequests.ToList()) - { - pendingRequest.Fail("Connection interupted"); - // Remove? - } + pendingRequest.Fail("Connection interupted"); + // Remove? } var reconnectSuccessful = await ProcessReconnectAsync().ConfigureAwait(false); @@ -299,10 +287,7 @@ namespace CryptoExchange.Net.Sockets /// Id of the request sent protected virtual void HandleRequestSent(int requestId) { - BasePendingRequest pendingRequest; - lock (_pendingRequests) - pendingRequest = _pendingRequests.SingleOrDefault(p => p.Id == requestId); - + var pendingRequest = _pendingRequests.SingleOrDefault(p => p.Id == requestId); if (pendingRequest == null) { _logger.Log(LogLevel.Debug, $"Socket {SocketId} - msg {requestId} - message sent, but not pending"); @@ -322,11 +307,7 @@ namespace CryptoExchange.Net.Sockets var timestamp = DateTime.UtcNow; TimeSpan userCodeDuration = TimeSpan.Zero; - List subscriptions; - lock (_subscriptionLock) - subscriptions = _subscriptions.OrderByDescending(x => !x.UserSubscription).ToList(); - - var result = ApiClient.StreamConverter.ReadJson(stream, _pendingRequests, subscriptions, ApiClient.ApiOptions.OutputOriginalData ?? ApiClient.ClientOptions.OutputOriginalData); + var result = ApiClient.StreamConverter.ReadJson(stream, _pendingRequests, _subscriptions, ApiClient.ApiOptions.OutputOriginalData ?? ApiClient.ClientOptions.OutputOriginalData); if(result == null) { stream.Position = 0; @@ -358,12 +339,13 @@ namespace CryptoExchange.Net.Sockets return; } + stream.Dispose(); _logger.Log(LogLevel.Trace, $"Socket {SocketId} Message mapped to processor {messageProcessor.Id} with identifier {result.Identifier}"); if (messageProcessor is BaseQuery query) { foreach (var id in query.Identifiers) - _messageIdMap.Remove(id); + _messageIdMap.TryRemove(id, out _); if (query.PendingRequest != null) _pendingRequests.Remove(query.PendingRequest); @@ -412,13 +394,10 @@ namespace CryptoExchange.Net.Sockets if (ApiClient.socketConnections.ContainsKey(SocketId)) ApiClient.socketConnections.TryRemove(SocketId, out _); - lock (_subscriptionLock) + foreach (var subscription in _subscriptions) { - foreach (var subscription in _subscriptions) - { - if (subscription.CancellationTokenRegistration.HasValue) - subscription.CancellationTokenRegistration.Value.Dispose(); - } + if (subscription.CancellationTokenRegistration.HasValue) + subscription.CancellationTokenRegistration.Value.Dispose(); } await _socket.CloseAsync().ConfigureAwait(false); @@ -433,13 +412,10 @@ namespace CryptoExchange.Net.Sockets /// public async Task CloseAsync(Subscription subscription, bool unsubEvenIfNotConfirmed = false) { - lock (_subscriptionLock) - { - if (!_subscriptions.Contains(subscription)) - return; + if (!_subscriptions.Contains(subscription)) + return; - subscription.Closed = true; - } + subscription.Closed = true; if (Status == SocketStatus.Closing || Status == SocketStatus.Closed || Status == SocketStatus.Disposed) return; @@ -451,32 +427,25 @@ namespace CryptoExchange.Net.Sockets if ((unsubEvenIfNotConfirmed || subscription.Confirmed) && _socket.IsOpen) await UnsubscribeAsync(subscription).ConfigureAwait(false); - bool shouldCloseConnection; - lock (_subscriptionLock) + if (Status == SocketStatus.Closing) { - if (Status == SocketStatus.Closing) - { - _logger.Log(LogLevel.Debug, $"Socket {SocketId} already closing"); - return; - } - - shouldCloseConnection = _subscriptions.All(r => !r.UserSubscription || r.Closed); - if (shouldCloseConnection) - Status = SocketStatus.Closing; + _logger.Log(LogLevel.Debug, $"Socket {SocketId} already closing"); + return; } + var shouldCloseConnection = _subscriptions.All(r => !r.UserSubscription || r.Closed); + if (shouldCloseConnection) + Status = SocketStatus.Closing; + if (shouldCloseConnection) { _logger.Log(LogLevel.Debug, $"Socket {SocketId} closing as there are no more subscriptions"); await CloseAsync().ConfigureAwait(false); } - lock (_subscriptionLock) - { - _subscriptions.Remove(subscription); - foreach (var id in subscription.Identifiers) - _messageIdMap.Remove(id); - } + _subscriptions.Remove(subscription); + foreach (var id in subscription.Identifiers) + _messageIdMap.TryRemove(id, out _); } /// @@ -494,44 +463,36 @@ namespace CryptoExchange.Net.Sockets /// public bool AddSubscription(Subscription subscription) { - lock (_subscriptionLock) + if (Status != SocketStatus.None && Status != SocketStatus.Connected) + return false; + + _subscriptions.Add(subscription); + if (subscription.Identifiers != null) { - if (Status != SocketStatus.None && Status != SocketStatus.Connected) - return false; - - _subscriptions.Add(subscription); - if (subscription.Identifiers != null) + foreach (var id in subscription.Identifiers) { - foreach (var id in subscription.Identifiers) - _messageIdMap.Add(id.ToLowerInvariant(), subscription); + if (!_messageIdMap.TryAdd(id.ToLowerInvariant(), subscription)) + throw new InvalidOperationException($"Failed to register subscription id {id}, already registered"); } - - if (subscription.UserSubscription) - _logger.Log(LogLevel.Debug, $"Socket {SocketId} adding new subscription with id {subscription.Id}, total subscriptions on connection: {_subscriptions.Count(s => s.UserSubscription)}"); - return true; } + + if (subscription.UserSubscription) + _logger.Log(LogLevel.Debug, $"Socket {SocketId} adding new subscription with id {subscription.Id}, total subscriptions on connection: {_subscriptions.Count(s => s.UserSubscription)}"); + return true; } /// /// Get a subscription on this connection by id /// /// - public Subscription? GetSubscription(int id) - { - lock (_subscriptionLock) - return _subscriptions.SingleOrDefault(s => s.Id == id); - } + public Subscription? GetSubscription(int id) => _subscriptions.SingleOrDefault(s => s.Id == id); /// /// Get a subscription on this connection by its subscribe request /// /// Filter for a request /// - public Subscription? GetSubscriptionByRequest(Func predicate) - { - lock(_subscriptionLock) - return _subscriptions.SingleOrDefault(s => predicate(s)); - } + public Subscription? GetSubscriptionByRequest(Func predicate) => _subscriptions.SingleOrDefault(s => predicate(s)); /// /// Send a query request and wait for an answer @@ -544,7 +505,10 @@ namespace CryptoExchange.Net.Sockets if (query.Identifiers != null) { foreach (var id in query.Identifiers) - _messageIdMap.Add(id.ToLowerInvariant(), query); + { + if(!_messageIdMap.TryAdd(id.ToLowerInvariant(), query)) + throw new InvalidOperationException($"Failed to register subscription id {id}, already registered"); + } } await SendAndWaitAsync(pendingRequest, query.Weight).ConfigureAwait(false); @@ -563,7 +527,10 @@ namespace CryptoExchange.Net.Sockets if (query.Identifiers != null) { foreach (var id in query.Identifiers) - _messageIdMap.Add(id.ToLowerInvariant(), query); + { + if (!_messageIdMap.TryAdd(id.ToLowerInvariant(), query)) + throw new InvalidOperationException($"Failed to register subscription id {id}, already registered"); + } } await SendAndWaitAsync(pendingRequest, query.Weight).ConfigureAwait(false); @@ -641,10 +608,7 @@ namespace CryptoExchange.Net.Sockets if (!_socket.IsOpen) return new CallResult(new WebError("Socket not connected")); - bool anySubscriptions = false; - lock (_subscriptionLock) - anySubscriptions = _subscriptions.Any(s => s.UserSubscription); - + var anySubscriptions = _subscriptions.Any(s => s.UserSubscription); if (!anySubscriptions) { // No need to resubscribe anything @@ -653,10 +617,7 @@ namespace CryptoExchange.Net.Sockets return new CallResult(true); } - bool anyAuthenticated = false; - lock (_subscriptionLock) - anyAuthenticated = _subscriptions.Any(s => s.Authenticated); - + var anyAuthenticated = _subscriptions.Any(s => s.Authenticated); if (anyAuthenticated) { // If we reconnected a authenticated connection we need to re-authenticate @@ -672,9 +633,7 @@ namespace CryptoExchange.Net.Sockets } // Get a list of all subscriptions on the socket - List subList = new List(); - lock (_subscriptionLock) - subList = _subscriptions.ToList(); + var subList = _subscriptions.ToList(); foreach(var subscription in subList) { diff --git a/CryptoExchange.Net/Sockets/Subscription.cs b/CryptoExchange.Net/Sockets/Subscription.cs index 8b612b9..8fbad5b 100644 --- a/CryptoExchange.Net/Sockets/Subscription.cs +++ b/CryptoExchange.Net/Sockets/Subscription.cs @@ -64,10 +64,12 @@ namespace CryptoExchange.Net.Sockets /// /// /// - public Subscription(ILogger logger, bool authenticated) + /// + public Subscription(ILogger logger, bool authenticated, bool userSubscription = true) { _logger = logger; Authenticated = authenticated; + UserSubscription = userSubscription; Id = ExchangeHelpers.NextId(); } diff --git a/CryptoExchange.Net/Sockets/SystemSubscription.cs b/CryptoExchange.Net/Sockets/SystemSubscription.cs index c2c82e5..e5d6165 100644 --- a/CryptoExchange.Net/Sockets/SystemSubscription.cs +++ b/CryptoExchange.Net/Sockets/SystemSubscription.cs @@ -12,7 +12,7 @@ namespace CryptoExchange.Net.Sockets /// /// /// - public SystemSubscription(ILogger logger, bool authenticated = false) : base(logger, authenticated) + public SystemSubscription(ILogger logger, bool authenticated = false) : base(logger, authenticated, false) { }