1use crate::{Error, Result};
9use base16ct::HexDisplay;
10use core::{
11 cmp::Ordering,
12 fmt::{self, Debug},
13 hash::{Hash, Hasher},
14 ops::Add,
15 str,
16};
17use generic_array::{
18 typenum::{U1, U24, U28, U32, U48, U66},
19 ArrayLength, GenericArray,
20};
21
22#[cfg(feature = "alloc")]
23use alloc::boxed::Box;
24
25#[cfg(feature = "serde")]
26use serdect::serde::{de, ser, Deserialize, Serialize};
27
28#[cfg(feature = "subtle")]
29use subtle::{Choice, ConditionallySelectable};
30
31#[cfg(feature = "zeroize")]
32use zeroize::Zeroize;
33
34pub trait ModulusSize: 'static + ArrayLength<u8> + Copy + Debug {
38 type CompressedPointSize: 'static + ArrayLength<u8> + Copy + Debug;
42
43 type UncompressedPointSize: 'static + ArrayLength<u8> + Copy + Debug;
47
48 type UntaggedPointSize: 'static + ArrayLength<u8> + Copy + Debug;
51}
52
53macro_rules! impl_modulus_size {
54 ($($size:ty),+) => {
55 $(impl ModulusSize for $size {
56 type CompressedPointSize = <$size as Add<U1>>::Output;
57 type UncompressedPointSize = <Self::UntaggedPointSize as Add<U1>>::Output;
58 type UntaggedPointSize = <$size as Add>::Output;
59 })+
60 }
61}
62
63impl_modulus_size!(U24, U28, U32, U48, U66);
64
65#[derive(Clone, Default)]
71pub struct EncodedPoint<Size>
72where
73 Size: ModulusSize,
74{
75 bytes: GenericArray<u8, Size::UncompressedPointSize>,
76}
77
78#[allow(clippy::len_without_is_empty)]
79impl<Size> EncodedPoint<Size>
80where
81 Size: ModulusSize,
82{
83 pub fn from_bytes(input: impl AsRef<[u8]>) -> Result<Self> {
90 let input = input.as_ref();
91
92 let tag = input
94 .first()
95 .cloned()
96 .ok_or(Error::PointEncoding)
97 .and_then(Tag::from_u8)?;
98
99 let expected_len = tag.message_len(Size::to_usize());
101
102 if input.len() != expected_len {
103 return Err(Error::PointEncoding);
104 }
105
106 let mut bytes = GenericArray::default();
107 bytes[..expected_len].copy_from_slice(input);
108 Ok(Self { bytes })
109 }
110
111 pub fn from_untagged_bytes(bytes: &GenericArray<u8, Size::UntaggedPointSize>) -> Self {
115 let (x, y) = bytes.split_at(Size::to_usize());
116 Self::from_affine_coordinates(x.into(), y.into(), false)
117 }
118
119 pub fn from_affine_coordinates(
122 x: &GenericArray<u8, Size>,
123 y: &GenericArray<u8, Size>,
124 compress: bool,
125 ) -> Self {
126 let tag = if compress {
127 Tag::compress_y(y.as_slice())
128 } else {
129 Tag::Uncompressed
130 };
131
132 let mut bytes = GenericArray::default();
133 bytes[0] = tag.into();
134 bytes[1..(Size::to_usize() + 1)].copy_from_slice(x);
135
136 if !compress {
137 bytes[(Size::to_usize() + 1)..].copy_from_slice(y);
138 }
139
140 Self { bytes }
141 }
142
143 pub fn identity() -> Self {
146 Self::default()
147 }
148
149 pub fn len(&self) -> usize {
151 self.tag().message_len(Size::to_usize())
152 }
153
154 pub fn as_bytes(&self) -> &[u8] {
156 &self.bytes[..self.len()]
157 }
158
159 #[cfg(feature = "alloc")]
161 pub fn to_bytes(&self) -> Box<[u8]> {
162 self.as_bytes().to_vec().into_boxed_slice()
163 }
164
165 pub fn is_compact(&self) -> bool {
167 self.tag().is_compact()
168 }
169
170 pub fn is_compressed(&self) -> bool {
172 self.tag().is_compressed()
173 }
174
175 pub fn is_identity(&self) -> bool {
177 self.tag().is_identity()
178 }
179
180 pub fn compress(&self) -> Self {
182 match self.coordinates() {
183 Coordinates::Compressed { .. }
184 | Coordinates::Compact { .. }
185 | Coordinates::Identity => self.clone(),
186 Coordinates::Uncompressed { x, y } => Self::from_affine_coordinates(x, y, true),
187 }
188 }
189
190 pub fn tag(&self) -> Tag {
192 Tag::from_u8(self.bytes[0]).expect("invalid tag")
194 }
195
196 #[inline]
198 pub fn coordinates(&self) -> Coordinates<'_, Size> {
199 if self.is_identity() {
200 return Coordinates::Identity;
201 }
202
203 let (x, y) = self.bytes[1..].split_at(Size::to_usize());
204
205 if self.is_compressed() {
206 Coordinates::Compressed {
207 x: x.into(),
208 y_is_odd: self.tag() as u8 & 1 == 1,
209 }
210 } else if self.is_compact() {
211 Coordinates::Compact { x: x.into() }
212 } else {
213 Coordinates::Uncompressed {
214 x: x.into(),
215 y: y.into(),
216 }
217 }
218 }
219
220 pub fn x(&self) -> Option<&GenericArray<u8, Size>> {
224 match self.coordinates() {
225 Coordinates::Identity => None,
226 Coordinates::Compressed { x, .. } => Some(x),
227 Coordinates::Uncompressed { x, .. } => Some(x),
228 Coordinates::Compact { x } => Some(x),
229 }
230 }
231
232 pub fn y(&self) -> Option<&GenericArray<u8, Size>> {
236 match self.coordinates() {
237 Coordinates::Compressed { .. } | Coordinates::Identity => None,
238 Coordinates::Uncompressed { y, .. } => Some(y),
239 Coordinates::Compact { .. } => None,
240 }
241 }
242}
243
244impl<Size> AsRef<[u8]> for EncodedPoint<Size>
245where
246 Size: ModulusSize,
247{
248 #[inline]
249 fn as_ref(&self) -> &[u8] {
250 self.as_bytes()
251 }
252}
253
254#[cfg(feature = "subtle")]
255impl<Size> ConditionallySelectable for EncodedPoint<Size>
256where
257 Size: ModulusSize,
258 <Size::UncompressedPointSize as ArrayLength<u8>>::ArrayType: Copy,
259{
260 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
261 let mut bytes = GenericArray::default();
262
263 for (i, byte) in bytes.iter_mut().enumerate() {
264 *byte = u8::conditional_select(&a.bytes[i], &b.bytes[i], choice);
265 }
266
267 Self { bytes }
268 }
269}
270
271impl<Size> Copy for EncodedPoint<Size>
272where
273 Size: ModulusSize,
274 <Size::UncompressedPointSize as ArrayLength<u8>>::ArrayType: Copy,
275{
276}
277
278impl<Size> Debug for EncodedPoint<Size>
279where
280 Size: ModulusSize,
281{
282 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
283 write!(f, "EncodedPoint({:?})", self.coordinates())
284 }
285}
286
287impl<Size: ModulusSize> Eq for EncodedPoint<Size> {}
288
289impl<Size> PartialEq for EncodedPoint<Size>
290where
291 Size: ModulusSize,
292{
293 fn eq(&self, other: &Self) -> bool {
294 self.as_bytes() == other.as_bytes()
295 }
296}
297
298impl<Size> Hash for EncodedPoint<Size>
299where
300 Size: ModulusSize,
301{
302 fn hash<H: Hasher>(&self, state: &mut H) {
303 self.as_bytes().hash(state)
304 }
305}
306
307impl<Size: ModulusSize> PartialOrd for EncodedPoint<Size>
308where
309 Size: ModulusSize,
310{
311 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
312 Some(self.cmp(other))
313 }
314}
315
316impl<Size: ModulusSize> Ord for EncodedPoint<Size>
317where
318 Size: ModulusSize,
319{
320 fn cmp(&self, other: &Self) -> Ordering {
321 self.as_bytes().cmp(other.as_bytes())
322 }
323}
324
325impl<Size: ModulusSize> TryFrom<&[u8]> for EncodedPoint<Size>
326where
327 Size: ModulusSize,
328{
329 type Error = Error;
330
331 fn try_from(bytes: &[u8]) -> Result<Self> {
332 Self::from_bytes(bytes)
333 }
334}
335
336#[cfg(feature = "zeroize")]
337impl<Size> Zeroize for EncodedPoint<Size>
338where
339 Size: ModulusSize,
340{
341 fn zeroize(&mut self) {
342 self.bytes.zeroize();
343 *self = Self::identity();
344 }
345}
346
347impl<Size> fmt::Display for EncodedPoint<Size>
348where
349 Size: ModulusSize,
350{
351 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
352 write!(f, "{:X}", self)
353 }
354}
355
356impl<Size> fmt::LowerHex for EncodedPoint<Size>
357where
358 Size: ModulusSize,
359{
360 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
361 write!(f, "{:x}", HexDisplay(self.as_bytes()))
362 }
363}
364
365impl<Size> fmt::UpperHex for EncodedPoint<Size>
366where
367 Size: ModulusSize,
368{
369 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370 write!(f, "{:X}", HexDisplay(self.as_bytes()))
371 }
372}
373
374impl<Size> str::FromStr for EncodedPoint<Size>
379where
380 Size: ModulusSize,
381{
382 type Err = Error;
383
384 fn from_str(hex: &str) -> Result<Self> {
385 let mut buf = GenericArray::<u8, Size::UncompressedPointSize>::default();
386 base16ct::mixed::decode(hex, &mut buf)
387 .map_err(|_| Error::PointEncoding)
388 .and_then(Self::from_bytes)
389 }
390}
391
392#[cfg(feature = "serde")]
393impl<Size> Serialize for EncodedPoint<Size>
394where
395 Size: ModulusSize,
396{
397 fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
398 where
399 S: ser::Serializer,
400 {
401 serdect::slice::serialize_hex_upper_or_bin(&self.as_bytes(), serializer)
402 }
403}
404
405#[cfg(feature = "serde")]
406impl<'de, Size> Deserialize<'de> for EncodedPoint<Size>
407where
408 Size: ModulusSize,
409{
410 fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
411 where
412 D: de::Deserializer<'de>,
413 {
414 let bytes = serdect::slice::deserialize_hex_or_bin_vec(deserializer)?;
415 Self::from_bytes(bytes).map_err(de::Error::custom)
416 }
417}
418
419#[derive(Copy, Clone, Debug, Eq, PartialEq)]
422pub enum Coordinates<'a, Size: ModulusSize> {
423 Identity,
425
426 Compact {
428 x: &'a GenericArray<u8, Size>,
430 },
431
432 Compressed {
434 x: &'a GenericArray<u8, Size>,
436
437 y_is_odd: bool,
439 },
440
441 Uncompressed {
443 x: &'a GenericArray<u8, Size>,
445
446 y: &'a GenericArray<u8, Size>,
448 },
449}
450
451impl<'a, Size: ModulusSize> Coordinates<'a, Size> {
452 pub fn tag(&self) -> Tag {
454 match self {
455 Coordinates::Compact { .. } => Tag::Compact,
456 Coordinates::Compressed { y_is_odd, .. } => {
457 if *y_is_odd {
458 Tag::CompressedOddY
459 } else {
460 Tag::CompressedEvenY
461 }
462 }
463 Coordinates::Identity => Tag::Identity,
464 Coordinates::Uncompressed { .. } => Tag::Uncompressed,
465 }
466 }
467}
468
469#[derive(Copy, Clone, Debug, Eq, PartialEq)]
471#[repr(u8)]
472pub enum Tag {
473 Identity = 0,
475
476 CompressedEvenY = 2,
478
479 CompressedOddY = 3,
481
482 Uncompressed = 4,
484
485 Compact = 5,
487}
488
489impl Tag {
490 pub fn from_u8(byte: u8) -> Result<Self> {
492 match byte {
493 0 => Ok(Tag::Identity),
494 2 => Ok(Tag::CompressedEvenY),
495 3 => Ok(Tag::CompressedOddY),
496 4 => Ok(Tag::Uncompressed),
497 5 => Ok(Tag::Compact),
498 _ => Err(Error::PointEncoding),
499 }
500 }
501
502 pub fn is_compact(self) -> bool {
504 matches!(self, Tag::Compact)
505 }
506
507 pub fn is_compressed(self) -> bool {
509 matches!(self, Tag::CompressedEvenY | Tag::CompressedOddY)
510 }
511
512 pub fn is_identity(self) -> bool {
514 self == Tag::Identity
515 }
516
517 pub fn message_len(self, field_element_size: usize) -> usize {
521 1 + match self {
522 Tag::Identity => 0,
523 Tag::CompressedEvenY | Tag::CompressedOddY => field_element_size,
524 Tag::Uncompressed => field_element_size * 2,
525 Tag::Compact => field_element_size,
526 }
527 }
528
529 fn compress_y(y: &[u8]) -> Self {
531 if y.as_ref().last().expect("empty y-coordinate") & 1 == 1 {
533 Tag::CompressedOddY
534 } else {
535 Tag::CompressedEvenY
536 }
537 }
538}
539
540impl TryFrom<u8> for Tag {
541 type Error = Error;
542
543 fn try_from(byte: u8) -> Result<Self> {
544 Self::from_u8(byte)
545 }
546}
547
548impl From<Tag> for u8 {
549 fn from(tag: Tag) -> u8 {
550 tag as u8
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::{Coordinates, Tag};
557 use core::str::FromStr;
558 use generic_array::{typenum::U32, GenericArray};
559 use hex_literal::hex;
560
561 #[cfg(feature = "alloc")]
562 use alloc::string::ToString;
563
564 #[cfg(feature = "subtle")]
565 use subtle::ConditionallySelectable;
566
567 type EncodedPoint = super::EncodedPoint<U32>;
568
569 const IDENTITY_BYTES: [u8; 1] = [0];
571
572 const UNCOMPRESSED_BYTES: [u8; 65] = hex!("0411111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222");
574
575 const COMPRESSED_BYTES: [u8; 33] =
577 hex!("021111111111111111111111111111111111111111111111111111111111111111");
578
579 #[test]
580 fn decode_compressed_point() {
581 let compressed_even_y_bytes =
583 hex!("020100000000000000000000000000000000000000000000000000000000000000");
584
585 let compressed_even_y = EncodedPoint::from_bytes(&compressed_even_y_bytes[..]).unwrap();
586
587 assert!(compressed_even_y.is_compressed());
588 assert_eq!(compressed_even_y.tag(), Tag::CompressedEvenY);
589 assert_eq!(compressed_even_y.len(), 33);
590 assert_eq!(compressed_even_y.as_bytes(), &compressed_even_y_bytes[..]);
591
592 assert_eq!(
593 compressed_even_y.coordinates(),
594 Coordinates::Compressed {
595 x: &hex!("0100000000000000000000000000000000000000000000000000000000000000").into(),
596 y_is_odd: false
597 }
598 );
599
600 assert_eq!(
601 compressed_even_y.x().unwrap(),
602 &hex!("0100000000000000000000000000000000000000000000000000000000000000").into()
603 );
604 assert_eq!(compressed_even_y.y(), None);
605
606 let compressed_odd_y_bytes =
608 hex!("030200000000000000000000000000000000000000000000000000000000000000");
609
610 let compressed_odd_y = EncodedPoint::from_bytes(&compressed_odd_y_bytes[..]).unwrap();
611
612 assert!(compressed_odd_y.is_compressed());
613 assert_eq!(compressed_odd_y.tag(), Tag::CompressedOddY);
614 assert_eq!(compressed_odd_y.len(), 33);
615 assert_eq!(compressed_odd_y.as_bytes(), &compressed_odd_y_bytes[..]);
616
617 assert_eq!(
618 compressed_odd_y.coordinates(),
619 Coordinates::Compressed {
620 x: &hex!("0200000000000000000000000000000000000000000000000000000000000000").into(),
621 y_is_odd: true
622 }
623 );
624
625 assert_eq!(
626 compressed_odd_y.x().unwrap(),
627 &hex!("0200000000000000000000000000000000000000000000000000000000000000").into()
628 );
629 assert_eq!(compressed_odd_y.y(), None);
630 }
631
632 #[test]
633 fn decode_uncompressed_point() {
634 let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
635
636 assert!(!uncompressed_point.is_compressed());
637 assert_eq!(uncompressed_point.tag(), Tag::Uncompressed);
638 assert_eq!(uncompressed_point.len(), 65);
639 assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
640
641 assert_eq!(
642 uncompressed_point.coordinates(),
643 Coordinates::Uncompressed {
644 x: &hex!("1111111111111111111111111111111111111111111111111111111111111111").into(),
645 y: &hex!("2222222222222222222222222222222222222222222222222222222222222222").into()
646 }
647 );
648
649 assert_eq!(
650 uncompressed_point.x().unwrap(),
651 &hex!("1111111111111111111111111111111111111111111111111111111111111111").into()
652 );
653 assert_eq!(
654 uncompressed_point.y().unwrap(),
655 &hex!("2222222222222222222222222222222222222222222222222222222222222222").into()
656 );
657 }
658
659 #[test]
660 fn decode_identity() {
661 let identity_point = EncodedPoint::from_bytes(&IDENTITY_BYTES[..]).unwrap();
662 assert!(identity_point.is_identity());
663 assert_eq!(identity_point.tag(), Tag::Identity);
664 assert_eq!(identity_point.len(), 1);
665 assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
666 assert_eq!(identity_point.coordinates(), Coordinates::Identity);
667 assert_eq!(identity_point.x(), None);
668 assert_eq!(identity_point.y(), None);
669 }
670
671 #[test]
672 fn decode_invalid_tag() {
673 let mut compressed_bytes = COMPRESSED_BYTES;
674 let mut uncompressed_bytes = UNCOMPRESSED_BYTES;
675
676 for bytes in &mut [&mut compressed_bytes[..], &mut uncompressed_bytes[..]] {
677 for tag in 0..=0xFF {
678 if tag == 2 || tag == 3 || tag == 4 || tag == 5 {
680 continue;
681 }
682
683 (*bytes)[0] = tag;
684 let decode_result = EncodedPoint::from_bytes(&*bytes);
685 assert!(decode_result.is_err());
686 }
687 }
688 }
689
690 #[test]
691 fn decode_truncated_point() {
692 for bytes in &[&COMPRESSED_BYTES[..], &UNCOMPRESSED_BYTES[..]] {
693 for len in 0..bytes.len() {
694 let decode_result = EncodedPoint::from_bytes(&bytes[..len]);
695 assert!(decode_result.is_err());
696 }
697 }
698 }
699
700 #[test]
701 fn from_untagged_point() {
702 let untagged_bytes = hex!("11111111111111111111111111111111111111111111111111111111111111112222222222222222222222222222222222222222222222222222222222222222");
703 let uncompressed_point =
704 EncodedPoint::from_untagged_bytes(GenericArray::from_slice(&untagged_bytes[..]));
705 assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
706 }
707
708 #[test]
709 fn from_affine_coordinates() {
710 let x = hex!("1111111111111111111111111111111111111111111111111111111111111111");
711 let y = hex!("2222222222222222222222222222222222222222222222222222222222222222");
712
713 let uncompressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), false);
714 assert_eq!(uncompressed_point.as_bytes(), &UNCOMPRESSED_BYTES[..]);
715
716 let compressed_point = EncodedPoint::from_affine_coordinates(&x.into(), &y.into(), true);
717 assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
718 }
719
720 #[test]
721 fn compress() {
722 let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
723 let compressed_point = uncompressed_point.compress();
724 assert_eq!(compressed_point.as_bytes(), &COMPRESSED_BYTES[..]);
725 }
726
727 #[cfg(feature = "subtle")]
728 #[test]
729 fn conditional_select() {
730 let a = EncodedPoint::from_bytes(&COMPRESSED_BYTES[..]).unwrap();
731 let b = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
732
733 let a_selected = EncodedPoint::conditional_select(&a, &b, 0.into());
734 assert_eq!(a, a_selected);
735
736 let b_selected = EncodedPoint::conditional_select(&a, &b, 1.into());
737 assert_eq!(b, b_selected);
738 }
739
740 #[test]
741 fn identity() {
742 let identity_point = EncodedPoint::identity();
743 assert_eq!(identity_point.tag(), Tag::Identity);
744 assert_eq!(identity_point.len(), 1);
745 assert_eq!(identity_point.as_bytes(), &IDENTITY_BYTES[..]);
746
747 assert_eq!(identity_point, EncodedPoint::default());
749 }
750
751 #[test]
752 fn decode_hex() {
753 let point = EncodedPoint::from_str(
754 "021111111111111111111111111111111111111111111111111111111111111111",
755 )
756 .unwrap();
757 assert_eq!(point.as_bytes(), COMPRESSED_BYTES);
758 }
759
760 #[cfg(feature = "alloc")]
761 #[test]
762 fn to_bytes() {
763 let uncompressed_point = EncodedPoint::from_bytes(&UNCOMPRESSED_BYTES[..]).unwrap();
764 assert_eq!(&*uncompressed_point.to_bytes(), &UNCOMPRESSED_BYTES[..]);
765 }
766
767 #[cfg(feature = "alloc")]
768 #[test]
769 fn to_string() {
770 let point = EncodedPoint::from_bytes(&COMPRESSED_BYTES[..]).unwrap();
771 assert_eq!(
772 point.to_string(),
773 "021111111111111111111111111111111111111111111111111111111111111111"
774 );
775 }
776}