1use std::future::Future;
2use std::sync::mpsc::SendError;
3
4#[derive(Debug, thiserror::Error)]
5pub enum AsyncError {
6 #[cfg(not(target_arch = "wasm32"))]
8 #[error("not running inside tokio runtime: {0}")]
9 NotInTokioRuntime(#[from] tokio::runtime::TryCurrentError),
10
11 #[error("cannot receive response from async function: {0}")]
13 RecvError(#[from] std::sync::mpsc::RecvError),
14
15 #[error("cannot send response from async function: {0}")]
17 SendError(String),
18
19 #[error("asynchronous call from synchronous context failed: {0}")]
20 #[allow(unused)]
21 Generic(String),
22}
23
24impl<T> From<SendError<T>> for AsyncError {
25 fn from(error: SendError<T>) -> Self {
26 Self::SendError(error.to_string())
27 }
28}
29
30#[cfg(not(target_arch = "wasm32"))]
43pub fn block_on<F>(fut: F) -> Result<F::Output, AsyncError>
44where
45 F: Future + Send + 'static,
46 F::Output: Send,
47{
48 use tokio::runtime::RuntimeFlavor;
49
50 tracing::trace!("block_on: running async function from sync code");
51
52 let handle = match tokio::runtime::Handle::try_current() {
53 Ok(h) => h,
54 Err(e) => {
55 tracing::trace!("block_on: no active runtime ({e}), creating temporary runtime");
56 return Ok(tokio::runtime::Builder::new_current_thread()
57 .enable_all()
58 .build()
59 .map_err(|e| AsyncError::Generic(e.to_string()))?
60 .block_on(fut));
61 }
62 };
63
64 match handle.runtime_flavor() {
65 RuntimeFlavor::CurrentThread => {
66 tracing::trace!("block_on: current-thread runtime, spawning dedicated OS thread");
67 let (tx, rx) = std::sync::mpsc::sync_channel::<Result<F::Output, AsyncError>>(1);
68 let join_handle = std::thread::spawn(move || {
69 let result = tokio::runtime::Builder::new_current_thread()
70 .enable_all()
71 .build()
72 .map_err(|e| {
73 tracing::error!("block_on: failed to create worker runtime: {}", e);
74 AsyncError::Generic(format!("failed to create worker runtime: {e}"))
75 })
76 .map(|rt| rt.block_on(fut));
77 let _ = tx.send(result);
78 });
79 let recv_result = rx.recv();
80 let join_result = join_handle.join();
81 match (join_result, recv_result) {
82 (Err(_), _) => Err(AsyncError::Generic(
83 "block_on worker thread panicked".to_string(),
84 )),
85 (Ok(()), Err(_)) => Err(AsyncError::Generic(
86 "block_on worker exited without sending a result".to_string(),
87 )),
88 (Ok(()), Ok(result)) => result,
89 }
90 }
91 _ => {
94 tracing::trace!("block_on: multi-thread runtime, using block_in_place");
95 let (tx, rx) = std::sync::mpsc::sync_channel::<F::Output>(1);
96 let hdl = handle.spawn(worker(fut, tx));
97 let resp = tokio::task::block_in_place(|| rx.recv())?;
98 if !hdl.is_finished() {
99 tracing::debug!("async-sync worker future is not finished, aborting");
100 hdl.abort();
101 }
102 Ok(resp)
103 }
104 }
105}
106
107#[cfg(target_arch = "wasm32")]
122pub fn block_on<F>(_fut: F) -> Result<F::Output, AsyncError>
123where
124 F: Future,
125{
126 Err(AsyncError::Generic(
127 "block_on is not yet supported in WASM. \
128 Awaiting wasm-bindgen JSPI support \
129 (https://github.com/rustwasm/wasm-bindgen/issues/3633). \
130 Use async callers via #[wasm_bindgen] instead."
131 .to_string(),
132 ))
133}
134
135#[cfg(not(target_arch = "wasm32"))]
137async fn worker<F: Future>(
138 fut: F,
139 response: std::sync::mpsc::SyncSender<F::Output>,
140) -> Result<(), AsyncError> {
141 tracing::trace!("Worker start");
142 let result = fut.await;
143 tracing::trace!("Worker async function completed, sending response");
144 response.send(result)?;
145 tracing::trace!("Worker response sent");
146
147 Ok(())
148}
149
150#[cfg(test)]
151mod test {
152 use super::*;
153 use tokio::{
154 runtime::Builder,
155 sync::mpsc::{self, Receiver},
156 };
157
158 #[test]
164 fn test_block_on_nested_async_sync() {
165 let rt = Builder::new_multi_thread()
166 .worker_threads(1)
167 .max_blocking_threads(1)
168 .enable_all()
169 .build()
170 .expect("Failed to create Tokio runtime");
171
172 for _repeat in 0..5 {
173 const MSGS: usize = 10;
174 let (tx, rx) = mpsc::channel::<usize>(1);
175
176 let worker_task = async move {
177 for count in 0..MSGS {
178 tx.send(count).await.unwrap();
179 }
180 };
181 let worker_join = rt.spawn(worker_task);
182
183 let levels = 4;
184
185 async fn innermost_async_function(mut rx: Receiver<usize>) -> Result<String, String> {
186 for i in 0..MSGS {
187 let count = rx.recv().await.unwrap();
188 assert_eq!(count, i);
189 }
190 Ok(String::from("Success"))
191 }
192
193 fn nested_sync_function<F>(fut: F) -> Result<String, String>
194 where
195 F: Future<Output = Result<String, String>> + Send + 'static,
196 F::Output: Send,
197 {
198 block_on(fut)
199 .map_err(|e| e.to_string())?
200 .map_err(|e| e.to_string())
201 }
202
203 async fn outer_async_function(
204 levels: usize,
205 rx: Receiver<usize>,
206 ) -> Result<String, String> {
207 let mut result = innermost_async_function(rx).await;
208 for _ in 0..levels {
209 result = nested_sync_function(async { result });
210 }
211 result
212 }
213
214 let result = rt.block_on(outer_async_function(levels, rx));
215 rt.block_on(worker_join).unwrap();
216 assert_eq!(result.unwrap(), "Success");
217 }
218 }
219
220 #[test]
227 fn test_block_on_succeeds_on_current_thread_runtime() {
228 let rt = Builder::new_current_thread()
229 .enable_all()
230 .build()
231 .expect("Failed to create current-thread Tokio runtime");
232
233 const MSGS: usize = 3;
234 let (tx, rx) = mpsc::channel::<usize>(1);
235
236 let worker_task = async move {
237 for count in 0..MSGS {
238 tx.send(count).await.unwrap();
239 }
240 };
241 let worker_join = rt.spawn(worker_task);
242
243 async fn innermost(mut rx: Receiver<usize>) -> Result<String, String> {
244 for i in 0..MSGS {
245 let count = rx.recv().await.unwrap();
246 assert_eq!(count, i);
247 }
248 Ok("Success".to_string())
249 }
250
251 fn sync_bridge<F>(fut: F) -> Result<String, String>
252 where
253 F: Future<Output = Result<String, String>> + Send + 'static,
254 F::Output: Send,
255 {
256 block_on(fut)
257 .map_err(|e| e.to_string())?
258 .map_err(|e| e.to_string())
259 }
260
261 async fn outer(rx: Receiver<usize>) -> Result<String, String> {
262 let result = innermost(rx).await;
263 sync_bridge(async { result })
264 }
265
266 let result = rt.block_on(outer(rx));
267
268 rt.block_on(worker_join).ok();
269
270 assert_eq!(
271 result.unwrap(),
272 "Success",
273 "block_on should succeed on a current-thread runtime"
274 );
275 }
276}