rs_dapi_client/
address_list.rs1use crate::Uri;
4use chrono::Utc;
5use rand::{rngs::SmallRng, seq::IteratorRandom, SeedableRng};
6use std::collections::hash_map::Entry;
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::mem;
10use std::str::FromStr;
11use std::sync::{Arc, RwLock};
12use std::time::Duration;
13
14const DEFAULT_BASE_BAN_PERIOD: Duration = Duration::from_secs(60);
15
16#[derive(Debug, Clone, Eq)]
18#[cfg_attr(feature = "mocks", derive(serde::Serialize, serde::Deserialize))]
19pub struct Address(#[cfg_attr(feature = "mocks", serde(with = "http_serde::uri"))] Uri);
20
21impl FromStr for Address {
22 type Err = AddressListError;
23
24 fn from_str(s: &str) -> Result<Self, Self::Err> {
25 Uri::from_str(s)
26 .map_err(|e| AddressListError::InvalidAddressUri(e.to_string()))
27 .map(Address::try_from)?
28 }
29}
30
31impl PartialEq<Self> for Address {
32 fn eq(&self, other: &Self) -> bool {
33 self.0 == other.0
34 }
35}
36
37impl PartialEq<Uri> for Address {
38 fn eq(&self, other: &Uri) -> bool {
39 self.0 == *other
40 }
41}
42
43impl Hash for Address {
44 fn hash<H: Hasher>(&self, state: &mut H) {
45 self.0.hash(state);
46 }
47}
48
49impl TryFrom<Uri> for Address {
50 type Error = AddressListError;
51
52 fn try_from(value: Uri) -> Result<Self, Self::Error> {
53 if value.host().is_none() {
54 return Err(AddressListError::InvalidAddressUri(
55 "uri must contain host".to_string(),
56 ));
57 }
58
59 Ok(Address(value))
60 }
61}
62
63impl Address {
64 pub fn uri(&self) -> &Uri {
66 &self.0
67 }
68}
69
70#[derive(Debug, Default, Clone)]
73pub struct AddressStatus {
74 ban_count: usize,
75 banned_until: Option<chrono::DateTime<Utc>>,
76}
77
78impl AddressStatus {
79 pub fn ban(&mut self, base_ban_period: &Duration) {
81 let coefficient = (self.ban_count as f64).exp();
82 let ban_period = Duration::from_secs_f64(base_ban_period.as_secs_f64() * coefficient);
83
84 self.banned_until = Some(chrono::Utc::now() + ban_period);
85 self.ban_count += 1;
86 }
87
88 pub fn is_banned(&self) -> bool {
90 self.ban_count > 0
91 }
92
93 pub fn unban(&mut self) {
95 self.ban_count = 0;
96 self.banned_until = None;
97 }
98}
99
100#[derive(Debug, thiserror::Error, Clone)]
102#[cfg_attr(feature = "mocks", derive(serde::Serialize, serde::Deserialize))]
103pub enum AddressListError {
104 #[error("unable parse address: {0}")]
106 #[cfg_attr(feature = "mocks", serde(skip))]
107 InvalidAddressUri(String),
108}
109
110#[derive(Debug, Clone)]
113pub struct AddressList {
114 addresses: Arc<RwLock<HashMap<Address, AddressStatus>>>,
115 base_ban_period: Duration,
116}
117
118impl Default for AddressList {
119 fn default() -> Self {
120 Self::new()
121 }
122}
123
124impl std::fmt::Display for Address {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 self.0.fmt(f)
127 }
128}
129
130impl AddressList {
131 pub fn new() -> Self {
133 AddressList::with_settings(DEFAULT_BASE_BAN_PERIOD)
134 }
135
136 pub fn with_settings(base_ban_period: Duration) -> Self {
138 AddressList {
139 addresses: Arc::new(RwLock::new(HashMap::new())),
140 base_ban_period,
141 }
142 }
143
144 pub fn ban(&self, address: &Address) -> bool {
147 let mut guard = self.addresses.write().unwrap();
148
149 let Some(status) = guard.get_mut(address) else {
150 return false;
151 };
152
153 status.ban(&self.base_ban_period);
154
155 true
156 }
157
158 pub fn unban(&self, address: &Address) -> bool {
161 let mut guard = self.addresses.write().unwrap();
162
163 let Some(status) = guard.get_mut(address) else {
164 return false;
165 };
166
167 status.unban();
168
169 true
170 }
171
172 pub fn is_banned(&self, address: &Address) -> bool {
174 let guard = self.addresses.read().unwrap();
175
176 guard
177 .get(address)
178 .map(|status| status.is_banned())
179 .unwrap_or(false)
180 }
181
182 pub fn add(&mut self, address: Address) -> bool {
185 let mut guard = self.addresses.write().unwrap();
186
187 match guard.entry(address) {
188 Entry::Occupied(_) => false,
189 Entry::Vacant(e) => {
190 e.insert(AddressStatus::default());
191
192 true
193 }
194 }
195 }
196
197 pub fn remove(&mut self, address: &Address) -> Option<AddressStatus> {
200 let mut guard = self.addresses.write().unwrap();
201
202 guard.remove(address)
203 }
204
205 #[deprecated]
206 pub fn add_uri(&mut self, uri: Uri) -> bool {
210 self.add(Address::try_from(uri).expect("valid uri"))
211 }
212
213 pub fn get_live_address(&self) -> Option<Address> {
215 let guard = self.addresses.read().unwrap();
216
217 let mut rng = SmallRng::from_entropy();
218
219 let now = chrono::Utc::now();
220
221 guard
222 .iter()
223 .filter(|(_, status)| {
224 status
225 .banned_until
226 .map(|banned_until| banned_until < now)
227 .unwrap_or(true)
228 })
229 .choose(&mut rng)
230 .map(|(addr, _)| addr.clone())
231 }
232
233 pub fn get_live_addresses(&self) -> Vec<Address> {
253 let guard = self.addresses.read().unwrap();
254
255 let now = chrono::Utc::now();
256
257 guard
258 .iter()
259 .filter(|(_, status)| {
260 status
261 .banned_until
262 .map(|banned_until| banned_until < now)
263 .unwrap_or(true)
264 })
265 .map(|(addr, _)| addr.clone())
266 .collect()
267 }
268
269 pub fn len(&self) -> usize {
271 self.addresses.read().unwrap().len()
272 }
273
274 pub fn is_empty(&self) -> bool {
279 self.addresses.read().unwrap().is_empty()
280 }
281}
282
283impl IntoIterator for AddressList {
284 type Item = (Address, AddressStatus);
285 type IntoIter = std::collections::hash_map::IntoIter<Address, AddressStatus>;
286
287 fn into_iter(self) -> Self::IntoIter {
288 let mut guard = self.addresses.write().unwrap();
289
290 let addresses_map = mem::take(&mut *guard);
291
292 addresses_map.into_iter()
293 }
294}
295
296impl FromStr for AddressList {
297 type Err = AddressListError;
298
299 fn from_str(s: &str) -> Result<Self, Self::Err> {
300 let uri_list: Vec<Address> = s
301 .split(',')
302 .map(Address::from_str)
303 .collect::<Result<_, _>>()?;
304
305 Ok(Self::from_iter(uri_list))
306 }
307}
308
309impl FromIterator<Address> for AddressList {
310 fn from_iter<T: IntoIterator<Item = Address>>(iter: T) -> Self {
311 let mut address_list = Self::new();
312 for uri in iter {
313 address_list.add(uri);
314 }
315
316 address_list
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323
324 #[test]
325 fn test_get_live_addresses_empty_list() {
326 let list = AddressList::new();
327 let live_addresses = list.get_live_addresses();
328 assert_eq!(live_addresses.len(), 0);
329 }
330
331 #[test]
332 fn test_get_live_addresses_all_unbanned() {
333 let mut list = AddressList::new();
334 list.add("http://127.0.0.1:3000".parse().unwrap());
335 list.add("http://127.0.0.1:3001".parse().unwrap());
336 list.add("http://127.0.0.1:3002".parse().unwrap());
337
338 let live_addresses = list.get_live_addresses();
339 assert_eq!(live_addresses.len(), 3);
340 }
341
342 #[test]
343 fn test_get_live_addresses_some_banned() {
344 let mut list = AddressList::new();
345 let addr1: Address = "http://127.0.0.1:3000".parse().unwrap();
346 let addr2: Address = "http://127.0.0.1:3001".parse().unwrap();
347 let addr3: Address = "http://127.0.0.1:3002".parse().unwrap();
348
349 list.add(addr1.clone());
350 list.add(addr2.clone());
351 list.add(addr3.clone());
352
353 list.ban(&addr2);
355
356 let live_addresses = list.get_live_addresses();
357 assert_eq!(live_addresses.len(), 2);
358 assert!(live_addresses.contains(&addr1));
359 assert!(live_addresses.contains(&addr3));
360 assert!(!live_addresses.contains(&addr2));
361 }
362
363 #[test]
364 fn test_get_live_addresses_all_banned() {
365 let mut list = AddressList::new();
366 let addr1: Address = "http://127.0.0.1:3000".parse().unwrap();
367 let addr2: Address = "http://127.0.0.1:3001".parse().unwrap();
368
369 list.add(addr1.clone());
370 list.add(addr2.clone());
371
372 list.ban(&addr1);
374 list.ban(&addr2);
375
376 let live_addresses = list.get_live_addresses();
377 assert_eq!(live_addresses.len(), 0);
378 }
379
380 #[test]
381 fn test_get_live_addresses_unbanned_after_ban() {
382 let mut list = AddressList::new();
383 let addr1: Address = "http://127.0.0.1:3000".parse().unwrap();
384
385 list.add(addr1.clone());
386
387 list.ban(&addr1);
389 list.unban(&addr1);
390
391 let live_addresses = list.get_live_addresses();
392 assert_eq!(live_addresses.len(), 1);
393 assert!(live_addresses.contains(&addr1));
394 }
395
396 #[test]
397 fn test_address_try_from_uri_without_host() {
398 let uri: Uri = Uri::from_str("/path/only").unwrap();
399 let result = Address::try_from(uri);
400 assert!(result.is_err());
401 let err = result.unwrap_err();
402 assert!(matches!(err, AddressListError::InvalidAddressUri(_)));
403 }
404
405 #[test]
406 fn test_address_from_str_invalid_uri() {
407 let result = Address::from_str("not a valid uri\x00");
409 assert!(result.is_err());
410 }
411
412 #[test]
413 fn test_address_uri_accessor() {
414 let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
415 let uri = addr.uri();
416 assert_eq!(uri.host(), Some("127.0.0.1"));
417 }
418
419 #[test]
420 fn test_address_partial_eq_with_uri() {
421 let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
422 let uri = Uri::from_str("http://127.0.0.1:3000").unwrap();
423 assert!(addr == uri);
424
425 let other_uri = Uri::from_str("http://127.0.0.1:4000").unwrap();
426 assert!(addr != other_uri);
427 }
428
429 #[test]
430 fn test_address_display() {
431 let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
432 let display = format!("{}", addr);
433 assert!(display.contains("127.0.0.1"));
434 }
435
436 #[test]
437 fn test_address_status_is_banned() {
438 let mut status = AddressStatus::default();
439 assert!(!status.is_banned());
440
441 status.ban(&Duration::from_secs(60));
442 assert!(status.is_banned());
443
444 status.unban();
445 assert!(!status.is_banned());
446 }
447
448 #[test]
449 fn test_address_status_exponential_ban() {
450 let mut status = AddressStatus::default();
451 let base_period = Duration::from_secs(1);
452
453 status.ban(&base_period);
455 assert_eq!(status.ban_count, 1);
456 assert!(status.banned_until.is_some());
457
458 status.ban(&base_period);
460 assert_eq!(status.ban_count, 2);
461 }
462
463 #[test]
464 fn test_address_list_is_empty() {
465 let list = AddressList::new();
466 assert!(list.is_empty());
467
468 let mut list = AddressList::new();
469 list.add("http://127.0.0.1:3000".parse().unwrap());
470 assert!(!list.is_empty());
471 }
472
473 #[test]
474 fn test_address_list_len() {
475 let mut list = AddressList::new();
476 assert_eq!(list.len(), 0);
477
478 list.add("http://127.0.0.1:3000".parse().unwrap());
479 assert_eq!(list.len(), 1);
480
481 list.add("http://127.0.0.1:3001".parse().unwrap());
482 assert_eq!(list.len(), 2);
483 }
484
485 #[test]
486 fn test_address_list_add_duplicate() {
487 let mut list = AddressList::new();
488 let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
489
490 assert!(list.add(addr.clone()));
491 assert!(!list.add(addr)); assert_eq!(list.len(), 1);
493 }
494
495 #[test]
496 fn test_address_list_remove() {
497 let mut list = AddressList::new();
498 let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
499
500 list.add(addr.clone());
501 assert_eq!(list.len(), 1);
502
503 let removed = list.remove(&addr);
504 assert!(removed.is_some());
505 assert_eq!(list.len(), 0);
506
507 let removed = list.remove(&addr);
509 assert!(removed.is_none());
510 }
511
512 #[test]
513 fn test_address_list_ban_nonexistent() {
514 let list = AddressList::new();
515 let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
516 assert!(!list.ban(&addr));
517 }
518
519 #[test]
520 fn test_address_list_unban_nonexistent() {
521 let list = AddressList::new();
522 let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
523 assert!(!list.unban(&addr));
524 }
525
526 #[test]
527 fn test_address_list_is_banned() {
528 let mut list = AddressList::new();
529 let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
530 let unknown: Address = "http://127.0.0.1:9999".parse().unwrap();
531
532 list.add(addr.clone());
533
534 assert!(!list.is_banned(&addr));
535 assert!(!list.is_banned(&unknown)); list.ban(&addr);
538 assert!(list.is_banned(&addr));
539 }
540
541 #[test]
542 fn test_address_list_from_str() {
543 let list: AddressList = "http://127.0.0.1:3000,http://127.0.0.1:3001"
544 .parse()
545 .unwrap();
546 assert_eq!(list.len(), 2);
547 }
548
549 #[test]
550 fn test_address_list_from_str_single() {
551 let list: AddressList = "http://127.0.0.1:3000".parse().unwrap();
552 assert_eq!(list.len(), 1);
553 }
554
555 #[test]
556 fn test_address_list_from_str_invalid() {
557 let result: Result<AddressList, _> = "not a valid uri\x00".parse();
558 assert!(result.is_err());
559 }
560
561 #[test]
562 fn test_address_list_get_live_address_returns_none_when_empty() {
563 let list = AddressList::new();
564 assert!(list.get_live_address().is_none());
565 }
566
567 #[test]
568 fn test_address_list_get_live_address_returns_some_when_available() {
569 let mut list = AddressList::new();
570 list.add("http://127.0.0.1:3000".parse().unwrap());
571 assert!(list.get_live_address().is_some());
572 }
573
574 #[test]
575 fn test_address_list_into_iter() {
576 let mut list = AddressList::new();
577 list.add("http://127.0.0.1:3000".parse().unwrap());
578 list.add("http://127.0.0.1:3001".parse().unwrap());
579
580 let items: Vec<_> = list.into_iter().collect();
581 assert_eq!(items.len(), 2);
582 }
583
584 #[test]
585 fn test_address_list_with_settings() {
586 let list = AddressList::with_settings(Duration::from_secs(120));
587 assert!(list.is_empty());
588 }
589
590 #[test]
591 fn test_address_list_default() {
592 let list = AddressList::default();
593 assert!(list.is_empty());
594 }
595}