diff --git a/CryptoExchange.Net.UnitTests/SocketClientTests.cs b/CryptoExchange.Net.UnitTests/SocketClientTests.cs index 37ae476..912a4c3 100644 --- a/CryptoExchange.Net.UnitTests/SocketClientTests.cs +++ b/CryptoExchange.Net.UnitTests/SocketClientTests.cs @@ -164,6 +164,7 @@ namespace CryptoExchange.Net.UnitTests { reconnected = true; rstEvent.Set(); + return true; }; // act @@ -229,5 +230,34 @@ namespace CryptoExchange.Net.UnitTests // assert Assert.IsFalse(connectResult.Success); } + + [Test] + public void WhenResubscribeFails_Socket_ShouldReconnect() + { + // arrange + int reconnected = 0; + var client = new TestSocketClient(new SocketClientOptions() { ReconnectInterval = TimeSpan.FromMilliseconds(1), LogVerbosity = LogVerbosity.Debug }); + var socket = client.CreateSocket(); + socket.ShouldReconnect = true; + socket.CanConnect = true; + socket.DisconnectTime = DateTime.UtcNow; + var sub = new SocketSubscription(socket); + client.ConnectSocketSub(sub); + var rstEvent = new ManualResetEvent(false); + client.OnReconnect += () => + { + reconnected++; + rstEvent.Set(); + return reconnected == 2; + }; + + // act + socket.InvokeClose(); + rstEvent.WaitOne(1000); + Thread.Sleep(100); + + // assert + Assert.IsTrue(reconnected == 2); + } } } diff --git a/CryptoExchange.Net.UnitTests/TestImplementations/TestSocket.cs b/CryptoExchange.Net.UnitTests/TestImplementations/TestSocket.cs index c911453..31b4ff0 100644 --- a/CryptoExchange.Net.UnitTests/TestImplementations/TestSocket.cs +++ b/CryptoExchange.Net.UnitTests/TestImplementations/TestSocket.cs @@ -20,6 +20,7 @@ namespace CryptoExchange.Net.UnitTests.TestImplementations public int Id { get; } public bool ShouldReconnect { get; set; } + public TimeSpan Timeout { get; set; } public Func DataInterpreter { get; set; } public DateTime? DisconnectTime { get; set; } public string Url { get; } @@ -30,9 +31,24 @@ namespace CryptoExchange.Net.UnitTests.TestImplementations public TimeSpan PingInterval { get; set; } public SslProtocols SSLProtocols { get; set; } + public int ConnectCalls { get; private set; } + + public static int lastId = 0; + public static object lastIdLock = new object(); + + public TestSocket() + { + lock (lastIdLock) + { + Id = lastId + 1; + lastId++; + } + } + public Task Connect() { Connected = CanConnect; + ConnectCalls++; return Task.FromResult(CanConnect); } @@ -45,6 +61,8 @@ namespace CryptoExchange.Net.UnitTests.TestImplementations public Task Close() { Connected = false; + DisconnectTime = DateTime.UtcNow; + OnClose?.Invoke(); return Task.FromResult(0); } @@ -59,6 +77,7 @@ namespace CryptoExchange.Net.UnitTests.TestImplementations public void InvokeClose() { Connected = false; + DisconnectTime = DateTime.UtcNow; OnClose?.Invoke(); } diff --git a/CryptoExchange.Net.UnitTests/TestImplementations/TestSocketClient.cs b/CryptoExchange.Net.UnitTests/TestImplementations/TestSocketClient.cs index 45ce0e3..07d490c 100644 --- a/CryptoExchange.Net.UnitTests/TestImplementations/TestSocketClient.cs +++ b/CryptoExchange.Net.UnitTests/TestImplementations/TestSocketClient.cs @@ -9,7 +9,7 @@ namespace CryptoExchange.Net.UnitTests.TestImplementations { public class TestSocketClient: SocketClient { - public Action OnReconnect { get; set; } + public Func OnReconnect { get; set; } public TestSocketClient() : this(new SocketClientOptions()) { @@ -23,6 +23,7 @@ namespace CryptoExchange.Net.UnitTests.TestImplementations public TestSocket CreateSocket() { + Mock.Get(SocketFactory).Setup(f => f.CreateWebsocket(It.IsAny(), It.IsAny())).Returns(new TestSocket()); return (TestSocket)CreateSocket(BaseAddress); } @@ -33,8 +34,7 @@ namespace CryptoExchange.Net.UnitTests.TestImplementations protected override bool SocketReconnect(SocketSubscription subscription, TimeSpan disconnectedTime) { - OnReconnect?.Invoke(); - return true; + return OnReconnect.Invoke(); } } } diff --git a/CryptoExchange.Net/SocketClient.cs b/CryptoExchange.Net/SocketClient.cs index f41ad2f..358fb00 100644 --- a/CryptoExchange.Net/SocketClient.cs +++ b/CryptoExchange.Net/SocketClient.cs @@ -76,8 +76,11 @@ namespace CryptoExchange.Net socket.DataInterpreter = dataInterpreter; socket.OnClose += () => { - foreach (var sub in sockets) - sub.ResetEvents(); + lock (sockets) + { + foreach (var sub in sockets) + sub.ResetEvents(); + } SocketOnClose(socket); }; @@ -93,7 +96,8 @@ namespace CryptoExchange.Net protected virtual SocketSubscription GetBackgroundSocket(bool authenticated = false) { - return sockets.SingleOrDefault(s => s.Type == (authenticated ? SocketType.BackgroundAuthenticated : SocketType.Background)); + lock (sockets) + return sockets.SingleOrDefault(s => s.Type == (authenticated ? SocketType.BackgroundAuthenticated : SocketType.Background)); } protected virtual void SocketOpened(IWebsocket socket) { } @@ -197,7 +201,7 @@ namespace CryptoExchange.Net socket.Dispose(); lock (sockets) { - var subscription = sockets.SingleOrDefault(s => s.Socket == socket); + var subscription = sockets.SingleOrDefault(s => s.Socket.Id == socket.Id); if(subscription != null) sockets.Remove(subscription); } @@ -248,8 +252,12 @@ namespace CryptoExchange.Net await Task.Run(() => { var tasks = new List(); - foreach (var sub in new List(sockets)) - tasks.Add(sub.Close()); + lock (sockets) + { + foreach (var sub in new List(sockets)) + tasks.Add(sub.Close()); + } + Task.WaitAll(tasks.ToArray()); }); } @@ -257,8 +265,7 @@ namespace CryptoExchange.Net public override void Dispose() { log.Write(LogVerbosity.Debug, "Disposing socket client, closing all subscriptions"); - lock (sockets) - UnsubscribeAll().Wait(); + UnsubscribeAll().Wait(); base.Dispose(); }