1#[cfg_attr(target_pointer_width = "32", path = "scalar/scalar32.rs")]
4#[cfg_attr(target_pointer_width = "64", path = "scalar/scalar64.rs")]
5mod scalar_impl;
6
7use self::scalar_impl::barrett_reduce;
8use crate::{FieldBytes, NistP256, SecretKey, ORDER_HEX};
9use core::{
10 fmt::{self, Debug},
11 iter::{Product, Sum},
12 ops::{Add, AddAssign, Mul, MulAssign, Neg, Shr, ShrAssign, Sub, SubAssign},
13};
14use elliptic_curve::{
15 bigint::{prelude::*, Limb, U256},
16 group::ff::{self, Field, PrimeField},
17 ops::{Invert, Reduce, ReduceNonZero},
18 rand_core::RngCore,
19 scalar::{FromUintUnchecked, IsHigh},
20 subtle::{
21 Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess,
22 CtOption,
23 },
24 zeroize::DefaultIsZeroes,
25 Curve, ScalarPrimitive,
26};
27
28#[cfg(feature = "bits")]
29use {crate::ScalarBits, elliptic_curve::group::ff::PrimeFieldBits};
30
31#[cfg(feature = "serde")]
32use serdect::serde::{de, ser, Deserialize, Serialize};
33
34pub(crate) const MODULUS: U256 = NistP256::ORDER;
37
38const FRAC_MODULUS_2: Scalar = Scalar(MODULUS.shr_vartime(1));
40
41pub const MU: [u64; 5] = [
45 0x012f_fd85_eedf_9bfe,
46 0x4319_0552_df1a_6c21,
47 0xffff_fffe_ffff_ffff,
48 0x0000_0000_ffff_ffff,
49 0x0000_0000_0000_0001,
50];
51
52#[derive(Clone, Copy, Default)]
84pub struct Scalar(pub(crate) U256);
85
86impl Scalar {
87 pub const ZERO: Self = Self(U256::ZERO);
89
90 pub const ONE: Self = Self(U256::ONE);
92
93 pub fn to_bytes(&self) -> FieldBytes {
95 self.0.to_be_byte_array()
96 }
97
98 pub const fn add(&self, rhs: &Self) -> Self {
100 Self(self.0.add_mod(&rhs.0, &NistP256::ORDER))
101 }
102
103 pub const fn double(&self) -> Self {
105 self.add(self)
106 }
107
108 pub const fn sub(&self, rhs: &Self) -> Self {
110 Self(self.0.sub_mod(&rhs.0, &NistP256::ORDER))
111 }
112
113 pub const fn multiply(&self, rhs: &Self) -> Self {
115 let (lo, hi) = self.0.mul_wide(&rhs.0);
116 Self(barrett_reduce(lo, hi))
117 }
118
119 pub const fn square(&self) -> Self {
121 self.multiply(self)
123 }
124
125 pub const fn shr_vartime(&self, shift: usize) -> Scalar {
129 Self(self.0.shr_vartime(shift))
130 }
131
132 pub fn invert(&self) -> CtOption<Self> {
134 CtOption::new(self.invert_unchecked(), !self.is_zero())
135 }
136
137 const fn invert_unchecked(&self) -> Self {
141 self.pow_vartime(&[
152 0xf3b9_cac2_fc63_254f,
153 0xbce6_faad_a717_9e84,
154 0xffff_ffff_ffff_ffff,
155 0xffff_ffff_0000_0000,
156 ])
157 }
158
159 pub const fn pow_vartime(&self, exp: &[u64]) -> Self {
162 let mut res = Self::ONE;
163
164 let mut i = exp.len();
165 while i > 0 {
166 i -= 1;
167
168 let mut j = 64;
169 while j > 0 {
170 j -= 1;
171 res = res.square();
172
173 if ((exp[i] >> j) & 1) == 1 {
174 res = res.multiply(self);
175 }
176 }
177 }
178
179 res
180 }
181
182 pub fn is_odd(&self) -> Choice {
184 self.0.is_odd()
185 }
186
187 pub fn is_even(&self) -> Choice {
189 !self.is_odd()
190 }
191}
192
193impl AsRef<Scalar> for Scalar {
194 fn as_ref(&self) -> &Scalar {
195 self
196 }
197}
198
199impl Field for Scalar {
200 const ZERO: Self = Self::ZERO;
201 const ONE: Self = Self::ONE;
202
203 fn random(mut rng: impl RngCore) -> Self {
204 let mut bytes = FieldBytes::default();
205
206 loop {
216 rng.fill_bytes(&mut bytes);
217 if let Some(scalar) = Scalar::from_repr(bytes).into() {
218 return scalar;
219 }
220 }
221 }
222
223 #[must_use]
224 fn square(&self) -> Self {
225 Scalar::square(self)
226 }
227
228 #[must_use]
229 fn double(&self) -> Self {
230 self.add(self)
231 }
232
233 fn invert(&self) -> CtOption<Self> {
234 Scalar::invert(self)
235 }
236
237 #[allow(clippy::many_single_char_names)]
240 fn sqrt(&self) -> CtOption<Self> {
241 let w = self.pow_vartime(&[
243 0x279dce5617e3192a,
244 0xfde737d56d38bcf4,
245 0x07ffffffffffffff,
246 0x07fffffff8000000,
247 ]);
248
249 let mut v = Self::S;
250 let mut x = *self * w;
251 let mut b = x * w;
252 let mut z = Self::ROOT_OF_UNITY;
253
254 for max_v in (1..=Self::S).rev() {
255 let mut k = 1;
256 let mut tmp = b.square();
257 let mut j_less_than_v = Choice::from(1);
258
259 for j in 2..max_v {
260 let tmp_is_one = tmp.ct_eq(&Self::ONE);
261 let squared = Self::conditional_select(&tmp, &z, tmp_is_one).square();
262 tmp = Self::conditional_select(&squared, &tmp, tmp_is_one);
263 let new_z = Self::conditional_select(&z, &squared, tmp_is_one);
264 j_less_than_v &= !j.ct_eq(&v);
265 k = u32::conditional_select(&j, &k, tmp_is_one);
266 z = Self::conditional_select(&z, &new_z, j_less_than_v);
267 }
268
269 let result = x * z;
270 x = Self::conditional_select(&result, &x, b.ct_eq(&Self::ONE));
271 z = z.square();
272 b *= z;
273 v = k;
274 }
275
276 CtOption::new(x, x.square().ct_eq(self))
277 }
278
279 fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
280 ff::helpers::sqrt_ratio_generic(num, div)
281 }
282}
283
284impl PrimeField for Scalar {
285 type Repr = FieldBytes;
286
287 const MODULUS: &'static str = ORDER_HEX;
288 const NUM_BITS: u32 = 256;
289 const CAPACITY: u32 = 255;
290 const TWO_INV: Self = Self(U256::from_u8(2)).invert_unchecked();
291 const MULTIPLICATIVE_GENERATOR: Self = Self(U256::from_u8(7));
292 const S: u32 = 4;
293 const ROOT_OF_UNITY: Self = Self(U256::from_be_hex(
294 "ffc97f062a770992ba807ace842a3dfc1546cad004378daf0592d7fbb41e6602",
295 ));
296 const ROOT_OF_UNITY_INV: Self = Self::ROOT_OF_UNITY.invert_unchecked();
297 const DELTA: Self = Self(U256::from_u64(33232930569601));
298
299 fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
304 let inner = U256::from_be_byte_array(bytes);
305 CtOption::new(Self(inner), inner.ct_lt(&NistP256::ORDER))
306 }
307
308 fn to_repr(&self) -> FieldBytes {
309 self.to_bytes()
310 }
311
312 fn is_odd(&self) -> Choice {
313 self.0.is_odd()
314 }
315}
316
317#[cfg(feature = "bits")]
318impl PrimeFieldBits for Scalar {
319 #[cfg(target_pointer_width = "32")]
320 type ReprBits = [u32; 8];
321
322 #[cfg(target_pointer_width = "64")]
323 type ReprBits = [u64; 4];
324
325 fn to_le_bits(&self) -> ScalarBits {
326 self.into()
327 }
328
329 fn char_le_bits() -> ScalarBits {
330 NistP256::ORDER.to_words().into()
331 }
332}
333
334impl DefaultIsZeroes for Scalar {}
335
336impl Eq for Scalar {}
337
338impl FromUintUnchecked for Scalar {
339 type Uint = U256;
340
341 fn from_uint_unchecked(uint: Self::Uint) -> Self {
342 Self(uint)
343 }
344}
345
346impl Invert for Scalar {
347 type Output = CtOption<Self>;
348
349 fn invert(&self) -> CtOption<Self> {
350 self.invert()
351 }
352
353 #[allow(non_snake_case)]
365 fn invert_vartime(&self) -> CtOption<Self> {
366 let mut u = *self;
367 let mut v = Self(MODULUS);
368 let mut A = Self::ONE;
369 let mut C = Self::ZERO;
370
371 while !bool::from(u.is_zero()) {
372 while bool::from(u.is_even()) {
374 u >>= 1;
375
376 let was_odd: bool = A.is_odd().into();
377 A >>= 1;
378
379 if was_odd {
380 A += FRAC_MODULUS_2;
381 A += Self::ONE;
382 }
383 }
384
385 while bool::from(v.is_even()) {
387 v >>= 1;
388
389 let was_odd: bool = C.is_odd().into();
390 C >>= 1;
391
392 if was_odd {
393 C += FRAC_MODULUS_2;
394 C += Self::ONE;
395 }
396 }
397
398 if u >= v {
400 u -= &v;
401 A -= &C;
402 } else {
403 v -= &u;
404 C -= &A;
405 }
406 }
407
408 CtOption::new(C, !self.is_zero())
409 }
410}
411
412impl IsHigh for Scalar {
413 fn is_high(&self) -> Choice {
414 self.0.ct_gt(&FRAC_MODULUS_2.0)
415 }
416}
417
418impl Shr<usize> for Scalar {
419 type Output = Self;
420
421 fn shr(self, rhs: usize) -> Self::Output {
422 self.shr_vartime(rhs)
423 }
424}
425
426impl Shr<usize> for &Scalar {
427 type Output = Scalar;
428
429 fn shr(self, rhs: usize) -> Self::Output {
430 self.shr_vartime(rhs)
431 }
432}
433
434impl ShrAssign<usize> for Scalar {
435 fn shr_assign(&mut self, rhs: usize) {
436 *self = *self >> rhs;
437 }
438}
439
440impl PartialEq for Scalar {
441 fn eq(&self, other: &Self) -> bool {
442 self.ct_eq(other).into()
443 }
444}
445
446impl PartialOrd for Scalar {
447 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
448 Some(self.cmp(other))
449 }
450}
451
452impl Ord for Scalar {
453 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
454 self.0.cmp(&other.0)
455 }
456}
457
458impl From<u32> for Scalar {
459 fn from(k: u32) -> Self {
460 Scalar(k.into())
461 }
462}
463
464impl From<u64> for Scalar {
465 fn from(k: u64) -> Self {
466 Scalar(k.into())
467 }
468}
469
470impl From<u128> for Scalar {
471 fn from(k: u128) -> Self {
472 Scalar(k.into())
473 }
474}
475
476impl From<Scalar> for FieldBytes {
477 fn from(scalar: Scalar) -> Self {
478 scalar.to_bytes()
479 }
480}
481
482impl From<&Scalar> for FieldBytes {
483 fn from(scalar: &Scalar) -> Self {
484 scalar.to_bytes()
485 }
486}
487
488impl From<ScalarPrimitive<NistP256>> for Scalar {
489 fn from(scalar: ScalarPrimitive<NistP256>) -> Scalar {
490 Scalar(*scalar.as_uint())
491 }
492}
493
494impl From<&ScalarPrimitive<NistP256>> for Scalar {
495 fn from(scalar: &ScalarPrimitive<NistP256>) -> Scalar {
496 Scalar(*scalar.as_uint())
497 }
498}
499
500impl From<Scalar> for ScalarPrimitive<NistP256> {
501 fn from(scalar: Scalar) -> ScalarPrimitive<NistP256> {
502 ScalarPrimitive::from(&scalar)
503 }
504}
505
506impl From<&Scalar> for ScalarPrimitive<NistP256> {
507 fn from(scalar: &Scalar) -> ScalarPrimitive<NistP256> {
508 ScalarPrimitive::new(scalar.0).unwrap()
509 }
510}
511
512impl From<&SecretKey> for Scalar {
513 fn from(secret_key: &SecretKey) -> Scalar {
514 *secret_key.to_nonzero_scalar()
515 }
516}
517
518impl From<Scalar> for U256 {
519 fn from(scalar: Scalar) -> U256 {
520 scalar.0
521 }
522}
523
524impl From<&Scalar> for U256 {
525 fn from(scalar: &Scalar) -> U256 {
526 scalar.0
527 }
528}
529
530#[cfg(feature = "bits")]
531impl From<&Scalar> for ScalarBits {
532 fn from(scalar: &Scalar) -> ScalarBits {
533 scalar.0.to_words().into()
534 }
535}
536
537impl Add<Scalar> for Scalar {
538 type Output = Scalar;
539
540 fn add(self, other: Scalar) -> Scalar {
541 Scalar::add(&self, &other)
542 }
543}
544
545impl Add<&Scalar> for &Scalar {
546 type Output = Scalar;
547
548 fn add(self, other: &Scalar) -> Scalar {
549 Scalar::add(self, other)
550 }
551}
552
553impl Add<&Scalar> for Scalar {
554 type Output = Scalar;
555
556 fn add(self, other: &Scalar) -> Scalar {
557 Scalar::add(&self, other)
558 }
559}
560
561impl AddAssign<Scalar> for Scalar {
562 fn add_assign(&mut self, rhs: Scalar) {
563 *self = Scalar::add(self, &rhs);
564 }
565}
566
567impl AddAssign<&Scalar> for Scalar {
568 fn add_assign(&mut self, rhs: &Scalar) {
569 *self = Scalar::add(self, rhs);
570 }
571}
572
573impl Sub<Scalar> for Scalar {
574 type Output = Scalar;
575
576 fn sub(self, other: Scalar) -> Scalar {
577 Scalar::sub(&self, &other)
578 }
579}
580
581impl Sub<&Scalar> for &Scalar {
582 type Output = Scalar;
583
584 fn sub(self, other: &Scalar) -> Scalar {
585 Scalar::sub(self, other)
586 }
587}
588
589impl Sub<&Scalar> for Scalar {
590 type Output = Scalar;
591
592 fn sub(self, other: &Scalar) -> Scalar {
593 Scalar::sub(&self, other)
594 }
595}
596
597impl SubAssign<Scalar> for Scalar {
598 fn sub_assign(&mut self, rhs: Scalar) {
599 *self = Scalar::sub(self, &rhs);
600 }
601}
602
603impl SubAssign<&Scalar> for Scalar {
604 fn sub_assign(&mut self, rhs: &Scalar) {
605 *self = Scalar::sub(self, rhs);
606 }
607}
608
609impl Mul<Scalar> for Scalar {
610 type Output = Scalar;
611
612 fn mul(self, other: Scalar) -> Scalar {
613 Scalar::multiply(&self, &other)
614 }
615}
616
617impl Mul<&Scalar> for &Scalar {
618 type Output = Scalar;
619
620 fn mul(self, other: &Scalar) -> Scalar {
621 Scalar::multiply(self, other)
622 }
623}
624
625impl Mul<&Scalar> for Scalar {
626 type Output = Scalar;
627
628 fn mul(self, other: &Scalar) -> Scalar {
629 Scalar::multiply(&self, other)
630 }
631}
632
633impl MulAssign<Scalar> for Scalar {
634 fn mul_assign(&mut self, rhs: Scalar) {
635 *self = Scalar::multiply(self, &rhs);
636 }
637}
638
639impl MulAssign<&Scalar> for Scalar {
640 fn mul_assign(&mut self, rhs: &Scalar) {
641 *self = Scalar::multiply(self, rhs);
642 }
643}
644
645impl Neg for Scalar {
646 type Output = Scalar;
647
648 fn neg(self) -> Scalar {
649 Scalar::ZERO - self
650 }
651}
652
653impl<'a> Neg for &'a Scalar {
654 type Output = Scalar;
655
656 fn neg(self) -> Scalar {
657 Scalar::ZERO - self
658 }
659}
660
661impl Reduce<U256> for Scalar {
662 type Bytes = FieldBytes;
663
664 fn reduce(w: U256) -> Self {
665 let (r, underflow) = w.sbb(&NistP256::ORDER, Limb::ZERO);
666 let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
667 Self(U256::conditional_select(&w, &r, !underflow))
668 }
669
670 fn reduce_bytes(bytes: &FieldBytes) -> Self {
671 Self::reduce(U256::from_be_byte_array(*bytes))
672 }
673}
674
675impl ReduceNonZero<U256> for Scalar {
676 fn reduce_nonzero(w: U256) -> Self {
677 const ORDER_MINUS_ONE: U256 = NistP256::ORDER.wrapping_sub(&U256::ONE);
678 let (r, underflow) = w.sbb(&ORDER_MINUS_ONE, Limb::ZERO);
679 let underflow = Choice::from((underflow.0 >> (Limb::BITS - 1)) as u8);
680 Self(U256::conditional_select(&w, &r, !underflow).wrapping_add(&U256::ONE))
681 }
682
683 fn reduce_nonzero_bytes(bytes: &FieldBytes) -> Self {
684 Self::reduce_nonzero(U256::from_be_byte_array(*bytes))
685 }
686}
687
688impl Sum for Scalar {
689 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
690 iter.reduce(core::ops::Add::add).unwrap_or(Self::ZERO)
691 }
692}
693
694impl<'a> Sum<&'a Scalar> for Scalar {
695 fn sum<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
696 iter.copied().sum()
697 }
698}
699
700impl Product for Scalar {
701 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
702 iter.reduce(core::ops::Mul::mul).unwrap_or(Self::ONE)
703 }
704}
705
706impl<'a> Product<&'a Scalar> for Scalar {
707 fn product<I: Iterator<Item = &'a Scalar>>(iter: I) -> Self {
708 iter.copied().product()
709 }
710}
711
712impl ConditionallySelectable for Scalar {
713 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
714 Self(U256::conditional_select(&a.0, &b.0, choice))
715 }
716}
717
718impl ConstantTimeEq for Scalar {
719 fn ct_eq(&self, other: &Self) -> Choice {
720 self.0.ct_eq(&other.0)
721 }
722}
723
724impl Debug for Scalar {
725 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
726 write!(f, "Scalar(0x{:X})", &self.0)
727 }
728}
729
730#[cfg(feature = "serde")]
731impl Serialize for Scalar {
732 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
733 where
734 S: ser::Serializer,
735 {
736 ScalarPrimitive::from(self).serialize(serializer)
737 }
738}
739
740#[cfg(feature = "serde")]
741impl<'de> Deserialize<'de> for Scalar {
742 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
743 where
744 D: de::Deserializer<'de>,
745 {
746 Ok(ScalarPrimitive::deserialize(deserializer)?.into())
747 }
748}
749
750#[cfg(test)]
751mod tests {
752 use super::Scalar;
753 use crate::{FieldBytes, SecretKey};
754 use elliptic_curve::group::ff::{Field, PrimeField};
755 use primeorder::{
756 impl_field_identity_tests, impl_field_invert_tests, impl_field_sqrt_tests,
757 impl_primefield_tests,
758 };
759
760 const T: [u64; 4] = [
762 0x4f3b9cac2fc63255,
763 0xfbce6faada7179e8,
764 0x0fffffffffffffff,
765 0x0ffffffff0000000,
766 ];
767
768 impl_field_identity_tests!(Scalar);
769 impl_field_invert_tests!(Scalar);
770 impl_field_sqrt_tests!(Scalar);
771 impl_primefield_tests!(Scalar, T);
772
773 #[test]
774 fn from_to_bytes_roundtrip() {
775 let k: u64 = 42;
776 let mut bytes = FieldBytes::default();
777 bytes[24..].copy_from_slice(k.to_be_bytes().as_ref());
778
779 let scalar = Scalar::from_repr(bytes).unwrap();
780 assert_eq!(bytes, scalar.to_bytes());
781 }
782
783 #[test]
785 fn multiply() {
786 let one = Scalar::ONE;
787 let two = one + &one;
788 let three = two + &one;
789 let six = three + &three;
790 assert_eq!(six, two * &three);
791
792 let minus_two = -two;
793 let minus_three = -three;
794 assert_eq!(two, -minus_two);
795
796 assert_eq!(minus_three * &minus_two, minus_two * &minus_three);
797 assert_eq!(six, minus_two * &minus_three);
798 }
799
800 #[test]
802 fn from_ec_secret() {
803 let scalar = Scalar::ONE;
804 let secret = SecretKey::from_bytes(&scalar.to_bytes()).unwrap();
805 let rederived_scalar = Scalar::from(&secret);
806 assert_eq!(scalar.0, rederived_scalar.0);
807 }
808
809 #[test]
810 #[cfg(all(feature = "bits", target_pointer_width = "32"))]
811 fn scalar_into_scalarbits() {
812 use crate::ScalarBits;
813
814 let minus_one = ScalarBits::from([
815 0xfc63_2550,
816 0xf3b9_cac2,
817 0xa717_9e84,
818 0xbce6_faad,
819 0xffff_ffff,
820 0xffff_ffff,
821 0x0000_0000,
822 0xffff_ffff,
823 ]);
824
825 let scalar_bits = ScalarBits::from(&-Scalar::from(1u32));
826 assert_eq!(minus_one, scalar_bits);
827 }
828}