diff --git a/CryptoExchange.Net.UnitTests/ClientTests/RestClientTests.cs b/CryptoExchange.Net.UnitTests/ClientTests/RestClientTests.cs index 4e36ac4..c766d7d 100644 --- a/CryptoExchange.Net.UnitTests/ClientTests/RestClientTests.cs +++ b/CryptoExchange.Net.UnitTests/ClientTests/RestClientTests.cs @@ -147,231 +147,5 @@ namespace CryptoExchange.Net.UnitTests.ClientTests Assert.That(result.RequestBody?.Contains("TestParam2") == true == (pos == HttpMethodParameterPosition.InBody)); Assert.That((result.RequestUrl?.ToString().Contains("TestParam2")) == (pos == HttpMethodParameterPosition.InUri)); } - - - [TestCase(1, 0.1)] - [TestCase(2, 0.1)] - [TestCase(5, 1)] - [TestCase(1, 2)] - public async Task PartialEndpointRateLimiterBasics(int requests, double perSeconds) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new PathStartFilter("/sapi/"), requests, TimeSpan.FromSeconds(perSeconds), RateLimitWindowType.Fixed)); - - var triggered = false; - rateLimiter.RateLimitTriggered += (x) => { triggered = true; }; - var requestDefinition = new RequestDefinition("/sapi/v1/system/status", HttpMethod.Get); - - for (var i = 0; i < requests + 1; i++) - { - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(i == requests ? triggered : !triggered); - } - triggered = false; - await Task.Delay((int)Math.Round(perSeconds * 1000) + 10); - var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(!triggered); - } - - [TestCase("/sapi/test1", true)] - [TestCase("/sapi/test2", true)] - [TestCase("/api/test1", false)] - [TestCase("sapi/test1", true)] - [TestCase("/sapi/", true)] - public async Task PartialEndpointRateLimiterEndpoints(string endpoint, bool expectLimiting) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new PathStartFilter("/sapi/"), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); - - var requestDefinition = new RequestDefinition(endpoint, HttpMethod.Get); - - RateLimitEvent? evnt = null; - rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; - for (var i = 0; i < 2; i++) - { - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - bool expected = i == 1 ? expectLimiting ? evnt?.DelayTime > TimeSpan.Zero : evnt == null : evnt == null; - Assert.That(expected); - } - } - - [TestCase("/sapi/", "/sapi/", true)] - [TestCase("/sapi/test", "/sapi/test", true)] - [TestCase("/sapi/test", "/sapi/test123", false)] - [TestCase("/sapi/test", "/sapi/", false)] - public async Task PartialEndpointRateLimiterEndpoints(string endpoint1, string endpoint2, bool expectLimiting) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerEndpoint, new PathStartFilter("/sapi/"), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); - - var requestDefinition1 = new RequestDefinition(endpoint1, HttpMethod.Get); - var requestDefinition2 = new RequestDefinition(endpoint2, HttpMethod.Get); - - RateLimitEvent? evnt = null; - rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; - - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(evnt == null); - var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition2, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(expectLimiting ? evnt != null : evnt == null); - } - - [TestCase(1, 0.1)] - [TestCase(2, 0.1)] - [TestCase(5, 1)] - [TestCase(1, 2)] - public async Task EndpointRateLimiterBasics(int requests, double perSeconds) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerEndpoint, new PathStartFilter("/sapi/test"), requests, TimeSpan.FromSeconds(perSeconds), RateLimitWindowType.Fixed)); - - bool triggered = false; - rateLimiter.RateLimitTriggered += (x) => { triggered = true; }; - var requestDefinition = new RequestDefinition("/sapi/test", HttpMethod.Get); - - for (var i = 0; i < requests + 1; i++) - { - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(i == requests ? triggered : !triggered); - } - triggered = false; - await Task.Delay((int)Math.Round(perSeconds * 1000) + 10); - var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(!triggered); - } - - [TestCase("/", false)] - [TestCase("/sapi/test", true)] - [TestCase("/sapi/test/123", false)] - public async Task EndpointRateLimiterEndpoints(string endpoint, bool expectLimited) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerEndpoint, new ExactPathFilter("/sapi/test"), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); - - var requestDefinition = new RequestDefinition(endpoint, HttpMethod.Get); - - RateLimitEvent? evnt = null; - rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; - for (var i = 0; i < 2; i++) - { - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - bool expected = i == 1 ? expectLimited ? evnt?.DelayTime > TimeSpan.Zero : evnt == null : evnt == null; - Assert.That(expected); - } - } - - [TestCase("/", false)] - [TestCase("/sapi/test", true)] - [TestCase("/sapi/test2", true)] - [TestCase("/sapi/test23", false)] - public async Task EndpointRateLimiterMultipleEndpoints(string endpoint, bool expectLimited) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerEndpoint, new ExactPathsFilter(new[] { "/sapi/test", "/sapi/test2" }), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); - var requestDefinition = new RequestDefinition(endpoint, HttpMethod.Get); - - RateLimitEvent? evnt = null; - rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; - for (var i = 0; i < 2; i++) - { - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - bool expected = i == 1 ? expectLimited ? evnt?.DelayTime > TimeSpan.Zero : evnt == null : evnt == null; - Assert.That(expected); - } - } - - [TestCase("123", "123", "/sapi/test", "/sapi/test", true)] - [TestCase("123", "456", "/sapi/test", "/sapi/test", false)] - [TestCase("123", "123", "/sapi/test", "/sapi/test2", true)] - [TestCase("123", "123", "/sapi/test2", "/sapi/test", true)] - [TestCase(null, "123", "/sapi/test", "/sapi/test", false)] - [TestCase("123", null, "/sapi/test", "/sapi/test", false)] - [TestCase(null, null, "/sapi/test", "/sapi/test", false)] - public async Task ApiKeyRateLimiterBasics(string key1, string key2, string endpoint1, string endpoint2, bool expectLimited) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerApiKey, new AuthenticatedEndpointFilter(true), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Sliding)); - var requestDefinition1 = new RequestDefinition(endpoint1, HttpMethod.Get) { Authenticated = key1 != null }; - var requestDefinition2 = new RequestDefinition(endpoint2, HttpMethod.Get) { Authenticated = key2 != null }; - - RateLimitEvent? evnt = null; - rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; - - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, "https://test.com", key1, 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(evnt == null); - var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition2, "https://test.com", key2, 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(expectLimited ? evnt != null : evnt == null); - } - - [TestCase("/sapi/test", "/sapi/test", true)] - [TestCase("/sapi/test1", "/api/test2", true)] - [TestCase("/", "/sapi/test2", true)] - public async Task TotalRateLimiterBasics(string endpoint1, string endpoint2, bool expectLimited) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, Array.Empty(), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); - var requestDefinition1 = new RequestDefinition(endpoint1, HttpMethod.Get); - var requestDefinition2 = new RequestDefinition(endpoint2, HttpMethod.Get) { Authenticated = true }; - - RateLimitEvent? evnt = null; - rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; - - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(evnt == null); - var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition2, "https://test.com", null, 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(expectLimited ? evnt != null : evnt == null); - } - - [TestCase("https://test.com", "/sapi/test", "https://test.com", "/sapi/test", true)] - [TestCase("https://test2.com", "/sapi/test", "https://test.com", "/sapi/test", false)] - [TestCase("https://test.com", "/sapi/test", "https://test2.com", "/sapi/test", false)] - [TestCase("https://test.com", "/sapi/test", "https://test.com", "/sapi/test2", true)] - public async Task HostRateLimiterBasics(string host1, string endpoint1, string host2, string endpoint2, bool expectLimited) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new HostFilter("https://test.com"), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); - var requestDefinition1 = new RequestDefinition(endpoint1, HttpMethod.Get); - var requestDefinition2 = new RequestDefinition(endpoint2, HttpMethod.Get) { Authenticated = true }; - - RateLimitEvent? evnt = null; - rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; - - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, host1, "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(evnt == null); - var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, host2, "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(expectLimited ? evnt != null : evnt == null); - } - - [TestCase("https://test.com", "https://test.com", true)] - [TestCase("https://test2.com", "https://test.com", false)] - [TestCase("https://test.com", "https://test2.com", false)] - public async Task ConnectionRateLimiterBasics(string host1, string host2, bool expectLimited) - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new LimitItemTypeFilter(RateLimitItemType.Connection), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); - - RateLimitEvent? evnt = null; - rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; - - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Connection, new RequestDefinition("1", HttpMethod.Get), host1, "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(evnt == null); - var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Connection, new RequestDefinition("1", HttpMethod.Get), host2, "123", 1, RateLimitingBehaviour.Wait, null, default); - Assert.That(expectLimited ? evnt != null : evnt == null); - } - - [Test] - public async Task ConnectionRateLimiterCancel() - { - var rateLimiter = new RateLimitGate("Test"); - rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new LimitItemTypeFilter(RateLimitItemType.Connection), 1, TimeSpan.FromSeconds(10), RateLimitWindowType.Fixed)); - - RateLimitEvent? evnt = null; - rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; - var ct = new CancellationTokenSource(TimeSpan.FromSeconds(0.2)); - - var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Connection, new RequestDefinition("1", HttpMethod.Get), "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, ct.Token); - var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Connection, new RequestDefinition("1", HttpMethod.Get), "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, ct.Token); - Assert.That(result2.Error, Is.TypeOf()); - } } } diff --git a/CryptoExchange.Net.UnitTests/RateLimitTests.cs b/CryptoExchange.Net.UnitTests/RateLimitTests.cs new file mode 100644 index 0000000..719be6b --- /dev/null +++ b/CryptoExchange.Net.UnitTests/RateLimitTests.cs @@ -0,0 +1,289 @@ +using CryptoExchange.Net.Objects; +using CryptoExchange.Net.RateLimiting; +using CryptoExchange.Net.RateLimiting.Filters; +using CryptoExchange.Net.RateLimiting.Guards; +using CryptoExchange.Net.RateLimiting.Interfaces; +using NUnit.Framework; +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace CryptoExchange.Net.UnitTests +{ + [TestFixture()] + public class RateLimitTests + { + [TestCase(1, 0.1)] + [TestCase(2, 0.1)] + [TestCase(5, 1)] + [TestCase(1, 2)] + public async Task PartialEndpointRateLimiterBasics(int requests, double perSeconds) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new PathStartFilter("/sapi/"), requests, TimeSpan.FromSeconds(perSeconds), RateLimitWindowType.Fixed)); + + var triggered = false; + rateLimiter.RateLimitTriggered += (x) => { triggered = true; }; + var requestDefinition = new RequestDefinition("/sapi/v1/system/status", HttpMethod.Get); + + for (var i = 0; i < requests + 1; i++) + { + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(i == requests ? triggered : !triggered); + } + triggered = false; + await Task.Delay((int)Math.Round(perSeconds * 1000) + 10); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(!triggered); + } + + [TestCase("/sapi/test1", true)] + [TestCase("/sapi/test2", true)] + [TestCase("/api/test1", false)] + [TestCase("sapi/test1", true)] + [TestCase("/sapi/", true)] + public async Task PartialEndpointRateLimiterEndpoints(string endpoint, bool expectLimiting) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new PathStartFilter("/sapi/"), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); + + var requestDefinition = new RequestDefinition(endpoint, HttpMethod.Get); + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + for (var i = 0; i < 2; i++) + { + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + bool expected = i == 1 ? expectLimiting ? evnt?.DelayTime > TimeSpan.Zero : evnt == null : evnt == null; + Assert.That(expected); + } + } + + [TestCase("/sapi/", "/sapi/", true)] + [TestCase("/sapi/test", "/sapi/test", true)] + [TestCase("/sapi/test", "/sapi/test123", false)] + [TestCase("/sapi/test", "/sapi/", false)] + public async Task PartialEndpointRateLimiterEndpoints(string endpoint1, string endpoint2, bool expectLimiting) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerEndpoint, new PathStartFilter("/sapi/"), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); + + var requestDefinition1 = new RequestDefinition(endpoint1, HttpMethod.Get); + var requestDefinition2 = new RequestDefinition(endpoint2, HttpMethod.Get); + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(evnt == null); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition2, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(expectLimiting ? evnt != null : evnt == null); + } + + [TestCase(1, 0.1)] + [TestCase(2, 0.1)] + [TestCase(5, 1)] + [TestCase(1, 2)] + public async Task EndpointRateLimiterBasics(int requests, double perSeconds) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerEndpoint, new PathStartFilter("/sapi/test"), requests, TimeSpan.FromSeconds(perSeconds), RateLimitWindowType.Fixed)); + + bool triggered = false; + rateLimiter.RateLimitTriggered += (x) => { triggered = true; }; + var requestDefinition = new RequestDefinition("/sapi/test", HttpMethod.Get); + + for (var i = 0; i < requests + 1; i++) + { + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(i == requests ? triggered : !triggered); + } + triggered = false; + await Task.Delay((int)Math.Round(perSeconds * 1000) + 10); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(!triggered); + } + + [TestCase("/", false)] + [TestCase("/sapi/test", true)] + [TestCase("/sapi/test/123", false)] + public async Task EndpointRateLimiterEndpoints(string endpoint, bool expectLimited) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerEndpoint, new ExactPathFilter("/sapi/test"), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); + + var requestDefinition = new RequestDefinition(endpoint, HttpMethod.Get); + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + for (var i = 0; i < 2; i++) + { + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + bool expected = i == 1 ? expectLimited ? evnt?.DelayTime > TimeSpan.Zero : evnt == null : evnt == null; + Assert.That(expected); + } + } + + [TestCase("/", false)] + [TestCase("/sapi/test", true)] + [TestCase("/sapi/test2", true)] + [TestCase("/sapi/test23", false)] + public async Task EndpointRateLimiterMultipleEndpoints(string endpoint, bool expectLimited) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerEndpoint, new ExactPathsFilter(new[] { "/sapi/test", "/sapi/test2" }), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); + var requestDefinition = new RequestDefinition(endpoint, HttpMethod.Get); + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + for (var i = 0; i < 2; i++) + { + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + bool expected = i == 1 ? expectLimited ? evnt?.DelayTime > TimeSpan.Zero : evnt == null : evnt == null; + Assert.That(expected); + } + } + + [TestCase("123", "123", "/sapi/test", "/sapi/test", true)] + [TestCase("123", "456", "/sapi/test", "/sapi/test", false)] + [TestCase("123", "123", "/sapi/test", "/sapi/test2", true)] + [TestCase("123", "123", "/sapi/test2", "/sapi/test", true)] + [TestCase(null, "123", "/sapi/test", "/sapi/test", false)] + [TestCase("123", null, "/sapi/test", "/sapi/test", false)] + [TestCase(null, null, "/sapi/test", "/sapi/test", false)] + public async Task ApiKeyRateLimiterBasics(string key1, string key2, string endpoint1, string endpoint2, bool expectLimited) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerApiKey, new AuthenticatedEndpointFilter(true), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Sliding)); + var requestDefinition1 = new RequestDefinition(endpoint1, HttpMethod.Get) { Authenticated = key1 != null }; + var requestDefinition2 = new RequestDefinition(endpoint2, HttpMethod.Get) { Authenticated = key2 != null }; + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, "https://test.com", key1, 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(evnt == null); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition2, "https://test.com", key2, 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(expectLimited ? evnt != null : evnt == null); + } + + [TestCase("/sapi/test", "/sapi/test", true)] + [TestCase("/sapi/test1", "/api/test2", true)] + [TestCase("/", "/sapi/test2", true)] + public async Task TotalRateLimiterBasics(string endpoint1, string endpoint2, bool expectLimited) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, Array.Empty(), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); + var requestDefinition1 = new RequestDefinition(endpoint1, HttpMethod.Get); + var requestDefinition2 = new RequestDefinition(endpoint2, HttpMethod.Get) { Authenticated = true }; + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(evnt == null); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition2, "https://test.com", null, 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(expectLimited ? evnt != null : evnt == null); + } + + [TestCase("https://test.com", "/sapi/test", "https://test.com", "/sapi/test", true)] + [TestCase("https://test2.com", "/sapi/test", "https://test.com", "/sapi/test", false)] + [TestCase("https://test.com", "/sapi/test", "https://test2.com", "/sapi/test", false)] + [TestCase("https://test.com", "/sapi/test", "https://test.com", "/sapi/test2", true)] + public async Task HostRateLimiterBasics(string host1, string endpoint1, string host2, string endpoint2, bool expectLimited) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new HostFilter("https://test.com"), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); + var requestDefinition1 = new RequestDefinition(endpoint1, HttpMethod.Get); + var requestDefinition2 = new RequestDefinition(endpoint2, HttpMethod.Get) { Authenticated = true }; + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, host1, "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(evnt == null); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, requestDefinition1, host2, "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(expectLimited ? evnt != null : evnt == null); + } + + [TestCase("https://test.com", "https://test.com", true)] + [TestCase("https://test2.com", "https://test.com", false)] + [TestCase("https://test.com", "https://test2.com", false)] + public async Task ConnectionRateLimiterBasics(string host1, string host2, bool expectLimited) + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new LimitItemTypeFilter(RateLimitItemType.Connection), 1, TimeSpan.FromSeconds(0.1), RateLimitWindowType.Fixed)); + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Connection, new RequestDefinition("1", HttpMethod.Get), host1, "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(evnt == null); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Connection, new RequestDefinition("1", HttpMethod.Get), host2, "123", 1, RateLimitingBehaviour.Wait, null, default); + Assert.That(expectLimited ? evnt != null : evnt == null); + } + + [Test] + public async Task ConnectionRateLimiterCancel() + { + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerHost, new LimitItemTypeFilter(RateLimitItemType.Connection), 1, TimeSpan.FromSeconds(10), RateLimitWindowType.Fixed)); + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + var ct = new CancellationTokenSource(TimeSpan.FromSeconds(0.2)); + + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Connection, new RequestDefinition("1", HttpMethod.Get), "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, ct.Token); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Connection, new RequestDefinition("1", HttpMethod.Get), "https://test.com", "123", 1, RateLimitingBehaviour.Wait, null, ct.Token); + Assert.That(result2.Error, Is.TypeOf()); + } + + [Test] + public async Task RateLimiterReset_Should_AllowNextRequestForSameDefinition() + { + // arrange + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerConnection, new LimitItemTypeFilter(RateLimitItemType.Request), 1, TimeSpan.FromSeconds(10), RateLimitWindowType.Fixed)); + + var definition = new RequestDefinition("1", HttpMethod.Get) { ConnectionId = 1 }; + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + var ct = new CancellationTokenSource(TimeSpan.FromSeconds(0.2)); + + // act + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, definition, "https://test.com", null, 1, RateLimitingBehaviour.Fail, null, ct.Token); + await rateLimiter.ResetAsync(RateLimitItemType.Request, definition, "https://test.com", null, null, default); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, definition, "https://test.com", null, 1, RateLimitingBehaviour.Fail, null, ct.Token); + + // assert + Assert.That(evnt, Is.Null); + } + + [Test] + public async Task RateLimiterReset_Should_NotAllowNextRequestForDifferentDefinition() + { + // arrange + var rateLimiter = new RateLimitGate("Test"); + rateLimiter.AddGuard(new RateLimitGuard(RateLimitGuard.PerConnection, new LimitItemTypeFilter(RateLimitItemType.Request), 1, TimeSpan.FromSeconds(10), RateLimitWindowType.Fixed)); + + var definition1 = new RequestDefinition("1", HttpMethod.Get) { ConnectionId = 1 }; + var definition2 = new RequestDefinition("2", HttpMethod.Get) { ConnectionId = 2 }; + + RateLimitEvent? evnt = null; + rateLimiter.RateLimitTriggered += (x) => { evnt = x; }; + + // act + var result1 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, definition1, "https://test.com", null, 1, RateLimitingBehaviour.Fail, null, default); + var result2 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, definition2, "https://test.com", null, 1, RateLimitingBehaviour.Fail, null, default); + await rateLimiter.ResetAsync(RateLimitItemType.Request, definition1, "https://test.com", null, null, default); + var result3 = await rateLimiter.ProcessAsync(new TraceLogger(), 1, RateLimitItemType.Request, definition2, "https://test.com", null, 1, RateLimitingBehaviour.Fail, null, default); + + // assert + Assert.That(evnt, Is.Not.Null); + } + } +} diff --git a/CryptoExchange.Net/RateLimiting/Guards/RateLimitGuard.cs b/CryptoExchange.Net/RateLimiting/Guards/RateLimitGuard.cs index 5fd5d38..dabd62e 100644 --- a/CryptoExchange.Net/RateLimiting/Guards/RateLimitGuard.cs +++ b/CryptoExchange.Net/RateLimiting/Guards/RateLimitGuard.cs @@ -3,6 +3,7 @@ using CryptoExchange.Net.RateLimiting.Interfaces; using CryptoExchange.Net.RateLimiting.Trackers; using System; using System.Collections.Generic; +using System.Linq; using System.Threading; namespace CryptoExchange.Net.RateLimiting.Guards @@ -126,7 +127,6 @@ namespace CryptoExchange.Net.RateLimiting.Guards _trackers.Add(key, tracker); } - var delay = tracker.GetWaitTime(requestWeight); if (delay == default) return LimitCheck.NotNeeded(Limit, TimeSpan, tracker.Current); @@ -172,6 +172,27 @@ namespace CryptoExchange.Net.RateLimiting.Guards return RateLimitState.Applied(Limit, TimeSpan, tracker.Current); } + /// + public void Reset(RateLimitItemType type, RequestDefinition definition, string host, string? apiKey, string? keySuffix) + { + if (SharedGuard) + _sharedGuardSemaphore!.Wait(); + + try + { + var key = _keySelector(definition, host, apiKey) + keySuffix; + if (!_trackers.TryGetValue(key, out var tracker)) + return; + + tracker.Reset(); + } + finally + { + if (SharedGuard) + _sharedGuardSemaphore!.Release(); + } + } + /// /// Create a new WindowTracker /// diff --git a/CryptoExchange.Net/RateLimiting/Guards/RetryAfterGuard.cs b/CryptoExchange.Net/RateLimiting/Guards/RetryAfterGuard.cs index c9c4bbe..1b4fe6d 100644 --- a/CryptoExchange.Net/RateLimiting/Guards/RetryAfterGuard.cs +++ b/CryptoExchange.Net/RateLimiting/Guards/RetryAfterGuard.cs @@ -65,5 +65,11 @@ namespace CryptoExchange.Net.RateLimiting.Guards /// /// public void UpdateAfter(DateTime after) => After = after; + + /// + public void Reset(RateLimitItemType type, RequestDefinition definition, string host, string? apiKey, string? keySuffix) + { + After = DateTime.UtcNow; + } } } diff --git a/CryptoExchange.Net/RateLimiting/Guards/SingleLimitGuard.cs b/CryptoExchange.Net/RateLimiting/Guards/SingleLimitGuard.cs index f3e7998..6f35f2d 100644 --- a/CryptoExchange.Net/RateLimiting/Guards/SingleLimitGuard.cs +++ b/CryptoExchange.Net/RateLimiting/Guards/SingleLimitGuard.cs @@ -88,5 +88,15 @@ namespace CryptoExchange.Net.RateLimiting.Guards : _windowType == RateLimitWindowType.Fixed ? new FixedWindowTracker(_limit, _period) : new DecayWindowTracker(_limit, _period, _decayRate ?? throw new InvalidOperationException("Decay rate not provided")); } + + /// + public void Reset(RateLimitItemType type, RequestDefinition definition, string host, string? apiKey, string? keySuffix) + { + var key = _keySelector(definition, host, apiKey) + keySuffix; + if (!_trackers.TryGetValue(key, out var tracker)) + return; + + tracker.Reset(); + } } } diff --git a/CryptoExchange.Net/RateLimiting/Interfaces/IRateLimitGate.cs b/CryptoExchange.Net/RateLimiting/Interfaces/IRateLimitGate.cs index 8d9b8dc..09f279c 100644 --- a/CryptoExchange.Net/RateLimiting/Interfaces/IRateLimitGate.cs +++ b/CryptoExchange.Net/RateLimiting/Interfaces/IRateLimitGate.cs @@ -74,5 +74,22 @@ namespace CryptoExchange.Net.RateLimiting.Interfaces /// Cancelation token /// Error if RateLimitingBehaviour is Fail and rate limit is hit ValueTask ProcessSingleAsync(ILogger logger, int itemId, IRateLimitGuard guard, RateLimitItemType type, RequestDefinition definition, string baseAddress, string? apiKey, int requestWeight, RateLimitingBehaviour behaviour, string? keySuffix, CancellationToken ct); + + /// + /// Reset the limit for the specified parameters + /// + /// The rate limit item type + /// The request definition + /// The host address + /// The API key + /// An additional optional suffix for the key selector. Can be used to make rate limiting work based on parameters. + /// Cancelation token + Task ResetAsync( + RateLimitItemType type, + RequestDefinition definition, + string host, + string? apiKey, + string? keySuffix, + CancellationToken ct); } } diff --git a/CryptoExchange.Net/RateLimiting/Interfaces/IRateLimitGuard.cs b/CryptoExchange.Net/RateLimiting/Interfaces/IRateLimitGuard.cs index 9a3f6a7..23124a1 100644 --- a/CryptoExchange.Net/RateLimiting/Interfaces/IRateLimitGuard.cs +++ b/CryptoExchange.Net/RateLimiting/Interfaces/IRateLimitGuard.cs @@ -40,5 +40,15 @@ namespace CryptoExchange.Net.RateLimiting.Interfaces /// An additional optional suffix for the key selector. Can be used to make rate limiting work based on parameters. /// RateLimitState ApplyWeight(RateLimitItemType type, RequestDefinition definition, string host, string? apiKey, int requestWeight, string? keySuffix); + + /// + /// Reset the limit for the specified parameters + /// + /// The rate limit item type + /// The request definition + /// The host address + /// The API key + /// An additional optional suffix for the key selector. Can be used to make rate limiting work based on parameters. + void Reset(RateLimitItemType type, RequestDefinition definition, string host, string? apiKey, string? keySuffix); } } diff --git a/CryptoExchange.Net/RateLimiting/Interfaces/IWindowTracker.cs b/CryptoExchange.Net/RateLimiting/Interfaces/IWindowTracker.cs index ab768dd..1ddf927 100644 --- a/CryptoExchange.Net/RateLimiting/Interfaces/IWindowTracker.cs +++ b/CryptoExchange.Net/RateLimiting/Interfaces/IWindowTracker.cs @@ -30,5 +30,9 @@ namespace CryptoExchange.Net.RateLimiting.Interfaces /// /// Request weight void ApplyWeight(int weight); + /// + /// Reset the limit counter for this tracker + /// + void Reset(); } } diff --git a/CryptoExchange.Net/RateLimiting/RateLimitGate.cs b/CryptoExchange.Net/RateLimiting/RateLimitGate.cs index 518a105..7427954 100644 --- a/CryptoExchange.Net/RateLimiting/RateLimitGate.cs +++ b/CryptoExchange.Net/RateLimiting/RateLimitGate.cs @@ -192,5 +192,27 @@ namespace CryptoExchange.Net.RateLimiting _semaphore.Release(); } } + + + /// + public async Task ResetAsync( + RateLimitItemType type, + RequestDefinition definition, + string host, + string? apiKey, + string? keySuffix, + CancellationToken ct) + { + await _semaphore.WaitAsync(ct).ConfigureAwait(false); + try + { + foreach (var guard in _guards) + guard.Reset(type, definition, host, apiKey, keySuffix); + } + finally + { + _semaphore.Release(); + } + } } } diff --git a/CryptoExchange.Net/RateLimiting/Trackers/DecayWindowTracker.cs b/CryptoExchange.Net/RateLimiting/Trackers/DecayWindowTracker.cs index 86b783b..4201b80 100644 --- a/CryptoExchange.Net/RateLimiting/Trackers/DecayWindowTracker.cs +++ b/CryptoExchange.Net/RateLimiting/Trackers/DecayWindowTracker.cs @@ -26,6 +26,13 @@ namespace CryptoExchange.Net.RateLimiting.Trackers DecreaseRate = decayRate; } + /// + public void Reset() + { + _currentWeight = 0; + _lastDecrease = DateTime.UtcNow; + } + /// public TimeSpan GetWaitTime(int weight) { diff --git a/CryptoExchange.Net/RateLimiting/Trackers/FixedAfterStartWindowTracker.cs b/CryptoExchange.Net/RateLimiting/Trackers/FixedAfterStartWindowTracker.cs index be0d713..34be989 100644 --- a/CryptoExchange.Net/RateLimiting/Trackers/FixedAfterStartWindowTracker.cs +++ b/CryptoExchange.Net/RateLimiting/Trackers/FixedAfterStartWindowTracker.cs @@ -29,6 +29,14 @@ namespace CryptoExchange.Net.RateLimiting.Trackers _entries = new Queue(); } + /// + public void Reset() + { + _entries.Clear(); + _currentWeight = 0; + _nextReset = null; + } + public TimeSpan GetWaitTime(int weight) { // Remove requests no longer in time period from the history diff --git a/CryptoExchange.Net/RateLimiting/Trackers/FixedWindowTracker.cs b/CryptoExchange.Net/RateLimiting/Trackers/FixedWindowTracker.cs index 1481894..9b76e58 100644 --- a/CryptoExchange.Net/RateLimiting/Trackers/FixedWindowTracker.cs +++ b/CryptoExchange.Net/RateLimiting/Trackers/FixedWindowTracker.cs @@ -28,6 +28,13 @@ namespace CryptoExchange.Net.RateLimiting.Trackers _entries = new Queue(); } + /// + public void Reset() + { + _entries.Clear(); + _currentWeight = 0; + } + /// public TimeSpan GetWaitTime(int weight) { diff --git a/CryptoExchange.Net/RateLimiting/Trackers/SlidingWindowTracker.cs b/CryptoExchange.Net/RateLimiting/Trackers/SlidingWindowTracker.cs index 19f9abb..ea42547 100644 --- a/CryptoExchange.Net/RateLimiting/Trackers/SlidingWindowTracker.cs +++ b/CryptoExchange.Net/RateLimiting/Trackers/SlidingWindowTracker.cs @@ -28,6 +28,13 @@ namespace CryptoExchange.Net.RateLimiting.Trackers _entries = new List(); } + /// + public void Reset() + { + _entries.Clear(); + _currentWeight = 0; + } + /// public TimeSpan GetWaitTime(int weight) { diff --git a/CryptoExchange.Net/Sockets/Default/CryptoExchangeWebSocketClient.cs b/CryptoExchange.Net/Sockets/Default/CryptoExchangeWebSocketClient.cs index 413a434..31ac86a 100644 --- a/CryptoExchange.Net/Sockets/Default/CryptoExchangeWebSocketClient.cs +++ b/CryptoExchange.Net/Sockets/Default/CryptoExchangeWebSocketClient.cs @@ -48,6 +48,7 @@ namespace CryptoExchange.Net.Sockets.Default private readonly string _baseAddress; private int _reconnectAttempt; private readonly int _receiveBufferSize; + private readonly RequestDefinition _requestDefinition; private const int _sendBufferSize = 4096; @@ -137,6 +138,7 @@ namespace CryptoExchange.Net.Sockets.Default _sendBuffer = new ConcurrentQueue(); _ctsSource = new CancellationTokenSource(); _receiveBufferSize = websocketParameters.ReceiveBufferSize ?? 65536; + _requestDefinition = new RequestDefinition(Uri.AbsolutePath, HttpMethod.Get) { ConnectionId = Id }; _closeSem = new SemaphoreSlim(1, 1); _socket = CreateSocket(); @@ -206,8 +208,7 @@ namespace CryptoExchange.Net.Sockets.Default { if (Parameters.RateLimiter != null) { - var definition = new RequestDefinition(Uri.AbsolutePath, HttpMethod.Get) { ConnectionId = Id }; - var limitResult = await Parameters.RateLimiter.ProcessAsync(_logger, Id, RateLimitItemType.Connection, definition, _baseAddress, null, 1, Parameters.RateLimitingBehavior, null, _ctsSource.Token).ConfigureAwait(false); + var limitResult = await Parameters.RateLimiter.ProcessAsync(_logger, Id, RateLimitItemType.Connection, _requestDefinition, _baseAddress, null, 1, Parameters.RateLimitingBehavior, null, _ctsSource.Token).ConfigureAwait(false); if (!limitResult) return new CallResult(new ClientRateLimitError("Connection limit reached")); } @@ -296,6 +297,9 @@ namespace CryptoExchange.Net.Sockets.Default await (OnReconnecting?.Invoke() ?? Task.CompletedTask).ConfigureAwait(false); } + if (Parameters.RateLimiter != null) + await Parameters.RateLimiter.ResetAsync(RateLimitItemType.Connection, _requestDefinition, _baseAddress, null, null, _ctsSource.Token).ConfigureAwait(false); + // Delay here to prevent very rapid looping when a connection to the server is accepted and immediately disconnected var initialDelay = GetReconnectDelay(); await Task.Delay(initialDelay).ConfigureAwait(false); @@ -496,7 +500,6 @@ namespace CryptoExchange.Net.Sockets.Default /// private async Task SendLoopAsync() { - var requestDefinition = new RequestDefinition(Uri.AbsolutePath, HttpMethod.Get) { ConnectionId = Id }; try { while (true) @@ -520,7 +523,7 @@ namespace CryptoExchange.Net.Sockets.Default { try { - var limitResult = await Parameters.RateLimiter.ProcessAsync(_logger, data.Id, RateLimitItemType.Request, requestDefinition, _baseAddress, null, data.Weight, Parameters.RateLimitingBehavior, null, _ctsSource.Token).ConfigureAwait(false); + var limitResult = await Parameters.RateLimiter.ProcessAsync(_logger, data.Id, RateLimitItemType.Request, _requestDefinition, _baseAddress, null, data.Weight, Parameters.RateLimitingBehavior, null, _ctsSource.Token).ConfigureAwait(false); if (!limitResult) { await (OnRequestRateLimited?.Invoke(data.Id) ?? Task.CompletedTask).ConfigureAwait(false);