Skip to main content

zebra_network/peer/client/
tests.rs

1//! Tests for the [`Client`] part of peer connections, and some test utilities for mocking
2//! [`Client`] instances.
3
4#![allow(clippy::unwrap_in_result)]
5#![cfg_attr(feature = "proptest-impl", allow(dead_code))]
6
7use std::{
8    net::{Ipv4Addr, SocketAddrV4},
9    sync::Arc,
10    time::Duration,
11};
12
13use chrono::Utc;
14use futures::{
15    channel::{mpsc, oneshot},
16    future::{self, AbortHandle, Future, FutureExt},
17};
18use tokio::{
19    sync::broadcast::{self, error::TryRecvError},
20    task::JoinHandle,
21};
22
23use zebra_chain::block::Height;
24
25use crate::{
26    constants,
27    peer::{
28        error::SharedPeerError, CancelHeartbeatTask, Client, ClientRequest, ConnectionInfo,
29        ErrorSlot,
30    },
31    peer_set::InventoryChange,
32    protocol::{
33        external::{types::Version, AddrInVersion},
34        types::{Nonce, PeerServices},
35    },
36    BoxError, VersionMessage,
37};
38
39#[cfg(test)]
40mod vectors;
41
42/// The maximum time a mocked peer connection should be alive during a test.
43const MAX_PEER_CONNECTION_TIME: Duration = Duration::from_secs(10);
44
45/// A harness with mocked channels for testing a [`Client`] instance.
46pub struct ClientTestHarness {
47    client_request_receiver: Option<mpsc::Receiver<ClientRequest>>,
48    shutdown_receiver: Option<oneshot::Receiver<CancelHeartbeatTask>>,
49    #[allow(dead_code)]
50    inv_receiver: Option<broadcast::Receiver<InventoryChange>>,
51    error_slot: ErrorSlot,
52    remote_version: Version,
53    connection_aborter: AbortHandle,
54    heartbeat_aborter: AbortHandle,
55}
56
57impl ClientTestHarness {
58    /// Create a [`ClientTestHarnessBuilder`] instance to help create a new [`Client`] instance
59    /// and a [`ClientTestHarness`] to track it.
60    pub fn build() -> ClientTestHarnessBuilder {
61        ClientTestHarnessBuilder {
62            version: None,
63            connection_task: None,
64            heartbeat_task: None,
65        }
66    }
67
68    /// Gets the remote peer protocol version reported by the [`Client`].
69    pub fn remote_version(&self) -> Version {
70        self.remote_version
71    }
72
73    /// Returns true if the [`Client`] instance still wants connection heartbeats to be sent.
74    ///
75    /// Checks that the client:
76    /// - has not been dropped,
77    /// - has not closed or dropped the mocked heartbeat task channel, and
78    /// - has not asked the mocked heartbeat task to shut down.
79    pub fn wants_connection_heartbeats(&mut self) -> bool {
80        let receive_result = self
81            .shutdown_receiver
82            .as_mut()
83            .expect("heartbeat shutdown receiver endpoint has been dropped")
84            .try_recv();
85
86        match receive_result {
87            Ok(None) => true,
88            Ok(Some(CancelHeartbeatTask)) | Err(oneshot::Canceled) => false,
89        }
90    }
91
92    /// Drops the mocked heartbeat shutdown receiver endpoint.
93    pub fn drop_heartbeat_shutdown_receiver(&mut self) {
94        let hearbeat_future = self
95            .shutdown_receiver
96            .take()
97            .expect("unexpected test failure: heartbeat shutdown receiver endpoint has already been dropped");
98
99        std::mem::drop(hearbeat_future);
100    }
101
102    /// Closes the receiver endpoint of [`ClientRequest`]s that are supposed to be sent to the
103    /// remote peer.
104    ///
105    /// The remote peer that would receive the requests is mocked for testing.
106    pub fn close_outbound_client_request_receiver(&mut self) {
107        self.client_request_receiver
108            .as_mut()
109            .expect("request receiver endpoint has been dropped")
110            .close();
111    }
112
113    /// Drops the receiver endpoint of [`ClientRequest`]s, forcefully closing the channel.
114    ///
115    /// The remote peer that would receive the requests is mocked for testing.
116    pub fn drop_outbound_client_request_receiver(&mut self) {
117        self.client_request_receiver
118            .take()
119            .expect("request receiver endpoint has already been dropped");
120    }
121
122    /// Tries to receive a [`ClientRequest`] sent by the [`Client`] instance.
123    ///
124    /// The remote peer that would receive the requests is mocked for testing.
125    pub(crate) fn try_to_receive_outbound_client_request(&mut self) -> ReceiveRequestAttempt {
126        let receive_result = self
127            .client_request_receiver
128            .as_mut()
129            .expect("request receiver endpoint has been dropped")
130            .try_next();
131
132        match receive_result {
133            Ok(Some(request)) => ReceiveRequestAttempt::Request(request),
134            Ok(None) => ReceiveRequestAttempt::Closed,
135            Err(_) => ReceiveRequestAttempt::Empty,
136        }
137    }
138
139    /// Drops the receiver endpoint of [`InventoryChange`]s, forcefully closing the channel.
140    ///
141    /// The inventory registry that would track the changes is mocked for testing.
142    ///
143    /// Note: this closes the broadcast receiver, it doesn't have a separate `close()` method.
144    #[allow(dead_code)]
145    pub fn drop_inventory_change_receiver(&mut self) {
146        self.inv_receiver
147            .take()
148            .expect("inventory change receiver endpoint has already been dropped");
149    }
150
151    /// Tries to receive an [`InventoryChange`] sent by the [`Client`] instance.
152    ///
153    /// This method acts like a mock inventory registry, allowing tests to track the changes.
154    ///
155    /// TODO: make ReceiveRequestAttempt generic, and use it here.
156    #[allow(dead_code)]
157    #[allow(clippy::unwrap_in_result)]
158    pub(crate) fn try_to_receive_inventory_change(&mut self) -> Option<InventoryChange> {
159        let receive_result = self
160            .inv_receiver
161            .as_mut()
162            .expect("inventory change receiver endpoint has been dropped")
163            .try_recv();
164
165        match receive_result {
166            Ok(change) => Some(change),
167            Err(TryRecvError::Empty) => None,
168            Err(TryRecvError::Closed) => None,
169            Err(TryRecvError::Lagged(skipped_messages)) => unreachable!(
170                "unexpected lagged inventory receiver in tests, skipped {} messages",
171                skipped_messages,
172            ),
173        }
174    }
175
176    /// Returns the current error in the [`ErrorSlot`], if there is one.
177    pub fn current_error(&self) -> Option<SharedPeerError> {
178        self.error_slot.try_get_error()
179    }
180
181    /// Sets the error in the [`ErrorSlot`], assuming there isn't one already.
182    ///
183    /// # Panics
184    ///
185    /// If there's already an error in the [`ErrorSlot`].
186    pub fn set_error(&self, error: impl Into<SharedPeerError>) {
187        self.error_slot
188            .try_update_error(error.into())
189            .expect("unexpected earlier error in error slot")
190    }
191
192    /// Stops the mock background task that handles incoming remote requests and replies.
193    pub async fn stop_connection_task(&self) {
194        self.connection_aborter.abort();
195
196        // Allow the task to detect that it was aborted.
197        tokio::task::yield_now().await;
198    }
199
200    /// Stops the mock background task that sends periodic heartbeats.
201    pub async fn stop_heartbeat_task(&self) {
202        self.heartbeat_aborter.abort();
203
204        // Allow the task to detect that it was aborted.
205        tokio::task::yield_now().await;
206    }
207}
208
209/// The result of an attempt to receive a [`ClientRequest`] sent by the [`Client`] instance.
210///
211/// The remote peer that would receive the request is mocked for testing.
212pub(crate) enum ReceiveRequestAttempt {
213    /// The [`Client`] instance has closed the sender endpoint of the channel.
214    Closed,
215
216    /// There were no queued requests in the channel.
217    Empty,
218
219    /// One request was successfully received.
220    Request(ClientRequest),
221}
222
223impl ReceiveRequestAttempt {
224    /// Check if the attempt to receive resulted in discovering that the sender endpoint had been
225    /// closed.
226    pub fn is_closed(&self) -> bool {
227        matches!(self, ReceiveRequestAttempt::Closed)
228    }
229
230    /// Check if the attempt to receive resulted in no requests.
231    pub fn is_empty(&self) -> bool {
232        matches!(self, ReceiveRequestAttempt::Empty)
233    }
234
235    /// Returns the received request, if there was one.
236    #[allow(dead_code)]
237    pub fn request(self) -> Option<ClientRequest> {
238        match self {
239            ReceiveRequestAttempt::Request(request) => Some(request),
240            ReceiveRequestAttempt::Closed | ReceiveRequestAttempt::Empty => None,
241        }
242    }
243}
244
245/// A builder for a [`Client`] and [`ClientTestHarness`] instance.
246///
247/// Mocked data is used to construct a real [`Client`] instance. The mocked data is initialized by
248/// the [`ClientTestHarnessBuilder`], and can be accessed and changed through the
249/// [`ClientTestHarness`].
250pub struct ClientTestHarnessBuilder<C = future::Ready<()>, H = future::Ready<()>> {
251    connection_task: Option<C>,
252    heartbeat_task: Option<H>,
253    version: Option<Version>,
254}
255
256impl<C, H> ClientTestHarnessBuilder<C, H>
257where
258    C: Future<Output = ()> + Send + 'static,
259    H: Future<Output = ()> + Send + 'static,
260{
261    /// Configure the mocked version for the peer.
262    pub fn with_version(mut self, version: Version) -> Self {
263        self.version = Some(version);
264        self
265    }
266
267    /// Configure the mock connection task future to use.
268    pub fn with_connection_task<NewC>(
269        self,
270        connection_task: NewC,
271    ) -> ClientTestHarnessBuilder<NewC, H> {
272        ClientTestHarnessBuilder {
273            connection_task: Some(connection_task),
274            heartbeat_task: self.heartbeat_task,
275            version: self.version,
276        }
277    }
278
279    /// Configure the mock heartbeat task future to use.
280    pub fn with_heartbeat_task<NewH>(
281        self,
282        heartbeat_task: NewH,
283    ) -> ClientTestHarnessBuilder<C, NewH> {
284        ClientTestHarnessBuilder {
285            connection_task: self.connection_task,
286            heartbeat_task: Some(heartbeat_task),
287            version: self.version,
288        }
289    }
290
291    /// Build a [`Client`] instance with the mocked data and a [`ClientTestHarness`] to track it.
292    pub fn finish(self) -> (Client, ClientTestHarness) {
293        let (shutdown_sender, shutdown_receiver) = oneshot::channel();
294        let (client_request_sender, client_request_receiver) = mpsc::channel(1);
295        let (inv_sender, inv_receiver) = broadcast::channel(5);
296
297        let error_slot = ErrorSlot::default();
298        let remote_version = self.version.unwrap_or(Version(0));
299
300        let (connection_task, connection_aborter) =
301            Self::spawn_background_task_or_fallback(self.connection_task);
302        let (heartbeat_task, heartbeat_aborter) =
303            Self::spawn_background_task_or_fallback_with_result(self.heartbeat_task);
304
305        let negotiated_version =
306            std::cmp::min(remote_version, constants::CURRENT_NETWORK_PROTOCOL_VERSION);
307
308        let remote = VersionMessage {
309            version: remote_version,
310            services: PeerServices::default(),
311            timestamp: Utc::now(),
312            address_recv: AddrInVersion::new(
313                SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1),
314                PeerServices::default(),
315            ),
316            address_from: AddrInVersion::new(
317                SocketAddrV4::new(Ipv4Addr::LOCALHOST, 2),
318                PeerServices::default(),
319            ),
320            nonce: Nonce::default(),
321            user_agent: "client test harness".to_string(),
322            start_height: Height(0),
323            relay: true,
324        };
325
326        let connection_info = Arc::new(ConnectionInfo {
327            connected_addr: crate::peer::ConnectedAddr::Isolated,
328            remote,
329            negotiated_version,
330        });
331
332        let client = Client {
333            connection_info,
334            shutdown_tx: Some(shutdown_sender),
335            server_tx: client_request_sender,
336            inv_collector: inv_sender,
337            error_slot: error_slot.clone(),
338            connection_task,
339            heartbeat_task,
340        };
341
342        let harness = ClientTestHarness {
343            client_request_receiver: Some(client_request_receiver),
344            shutdown_receiver: Some(shutdown_receiver),
345            inv_receiver: Some(inv_receiver),
346            error_slot,
347            remote_version,
348            connection_aborter,
349            heartbeat_aborter,
350        };
351
352        (client, harness)
353    }
354
355    /// Spawn a mock background abortable task `task_future` if provided, or a fallback task
356    /// otherwise.
357    ///
358    /// The fallback task lives as long as [`MAX_PEER_CONNECTION_TIME`].
359    fn spawn_background_task_or_fallback<T>(task_future: Option<T>) -> (JoinHandle<()>, AbortHandle)
360    where
361        T: Future<Output = ()> + Send + 'static,
362    {
363        match task_future {
364            Some(future) => Self::spawn_background_task(future),
365            None => Self::spawn_background_task(tokio::time::sleep(MAX_PEER_CONNECTION_TIME)),
366        }
367    }
368
369    /// Spawn a mock background abortable task to run `task_future`.
370    fn spawn_background_task<T>(task_future: T) -> (JoinHandle<()>, AbortHandle)
371    where
372        T: Future<Output = ()> + Send + 'static,
373    {
374        let (task, abort_handle) = future::abortable(task_future);
375        let task_handle = tokio::spawn(task.map(|_result| ()));
376
377        (task_handle, abort_handle)
378    }
379
380    // TODO: In the context of #4734:
381    // - Delete `spawn_background_task_or_fallback` and `spawn_background_task`
382    // - Rename `spawn_background_task_or_fallback_with_result` and `spawn_background_task_with_result` to
383    //   `spawn_background_task_or_fallback` and `spawn_background_task`
384
385    // Similar to `spawn_background_task_or_fallback` but returns a `Result`.
386    fn spawn_background_task_or_fallback_with_result<T>(
387        task_future: Option<T>,
388    ) -> (JoinHandle<Result<(), BoxError>>, AbortHandle)
389    where
390        T: Future<Output = ()> + Send + 'static,
391    {
392        match task_future {
393            Some(future) => Self::spawn_background_task_with_result(future),
394            None => Self::spawn_background_task_with_result(tokio::time::sleep(
395                MAX_PEER_CONNECTION_TIME,
396            )),
397        }
398    }
399
400    // Similar to `spawn_background_task` but returns a `Result`.
401    fn spawn_background_task_with_result<T>(
402        task_future: T,
403    ) -> (JoinHandle<Result<(), BoxError>>, AbortHandle)
404    where
405        T: Future<Output = ()> + Send + 'static,
406    {
407        let (task, abort_handle) = future::abortable(task_future);
408        let task_handle = tokio::spawn(task.map(|_result| Ok(())));
409
410        (task_handle, abort_handle)
411    }
412}