1use crate::{Error, Result};
4
5#[cfg(feature = "signing")]
6use {
7 crate::{hazmat::SignPrimitive, SigningKey},
8 elliptic_curve::subtle::CtOption,
9 signature::{hazmat::PrehashSigner, DigestSigner, Signer},
10};
11
12#[cfg(feature = "verifying")]
13use {
14 crate::{hazmat::VerifyPrimitive, VerifyingKey},
15 elliptic_curve::{
16 bigint::CheckedAdd,
17 ops::{LinearCombination, Reduce},
18 point::DecompressPoint,
19 sec1::{self, FromEncodedPoint, ToEncodedPoint},
20 AffinePoint, FieldBytesEncoding, FieldBytesSize, Group, PrimeField, ProjectivePoint,
21 },
22 signature::hazmat::PrehashVerifier,
23};
24
25#[cfg(any(feature = "signing", feature = "verifying"))]
26use {
27 crate::{
28 hazmat::{bits2field, DigestPrimitive},
29 Signature, SignatureSize,
30 },
31 elliptic_curve::{
32 generic_array::ArrayLength, ops::Invert, CurveArithmetic, PrimeCurve, Scalar,
33 },
34 signature::digest::Digest,
35};
36
37#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
51pub struct RecoveryId(u8);
52
53impl RecoveryId {
54 pub const MAX: u8 = 3;
56
57 pub const fn new(is_y_odd: bool, is_x_reduced: bool) -> Self {
62 Self((is_x_reduced as u8) << 1 | (is_y_odd as u8))
63 }
64
65 pub const fn is_x_reduced(self) -> bool {
67 (self.0 & 0b10) != 0
68 }
69
70 pub const fn is_y_odd(self) -> bool {
72 (self.0 & 1) != 0
73 }
74
75 pub const fn from_byte(byte: u8) -> Option<Self> {
77 if byte <= Self::MAX {
78 Some(Self(byte))
79 } else {
80 None
81 }
82 }
83
84 pub const fn to_byte(self) -> u8 {
86 self.0
87 }
88}
89
90#[cfg(feature = "verifying")]
91impl RecoveryId {
92 pub fn trial_recovery_from_msg<C>(
96 verifying_key: &VerifyingKey<C>,
97 msg: &[u8],
98 signature: &Signature<C>,
99 ) -> Result<Self>
100 where
101 C: DigestPrimitive + PrimeCurve + CurveArithmetic,
102 AffinePoint<C>:
103 DecompressPoint<C> + FromEncodedPoint<C> + ToEncodedPoint<C> + VerifyPrimitive<C>,
104 FieldBytesSize<C>: sec1::ModulusSize,
105 SignatureSize<C>: ArrayLength<u8>,
106 {
107 Self::trial_recovery_from_digest(verifying_key, C::Digest::new_with_prefix(msg), signature)
108 }
109
110 pub fn trial_recovery_from_digest<C, D>(
114 verifying_key: &VerifyingKey<C>,
115 digest: D,
116 signature: &Signature<C>,
117 ) -> Result<Self>
118 where
119 C: PrimeCurve + CurveArithmetic,
120 D: Digest,
121 AffinePoint<C>:
122 DecompressPoint<C> + FromEncodedPoint<C> + ToEncodedPoint<C> + VerifyPrimitive<C>,
123 FieldBytesSize<C>: sec1::ModulusSize,
124 SignatureSize<C>: ArrayLength<u8>,
125 {
126 Self::trial_recovery_from_prehash(verifying_key, &digest.finalize(), signature)
127 }
128
129 pub fn trial_recovery_from_prehash<C>(
133 verifying_key: &VerifyingKey<C>,
134 prehash: &[u8],
135 signature: &Signature<C>,
136 ) -> Result<Self>
137 where
138 C: PrimeCurve + CurveArithmetic,
139 AffinePoint<C>:
140 DecompressPoint<C> + FromEncodedPoint<C> + ToEncodedPoint<C> + VerifyPrimitive<C>,
141 FieldBytesSize<C>: sec1::ModulusSize,
142 SignatureSize<C>: ArrayLength<u8>,
143 {
144 for id in 0..=Self::MAX {
145 let recovery_id = RecoveryId(id);
146
147 if let Ok(vk) = VerifyingKey::recover_from_prehash(prehash, signature, recovery_id) {
148 if verifying_key == &vk {
149 return Ok(recovery_id);
150 }
151 }
152 }
153
154 Err(Error::new())
155 }
156}
157
158impl TryFrom<u8> for RecoveryId {
159 type Error = Error;
160
161 fn try_from(byte: u8) -> Result<Self> {
162 Self::from_byte(byte).ok_or_else(Error::new)
163 }
164}
165
166impl From<RecoveryId> for u8 {
167 fn from(id: RecoveryId) -> u8 {
168 id.0
169 }
170}
171
172#[cfg(feature = "signing")]
173impl<C> SigningKey<C>
174where
175 C: PrimeCurve + CurveArithmetic + DigestPrimitive,
176 Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
177 SignatureSize<C>: ArrayLength<u8>,
178{
179 pub fn sign_prehash_recoverable(&self, prehash: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
181 let z = bits2field::<C>(prehash)?;
182 let (sig, recid) = self
183 .as_nonzero_scalar()
184 .try_sign_prehashed_rfc6979::<C::Digest>(&z, &[])?;
185
186 Ok((sig, recid.ok_or_else(Error::new)?))
187 }
188
189 pub fn sign_digest_recoverable<D>(&self, msg_digest: D) -> Result<(Signature<C>, RecoveryId)>
191 where
192 D: Digest,
193 {
194 self.sign_prehash_recoverable(&msg_digest.finalize())
195 }
196
197 pub fn sign_recoverable(&self, msg: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
200 self.sign_digest_recoverable(C::Digest::new_with_prefix(msg))
201 }
202}
203
204#[cfg(feature = "signing")]
205impl<C, D> DigestSigner<D, (Signature<C>, RecoveryId)> for SigningKey<C>
206where
207 C: PrimeCurve + CurveArithmetic + DigestPrimitive,
208 D: Digest,
209 Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
210 SignatureSize<C>: ArrayLength<u8>,
211{
212 fn try_sign_digest(&self, msg_digest: D) -> Result<(Signature<C>, RecoveryId)> {
213 self.sign_digest_recoverable(msg_digest)
214 }
215}
216
217#[cfg(feature = "signing")]
218impl<C> PrehashSigner<(Signature<C>, RecoveryId)> for SigningKey<C>
219where
220 C: PrimeCurve + CurveArithmetic + DigestPrimitive,
221 Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
222 SignatureSize<C>: ArrayLength<u8>,
223{
224 fn sign_prehash(&self, prehash: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
225 self.sign_prehash_recoverable(prehash)
226 }
227}
228
229#[cfg(feature = "signing")]
230impl<C> Signer<(Signature<C>, RecoveryId)> for SigningKey<C>
231where
232 C: PrimeCurve + CurveArithmetic + DigestPrimitive,
233 Scalar<C>: Invert<Output = CtOption<Scalar<C>>> + SignPrimitive<C>,
234 SignatureSize<C>: ArrayLength<u8>,
235{
236 fn try_sign(&self, msg: &[u8]) -> Result<(Signature<C>, RecoveryId)> {
237 self.sign_recoverable(msg)
238 }
239}
240
241#[cfg(feature = "verifying")]
242impl<C> VerifyingKey<C>
243where
244 C: PrimeCurve + CurveArithmetic,
245 AffinePoint<C>:
246 DecompressPoint<C> + FromEncodedPoint<C> + ToEncodedPoint<C> + VerifyPrimitive<C>,
247 FieldBytesSize<C>: sec1::ModulusSize,
248 SignatureSize<C>: ArrayLength<u8>,
249{
250 pub fn recover_from_msg(
255 msg: &[u8],
256 signature: &Signature<C>,
257 recovery_id: RecoveryId,
258 ) -> Result<Self>
259 where
260 C: DigestPrimitive,
261 {
262 Self::recover_from_digest(C::Digest::new_with_prefix(msg), signature, recovery_id)
263 }
264
265 pub fn recover_from_digest<D>(
268 msg_digest: D,
269 signature: &Signature<C>,
270 recovery_id: RecoveryId,
271 ) -> Result<Self>
272 where
273 D: Digest,
274 {
275 Self::recover_from_prehash(&msg_digest.finalize(), signature, recovery_id)
276 }
277
278 #[allow(non_snake_case)]
281 pub fn recover_from_prehash(
282 prehash: &[u8],
283 signature: &Signature<C>,
284 recovery_id: RecoveryId,
285 ) -> Result<Self> {
286 let (r, s) = signature.split_scalars();
287 let z = <Scalar<C> as Reduce<C::Uint>>::reduce_bytes(&bits2field::<C>(prehash)?);
288
289 let mut r_bytes = r.to_repr();
290 if recovery_id.is_x_reduced() {
291 match Option::<C::Uint>::from(
292 C::Uint::decode_field_bytes(&r_bytes).checked_add(&C::ORDER),
293 ) {
294 Some(restored) => r_bytes = restored.encode_field_bytes(),
295 None => return Err(Error::new()),
297 };
298 }
299 let R = AffinePoint::<C>::decompress(&r_bytes, u8::from(recovery_id.is_y_odd()).into());
300
301 if R.is_none().into() {
302 return Err(Error::new());
303 }
304
305 let R = ProjectivePoint::<C>::from(R.unwrap());
306 let r_inv = *r.invert();
307 let u1 = -(r_inv * z);
308 let u2 = r_inv * *s;
309 let pk = ProjectivePoint::<C>::lincomb(&ProjectivePoint::<C>::generator(), &u1, &R, &u2);
310 let vk = Self::from_affine(pk.into())?;
311
312 vk.verify_prehash(prehash, signature)?;
314
315 Ok(vk)
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::RecoveryId;
322
323 #[test]
324 fn new() {
325 assert_eq!(RecoveryId::new(false, false).to_byte(), 0);
326 assert_eq!(RecoveryId::new(true, false).to_byte(), 1);
327 assert_eq!(RecoveryId::new(false, true).to_byte(), 2);
328 assert_eq!(RecoveryId::new(true, true).to_byte(), 3);
329 }
330
331 #[test]
332 fn try_from() {
333 for n in 0u8..=3 {
334 assert_eq!(RecoveryId::try_from(n).unwrap().to_byte(), n);
335 }
336
337 for n in 4u8..=255 {
338 assert!(RecoveryId::try_from(n).is_err());
339 }
340 }
341
342 #[test]
343 fn is_x_reduced() {
344 assert_eq!(RecoveryId::try_from(0).unwrap().is_x_reduced(), false);
345 assert_eq!(RecoveryId::try_from(1).unwrap().is_x_reduced(), false);
346 assert_eq!(RecoveryId::try_from(2).unwrap().is_x_reduced(), true);
347 assert_eq!(RecoveryId::try_from(3).unwrap().is_x_reduced(), true);
348 }
349
350 #[test]
351 fn is_y_odd() {
352 assert_eq!(RecoveryId::try_from(0).unwrap().is_y_odd(), false);
353 assert_eq!(RecoveryId::try_from(1).unwrap().is_y_odd(), true);
354 assert_eq!(RecoveryId::try_from(2).unwrap().is_y_odd(), false);
355 assert_eq!(RecoveryId::try_from(3).unwrap().is_y_odd(), true);
356 }
357}