diff --git a/CryptoExchange.Net.UnitTests/RestClientTests.cs b/CryptoExchange.Net.UnitTests/RestClientTests.cs index aa480d6..5ab0f08 100644 --- a/CryptoExchange.Net.UnitTests/RestClientTests.cs +++ b/CryptoExchange.Net.UnitTests/RestClientTests.cs @@ -8,10 +8,11 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; using CryptoExchange.Net.Interfaces; -using CryptoExchange.Net.RateLimiter; using Microsoft.Extensions.Logging; using System.Net.Http; using System.Threading.Tasks; +using CryptoExchange.Net.Logging; +using System.Threading; namespace CryptoExchange.Net.UnitTests { @@ -108,7 +109,7 @@ namespace CryptoExchange.Net.UnitTests var client = new TestRestClient(new RestClientOptions() { BaseAddress = "http://test.address.com", - RateLimiters = new List{new RateLimiterTotal(1, TimeSpan.FromSeconds(1))}, + RateLimiters = new List{new RateLimiter()}, RateLimitingBehaviour = RateLimitingBehaviour.Fail, RequestTimeout = TimeSpan.FromMinutes(1) }); @@ -161,84 +162,199 @@ namespace CryptoExchange.Net.UnitTests Assert.IsTrue(request.GetHeaders().First().Value.Contains("123")); } - [TestCase] - public void SettingRateLimitingBehaviourToFail_Should_FailLimitedRequests() + + [TestCase(1, 0.1)] + [TestCase(2, 0.1)] + [TestCase(5, 1)] + [TestCase(1, 2)] + public async Task PartialEndpointRateLimiterBasics(int requests, double perSeconds) { - // arrange - var client = new TestRestClient(new RestClientOptions() + var log = new Log("Test"); + log.Level = LogLevel.Trace; + + var rateLimiter = new RateLimiter(); + rateLimiter.AddPartialEndpointLimit("/sapi/", requests, TimeSpan.FromSeconds(perSeconds)); + + for (var i = 0; i < requests + 1; i++) { - RateLimiters = new List { new RateLimiterTotal(1, TimeSpan.FromSeconds(1)) }, - RateLimitingBehaviour = RateLimitingBehaviour.Fail - }); - client.SetResponse("{\"property\": 123}", out _); + var result1 = await rateLimiter.LimitRequestAsync(log, "/sapi/v1/system/status", HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + Assert.IsTrue(i == requests? result1.Data > 1 : result1.Data == 0); + } - - // act - var result1 = client.Request().Result; - client.SetResponse("{\"property\": 123}", out _); - var result2 = client.Request().Result; - - - // assert - Assert.IsTrue(result1.Success); - Assert.IsFalse(result2.Success); + await Task.Delay((int)Math.Round(perSeconds * 1000) + 10); + var result2 = await rateLimiter.LimitRequestAsync(log, "/sapi/v1/system/status", HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + Assert.IsTrue(result2.Data == 0); } - [TestCase] - public void SettingRateLimitingBehaviourToWait_Should_DelayLimitedRequests() + [TestCase("/sapi/test1", true)] + [TestCase("/sapi/test2", true)] + [TestCase("/api/test1", false)] + [TestCase("sapi/test1", false)] + [TestCase("/sapi/", true)] + public async Task PartialEndpointRateLimiterEndpoints(string endpoint, bool expectLimiting) { - // arrange - var client = new TestRestClient(new RestClientOptions() + var log = new Log("Test"); + log.Level = LogLevel.Trace; + + var rateLimiter = new RateLimiter(); + rateLimiter.AddPartialEndpointLimit("/sapi/", 1, TimeSpan.FromSeconds(0.1)); + + for (var i = 0; i < 2; i++) { - RateLimiters = new List { new RateLimiterTotal(1, TimeSpan.FromSeconds(1)) }, - RateLimitingBehaviour = RateLimitingBehaviour.Wait - }); - client.SetResponse("{\"property\": 123}", out _); + var result1 = await rateLimiter.LimitRequestAsync(log, endpoint, HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + bool expected = i == 1 ? (expectLimiting ? result1.Data > 1 : result1.Data == 0) : result1.Data == 0; + Assert.IsTrue(expected); + } + } + [TestCase("/sapi/", "/sapi/", true)] + [TestCase("/sapi/test", "/sapi/test", true)] + [TestCase("/sapi/test", "/sapi/test123", false)] + [TestCase("/sapi/test", "/sapi/", false)] + public async Task PartialEndpointRateLimiterEndpoints(string endpoint1, string endpoint2, bool expectLimiting) + { + var log = new Log("Test"); + log.Level = LogLevel.Trace; + var rateLimiter = new RateLimiter(); + rateLimiter.AddPartialEndpointLimit("/sapi/", 1, TimeSpan.FromSeconds(0.1), countPerEndpoint: true); - // act - var sw = Stopwatch.StartNew(); - var result1 = client.Request().Result; - client.SetResponse("{\"property\": 123}", out _); // reset response stream - var result2 = client.Request().Result; - sw.Stop(); - - // assert - Assert.IsTrue(result1.Success); - Assert.IsTrue(result2.Success); - Assert.IsTrue(sw.ElapsedMilliseconds > 900, $"Actual: {sw.ElapsedMilliseconds}"); + var result1 = await rateLimiter.LimitRequestAsync(log, endpoint1, HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + var result2 = await rateLimiter.LimitRequestAsync(log, endpoint2, HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + Assert.IsTrue(result1.Data == 0); + Assert.IsTrue(expectLimiting ? result2.Data > 0 : result2.Data == 0); } - [TestCase] - public void SettingApiKeyRateLimiter_Should_DelayRequestsFromSameKey() + [TestCase(1, 0.1)] + [TestCase(2, 0.1)] + [TestCase(5, 1)] + [TestCase(1, 2)] + public async Task EndpointRateLimiterBasics(int requests, double perSeconds) { - // arrange - var client = new TestRestClient(new RestClientOptions() + var log = new Log("Test"); + log.Level = LogLevel.Trace; + + var rateLimiter = new RateLimiter(); + rateLimiter.AddEndpointLimit("/sapi/test", requests, TimeSpan.FromSeconds(perSeconds)); + + for (var i = 0; i < requests + 1; i++) { - RateLimiters = new List { new RateLimiterAPIKey(1, TimeSpan.FromSeconds(1)) }, - RateLimitingBehaviour = RateLimitingBehaviour.Wait, - LogLevel = LogLevel.Debug, - ApiCredentials = new ApiCredentials("TestKey", "TestSecret") - }); - client.SetResponse("{\"property\": 123}", out _); + var result1 = await rateLimiter.LimitRequestAsync(log, "/sapi/test", HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + Assert.IsTrue(i == requests ? result1.Data > 1 : result1.Data == 0); + } + await Task.Delay((int)Math.Round(perSeconds * 1000) + 10); + var result2 = await rateLimiter.LimitRequestAsync(log, "/sapi/test", HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + Assert.IsTrue(result2.Data == 0); + } - // act - var sw = Stopwatch.StartNew(); - var result1 = client.Request().Result; - client.SetKey("TestKey2", "TestSecret2"); // set to different key - client.SetResponse("{\"property\": 123}", out _); // reset response stream - var result2 = client.Request().Result; - client.SetKey("TestKey", "TestSecret"); // set back to original key, should delay - client.SetResponse("{\"property\": 123}", out _); // reset response stream - var result3 = client.Request().Result; - sw.Stop(); + [TestCase("/", false)] + [TestCase("/sapi/test", true)] + [TestCase("/sapi/test/123", false)] + public async Task EndpointRateLimiterEndpoints(string endpoint, bool expectLimited) + { + var log = new Log("Test"); + log.Level = LogLevel.Trace; - // assert - Assert.IsTrue(result1.Success); - Assert.IsTrue(result2.Success); - Assert.IsTrue(result3.Success); - Assert.IsTrue(sw.ElapsedMilliseconds > 900 && sw.ElapsedMilliseconds < 1900, $"Actual: {sw.ElapsedMilliseconds}"); + var rateLimiter = new RateLimiter(); + rateLimiter.AddEndpointLimit("/sapi/test", 1, TimeSpan.FromSeconds(0.1)); + + for (var i = 0; i < 2; i++) + { + var result1 = await rateLimiter.LimitRequestAsync(log, endpoint, HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + bool expected = i == 1 ? (expectLimited ? result1.Data > 1 : result1.Data == 0) : result1.Data == 0; + Assert.IsTrue(expected); + } + } + + [TestCase("/", false)] + [TestCase("/sapi/test", true)] + [TestCase("/sapi/test2", true)] + [TestCase("/sapi/test23", false)] + public async Task EndpointRateLimiterMultipleEndpoints(string endpoint, bool expectLimited) + { + var log = new Log("Test"); + log.Level = LogLevel.Trace; + + var rateLimiter = new RateLimiter(); + rateLimiter.AddEndpointLimit(new[] { "/sapi/test", "/sapi/test2" }, 1, TimeSpan.FromSeconds(0.1)); + + for (var i = 0; i < 2; i++) + { + var result1 = await rateLimiter.LimitRequestAsync(log, endpoint, HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + bool expected = i == 1 ? (expectLimited ? result1.Data > 1 : result1.Data == 0) : result1.Data == 0; + Assert.IsTrue(expected); + } + } + + [TestCase("123", "123", "/sapi/test", "/sapi/test", true, true, true, true)] + [TestCase("123", "456", "/sapi/test", "/sapi/test", true, true, true, false)] + [TestCase("123", "123", "/sapi/test", "/sapi/test2", true, true, true, true)] + [TestCase("123", "123", "/sapi/test2", "/sapi/test", true, true, true, true)] + [TestCase("123", "123", "/sapi/test", "/sapi/test", true, false, true, false)] + [TestCase("123", "123", "/sapi/test", "/sapi/test", false, true, true, false)] + [TestCase("123", "123", "/sapi/test", "/sapi/test", false, false, true, false)] + [TestCase(null, "123", "/sapi/test", "/sapi/test", false, true, true, false)] + [TestCase("123", null, "/sapi/test", "/sapi/test", true, false, true, false)] + [TestCase(null, null, "/sapi/test", "/sapi/test", false, false, true, false)] + + [TestCase("123", "123", "/sapi/test", "/sapi/test", true, true, false, true)] + [TestCase("123", "456", "/sapi/test", "/sapi/test", true, true, false, false)] + [TestCase("123", "123", "/sapi/test", "/sapi/test2", true, true, false, true)] + [TestCase("123", "123", "/sapi/test2", "/sapi/test", true, true, false, true)] + [TestCase("123", "123", "/sapi/test", "/sapi/test", true, false, false, true)] + [TestCase("123", "123", "/sapi/test", "/sapi/test", false, true, false, true)] + [TestCase("123", "123", "/sapi/test", "/sapi/test", false, false, false, true)] + [TestCase(null, "123", "/sapi/test", "/sapi/test", false, true, false, false)] + [TestCase("123", null, "/sapi/test", "/sapi/test", true, false, false, false)] + [TestCase(null, null, "/sapi/test", "/sapi/test", false, false, false, true)] + public async Task ApiKeyRateLimiterBasics(string key1, string key2, string endpoint1, string endpoint2, bool signed1, bool signed2, bool onlyForSignedRequests, bool expectLimited) + { + var log = new Log("Test"); + log.Level = LogLevel.Trace; + + var rateLimiter = new RateLimiter(); + rateLimiter.AddApiKeyLimit(1, TimeSpan.FromSeconds(0.1), onlyForSignedRequests, false); + + var result1 = await rateLimiter.LimitRequestAsync(log, endpoint1, HttpMethod.Get, signed1, key1?.ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + var result2 = await rateLimiter.LimitRequestAsync(log, endpoint2, HttpMethod.Get, signed2, key2?.ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + Assert.IsTrue(result1.Data == 0); + Assert.IsTrue(expectLimited ? result2.Data > 0 : result2.Data == 0); + } + + [TestCase("/sapi/test", "/sapi/test", true)] + [TestCase("/sapi/test1", "/api/test2", true)] + [TestCase("/", "/sapi/test2", true)] + public async Task TotalRateLimiterBasics(string endpoint1, string endpoint2, bool expectLimited) + { + var log = new Log("Test"); + log.Level = LogLevel.Trace; + + var rateLimiter = new RateLimiter(); + rateLimiter.AddTotalRateLimit(1, TimeSpan.FromSeconds(0.1)); + + var result1 = await rateLimiter.LimitRequestAsync(log, endpoint1, HttpMethod.Get, false, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + var result2 = await rateLimiter.LimitRequestAsync(log, endpoint2, HttpMethod.Get, true, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + Assert.IsTrue(result1.Data == 0); + Assert.IsTrue(expectLimited ? result2.Data > 0 : result2.Data == 0); + } + + [TestCase("/sapi/test", true, true, true, false)] + [TestCase("/sapi/test", false, true, true, false)] + [TestCase("/sapi/test", false, true, false, true)] + [TestCase("/sapi/test", true, true, false, true)] + public async Task ApiKeyRateLimiterIgnores_TotalRateLimiter_IfSet(string endpoint, bool signed1, bool signed2, bool ignoreTotal, bool expectLimited) + { + var log = new Log("Test"); + log.Level = LogLevel.Trace; + + var rateLimiter = new RateLimiter(); + rateLimiter.AddApiKeyLimit(100, TimeSpan.FromSeconds(0.1), true, ignoreTotal); + rateLimiter.AddTotalRateLimit(1, TimeSpan.FromSeconds(0.1)); + + var result1 = await rateLimiter.LimitRequestAsync(log, endpoint, HttpMethod.Get, signed1, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + var result2 = await rateLimiter.LimitRequestAsync(log, endpoint, HttpMethod.Get, signed2, "123".ToSecureString(), RateLimitingBehaviour.Wait, 1, default); + Assert.IsTrue(result1.Data == 0); + Assert.IsTrue(expectLimited ? result2.Data > 0 : result2.Data == 0); } } } diff --git a/CryptoExchange.Net/CryptoExchange.Net.csproj b/CryptoExchange.Net/CryptoExchange.Net.csproj index 5bf50f4..ed0cf96 100644 --- a/CryptoExchange.Net/CryptoExchange.Net.csproj +++ b/CryptoExchange.Net/CryptoExchange.Net.csproj @@ -1,4 +1,4 @@ - + netstandard2.0;netstandard2.1 diff --git a/CryptoExchange.Net/ExtensionMethods.cs b/CryptoExchange.Net/ExtensionMethods.cs index bba9ff4..2055978 100644 --- a/CryptoExchange.Net/ExtensionMethods.cs +++ b/CryptoExchange.Net/ExtensionMethods.cs @@ -177,6 +177,41 @@ namespace CryptoExchange.Net } } + /// + /// Are 2 secure strings equal + /// + /// Source secure string + /// Compare secure string + /// True if equal by value + public static bool IsEqualTo(this SecureString ss1, SecureString ss2) + { + IntPtr bstr1 = IntPtr.Zero; + IntPtr bstr2 = IntPtr.Zero; + try + { + bstr1 = Marshal.SecureStringToBSTR(ss1); + bstr2 = Marshal.SecureStringToBSTR(ss2); + int length1 = Marshal.ReadInt32(bstr1, -4); + int length2 = Marshal.ReadInt32(bstr2, -4); + if (length1 == length2) + { + for (int x = 0; x < length1; ++x) + { + byte b1 = Marshal.ReadByte(bstr1, x); + byte b2 = Marshal.ReadByte(bstr2, x); + if (b1 != b2) return false; + } + } + else return false; + return true; + } + finally + { + if (bstr2 != IntPtr.Zero) Marshal.ZeroFreeBSTR(bstr2); + if (bstr1 != IntPtr.Zero) Marshal.ZeroFreeBSTR(bstr1); + } + } + /// /// Create a secure string from a string /// diff --git a/CryptoExchange.Net/Interfaces/IRateLimiter.cs b/CryptoExchange.Net/Interfaces/IRateLimiter.cs index df6b494..f5b3361 100644 --- a/CryptoExchange.Net/Interfaces/IRateLimiter.cs +++ b/CryptoExchange.Net/Interfaces/IRateLimiter.cs @@ -1,4 +1,10 @@ +using CryptoExchange.Net.Logging; using CryptoExchange.Net.Objects; +using System; +using System.Net.Http; +using System.Security; +using System.Threading; +using System.Threading.Tasks; namespace CryptoExchange.Net.Interfaces { @@ -8,13 +14,17 @@ namespace CryptoExchange.Net.Interfaces public interface IRateLimiter { /// - /// Limit the request if needed + /// Limit a request based on previous requests made /// - /// - /// - /// - /// - /// - CallResult LimitRequest(RestClient client, string url, RateLimitingBehaviour limitBehaviour, int credits=1); + /// The logger + /// The endpoint the request is for + /// The Http request method + /// Whether the request is singed(private) or not + /// The api key making this request + /// The limit behavior for when the limit is reached + /// The weight of the request + /// Cancellation token to cancel waiting + /// The time in milliseconds spend waiting + Task> LimitRequestAsync(Log log, string endpoint, HttpMethod method, bool signed, SecureString? apiKey, RateLimitingBehaviour limitBehaviour, int requestWeight, CancellationToken ct); } } diff --git a/CryptoExchange.Net/Interfaces/IRestClient.cs b/CryptoExchange.Net/Interfaces/IRestClient.cs index 3c150ed..015c6f5 100644 --- a/CryptoExchange.Net/Interfaces/IRestClient.cs +++ b/CryptoExchange.Net/Interfaces/IRestClient.cs @@ -3,7 +3,6 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using CryptoExchange.Net.Objects; -using CryptoExchange.Net.RateLimiter; namespace CryptoExchange.Net.Interfaces { diff --git a/CryptoExchange.Net/Objects/RateLimiter.cs b/CryptoExchange.Net/Objects/RateLimiter.cs new file mode 100644 index 0000000..420f413 --- /dev/null +++ b/CryptoExchange.Net/Objects/RateLimiter.cs @@ -0,0 +1,405 @@ +using CryptoExchange.Net.Interfaces; +using CryptoExchange.Net.Logging; +using CryptoExchange.Net.Objects; +using Microsoft.Extensions.Logging; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Net.Http; +using System.Security; +using System.Threading; +using System.Threading.Tasks; + +namespace CryptoExchange.Net.Objects +{ + /// + /// Limits the amount of requests to a certain constraint + /// + public class RateLimiter : IRateLimiter + { + private readonly object _limiterLock = new object(); + internal List Limiters = new List(); + + /// + /// Create a new RateLimiter. Configure the rate limiter by calling , + /// , or . + /// + public RateLimiter() + { + } + + /// + /// Add a rate limit for the total amount of requests per time period + /// + /// The limit per period. Note that this is weight, not single request, altough by default requests have a weight of 1 + /// The time period the limit is for + public RateLimiter AddTotalRateLimit(int limit, TimeSpan perTimePeriod) + { + lock(_limiterLock) + Limiters.Add(new TotalRateLimiter(limit, perTimePeriod, null)); + return this; + } + + /// + /// Add a rate lmit for the amount of requests per time for an endpoint + /// + /// The endpoint the limit is for + /// The limit per period. Note that this is weight, not single request, altough by default requests have a weight of 1 + /// The time period the limit is for + /// The HttpMethod the limit is for, null for all + /// If set to true it ignores other rate limits + public RateLimiter AddEndpointLimit(string endpoint, int limit, TimeSpan perTimePeriod, HttpMethod? method = null, bool excludeFromOtherRateLimits = false) + { + lock(_limiterLock) + Limiters.Add(new EndpointRateLimiter(new[] { endpoint }, limit, perTimePeriod, method, excludeFromOtherRateLimits)); + return this; + } + + /// + /// Add a rate lmit for the amount of requests per time for an endpoint + /// + /// The endpoints the limit is for + /// The limit per period. Note that this is weight, not single request, altough by default requests have a weight of 1 + /// The time period the limit is for + /// The HttpMethod the limit is for, null for all + /// If set to true it ignores other rate limits + public RateLimiter AddEndpointLimit(IEnumerable endpoints, int limit, TimeSpan perTimePeriod, HttpMethod? method = null, bool excludeFromOtherRateLimits = false) + { + lock(_limiterLock) + Limiters.Add(new EndpointRateLimiter(endpoints.ToArray(), limit, perTimePeriod, method, excludeFromOtherRateLimits)); + return this; + } + + /// + /// Add a rate lmit for the amount of requests per time for an endpoint + /// + /// The endpoint the limit is for + /// The limit per period. Note that this is weight, not single request, altough by default requests have a weight of 1 + /// The time period the limit is for + /// The HttpMethod the limit is for, null for all + /// If set to true it ignores other rate limits + /// Whether all requests for this partial endpoint are bound to the same limit or each individual endpoint has its own limit + public RateLimiter AddPartialEndpointLimit(string endpoint, int limit, TimeSpan perTimePeriod, HttpMethod? method = null, bool countPerEndpoint = false, bool ignoreOtherRateLimits = false) + { + lock(_limiterLock) + Limiters.Add(new PartialEndpointRateLimiter(new[] { endpoint }, limit, perTimePeriod, method, ignoreOtherRateLimits, countPerEndpoint)); + return this; + } + + /// + /// Add a rate limit for the amount of requests per Api key + /// + /// The limit per period. Note that this is weight, not single request, altough by default requests have a weight of 1 + /// The time period the limit is for + /// Only include calls that are signed in this limiter + /// Exclude requests with API key from the total rate limiter + public RateLimiter AddApiKeyLimit(int limit, TimeSpan perTimePeriod, bool onlyForSignedRequests, bool excludeFromTotalRateLimit) + { + lock(_limiterLock) + Limiters.Add(new ApiKeyRateLimiter(limit, perTimePeriod, null, onlyForSignedRequests, excludeFromTotalRateLimit)); + return this; + } + + /// + public async Task> LimitRequestAsync(Log log, string endpoint, HttpMethod method, bool signed, SecureString? apiKey, RateLimitingBehaviour limitBehaviour, int requestWeight, CancellationToken ct) + { + int totalWaitTime = 0; + + EndpointRateLimiter endpointLimit; + lock (_limiterLock) + endpointLimit = Limiters.OfType().SingleOrDefault(h => h.Endpoints.Contains(endpoint) && (h.Method == null || h.Method == method)); + if(endpointLimit != null) + { + var waitResult = await ProcessTopic(log, endpointLimit, endpoint, requestWeight, limitBehaviour, ct).ConfigureAwait(false); + if (!waitResult) + return waitResult; + + totalWaitTime += waitResult.Data; + } + + if (endpointLimit?.IgnoreOtherRateLimits == true) + return new CallResult(totalWaitTime, null); + + List partialEndpointLimits; + lock (_limiterLock) + partialEndpointLimits = Limiters.OfType().Where(h => h.PartialEndpoints.Any(h => endpoint.Contains(h)) && (h.Method == null || h.Method == method)).ToList(); + foreach (var partialEndpointLimit in partialEndpointLimits) + { + if (partialEndpointLimit.CountPerEndpoint) + { + SingleTopicRateLimiter thisEndpointLimit; + lock (_limiterLock) + { + thisEndpointLimit = Limiters.OfType().SingleOrDefault(h => h.Type == RateLimitType.PartialEndpoint && (string)h.Topic == endpoint); + if (thisEndpointLimit == null) + { + thisEndpointLimit = new SingleTopicRateLimiter(endpoint, partialEndpointLimit); + Limiters.Add(thisEndpointLimit); + } + } + + var waitResult = await ProcessTopic(log, thisEndpointLimit, endpoint, requestWeight, limitBehaviour, ct).ConfigureAwait(false); + if (!waitResult) + return waitResult; + + totalWaitTime += waitResult.Data; + } + else + { + var waitResult = await ProcessTopic(log, partialEndpointLimit, endpoint, requestWeight, limitBehaviour, ct).ConfigureAwait(false); + if (!waitResult) + return waitResult; + + totalWaitTime += waitResult.Data; + } + } + + if(partialEndpointLimits.Any(p => p.IgnoreOtherRateLimits)) + return new CallResult(totalWaitTime, null); + + ApiKeyRateLimiter apiLimit; + lock (_limiterLock) + apiLimit = Limiters.OfType().SingleOrDefault(h => h.Type == RateLimitType.ApiKey); + if (apiLimit != null) + { + if(apiKey == null) + { + if (!apiLimit.OnlyForSignedRequests) + { + var waitResult = await ProcessTopic(log, apiLimit, endpoint, requestWeight, limitBehaviour, ct).ConfigureAwait(false); + if (!waitResult) + return waitResult; + + totalWaitTime += waitResult.Data; + } + } + else if (signed || !apiLimit.OnlyForSignedRequests) + { + SingleTopicRateLimiter thisApiLimit; + lock (_limiterLock) + { + thisApiLimit = Limiters.OfType().SingleOrDefault(h => h.Type == RateLimitType.ApiKey && ((SecureString)h.Topic).IsEqualTo(apiKey)); + if (thisApiLimit == null) + { + thisApiLimit = new SingleTopicRateLimiter(apiKey, apiLimit); + Limiters.Add(thisApiLimit); + } + } + + var waitResult = await ProcessTopic(log, thisApiLimit, endpoint, requestWeight, limitBehaviour, ct).ConfigureAwait(false); + if (!waitResult) + return waitResult; + + totalWaitTime += waitResult.Data; + } + } + + if ((signed || apiLimit?.OnlyForSignedRequests == false) && apiLimit?.IgnoreTotalRateLimit == true) + return new CallResult(totalWaitTime, null); + + TotalRateLimiter totalLimit; + lock (_limiterLock) + totalLimit = Limiters.OfType().SingleOrDefault(); + if (totalLimit != null) + { + var waitResult = await ProcessTopic(log, totalLimit, endpoint, requestWeight, limitBehaviour, ct).ConfigureAwait(false); + if (!waitResult) + return waitResult; + + totalWaitTime += waitResult.Data; + } + + return new CallResult(totalWaitTime, null); + } + + private static async Task> ProcessTopic(Log log, Limiter historyTopic, string endpoint, int requestWeight, RateLimitingBehaviour limitBehaviour, CancellationToken ct) + { + var sw = Stopwatch.StartNew(); + try + { + await historyTopic.Semaphore.WaitAsync(ct).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + return new CallResult(0, new CancellationRequestedError()); + } + sw.Stop(); + + int totalWaitTime = 0; + while (true) + { + // Remove requests no longer in time period from the history + var checkTime = DateTime.UtcNow; + for (var i = 0; i < historyTopic.Entries.Count; i++) + { + if (historyTopic.Entries[i].Timestamp < checkTime - historyTopic.Period) + { + historyTopic.Entries.Remove(historyTopic.Entries[i]); + i--; + } + else + break; + } + + var currentWeight = historyTopic.Entries.Sum(h => h.Weight); + if (currentWeight + requestWeight > historyTopic.Limit) + { + // Wait until the next entry should be removed from the history + var thisWaitTime = (int)Math.Round((historyTopic.Entries.First().Timestamp - (checkTime - historyTopic.Period)).TotalMilliseconds); + if (thisWaitTime > 0) + { + if (limitBehaviour == RateLimitingBehaviour.Fail) + { + historyTopic.Semaphore.Release(); + var msg = $"Request to {endpoint} failed because of rate limit `{historyTopic}`. Current weight: {currentWeight}/{historyTopic.Limit}, request weight: {requestWeight}"; + log.Write(LogLevel.Warning, msg); + return new CallResult(thisWaitTime, new RateLimitError(msg)); + } + + log.Write(LogLevel.Information, $"Request to {endpoint} waiting {thisWaitTime}ms for rate limit `{historyTopic}`. Current weight: {currentWeight}/{historyTopic.Limit}, request weight: {requestWeight}"); + try + { + await Task.Delay(thisWaitTime, ct).ConfigureAwait(false); + } + catch (OperationCanceledException) + { + return new CallResult(0, new CancellationRequestedError()); + } + totalWaitTime += thisWaitTime; + } + } + else + { + break; + } + } + + var newTime = DateTime.UtcNow; + historyTopic.Entries.Add(new LimitEntry(newTime, requestWeight)); + historyTopic.Semaphore.Release(); + return new CallResult(totalWaitTime, null); + } + + internal struct LimitEntry + { + public DateTime Timestamp { get; set; } + public int Weight { get; set; } + + public LimitEntry(DateTime timestamp, int weight) + { + Timestamp = timestamp; + Weight = weight; + } + } + + internal class Limiter + { + public RateLimitType Type { get; set; } + public HttpMethod? Method { get; set; } + + public SemaphoreSlim Semaphore { get; set; } + public int Limit { get; set; } + + public TimeSpan Period { get; set; } + public List Entries { get; set; } = new List(); + + public Limiter(RateLimitType type, int limit, TimeSpan perPeriod, HttpMethod? method) + { + Semaphore = new SemaphoreSlim(1, 1); + Type = type; + Limit = limit; + Period = perPeriod; + Method = method; + } + } + + internal class TotalRateLimiter : Limiter + { + public TotalRateLimiter(int limit, TimeSpan perPeriod, HttpMethod? method) + : base(RateLimitType.Total, limit, perPeriod, method) + { + } + + public override string ToString() + { + return nameof(TotalRateLimiter); + } + } + + internal class EndpointRateLimiter: Limiter + { + public string[] Endpoints { get; set; } + public bool IgnoreOtherRateLimits { get; set; } + + public EndpointRateLimiter(string[] endpoints, int limit, TimeSpan perPeriod, HttpMethod? method, bool ignoreOtherRateLimits) + :base(RateLimitType.Endpoint, limit, perPeriod, method) + { + Endpoints = endpoints; + IgnoreOtherRateLimits = ignoreOtherRateLimits; + } + + public override string ToString() + { + return nameof(EndpointRateLimiter) + $": {string.Join(", ", Endpoints)}"; + } + } + + internal class PartialEndpointRateLimiter : Limiter + { + public string[] PartialEndpoints { get; set; } + public bool IgnoreOtherRateLimits { get; set; } + public bool CountPerEndpoint { get; set; } + + public PartialEndpointRateLimiter(string[] partialEndpoints, int limit, TimeSpan perPeriod, HttpMethod? method, bool ignoreOtherRateLimits, bool countPerEndpoint) + : base(RateLimitType.PartialEndpoint, limit, perPeriod, method) + { + PartialEndpoints = partialEndpoints; + IgnoreOtherRateLimits = ignoreOtherRateLimits; + CountPerEndpoint = countPerEndpoint; + } + + public override string ToString() + { + return nameof(PartialEndpointRateLimiter) + $": {string.Join(", ", PartialEndpoints)}"; + } + } + + internal class ApiKeyRateLimiter : Limiter + { + public bool OnlyForSignedRequests { get; set; } + public bool IgnoreTotalRateLimit { get; set; } + + public ApiKeyRateLimiter(int limit, TimeSpan perPeriod, HttpMethod? method, bool onlyForSignedRequests, bool ignoreTotalRateLimit) + :base(RateLimitType.ApiKey, limit, perPeriod, method) + { + OnlyForSignedRequests = onlyForSignedRequests; + IgnoreTotalRateLimit = ignoreTotalRateLimit; + } + } + + internal class SingleTopicRateLimiter: Limiter + { + public object Topic { get; set; } + + public SingleTopicRateLimiter(object topic, Limiter limiter) + :base(limiter.Type, limiter.Limit, limiter.Period, limiter.Method) + { + Topic = topic; + } + + public override string ToString() + { + return (Type == RateLimitType.ApiKey ? nameof(ApiKeyRateLimiter): nameof(EndpointRateLimiter)) + $": {Topic}"; + } + } + + internal enum RateLimitType + { + Total, + Endpoint, + PartialEndpoint, + ApiKey + } + } +} diff --git a/CryptoExchange.Net/RateLimiter/RateLimitObject.cs b/CryptoExchange.Net/RateLimiter/RateLimitObject.cs deleted file mode 100644 index 34ed08c..0000000 --- a/CryptoExchange.Net/RateLimiter/RateLimitObject.cs +++ /dev/null @@ -1,52 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace CryptoExchange.Net.RateLimiter -{ - /// - /// Rate limiting object - /// - public class RateLimitObject - { - /// - /// Lock - /// - public object LockObject { get; } - private List Times { get; } - - /// - /// ctor - /// - public RateLimitObject() - { - LockObject = new object(); - Times = new List(); - } - - /// - /// Get time to wait for a specific time - /// - /// - /// - /// - /// - public int GetWaitTime(DateTime time, int limit, TimeSpan perTimePeriod) - { - Times.RemoveAll(d => d < time - perTimePeriod); - if (Times.Count >= limit) - return (int)Math.Round((Times.First() - (time - perTimePeriod)).TotalMilliseconds); - return 0; - } - - /// - /// Add an executed request time - /// - /// - public void Add(DateTime time) - { - Times.Add(time); - Times.Sort(); - } - } -} diff --git a/CryptoExchange.Net/RateLimiter/RateLimiterAPIKey.cs b/CryptoExchange.Net/RateLimiter/RateLimiterAPIKey.cs deleted file mode 100644 index 8bfe2f6..0000000 --- a/CryptoExchange.Net/RateLimiter/RateLimiterAPIKey.cs +++ /dev/null @@ -1,92 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Security.Cryptography; -using System.Text; -using System.Threading; -using CryptoExchange.Net.Interfaces; -using CryptoExchange.Net.Objects; - -namespace CryptoExchange.Net.RateLimiter -{ - /// - /// Limits the amount of requests per time period to a certain limit, counts the request per API key. - /// - public class RateLimiterAPIKey: IRateLimiter, IDisposable - { - internal Dictionary history = new Dictionary(); - - private readonly SHA256 encryptor; - private readonly int limitPerKey; - private readonly TimeSpan perTimePeriod; - private readonly object historyLock = new object(); - - /// - /// Create a new RateLimiterAPIKey. This rate limiter limits the amount of requests per time period to a certain limit, counts the request per API key. - /// - /// The amount to limit to - /// The time period over which the limit counts - public RateLimiterAPIKey(int limitPerApiKey, TimeSpan perTimePeriod) - { - limitPerKey = limitPerApiKey; - encryptor = SHA256.Create(); - this.perTimePeriod = perTimePeriod; - } - - /// - public CallResult LimitRequest(RestClient client, string url, RateLimitingBehaviour limitBehaviour, int credits = 1) - { - if(client.authProvider?.Credentials?.Key == null) - return new CallResult(0, null); - - var keyBytes = encryptor.ComputeHash(Encoding.UTF8.GetBytes(client.authProvider.Credentials.Key.GetString())); - StringBuilder builder = new StringBuilder(); - for (int i = 0; i < keyBytes.Length; i++) - { - builder.Append(keyBytes[i].ToString("x2")); - } - - var key = builder.ToString(); - - int waitTime; - RateLimitObject rlo; - lock (historyLock) - { - if (history.ContainsKey(key)) - rlo = history[key]; - else - { - rlo = new RateLimitObject(); - history.Add(key, rlo); - } - } - - var sw = Stopwatch.StartNew(); - lock (rlo.LockObject) - { - sw.Stop(); - waitTime = rlo.GetWaitTime(DateTime.UtcNow, limitPerKey, perTimePeriod); - if (waitTime != 0) - { - if (limitBehaviour == RateLimitingBehaviour.Fail) - return new CallResult(waitTime, new RateLimitError($"endpoint limit of {limitPerKey} reached on api key " + key)); - - Thread.Sleep(Convert.ToInt32(waitTime)); - waitTime += (int)sw.ElapsedMilliseconds; - } - - rlo.Add(DateTime.UtcNow); - } - - return new CallResult(waitTime, null); - } - - /// - /// Dispose - /// - public void Dispose() - { - encryptor.Dispose(); - } - } -} diff --git a/CryptoExchange.Net/RateLimiter/RateLimiterCredit.cs b/CryptoExchange.Net/RateLimiter/RateLimiterCredit.cs deleted file mode 100644 index ef6cc2b..0000000 --- a/CryptoExchange.Net/RateLimiter/RateLimiterCredit.cs +++ /dev/null @@ -1,65 +0,0 @@ -using CryptoExchange.Net.Interfaces; -using CryptoExchange.Net.Objects; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Threading; - -namespace CryptoExchange.Net.RateLimiter -{ - /// - /// Limits the amount of requests per time period to a certain limit, counts the total amount of requests. - /// - public class RateLimiterCredit : IRateLimiter - { - internal List history = new List(); - - private readonly int limit; - private readonly TimeSpan perTimePeriod; - private readonly object requestLock = new object(); - - /// - /// Create a new RateLimiterTotal. This rate limiter limits the amount of requests per time period to a certain limit, counts the total amount of requests. - /// - /// The amount to limit to - /// The time period over which the limit counts - public RateLimiterCredit(int limit, TimeSpan perTimePeriod) - { - this.limit = limit; - this.perTimePeriod = perTimePeriod; - } - - /// - public CallResult LimitRequest(RestClient client, string url, RateLimitingBehaviour limitBehaviour, int credits = 1) - { - var sw = Stopwatch.StartNew(); - lock (requestLock) - { - sw.Stop(); - double waitTime = 0; - var checkTime = DateTime.UtcNow; - history.RemoveAll(d => d < checkTime - perTimePeriod); - - if (history.Count >= limit) - { - waitTime = (history.First() - (checkTime - perTimePeriod)).TotalMilliseconds; - if (waitTime > 0) - { - if (limitBehaviour == RateLimitingBehaviour.Fail) - return new CallResult(waitTime, new RateLimitError($"total limit of {limit} reached")); - - Thread.Sleep(Convert.ToInt32(waitTime)); - waitTime += sw.ElapsedMilliseconds; - } - } - - for (int i = 1; i <= credits; i++) - history.Add(DateTime.UtcNow); - - history.Sort(); - return new CallResult(waitTime, null); - } - } - } -} diff --git a/CryptoExchange.Net/RateLimiter/RateLimiterPerEndpoint.cs b/CryptoExchange.Net/RateLimiter/RateLimiterPerEndpoint.cs deleted file mode 100644 index f137ae3..0000000 --- a/CryptoExchange.Net/RateLimiter/RateLimiterPerEndpoint.cs +++ /dev/null @@ -1,68 +0,0 @@ -using CryptoExchange.Net.Interfaces; -using CryptoExchange.Net.Objects; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Threading; - -namespace CryptoExchange.Net.RateLimiter -{ - /// - /// Limits the amount of requests per time period to a certain limit, counts the request per endpoint. - /// - public class RateLimiterPerEndpoint: IRateLimiter - { - internal Dictionary history = new Dictionary(); - - private readonly int limitPerEndpoint; - private readonly TimeSpan perTimePeriod; - private readonly object historyLock = new object(); - - /// - /// Create a new RateLimiterPerEndpoint. This rate limiter limits the amount of requests per time period to a certain limit, counts the request per endpoint. - /// - /// The amount to limit to - /// The time period over which the limit counts - public RateLimiterPerEndpoint(int limitPerEndpoint, TimeSpan perTimePeriod) - { - this.limitPerEndpoint = limitPerEndpoint; - this.perTimePeriod = perTimePeriod; - } - - /// - public CallResult LimitRequest(RestClient client, string url, RateLimitingBehaviour limitingBehaviour, int credits = 1) - { - int waitTime; - RateLimitObject rlo; - lock (historyLock) - { - if (history.ContainsKey(url)) - rlo = history[url]; - else - { - rlo = new RateLimitObject(); - history.Add(url, rlo); - } - } - - var sw = Stopwatch.StartNew(); - lock (rlo.LockObject) - { - sw.Stop(); - waitTime = rlo.GetWaitTime(DateTime.UtcNow, limitPerEndpoint, perTimePeriod); - if (waitTime != 0) - { - if(limitingBehaviour == RateLimitingBehaviour.Fail) - return new CallResult(waitTime, new RateLimitError($"endpoint limit of {limitPerEndpoint} reached on endpoint " + url)); - - Thread.Sleep(Convert.ToInt32(waitTime)); - waitTime += (int)sw.ElapsedMilliseconds; - } - - rlo.Add(DateTime.UtcNow); - } - - return new CallResult(waitTime, null); - } - } -} diff --git a/CryptoExchange.Net/RateLimiter/RateLimiterTotal.cs b/CryptoExchange.Net/RateLimiter/RateLimiterTotal.cs deleted file mode 100644 index a5d1792..0000000 --- a/CryptoExchange.Net/RateLimiter/RateLimiterTotal.cs +++ /dev/null @@ -1,63 +0,0 @@ -using CryptoExchange.Net.Interfaces; -using CryptoExchange.Net.Objects; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Threading; - -namespace CryptoExchange.Net.RateLimiter -{ - /// - /// Limits the amount of requests per time period to a certain limit, counts the total amount of requests. - /// - public class RateLimiterTotal: IRateLimiter - { - internal List history = new List(); - - private readonly int limit; - private readonly TimeSpan perTimePeriod; - private readonly object requestLock = new object(); - - /// - /// Create a new RateLimiterTotal. This rate limiter limits the amount of requests per time period to a certain limit, counts the total amount of requests. - /// - /// The amount to limit to - /// The time period over which the limit counts - public RateLimiterTotal(int limit, TimeSpan perTimePeriod) - { - this.limit = limit; - this.perTimePeriod = perTimePeriod; - } - - /// - public CallResult LimitRequest(RestClient client, string url, RateLimitingBehaviour limitBehaviour, int credits = 1) - { - var sw = Stopwatch.StartNew(); - lock (requestLock) - { - sw.Stop(); - double waitTime = 0; - var checkTime = DateTime.UtcNow; - history.RemoveAll(d => d < checkTime - perTimePeriod); - - if (history.Count >= limit) - { - waitTime = (history.First() - (checkTime - perTimePeriod)).TotalMilliseconds; - if (waitTime > 0) - { - if (limitBehaviour == RateLimitingBehaviour.Fail) - return new CallResult(waitTime, new RateLimitError($"total limit of {limit} reached")); - - Thread.Sleep(Convert.ToInt32(waitTime)); - waitTime += sw.ElapsedMilliseconds; - } - } - - history.Add(DateTime.UtcNow); - history.Sort(); - return new CallResult(waitTime, null); - } - } - } -} diff --git a/CryptoExchange.Net/RestClient.cs b/CryptoExchange.Net/RestClient.cs index 559c0b7..68e06c7 100644 --- a/CryptoExchange.Net/RestClient.cs +++ b/CryptoExchange.Net/RestClient.cs @@ -13,7 +13,6 @@ using System.Web; using CryptoExchange.Net.Authentication; using CryptoExchange.Net.Interfaces; using CryptoExchange.Net.Objects; -using CryptoExchange.Net.RateLimiter; using CryptoExchange.Net.Requests; using Microsoft.Extensions.Logging; using Newtonsoft.Json; @@ -133,7 +132,7 @@ namespace CryptoExchange.Net /// Whether or not the request should be authenticated /// Where the parameters should be placed, overwrites the value set in the client /// How array parameters should be serialized, overwrites the value set in the client - /// Credits used for the request + /// Credits used for the request /// The JsonSerializer to use for deserialization /// Additional headers to send with the request /// @@ -146,7 +145,7 @@ namespace CryptoExchange.Net bool signed = false, HttpMethodParameterPosition? parameterPosition = null, ArrayParametersSerialization? arraySerialization = null, - int credits = 1, + int requestWeight = 1, JsonSerializer? deserializer = null, Dictionary? additionalHeaders = null) where T : class { @@ -162,15 +161,9 @@ namespace CryptoExchange.Net var request = ConstructRequest(uri, method, parameters, signed, paramsPosition, arraySerialization ?? this.arraySerialization, requestId, additionalHeaders); foreach (var limiter in RateLimiters) { - var limitResult = limiter.LimitRequest(this, uri.AbsolutePath, ClientOptions.RateLimitingBehaviour, credits); - if (!limitResult.Success) - { - log.Write(LogLevel.Information, $"[{requestId}] Request {uri.AbsolutePath} failed because of rate limit"); + var limitResult = await limiter.LimitRequestAsync(log, uri.AbsolutePath, method, signed, ClientOptions.ApiCredentials?.Key, ClientOptions.RateLimitingBehaviour, requestWeight, cancellationToken).ConfigureAwait(false); + if (!limitResult.Success) return new WebCallResult(null, null, null, limitResult.Error); - } - - if (limitResult.Data > 0) - log.Write(LogLevel.Information, $"[{requestId}] Request {uri.AbsolutePath} was limited by {limitResult.Data}ms by {limiter.GetType().Name}"); } string? paramString = "";