1use crate::error::Error;
8use dash_context_provider::ContextProviderError;
9use rs_dapi_client::{
10 transport::sleep, update_address_ban_status, AddressList, CanRetry, ExecutionResult,
11 RequestSettings,
12};
13use std::time::Duration;
14use std::{fmt::Debug, future::Future, sync::mpsc::SendError};
15
16#[derive(Debug, thiserror::Error)]
17pub enum AsyncError {
18 #[cfg(not(target_arch = "wasm32"))]
20 #[error("not running inside tokio runtime: {0}")]
21 NotInTokioRuntime(#[from] tokio::runtime::TryCurrentError),
22
23 #[error("cannot receive response from async function: {0}")]
25 RecvError(#[from] std::sync::mpsc::RecvError),
26
27 #[error("cannot send response from async function: {0}")]
29 SendError(String),
30
31 #[error("asynchronous call from synchronous context failed: {0}")]
32 #[allow(unused)]
33 Generic(String),
34}
35
36impl<T> From<SendError<T>> for AsyncError {
37 fn from(error: SendError<T>) -> Self {
38 Self::SendError(error.to_string())
39 }
40}
41
42impl From<AsyncError> for ContextProviderError {
43 fn from(error: AsyncError) -> Self {
44 ContextProviderError::AsyncError(error.to_string())
45 }
46}
47
48impl From<AsyncError> for crate::Error {
49 fn from(error: AsyncError) -> Self {
50 Self::ContextProviderError(error.into())
51 }
52}
53
54#[cfg(not(target_arch = "wasm32"))]
62pub fn block_on<F>(fut: F) -> Result<F::Output, AsyncError>
63where
64 F: Future + Send + 'static,
65 F::Output: Send,
66{
67 tracing::trace!("block_on: running async function from sync code");
68 let rt = tokio::runtime::Handle::try_current()?;
69 let (tx, rx) = std::sync::mpsc::channel();
70 tracing::trace!("block_on: Spawning worker");
71 let hdl = rt.spawn(worker(fut, tx));
72 tracing::trace!("block_on: Worker spawned");
73 let resp = tokio::task::block_in_place(|| rx.recv())?;
74
75 tracing::trace!("Response received");
76 if !hdl.is_finished() {
77 tracing::debug!("async-sync worker future is not finished, aborting; this should not happen, but it's fine");
78 hdl.abort(); }
80
81 Ok(resp)
82}
83
84#[cfg(target_arch = "wasm32")]
85pub fn block_on<F>(_fut: F) -> Result<F::Output, AsyncError>
86where
87 F: Future + Send + 'static,
88 F::Output: Send,
89{
90 unimplemented!("block_on is not supported in wasm");
91}
92
93#[cfg(not(target_arch = "wasm32"))]
95async fn worker<F: Future>(
96 fut: F,
97 response: std::sync::mpsc::Sender<F::Output>,
99) -> Result<(), AsyncError> {
100 tracing::trace!("Worker start");
101 let result = fut.await;
102 tracing::trace!("Worker async function completed, sending response");
103 response.send(result)?;
104 tracing::trace!("Worker response sent");
105
106 Ok(())
107}
108
109pub async fn retry<Fut, FutureFactoryFn, R>(
166 address_list: &AddressList,
167 settings: RequestSettings,
168 mut future_factory_fn: FutureFactoryFn,
169) -> ExecutionResult<R, Error>
170where
171 Fut: Future<Output = ExecutionResult<R, Error>>,
172 FutureFactoryFn: FnMut(RequestSettings) -> Fut,
173 R: Send,
174{
175 let max_retries = settings.retries.unwrap_or_default();
176 let mut total_retries: usize = 0;
177 let mut current_settings = settings;
178
179 let mut last_meaningful_error: Option<rs_dapi_client::ExecutionError<Error>> = None;
182
183 loop {
184 let result = future_factory_fn(current_settings).await;
185
186 update_address_ban_status(address_list, &result, ¤t_settings.finalize());
188
189 match result {
190 Ok(response) => return Ok(response),
191 Err(error) => {
192 if error.is_no_available_addresses() {
194 if let Some(prev_error) = last_meaningful_error.take() {
195 tracing::error!(
196 retry = total_retries,
197 max_retries,
198 error = ?prev_error,
199 "no addresses available to retry"
200 );
201 return Err(rs_dapi_client::ExecutionError {
203 inner: Error::NoAvailableAddressesToRetry(Box::new(prev_error.inner)),
204 retries: total_retries,
205 address: prev_error.address,
206 });
207 }
208 return Err(error);
210 }
211
212 let requests_sent = error.retries + 1;
214 total_retries += requests_sent;
215
216 if !error.can_retry() {
217 let mut final_error = error;
219 final_error.retries = total_retries;
220 return Err(final_error);
221 }
222
223 if total_retries > max_retries {
224 tracing::error!(
226 retry = total_retries,
227 max_retries,
228 error = ?error,
229 "no more retries left, giving up"
230 );
231 let mut final_error = error;
232 final_error.retries = total_retries;
233 return Err(final_error);
234 }
235
236 tracing::warn!(
238 retry = total_retries,
239 max_retries,
240 error = ?error,
241 "retrying request"
242 );
243
244 current_settings.retries = Some(max_retries.saturating_sub(total_retries));
246
247 let delay = Duration::from_millis(10);
250 tracing::warn!(duration = ?delay, error = ?error, "request failed, retrying");
251
252 last_meaningful_error = Some(error);
254
255 sleep(delay).await;
256 }
257 }
258 }
259}
260
261#[cfg(test)]
262mod test {
263 use super::*;
264 use rs_dapi_client::ExecutionError;
265 use std::{
266 future::Future,
267 sync::{
268 atomic::{AtomicUsize, Ordering},
269 Arc,
270 },
271 };
272 use tokio::{
273 runtime::Builder,
274 sync::mpsc::{self, Receiver},
275 };
276
277 #[test]
283 fn test_block_on_nested_async_sync() {
284 let rt = Builder::new_multi_thread()
285 .worker_threads(1) .max_blocking_threads(1) .enable_all()
288 .build()
289 .expect("Failed to create Tokio runtime");
290 for _repeat in 0..5 {
292 const MSGS: usize = 10;
295 let (tx, rx) = mpsc::channel::<usize>(1);
296
297 let worker = async move {
299 for count in 0..MSGS {
300 tx.send(count).await.unwrap();
301 }
302 };
303 let worker_join = rt.spawn(worker);
304 let levels = 4;
306
307 async fn innermost_async_function(
309 mut rx: Receiver<usize>,
310 ) -> Result<String, ContextProviderError> {
311 for i in 0..MSGS {
312 let count = rx.recv().await.unwrap();
313 assert_eq!(count, i);
314 }
315
316 Ok(String::from("Success"))
317 }
318
319 fn nested_sync_function<F>(fut: F) -> Result<String, ContextProviderError>
321 where
322 F: Future<Output = Result<String, ContextProviderError>> + Send + 'static,
323 F::Output: Send,
324 {
325 block_on(fut)?.map_err(|e| ContextProviderError::Generic(e.to_string()))
326 }
327
328 async fn outer_async_function(
330 levels: usize,
331 rx: Receiver<usize>,
332 ) -> Result<String, ContextProviderError> {
333 let mut result = innermost_async_function(rx).await;
334 for _ in 0..levels {
335 result = nested_sync_function(async { result });
336 }
337 result
338 }
339
340 let result = rt.block_on(outer_async_function(levels, rx));
342
343 rt.block_on(worker_join).unwrap();
344 assert_eq!(result.unwrap(), "Success");
346 }
347 }
348
349 use crate::error::StaleNodeError;
350 use rs_dapi_client::DapiClientError;
351
352 async fn retry_test_function(
353 settings: RequestSettings,
354 counter: Arc<AtomicUsize>,
355 ) -> ExecutionResult<(), Error> {
356 let retries = counter.load(Ordering::Relaxed);
358 let retries = if settings.retries.unwrap_or_default() < retries {
359 settings.retries.unwrap_or_default()
360 } else {
361 retries
362 };
363
364 counter.fetch_add(1 + retries, Ordering::Relaxed);
366
367 Err(ExecutionError {
368 inner: Error::StaleNode(StaleNodeError::Height {
369 expected_height: 100,
370 received_height: 50,
371 tolerance_blocks: 1,
372 }),
373 retries,
374 address: Some("http://localhost".parse().expect("valid address")),
375 })
376 }
377
378 #[test_case::test_matrix([1,2,3,5,7,8,10,11,23,49, usize::MAX])]
379 #[tokio::test]
380 async fn test_retry(expected_requests: usize) {
381 for _ in 0..1 {
382 let counter = Arc::new(AtomicUsize::new(0));
383
384 let address_list = AddressList::default();
385
386 let mut global_settings = RequestSettings::default();
388 global_settings.retries = Some(expected_requests - 1);
389
390 let closure = |s| {
391 let counter = counter.clone();
392 retry_test_function(s, counter)
393 };
394
395 retry(&address_list, global_settings, closure)
396 .await
397 .expect_err("should fail");
398
399 assert_eq!(
400 counter.load(Ordering::Relaxed),
401 expected_requests,
402 "test failed for expected {} requests",
403 expected_requests
404 );
405 }
406 }
407
408 #[tokio::test]
417 async fn test_retry_returns_last_meaningful_error_on_no_addresses() {
418 let call_count = Arc::new(AtomicUsize::new(0));
419 let address_list = AddressList::default();
420
421 let mut settings = RequestSettings::default();
422 settings.retries = Some(5);
423
424 let call_count_clone = call_count.clone();
425 let closure = move |_settings: RequestSettings| {
426 let count = call_count_clone.fetch_add(1, Ordering::Relaxed);
427 async move {
428 if count == 0 {
429 Err(ExecutionError {
431 inner: Error::StaleNode(StaleNodeError::Height {
432 expected_height: 100,
433 received_height: 50,
434 tolerance_blocks: 1,
435 }),
436 retries: 0,
437 address: Some("http://localhost:1".parse().unwrap()),
438 })
439 } else {
440 Err(ExecutionError {
442 inner: Error::DapiClientError(DapiClientError::NoAvailableAddresses),
443 retries: 0,
444 address: None,
445 })
446 }
447 }
448 };
449
450 let result: ExecutionResult<(), Error> = retry(&address_list, settings, closure).await;
451
452 let error = result.expect_err("should fail");
454 match &error.inner {
455 Error::NoAvailableAddressesToRetry(inner) => {
456 assert!(
457 matches!(**inner, Error::StaleNode(_)),
458 "inner error should be StaleNode, got: {:?}",
459 inner
460 );
461 }
462 _ => panic!(
463 "expected NoAvailableAddresses error, got: {:?}",
464 error.inner
465 ),
466 }
467 assert_eq!(
468 call_count.load(Ordering::Relaxed),
469 2,
470 "should have called twice"
471 );
472 }
473
474 #[tokio::test]
477 async fn test_retry_returns_no_addresses_if_no_previous_error() {
478 let address_list = AddressList::default();
479
480 let mut settings = RequestSettings::default();
481 settings.retries = Some(5);
482
483 let closure = move |_settings: RequestSettings| async move {
484 Err(ExecutionError {
486 inner: Error::DapiClientError(DapiClientError::NoAvailableAddresses),
487 retries: 0,
488 address: None,
489 })
490 };
491
492 let result: ExecutionResult<(), Error> = retry(&address_list, settings, closure).await;
493
494 let error = result.expect_err("should fail");
495 assert!(
496 matches!(
497 error.inner,
498 Error::DapiClientError(DapiClientError::NoAvailableAddresses)
499 ),
500 "should return 'no available addresses' when there's no previous meaningful error, got: {:?}",
501 error.inner
502 );
503 }
504}