diff --git a/CryptoExchange.Net/Converters/SocketConverter.cs b/CryptoExchange.Net/Converters/SocketConverter.cs index f389d50..07d0c75 100644 --- a/CryptoExchange.Net/Converters/SocketConverter.cs +++ b/CryptoExchange.Net/Converters/SocketConverter.cs @@ -5,6 +5,7 @@ using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Text; namespace CryptoExchange.Net.Converters @@ -33,6 +34,8 @@ namespace CryptoExchange.Net.Converters /// <returns></returns> public abstract Type? GetDeserializationType(Dictionary<string, string?> idValues, List<BasePendingRequest> pendingRequests, List<Subscription> listeners); + public virtual string CreateIdentifierString(Dictionary<string, string?> idValues) => string.Join("-", idValues.Values.Where(v => v != null).Select(v => v!.ToLower())); + /// <inheritdoc /> public BaseParsedMessage? ReadJson(Stream stream, List<BasePendingRequest> pendingRequests, List<Subscription> listeners, bool outputOriginalData) { @@ -61,19 +64,15 @@ namespace CryptoExchange.Net.Converters } var typeIdDict = new Dictionary<string, string?>(); - string idString = ""; foreach (var idField in TypeIdFields) - { - var val = GetValueForKey(token, idField); - idString += val; - typeIdDict[idField] = val; - } + typeIdDict[idField] = GetValueForKey(token, idField); + Dictionary<string, string?>? subIdDict = null; if (SubscriptionIdFields != null) { - idString = ""; + subIdDict = new Dictionary<string, string?>(); foreach (var idField in SubscriptionIdFields) - idString += GetValueForKey(token, idField); + subIdDict[idField] = GetValueForKey(token, idField); } var resultType = GetDeserializationType(typeIdDict, pendingRequests, listeners); @@ -91,7 +90,7 @@ namespace CryptoExchange.Net.Converters instance.OriginalData = sr.ReadToEnd(); } - instance.Identifier = idString; + instance.Identifier = CreateIdentifierString(subIdDict ?? typeIdDict); instance.Parsed = resultType != null; return instance; } diff --git a/CryptoExchange.Net/Interfaces/IMessageProcessor.cs b/CryptoExchange.Net/Interfaces/IMessageProcessor.cs new file mode 100644 index 0000000..0a3f457 --- /dev/null +++ b/CryptoExchange.Net/Interfaces/IMessageProcessor.cs @@ -0,0 +1,15 @@ +using CryptoExchange.Net.Objects; +using CryptoExchange.Net.Objects.Sockets; +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +namespace CryptoExchange.Net.Interfaces +{ + public interface IMessageProcessor + { + public int Id { get; } + Task<CallResult> HandleMessageAsync(DataEvent<BaseParsedMessage> message); + } +} diff --git a/CryptoExchange.Net/Objects/Sockets/PendingRequest.cs b/CryptoExchange.Net/Objects/Sockets/PendingRequest.cs index b90d3a5..b0eef4a 100644 --- a/CryptoExchange.Net/Objects/Sockets/PendingRequest.cs +++ b/CryptoExchange.Net/Objects/Sockets/PendingRequest.cs @@ -15,10 +15,6 @@ namespace CryptoExchange.Net.Objects.Sockets /// </summary> public int Id { get; set; } /// <summary> - /// Callback for checking if a message is a response to this request - /// </summary> - public Func<BaseParsedMessage, bool> MessageMatchesHandler { get; } - /// <summary> /// If the request is completed /// </summary> public bool Completed { get; protected set; } @@ -48,12 +44,10 @@ namespace CryptoExchange.Net.Objects.Sockets /// </summary> /// <param name="id"></param> /// <param name="request"></param> - /// <param name="messageMatchesHandler"></param> /// <param name="timeout"></param> - protected BasePendingRequest(int id, object request, Func<BaseParsedMessage, bool> messageMatchesHandler, TimeSpan timeout) + protected BasePendingRequest(int id, object request, TimeSpan timeout) { Id = id; - MessageMatchesHandler = messageMatchesHandler; _event = new AsyncResetEvent(false, false); _timeout = timeout; Request = request; @@ -93,7 +87,7 @@ namespace CryptoExchange.Net.Objects.Sockets /// </summary> /// <param name="message"></param> /// <returns></returns> - public abstract void ProcessAsync(BaseParsedMessage message); + public abstract Task ProcessAsync(DataEvent<BaseParsedMessage> message); } /// <summary> @@ -111,7 +105,7 @@ namespace CryptoExchange.Net.Objects.Sockets /// <summary> /// Data handler /// </summary> - public Func<ParsedMessage<T>, CallResult<T>> Handler { get; } + public Func<DataEvent<ParsedMessage<T>>, Task<CallResult<T>>> Handler { get; } /// <summary> /// The response object type /// </summary> @@ -122,11 +116,10 @@ namespace CryptoExchange.Net.Objects.Sockets /// </summary> /// <param name="id"></param> /// <param name="request"></param> - /// <param name="messageMatchesHandler"></param> /// <param name="messageHandler"></param> /// <param name="timeout"></param> - private PendingRequest(int id, object request, Func<ParsedMessage<T>, bool> messageMatchesHandler, Func<ParsedMessage<T>, CallResult<T>> messageHandler, TimeSpan timeout) - : base(id, request, (x) => messageMatchesHandler((ParsedMessage<T>)x), timeout) + private PendingRequest(int id, object request, Func<DataEvent<ParsedMessage<T>>, Task<CallResult<T>>> messageHandler, TimeSpan timeout) + : base(id, request, timeout) { Handler = messageHandler; } @@ -135,12 +128,13 @@ namespace CryptoExchange.Net.Objects.Sockets /// Create a new pending request for provided query /// </summary> /// <param name="query"></param> + /// <param name="id"></param> /// <returns></returns> - public static PendingRequest<T> CreateForQuery(Query<T> query) + public static PendingRequest<T> CreateForQuery(Query<T> query, int id) { - return new PendingRequest<T>(ExchangeHelpers.NextId(), query.Request, query.MessageMatchesQuery, x => + return new PendingRequest<T>(id, query.Request, async x => { - var response = query.HandleResponse(x); + var response = await query.HandleMessageAsync(x).ConfigureAwait(false); return response.As(response.Data); }, TimeSpan.FromSeconds(5)); } @@ -161,10 +155,10 @@ namespace CryptoExchange.Net.Objects.Sockets } /// <inheritdoc /> - public override void ProcessAsync(BaseParsedMessage message) + public override async Task ProcessAsync(DataEvent<BaseParsedMessage> message) { Completed = true; - Result = Handler((ParsedMessage<T>)message); + Result = await Handler(message.As((ParsedMessage<T>)message.Data)).ConfigureAwait(false); _event.Set(); } } diff --git a/CryptoExchange.Net/Sockets/Query.cs b/CryptoExchange.Net/Sockets/Query.cs index 9310a5d..6eacbf7 100644 --- a/CryptoExchange.Net/Sockets/Query.cs +++ b/CryptoExchange.Net/Sockets/Query.cs @@ -1,13 +1,22 @@ -using CryptoExchange.Net.Objects; +using CryptoExchange.Net.Interfaces; +using CryptoExchange.Net.Objects; using CryptoExchange.Net.Objects.Sockets; +using System.Collections.Generic; +using System.Threading.Tasks; namespace CryptoExchange.Net.Sockets { /// <summary> /// Query /// </summary> - public abstract class BaseQuery + public abstract class BaseQuery : IMessageProcessor { + public int Id { get; } = ExchangeHelpers.NextId(); + /// <summary> + /// Strings to identify this subscription with + /// </summary> + public abstract List<string> Identifiers { get; } + /// <summary> /// The query request object /// </summary> @@ -23,6 +32,8 @@ namespace CryptoExchange.Net.Sockets /// </summary> public int Weight { get; } + public BasePendingRequest PendingRequest { get; private set; } + /// <summary> /// ctor /// </summary> @@ -39,9 +50,23 @@ namespace CryptoExchange.Net.Sockets /// <summary> /// Create a pending request for this query /// </summary> - public abstract BasePendingRequest CreatePendingRequest(); + public BasePendingRequest CreatePendingRequest() + { + PendingRequest = GetPendingRequest(Id); + return PendingRequest; + } + + public abstract BasePendingRequest GetPendingRequest(int id); + + /// <summary> + /// Handle a response message + /// </summary> + /// <param name="message"></param> + /// <returns></returns> + public abstract Task<CallResult> HandleMessageAsync(DataEvent<BaseParsedMessage> message); + } - + /// <summary> /// Query /// </summary> @@ -58,21 +83,21 @@ namespace CryptoExchange.Net.Sockets { } + /// <inheritdoc /> + public override async Task<CallResult> HandleMessageAsync(DataEvent<BaseParsedMessage> message) + { + await PendingRequest.ProcessAsync(message).ConfigureAwait(false); + return await HandleMessageAsync(message.As((ParsedMessage<TResponse>)message.Data)).ConfigureAwait(false); + } + /// <summary> /// Handle the query response /// </summary> /// <param name="message"></param> /// <returns></returns> - public abstract CallResult<TResponse> HandleResponse(ParsedMessage<TResponse> message); - - /// <summary> - /// Check if the message is the response to the query - /// </summary> - /// <param name="message"></param> - /// <returns></returns> - public abstract bool MessageMatchesQuery(ParsedMessage<TResponse> message); + public virtual Task<CallResult<TResponse>> HandleMessageAsync(DataEvent<ParsedMessage<TResponse>> message) => Task.FromResult(new CallResult<TResponse>(message.Data.Data!)); /// <inheritdoc /> - public override BasePendingRequest CreatePendingRequest() => PendingRequest<TResponse>.CreateForQuery(this); + public override BasePendingRequest GetPendingRequest(int id) => PendingRequest<TResponse>.CreateForQuery(this, id); } } diff --git a/CryptoExchange.Net/Sockets/SocketConnection.cs b/CryptoExchange.Net/Sockets/SocketConnection.cs index e02d6b7..7a52282 100644 --- a/CryptoExchange.Net/Sockets/SocketConnection.cs +++ b/CryptoExchange.Net/Sockets/SocketConnection.cs @@ -60,7 +60,7 @@ namespace CryptoExchange.Net.Sockets public int UserSubscriptionCount { get { lock (_subscriptionLock) - return _messageIdentifierSubscriptions.Values.Count(h => h.UserSubscription); } + return _subscriptions.Count(h => h.UserSubscription); } } /// <summary> @@ -71,7 +71,7 @@ namespace CryptoExchange.Net.Sockets get { lock (_subscriptionLock) - return _messageIdentifierSubscriptions.Values.Where(h => h.UserSubscription).ToArray(); + return _subscriptions.Where(h => h.UserSubscription).ToArray(); } } @@ -158,7 +158,7 @@ namespace CryptoExchange.Net.Sockets private bool _pausedActivity; private readonly List<BasePendingRequest> _pendingRequests; private readonly List<Subscription> _subscriptions; - private readonly Dictionary<string, Subscription> _messageIdentifierSubscriptions; + private readonly Dictionary<string, IMessageProcessor> _messageIdMap; private readonly object _subscriptionLock = new(); private readonly ILogger _logger; @@ -186,7 +186,7 @@ namespace CryptoExchange.Net.Sockets _pendingRequests = new List<BasePendingRequest>(); _subscriptions = new List<Subscription>(); - _messageIdentifierSubscriptions = new Dictionary<string, Subscription>(); + _messageIdMap = new Dictionary<string, IMessageProcessor>(); _socket = socket; _socket.OnStreamMessage += HandleStreamMessage; @@ -346,48 +346,40 @@ namespace CryptoExchange.Net.Sockets } // TODO lock - if (_messageIdentifierSubscriptions.TryGetValue(result.Identifier.ToLowerInvariant(), out var idSubscription)) + if (!_messageIdMap.TryGetValue(result.Identifier, out var messageProcessor)) { - // Matched based on identifier - var userSw = Stopwatch.StartNew(); - var dataEvent = new DataEvent<BaseParsedMessage>(result, null, result.OriginalData, DateTime.UtcNow, null); - await idSubscription.HandleEventAsync(dataEvent).ConfigureAwait(false); - userSw.Stop(); + stream.Position = 0; + var unhandledBuffer = new byte[stream.Length]; + stream.Read(unhandledBuffer, 0, unhandledBuffer.Length); + + _logger.Log(LogLevel.Warning, $"Socket {SocketId} Message unidentified. Id: {result.Identifier.ToLowerInvariant()}, Message: {Encoding.UTF8.GetString(unhandledBuffer)} "); + + UnhandledMessage?.Invoke(result); return; } - List<BasePendingRequest> pendingRequests; - lock (_pendingRequests) - pendingRequests = _pendingRequests.ToList(); + _logger.Log(LogLevel.Trace, $"Socket {SocketId} Message mapped to processor {messageProcessor.Id} with identifier {result.Identifier}"); - foreach (var pendingRequest in pendingRequests) + if (messageProcessor is BaseQuery query) { - if (pendingRequest.MessageMatchesHandler(result)) + foreach (var id in query.Identifiers) + _messageIdMap.Remove(id); + + if (query.PendingRequest != null) + _pendingRequests.Remove(query.PendingRequest); + + if (query.PendingRequest?.Completed == true) { - lock (_pendingRequests) - _pendingRequests.Remove(pendingRequest); - - if (pendingRequest.Completed) - { - // Answer to a timed out request - _logger.Log(LogLevel.Warning, $"Socket {SocketId} Received after request timeout. Consider increasing the RequestTimeout"); - } - else - { - _logger.Log(LogLevel.Trace, $"Socket {SocketId} - msg {pendingRequest.Id} - received data matched to pending request"); - pendingRequest.ProcessAsync(result); - } - - return; + // Answer to a timed out request + _logger.Log(LogLevel.Warning, $"Socket {SocketId} Received after request timeout. Consider increasing the RequestTimeout"); } } - stream.Position = 0; - var unhandledBuffer = new byte[stream.Length]; - stream.Read(unhandledBuffer, 0, unhandledBuffer.Length); - - _logger.Log(LogLevel.Warning, $"Socket {SocketId} Message not handled: " + Encoding.UTF8.GetString(unhandledBuffer)); - UnhandledMessage?.Invoke(result); + // Matched based on identifier + var userSw = Stopwatch.StartNew(); + var dataEvent = new DataEvent<BaseParsedMessage>(result, null, result.OriginalData, DateTime.UtcNow, null); + await messageProcessor.HandleMessageAsync(dataEvent).ConfigureAwait(false); + userSw.Stop(); } /// <summary> @@ -468,7 +460,7 @@ namespace CryptoExchange.Net.Sockets return; } - shouldCloseConnection = _messageIdentifierSubscriptions.All(r => !r.Value.UserSubscription || r.Value.Closed); + shouldCloseConnection = _subscriptions.All(r => !r.UserSubscription || r.Closed); if (shouldCloseConnection) Status = SocketStatus.Closing; } @@ -483,7 +475,7 @@ namespace CryptoExchange.Net.Sockets { _subscriptions.Remove(subscription); foreach (var id in subscription.Identifiers) - _messageIdentifierSubscriptions.Remove(id); + _messageIdMap.Remove(id); } } @@ -511,7 +503,7 @@ namespace CryptoExchange.Net.Sockets if (subscription.Identifiers != null) { foreach (var id in subscription.Identifiers) - _messageIdentifierSubscriptions.Add(id.ToLowerInvariant(), subscription); + _messageIdMap.Add(id.ToLowerInvariant(), subscription); } if (subscription.UserSubscription) @@ -549,6 +541,12 @@ namespace CryptoExchange.Net.Sockets public virtual async Task<CallResult> SendAndWaitQueryAsync(BaseQuery query) { var pendingRequest = query.CreatePendingRequest(); + if (query.Identifiers != null) + { + foreach (var id in query.Identifiers) + _messageIdMap.Add(id.ToLowerInvariant(), query); + } + await SendAndWaitAsync(pendingRequest, query.Weight).ConfigureAwait(false); return pendingRequest.Result!; } @@ -561,7 +559,13 @@ namespace CryptoExchange.Net.Sockets /// <returns></returns> public virtual async Task<CallResult<T>> SendAndWaitQueryAsync<T>(Query<T> query) { - var pendingRequest = PendingRequest<T>.CreateForQuery(query); + var pendingRequest = PendingRequest<T>.CreateForQuery(query, query.Id); + if (query.Identifiers != null) + { + foreach (var id in query.Identifiers) + _messageIdMap.Add(id.ToLowerInvariant(), query); + } + await SendAndWaitAsync(pendingRequest, query.Weight).ConfigureAwait(false); return pendingRequest.TypedResult!; } diff --git a/CryptoExchange.Net/Sockets/Subscription.cs b/CryptoExchange.Net/Sockets/Subscription.cs index 5f73def..8b612b9 100644 --- a/CryptoExchange.Net/Sockets/Subscription.cs +++ b/CryptoExchange.Net/Sockets/Subscription.cs @@ -1,4 +1,5 @@ -using CryptoExchange.Net.Objects; +using CryptoExchange.Net.Interfaces; +using CryptoExchange.Net.Objects; using CryptoExchange.Net.Objects.Sockets; using Microsoft.Extensions.Logging; using System; @@ -11,7 +12,7 @@ namespace CryptoExchange.Net.Sockets /// <summary> /// Socket subscription /// </summary> - public abstract class Subscription + public abstract class Subscription : IMessageProcessor { /// <summary> /// Subscription id @@ -67,6 +68,7 @@ namespace CryptoExchange.Net.Sockets { _logger = logger; Authenticated = authenticated; + Id = ExchangeHelpers.NextId(); } /// <summary> @@ -86,7 +88,7 @@ namespace CryptoExchange.Net.Sockets /// </summary> /// <param name="message"></param> /// <returns></returns> - public abstract Task HandleEventAsync(DataEvent<BaseParsedMessage> message); + public abstract Task<CallResult> HandleMessageAsync(DataEvent<BaseParsedMessage> message); /// <summary> /// Invoke the exception event @@ -124,7 +126,7 @@ namespace CryptoExchange.Net.Sockets } /// <inheritdoc /> - public override Task HandleEventAsync(DataEvent<BaseParsedMessage> message) + public override Task<CallResult> HandleMessageAsync(DataEvent<BaseParsedMessage> message) => HandleEventAsync(message.As((ParsedMessage<TEvent>)message.Data)); /// <summary> @@ -132,6 +134,6 @@ namespace CryptoExchange.Net.Sockets /// </summary> /// <param name="message"></param> /// <returns></returns> - public abstract Task HandleEventAsync(DataEvent<ParsedMessage<TEvent>> message); + public abstract Task<CallResult> HandleEventAsync(DataEvent<ParsedMessage<TEvent>> message); } }