1use std::{
9 cmp::Ordering,
10 fmt,
11 hash::{Hash, Hasher},
12 marker::PhantomData,
13 ops::RangeInclusive,
14};
15
16use crate::serialization::{ZcashDeserialize, ZcashSerialize};
17use byteorder::{ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt};
18
19#[cfg(any(test, feature = "proptest-impl"))]
20pub mod arbitrary;
21
22#[cfg(test)]
23mod tests;
24
25pub type Result<T, E = Error> = std::result::Result<T, E>;
27
28#[derive(Clone, Copy, Serialize, Deserialize, Default)]
34#[serde(try_from = "i64")]
35#[serde(into = "i64")]
36#[serde(bound = "C: Constraint + Clone")]
37pub struct Amount<C = NegativeAllowed>(
38 i64,
40 #[serde(skip)]
47 PhantomData<C>,
48);
49
50impl<C> fmt::Display for Amount<C> {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 let zats = self.zatoshis();
53
54 f.pad_integral(zats > 0, "", &zats.to_string())
55 }
56}
57
58impl<C> fmt::Debug for Amount<C> {
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 f.debug_tuple(&format!("Amount<{}>", std::any::type_name::<C>()))
61 .field(&self.0)
62 .finish()
63 }
64}
65
66impl Amount<NonNegative> {
67 pub const fn new_from_zec(zec_value: i64) -> Self {
69 Self::new(zec_value.checked_mul(COIN).expect("should fit in i64"))
70 }
71
72 pub const fn new(zatoshis: i64) -> Self {
74 assert!(zatoshis <= MAX_MONEY && zatoshis >= 0);
75 Self(zatoshis, PhantomData)
76 }
77
78 pub const fn div_exact(self, rhs: i64) -> Self {
80 let result = self.0.checked_div(rhs).expect("divisor must be non-zero");
81 if self.0 % rhs != 0 {
82 panic!("divisor must divide amount evenly, no remainder");
83 }
84
85 Self(result, PhantomData)
86 }
87}
88
89impl<C> Amount<C> {
90 pub fn constrain<C2>(self) -> Result<Amount<C2>>
92 where
93 C2: Constraint,
94 {
95 self.0.try_into()
96 }
97
98 pub fn zatoshis(&self) -> i64 {
100 self.0
101 }
102
103 pub fn checked_sub<C2: Constraint>(self, rhs: Amount<C2>) -> Option<Amount> {
105 self.0.checked_sub(rhs.0).and_then(|v| v.try_into().ok())
106 }
107
108 pub fn to_bytes(&self) -> [u8; 8] {
110 let mut buf: [u8; 8] = [0; 8];
111 LittleEndian::write_i64(&mut buf, self.0);
112 buf
113 }
114
115 pub fn from_bytes(bytes: [u8; 8]) -> Result<Amount<C>>
117 where
118 C: Constraint,
119 {
120 let amount = i64::from_le_bytes(bytes);
121 amount.try_into()
122 }
123
124 pub fn zero() -> Amount<C>
126 where
127 C: Constraint,
128 {
129 0.try_into().expect("an amount of 0 is always valid")
130 }
131
132 pub fn is_zero(&self) -> bool {
134 self.0 == 0
135 }
136}
137
138impl<C> std::ops::Add<Amount<C>> for Amount<C>
139where
140 C: Constraint,
141{
142 type Output = Result<Amount<C>>;
143
144 fn add(self, rhs: Amount<C>) -> Self::Output {
145 let value = self
146 .0
147 .checked_add(rhs.0)
148 .expect("adding two constrained Amounts is always within an i64");
149 value.try_into()
150 }
151}
152
153impl<C> std::ops::Add<Amount<C>> for Result<Amount<C>>
154where
155 C: Constraint,
156{
157 type Output = Result<Amount<C>>;
158
159 fn add(self, rhs: Amount<C>) -> Self::Output {
160 self? + rhs
161 }
162}
163
164impl<C> std::ops::Add<Result<Amount<C>>> for Amount<C>
165where
166 C: Constraint,
167{
168 type Output = Result<Amount<C>>;
169
170 fn add(self, rhs: Result<Amount<C>>) -> Self::Output {
171 self + rhs?
172 }
173}
174
175impl<C> std::ops::AddAssign<Amount<C>> for Result<Amount<C>>
176where
177 Amount<C>: Copy,
178 C: Constraint,
179{
180 fn add_assign(&mut self, rhs: Amount<C>) {
181 if let Ok(lhs) = *self {
182 *self = lhs + rhs;
183 }
184 }
185}
186
187impl<C> std::ops::Sub<Amount<C>> for Amount<C>
188where
189 C: Constraint,
190{
191 type Output = Result<Amount<C>>;
192
193 fn sub(self, rhs: Amount<C>) -> Self::Output {
194 let value = self
195 .0
196 .checked_sub(rhs.0)
197 .expect("subtracting two constrained Amounts is always within an i64");
198 value.try_into()
199 }
200}
201
202impl<C> std::ops::Sub<Amount<C>> for Result<Amount<C>>
203where
204 C: Constraint,
205{
206 type Output = Result<Amount<C>>;
207
208 fn sub(self, rhs: Amount<C>) -> Self::Output {
209 self? - rhs
210 }
211}
212
213impl<C> std::ops::Sub<Result<Amount<C>>> for Amount<C>
214where
215 C: Constraint,
216{
217 type Output = Result<Amount<C>>;
218
219 fn sub(self, rhs: Result<Amount<C>>) -> Self::Output {
220 self - rhs?
221 }
222}
223
224impl<C> std::ops::SubAssign<Amount<C>> for Result<Amount<C>>
225where
226 Amount<C>: Copy,
227 C: Constraint,
228{
229 fn sub_assign(&mut self, rhs: Amount<C>) {
230 if let Ok(lhs) = *self {
231 *self = lhs - rhs;
232 }
233 }
234}
235
236impl<C> From<Amount<C>> for i64 {
237 fn from(amount: Amount<C>) -> Self {
238 amount.0
239 }
240}
241
242impl From<Amount<NonNegative>> for u64 {
243 fn from(amount: Amount<NonNegative>) -> Self {
244 amount.0.try_into().expect("non-negative i64 fits in u64")
245 }
246}
247
248impl<C> From<Amount<C>> for jubjub::Fr {
249 fn from(a: Amount<C>) -> jubjub::Fr {
250 if a.0 < 0 {
252 let abs_amount = i128::from(a.0)
253 .checked_abs()
254 .expect("absolute i64 fits in i128");
255 let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
256
257 jubjub::Fr::from(abs_amount).neg()
258 } else {
259 jubjub::Fr::from(u64::try_from(a.0).expect("non-negative i64 fits in u64"))
260 }
261 }
262}
263
264impl<C> From<Amount<C>> for halo2::pasta::pallas::Scalar {
265 fn from(a: Amount<C>) -> halo2::pasta::pallas::Scalar {
266 if a.0 < 0 {
268 let abs_amount = i128::from(a.0)
269 .checked_abs()
270 .expect("absolute i64 fits in i128");
271 let abs_amount = u64::try_from(abs_amount).expect("absolute i64 fits in u64");
272
273 halo2::pasta::pallas::Scalar::from(abs_amount).neg()
274 } else {
275 halo2::pasta::pallas::Scalar::from(
276 u64::try_from(a.0).expect("non-negative i64 fits in u64"),
277 )
278 }
279 }
280}
281
282impl<C> TryFrom<i32> for Amount<C>
283where
284 C: Constraint,
285{
286 type Error = Error;
287
288 fn try_from(value: i32) -> Result<Self, Self::Error> {
289 C::validate(value.into()).map(|v| Self(v, PhantomData))
290 }
291}
292
293impl<C> TryFrom<i64> for Amount<C>
294where
295 C: Constraint,
296{
297 type Error = Error;
298
299 fn try_from(value: i64) -> Result<Self, Self::Error> {
300 C::validate(value).map(|v| Self(v, PhantomData))
301 }
302}
303
304impl<C> TryFrom<u64> for Amount<C>
305where
306 C: Constraint,
307{
308 type Error = Error;
309
310 fn try_from(value: u64) -> Result<Self, Self::Error> {
311 let value = value.try_into().map_err(|source| Error::Convert {
312 value: value.into(),
313 source,
314 })?;
315
316 C::validate(value).map(|v| Self(v, PhantomData))
317 }
318}
319
320impl<C> TryFrom<i128> for Amount<C>
324where
325 C: Constraint,
326{
327 type Error = Error;
328
329 fn try_from(value: i128) -> Result<Self, Self::Error> {
330 let value = value
331 .try_into()
332 .map_err(|source| Error::Convert { value, source })?;
333
334 C::validate(value).map(|v| Self(v, PhantomData))
335 }
336}
337
338impl<C> Hash for Amount<C> {
339 fn hash<H: Hasher>(&self, state: &mut H) {
341 self.0.hash(state);
342 }
343}
344
345impl<C1, C2> PartialEq<Amount<C2>> for Amount<C1> {
346 fn eq(&self, other: &Amount<C2>) -> bool {
347 self.0.eq(&other.0)
348 }
349}
350
351impl<C> PartialEq<i64> for Amount<C> {
352 fn eq(&self, other: &i64) -> bool {
353 self.0.eq(other)
354 }
355}
356
357impl<C> PartialEq<Amount<C>> for i64 {
358 fn eq(&self, other: &Amount<C>) -> bool {
359 self.eq(&other.0)
360 }
361}
362
363impl<C> Eq for Amount<C> {}
364
365impl<C1, C2> PartialOrd<Amount<C2>> for Amount<C1> {
366 fn partial_cmp(&self, other: &Amount<C2>) -> Option<Ordering> {
367 Some(self.0.cmp(&other.0))
368 }
369}
370
371impl<C> Ord for Amount<C> {
372 fn cmp(&self, other: &Amount<C>) -> Ordering {
373 self.0.cmp(&other.0)
374 }
375}
376
377impl<C> std::ops::Mul<u64> for Amount<C>
378where
379 C: Constraint,
380{
381 type Output = Result<Amount<C>>;
382
383 fn mul(self, rhs: u64) -> Self::Output {
384 let value = i128::from(self.0)
386 .checked_mul(i128::from(rhs))
387 .expect("multiplying i64 by u64 can't overflow i128");
388
389 value.try_into().map_err(|_| Error::MultiplicationOverflow {
390 amount: self.0,
391 multiplier: rhs,
392 overflowing_result: value,
393 })
394 }
395}
396
397impl<C> std::ops::Mul<Amount<C>> for u64
398where
399 C: Constraint,
400{
401 type Output = Result<Amount<C>>;
402
403 fn mul(self, rhs: Amount<C>) -> Self::Output {
404 rhs.mul(self)
405 }
406}
407
408impl<C> std::ops::Div<u64> for Amount<C>
409where
410 C: Constraint,
411{
412 type Output = Result<Amount<C>>;
413
414 fn div(self, rhs: u64) -> Self::Output {
415 let quotient = i128::from(self.0)
416 .checked_div(i128::from(rhs))
417 .ok_or(Error::DivideByZero { amount: self.0 })?;
418
419 Ok(quotient
420 .try_into()
421 .expect("division by a positive integer always stays within the constraint"))
422 }
423}
424
425impl<C> std::iter::Sum<Amount<C>> for Result<Amount<C>>
426where
427 C: Constraint,
428{
429 fn sum<I: Iterator<Item = Amount<C>>>(mut iter: I) -> Self {
430 let sum = iter.try_fold(Amount::zero(), |acc, amount| acc + amount);
431
432 match sum {
433 Ok(sum) => Ok(sum),
434 Err(Error::Constraint { value, .. }) => Err(Error::SumOverflow {
435 partial_sum: value,
436 remaining_items: iter.count(),
437 }),
438 Err(unexpected_error) => unreachable!("unexpected Add error: {:?}", unexpected_error),
439 }
440 }
441}
442
443impl<'amt, C> std::iter::Sum<&'amt Amount<C>> for Result<Amount<C>>
444where
445 C: Constraint + Copy + 'amt,
446{
447 fn sum<I: Iterator<Item = &'amt Amount<C>>>(iter: I) -> Self {
448 iter.copied().sum()
449 }
450}
451
452impl<C> std::ops::Neg for Amount<C>
455where
456 C: Constraint,
457{
458 type Output = Amount<NegativeAllowed>;
459 fn neg(self) -> Self::Output {
460 Amount::<NegativeAllowed>::try_from(-self.0)
461 .expect("a negation of any Amount into NegativeAllowed is always valid")
462 }
463}
464
465#[allow(missing_docs)]
466#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
467pub enum Error {
469 Constraint {
471 value: i64,
472 range: RangeInclusive<i64>,
473 },
474
475 Convert {
477 value: i128,
478 source: std::num::TryFromIntError,
479 },
480
481 MultiplicationOverflow {
483 amount: i64,
484 multiplier: u64,
485 overflowing_result: i128,
486 },
487
488 DivideByZero { amount: i64 },
490
491 SumOverflow {
493 partial_sum: i64,
494 remaining_items: usize,
495 },
496}
497
498impl fmt::Display for Error {
499 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
500 f.write_str(&match self {
501 Error::Constraint { value, range } => format!(
502 "input {value} is outside of valid range for zatoshi Amount, valid_range={range:?}"
503 ),
504 Error::Convert { value, .. } => {
505 format!("{value} could not be converted to an i64 Amount")
506 }
507 Error::MultiplicationOverflow {
508 amount,
509 multiplier,
510 overflowing_result,
511 } => format!(
512 "overflow when calculating {amount}i64 * {multiplier}u64 = {overflowing_result}i128"
513 ),
514 Error::DivideByZero { amount } => format!("cannot divide amount {amount} by zero"),
515 Error::SumOverflow {
516 partial_sum,
517 remaining_items,
518 } => format!(
519 "overflow when summing i64 amounts; \
520 partial sum: {partial_sum}, number of remaining items: {remaining_items}"
521 ),
522 })
523 }
524}
525
526impl Error {
527 pub fn invalid_value(&self) -> i128 {
532 use Error::*;
533
534 match self.clone() {
535 Constraint { value, .. } => value.into(),
536 Convert { value, .. } => value,
537 MultiplicationOverflow {
538 overflowing_result, ..
539 } => overflowing_result,
540 DivideByZero { amount } => amount.into(),
541 SumOverflow { partial_sum, .. } => partial_sum.into(),
542 }
543 }
544}
545
546#[derive(Clone, Copy, Debug, Eq, PartialEq, Default)]
556pub struct NegativeAllowed;
557
558impl Constraint for NegativeAllowed {
559 fn valid_range() -> RangeInclusive<i64> {
560 -MAX_MONEY..=MAX_MONEY
561 }
562}
563
564#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Default)]
574#[cfg_attr(
575 any(test, feature = "proptest-impl"),
576 derive(proptest_derive::Arbitrary)
577)]
578pub struct NonNegative;
579
580impl Constraint for NonNegative {
581 fn valid_range() -> RangeInclusive<i64> {
582 0..=MAX_MONEY
583 }
584}
585
586#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
598pub struct NegativeOrZero;
599
600impl Constraint for NegativeOrZero {
601 fn valid_range() -> RangeInclusive<i64> {
602 -MAX_MONEY..=0
603 }
604}
605
606pub const COIN: i64 = 100_000_000;
608
609pub const MAX_MONEY: i64 = 21_000_000 * COIN;
611
612pub trait Constraint {
614 fn valid_range() -> RangeInclusive<i64>;
616
617 fn validate(value: i64) -> Result<i64, Error> {
619 let range = Self::valid_range();
620
621 if !range.contains(&value) {
622 Err(Error::Constraint { value, range })
623 } else {
624 Ok(value)
625 }
626 }
627}
628
629impl ZcashSerialize for Amount<NegativeAllowed> {
630 fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
631 writer.write_i64::<LittleEndian>(self.0)
632 }
633}
634
635impl ZcashDeserialize for Amount<NegativeAllowed> {
636 fn zcash_deserialize<R: std::io::Read>(
637 mut reader: R,
638 ) -> Result<Self, crate::serialization::SerializationError> {
639 Ok(reader.read_i64::<LittleEndian>()?.try_into()?)
640 }
641}
642
643impl ZcashSerialize for Amount<NonNegative> {
644 #[allow(clippy::unwrap_in_result)]
645 fn zcash_serialize<W: std::io::Write>(&self, mut writer: W) -> Result<(), std::io::Error> {
646 let amount = self
647 .0
648 .try_into()
649 .expect("constraint guarantees value is positive");
650
651 writer.write_u64::<LittleEndian>(amount)
652 }
653}
654
655impl ZcashDeserialize for Amount<NonNegative> {
656 fn zcash_deserialize<R: std::io::Read>(
657 mut reader: R,
658 ) -> Result<Self, crate::serialization::SerializationError> {
659 Ok(reader.read_u64::<LittleEndian>()?.try_into()?)
660 }
661}
662
663#[derive(Clone, Copy, Default, Debug, PartialEq, Eq)]
665pub struct DeferredPoolBalanceChange(Amount);
666
667impl DeferredPoolBalanceChange {
668 pub fn new(amount: Amount) -> Self {
670 Self(amount)
671 }
672
673 pub fn zero() -> Self {
675 Self(Amount::zero())
676 }
677
678 pub fn value(self) -> Amount {
680 self.0
681 }
682}