Skip to main content

zebra_consensus/primitives/
sapling.rs

1//! Async Sapling batch verifier service
2
3use core::fmt;
4use std::{
5    future::Future,
6    mem,
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11use futures::{future::BoxFuture, FutureExt, TryFutureExt};
12use once_cell::sync::Lazy;
13use rand::thread_rng;
14use tokio::sync::watch;
15use tower::{util::ServiceFn, Service};
16use tower_batch_control::{Batch, BatchControl, RequestWeight};
17use tower_fallback::Fallback;
18
19use sapling_crypto::{bundle::Authorized, BatchValidator, Bundle};
20use zcash_proofs::prover::LocalTxProver;
21use zcash_protocol::value::ZatBalance;
22use zebra_chain::transaction::SigHash;
23
24/// Sapling prover containing spend and output params for the Sapling circuit.
25///
26/// Used to:
27///
28/// - construct Sapling outputs in coinbase txs, and
29/// - verify Sapling shielded data in the tx verifier.
30static SAPLING: Lazy<LocalTxProver> = Lazy::new(LocalTxProver::bundled);
31
32#[derive(Clone)]
33pub struct Item {
34    /// The bundle containing the Sapling shielded data to verify.
35    bundle: Bundle<Authorized, ZatBalance>,
36    /// The sighash of the transaction that contains the Sapling shielded data.
37    sighash: SigHash,
38}
39
40impl Item {
41    /// Creates a new [`Item`] from a Sapling bundle and sighash.
42    pub fn new(bundle: Bundle<Authorized, ZatBalance>, sighash: SigHash) -> Self {
43        Self { bundle, sighash }
44    }
45}
46
47impl RequestWeight for Item {}
48
49/// A service that verifies Sapling shielded data in batches.
50///
51/// Handles batching incoming requests, driving batches to completion, and reporting results.
52#[derive(Default)]
53pub struct Verifier {
54    /// A batch verifier for Sapling shielded data.
55    batch: BatchValidator,
56
57    /// A channel for broadcasting the verification result of the batch.
58    ///
59    /// Each batch gets a newly created channel, so there is only ever one result sent per channel.
60    /// Tokio doesn't have a oneshot multi-consumer channel, so we use a watch channel.
61    tx: watch::Sender<Option<bool>>,
62}
63
64impl fmt::Debug for Verifier {
65    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66        f.debug_struct("Verifier")
67            .field("batch", &"..")
68            .field("tx", &self.tx)
69            .finish()
70    }
71}
72
73impl Drop for Verifier {
74    // Flush the current batch in case there are still any pending futures.
75    //
76    // Flushing the batch means we need to validate it. This function fires off the validation and
77    // returns immediately, usually before the validation finishes.
78    fn drop(&mut self) {
79        let batch = mem::take(&mut self.batch);
80        let tx = mem::take(&mut self.tx);
81
82        // The validation is CPU-intensive; do it on a dedicated thread so it does not block.
83        rayon::spawn_fifo(move || {
84            let (spend_vk, output_vk) = SAPLING.verifying_keys();
85
86            // Validate the batch and send the result through the channel.
87            let res = batch.validate(&spend_vk, &output_vk, thread_rng());
88            let _ = tx.send(Some(res));
89        });
90    }
91}
92
93impl Service<BatchControl<Item>> for Verifier {
94    type Response = ();
95    type Error = Box<dyn std::error::Error + Send + Sync>;
96    type Future = Pin<Box<dyn Future<Output = Result<(), Self::Error>> + Send>>;
97
98    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
99        Poll::Ready(Ok(()))
100    }
101
102    fn call(&mut self, req: BatchControl<Item>) -> Self::Future {
103        match req {
104            BatchControl::Item(item) => {
105                let mut rx = self.tx.subscribe();
106
107                let bundle_check = self
108                    .batch
109                    .check_bundle(item.bundle, item.sighash.into())
110                    .then_some(())
111                    .ok_or("invalid Sapling bundle");
112
113                async move {
114                    bundle_check?;
115
116                    rx.changed()
117                        .await
118                        .map_err(|_| "verifier was dropped without flushing")
119                        .and_then(|_| {
120                            // We use a new channel for each batch, so we always get the correct
121                            // batch result here.
122                            rx.borrow()
123                                .ok_or("threadpool unexpectedly dropped channel sender")?
124                                .then(|| {
125                                    metrics::counter!("proofs.sapling.verified").increment(1);
126                                })
127                                .ok_or_else(|| {
128                                    metrics::counter!("proofs.sapling.invalid").increment(1);
129                                    "batch verification of Sapling shielded data failed"
130                                })
131                        })
132                        .map_err(Self::Error::from)
133                }
134                .boxed()
135            }
136
137            BatchControl::Flush => {
138                let batch = mem::take(&mut self.batch);
139                let tx = mem::take(&mut self.tx);
140
141                async move {
142                    let start = std::time::Instant::now();
143                    let spawn_result = tokio::task::spawn_blocking(move || {
144                        let (spend_vk, output_vk) = SAPLING.verifying_keys();
145                        batch.validate(&spend_vk, &output_vk, thread_rng())
146                    })
147                    .await;
148                    let duration = start.elapsed().as_secs_f64();
149
150                    let result_label = match &spawn_result {
151                        Ok(true) => "success",
152                        _ => "failure",
153                    };
154                    metrics::histogram!(
155                        "zebra.consensus.batch.duration_seconds",
156                        "verifier" => "groth16_sapling",
157                        "result" => result_label
158                    )
159                    .record(duration);
160
161                    // Extract the value before consuming spawn_result
162                    let is_valid = spawn_result.as_ref().ok().copied();
163                    let _ = tx.send(is_valid);
164                    spawn_result.map(|_| ()).map_err(Self::Error::from)
165                }
166                .boxed()
167            }
168        }
169    }
170}
171
172/// Verifies a single [`Item`].
173pub fn verify_single(
174    item: Item,
175) -> Pin<Box<dyn Future<Output = Result<(), Box<dyn std::error::Error + Send + Sync>>> + Send>> {
176    async move {
177        let mut verifier = Verifier::default();
178
179        let check = verifier
180            .batch
181            .check_bundle(item.bundle, item.sighash.into())
182            .then_some(())
183            .ok_or("invalid Sapling bundle");
184        check?;
185
186        tokio::task::spawn_blocking(move || {
187            let (spend_vk, output_vk) = SAPLING.verifying_keys();
188
189            mem::take(&mut verifier.batch).validate(&spend_vk, &output_vk, thread_rng())
190        })
191        .await
192        .map_err(|_| "Sapling bundle validation thread panicked")?
193        .then_some(())
194        .ok_or("invalid proof or sig in Sapling bundle")
195    }
196    .map_err(Box::from)
197    .boxed()
198}
199
200/// Global batch verification context for Sapling shielded data.
201pub static VERIFIER: Lazy<
202    Fallback<
203        Batch<Verifier, Item>,
204        ServiceFn<
205            fn(Item) -> BoxFuture<'static, Result<(), Box<dyn std::error::Error + Send + Sync>>>,
206        >,
207    >,
208> = Lazy::new(|| {
209    Fallback::new(
210        Batch::new(
211            Verifier::default(),
212            super::MAX_BATCH_SIZE,
213            None,
214            super::MAX_BATCH_LATENCY,
215        ),
216        tower::service_fn(verify_single),
217    )
218});