rs_dapi_client/
connection_pool.rs

1use std::{
2    fmt::Display,
3    sync::{Arc, Mutex},
4};
5
6use lru::LruCache;
7
8use crate::{
9    request_settings::AppliedRequestSettings,
10    transport::{CoreGrpcClient, PlatformGrpcClient},
11    Uri,
12};
13
14/// ConnectionPool represents pool of connections to DAPI nodes.
15///
16/// It can be cloned and shared between threads.
17/// Cloning the pool will create a new reference to the same pool.
18#[derive(Debug, Clone)]
19pub struct ConnectionPool {
20    inner: Arc<Mutex<LruCache<String, PoolItem>>>,
21}
22
23impl ConnectionPool {
24    /// Create a new pool with a given capacity.
25    /// The pool will evict the least recently used item when the capacity is reached.
26    ///
27    /// # Panics
28    ///
29    /// Panics if the capacity is zero.
30    pub fn new(capacity: usize) -> Self {
31        Self {
32            inner: Arc::new(Mutex::new(LruCache::new(
33                capacity.try_into().expect("must be non-zero"),
34            ))),
35        }
36    }
37}
38
39impl Default for ConnectionPool {
40    fn default() -> Self {
41        Self::new(50)
42    }
43}
44
45impl ConnectionPool {
46    /// Get item from the pool for the given uri and settings.
47    ///
48    /// # Arguments
49    /// * `prefix` -  Prefix for the item in the pool. Used to distinguish between Core and Platform clients.
50    /// * `uri` - URI of the node.
51    /// * `settings` - Applied request settings.
52    pub fn get(
53        &self,
54        prefix: PoolPrefix,
55        uri: &Uri,
56        settings: Option<&AppliedRequestSettings>,
57    ) -> Option<PoolItem> {
58        let key = Self::key(prefix, uri, settings);
59        self.inner.lock().expect("must lock").get(&key).cloned()
60    }
61
62    /// Get value from cache or create it using provided closure.
63    /// If value is already in the cache, it will be returned.
64    /// If value is not in the cache, it will be created by calling `create()` and stored in the cache.
65    ///
66    /// # Arguments
67    /// * `prefix` -  Prefix for the item in the pool. Used to distinguish between Core and Platform clients.
68    /// * `uri` - URI of the node.
69    /// * `settings` - Applied request settings.
70    pub fn get_or_create<E>(
71        &self,
72        prefix: PoolPrefix,
73        uri: &Uri,
74        settings: Option<&AppliedRequestSettings>,
75        create: impl FnOnce() -> Result<PoolItem, E>,
76    ) -> Result<PoolItem, E> {
77        if let Some(cli) = self.get(prefix, uri, settings) {
78            return Ok(cli);
79        }
80
81        let cli = create();
82        if let Ok(cli) = &cli {
83            self.put(uri, settings, cli.clone());
84        }
85        cli
86    }
87
88    /// Put item into the pool for the given uri and settings.
89    pub fn put(&self, uri: &Uri, settings: Option<&AppliedRequestSettings>, value: PoolItem) {
90        let key = Self::key(&value, uri, settings);
91        self.inner.lock().expect("must lock").put(key, value);
92    }
93
94    fn key<C: Into<PoolPrefix>>(
95        class: C,
96        uri: &Uri,
97        settings: Option<&AppliedRequestSettings>,
98    ) -> String {
99        let prefix: PoolPrefix = class.into();
100        format!("{}:{}{:?}", prefix, uri, settings)
101    }
102}
103
104/// Item stored in the pool.
105///
106/// We use an enum as we need to represent two different types of clients.
107#[derive(Clone, Debug)]
108pub enum PoolItem {
109    Core(CoreGrpcClient),
110    Platform(PlatformGrpcClient),
111}
112
113impl From<PlatformGrpcClient> for PoolItem {
114    fn from(client: PlatformGrpcClient) -> Self {
115        Self::Platform(client)
116    }
117}
118impl From<CoreGrpcClient> for PoolItem {
119    fn from(client: CoreGrpcClient) -> Self {
120        Self::Core(client)
121    }
122}
123
124impl From<PoolItem> for PlatformGrpcClient {
125    fn from(client: PoolItem) -> Self {
126        match client {
127            PoolItem::Platform(client) => client,
128            _ => {
129                tracing::error!(
130                    ?client,
131                    "invalid connection fetched from pool: expected platform client"
132                );
133                panic!("ClientType is not Platform: {:?}", client)
134            }
135        }
136    }
137}
138
139impl From<PoolItem> for CoreGrpcClient {
140    fn from(client: PoolItem) -> Self {
141        match client {
142            PoolItem::Core(client) => client,
143            _ => {
144                tracing::error!(
145                    ?client,
146                    "invalid connection fetched from pool: expected core client"
147                );
148                panic!("ClientType is not Core: {:?}", client)
149            }
150        }
151    }
152}
153
154/// Prefix for the item in the pool. Used to distinguish between Core and Platform clients.
155pub enum PoolPrefix {
156    Core,
157    Platform,
158}
159impl Display for PoolPrefix {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        match self {
162            PoolPrefix::Core => write!(f, "Core"),
163            PoolPrefix::Platform => write!(f, "Platform"),
164        }
165    }
166}
167impl From<&PoolItem> for PoolPrefix {
168    fn from(item: &PoolItem) -> Self {
169        match item {
170            PoolItem::Core(_) => PoolPrefix::Core,
171            PoolItem::Platform(_) => PoolPrefix::Platform,
172        }
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use dapi_grpc::tonic::transport::Channel;
180    use std::str::FromStr;
181
182    fn test_uri() -> Uri {
183        Uri::from_str("http://127.0.0.1:3000").unwrap()
184    }
185
186    fn make_platform_pool_item() -> PoolItem {
187        let channel = Channel::builder(test_uri()).connect_lazy();
188        PoolItem::Platform(PlatformGrpcClient::new(channel))
189    }
190
191    fn make_core_pool_item() -> PoolItem {
192        let channel = Channel::builder(test_uri()).connect_lazy();
193        PoolItem::Core(CoreGrpcClient::new(channel))
194    }
195
196    #[test]
197    fn test_connection_pool_new() {
198        let pool = ConnectionPool::new(10);
199        let result = pool.get(PoolPrefix::Platform, &test_uri(), None);
200        assert!(result.is_none());
201    }
202
203    #[test]
204    fn test_connection_pool_default() {
205        let pool = ConnectionPool::default();
206        let result = pool.get(PoolPrefix::Core, &test_uri(), None);
207        assert!(result.is_none());
208    }
209
210    #[tokio::test]
211    async fn test_connection_pool_put_and_get_platform() {
212        let pool = ConnectionPool::new(10);
213        let uri = test_uri();
214        let item = make_platform_pool_item();
215
216        pool.put(&uri, None, item);
217
218        let result = pool.get(PoolPrefix::Platform, &uri, None);
219        assert!(result.is_some());
220        assert!(matches!(result.unwrap(), PoolItem::Platform(_)));
221    }
222
223    #[tokio::test]
224    async fn test_connection_pool_put_and_get_core() {
225        let pool = ConnectionPool::new(10);
226        let uri = test_uri();
227        let item = make_core_pool_item();
228
229        pool.put(&uri, None, item);
230
231        let result = pool.get(PoolPrefix::Core, &uri, None);
232        assert!(result.is_some());
233        assert!(matches!(result.unwrap(), PoolItem::Core(_)));
234    }
235
236    #[tokio::test]
237    async fn test_connection_pool_get_or_create_creates_new() {
238        let pool = ConnectionPool::new(10);
239        let uri = test_uri();
240
241        let result: Result<PoolItem, String> =
242            pool.get_or_create(PoolPrefix::Platform, &uri, None, || {
243                Ok(make_platform_pool_item())
244            });
245
246        assert!(result.is_ok());
247
248        // Second call should return cached version
249        let mut create_called = false;
250        let result2: Result<PoolItem, String> =
251            pool.get_or_create(PoolPrefix::Platform, &uri, None, || {
252                create_called = true;
253                Ok(make_platform_pool_item())
254            });
255
256        assert!(result2.is_ok());
257        assert!(
258            !create_called,
259            "create should not be called for cached item"
260        );
261    }
262
263    #[test]
264    fn test_connection_pool_get_or_create_error_not_cached() {
265        let pool = ConnectionPool::new(10);
266        let uri = test_uri();
267
268        let result: Result<PoolItem, String> =
269            pool.get_or_create(PoolPrefix::Platform, &uri, None, || {
270                Err("creation failed".to_string())
271            });
272
273        assert!(result.is_err());
274
275        // Pool should still be empty after failed creation
276        let cached = pool.get(PoolPrefix::Platform, &uri, None);
277        assert!(cached.is_none());
278    }
279
280    #[test]
281    fn test_pool_prefix_display() {
282        assert_eq!(format!("{}", PoolPrefix::Core), "Core");
283        assert_eq!(format!("{}", PoolPrefix::Platform), "Platform");
284    }
285
286    #[tokio::test]
287    async fn test_pool_prefix_from_pool_item() {
288        let platform_item = make_platform_pool_item();
289        let prefix: PoolPrefix = (&platform_item).into();
290        assert!(matches!(prefix, PoolPrefix::Platform));
291
292        let core_item = make_core_pool_item();
293        let prefix: PoolPrefix = (&core_item).into();
294        assert!(matches!(prefix, PoolPrefix::Core));
295    }
296
297    #[tokio::test]
298    async fn test_pool_item_from_platform_client() {
299        let channel = Channel::builder(test_uri()).connect_lazy();
300        let client = PlatformGrpcClient::new(channel);
301        let item: PoolItem = client.into();
302        assert!(matches!(item, PoolItem::Platform(_)));
303    }
304
305    #[tokio::test]
306    async fn test_pool_item_from_core_client() {
307        let channel = Channel::builder(test_uri()).connect_lazy();
308        let client = CoreGrpcClient::new(channel);
309        let item: PoolItem = client.into();
310        assert!(matches!(item, PoolItem::Core(_)));
311    }
312
313    #[tokio::test]
314    async fn test_pool_item_into_platform_client() {
315        let item = make_platform_pool_item();
316        let _client: PlatformGrpcClient = item.into();
317    }
318
319    #[tokio::test]
320    async fn test_pool_item_into_core_client() {
321        let item = make_core_pool_item();
322        let _client: CoreGrpcClient = item.into();
323    }
324
325    #[tokio::test]
326    #[should_panic(expected = "ClientType is not Platform")]
327    async fn test_pool_item_core_into_platform_panics() {
328        let item = make_core_pool_item();
329        let _client: PlatformGrpcClient = item.into();
330    }
331
332    #[tokio::test]
333    #[should_panic(expected = "ClientType is not Core")]
334    async fn test_pool_item_platform_into_core_panics() {
335        let item = make_platform_pool_item();
336        let _client: CoreGrpcClient = item.into();
337    }
338
339    #[tokio::test]
340    async fn test_connection_pool_different_prefixes_different_keys() {
341        let pool = ConnectionPool::new(10);
342        let uri = test_uri();
343
344        pool.put(&uri, None, make_platform_pool_item());
345
346        // Core prefix should not find a Platform item
347        let result = pool.get(PoolPrefix::Core, &uri, None);
348        assert!(result.is_none());
349
350        // Platform prefix should find it
351        let result = pool.get(PoolPrefix::Platform, &uri, None);
352        assert!(result.is_some());
353    }
354
355    #[tokio::test]
356    async fn test_connection_pool_clone_shares_data() {
357        let pool = ConnectionPool::new(10);
358        let pool_clone = pool.clone();
359        let uri = test_uri();
360
361        pool.put(&uri, None, make_platform_pool_item());
362
363        // Clone should see the same data
364        let result = pool_clone.get(PoolPrefix::Platform, &uri, None);
365        assert!(result.is_some());
366    }
367}