1
0
mirror of https://github.com/JKorf/CryptoExchange.Net synced 2025-06-09 00:46:19 +00:00

Finished up websocket refactoring

This commit is contained in:
JKorf 2022-07-10 16:36:00 +02:00
parent 89b517c936
commit 70f8bd203a
4 changed files with 156 additions and 52 deletions

View File

@ -3,6 +3,7 @@ using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using CryptoExchange.Net.Authentication;
@ -191,10 +192,14 @@ namespace CryptoExchange.Net
while (true)
{
// Get a new or existing socket connection
socketConnection = GetSocketConnection(apiClient, url, authenticated);
var socketResult = await GetSocketConnection(apiClient, url, authenticated).ConfigureAwait(false);
if(!socketResult)
return socketResult.As<UpdateSubscription>(null);
socketConnection = socketResult.Data;
// Add a subscription on the socket connection
subscription = AddSubscription(request, identifier, true, socketConnection, dataHandler);
subscription = AddSubscription(request, identifier, true, socketConnection, dataHandler, authenticated);
if (subscription == null)
{
log.Write(LogLevel.Trace, $"Socket {socketConnection.SocketId} failed to add subscription, retrying on different connection");
@ -235,7 +240,7 @@ namespace CryptoExchange.Net
var subResult = await SubscribeAndWaitAsync(socketConnection, request, subscription).ConfigureAwait(false);
if (!subResult)
{
log.Write(LogLevel.Information, $"Socket {socketConnection.SocketId} failed to subscribe: {subResult.Error}");
log.Write(LogLevel.Warning, $"Socket {socketConnection.SocketId} failed to subscribe: {subResult.Error}");
await socketConnection.CloseAsync(subscription).ConfigureAwait(false);
return new CallResult<UpdateSubscription>(subResult.Error!);
}
@ -315,7 +320,12 @@ namespace CryptoExchange.Net
await semaphoreSlim.WaitAsync().ConfigureAwait(false);
try
{
socketConnection = GetSocketConnection(apiClient, url, authenticated);
var socketResult = await GetSocketConnection(apiClient, url, authenticated).ConfigureAwait(false);
if (!socketResult)
return socketResult.As<T>(default);
socketConnection = socketResult.Data;
if (ClientOptions.SocketSubscriptionsCombineTarget == 1)
{
// Can release early when only a single sub per connection
@ -474,8 +484,9 @@ namespace CryptoExchange.Net
/// <param name="userSubscription">Whether or not this is a user subscription (counts towards the max amount of handlers on a socket)</param>
/// <param name="connection">The socket connection the handler is on</param>
/// <param name="dataHandler">The handler of the data received</param>
/// <param name="authenticated">Whether the subscription needs authentication</param>
/// <returns></returns>
protected virtual SocketSubscription? AddSubscription<T>(object? request, string? identifier, bool userSubscription, SocketConnection connection, Action<DataEvent<T>> dataHandler)
protected virtual SocketSubscription? AddSubscription<T>(object? request, string? identifier, bool userSubscription, SocketConnection connection, Action<DataEvent<T>> dataHandler, bool authenticated)
{
void InternalHandler(MessageEvent messageEvent)
{
@ -497,8 +508,8 @@ namespace CryptoExchange.Net
}
var subscription = request == null
? SocketSubscription.CreateForIdentifier(NextId(), identifier!, userSubscription, InternalHandler)
: SocketSubscription.CreateForRequest(NextId(), request, userSubscription, InternalHandler);
? SocketSubscription.CreateForIdentifier(NextId(), identifier!, userSubscription, authenticated, InternalHandler)
: SocketSubscription.CreateForRequest(NextId(), request, userSubscription, authenticated, InternalHandler);
if (!connection.AddSubscription(subscription))
return null;
return subscription;
@ -512,11 +523,23 @@ namespace CryptoExchange.Net
protected void AddGenericHandler(string identifier, Action<MessageEvent> action)
{
genericHandlers.Add(identifier, action);
var subscription = SocketSubscription.CreateForIdentifier(NextId(), identifier, false, action);
var subscription = SocketSubscription.CreateForIdentifier(NextId(), identifier, false, false, action);
foreach (var connection in socketConnections.Values)
connection.AddSubscription(subscription);
}
/// <summary>
/// Get the url to connect to (defaults to BaseAddress form the client options)
/// </summary>
/// <param name="apiClient"></param>
/// <param name="address"></param>
/// <param name="authentication"></param>
/// <returns></returns>
protected virtual Task<CallResult<string?>> GetConnectionUrlAsync(SocketApiClient apiClient, string address, bool authentication)
{
return Task.FromResult(new CallResult<string?>(address));
}
/// <summary>
/// Gets a connection for a new subscription or query. Can be an existing if there are open position or a new one.
/// </summary>
@ -524,10 +547,10 @@ namespace CryptoExchange.Net
/// <param name="address">The address the socket is for</param>
/// <param name="authenticated">Whether the socket should be authenticated</param>
/// <returns></returns>
protected virtual SocketConnection GetSocketConnection(SocketApiClient apiClient, string address, bool authenticated)
protected virtual async Task<CallResult<SocketConnection>> GetSocketConnection(SocketApiClient apiClient, string address, bool authenticated)
{
var socketResult = socketConnections.Where(s => (s.Value.Status == SocketConnection.SocketStatus.None || s.Value.Status == SocketConnection.SocketStatus.Connected)
&& s.Value.Uri.ToString().TrimEnd('/') == address.TrimEnd('/')
&& s.Value.Tag.TrimEnd('/') == address.TrimEnd('/')
&& (s.Value.ApiClient.GetType() == apiClient.GetType())
&& (s.Value.Authenticated == authenticated || !authenticated) && s.Value.Connected).OrderBy(s => s.Value.SubscriptionCount).FirstOrDefault();
var result = socketResult.Equals(default(KeyValuePair<int, SocketConnection>)) ? null : socketResult.Value;
@ -536,21 +559,31 @@ namespace CryptoExchange.Net
if (result.SubscriptionCount < ClientOptions.SocketSubscriptionsCombineTarget || (socketConnections.Count >= ClientOptions.MaxSocketConnections && socketConnections.All(s => s.Value.SubscriptionCount >= ClientOptions.SocketSubscriptionsCombineTarget)))
{
// Use existing socket if it has less than target connections OR it has the least connections and we can't make new
return result;
return new CallResult<SocketConnection>(result);
}
}
var connectionAddress = await GetConnectionUrlAsync(apiClient, address, authenticated).ConfigureAwait(false);
if (!connectionAddress)
{
log.Write(LogLevel.Warning, $"Failed to determine connection url: " + connectionAddress.Error);
return connectionAddress.As<SocketConnection>(null);
}
if (connectionAddress.Data != address)
log.Write(LogLevel.Debug, $"Connection address set to " + connectionAddress.Data);
// Create new socket
var socket = CreateSocket(address);
var socketConnection = new SocketConnection(this, apiClient, socket);
var socket = CreateSocket(connectionAddress.Data!);
var socketConnection = new SocketConnection(this, apiClient, socket, address);
socketConnection.UnhandledMessage += HandleUnhandledMessage;
foreach (var kvp in genericHandlers)
{
var handler = SocketSubscription.CreateForIdentifier(NextId(), kvp.Key, false, kvp.Value);
var handler = SocketSubscription.CreateForIdentifier(NextId(), kvp.Key, false, false, kvp.Value);
socketConnection.AddSubscription(handler);
}
return socketConnection;
return new CallResult<SocketConnection>(socketConnection);
}
/// <summary>
@ -700,7 +733,7 @@ namespace CryptoExchange.Net
/// <returns></returns>
public virtual async Task UnsubscribeAllAsync()
{
log.Write(LogLevel.Information, $"Closing all {socketConnections.Sum(s => s.Value.SubscriptionCount)} subscriptions");
log.Write(LogLevel.Information, $"Unsubscribing all {socketConnections.Sum(s => s.Value.SubscriptionCount)} subscriptions");
var tasks = new List<Task>();
{
var socketList = socketConnections.Values;
@ -711,6 +744,39 @@ namespace CryptoExchange.Net
await Task.WhenAll(tasks.ToArray()).ConfigureAwait(false);
}
/// <summary>
/// Reconnect all connections
/// </summary>
/// <returns></returns>
public virtual async Task ReconnectAsync()
{
log.Write(LogLevel.Information, $"Reconnecting all {socketConnections.Count} connections");
var tasks = new List<Task>();
{
var socketList = socketConnections.Values;
foreach (var sub in socketList)
tasks.Add(sub.TriggerReconnectAsync());
}
await Task.WhenAll(tasks.ToArray()).ConfigureAwait(false);
}
/// <summary>
/// Log the current state of connections and subscriptions
/// </summary>
public string GetSubscriptionsState()
{
var sb = new StringBuilder();
sb.AppendLine($"{socketConnections.Count} connections, {CurrentSubscriptions} subscriptions, kbps: {IncomingKbps}");
foreach(var connection in socketConnections)
{
sb.AppendLine($" Connection {connection.Key}: {connection.Value.SubscriptionCount} subscriptions, status: {connection.Value.Status}, authenticated: {connection.Value.Authenticated}, kbps: {connection.Value.IncomingKbps}");
foreach (var subscription in connection.Value.Subscriptions)
sb.AppendLine($" Subscription {subscription.Id}, authenticated: {subscription.Authenticated}, confirmed: {subscription.Confirmed}");
}
return sb.ToString();
}
/// <summary>
/// Dispose the client
/// </summary>
@ -719,8 +785,11 @@ namespace CryptoExchange.Net
disposing = true;
periodicEvent?.Set();
periodicEvent?.Dispose();
if (socketConnections.Sum(s => s.Value.SubscriptionCount) > 0)
{
log.Write(LogLevel.Debug, "Disposing socket client, closing all subscriptions");
_ = UnsubscribeAllAsync();
}
semaphoreSlim?.Dispose();
base.Dispose();
}

View File

@ -27,13 +27,6 @@ namespace CryptoExchange.Net.Sockets
Reconnecting
}
enum CloseState
{
Idle,
Closing,
Closed
}
internal static int lastStreamId;
private static readonly object streamIdLock = new();
@ -191,13 +184,13 @@ namespace CryptoExchange.Net.Sockets
{
while (!_stopRequested)
{
_log.Write(LogLevel.Trace, $"Socket {Id} ProcessAsync started");
_log.Write(LogLevel.Debug, $"Socket {Id} starting processing tasks");
_processState = ProcessState.Processing;
var sendTask = SendLoopAsync();
var receiveTask = ReceiveLoopAsync();
var timeoutTask = _parameters.Timeout != null && _parameters.Timeout > TimeSpan.FromSeconds(0) ? CheckTimeoutAsync() : Task.CompletedTask;
await Task.WhenAll(sendTask, receiveTask, timeoutTask).ConfigureAwait(false);
_log.Write(LogLevel.Trace, $"Socket {Id} ProcessAsync finished");
_log.Write(LogLevel.Debug, $"Socket {Id} processing tasks finished");
_processState = ProcessState.WaitingForClose;
while (_closeTask == null)
@ -221,7 +214,7 @@ namespace CryptoExchange.Net.Sockets
while (!_stopRequested)
{
_log.Write(LogLevel.Trace, $"Socket {Id} attempting to reconnect");
_log.Write(LogLevel.Debug, $"Socket {Id} attempting to reconnect");
_socket = CreateSocket();
_ctsSource.Dispose();
_ctsSource = new CancellationTokenSource();
@ -260,7 +253,7 @@ namespace CryptoExchange.Net.Sockets
if (_processState != ProcessState.Processing)
return;
_log.Write(LogLevel.Debug, $"Socket {Id} reconnecting");
_log.Write(LogLevel.Debug, $"Socket {Id} reconnect requested");
_closeTask = CloseInternalAsync();
await _closeTask.ConfigureAwait(false);
}
@ -318,7 +311,6 @@ namespace CryptoExchange.Net.Sockets
{
try
{
_log.Write(LogLevel.Trace, $"Socket {Id} normal closure 1");
await _socket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "Closing", default).ConfigureAwait(false);
}
catch (Exception)
@ -332,7 +324,6 @@ namespace CryptoExchange.Net.Sockets
{
try
{
_log.Write(LogLevel.Trace, $"Socket {Id} normal closure 2");
await _socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing", default).ConfigureAwait(false);
}
catch (Exception)
@ -390,7 +381,7 @@ namespace CryptoExchange.Net.Sockets
}
if (start != null)
_log.Write(LogLevel.Trace, $"Socket {Id} sent delayed {Math.Round((DateTime.UtcNow - start.Value).TotalMilliseconds)}ms because of rate limit");
_log.Write(LogLevel.Debug, $"Socket {Id} sent delayed {Math.Round((DateTime.UtcNow - start.Value).TotalMilliseconds)}ms because of rate limit");
}
try
@ -424,7 +415,7 @@ namespace CryptoExchange.Net.Sockets
}
finally
{
_log.Write(LogLevel.Trace, $"Socket {Id} Send loop finished");
_log.Write(LogLevel.Debug, $"Socket {Id} Send loop finished");
}
}
@ -542,7 +533,7 @@ namespace CryptoExchange.Net.Sockets
}
finally
{
_log.Write(LogLevel.Trace, $"Socket {Id} Receive loop finished");
_log.Write(LogLevel.Debug, $"Socket {Id} Receive loop finished");
}
}

View File

@ -57,10 +57,22 @@ namespace CryptoExchange.Net.Sockets
return subscriptions.Count(h => h.UserSubscription); }
}
/// <summary>
/// Get a copy of the current subscriptions
/// </summary>
public SocketSubscription[] Subscriptions
{
get
{
lock (subscriptionLock)
return subscriptions.Where(h => h.UserSubscription).ToArray();
}
}
/// <summary>
/// If the connection has been authenticated
/// </summary>
public bool Authenticated { get; set; }
public bool Authenticated { get; internal set; }
/// <summary>
/// If connection is made
@ -80,7 +92,7 @@ namespace CryptoExchange.Net.Sockets
/// <summary>
/// The connection uri
/// </summary>
public Uri Uri => _socket.Uri;
public Uri ConnectionUri => _socket.Uri;
/// <summary>
/// The API client the connection is for
@ -95,7 +107,7 @@ namespace CryptoExchange.Net.Sockets
/// <summary>
/// Tag for identificaion
/// </summary>
public string? Tag { get; set; }
public string Tag { get; set; }
/// <summary>
/// If activity is paused
@ -128,7 +140,7 @@ namespace CryptoExchange.Net.Sockets
var oldStatus = _status;
_status = value;
log.Write(LogLevel.Trace, $"Socket {SocketId} status changed from {oldStatus} to {_status}");
log.Write(LogLevel.Debug, $"Socket {SocketId} status changed from {oldStatus} to {_status}");
}
}
@ -154,11 +166,12 @@ namespace CryptoExchange.Net.Sockets
/// <param name="client">The socket client</param>
/// <param name="apiClient">The api client</param>
/// <param name="socket">The socket</param>
public SocketConnection(BaseSocketClient client, SocketApiClient apiClient, IWebsocket socket)
public SocketConnection(BaseSocketClient client, SocketApiClient apiClient, IWebsocket socket, string tag)
{
log = client.log;
socketClient = client;
ApiClient = apiClient;
Tag = tag;
pendingRequests = new List<PendingRequest>();
subscriptions = new List<SocketSubscription>();
@ -187,6 +200,12 @@ namespace CryptoExchange.Net.Sockets
protected virtual void HandleClose()
{
Status = SocketStatus.Closed;
Authenticated = false;
lock(subscriptionLock)
{
foreach (var sub in subscriptions)
sub.Confirmed = false;
}
Task.Run(() => ConnectionClosed?.Invoke());
}
@ -197,6 +216,12 @@ namespace CryptoExchange.Net.Sockets
{
Status = SocketStatus.Reconnecting;
DisconnectTime = DateTime.UtcNow;
Authenticated = false;
lock (subscriptionLock)
{
foreach (var sub in subscriptions)
sub.Confirmed = false;
}
Task.Run(() => ConnectionLost?.Invoke());
}
@ -365,7 +390,7 @@ namespace CryptoExchange.Net.Sockets
if (Status == SocketStatus.Closing || Status == SocketStatus.Closed || Status == SocketStatus.Disposed)
return;
log.Write(LogLevel.Trace, $"Socket {SocketId} closing subscription {subscription.Id}");
log.Write(LogLevel.Debug, $"Socket {SocketId} closing subscription {subscription.Id}");
if (subscription.CancellationTokenRegistration.HasValue)
subscription.CancellationTokenRegistration.Value.Dispose();
@ -377,7 +402,7 @@ namespace CryptoExchange.Net.Sockets
{
if (Status == SocketStatus.Closing)
{
log.Write(LogLevel.Trace, $"Socket {SocketId} already closing");
log.Write(LogLevel.Debug, $"Socket {SocketId} already closing");
return;
}
@ -388,7 +413,7 @@ namespace CryptoExchange.Net.Sockets
if (shouldCloseConnection)
{
log.Write(LogLevel.Trace, $"Socket {SocketId} closing as there are no more subscriptions");
log.Write(LogLevel.Debug, $"Socket {SocketId} closing as there are no more subscriptions");
await CloseAsync().ConfigureAwait(false);
}
}
@ -414,7 +439,8 @@ namespace CryptoExchange.Net.Sockets
return false;
subscriptions.Add(subscription);
log.Write(LogLevel.Trace, $"Socket {SocketId} adding new subscription with id {subscription.Id}, total subscriptions on connection: {subscriptions.Count}");
if(subscription.UserSubscription)
log.Write(LogLevel.Debug, $"Socket {SocketId} adding new subscription with id {subscription.Id}, total subscriptions on connection: {subscriptions.Count(s => s.UserSubscription)}");
return true;
}
}
@ -567,7 +593,7 @@ namespace CryptoExchange.Net.Sockets
return new CallResult<bool>(true);
}
if (Authenticated)
if (subscriptions.Any(s => s.Authenticated))
{
// If we reconnected a authenticated connection we need to re-authenticate
var authResult = await socketClient.AuthenticateSocketAsync(this).ConfigureAwait(false);
@ -577,13 +603,22 @@ namespace CryptoExchange.Net.Sockets
return authResult;
}
Authenticated = true;
log.Write(LogLevel.Debug, $"Socket {SocketId} authentication succeeded on reconnected socket.");
}
// Get a list of all subscriptions on the socket
List<SocketSubscription> subscriptionList;
List<SocketSubscription> subscriptionList = new List<SocketSubscription>();
lock (subscriptionLock)
subscriptionList = subscriptions.Where(h => h.Request != null).ToList();
{
foreach (var subscription in subscriptions)
{
if (subscription.Request != null)
subscriptionList.Add(subscription);
else
subscription.Confirmed = true;
}
}
// Foreach subscription which is subscribed by a subscription request we will need to resend that request to resubscribe
for (var i = 0; i < subscriptionList.Count; i += socketClient.ClientOptions.MaxConcurrentResubscriptionsPerSocket)
@ -600,6 +635,9 @@ namespace CryptoExchange.Net.Sockets
return taskList.First(t => !t.Result.Success).Result;
}
foreach (var subscription in subscriptionList)
subscription.Confirmed = true;
if (!_socket.IsOpen)
return new CallResult<bool>(new WebError("Socket not connected"));

View File

@ -43,19 +43,25 @@ namespace CryptoExchange.Net.Sockets
/// </summary>
public bool Confirmed { get; set; }
/// <summary>
/// Whether authentication is needed for this subscription
/// </summary>
public bool Authenticated { get; set; }
/// <summary>
/// Cancellation token registration, should be disposed when subscription is closed. Used for closing the subscription with
/// a provided cancelation token
/// </summary>
public CancellationTokenRegistration? CancellationTokenRegistration { get; set; }
private SocketSubscription(int id, object? request, string? identifier, bool userSubscription, Action<MessageEvent> dataHandler)
private SocketSubscription(int id, object? request, string? identifier, bool userSubscription, bool authenticated, Action<MessageEvent> dataHandler)
{
Id = id;
UserSubscription = userSubscription;
MessageHandler = dataHandler;
Request = request;
Identifier = identifier;
Authenticated = authenticated;
}
/// <summary>
@ -67,9 +73,9 @@ namespace CryptoExchange.Net.Sockets
/// <param name="dataHandler"></param>
/// <returns></returns>
public static SocketSubscription CreateForRequest(int id, object request, bool userSubscription,
Action<MessageEvent> dataHandler)
bool authenticated, Action<MessageEvent> dataHandler)
{
return new SocketSubscription(id, request, null, userSubscription, dataHandler);
return new SocketSubscription(id, request, null, userSubscription, authenticated, dataHandler);
}
/// <summary>
@ -81,9 +87,9 @@ namespace CryptoExchange.Net.Sockets
/// <param name="dataHandler"></param>
/// <returns></returns>
public static SocketSubscription CreateForIdentifier(int id, string identifier, bool userSubscription,
Action<MessageEvent> dataHandler)
bool authenticated, Action<MessageEvent> dataHandler)
{
return new SocketSubscription(id, null, identifier, userSubscription, dataHandler);
return new SocketSubscription(id, null, identifier, userSubscription, authenticated, dataHandler);
}
/// <summary>