elliptic_curve/point/
non_identity.rs

1//! Non-identity point type.
2
3use core::ops::{Deref, Mul};
4
5use group::{prime::PrimeCurveAffine, Curve, GroupEncoding};
6use rand_core::{CryptoRng, RngCore};
7use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
8
9#[cfg(feature = "serde")]
10use serdect::serde::{de, ser, Deserialize, Serialize};
11
12use crate::{CurveArithmetic, NonZeroScalar, Scalar};
13
14/// Non-identity point type.
15///
16/// This type ensures that its value is not the identity point, ala `core::num::NonZero*`.
17///
18/// In the context of ECC, it's useful for ensuring that certain arithmetic
19/// cannot result in the identity point.
20#[derive(Clone, Copy)]
21pub struct NonIdentity<P> {
22    point: P,
23}
24
25impl<P> NonIdentity<P>
26where
27    P: ConditionallySelectable + ConstantTimeEq + Default,
28{
29    /// Create a [`NonIdentity`] from a point.
30    pub fn new(point: P) -> CtOption<Self> {
31        CtOption::new(Self { point }, !point.ct_eq(&P::default()))
32    }
33
34    pub(crate) fn new_unchecked(point: P) -> Self {
35        Self { point }
36    }
37}
38
39impl<P> NonIdentity<P>
40where
41    P: ConditionallySelectable + ConstantTimeEq + Default + GroupEncoding,
42{
43    /// Decode a [`NonIdentity`] from its encoding.
44    pub fn from_repr(repr: &P::Repr) -> CtOption<Self> {
45        Self::from_bytes(repr)
46    }
47}
48
49impl<P: Copy> NonIdentity<P> {
50    /// Return wrapped point.
51    pub fn to_point(self) -> P {
52        self.point
53    }
54}
55
56impl<P> NonIdentity<P>
57where
58    P: ConditionallySelectable + ConstantTimeEq + Curve + Default,
59{
60    /// Generate a random `NonIdentity<ProjectivePoint>`.
61    pub fn random(mut rng: impl CryptoRng + RngCore) -> Self {
62        loop {
63            if let Some(point) = Self::new(P::random(&mut rng)).into() {
64                break point;
65            }
66        }
67    }
68
69    /// Converts this element into its affine representation.
70    pub fn to_affine(self) -> NonIdentity<P::AffineRepr> {
71        NonIdentity {
72            point: self.point.to_affine(),
73        }
74    }
75}
76
77impl<P> NonIdentity<P>
78where
79    P: PrimeCurveAffine,
80{
81    /// Converts this element to its curve representation.
82    pub fn to_curve(self) -> NonIdentity<P::Curve> {
83        NonIdentity {
84            point: self.point.to_curve(),
85        }
86    }
87}
88
89impl<P> AsRef<P> for NonIdentity<P> {
90    fn as_ref(&self) -> &P {
91        &self.point
92    }
93}
94
95impl<P> ConditionallySelectable for NonIdentity<P>
96where
97    P: ConditionallySelectable,
98{
99    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
100        Self {
101            point: P::conditional_select(&a.point, &b.point, choice),
102        }
103    }
104}
105
106impl<P> ConstantTimeEq for NonIdentity<P>
107where
108    P: ConstantTimeEq,
109{
110    fn ct_eq(&self, other: &Self) -> Choice {
111        self.point.ct_eq(&other.point)
112    }
113}
114
115impl<P> Deref for NonIdentity<P> {
116    type Target = P;
117
118    fn deref(&self) -> &Self::Target {
119        &self.point
120    }
121}
122
123impl<P> GroupEncoding for NonIdentity<P>
124where
125    P: ConditionallySelectable + ConstantTimeEq + Default + GroupEncoding,
126{
127    type Repr = P::Repr;
128
129    fn from_bytes(bytes: &Self::Repr) -> CtOption<Self> {
130        let point = P::from_bytes(bytes);
131        point.and_then(|point| CtOption::new(Self { point }, !point.ct_eq(&P::default())))
132    }
133
134    fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption<Self> {
135        P::from_bytes_unchecked(bytes).map(|point| Self { point })
136    }
137
138    fn to_bytes(&self) -> Self::Repr {
139        self.point.to_bytes()
140    }
141}
142
143impl<C, P> Mul<NonZeroScalar<C>> for NonIdentity<P>
144where
145    C: CurveArithmetic,
146    P: Copy + Mul<Scalar<C>, Output = P>,
147{
148    type Output = NonIdentity<P>;
149
150    fn mul(self, rhs: NonZeroScalar<C>) -> Self::Output {
151        &self * &rhs
152    }
153}
154
155impl<C, P> Mul<&NonZeroScalar<C>> for &NonIdentity<P>
156where
157    C: CurveArithmetic,
158    P: Copy + Mul<Scalar<C>, Output = P>,
159{
160    type Output = NonIdentity<P>;
161
162    fn mul(self, rhs: &NonZeroScalar<C>) -> Self::Output {
163        NonIdentity {
164            point: self.point * *rhs.as_ref(),
165        }
166    }
167}
168
169#[cfg(feature = "serde")]
170impl<P> Serialize for NonIdentity<P>
171where
172    P: Serialize,
173{
174    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
175    where
176        S: ser::Serializer,
177    {
178        self.point.serialize(serializer)
179    }
180}
181
182#[cfg(feature = "serde")]
183impl<'de, P> Deserialize<'de> for NonIdentity<P>
184where
185    P: ConditionallySelectable + ConstantTimeEq + Default + Deserialize<'de> + GroupEncoding,
186{
187    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
188    where
189        D: de::Deserializer<'de>,
190    {
191        Option::from(Self::new(P::deserialize(deserializer)?))
192            .ok_or_else(|| de::Error::custom("expected non-identity point"))
193    }
194}
195
196#[cfg(all(test, feature = "dev"))]
197mod tests {
198    use super::NonIdentity;
199    use crate::dev::{AffinePoint, ProjectivePoint};
200    use group::GroupEncoding;
201    use hex_literal::hex;
202
203    #[test]
204    fn new_success() {
205        let point = ProjectivePoint::from_bytes(
206            &hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721").into(),
207        )
208        .unwrap();
209
210        assert!(bool::from(NonIdentity::new(point).is_some()));
211
212        assert!(bool::from(
213            NonIdentity::new(AffinePoint::from(point)).is_some()
214        ));
215    }
216
217    #[test]
218    fn new_fail() {
219        assert!(bool::from(
220            NonIdentity::new(ProjectivePoint::default()).is_none()
221        ));
222        assert!(bool::from(
223            NonIdentity::new(AffinePoint::default()).is_none()
224        ));
225    }
226
227    #[test]
228    fn round_trip() {
229        let bytes = hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
230        let point = NonIdentity::<ProjectivePoint>::from_repr(&bytes.into()).unwrap();
231        assert_eq!(&bytes, point.to_bytes().as_slice());
232
233        let bytes = hex!("02c9afa9d845ba75166b5c215767b1d6934e50c3db36e89b127b8a622b120f6721");
234        let point = NonIdentity::<AffinePoint>::from_repr(&bytes.into()).unwrap();
235        assert_eq!(&bytes, point.to_bytes().as_slice());
236    }
237}