rs_dapi_client/
address_list.rs

1//! Subsystem to manage DAPI nodes.
2
3use crate::Uri;
4use chrono::Utc;
5use rand::{rngs::SmallRng, seq::IteratorRandom, SeedableRng};
6use std::collections::hash_map::Entry;
7use std::collections::HashMap;
8use std::hash::{Hash, Hasher};
9use std::mem;
10use std::str::FromStr;
11use std::sync::{Arc, RwLock};
12use std::time::Duration;
13
14const DEFAULT_BASE_BAN_PERIOD: Duration = Duration::from_secs(60);
15
16/// DAPI address.
17#[derive(Debug, Clone, Eq)]
18#[cfg_attr(feature = "mocks", derive(serde::Serialize, serde::Deserialize))]
19pub struct Address(#[cfg_attr(feature = "mocks", serde(with = "http_serde::uri"))] Uri);
20
21impl FromStr for Address {
22    type Err = AddressListError;
23
24    fn from_str(s: &str) -> Result<Self, Self::Err> {
25        Uri::from_str(s)
26            .map_err(|e| AddressListError::InvalidAddressUri(e.to_string()))
27            .map(Address::try_from)?
28    }
29}
30
31impl PartialEq<Self> for Address {
32    fn eq(&self, other: &Self) -> bool {
33        self.0 == other.0
34    }
35}
36
37impl PartialEq<Uri> for Address {
38    fn eq(&self, other: &Uri) -> bool {
39        self.0 == *other
40    }
41}
42
43impl Hash for Address {
44    fn hash<H: Hasher>(&self, state: &mut H) {
45        self.0.hash(state);
46    }
47}
48
49impl TryFrom<Uri> for Address {
50    type Error = AddressListError;
51
52    fn try_from(value: Uri) -> Result<Self, Self::Error> {
53        if value.host().is_none() {
54            return Err(AddressListError::InvalidAddressUri(
55                "uri must contain host".to_string(),
56            ));
57        }
58
59        Ok(Address(value))
60    }
61}
62
63impl Address {
64    /// Get [Uri] of a node.
65    pub fn uri(&self) -> &Uri {
66        &self.0
67    }
68}
69
70/// Address status
71/// Contains information about the number of bans and the time until the next ban is lifted.
72#[derive(Debug, Default, Clone)]
73pub struct AddressStatus {
74    ban_count: usize,
75    banned_until: Option<chrono::DateTime<Utc>>,
76}
77
78impl AddressStatus {
79    /// Ban the [Address] so it won't be available through [AddressList::get_live_address] for some time.
80    pub fn ban(&mut self, base_ban_period: &Duration) {
81        let coefficient = (self.ban_count as f64).exp();
82        let ban_period = Duration::from_secs_f64(base_ban_period.as_secs_f64() * coefficient);
83
84        self.banned_until = Some(chrono::Utc::now() + ban_period);
85        self.ban_count += 1;
86    }
87
88    /// Check if [Address] is banned.
89    pub fn is_banned(&self) -> bool {
90        self.ban_count > 0
91    }
92
93    /// Clears ban record.
94    pub fn unban(&mut self) {
95        self.ban_count = 0;
96        self.banned_until = None;
97    }
98}
99
100/// [AddressList] errors
101#[derive(Debug, thiserror::Error, Clone)]
102#[cfg_attr(feature = "mocks", derive(serde::Serialize, serde::Deserialize))]
103pub enum AddressListError {
104    /// A valid uri is required to create an Address
105    #[error("unable parse address: {0}")]
106    #[cfg_attr(feature = "mocks", serde(skip))]
107    InvalidAddressUri(String),
108}
109
110/// A structure to manage DAPI addresses to select from
111/// for [DapiRequest](crate::DapiRequest) execution.
112#[derive(Debug, Clone)]
113pub struct AddressList {
114    addresses: Arc<RwLock<HashMap<Address, AddressStatus>>>,
115    base_ban_period: Duration,
116}
117
118impl Default for AddressList {
119    fn default() -> Self {
120        Self::new()
121    }
122}
123
124impl std::fmt::Display for Address {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        self.0.fmt(f)
127    }
128}
129
130impl AddressList {
131    /// Creates an empty [AddressList] with default base ban time.
132    pub fn new() -> Self {
133        AddressList::with_settings(DEFAULT_BASE_BAN_PERIOD)
134    }
135
136    /// Creates an empty [AddressList] with adjustable base ban time.
137    pub fn with_settings(base_ban_period: Duration) -> Self {
138        AddressList {
139            addresses: Arc::new(RwLock::new(HashMap::new())),
140            base_ban_period,
141        }
142    }
143
144    /// Bans address
145    /// Returns false if the address is not in the list.
146    pub fn ban(&self, address: &Address) -> bool {
147        let mut guard = self.addresses.write().unwrap();
148
149        let Some(status) = guard.get_mut(address) else {
150            return false;
151        };
152
153        status.ban(&self.base_ban_period);
154
155        true
156    }
157
158    /// Clears address' ban record
159    /// Returns false if the address is not in the list.
160    pub fn unban(&self, address: &Address) -> bool {
161        let mut guard = self.addresses.write().unwrap();
162
163        let Some(status) = guard.get_mut(address) else {
164            return false;
165        };
166
167        status.unban();
168
169        true
170    }
171
172    /// Check if the address is banned.
173    pub fn is_banned(&self, address: &Address) -> bool {
174        let guard = self.addresses.read().unwrap();
175
176        guard
177            .get(address)
178            .map(|status| status.is_banned())
179            .unwrap_or(false)
180    }
181
182    /// Adds a node [Address] to [AddressList]
183    /// Returns false if the address is already in the list.
184    pub fn add(&mut self, address: Address) -> bool {
185        let mut guard = self.addresses.write().unwrap();
186
187        match guard.entry(address) {
188            Entry::Occupied(_) => false,
189            Entry::Vacant(e) => {
190                e.insert(AddressStatus::default());
191
192                true
193            }
194        }
195    }
196
197    /// Remove address from the list
198    /// Returns [AddressStatus] if the address was in the list.
199    pub fn remove(&mut self, address: &Address) -> Option<AddressStatus> {
200        let mut guard = self.addresses.write().unwrap();
201
202        guard.remove(address)
203    }
204
205    #[deprecated]
206    // TODO: Remove in favor of add
207    /// Add a node [Address] to [AddressList] by [Uri].
208    /// Returns false if the address is already in the list.
209    pub fn add_uri(&mut self, uri: Uri) -> bool {
210        self.add(Address::try_from(uri).expect("valid uri"))
211    }
212
213    /// Randomly select a not banned address.
214    pub fn get_live_address(&self) -> Option<Address> {
215        let guard = self.addresses.read().unwrap();
216
217        let mut rng = SmallRng::from_entropy();
218
219        let now = chrono::Utc::now();
220
221        guard
222            .iter()
223            .filter(|(_, status)| {
224                status
225                    .banned_until
226                    .map(|banned_until| banned_until < now)
227                    .unwrap_or(true)
228            })
229            .choose(&mut rng)
230            .map(|(addr, _)| addr.clone())
231    }
232
233    /// Get all not banned addresses.
234    ///
235    /// Returns a vector of addresses that are not currently banned or whose ban period has expired.
236    /// The returned addresses use the same filtering logic as [get_live_address], checking if the
237    /// ban period has expired based on the current time.
238    ///
239    /// # Examples
240    ///
241    /// ```
242    /// use rs_dapi_client::{AddressList, Address};
243    ///
244    /// let mut list = AddressList::new();
245    /// list.add("http://127.0.0.1:3000".parse().unwrap());
246    /// list.add("http://127.0.0.1:3001".parse().unwrap());
247    ///
248    /// // Get all non-banned addresses
249    /// let live_addresses = list.get_live_addresses();
250    /// assert_eq!(live_addresses.len(), 2);
251    /// ```
252    pub fn get_live_addresses(&self) -> Vec<Address> {
253        let guard = self.addresses.read().unwrap();
254
255        let now = chrono::Utc::now();
256
257        guard
258            .iter()
259            .filter(|(_, status)| {
260                status
261                    .banned_until
262                    .map(|banned_until| banned_until < now)
263                    .unwrap_or(true)
264            })
265            .map(|(addr, _)| addr.clone())
266            .collect()
267    }
268
269    /// Get number of all addresses, both banned and not banned.
270    pub fn len(&self) -> usize {
271        self.addresses.read().unwrap().len()
272    }
273
274    /// Check if the list is empty.
275    /// Returns true if there are no addresses in the list.
276    /// Returns false if there is at least one address in the list.
277    /// Banned addresses are also counted.
278    pub fn is_empty(&self) -> bool {
279        self.addresses.read().unwrap().is_empty()
280    }
281}
282
283impl IntoIterator for AddressList {
284    type Item = (Address, AddressStatus);
285    type IntoIter = std::collections::hash_map::IntoIter<Address, AddressStatus>;
286
287    fn into_iter(self) -> Self::IntoIter {
288        let mut guard = self.addresses.write().unwrap();
289
290        let addresses_map = mem::take(&mut *guard);
291
292        addresses_map.into_iter()
293    }
294}
295
296impl FromStr for AddressList {
297    type Err = AddressListError;
298
299    fn from_str(s: &str) -> Result<Self, Self::Err> {
300        let uri_list: Vec<Address> = s
301            .split(',')
302            .map(Address::from_str)
303            .collect::<Result<_, _>>()?;
304
305        Ok(Self::from_iter(uri_list))
306    }
307}
308
309impl FromIterator<Address> for AddressList {
310    fn from_iter<T: IntoIterator<Item = Address>>(iter: T) -> Self {
311        let mut address_list = Self::new();
312        for uri in iter {
313            address_list.add(uri);
314        }
315
316        address_list
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323
324    #[test]
325    fn test_get_live_addresses_empty_list() {
326        let list = AddressList::new();
327        let live_addresses = list.get_live_addresses();
328        assert_eq!(live_addresses.len(), 0);
329    }
330
331    #[test]
332    fn test_get_live_addresses_all_unbanned() {
333        let mut list = AddressList::new();
334        list.add("http://127.0.0.1:3000".parse().unwrap());
335        list.add("http://127.0.0.1:3001".parse().unwrap());
336        list.add("http://127.0.0.1:3002".parse().unwrap());
337
338        let live_addresses = list.get_live_addresses();
339        assert_eq!(live_addresses.len(), 3);
340    }
341
342    #[test]
343    fn test_get_live_addresses_some_banned() {
344        let mut list = AddressList::new();
345        let addr1: Address = "http://127.0.0.1:3000".parse().unwrap();
346        let addr2: Address = "http://127.0.0.1:3001".parse().unwrap();
347        let addr3: Address = "http://127.0.0.1:3002".parse().unwrap();
348
349        list.add(addr1.clone());
350        list.add(addr2.clone());
351        list.add(addr3.clone());
352
353        // Ban addr2
354        list.ban(&addr2);
355
356        let live_addresses = list.get_live_addresses();
357        assert_eq!(live_addresses.len(), 2);
358        assert!(live_addresses.contains(&addr1));
359        assert!(live_addresses.contains(&addr3));
360        assert!(!live_addresses.contains(&addr2));
361    }
362
363    #[test]
364    fn test_get_live_addresses_all_banned() {
365        let mut list = AddressList::new();
366        let addr1: Address = "http://127.0.0.1:3000".parse().unwrap();
367        let addr2: Address = "http://127.0.0.1:3001".parse().unwrap();
368
369        list.add(addr1.clone());
370        list.add(addr2.clone());
371
372        // Ban all addresses
373        list.ban(&addr1);
374        list.ban(&addr2);
375
376        let live_addresses = list.get_live_addresses();
377        assert_eq!(live_addresses.len(), 0);
378    }
379
380    #[test]
381    fn test_get_live_addresses_unbanned_after_ban() {
382        let mut list = AddressList::new();
383        let addr1: Address = "http://127.0.0.1:3000".parse().unwrap();
384
385        list.add(addr1.clone());
386
387        // Ban and then unban
388        list.ban(&addr1);
389        list.unban(&addr1);
390
391        let live_addresses = list.get_live_addresses();
392        assert_eq!(live_addresses.len(), 1);
393        assert!(live_addresses.contains(&addr1));
394    }
395
396    #[test]
397    fn test_address_try_from_uri_without_host() {
398        let uri: Uri = Uri::from_str("/path/only").unwrap();
399        let result = Address::try_from(uri);
400        assert!(result.is_err());
401        let err = result.unwrap_err();
402        assert!(matches!(err, AddressListError::InvalidAddressUri(_)));
403    }
404
405    #[test]
406    fn test_address_from_str_invalid_uri() {
407        // Use a string with invalid URI characters that http::Uri rejects
408        let result = Address::from_str("not a valid uri\x00");
409        assert!(result.is_err());
410    }
411
412    #[test]
413    fn test_address_uri_accessor() {
414        let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
415        let uri = addr.uri();
416        assert_eq!(uri.host(), Some("127.0.0.1"));
417    }
418
419    #[test]
420    fn test_address_partial_eq_with_uri() {
421        let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
422        let uri = Uri::from_str("http://127.0.0.1:3000").unwrap();
423        assert!(addr == uri);
424
425        let other_uri = Uri::from_str("http://127.0.0.1:4000").unwrap();
426        assert!(addr != other_uri);
427    }
428
429    #[test]
430    fn test_address_display() {
431        let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
432        let display = format!("{}", addr);
433        assert!(display.contains("127.0.0.1"));
434    }
435
436    #[test]
437    fn test_address_status_is_banned() {
438        let mut status = AddressStatus::default();
439        assert!(!status.is_banned());
440
441        status.ban(&Duration::from_secs(60));
442        assert!(status.is_banned());
443
444        status.unban();
445        assert!(!status.is_banned());
446    }
447
448    #[test]
449    fn test_address_status_exponential_ban() {
450        let mut status = AddressStatus::default();
451        let base_period = Duration::from_secs(1);
452
453        // First ban: coefficient = exp(0) = 1, period = 1s
454        status.ban(&base_period);
455        assert_eq!(status.ban_count, 1);
456        assert!(status.banned_until.is_some());
457
458        // Second ban: coefficient = exp(1) ~= 2.718, period ~= 2.718s
459        status.ban(&base_period);
460        assert_eq!(status.ban_count, 2);
461    }
462
463    #[test]
464    fn test_address_list_is_empty() {
465        let list = AddressList::new();
466        assert!(list.is_empty());
467
468        let mut list = AddressList::new();
469        list.add("http://127.0.0.1:3000".parse().unwrap());
470        assert!(!list.is_empty());
471    }
472
473    #[test]
474    fn test_address_list_len() {
475        let mut list = AddressList::new();
476        assert_eq!(list.len(), 0);
477
478        list.add("http://127.0.0.1:3000".parse().unwrap());
479        assert_eq!(list.len(), 1);
480
481        list.add("http://127.0.0.1:3001".parse().unwrap());
482        assert_eq!(list.len(), 2);
483    }
484
485    #[test]
486    fn test_address_list_add_duplicate() {
487        let mut list = AddressList::new();
488        let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
489
490        assert!(list.add(addr.clone()));
491        assert!(!list.add(addr)); // duplicate returns false
492        assert_eq!(list.len(), 1);
493    }
494
495    #[test]
496    fn test_address_list_remove() {
497        let mut list = AddressList::new();
498        let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
499
500        list.add(addr.clone());
501        assert_eq!(list.len(), 1);
502
503        let removed = list.remove(&addr);
504        assert!(removed.is_some());
505        assert_eq!(list.len(), 0);
506
507        // Removing non-existent address returns None
508        let removed = list.remove(&addr);
509        assert!(removed.is_none());
510    }
511
512    #[test]
513    fn test_address_list_ban_nonexistent() {
514        let list = AddressList::new();
515        let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
516        assert!(!list.ban(&addr));
517    }
518
519    #[test]
520    fn test_address_list_unban_nonexistent() {
521        let list = AddressList::new();
522        let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
523        assert!(!list.unban(&addr));
524    }
525
526    #[test]
527    fn test_address_list_is_banned() {
528        let mut list = AddressList::new();
529        let addr: Address = "http://127.0.0.1:3000".parse().unwrap();
530        let unknown: Address = "http://127.0.0.1:9999".parse().unwrap();
531
532        list.add(addr.clone());
533
534        assert!(!list.is_banned(&addr));
535        assert!(!list.is_banned(&unknown)); // unknown returns false
536
537        list.ban(&addr);
538        assert!(list.is_banned(&addr));
539    }
540
541    #[test]
542    fn test_address_list_from_str() {
543        let list: AddressList = "http://127.0.0.1:3000,http://127.0.0.1:3001"
544            .parse()
545            .unwrap();
546        assert_eq!(list.len(), 2);
547    }
548
549    #[test]
550    fn test_address_list_from_str_single() {
551        let list: AddressList = "http://127.0.0.1:3000".parse().unwrap();
552        assert_eq!(list.len(), 1);
553    }
554
555    #[test]
556    fn test_address_list_from_str_invalid() {
557        let result: Result<AddressList, _> = "not a valid uri\x00".parse();
558        assert!(result.is_err());
559    }
560
561    #[test]
562    fn test_address_list_get_live_address_returns_none_when_empty() {
563        let list = AddressList::new();
564        assert!(list.get_live_address().is_none());
565    }
566
567    #[test]
568    fn test_address_list_get_live_address_returns_some_when_available() {
569        let mut list = AddressList::new();
570        list.add("http://127.0.0.1:3000".parse().unwrap());
571        assert!(list.get_live_address().is_some());
572    }
573
574    #[test]
575    fn test_address_list_into_iter() {
576        let mut list = AddressList::new();
577        list.add("http://127.0.0.1:3000".parse().unwrap());
578        list.add("http://127.0.0.1:3001".parse().unwrap());
579
580        let items: Vec<_> = list.into_iter().collect();
581        assert_eq!(items.len(), 2);
582    }
583
584    #[test]
585    fn test_address_list_with_settings() {
586        let list = AddressList::with_settings(Duration::from_secs(120));
587        assert!(list.is_empty());
588    }
589
590    #[test]
591    fn test_address_list_default() {
592        let list = AddressList::default();
593        assert!(list.is_empty());
594    }
595}