zebra_network/peer/client/
tests.rs1#![allow(clippy::unwrap_in_result)]
5#![cfg_attr(feature = "proptest-impl", allow(dead_code))]
6
7use std::{
8 net::{Ipv4Addr, SocketAddrV4},
9 sync::Arc,
10 time::Duration,
11};
12
13use chrono::Utc;
14use futures::{
15 channel::{mpsc, oneshot},
16 future::{self, AbortHandle, Future, FutureExt},
17};
18use tokio::{
19 sync::broadcast::{self, error::TryRecvError},
20 task::JoinHandle,
21};
22
23use zebra_chain::block::Height;
24
25use crate::{
26 constants,
27 peer::{
28 error::SharedPeerError, CancelHeartbeatTask, Client, ClientRequest, ConnectionInfo,
29 ErrorSlot,
30 },
31 peer_set::InventoryChange,
32 protocol::{
33 external::{types::Version, AddrInVersion},
34 types::{Nonce, PeerServices},
35 },
36 BoxError, VersionMessage,
37};
38
39#[cfg(test)]
40mod vectors;
41
42const MAX_PEER_CONNECTION_TIME: Duration = Duration::from_secs(10);
44
45pub struct ClientTestHarness {
47 client_request_receiver: Option<mpsc::Receiver<ClientRequest>>,
48 shutdown_receiver: Option<oneshot::Receiver<CancelHeartbeatTask>>,
49 #[allow(dead_code)]
50 inv_receiver: Option<broadcast::Receiver<InventoryChange>>,
51 error_slot: ErrorSlot,
52 remote_version: Version,
53 connection_aborter: AbortHandle,
54 heartbeat_aborter: AbortHandle,
55}
56
57impl ClientTestHarness {
58 pub fn build() -> ClientTestHarnessBuilder {
61 ClientTestHarnessBuilder {
62 version: None,
63 connection_task: None,
64 heartbeat_task: None,
65 }
66 }
67
68 pub fn remote_version(&self) -> Version {
70 self.remote_version
71 }
72
73 pub fn wants_connection_heartbeats(&mut self) -> bool {
80 let receive_result = self
81 .shutdown_receiver
82 .as_mut()
83 .expect("heartbeat shutdown receiver endpoint has been dropped")
84 .try_recv();
85
86 match receive_result {
87 Ok(None) => true,
88 Ok(Some(CancelHeartbeatTask)) | Err(oneshot::Canceled) => false,
89 }
90 }
91
92 pub fn drop_heartbeat_shutdown_receiver(&mut self) {
94 let hearbeat_future = self
95 .shutdown_receiver
96 .take()
97 .expect("unexpected test failure: heartbeat shutdown receiver endpoint has already been dropped");
98
99 std::mem::drop(hearbeat_future);
100 }
101
102 pub fn close_outbound_client_request_receiver(&mut self) {
107 self.client_request_receiver
108 .as_mut()
109 .expect("request receiver endpoint has been dropped")
110 .close();
111 }
112
113 pub fn drop_outbound_client_request_receiver(&mut self) {
117 self.client_request_receiver
118 .take()
119 .expect("request receiver endpoint has already been dropped");
120 }
121
122 pub(crate) fn try_to_receive_outbound_client_request(&mut self) -> ReceiveRequestAttempt {
126 let receive_result = self
127 .client_request_receiver
128 .as_mut()
129 .expect("request receiver endpoint has been dropped")
130 .try_next();
131
132 match receive_result {
133 Ok(Some(request)) => ReceiveRequestAttempt::Request(request),
134 Ok(None) => ReceiveRequestAttempt::Closed,
135 Err(_) => ReceiveRequestAttempt::Empty,
136 }
137 }
138
139 #[allow(dead_code)]
145 pub fn drop_inventory_change_receiver(&mut self) {
146 self.inv_receiver
147 .take()
148 .expect("inventory change receiver endpoint has already been dropped");
149 }
150
151 #[allow(dead_code)]
157 #[allow(clippy::unwrap_in_result)]
158 pub(crate) fn try_to_receive_inventory_change(&mut self) -> Option<InventoryChange> {
159 let receive_result = self
160 .inv_receiver
161 .as_mut()
162 .expect("inventory change receiver endpoint has been dropped")
163 .try_recv();
164
165 match receive_result {
166 Ok(change) => Some(change),
167 Err(TryRecvError::Empty) => None,
168 Err(TryRecvError::Closed) => None,
169 Err(TryRecvError::Lagged(skipped_messages)) => unreachable!(
170 "unexpected lagged inventory receiver in tests, skipped {} messages",
171 skipped_messages,
172 ),
173 }
174 }
175
176 pub fn current_error(&self) -> Option<SharedPeerError> {
178 self.error_slot.try_get_error()
179 }
180
181 pub fn set_error(&self, error: impl Into<SharedPeerError>) {
187 self.error_slot
188 .try_update_error(error.into())
189 .expect("unexpected earlier error in error slot")
190 }
191
192 pub async fn stop_connection_task(&self) {
194 self.connection_aborter.abort();
195
196 tokio::task::yield_now().await;
198 }
199
200 pub async fn stop_heartbeat_task(&self) {
202 self.heartbeat_aborter.abort();
203
204 tokio::task::yield_now().await;
206 }
207}
208
209pub(crate) enum ReceiveRequestAttempt {
213 Closed,
215
216 Empty,
218
219 Request(ClientRequest),
221}
222
223impl ReceiveRequestAttempt {
224 pub fn is_closed(&self) -> bool {
227 matches!(self, ReceiveRequestAttempt::Closed)
228 }
229
230 pub fn is_empty(&self) -> bool {
232 matches!(self, ReceiveRequestAttempt::Empty)
233 }
234
235 #[allow(dead_code)]
237 pub fn request(self) -> Option<ClientRequest> {
238 match self {
239 ReceiveRequestAttempt::Request(request) => Some(request),
240 ReceiveRequestAttempt::Closed | ReceiveRequestAttempt::Empty => None,
241 }
242 }
243}
244
245pub struct ClientTestHarnessBuilder<C = future::Ready<()>, H = future::Ready<()>> {
251 connection_task: Option<C>,
252 heartbeat_task: Option<H>,
253 version: Option<Version>,
254}
255
256impl<C, H> ClientTestHarnessBuilder<C, H>
257where
258 C: Future<Output = ()> + Send + 'static,
259 H: Future<Output = ()> + Send + 'static,
260{
261 pub fn with_version(mut self, version: Version) -> Self {
263 self.version = Some(version);
264 self
265 }
266
267 pub fn with_connection_task<NewC>(
269 self,
270 connection_task: NewC,
271 ) -> ClientTestHarnessBuilder<NewC, H> {
272 ClientTestHarnessBuilder {
273 connection_task: Some(connection_task),
274 heartbeat_task: self.heartbeat_task,
275 version: self.version,
276 }
277 }
278
279 pub fn with_heartbeat_task<NewH>(
281 self,
282 heartbeat_task: NewH,
283 ) -> ClientTestHarnessBuilder<C, NewH> {
284 ClientTestHarnessBuilder {
285 connection_task: self.connection_task,
286 heartbeat_task: Some(heartbeat_task),
287 version: self.version,
288 }
289 }
290
291 pub fn finish(self) -> (Client, ClientTestHarness) {
293 let (shutdown_sender, shutdown_receiver) = oneshot::channel();
294 let (client_request_sender, client_request_receiver) = mpsc::channel(1);
295 let (inv_sender, inv_receiver) = broadcast::channel(5);
296
297 let error_slot = ErrorSlot::default();
298 let remote_version = self.version.unwrap_or(Version(0));
299
300 let (connection_task, connection_aborter) =
301 Self::spawn_background_task_or_fallback(self.connection_task);
302 let (heartbeat_task, heartbeat_aborter) =
303 Self::spawn_background_task_or_fallback_with_result(self.heartbeat_task);
304
305 let negotiated_version =
306 std::cmp::min(remote_version, constants::CURRENT_NETWORK_PROTOCOL_VERSION);
307
308 let remote = VersionMessage {
309 version: remote_version,
310 services: PeerServices::default(),
311 timestamp: Utc::now(),
312 address_recv: AddrInVersion::new(
313 SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1),
314 PeerServices::default(),
315 ),
316 address_from: AddrInVersion::new(
317 SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2),
318 PeerServices::default(),
319 ),
320 nonce: Nonce::default(),
321 user_agent: "client test harness".to_string(),
322 start_height: Height(0),
323 relay: true,
324 };
325
326 let connection_info = Arc::new(ConnectionInfo {
327 connected_addr: crate::peer::ConnectedAddr::Isolated,
328 remote,
329 negotiated_version,
330 });
331
332 let client = Client {
333 connection_info,
334 shutdown_tx: Some(shutdown_sender),
335 server_tx: client_request_sender,
336 inv_collector: inv_sender,
337 error_slot: error_slot.clone(),
338 connection_task,
339 heartbeat_task,
340 };
341
342 let harness = ClientTestHarness {
343 client_request_receiver: Some(client_request_receiver),
344 shutdown_receiver: Some(shutdown_receiver),
345 inv_receiver: Some(inv_receiver),
346 error_slot,
347 remote_version,
348 connection_aborter,
349 heartbeat_aborter,
350 };
351
352 (client, harness)
353 }
354
355 fn spawn_background_task_or_fallback<T>(task_future: Option<T>) -> (JoinHandle<()>, AbortHandle)
360 where
361 T: Future<Output = ()> + Send + 'static,
362 {
363 match task_future {
364 Some(future) => Self::spawn_background_task(future),
365 None => Self::spawn_background_task(tokio::time::sleep(MAX_PEER_CONNECTION_TIME)),
366 }
367 }
368
369 fn spawn_background_task<T>(task_future: T) -> (JoinHandle<()>, AbortHandle)
371 where
372 T: Future<Output = ()> + Send + 'static,
373 {
374 let (task, abort_handle) = future::abortable(task_future);
375 let task_handle = tokio::spawn(task.map(|_result| ()));
376
377 (task_handle, abort_handle)
378 }
379
380 fn spawn_background_task_or_fallback_with_result<T>(
387 task_future: Option<T>,
388 ) -> (JoinHandle<Result<(), BoxError>>, AbortHandle)
389 where
390 T: Future<Output = ()> + Send + 'static,
391 {
392 match task_future {
393 Some(future) => Self::spawn_background_task_with_result(future),
394 None => Self::spawn_background_task_with_result(tokio::time::sleep(
395 MAX_PEER_CONNECTION_TIME,
396 )),
397 }
398 }
399
400 fn spawn_background_task_with_result<T>(
402 task_future: T,
403 ) -> (JoinHandle<Result<(), BoxError>>, AbortHandle)
404 where
405 T: Future<Output = ()> + Send + 'static,
406 {
407 let (task, abort_handle) = future::abortable(task_future);
408 let task_handle = tokio::spawn(task.map(|_result| Ok(())));
409
410 (task_handle, abort_handle)
411 }
412}