1use std::{
2 fmt::Display,
3 sync::{Arc, Mutex},
4};
5
6use lru::LruCache;
7
8use crate::{
9 request_settings::AppliedRequestSettings,
10 transport::{CoreGrpcClient, PlatformGrpcClient},
11 Uri,
12};
13
14#[derive(Debug, Clone)]
19pub struct ConnectionPool {
20 inner: Arc<Mutex<LruCache<String, PoolItem>>>,
21}
22
23impl ConnectionPool {
24 pub fn new(capacity: usize) -> Self {
31 Self {
32 inner: Arc::new(Mutex::new(LruCache::new(
33 capacity.try_into().expect("must be non-zero"),
34 ))),
35 }
36 }
37}
38
39impl Default for ConnectionPool {
40 fn default() -> Self {
41 Self::new(50)
42 }
43}
44
45impl ConnectionPool {
46 pub fn get(
53 &self,
54 prefix: PoolPrefix,
55 uri: &Uri,
56 settings: Option<&AppliedRequestSettings>,
57 ) -> Option<PoolItem> {
58 let key = Self::key(prefix, uri, settings);
59 self.inner.lock().expect("must lock").get(&key).cloned()
60 }
61
62 pub fn get_or_create<E>(
71 &self,
72 prefix: PoolPrefix,
73 uri: &Uri,
74 settings: Option<&AppliedRequestSettings>,
75 create: impl FnOnce() -> Result<PoolItem, E>,
76 ) -> Result<PoolItem, E> {
77 if let Some(cli) = self.get(prefix, uri, settings) {
78 return Ok(cli);
79 }
80
81 let cli = create();
82 if let Ok(cli) = &cli {
83 self.put(uri, settings, cli.clone());
84 }
85 cli
86 }
87
88 pub fn put(&self, uri: &Uri, settings: Option<&AppliedRequestSettings>, value: PoolItem) {
90 let key = Self::key(&value, uri, settings);
91 self.inner.lock().expect("must lock").put(key, value);
92 }
93
94 fn key<C: Into<PoolPrefix>>(
95 class: C,
96 uri: &Uri,
97 settings: Option<&AppliedRequestSettings>,
98 ) -> String {
99 let prefix: PoolPrefix = class.into();
100 format!("{}:{}{:?}", prefix, uri, settings)
101 }
102}
103
104#[derive(Clone, Debug)]
108pub enum PoolItem {
109 Core(CoreGrpcClient),
110 Platform(PlatformGrpcClient),
111}
112
113impl From<PlatformGrpcClient> for PoolItem {
114 fn from(client: PlatformGrpcClient) -> Self {
115 Self::Platform(client)
116 }
117}
118impl From<CoreGrpcClient> for PoolItem {
119 fn from(client: CoreGrpcClient) -> Self {
120 Self::Core(client)
121 }
122}
123
124impl From<PoolItem> for PlatformGrpcClient {
125 fn from(client: PoolItem) -> Self {
126 match client {
127 PoolItem::Platform(client) => client,
128 _ => {
129 tracing::error!(
130 ?client,
131 "invalid connection fetched from pool: expected platform client"
132 );
133 panic!("ClientType is not Platform: {:?}", client)
134 }
135 }
136 }
137}
138
139impl From<PoolItem> for CoreGrpcClient {
140 fn from(client: PoolItem) -> Self {
141 match client {
142 PoolItem::Core(client) => client,
143 _ => {
144 tracing::error!(
145 ?client,
146 "invalid connection fetched from pool: expected core client"
147 );
148 panic!("ClientType is not Core: {:?}", client)
149 }
150 }
151 }
152}
153
154pub enum PoolPrefix {
156 Core,
157 Platform,
158}
159impl Display for PoolPrefix {
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 match self {
162 PoolPrefix::Core => write!(f, "Core"),
163 PoolPrefix::Platform => write!(f, "Platform"),
164 }
165 }
166}
167impl From<&PoolItem> for PoolPrefix {
168 fn from(item: &PoolItem) -> Self {
169 match item {
170 PoolItem::Core(_) => PoolPrefix::Core,
171 PoolItem::Platform(_) => PoolPrefix::Platform,
172 }
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use dapi_grpc::tonic::transport::Channel;
180 use std::str::FromStr;
181
182 fn test_uri() -> Uri {
183 Uri::from_str("http://127.0.0.1:3000").unwrap()
184 }
185
186 fn make_platform_pool_item() -> PoolItem {
187 let channel = Channel::builder(test_uri()).connect_lazy();
188 PoolItem::Platform(PlatformGrpcClient::new(channel))
189 }
190
191 fn make_core_pool_item() -> PoolItem {
192 let channel = Channel::builder(test_uri()).connect_lazy();
193 PoolItem::Core(CoreGrpcClient::new(channel))
194 }
195
196 #[test]
197 fn test_connection_pool_new() {
198 let pool = ConnectionPool::new(10);
199 let result = pool.get(PoolPrefix::Platform, &test_uri(), None);
200 assert!(result.is_none());
201 }
202
203 #[test]
204 fn test_connection_pool_default() {
205 let pool = ConnectionPool::default();
206 let result = pool.get(PoolPrefix::Core, &test_uri(), None);
207 assert!(result.is_none());
208 }
209
210 #[tokio::test]
211 async fn test_connection_pool_put_and_get_platform() {
212 let pool = ConnectionPool::new(10);
213 let uri = test_uri();
214 let item = make_platform_pool_item();
215
216 pool.put(&uri, None, item);
217
218 let result = pool.get(PoolPrefix::Platform, &uri, None);
219 assert!(result.is_some());
220 assert!(matches!(result.unwrap(), PoolItem::Platform(_)));
221 }
222
223 #[tokio::test]
224 async fn test_connection_pool_put_and_get_core() {
225 let pool = ConnectionPool::new(10);
226 let uri = test_uri();
227 let item = make_core_pool_item();
228
229 pool.put(&uri, None, item);
230
231 let result = pool.get(PoolPrefix::Core, &uri, None);
232 assert!(result.is_some());
233 assert!(matches!(result.unwrap(), PoolItem::Core(_)));
234 }
235
236 #[tokio::test]
237 async fn test_connection_pool_get_or_create_creates_new() {
238 let pool = ConnectionPool::new(10);
239 let uri = test_uri();
240
241 let result: Result<PoolItem, String> =
242 pool.get_or_create(PoolPrefix::Platform, &uri, None, || {
243 Ok(make_platform_pool_item())
244 });
245
246 assert!(result.is_ok());
247
248 let mut create_called = false;
250 let result2: Result<PoolItem, String> =
251 pool.get_or_create(PoolPrefix::Platform, &uri, None, || {
252 create_called = true;
253 Ok(make_platform_pool_item())
254 });
255
256 assert!(result2.is_ok());
257 assert!(
258 !create_called,
259 "create should not be called for cached item"
260 );
261 }
262
263 #[test]
264 fn test_connection_pool_get_or_create_error_not_cached() {
265 let pool = ConnectionPool::new(10);
266 let uri = test_uri();
267
268 let result: Result<PoolItem, String> =
269 pool.get_or_create(PoolPrefix::Platform, &uri, None, || {
270 Err("creation failed".to_string())
271 });
272
273 assert!(result.is_err());
274
275 let cached = pool.get(PoolPrefix::Platform, &uri, None);
277 assert!(cached.is_none());
278 }
279
280 #[test]
281 fn test_pool_prefix_display() {
282 assert_eq!(format!("{}", PoolPrefix::Core), "Core");
283 assert_eq!(format!("{}", PoolPrefix::Platform), "Platform");
284 }
285
286 #[tokio::test]
287 async fn test_pool_prefix_from_pool_item() {
288 let platform_item = make_platform_pool_item();
289 let prefix: PoolPrefix = (&platform_item).into();
290 assert!(matches!(prefix, PoolPrefix::Platform));
291
292 let core_item = make_core_pool_item();
293 let prefix: PoolPrefix = (&core_item).into();
294 assert!(matches!(prefix, PoolPrefix::Core));
295 }
296
297 #[tokio::test]
298 async fn test_pool_item_from_platform_client() {
299 let channel = Channel::builder(test_uri()).connect_lazy();
300 let client = PlatformGrpcClient::new(channel);
301 let item: PoolItem = client.into();
302 assert!(matches!(item, PoolItem::Platform(_)));
303 }
304
305 #[tokio::test]
306 async fn test_pool_item_from_core_client() {
307 let channel = Channel::builder(test_uri()).connect_lazy();
308 let client = CoreGrpcClient::new(channel);
309 let item: PoolItem = client.into();
310 assert!(matches!(item, PoolItem::Core(_)));
311 }
312
313 #[tokio::test]
314 async fn test_pool_item_into_platform_client() {
315 let item = make_platform_pool_item();
316 let _client: PlatformGrpcClient = item.into();
317 }
318
319 #[tokio::test]
320 async fn test_pool_item_into_core_client() {
321 let item = make_core_pool_item();
322 let _client: CoreGrpcClient = item.into();
323 }
324
325 #[tokio::test]
326 #[should_panic(expected = "ClientType is not Platform")]
327 async fn test_pool_item_core_into_platform_panics() {
328 let item = make_core_pool_item();
329 let _client: PlatformGrpcClient = item.into();
330 }
331
332 #[tokio::test]
333 #[should_panic(expected = "ClientType is not Core")]
334 async fn test_pool_item_platform_into_core_panics() {
335 let item = make_platform_pool_item();
336 let _client: CoreGrpcClient = item.into();
337 }
338
339 #[tokio::test]
340 async fn test_connection_pool_different_prefixes_different_keys() {
341 let pool = ConnectionPool::new(10);
342 let uri = test_uri();
343
344 pool.put(&uri, None, make_platform_pool_item());
345
346 let result = pool.get(PoolPrefix::Core, &uri, None);
348 assert!(result.is_none());
349
350 let result = pool.get(PoolPrefix::Platform, &uri, None);
352 assert!(result.is_some());
353 }
354
355 #[tokio::test]
356 async fn test_connection_pool_clone_shares_data() {
357 let pool = ConnectionPool::new(10);
358 let pool_clone = pool.clone();
359 let uri = test_uri();
360
361 pool.put(&uri, None, make_platform_pool_item());
362
363 let result = pool_clone.get(PoolPrefix::Platform, &uri, None);
365 assert!(result.is_some());
366 }
367}