1use crate::{
15 transport::TransportRequest, DapiClientError, DapiRequestExecutor, ExecutionError,
16 ExecutionResponse, ExecutionResult, RequestSettings,
17};
18use dapi_grpc::mock::Mockable;
19use dapi_grpc::tonic::async_trait;
20use hex::ToHex;
21use sha2::Digest;
22use std::{
23 any::type_name,
24 collections::HashMap,
25 fmt::{Debug, Display},
26};
27
28#[derive(Default, Debug)]
34pub struct MockDapiClient {
35 expectations: Expectations,
36}
37pub type MockResult<T> = ExecutionResult<<T as TransportRequest>::Response, DapiClientError>;
39
40impl MockDapiClient {
41 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub fn expect<R>(&mut self, request: &R, result: &MockResult<R>) -> Result<&mut Self, MockError>
48 where
49 R: TransportRequest + Mockable,
50 R::Response: Mockable,
51 {
52 let key = self.expectations.add(request, result)?;
53
54 tracing::trace!(
55 %key,
56 request_type = std::any::type_name::<R>(),
57 response_type = std::any::type_name::<R::Response>(),
58 "mock added expectation"
59 );
60
61 Ok(self)
62 }
63
64 pub fn remove<R>(&mut self, request: &R) -> bool
68 where
69 R: TransportRequest + Mockable,
70 R::Response: Mockable,
71 {
72 let key = Key::new(request);
73 let removed = self.expectations.remove(&key);
74
75 if removed {
76 tracing::trace!(
77 %key,
78 request_type = std::any::type_name::<R>(),
79 "mock removed expectation"
80 );
81 }
82
83 removed
84 }
85
86 #[cfg(feature = "dump")]
95 pub fn load<T, P: AsRef<std::path::Path>>(
96 &mut self,
97 file: P,
98 ) -> Result<(T, MockResult<T>), std::io::Error>
99 where
100 T: TransportRequest + Mockable,
101 T::Response: Mockable,
102 {
103 use crate::DumpData;
104
105 let buf = std::fs::read(file)?;
106 let data = DumpData::<T>::mock_deserialize(&buf).ok_or({
107 std::io::Error::new(std::io::ErrorKind::InvalidData, "unable to parse json")
108 })?;
109
110 let (request, response) = data.deserialize();
111 self.expect(&request, &response).map_err(|e| {
112 std::io::Error::new(
113 std::io::ErrorKind::InvalidInput,
114 format!("unable to add expectation: {}", e),
115 )
116 })?;
117 Ok((request, response))
118 }
119}
120
121#[async_trait]
122impl DapiRequestExecutor for MockDapiClient {
123 async fn execute<R: TransportRequest>(
124 &self,
125 request: R,
126 _settings: RequestSettings,
127 ) -> MockResult<R>
128 where
129 R: Mockable,
130 R::Response: Mockable,
131 {
132 let (key, response) = self.expectations.get(&request);
133
134 tracing::trace!(
135 %key,
136 request_type = std::any::type_name::<R>(),
137 response_type = std::any::type_name::<R::Response>(),
138 response = ?response,
139 "mock execute"
140 );
141
142 if let Some(response) = response {
143 response
144 } else {
145 let error = MockError::MockExpectationNotFound(format!(
146 "unexpected mock request with key {}, use MockDapiClient::expect(): {:?}",
147 key, request
148 ));
149
150 Err(ExecutionError {
151 inner: DapiClientError::Mock(error),
152 retries: 0,
153 address: None,
154 })
155 }
156 }
157}
158
159#[derive(Eq, Hash, PartialEq, PartialOrd, Ord, Clone, Debug)]
160pub struct Key([u8; 32]);
162
163impl Key {
164 pub fn new<S: Mockable>(request: &S) -> Self {
170 Self::try_new(request).expect("unable to create a key")
171 }
172
173 pub fn try_new<S: Mockable>(request: &S) -> Result<Self, std::io::Error> {
175 let typ = type_name::<S>().replace('&', ""); let mut encoded = S::mock_serialize(request).ok_or(std::io::Error::new(
180 std::io::ErrorKind::InvalidData,
181 format!("mocking not supported for object of type {}", typ),
182 ))?;
183 encoded.append(&mut typ.as_bytes().to_vec());
184
185 let mut e = sha2::Sha256::new();
186 e.update(encoded);
187 let key = e.finalize().into();
188
189 Ok(Self(key))
190 }
191}
192
193impl ToHex for Key {
194 fn encode_hex<T: std::iter::FromIterator<char>>(&self) -> T {
195 self.0.encode_hex()
196 }
197
198 fn encode_hex_upper<T: std::iter::FromIterator<char>>(&self) -> T {
199 self.0.encode_hex_upper()
200 }
201}
202impl Display for Key {
203 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204 Display::fmt(&self.encode_hex::<String>(), f)
205 }
206}
207
208#[derive(Debug, thiserror::Error, Clone)]
209#[cfg_attr(feature = "mocks", derive(serde::Serialize, serde::Deserialize))]
210pub enum MockError {
212 #[error("mock expectation not found for request: {0}")]
213 MockExpectationNotFound(String),
215
216 #[error("expectation already defined for request: {0}")]
217 MockExpectationConflict(String),
219}
220
221#[derive(Debug)]
222struct ExpectedResult(Vec<u8>);
224
225impl ExpectedResult {
226 fn serialize<I: Mockable>(request: &I) -> Self {
227 Self(request.mock_serialize().expect("encode value"))
229 }
230
231 fn deserialize<O: Mockable>(&self) -> O {
232 O::mock_deserialize(&self.0).expect("deserialize value")
234 }
235}
236
237#[derive(Default, Debug)]
238struct Expectations {
240 expectations: HashMap<Key, ExpectedResult>,
241}
242
243impl Expectations {
244 pub fn add<I: Mockable + Debug, O: Mockable>(
248 &mut self,
249 request: &I,
250 result: &O,
251 ) -> Result<Key, MockError> {
252 let key = Key::new(request);
253 let value = ExpectedResult::serialize(result);
254
255 if self.expectations.contains_key(&key) {
256 return Err(MockError::MockExpectationConflict(format!(
257 "expectation with key {} already defined for {} request",
258 key,
259 std::any::type_name::<I>(),
260 )));
261 }
262
263 self.expectations.insert(key.clone(), value);
264
265 Ok(key)
266 }
267
268 pub fn get<I: Mockable, O: Mockable>(&self, request: &I) -> (Key, Option<O>) {
272 let key = Key::new(request);
273
274 let response = self.expectations.get(&key).and_then(|v| v.deserialize());
275
276 (key, response)
277 }
278
279 pub fn remove(&mut self, key: &Key) -> bool {
281 self.expectations.remove(key).is_some()
282 }
283}
284
285impl<R: Mockable> Mockable for ExecutionResponse<R> {
286 fn mock_serialize(&self) -> Option<Vec<u8>> {
287 R::mock_serialize(&self.inner)
288 }
289
290 fn mock_deserialize(data: &[u8]) -> Option<Self> {
291 R::mock_deserialize(data).map(|inner| ExecutionResponse {
293 inner,
294 retries: 0,
295 address: "http://127.0.0.1:9000"
296 .parse()
297 .expect("failed to parse address"),
298 })
299 }
300}
301
302impl<E: Mockable> Mockable for ExecutionError<E> {
303 fn mock_serialize(&self) -> Option<Vec<u8>> {
304 E::mock_serialize(&self.inner)
305 }
306
307 fn mock_deserialize(data: &[u8]) -> Option<Self> {
308 E::mock_deserialize(data).map(|inner| ExecutionError {
310 inner,
311 retries: 0,
312 address: None,
313 })
314 }
315}
316
317pub trait FromInner<R>
320where
321 Self: Default,
322{
323 fn from_inner(inner: R) -> Self;
328}
329
330impl<R> FromInner<R> for ExecutionResponse<R>
331where
332 Self: Default,
333{
334 fn from_inner(inner: R) -> Self {
335 Self {
336 inner,
337 ..Default::default()
338 }
339 }
340}