diff --git a/CryptoExchange.Net/Clients/SocketApiClient.cs b/CryptoExchange.Net/Clients/SocketApiClient.cs index a57bc59..7ca9391 100644 --- a/CryptoExchange.Net/Clients/SocketApiClient.cs +++ b/CryptoExchange.Net/Clients/SocketApiClient.cs @@ -282,10 +282,11 @@ namespace CryptoExchange.Net.Clients /// Expected result type /// The type returned to the caller /// The query + /// Cancellation token /// - protected virtual Task> QueryAsync(Query query) + protected virtual Task> QueryAsync(Query query, CancellationToken ct = default) { - return QueryAsync(BaseAddress, query); + return QueryAsync(BaseAddress, query, ct); } /// @@ -295,12 +296,16 @@ namespace CryptoExchange.Net.Clients /// The type returned to the caller /// The url for the request /// The query + /// Cancellation token /// - protected virtual async Task> QueryAsync(string url, Query query) + protected virtual async Task> QueryAsync(string url, Query query, CancellationToken ct = default) { if (_disposing) return new CallResult(new InvalidOperationError("Client disposed, can't query")); + if (ct.IsCancellationRequested) + return new CallResult(new CancellationRequestedError()); + SocketConnection socketConnection; var released = false; await semaphoreSlim.WaitAsync().ConfigureAwait(false); @@ -335,7 +340,10 @@ namespace CryptoExchange.Net.Clients return new CallResult(new ServerError("Socket is paused")); } - return await socketConnection.SendAndWaitQueryAsync(query).ConfigureAwait(false); + if (ct.IsCancellationRequested) + return new CallResult(new CancellationRequestedError()); + + return await socketConnection.SendAndWaitQueryAsync(query, null, ct).ConfigureAwait(false); } /// diff --git a/CryptoExchange.Net/Objects/AsyncAutoResetEvent.cs b/CryptoExchange.Net/Objects/AsyncAutoResetEvent.cs index 4cbffd0..9c6f77e 100644 --- a/CryptoExchange.Net/Objects/AsyncAutoResetEvent.cs +++ b/CryptoExchange.Net/Objects/AsyncAutoResetEvent.cs @@ -32,7 +32,7 @@ namespace CryptoExchange.Net.Objects /// Wait for the AutoResetEvent to be set /// /// - public Task WaitAsync(TimeSpan? timeout = null) + public Task WaitAsync(TimeSpan? timeout = null, CancellationToken ct = default) { lock (_waits) { @@ -44,22 +44,29 @@ namespace CryptoExchange.Net.Objects } else { - var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - if(timeout != null) - { - var cancellationSource = new CancellationTokenSource(timeout.Value); - var registration = cancellationSource.Token.Register(() => - { - lock (_waits) - { - tcs.TrySetResult(false); + if (ct.IsCancellationRequested) + return _completed; - // Not the cleanest but it works - _waits = new Queue>(_waits.Where(i => i != tcs)); - } - }, useSynchronizationContext: false); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + if (timeout.HasValue) + { + var timeoutSource = new CancellationTokenSource(timeout.Value); + var cancellationSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutSource.Token, ct); + ct = cancellationSource.Token; } + var registration = ct.Register(() => + { + lock (_waits) + { + tcs.TrySetResult(false); + + // Not the cleanest but it works + _waits = new Queue>(_waits.Where(i => i != tcs)); + } + }, useSynchronizationContext: false); + + _waits.Enqueue(tcs); return tcs.Task; } diff --git a/CryptoExchange.Net/Sockets/Query.cs b/CryptoExchange.Net/Sockets/Query.cs index d660b2a..a45cd3c 100644 --- a/CryptoExchange.Net/Sockets/Query.cs +++ b/CryptoExchange.Net/Sockets/Query.cs @@ -111,8 +111,9 @@ namespace CryptoExchange.Net.Sockets /// Wait untill timeout or the request is competed /// /// + /// Cancellation token /// - public async Task WaitAsync(TimeSpan timeout) => await _event.WaitAsync(timeout).ConfigureAwait(false); + public async Task WaitAsync(TimeSpan timeout, CancellationToken ct) => await _event.WaitAsync(timeout, ct).ConfigureAwait(false); /// public virtual CallResult Deserialize(IMessageAccessor message, Type type) => message.Deserialize(type); diff --git a/CryptoExchange.Net/Sockets/SocketConnection.cs b/CryptoExchange.Net/Sockets/SocketConnection.cs index 8238964..912909b 100644 --- a/CryptoExchange.Net/Sockets/SocketConnection.cs +++ b/CryptoExchange.Net/Sockets/SocketConnection.cs @@ -690,10 +690,11 @@ namespace CryptoExchange.Net.Sockets /// /// Query to send /// Wait event for when the socket message handler can continue + /// Cancellation token /// - public virtual async Task SendAndWaitQueryAsync(Query query, ManualResetEvent? continueEvent = null) + public virtual async Task SendAndWaitQueryAsync(Query query, ManualResetEvent? continueEvent = null, CancellationToken ct = default) { - await SendAndWaitIntAsync(query, continueEvent).ConfigureAwait(false); + await SendAndWaitIntAsync(query, continueEvent, ct).ConfigureAwait(false); return query.Result ?? new CallResult(new ServerError("Timeout")); } @@ -704,14 +705,15 @@ namespace CryptoExchange.Net.Sockets /// The type returned to the caller /// Query to send /// Wait event for when the socket message handler can continue + /// Cancellation token /// - public virtual async Task> SendAndWaitQueryAsync(Query query, ManualResetEvent? continueEvent = null) + public virtual async Task> SendAndWaitQueryAsync(Query query, ManualResetEvent? continueEvent = null, CancellationToken ct = default) { - await SendAndWaitIntAsync(query, continueEvent).ConfigureAwait(false); + await SendAndWaitIntAsync(query, continueEvent, ct).ConfigureAwait(false); return query.TypedResult ?? new CallResult(new ServerError("Timeout")); } - private async Task SendAndWaitIntAsync(Query query, ManualResetEvent? continueEvent) + private async Task SendAndWaitIntAsync(Query query, ManualResetEvent? continueEvent, CancellationToken ct = default) { lock(_listenersLock) _listeners.Add(query); @@ -728,7 +730,7 @@ namespace CryptoExchange.Net.Sockets try { - while (true) + while (!ct.IsCancellationRequested) { if (!_socket.IsOpen) { @@ -739,11 +741,17 @@ namespace CryptoExchange.Net.Sockets if (query.Completed) return; - await query.WaitAsync(TimeSpan.FromMilliseconds(500)).ConfigureAwait(false); + await query.WaitAsync(TimeSpan.FromMilliseconds(500), ct).ConfigureAwait(false); if (query.Completed) return; } + + if (ct.IsCancellationRequested) + { + query.Fail(new CancellationRequestedError()); + return; + } } finally {