diff --git a/CryptoExchange.Net/Objects/CallResult.cs b/CryptoExchange.Net/Objects/CallResult.cs
index a4a02e1..6837fcb 100644
--- a/CryptoExchange.Net/Objects/CallResult.cs
+++ b/CryptoExchange.Net/Objects/CallResult.cs
@@ -154,6 +154,9 @@ namespace CryptoExchange.Net.Objects
///
public CallResult AsDataless()
{
+ if (Error != null )
+ return new CallResult(Error);
+
return SuccessResult;
}
diff --git a/CryptoExchange.Net/RateLimiting/Guards/RateLimitGuard.cs b/CryptoExchange.Net/RateLimiting/Guards/RateLimitGuard.cs
index 6d5b971..5fd5d38 100644
--- a/CryptoExchange.Net/RateLimiting/Guards/RateLimitGuard.cs
+++ b/CryptoExchange.Net/RateLimiting/Guards/RateLimitGuard.cs
@@ -3,6 +3,7 @@ using CryptoExchange.Net.RateLimiting.Interfaces;
using CryptoExchange.Net.RateLimiting.Trackers;
using System;
using System.Collections.Generic;
+using System.Threading;
namespace CryptoExchange.Net.RateLimiting.Guards
{
@@ -36,6 +37,7 @@ namespace CryptoExchange.Net.RateLimiting.Guards
private readonly double? _decayRate;
private readonly int? _connectionWeight;
private readonly Func _keySelector;
+ private readonly SemaphoreSlim? _sharedGuardSemaphore;
///
public string Name => "RateLimitGuard";
@@ -52,6 +54,11 @@ namespace CryptoExchange.Net.RateLimiting.Guards
///
public TimeSpan TimeSpan { get; }
+ ///
+ /// Whether this guard is shared between multiple gates
+ ///
+ public bool SharedGuard { get; }
+
///
/// ctor
///
@@ -62,8 +69,9 @@ namespace CryptoExchange.Net.RateLimiting.Guards
/// Type of rate limit window
/// The decay per timespan if windowType is DecayWindowTracker
/// The weight of a new connection
- public RateLimitGuard(Func keySelector, IGuardFilter filter, int limit, TimeSpan timeSpan, RateLimitWindowType windowType, double? decayPerTimeSpan = null, int? connectionWeight = null)
- : this(keySelector, new[] { filter }, limit, timeSpan, windowType, decayPerTimeSpan, connectionWeight)
+ /// Whether this guard is shared between multiple gates
+ public RateLimitGuard(Func keySelector, IGuardFilter filter, int limit, TimeSpan timeSpan, RateLimitWindowType windowType, double? decayPerTimeSpan = null, int? connectionWeight = null, bool shared = false)
+ : this(keySelector, new[] { filter }, limit, timeSpan, windowType, decayPerTimeSpan, connectionWeight, shared)
{
}
@@ -77,22 +85,27 @@ namespace CryptoExchange.Net.RateLimiting.Guards
/// Type of rate limit window
/// The decay per timespan if windowType is DecayWindowTracker
/// The weight of a new connection
- public RateLimitGuard(Func keySelector, IEnumerable filters, int limit, TimeSpan timeSpan, RateLimitWindowType windowType, double? decayPerTimeSpan = null, int? connectionWeight = null)
+ /// Whether this guard is shared between multiple gates
+ public RateLimitGuard(Func keySelector, IEnumerable filters, int limit, TimeSpan timeSpan, RateLimitWindowType windowType, double? decayPerTimeSpan = null, int? connectionWeight = null, bool shared = false)
{
_filters = filters;
_trackers = new Dictionary();
_windowType = windowType;
Limit = limit;
TimeSpan = timeSpan;
+ SharedGuard = shared;
_keySelector = keySelector;
_decayRate = decayPerTimeSpan;
_connectionWeight = connectionWeight;
+
+ if (SharedGuard)
+ _sharedGuardSemaphore = new SemaphoreSlim(1, 1);
}
///
public LimitCheck Check(RateLimitItemType type, RequestDefinition definition, string host, string? apiKey, int requestWeight, string? keySuffix)
{
- foreach(var filter in _filters)
+ foreach (var filter in _filters)
{
if (!filter.Passes(type, definition, host, apiKey))
return LimitCheck.NotApplicable;
@@ -101,18 +114,30 @@ namespace CryptoExchange.Net.RateLimiting.Guards
if (type == RateLimitItemType.Connection)
requestWeight = _connectionWeight ?? requestWeight;
- var key = _keySelector(definition, host, apiKey) + keySuffix;
- if (!_trackers.TryGetValue(key, out var tracker))
+ if (SharedGuard)
+ _sharedGuardSemaphore!.Wait();
+
+ try
{
- tracker = CreateTracker();
- _trackers.Add(key, tracker);
+ var key = _keySelector(definition, host, apiKey) + keySuffix;
+ if (!_trackers.TryGetValue(key, out var tracker))
+ {
+ tracker = CreateTracker();
+ _trackers.Add(key, tracker);
+ }
+
+
+ var delay = tracker.GetWaitTime(requestWeight);
+ if (delay == default)
+ return LimitCheck.NotNeeded(Limit, TimeSpan, tracker.Current);
+
+ return LimitCheck.Needed(delay, Limit, TimeSpan, tracker.Current);
+ }
+ finally
+ {
+ if (SharedGuard)
+ _sharedGuardSemaphore!.Release();
}
-
- var delay = tracker.GetWaitTime(requestWeight);
- if (delay == default)
- return LimitCheck.NotNeeded(Limit, TimeSpan, tracker.Current);
-
- return LimitCheck.Needed(delay, Limit, TimeSpan, tracker.Current);
}
///
@@ -127,9 +152,23 @@ namespace CryptoExchange.Net.RateLimiting.Guards
if (type == RateLimitItemType.Connection)
requestWeight = _connectionWeight ?? requestWeight;
+
var key = _keySelector(definition, host, apiKey) + keySuffix;
var tracker = _trackers[key];
- tracker.ApplyWeight(requestWeight);
+
+ if (SharedGuard)
+ _sharedGuardSemaphore!.Wait();
+
+ try
+ {
+ tracker.ApplyWeight(requestWeight);
+ }
+ finally
+ {
+ if (SharedGuard)
+ _sharedGuardSemaphore!.Release();
+ }
+
return RateLimitState.Applied(Limit, TimeSpan, tracker.Current);
}
diff --git a/CryptoExchange.Net/Testing/SharedRestRequestValidator.cs b/CryptoExchange.Net/Testing/SharedRestRequestValidator.cs
new file mode 100644
index 0000000..498222a
--- /dev/null
+++ b/CryptoExchange.Net/Testing/SharedRestRequestValidator.cs
@@ -0,0 +1,126 @@
+using CryptoExchange.Net.Clients;
+using CryptoExchange.Net.Objects;
+using CryptoExchange.Net.SharedApis;
+using CryptoExchange.Net.Testing.Comparers;
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Net.Http;
+using System.Text;
+using System.Threading.Tasks;
+
+namespace CryptoExchange.Net.Testing
+{
+ ///
+ /// Validator for REST requests, comparing path, http method, authentication and response parsing
+ ///
+ /// The Rest client
+ public class SharedRestRequestValidator where TClient : BaseRestClient
+ {
+ private readonly TClient _client;
+ private readonly Func _isAuthenticated;
+ private readonly string _folder;
+ private readonly string _baseAddress;
+ private readonly string? _nestedPropertyForCompare;
+
+ ///
+ /// ctor
+ ///
+ /// Client to test
+ /// Folder for json test values
+ /// The base address that is expected
+ /// Func for checking if the request is authenticated
+ /// Property to use for compare
+ public SharedRestRequestValidator(TClient client, string folder, string baseAddress, Func isAuthenticated, string? nestedPropertyForCompare = null)
+ {
+ _client = client;
+ _folder = folder;
+ _baseAddress = baseAddress;
+ _nestedPropertyForCompare = nestedPropertyForCompare;
+ _isAuthenticated = isAuthenticated;
+ }
+
+ ///
+ /// Validate a request
+ ///
+ /// Expected response type
+ /// Method invocation
+ /// Method name for looking up json test values
+ /// Request options
+ ///
+ ///
+ public Task ValidateAsync(
+ Func>> methodInvoke,
+ string name,
+ EndpointOptions endpointOptions,
+ params Func[] validation)
+ => ValidateAsync(methodInvoke, name, endpointOptions, validation);
+
+ ///
+ /// Validate a request
+ ///
+ /// Expected response type
+ /// The concrete response type
+ /// Method invocation
+ /// Method name for looking up json test values
+ /// Request options
+ ///
+ ///
+ public async Task ValidateAsync(
+ Func>> methodInvoke,
+ string name,
+ EndpointOptions endpointOptions,
+ params Func[] validation) where TActualResponse : TResponse
+ {
+ var listener = new EnumValueTraceListener();
+ Trace.Listeners.Add(listener);
+
+ var path = Directory.GetParent(Environment.CurrentDirectory)!.Parent!.Parent!.FullName;
+ FileStream file;
+ try
+ {
+ file = File.OpenRead(Path.Combine(path, _folder, $"{name}.txt"));
+ }
+ catch (FileNotFoundException)
+ {
+ throw new Exception($"Response file not found for {name}: {path}");
+ }
+
+ var buffer = new byte[file.Length];
+ await file.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
+ file.Close();
+
+ var data = Encoding.UTF8.GetString(buffer);
+ using var reader = new StringReader(data);
+ var expectedMethod = reader.ReadLine();
+ var expectedPath = reader.ReadLine();
+ var expectedAuth = bool.Parse(reader.ReadLine()!);
+ var response = reader.ReadToEnd();
+
+ TestHelpers.ConfigureRestClient(_client, response, System.Net.HttpStatusCode.OK);
+ var result = await methodInvoke(_client).ConfigureAwait(false);
+
+ // Check request/response properties
+ if (result.Error != null)
+ throw new Exception(name + " returned error " + result.Error);
+ if (endpointOptions.NeedsAuthentication != expectedAuth)
+ throw new Exception(name + $" authentication not matched. Expected: {expectedAuth}, Actual: {_isAuthenticated(result.AsDataless())}");
+ if (result.RequestMethod != new HttpMethod(expectedMethod!))
+ throw new Exception(name + $" http method not matched. Expected {expectedMethod}, Actual: {result.RequestMethod}");
+ if (expectedPath != result.RequestUrl!.Replace(_baseAddress, "").Split(new char[] { '?' })[0])
+ throw new Exception(name + $" path not matched. Expected: {expectedPath}, Actual: {result.RequestUrl!.Replace(_baseAddress, "").Split(new char[] { '?' })[0]}");
+
+ var index = 0;
+ foreach(var validate in validation)
+ {
+ if (!validate(result.Data!))
+ throw new Exception(name + $" response validation #{index} failed");
+
+ index++;
+ }
+
+ Trace.Listeners.Remove(listener);
+ }
+ }
+}