1
0
mirror of https://github.com/JKorf/CryptoExchange.Net synced 2026-04-07 02:01:12 +00:00

Updated SocketConnection internal locking to fix potential deadlock

This commit is contained in:
Jkorf 2026-02-25 09:26:12 +01:00
parent 89a73747b0
commit 78e3523a4f

View File

@ -8,7 +8,9 @@ using CryptoExchange.Net.Sockets.Default.Interfaces;
using CryptoExchange.Net.Sockets.Interfaces; using CryptoExchange.Net.Sockets.Interfaces;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Diagnostics; using System.Diagnostics;
using System.Linq; using System.Linq;
using System.Net.WebSockets; using System.Net.WebSockets;
@ -122,17 +124,9 @@ namespace CryptoExchange.Net.Sockets.Default
public int UserSubscriptionCount public int UserSubscriptionCount
{ {
get get
{
_listenersLock.EnterReadLock();
try
{ {
return _listeners.OfType<Subscription>().Count(h => h.UserSubscription); return _listeners.OfType<Subscription>().Count(h => h.UserSubscription);
} }
finally
{
_listenersLock.ExitReadLock();
}
}
} }
/// <summary> /// <summary>
@ -141,17 +135,8 @@ namespace CryptoExchange.Net.Sockets.Default
public Subscription[] Subscriptions public Subscription[] Subscriptions
{ {
get get
{
_listenersLock.EnterReadLock();
try
{ {
return _listeners.OfType<Subscription>().Where(h => h.UserSubscription).ToArray(); return _listeners.OfType<Subscription>().Where(h => h.UserSubscription).ToArray();
}
finally
{
_listenersLock.ExitReadLock();
}
} }
} }
@ -254,17 +239,9 @@ namespace CryptoExchange.Net.Sockets.Default
public string[] Topics public string[] Topics
{ {
get get
{
_listenersLock.EnterReadLock();
try
{ {
return _listeners.OfType<Subscription>().Select(x => x.Topic).Where(t => t != null).ToArray()!; return _listeners.OfType<Subscription>().Select(x => x.Topic).Where(t => t != null).ToArray()!;
} }
finally
{
_listenersLock.ExitReadLock();
}
}
} }
/// <summary> /// <summary>
@ -273,23 +250,19 @@ namespace CryptoExchange.Net.Sockets.Default
public int PendingRequests public int PendingRequests
{ {
get get
{
_listenersLock.EnterReadLock();
try
{ {
return _listeners.OfType<Query>().Where(x => !x.Completed).Count(); return _listeners.OfType<Query>().Where(x => !x.Completed).Count();
} }
finally
{
_listenersLock.ExitReadLock();
}
}
} }
private bool _pausedActivity; private bool _pausedActivity;
private readonly ReaderWriterLockSlim _listenersLock = new ReaderWriterLockSlim(); #if NET9_0_OR_GREATER
private readonly List<IMessageProcessor> _listeners; private readonly Lock _listenersLock = new Lock();
#else
private readonly object _listenersLock = new object();
#endif
private ReadOnlyCollection<IMessageProcessor> _listeners;
private readonly ILogger _logger; private readonly ILogger _logger;
private SocketStatus _status; private SocketStatus _status;
@ -338,7 +311,7 @@ namespace CryptoExchange.Net.Sockets.Default
_socket.OnError += HandleErrorAsync; _socket.OnError += HandleErrorAsync;
_socket.GetReconnectionUrl = GetReconnectionUrlAsync; _socket.GetReconnectionUrl = GetReconnectionUrlAsync;
_listeners = new List<IMessageProcessor>(); _listeners = new ReadOnlyCollection<IMessageProcessor>([]);
_serializer = apiClient.CreateSerializer(); _serializer = apiClient.CreateSerializer();
} }
@ -365,25 +338,17 @@ namespace CryptoExchange.Net.Sockets.Default
if (ApiClient._socketConnections.ContainsKey(SocketId)) if (ApiClient._socketConnections.ContainsKey(SocketId))
ApiClient._socketConnections.TryRemove(SocketId, out _); ApiClient._socketConnections.TryRemove(SocketId, out _);
_listenersLock.EnterWriteLock();
try
{
foreach (var subscription in _listeners.OfType<Subscription>().Where(l => l.UserSubscription && !l.IsClosingConnection)) foreach (var subscription in _listeners.OfType<Subscription>().Where(l => l.UserSubscription && !l.IsClosingConnection))
{ {
subscription.IsClosingConnection = true; subscription.IsClosingConnection = true;
subscription.Reset(); subscription.Reset();
} }
foreach (var query in _listeners.OfType<Query>().ToList()) var queryList = _listeners.OfType<Query>().ToList();
{ foreach (var query in queryList)
query.Fail(new WebError("Connection interrupted")); query.Fail(new WebError("Connection interrupted"));
_listeners.Remove(query);
} RemoveMessageProcessors(queryList);
}
finally
{
_listenersLock.ExitWriteLock();
}
_ = Task.Run(() => ConnectionClosed?.Invoke()); _ = Task.Run(() => ConnectionClosed?.Invoke());
return Task.CompletedTask; return Task.CompletedTask;
@ -399,22 +364,14 @@ namespace CryptoExchange.Net.Sockets.Default
Authenticated = false; Authenticated = false;
_lastSequenceNumber = 0; _lastSequenceNumber = 0;
_listenersLock.EnterWriteLock();
try
{
foreach (var subscription in _listeners.OfType<Subscription>().Where(l => l.UserSubscription)) foreach (var subscription in _listeners.OfType<Subscription>().Where(l => l.UserSubscription))
subscription.Reset(); subscription.Reset();
foreach (var query in _listeners.OfType<Query>().ToList()) var queryList = _listeners.OfType<Query>().ToList();
{ foreach (var query in queryList)
query.Fail(new WebError("Connection interrupted")); query.Fail(new WebError("Connection interrupted"));
_listeners.Remove(query);
} RemoveMessageProcessors(queryList);
}
finally
{
_listenersLock.ExitWriteLock();
}
_ = Task.Run(() => ConnectionLost?.Invoke()); _ = Task.Run(() => ConnectionLost?.Invoke());
return Task.CompletedTask; return Task.CompletedTask;
@ -436,19 +393,11 @@ namespace CryptoExchange.Net.Sockets.Default
{ {
Status = SocketStatus.Resubscribing; Status = SocketStatus.Resubscribing;
_listenersLock.EnterWriteLock(); var queryList = _listeners.OfType<Query>().ToList();
try foreach (var query in queryList)
{
foreach (var query in _listeners.OfType<Query>().ToList())
{
query.Fail(new WebError("Connection interrupted")); query.Fail(new WebError("Connection interrupted"));
_listeners.Remove(query);
} RemoveMessageProcessors(queryList);
}
finally
{
_listenersLock.ExitWriteLock();
}
// Can't wait for this as it would cause a deadlock // Can't wait for this as it would cause a deadlock
_ = Task.Run(async () => _ = Task.Run(async () =>
@ -503,17 +452,7 @@ namespace CryptoExchange.Net.Sockets.Default
/// <returns></returns> /// <returns></returns>
protected virtual Task HandleRequestRateLimitedAsync(int requestId) protected virtual Task HandleRequestRateLimitedAsync(int requestId)
{ {
Query? query; var query = _listeners.OfType<Query>().FirstOrDefault(x => x.Id == requestId);
_listenersLock.EnterReadLock();
try
{
query = _listeners.OfType<Query>().FirstOrDefault(x => x.Id == requestId);
}
finally
{
_listenersLock.ExitReadLock();
}
if (query == null) if (query == null)
return Task.CompletedTask; return Task.CompletedTask;
@ -537,17 +476,7 @@ namespace CryptoExchange.Net.Sockets.Default
/// <param name="requestId">Id of the request sent</param> /// <param name="requestId">Id of the request sent</param>
protected virtual Task HandleRequestSentAsync(int requestId) protected virtual Task HandleRequestSentAsync(int requestId)
{ {
Query? query; var query = _listeners.OfType<Query>().FirstOrDefault(x => x.Id == requestId);
_listenersLock.EnterReadLock();
try
{
query = _listeners.OfType<Query>().FirstOrDefault(x => x.Id == requestId);
}
finally
{
_listenersLock.ExitReadLock();
}
if (query == null) if (query == null)
return Task.CompletedTask; return Task.CompletedTask;
@ -593,9 +522,6 @@ namespace CryptoExchange.Net.Sockets.Default
} }
Type? deserializationType = null; Type? deserializationType = null;
_listenersLock.EnterReadLock();
try
{
foreach (var subscription in _listeners) foreach (var subscription in _listeners)
{ {
foreach (var route in subscription.MessageRouter.Routes) foreach (var route in subscription.MessageRouter.Routes)
@ -610,11 +536,6 @@ namespace CryptoExchange.Net.Sockets.Default
if (deserializationType != null) if (deserializationType != null)
break; break;
} }
}
finally
{
_listenersLock.ExitReadLock();
}
if (deserializationType == null) if (deserializationType == null)
{ {
@ -660,20 +581,8 @@ namespace CryptoExchange.Net.Sockets.Default
var topicFilter = messageConverter.GetTopicFilter(result); var topicFilter = messageConverter.GetTopicFilter(result);
bool processed = false; bool processed = false;
_listenersLock.EnterReadLock(); foreach (var processor in _listeners)
try
{ {
var currentCount = _listeners.Count;
for(var i = 0; i < _listeners.Count; i++)
{
if (_listeners.Count != currentCount)
{
// Possible a query added or removed. If added it's not a problem, if removed it is
if (_listeners.Count < currentCount)
throw new Exception("Listeners list adjusted, can't continue processing");
}
var processor = _listeners[i];
bool isQuery = false; bool isQuery = false;
Query? query = null; Query? query = null;
if (processor is Query cquery) if (processor is Query cquery)
@ -725,18 +634,10 @@ namespace CryptoExchange.Net.Sockets.Default
if (complete) if (complete)
break; break;
} }
}
finally
{
_listenersLock.ExitReadLock();
}
if (!processed) if (!processed)
{ {
if (!ApiClient.HandleUnhandledMessage(this, typeIdentifier, data)) if (!ApiClient.HandleUnhandledMessage(this, typeIdentifier, data))
{
_listenersLock.EnterReadLock();
try
{ {
_logger.ReceivedMessageNotMatchedToAnyListener( _logger.ReceivedMessageNotMatchedToAnyListener(
SocketId, SocketId,
@ -744,11 +645,6 @@ namespace CryptoExchange.Net.Sockets.Default
topicFilter!, topicFilter!,
string.Join(",", _listeners.Select(x => string.Join(",", x.MessageRouter.Routes.Where(x => x.TypeIdentifier == typeIdentifier).Select(x => x.TopicFilter != null ? string.Join(",", x.TopicFilter) : "[null]"))))); string.Join(",", _listeners.Select(x => string.Join(",", x.MessageRouter.Routes.Where(x => x.TypeIdentifier == typeIdentifier).Select(x => x.TopicFilter != null ? string.Join(",", x.TopicFilter) : "[null]")))));
} }
finally
{
_listenersLock.ExitReadLock();
}
}
} }
} }
@ -792,19 +688,11 @@ namespace CryptoExchange.Net.Sockets.Default
if (ApiClient._socketConnections.ContainsKey(SocketId)) if (ApiClient._socketConnections.ContainsKey(SocketId))
ApiClient._socketConnections.TryRemove(SocketId, out _); ApiClient._socketConnections.TryRemove(SocketId, out _);
_listenersLock.EnterReadLock();
try
{
foreach (var subscription in _listeners.OfType<Subscription>()) foreach (var subscription in _listeners.OfType<Subscription>())
{ {
if (subscription.CancellationTokenRegistration.HasValue) if (subscription.CancellationTokenRegistration.HasValue)
subscription.CancellationTokenRegistration.Value.Dispose(); subscription.CancellationTokenRegistration.Value.Dispose();
} }
}
finally
{
_listenersLock.ExitReadLock();
}
await _socket.CloseAsync().ConfigureAwait(false); await _socket.CloseAsync().ConfigureAwait(false);
_socket.Dispose(); _socket.Dispose();
@ -833,32 +721,12 @@ namespace CryptoExchange.Net.Sockets.Default
if (subscription.CancellationTokenRegistration.HasValue) if (subscription.CancellationTokenRegistration.HasValue)
subscription.CancellationTokenRegistration.Value.Dispose(); subscription.CancellationTokenRegistration.Value.Dispose();
bool anyDuplicateSubscription; bool anyDuplicateSubscription = _listeners.OfType<Subscription>().Any(x => x != subscription && x.MessageRouter.Routes.All(l => subscription.MessageRouter.ContainsCheck(l)));
bool shouldCloseConnection; bool shouldCloseConnection = _listeners.OfType<Subscription>().All(r => !r.UserSubscription || r.Status == SubscriptionStatus.Closing || r.Status == SubscriptionStatus.Closed) && !DedicatedRequestConnection.IsDedicatedRequestConnection;
_listenersLock.EnterReadLock();
try
{
anyDuplicateSubscription = _listeners.OfType<Subscription>().Any(x => x != subscription && x.MessageRouter.Routes.All(l => subscription.MessageRouter.ContainsCheck(l)));
shouldCloseConnection = _listeners.OfType<Subscription>().All(r => !r.UserSubscription || r.Status == SubscriptionStatus.Closing || r.Status == SubscriptionStatus.Closed) && !DedicatedRequestConnection.IsDedicatedRequestConnection;
}
finally
{
_listenersLock.ExitReadLock();
}
if (!anyDuplicateSubscription) if (!anyDuplicateSubscription)
{ {
bool needUnsub; var needUnsub = _listeners.Contains(subscription) && !shouldCloseConnection;
_listenersLock.EnterReadLock();
try
{
needUnsub = _listeners.Contains(subscription) && !shouldCloseConnection;
}
finally
{
_listenersLock.ExitReadLock();
}
if (needUnsub && _socket.IsOpen) if (needUnsub && _socket.IsOpen)
await UnsubscribeAsync(subscription).ConfigureAwait(false); await UnsubscribeAsync(subscription).ConfigureAwait(false);
} }
@ -882,15 +750,7 @@ namespace CryptoExchange.Net.Sockets.Default
await CloseAsync().ConfigureAwait(false); await CloseAsync().ConfigureAwait(false);
} }
_listenersLock.EnterWriteLock(); RemoveMessageProcessor(subscription);
try
{
_listeners.Remove(subscription);
}
finally
{
_listenersLock.ExitWriteLock();
}
subscription.Status = SubscriptionStatus.Closed; subscription.Status = SubscriptionStatus.Closed;
} }
@ -914,15 +774,7 @@ namespace CryptoExchange.Net.Sockets.Default
if (Status != SocketStatus.None && Status != SocketStatus.Connected) if (Status != SocketStatus.None && Status != SocketStatus.Connected)
return false; return false;
_listenersLock.EnterWriteLock(); AddMessageProcessor(subscription);
try
{
_listeners.Add(subscription);
}
finally
{
_listenersLock.ExitWriteLock();
}
if (subscription.UserSubscription) if (subscription.UserSubscription)
_logger.AddingNewSubscription(SocketId, subscription.Id, UserSubscriptionCount); _logger.AddingNewSubscription(SocketId, subscription.Id, UserSubscriptionCount);
@ -934,17 +786,9 @@ namespace CryptoExchange.Net.Sockets.Default
/// </summary> /// </summary>
/// <param name="id"></param> /// <param name="id"></param>
public Subscription? GetSubscription(int id) public Subscription? GetSubscription(int id)
{
_listenersLock.EnterReadLock();
try
{ {
return _listeners.OfType<Subscription>().SingleOrDefault(s => s.Id == id); return _listeners.OfType<Subscription>().SingleOrDefault(s => s.Id == id);
} }
finally
{
_listenersLock.ExitReadLock();
}
}
/// <summary> /// <summary>
/// Get the state of the connection /// Get the state of the connection
@ -991,29 +835,12 @@ namespace CryptoExchange.Net.Sockets.Default
private async Task SendAndWaitIntAsync(Query query, CancellationToken ct = default) private async Task SendAndWaitIntAsync(Query query, CancellationToken ct = default)
{ {
_listenersLock.EnterWriteLock(); AddMessageProcessor(query);
try
{
_listeners.Add(query);
}
finally
{
_listenersLock.ExitWriteLock();
}
var sendResult = await SendAsync(query.Id, query.Request, query.Weight).ConfigureAwait(false); var sendResult = await SendAsync(query.Id, query.Request, query.Weight).ConfigureAwait(false);
if (!sendResult) if (!sendResult)
{ {
query.Fail(sendResult.Error!); query.Fail(sendResult.Error!);
_listenersLock.EnterWriteLock(); RemoveMessageProcessor(query);
try
{
_listeners.Remove(query);
}
finally
{
_listenersLock.ExitWriteLock();
}
return; return;
} }
@ -1044,15 +871,7 @@ namespace CryptoExchange.Net.Sockets.Default
} }
finally finally
{ {
_listenersLock.EnterWriteLock(); RemoveMessageProcessor(query);
try
{
_listeners.Remove(query);
}
finally
{
_listenersLock.ExitWriteLock();
}
} }
} }
@ -1158,17 +977,7 @@ namespace CryptoExchange.Net.Sockets.Default
if (!DedicatedRequestConnection.IsDedicatedRequestConnection) if (!DedicatedRequestConnection.IsDedicatedRequestConnection)
{ {
bool anySubscriptions; var anySubscriptions = _listeners.OfType<Subscription>().Any(s => s.UserSubscription);
_listenersLock.EnterReadLock();
try
{
anySubscriptions = _listeners.OfType<Subscription>().Any(s => s.UserSubscription);
}
finally
{
_listenersLock.ExitReadLock();
}
if (!anySubscriptions) if (!anySubscriptions)
{ {
// No need to resubscribe anything // No need to resubscribe anything
@ -1178,18 +987,8 @@ namespace CryptoExchange.Net.Sockets.Default
} }
} }
bool anyAuthenticated; bool anyAuthenticated = _listeners.OfType<Subscription>().Any(s => s.Authenticated)
_listenersLock.EnterReadLock();
try
{
anyAuthenticated = _listeners.OfType<Subscription>().Any(s => s.Authenticated)
|| DedicatedRequestConnection.IsDedicatedRequestConnection && DedicatedRequestConnection.Authenticated; || DedicatedRequestConnection.IsDedicatedRequestConnection && DedicatedRequestConnection.Authenticated;
}
finally
{
_listenersLock.ExitReadLock();
}
if (anyAuthenticated) if (anyAuthenticated)
{ {
// If we reconnected a authenticated connection we need to re-authenticate // If we reconnected a authenticated connection we need to re-authenticate
@ -1212,17 +1011,7 @@ namespace CryptoExchange.Net.Sockets.Default
if (!_socket.IsOpen) if (!_socket.IsOpen)
return new CallResult(new WebError("Socket not connected")); return new CallResult(new WebError("Socket not connected"));
List<Subscription> subList; var subList = _listeners.OfType<Subscription>().Where(x => x.Active).Skip(batch * batchSize).Take(batchSize).ToList();
_listenersLock.EnterReadLock();
try
{
subList = _listeners.OfType<Subscription>().Where(x => x.Active).Skip(batch * batchSize).Take(batchSize).ToList();
}
finally
{
_listenersLock.ExitReadLock();
}
if (subList.Count == 0) if (subList.Count == 0)
break; break;
@ -1404,6 +1193,37 @@ namespace CryptoExchange.Net.Sockets.Default
}); });
} }
private void AddMessageProcessor(IMessageProcessor processor)
{
lock (_listenersLock)
{
var updatedList = new List<IMessageProcessor>(_listeners);
updatedList.Add(processor);
_listeners = updatedList.AsReadOnly();
}
}
private void RemoveMessageProcessor(IMessageProcessor processor)
{
lock (_listenersLock)
{
var updatedList = new List<IMessageProcessor>(_listeners);
updatedList.Remove(processor);
_listeners = updatedList.AsReadOnly();
}
}
private void RemoveMessageProcessors(IEnumerable<IMessageProcessor> processors)
{
lock (_listenersLock)
{
var updatedList = new List<IMessageProcessor>(_listeners);
foreach (var processor in processors)
updatedList.Remove(processor);
_listeners = updatedList.AsReadOnly();
}
}
} }
} }