Skip to main content

zebra_chain/
amount.rs

1//! Strongly-typed zatoshi amounts that prevent under/overflows.
2//!
3//! The [`Amount`] type is parameterized by a [`Constraint`] implementation that
4//! declares the range of allowed values. In contrast to regular arithmetic
5//! operations, which return values, arithmetic on [`Amount`]s returns
6//! [`Result`](std::result::Result)s.
7
8use std::{
9    cmp::Ordering,
10    fmt,
11    hash::{Hash, Hasher},
12    marker::PhantomData,
13    ops::RangeInclusive,
14};
15
16use crate::serialization::{ZcashDeserialize, ZcashSerialize};
17use byteorder::{ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt};
18
19#[cfg(any(test, feature = "proptest-impl"))]
20pub mod arbitrary;
21
22#[cfg(test)]
23mod tests;
24
25/// The result of an amount operation.
26pub type Result<T, E = Error> = std::result::Result<T, E>;
27
28/// A runtime validated type for representing amounts of zatoshis
29//
30// TODO:
31// - remove the default NegativeAllowed bound, to make consensus rule reviews easier
32// - put a Constraint bound on the type generic, not just some implementations
33#[derive(Clone, Copy, Serialize, Deserialize, Default)]
34#[serde(try_from = "i64")]
35#[serde(into = "i64")]
36#[serde(bound = "C: Constraint + Clone")]
37pub struct Amount<C = NegativeAllowed>(
38    /// The inner amount value.
39    i64,
40    /// Used for [`Constraint`] type inference.
41    ///
42    /// # Correctness
43    ///
44    /// This internal Zebra marker type is not consensus-critical.
45    /// And it should be ignored during testing. (And other internal uses.)
46    #[serde(skip)]
47    PhantomData<C>,
48);
49
50impl<C> fmt::Display for Amount<C> {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        let zats = self.zatoshis();
53
54        f.pad_integral(zats > 0, "", &zats.to_string())
55    }
56}
57
58impl<C> fmt::Debug for Amount<C> {
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        f.debug_tuple(&format!("Amount<{}>", std::any::type_name::<C>()))
61            .field(&self.0)
62            .finish()
63    }
64}
65
66impl Amount<NonNegative> {
67    /// Create a new non-negative [`Amount`] from a provided value in ZEC.
68    pub const fn new_from_zec(zec_value: i64) -> Self {
69        Self::new(zec_value.checked_mul(COIN).expect("should fit in i64"))
70    }
71
72    /// Create a new non-negative [`Amount`] from a provided value in zatoshis.
73    pub const fn new(zatoshis: i64) -> Self {
74        assert!(zatoshis <= MAX_MONEY && zatoshis >= 0);
75        Self(zatoshis, PhantomData)
76    }
77
78    /// Divide an [`Amount`] by a value that the amount fits into evenly such that there is no remainder.
79    pub const fn div_exact(self, rhs: i64) -> Self {
80        let result = self.0.checked_div(rhs).expect("divisor must be non-zero");
81        if self.0 % rhs != 0 {
82            panic!("divisor must divide amount evenly, no remainder");
83        }
84
85        Self(result, PhantomData)
86    }
87}
88
89impl<C> Amount<C> {
90    /// Convert this amount to a different Amount type if it satisfies the new constraint
91    pub fn constrain<C2>(self) -> Result<Amount<C2>>
92    where
93        C2: Constraint,
94    {
95        self.0.try_into()
96    }
97
98    /// Returns the number of zatoshis in this amount.
99    pub fn zatoshis(&self) -> i64 {
100        self.0
101    }
102
103    /// Checked subtraction. Computes self - rhs, returning None if overflow occurred.
104    pub fn checked_sub<C2: Constraint>(self, rhs: Amount<C2>) -> Option<Amount> {
105        self.0.checked_sub(rhs.0).and_then(|v| v.try_into().ok())
106    }
107
108    /// To little endian byte array
109    pub fn to_bytes(&self) -> [u8; 8] {
110        let mut buf: [u8; 8] = [0; 8];
111        LittleEndian::write_i64(&mut buf, self.0);
112        buf
113    }
114
115    /// From little endian byte array
116    pub fn from_bytes(bytes: [u8; 8]) -> Result<Amount<C>>
117    where
118        C: Constraint,
119    {
120        let amount = i64::from_le_bytes(bytes);
121        amount.try_into()
122    }
123
124    /// Create a zero `Amount`
125    pub fn zero() -> Amount<C>
126    where
127        C: Constraint,
128    {
129        0.try_into().expect("an amount of 0 is always valid")
130    }
131
132    /// Returns true if this amount is zero.
133    pub fn is_zero(&self) -> bool {
134        self.0 == 0
135    }
136}
137
138impl<C> std::ops::Add<Amount<C>> for Amount<C>
139where
140    C: Constraint,
141{
142    type Output = Result<Amount<C>>;
143
144    fn add(self, rhs: Amount<C>) -> Self::Output {
145        let value = self
146            .0
147            .checked_add(rhs.0)
148            .expect("adding two constrained Amounts is always within an i64");
149        value.try_into()
150    }
151}
152
153impl<C> std::ops::Add<Amount<C>> for Result<Amount<C>>
154where
155    C: Constraint,
156{
157    type Output = Result<Amount<C>>;
158
159    fn add(self, rhs: Amount<C>) -> Self::Output {
160        self? + rhs
161    }
162}
163
164impl<C> std::ops::Add<Result<Amount<C>>> for Amount<C>
165where
166    C: Constraint,
167{
168    type Output = Result<Amount<C>>;
169
170    fn add(self, rhs: Result<Amount<C>>) -> Self::Output {
171        self + rhs?
172    }
173}
174
175impl<C> std::ops::AddAssign<Amount<C>> for Result<Amount<C>>
176where
177    Amount<C>: Copy,
178    C: Constraint,
179{
180    fn add_assign(&mut self, rhs: Amount<C>) {
181        if let Ok(lhs) = *self {
182            *self = lhs + rhs;
183        }
184    }
185}
186
187impl<C> std::ops::Sub<Amount<C>> for Amount<C>
188where
189    C: Constraint,
190{
191    type Output = Result<Amount<C>>;
192
193    fn sub(self, rhs: Amount<C>) -> Self::Output {
194        let value = self
195            .0
196            .checked_sub(rhs.0)
197            .expect("subtracting two constrained Amounts is always within an i64");
198        value.try_into()
199    }
200}
201
202impl<C> std::ops::Sub<Amount<C>> for Result<Amount<C>>
203where
204    C: Constraint,
205{
206    type Output = Result<Amount<C>>;
207
208    fn sub(self, rhs: Amount<C>) -> Self::Output {
209        self? - rhs
210    }
211}
212
213impl<C> std::ops::Sub<Result<Amount<C>>> for Amount<C>
214where
215    C: Constraint,
216{
217    type Output = Result<Amount<C>>;
218
219    fn sub(self, rhs: Result<Amount<C>>) -> Self::Output {
220        self - rhs?
221    }
222}
223
224impl<C> std::ops::SubAssign<Amount<C>> for Result<Amount<C>>
225where
226    Amount<C>: Copy,
227    C: Constraint,
228{
229    fn sub_assign(&mut self, rhs: Amount<C>) {
230        if let Ok(lhs) = *self {
231            *self = lhs - rhs;
232        }
233    }
234}
235
236impl<C> From<Amount<C>> for i64 {
237    fn from(amount: Amount<C>) -> Self {
238        amount.0
239    }
240}
241
242impl From<Amount<NonNegative>> for u64 {
243    fn from(amount: Amount<NonNegative>) -> Self {
244        amount.0.try_into().expect("non-negative i64 fits in u64")
245    }
246}
247
248impl<C> From<Amount<C>> for jubjub::Fr {
249    fn from(a: Amount<C>) -> jubjub::Fr {
250        // TODO: this isn't constant time -- does that matter?
251        if a.0 < 0 {
252            let abs_amount = i128::from(a.0)
253                .checked_abs()
254                .expect("absolute i64 fits in i128");
255            let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
256
257            jubjub::Fr::from(abs_amount).neg()
258        } else {
259            jubjub::Fr::from(u64::try_from(a.0).expect("non-negative i64 fits in u64"))
260        }
261    }
262}
263
264impl<C> From<Amount<C>> for halo2::pasta::pallas::Scalar {
265    fn from(a: Amount<C>) -> halo2::pasta::pallas::Scalar {
266        // TODO: this isn't constant time -- does that matter?
267        if a.0 < 0 {
268            let abs_amount = i128::from(a.0)
269                .checked_abs()
270                .expect("absolute i64 fits in i128");
271            let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
272
273            halo2::pasta::pallas::Scalar::from(abs_amount).neg()
274        } else {
275            halo2::pasta::pallas::Scalar::from(
276                u64::try_from(a.0).expect("non-negative i64 fits in u64"),
277            )
278        }
279    }
280}
281
282impl<C> TryFrom<i32> for Amount<C>
283where
284    C: Constraint,
285{
286    type Error = Error;
287
288    fn try_from(value: i32) -> Result<Self, Self::Error> {
289        C::validate(value.into()).map(|v| Self(v, PhantomData))
290    }
291}
292
293impl<C> TryFrom<i64> for Amount<C>
294where
295    C: Constraint,
296{
297    type Error = Error;
298
299    fn try_from(value: i64) -> Result<Self, Self::Error> {
300        C::validate(value).map(|v| Self(v, PhantomData))
301    }
302}
303
304impl<C> TryFrom<u64> for Amount<C>
305where
306    C: Constraint,
307{
308    type Error = Error;
309
310    fn try_from(value: u64) -> Result<Self, Self::Error> {
311        let value = value.try_into().map_err(|source| Error::Convert {
312            value: value.into(),
313            source,
314        })?;
315
316        C::validate(value).map(|v| Self(v, PhantomData))
317    }
318}
319
320/// Conversion from `i128` to `Amount`.
321///
322/// Used to handle the result of multiplying negative `Amount`s by `u64`.
323impl<C> TryFrom<i128> for Amount<C>
324where
325    C: Constraint,
326{
327    type Error = Error;
328
329    fn try_from(value: i128) -> Result<Self, Self::Error> {
330        let value = value
331            .try_into()
332            .map_err(|source| Error::Convert { value, source })?;
333
334        C::validate(value).map(|v| Self(v, PhantomData))
335    }
336}
337
338impl<C> Hash for Amount<C> {
339    /// Amounts with the same value are equal, even if they have different constraints
340    fn hash<H: Hasher>(&self, state: &mut H) {
341        self.0.hash(state);
342    }
343}
344
345impl<C1, C2> PartialEq<Amount<C2>> for Amount<C1> {
346    fn eq(&self, other: &Amount<C2>) -> bool {
347        self.0.eq(&other.0)
348    }
349}
350
351impl<C> PartialEq<i64> for Amount<C> {
352    fn eq(&self, other: &i64) -> bool {
353        self.0.eq(other)
354    }
355}
356
357impl<C> PartialEq<Amount<C>> for i64 {
358    fn eq(&self, other: &Amount<C>) -> bool {
359        self.eq(&other.0)
360    }
361}
362
363impl<C> Eq for Amount<C> {}
364
365impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> {
366    fn partial_cmp(&self, other: &Amount<C2>) -> Option<Ordering> {
367        Some(self.0.cmp(&other.0))
368    }
369}
370
371impl<C> Ord for Amount<C> {
372    fn cmp(&self, other: &Amount<C>) -> Ordering {
373        self.0.cmp(&other.0)
374    }
375}
376
377impl<C> std::ops::Mul<u64> for Amount<C>
378where
379    C: Constraint,
380{
381    type Output = Result<Amount<C>>;
382
383    fn mul(self, rhs: u64) -> Self::Output {
384        // use i128 for multiplication, so we can handle negative Amounts
385        let value = i128::from(self.0)
386            .checked_mul(i128::from(rhs))
387            .expect("multiplying i64 by u64 can't overflow i128");
388
389        value.try_into().map_err(|_| Error::MultiplicationOverflow {
390            amount: self.0,
391            multiplier: rhs,
392            overflowing_result: value,
393        })
394    }
395}
396
397impl<C> std::ops::Mul<Amount<C>> for u64
398where
399    C: Constraint,
400{
401    type Output = Result<Amount<C>>;
402
403    fn mul(self, rhs: Amount<C>) -> Self::Output {
404        rhs.mul(self)
405    }
406}
407
408impl<C> std::ops::Div<u64> for Amount<C>
409where
410    C: Constraint,
411{
412    type Output = Result<Amount<C>>;
413
414    fn div(self, rhs: u64) -> Self::Output {
415        let quotient = i128::from(self.0)
416            .checked_div(i128::from(rhs))
417            .ok_or(Error::DivideByZero { amount: self.0 })?;
418
419        Ok(quotient
420            .try_into()
421            .expect("division by a positive integer always stays within the constraint"))
422    }
423}
424
425impl<C> std::iter::Sum<Amount<C>> for Result<Amount<C>>
426where
427    C: Constraint,
428{
429    fn sum<I: Iterator<Item = Amount<C>>>(mut iter: I) -> Self {
430        let sum = iter.try_fold(Amount::zero(), |acc, amount| acc + amount);
431
432        match sum {
433            Ok(sum) => Ok(sum),
434            Err(Error::Constraint { value, .. }) => Err(Error::SumOverflow {
435                partial_sum: value,
436                remaining_items: iter.count(),
437            }),
438            Err(unexpected_error) => unreachable!("unexpected Add error: {:?}", unexpected_error),
439        }
440    }
441}
442
443impl<'amt, C> std::iter::Sum<&'amt Amount<C>> for Result<Amount<C>>
444where
445    C: Constraint + Copy + 'amt,
446{
447    fn sum<I: Iterator<Item = &'amt Amount<C>>>(iter: I) -> Self {
448        iter.copied().sum()
449    }
450}
451
452// TODO: add infallible impls for NonNegative <-> NegativeOrZero,
453//       when Rust uses trait output types to disambiguate overlapping impls.
454impl<C> std::ops::Neg for Amount<C>
455where
456    C: Constraint,
457{
458    type Output = Amount<NegativeAllowed>;
459    fn neg(self) -> Self::Output {
460        Amount::<NegativeAllowed>::try_from(-self.0)
461            .expect("a negation of any Amount into NegativeAllowed is always valid")
462    }
463}
464
465#[allow(missing_docs)]
466#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
467/// Errors that can be returned when validating [`Amount`]s.
468pub enum Error {
469    /// input {value} is outside of valid range for zatoshi Amount, valid_range={range:?}
470    Constraint {
471        value: i64,
472        range: RangeInclusive<i64>,
473    },
474
475    /// {value} could not be converted to an i64 Amount
476    Convert {
477        value: i128,
478        source: std::num::TryFromIntError,
479    },
480
481    /// i64 overflow when multiplying i64 amount {amount} by u64 {multiplier}, overflowing result {overflowing_result}
482    MultiplicationOverflow {
483        amount: i64,
484        multiplier: u64,
485        overflowing_result: i128,
486    },
487
488    /// cannot divide amount {amount} by zero
489    DivideByZero { amount: i64 },
490
491    /// i64 overflow when summing i64 amounts, partial_sum: {partial_sum}, remaining items: {remaining_items}
492    SumOverflow {
493        partial_sum: i64,
494        remaining_items: usize,
495    },
496}
497
498impl fmt::Display for Error {
499    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
500        f.write_str(&match self {
501            Error::Constraint { value, range } => format!(
502                "input {value} is outside of valid range for zatoshi Amount, valid_range={range:?}"
503            ),
504            Error::Convert { value, .. } => {
505                format!("{value} could not be converted to an i64 Amount")
506            }
507            Error::MultiplicationOverflow {
508                amount,
509                multiplier,
510                overflowing_result,
511            } => format!(
512                "overflow when calculating {amount}i64 * {multiplier}u64 = {overflowing_result}i128"
513            ),
514            Error::DivideByZero { amount } => format!("cannot divide amount {amount} by zero"),
515            Error::SumOverflow {
516                partial_sum,
517                remaining_items,
518            } => format!(
519                "overflow when summing i64 amounts; \
520                          partial sum: {partial_sum}, number of remaining items: {remaining_items}"
521            ),
522        })
523    }
524}
525
526impl Error {
527    /// Returns the invalid value for this error.
528    ///
529    /// This value may be an initial input value, partially calculated value,
530    /// or an overflowing or underflowing value.
531    pub fn invalid_value(&self) -> i128 {
532        use Error::*;
533
534        match self.clone() {
535            Constraint { value, .. } => value.into(),
536            Convert { value, .. } => value,
537            MultiplicationOverflow {
538                overflowing_result, ..
539            } => overflowing_result,
540            DivideByZero { amount } => amount.into(),
541            SumOverflow { partial_sum, .. } => partial_sum.into(),
542        }
543    }
544}
545
546/// Marker type for `Amount` that allows negative values.
547///
548/// ```
549/// # use zebra_chain::amount::{Constraint, MAX_MONEY, NegativeAllowed};
550/// assert_eq!(
551///     NegativeAllowed::valid_range(),
552///     -MAX_MONEY..=MAX_MONEY,
553/// );
554/// ```
555#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
556pub struct NegativeAllowed;
557
558impl Constraint for NegativeAllowed {
559    fn valid_range() -> RangeInclusive<i64> {
560        -MAX_MONEY..=MAX_MONEY
561    }
562}
563
564/// Marker type for `Amount` that requires nonnegative values.
565///
566/// ```
567/// # use zebra_chain::amount::{Constraint, MAX_MONEY, NonNegative};
568/// assert_eq!(
569///     NonNegative::valid_range(),
570///     0..=MAX_MONEY,
571/// );
572/// ```
573#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Default)]
574#[cfg_attr(
575    any(test, feature = "proptest-impl"),
576    derive(proptest_derive::Arbitrary)
577)]
578pub struct NonNegative;
579
580impl Constraint for NonNegative {
581    fn valid_range() -> RangeInclusive<i64> {
582        0..=MAX_MONEY
583    }
584}
585
586/// Marker type for `Amount` that requires negative or zero values.
587///
588/// Used for coinbase transactions in `getblocktemplate` RPCs.
589///
590/// ```
591/// # use zebra_chain::amount::{Constraint, MAX_MONEY, NegativeOrZero};
592/// assert_eq!(
593///     NegativeOrZero::valid_range(),
594///     -MAX_MONEY..=0,
595/// );
596/// ```
597#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
598pub struct NegativeOrZero;
599
600impl Constraint for NegativeOrZero {
601    fn valid_range() -> RangeInclusive<i64> {
602        -MAX_MONEY..=0
603    }
604}
605
606/// Number of zatoshis in 1 ZEC
607pub const COIN: i64 = 100_000_000;
608
609/// The maximum zatoshi amount.
610pub const MAX_MONEY: i64 = 21_000_000 * COIN;
611
612/// A trait for defining constraints on `Amount`
613pub trait Constraint {
614    /// Returns the range of values that are valid under this constraint
615    fn valid_range() -> RangeInclusive<i64>;
616
617    /// Check if an input value is within the valid range
618    fn validate(value: i64) -> Result<i64, Error> {
619        let range = Self::valid_range();
620
621        if !range.contains(&value) {
622            Err(Error::Constraint { value, range })
623        } else {
624            Ok(value)
625        }
626    }
627}
628
629impl ZcashSerialize for Amount<NegativeAllowed> {
630    fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
631        writer.write_i64::<LittleEndian>(self.0)
632    }
633}
634
635impl ZcashDeserialize for Amount<NegativeAllowed> {
636    fn zcash_deserialize<R: std::io::Read>(
637        mut reader: R,
638    ) -> Result<Self, crate::serialization::SerializationError> {
639        Ok(reader.read_i64::<LittleEndian>()?.try_into()?)
640    }
641}
642
643impl ZcashSerialize for Amount<NonNegative> {
644    #[allow(clippy::unwrap_in_result)]
645    fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
646        let amount = self
647            .0
648            .try_into()
649            .expect("constraint guarantees value is positive");
650
651        writer.write_u64::<LittleEndian>(amount)
652    }
653}
654
655impl ZcashDeserialize for Amount<NonNegative> {
656    fn zcash_deserialize<R: std::io::Read>(
657        mut reader: R,
658    ) -> Result<Self, crate::serialization::SerializationError> {
659        Ok(reader.read_u64::<LittleEndian>()?.try_into()?)
660    }
661}
662
663/// Represents a change to the deferred pool balance from a coinbase transaction.
664#[derive(Clone, Copy, Default, Debug, PartialEq, Eq)]
665pub struct DeferredPoolBalanceChange(Amount);
666
667impl DeferredPoolBalanceChange {
668    /// Creates a new [`DeferredPoolBalanceChange`]
669    pub fn new(amount: Amount) -> Self {
670        Self(amount)
671    }
672
673    /// Creates a new [`DeferredPoolBalanceChange`] with a zero value.
674    pub fn zero() -> Self {
675        Self(Amount::zero())
676    }
677
678    /// Consumes `self` and returns the inner [`Amount`] value.
679    pub fn value(self) -> Amount {
680        self.0
681    }
682}