1
0
mirror of https://github.com/JKorf/CryptoExchange.Net synced 2025-06-08 16:36:15 +00:00
This commit is contained in:
JKorf 2023-11-10 18:59:04 +01:00
parent b59fe9e3ef
commit ff6a9d5f13
8 changed files with 270 additions and 101 deletions

View File

@ -233,7 +233,7 @@ namespace CryptoExchange.Net
if (subQuery != null)
{
// Send the request and wait for answer
var subResult = await socketConnection.SendAndWaitQueryAsync(subQuery).ConfigureAwait(false); // TODO return null on timeout
var subResult = await socketConnection.SendAndWaitQueryAsync(subQuery).ConfigureAwait(false);
if (!subResult)
{
_logger.Log(LogLevel.Warning, $"Socket {socketConnection.SocketId} failed to subscribe: {subResult.Error}");
@ -259,6 +259,7 @@ namespace CryptoExchange.Net
}, false);
}
subscription.Confirmed = true;
_logger.Log(LogLevel.Information, $"Socket {socketConnection.SocketId} subscription {subscription.Id} completed successfully");
return new CallResult<UpdateSubscription>(new UpdateSubscription(socketConnection, subscription));
}

View File

@ -22,7 +22,7 @@ namespace CryptoExchange.Net.Converters
public abstract MessageInterpreterPipeline InterpreterPipeline { get; }
/// <inheritdoc />
public BaseParsedMessage? ReadJson(Stream stream, IDictionary<string, IMessageProcessor> processors, bool outputOriginalData)
public BaseParsedMessage? ReadJson(Stream stream, Dictionary<string, Type> processors, bool outputOriginalData)
{
// Start reading the data
// Once we reach the properties that identify the message we save those in a dict
@ -62,43 +62,69 @@ namespace CryptoExchange.Net.Converters
return null;
}
if (token.Type == JTokenType.Array)
{
// Received array, take first item as reference
token = token.First!;
}
PostInspectResult? inspectResult = null;
Dictionary<string, string> typeIdDict = new Dictionary<string, string>();
PostInspectCallback? usedParser = null;
foreach (var callback in InterpreterPipeline.PostInspectCallbacks)
object? usedParser = null;
if (token.Type == JTokenType.Object)
{
bool allFieldsPresent = true;
foreach(var field in callback.TypeFields)
foreach (var callback in InterpreterPipeline.PostInspectCallbacks.OfType<PostInspectCallback>())
{
var value = typeIdDict.TryGetValue(field, out var cachedValue) ? cachedValue : GetValueForKey(token, field);
if (value == null)
bool allFieldsPresent = true;
foreach (var field in callback.TypeFields)
{
allFieldsPresent = false;
break;
var value = typeIdDict.TryGetValue(field, out var cachedValue) ? cachedValue : GetValueForKey(token, field);
if (value == null)
{
allFieldsPresent = false;
break;
}
typeIdDict[field] = value;
}
typeIdDict[field] = value;
if (allFieldsPresent)
{
inspectResult = callback.Callback(typeIdDict, processors);
usedParser = callback;
break;
}
}
if (allFieldsPresent)
}
else
{
foreach (var callback in InterpreterPipeline.PostInspectCallbacks.OfType<PostInspectArrayCallback>())
{
inspectResult = callback.Callback(typeIdDict, processors);
usedParser = callback;
break;
var typeIdArrayDict = new Dictionary<int, string>();
bool allFieldsPresent = true;
var maxIndex = callback.TypeFields.Max();
if (((JArray)token).Count <= maxIndex)
continue;
foreach (var field in callback.TypeFields)
{
var value = token[field];
if (value == null)
{
allFieldsPresent = false;
break;
}
typeIdArrayDict[field] = value.ToString();
}
if (allFieldsPresent)
{
inspectResult = callback.Callback(typeIdArrayDict, processors);
usedParser = callback;
break;
}
}
}
if (usedParser == null)
throw new Exception("No parser found for message");
var resultMessageType = typeof(ParsedMessage<>).MakeGenericType(inspectResult.Type);
var instance = (BaseParsedMessage)Activator.CreateInstance(resultMessageType, inspectResult.Type == null ? null : token.ToObject(inspectResult.Type, _serializer));
var instance = InterpreterPipeline.ObjectInitializer(token, inspectResult.Type);
if (outputOriginalData)
{
stream.Position = 0;
@ -110,6 +136,13 @@ namespace CryptoExchange.Net.Converters
return instance;
}
public static BaseParsedMessage InstantiateMessageObject(JToken token, Type type)
{
var resultMessageType = typeof(ParsedMessage<>).MakeGenericType(type);
var instance = (BaseParsedMessage)Activator.CreateInstance(resultMessageType, type == null ? null : token.ToObject(type, _serializer));
return instance;
}
private string? GetValueForKey(JToken token, string key)
{
var splitTokens = key.Split(new char[] { ':' });

View File

@ -10,6 +10,7 @@ namespace CryptoExchange.Net.Interfaces
public interface IMessageProcessor
{
public int Id { get; }
public List<string> Identifiers { get; }
Task<CallResult> HandleMessageAsync(DataEvent<BaseParsedMessage> message);
public Type ExpectedMessageType { get; }
}

View File

@ -1,6 +1,7 @@
using CryptoExchange.Net.Converters;
using CryptoExchange.Net.Interfaces;
using CryptoExchange.Net.Sockets;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.IO;
@ -11,7 +12,8 @@ namespace CryptoExchange.Net.Objects.Sockets
public class MessageInterpreterPipeline
{
public List<PreInspectCallback> PreInspectCallbacks { get; set; } = new List<PreInspectCallback>();
public List<PostInspectCallback> PostInspectCallbacks { get; set; } = new List<PostInspectCallback>();
public List<object> PostInspectCallbacks { get; set; } = new List<object>();
public Func<JToken, Type, BaseParsedMessage> ObjectInitializer { get; set; } = SocketConverter.InstantiateMessageObject;
}
public class PreInspectCallback
@ -22,7 +24,13 @@ namespace CryptoExchange.Net.Objects.Sockets
public class PostInspectCallback
{
public List<string> TypeFields { get; set; } = new List<string>();
public Func<Dictionary<string, string>, IDictionary<string, IMessageProcessor>, PostInspectResult> Callback { get; set; }
public Func<Dictionary<string, string>, Dictionary<string, Type>, PostInspectResult> Callback { get; set; }
}
public class PostInspectArrayCallback
{
public List<int> TypeFields { get; set; } = new List<int>();
public Func<Dictionary<int, string>, Dictionary<string, Type>, PostInspectResult> Callback { get; set; }
}
public class PreInspectResult

View File

@ -152,9 +152,9 @@ namespace CryptoExchange.Net.Sockets
}
private bool _pausedActivity;
private readonly ConcurrentList<BasePendingRequest> _pendingRequests;
private readonly ConcurrentList<Subscription> _subscriptions;
private readonly ConcurrentDictionary<string, IMessageProcessor> _messageIdMap;
//private readonly ConcurrentList<BasePendingRequest> _pendingRequests;
//private readonly ConcurrentList<Subscription> _subscriptions;
private readonly SocketListenerManager _messageIdMap;
private readonly ILogger _logger;
private SocketStatus _status;
@ -177,9 +177,9 @@ namespace CryptoExchange.Net.Sockets
Tag = tag;
Properties = new Dictionary<string, object>();
_pendingRequests = new ConcurrentList<BasePendingRequest>();
_subscriptions = new ConcurrentList<Subscription>();
_messageIdMap = new ConcurrentDictionary<string, IMessageProcessor>();
//_pendingRequests = new ConcurrentList<BasePendingRequest>();
//_subscriptions = new ConcurrentList<Subscription>();
_messageIdMap = new SocketListenerManager(_logger);
_socket = socket;
_socket.OnStreamMessage += HandleStreamMessage;
@ -209,7 +209,7 @@ namespace CryptoExchange.Net.Sockets
Status = SocketStatus.Closed;
Authenticated = false;
foreach (var subscription in _subscriptions)
foreach (var subscription in _messageIdMap.GetSubscriptions())
subscription.Confirmed = false;
Task.Run(() => ConnectionClosed?.Invoke());
@ -224,7 +224,7 @@ namespace CryptoExchange.Net.Sockets
DisconnectTime = DateTime.UtcNow;
Authenticated = false;
foreach (var subscription in _subscriptions)
foreach (var subscription in _messageIdMap.GetSubscriptions())
subscription.Confirmed = false;
_ = Task.Run(() => ConnectionLost?.Invoke());
@ -307,7 +307,9 @@ namespace CryptoExchange.Net.Sockets
var timestamp = DateTime.UtcNow;
TimeSpan userCodeDuration = TimeSpan.Zero;
var result = ApiClient.StreamConverter.ReadJson(stream, _messageIdMap, ApiClient.ApiOptions.OutputOriginalData ?? ApiClient.ClientOptions.OutputOriginalData);
// TODO This shouldn't be done for every request, just when something changes. Might want to make it a seperate type or something with functions 'Add', 'Remove' and 'GetMapping' or something
// This could then cache the internal dictionary mapping of `GetMapping` until something changes, and also make sure there aren't duplicate ids with different message types
var result = ApiClient.StreamConverter.ReadJson(stream, _messageIdMap.GetMapping(), ApiClient.ApiOptions.OutputOriginalData ?? ApiClient.ClientOptions.OutputOriginalData);
if(result == null)
{
stream.Position = 0;
@ -326,8 +328,7 @@ namespace CryptoExchange.Net.Sockets
return;
}
// TODO lock
if (!_messageIdMap.TryGetValue(result.Identifier, out var messageProcessor))
if (!await _messageIdMap.InvokeListenersAsync(result.Identifier, result).ConfigureAwait(false))
{
stream.Position = 0;
var unhandledBuffer = new byte[stream.Length];
@ -340,28 +341,6 @@ namespace CryptoExchange.Net.Sockets
}
stream.Dispose();
_logger.Log(LogLevel.Trace, $"Socket {SocketId} Message mapped to processor {messageProcessor.Id} with identifier {result.Identifier}");
if (messageProcessor is BaseQuery query)
{
foreach (var id in query.Identifiers)
_messageIdMap.TryRemove(id, out _);
if (query.PendingRequest != null)
_pendingRequests.Remove(query.PendingRequest);
if (query.PendingRequest?.Completed == true)
{
// Answer to a timed out request
_logger.Log(LogLevel.Warning, $"Socket {SocketId} Received after request timeout. Consider increasing the RequestTimeout");
}
}
// Matched based on identifier
var userSw = Stopwatch.StartNew();
var dataEvent = new DataEvent<BaseParsedMessage>(result, null, result.OriginalData, DateTime.UtcNow, null);
await messageProcessor.HandleMessageAsync(dataEvent).ConfigureAwait(false);
userSw.Stop();
}
/// <summary>
@ -394,7 +373,7 @@ namespace CryptoExchange.Net.Sockets
if (ApiClient.socketConnections.ContainsKey(SocketId))
ApiClient.socketConnections.TryRemove(SocketId, out _);
foreach (var subscription in _subscriptions)
foreach (var subscription in _messageIdMap.GetSubscriptions())
{
if (subscription.CancellationTokenRegistration.HasValue)
subscription.CancellationTokenRegistration.Value.Dispose();
@ -412,7 +391,7 @@ namespace CryptoExchange.Net.Sockets
/// <returns></returns>
public async Task CloseAsync(Subscription subscription, bool unsubEvenIfNotConfirmed = false)
{
if (!_subscriptions.Contains(subscription))
if (!_messageIdMap.Contains(subscription))
return;
subscription.Closed = true;
@ -433,7 +412,7 @@ namespace CryptoExchange.Net.Sockets
return;
}
var shouldCloseConnection = _subscriptions.All(r => !r.UserSubscription || r.Closed);
var shouldCloseConnection = _messageIdMap.GetSubscriptions().All(r => !r.UserSubscription || r.Closed);
if (shouldCloseConnection)
Status = SocketStatus.Closing;
@ -443,9 +422,7 @@ namespace CryptoExchange.Net.Sockets
await CloseAsync().ConfigureAwait(false);
}
_subscriptions.Remove(subscription);
foreach (var id in subscription.Identifiers)
_messageIdMap.TryRemove(id, out _);
_messageIdMap.Remove(subscription);
}
/// <summary>
@ -466,18 +443,12 @@ namespace CryptoExchange.Net.Sockets
if (Status != SocketStatus.None && Status != SocketStatus.Connected)
return false;
_subscriptions.Add(subscription);
_messageIdMap.Add(subscription);
if (subscription.Identifiers != null)
{
foreach (var id in subscription.Identifiers)
{
if (!_messageIdMap.TryAdd(id.ToLowerInvariant(), subscription))
throw new InvalidOperationException($"Failed to register subscription id {id}, already registered");
}
}
_messageIdMap.Add(subscription);
if (subscription.UserSubscription)
_logger.Log(LogLevel.Debug, $"Socket {SocketId} adding new subscription with id {subscription.Id}, total subscriptions on connection: {_subscriptions.Count(s => s.UserSubscription)}");
//if (subscription.UserSubscription)
// _logger.Log(LogLevel.Debug, $"Socket {SocketId} adding new subscription with id {subscription.Id}, total subscriptions on connection: {_subscriptions.Count(s => s.UserSubscription)}");
return true;
}
@ -485,14 +456,14 @@ namespace CryptoExchange.Net.Sockets
/// Get a subscription on this connection by id
/// </summary>
/// <param name="id"></param>
public Subscription? GetSubscription(int id) => _subscriptions.SingleOrDefault(s => s.Id == id);
public Subscription? GetSubscription(int id) => _messageIdMap.GetSubscriptions().SingleOrDefault(s => s.Id == id);
/// <summary>
/// Get a subscription on this connection by its subscribe request
/// </summary>
/// <param name="predicate">Filter for a request</param>
/// <returns></returns>
public Subscription? GetSubscriptionByRequest(Func<object?, bool> predicate) => _subscriptions.SingleOrDefault(s => predicate(s));
public Subscription? GetSubscriptionByRequest(Func<object?, bool> predicate) => _messageIdMap.GetSubscriptions().SingleOrDefault(s => predicate(s));
/// <summary>
/// Send a query request and wait for an answer
@ -503,16 +474,10 @@ namespace CryptoExchange.Net.Sockets
{
var pendingRequest = query.CreatePendingRequest();
if (query.Identifiers != null)
{
foreach (var id in query.Identifiers)
{
if(!_messageIdMap.TryAdd(id.ToLowerInvariant(), query))
throw new InvalidOperationException($"Failed to register subscription id {id}, already registered");
}
}
_messageIdMap.Add(query);
await SendAndWaitAsync(pendingRequest, query.Weight).ConfigureAwait(false);
return pendingRequest.Result!;
return pendingRequest.Result ?? new CallResult(new ServerError("Timeout"));
}
/// <summary>
@ -523,18 +488,12 @@ namespace CryptoExchange.Net.Sockets
/// <returns></returns>
public virtual async Task<CallResult<T>> SendAndWaitQueryAsync<T>(Query<T> query)
{
var pendingRequest = PendingRequest<T>.CreateForQuery(query, query.Id);
var pendingRequest = (PendingRequest<T>)query.CreatePendingRequest();
if (query.Identifiers != null)
{
foreach (var id in query.Identifiers)
{
if (!_messageIdMap.TryAdd(id.ToLowerInvariant(), query))
throw new InvalidOperationException($"Failed to register subscription id {id}, already registered");
}
}
_messageIdMap.Add(query);
await SendAndWaitAsync(pendingRequest, query.Weight).ConfigureAwait(false);
return pendingRequest.TypedResult!;
return pendingRequest.TypedResult ?? new CallResult<T>(new ServerError("Timeout"));
}
private async Task SendAndWaitAsync(BasePendingRequest pending, int weight)
@ -608,7 +567,7 @@ namespace CryptoExchange.Net.Sockets
if (!_socket.IsOpen)
return new CallResult<bool>(new WebError("Socket not connected"));
var anySubscriptions = _subscriptions.Any(s => s.UserSubscription);
var anySubscriptions = _messageIdMap.GetSubscriptions().Any(s => s.UserSubscription);
if (!anySubscriptions)
{
// No need to resubscribe anything
@ -617,7 +576,7 @@ namespace CryptoExchange.Net.Sockets
return new CallResult<bool>(true);
}
var anyAuthenticated = _subscriptions.Any(s => s.Authenticated);
var anyAuthenticated = _messageIdMap.GetSubscriptions().Any(s => s.Authenticated);
if (anyAuthenticated)
{
// If we reconnected a authenticated connection we need to re-authenticate
@ -633,15 +592,16 @@ namespace CryptoExchange.Net.Sockets
}
// Get a list of all subscriptions on the socket
var subList = _subscriptions.ToList();
var subList = _messageIdMap.GetSubscriptions();
foreach(var subscription in subList)
{
subscription.ConnectionInvocations = 0;
var result = await ApiClient.RevitalizeRequestAsync(subscription).ConfigureAwait(false);
if (!result)
{
_logger.Log(LogLevel.Warning, $"Socket {SocketId} Failed request revitalization: " + result.Error);
return result.As<bool>(false);
return result.As(false);
}
}

View File

@ -0,0 +1,132 @@
using CryptoExchange.Net.Interfaces;
using CryptoExchange.Net.Objects.Sockets;
using Microsoft.Extensions.Logging;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
namespace CryptoExchange.Net.Sockets
{
internal class SocketListenerManager
{
private ILogger _logger;
private object _lock = new object();
private Dictionary<string, Type> _typeMap;
private Dictionary<string, List<IMessageProcessor>> _listeners;
public SocketListenerManager(ILogger logger)
{
_typeMap = new Dictionary<string, Type>();
_logger = logger;
}
public Dictionary<string, Type> GetMapping()
{
lock (this)
return _typeMap;
}
public void Add(IMessageProcessor processor)
{
lock (_lock)
{
foreach (var identifier in processor.Identifiers)
{
if (!_listeners.TryGetValue(identifier, out var list))
{
list = new List<IMessageProcessor>();
_listeners.Add(identifier, list);
}
list.Add(processor);
}
UpdateMap();
}
}
public async Task<bool> InvokeListenersAsync(string id, BaseParsedMessage data)
{
List<IMessageProcessor> listeners;
lock (_lock)
{
if (!_listeners.TryGetValue(id, out var idListeners))
return false;
listeners = idListeners.ToList();
}
foreach (var listener in listeners)
{
//_logger.Log(LogLevel.Trace, $"Socket {SocketId} Message mapped to processor {messageProcessor.Id} with identifier {result.Identifier}");
if (listener is BaseQuery query)
{
Remove(listener);
if (query.PendingRequest != null)
_pendingRequests.Remove(query.PendingRequest);
if (query.PendingRequest?.Completed == true)
{
// Answer to a timed out request
//_logger.Log(LogLevel.Warning, $"Socket {SocketId} Received after request timeout. Consider increasing the RequestTimeout");
}
}
// Matched based on identifier
var userSw = Stopwatch.StartNew();
var dataEvent = new DataEvent<BaseParsedMessage>(data, null, data.OriginalData, DateTime.UtcNow, null);
await listener.HandleMessageAsync(dataEvent).ConfigureAwait(false);
userSw.Stop();
}
return true;
}
public List<Subscription> GetSubscriptions()
{
lock (_lock)
return _listeners.Values.SelectMany(v => v.OfType<Subscription>()).ToList();
}
public List<BaseQuery> GetQueries()
{
lock (_lock)
return _listeners.Values.SelectMany(v => v.OfType<BaseQuery>()).ToList();
}
public bool Contains(IMessageProcessor processor)
{
lock (_lock)
return _listeners.Any(l => l.Value.Contains(processor));
}
public bool Remove(IMessageProcessor processor)
{
lock (_lock)
{
var removed = false;
foreach (var identifier in processor.Identifiers)
{
if (_listeners[identifier].Remove(processor))
removed = true;
if (!_listeners[identifier].Any())
_listeners.Remove(identifier);
}
UpdateMap();
return removed;
}
}
private void UpdateMap()
{
_typeMap = _listeners.ToDictionary(x => x.Key, x => x.Value.First().ExpectedMessageType);
}
}
}

View File

@ -19,6 +19,16 @@ namespace CryptoExchange.Net.Sockets
/// </summary>
public int Id { get; set; }
/// <summary>
/// Total amount of invocations
/// </summary>
public int TotalInvocations { get; set; }
/// <summary>
/// Amount of invocation during this connection
/// </summary>
public int ConnectionInvocations { get; set; }
/// <summary>
/// Is it a user subscription
/// </summary>
@ -87,12 +97,19 @@ namespace CryptoExchange.Net.Sockets
/// <returns></returns>
public abstract BaseQuery? GetUnsubQuery();
public async Task<CallResult> HandleMessageAsync(DataEvent<BaseParsedMessage> message)
{
ConnectionInvocations++;
TotalInvocations++;
return await DoHandleMessageAsync(message).ConfigureAwait(false);
}
/// <summary>
/// Handle the update message
/// </summary>
/// <param name="message"></param>
/// <returns></returns>
public abstract Task<CallResult> HandleMessageAsync(DataEvent<BaseParsedMessage> message);
public abstract Task<CallResult> DoHandleMessageAsync(DataEvent<BaseParsedMessage> message);
/// <summary>
/// Invoke the exception event
@ -132,7 +149,7 @@ namespace CryptoExchange.Net.Sockets
}
/// <inheritdoc />
public override Task<CallResult> HandleMessageAsync(DataEvent<BaseParsedMessage> message)
public override Task<CallResult> DoHandleMessageAsync(DataEvent<BaseParsedMessage> message)
=> HandleEventAsync(message.As((ParsedMessage<TEvent>)message.Data));
/// <summary>

View File

@ -1,4 +1,8 @@
using Microsoft.Extensions.Logging;
using CryptoExchange.Net.Objects;
using CryptoExchange.Net.Objects.Sockets;
using Microsoft.Extensions.Logging;
using System;
using System.Threading.Tasks;
namespace CryptoExchange.Net.Sockets
{
@ -22,4 +26,17 @@ namespace CryptoExchange.Net.Sockets
/// <inheritdoc />
public override BaseQuery? GetUnsubQuery() => null;
}
public abstract class SystemSubscription<T> : SystemSubscription
{
public override Type ExpectedMessageType => typeof(T);
public override Task<CallResult> DoHandleMessageAsync(DataEvent<BaseParsedMessage> message)
=> HandleMessageAsync(message.As((ParsedMessage<T>)message.Data));
protected SystemSubscription(ILogger logger, bool authenticated) : base(logger, authenticated)
{
}
public abstract Task<CallResult> HandleMessageAsync(DataEvent<ParsedMessage<T>> message);
}
}