1use crate::{Error, Result};
7use core::{
8    fmt::{self, Debug},
9    ops::{Add, Range},
10};
11use der::{asn1::UintRef, Decode, Encode, FixedTag, Length, Reader, Tag, Writer};
12use elliptic_curve::{
13    consts::U9,
14    generic_array::{typenum::Unsigned, ArrayLength, GenericArray},
15    FieldBytesSize, PrimeCurve,
16};
17
18#[cfg(feature = "alloc")]
19use {
20    alloc::{boxed::Box, vec::Vec},
21    signature::SignatureEncoding,
22    spki::{der::asn1::BitString, SignatureBitStringEncoding},
23};
24
25#[cfg(feature = "serde")]
26use serdect::serde::{de, ser, Deserialize, Serialize};
27
28pub type MaxOverhead = U9;
43
44pub type MaxSize<C> = <<FieldBytesSize<C> as Add>::Output as Add<MaxOverhead>>::Output;
46
47type SignatureBytes<C> = GenericArray<u8, MaxSize<C>>;
49
50pub struct Signature<C>
61where
62    C: PrimeCurve,
63    MaxSize<C>: ArrayLength<u8>,
64    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
65{
66    bytes: SignatureBytes<C>,
68
69    r_range: Range<usize>,
71
72    s_range: Range<usize>,
74}
75
76#[allow(clippy::len_without_is_empty)]
77impl<C> Signature<C>
78where
79    C: PrimeCurve,
80    MaxSize<C>: ArrayLength<u8>,
81    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
82{
83    pub fn from_bytes(input: &[u8]) -> Result<Self> {
85        let (r, s) = decode_der(input).map_err(|_| Error::new())?;
86
87        if r.as_bytes().len() > C::FieldBytesSize::USIZE
88            || s.as_bytes().len() > C::FieldBytesSize::USIZE
89        {
90            return Err(Error::new());
91        }
92
93        let r_range = find_scalar_range(input, r.as_bytes())?;
94        let s_range = find_scalar_range(input, s.as_bytes())?;
95
96        if s_range.end != input.len() {
97            return Err(Error::new());
98        }
99
100        let mut bytes = SignatureBytes::<C>::default();
101        bytes[..s_range.end].copy_from_slice(input);
102
103        Ok(Signature {
104            bytes,
105            r_range,
106            s_range,
107        })
108    }
109
110    pub(crate) fn from_components(r: &[u8], s: &[u8]) -> der::Result<Self> {
113        let r = UintRef::new(r)?;
114        let s = UintRef::new(s)?;
115
116        let mut bytes = SignatureBytes::<C>::default();
117        let mut writer = der::SliceWriter::new(&mut bytes);
118
119        writer.sequence((r.encoded_len()? + s.encoded_len()?)?, |seq| {
120            seq.encode(&r)?;
121            seq.encode(&s)
122        })?;
123
124        writer
125            .finish()?
126            .try_into()
127            .map_err(|_| der::Tag::Sequence.value_error())
128    }
129
130    pub fn as_bytes(&self) -> &[u8] {
132        &self.bytes.as_slice()[..self.len()]
133    }
134
135    #[cfg(feature = "alloc")]
137    pub fn to_bytes(&self) -> Box<[u8]> {
138        self.as_bytes().to_vec().into_boxed_slice()
139    }
140
141    pub fn len(&self) -> usize {
143        self.s_range.end
144    }
145
146    pub(crate) fn r(&self) -> &[u8] {
148        &self.bytes[self.r_range.clone()]
149    }
150
151    pub(crate) fn s(&self) -> &[u8] {
153        &self.bytes[self.s_range.clone()]
154    }
155}
156
157impl<C> AsRef<[u8]> for Signature<C>
158where
159    C: PrimeCurve,
160    MaxSize<C>: ArrayLength<u8>,
161    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
162{
163    fn as_ref(&self) -> &[u8] {
164        self.as_bytes()
165    }
166}
167
168impl<C> Clone for Signature<C>
169where
170    C: PrimeCurve,
171    MaxSize<C>: ArrayLength<u8>,
172    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
173{
174    fn clone(&self) -> Self {
175        Self {
176            bytes: self.bytes.clone(),
177            r_range: self.r_range.clone(),
178            s_range: self.s_range.clone(),
179        }
180    }
181}
182
183impl<C> Debug for Signature<C>
184where
185    C: PrimeCurve,
186    MaxSize<C>: ArrayLength<u8>,
187    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
188{
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        write!(f, "ecdsa::der::Signature<{:?}>(", C::default())?;
191
192        for &byte in self.as_ref() {
193            write!(f, "{:02X}", byte)?;
194        }
195
196        write!(f, ")")
197    }
198}
199
200impl<'a, C> Decode<'a> for Signature<C>
201where
202    C: PrimeCurve,
203    MaxSize<C>: ArrayLength<u8>,
204    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
205{
206    fn decode<R: Reader<'a>>(reader: &mut R) -> der::Result<Self> {
207        let header = reader.peek_header()?;
208        header.tag.assert_eq(Tag::Sequence)?;
209
210        let mut buf = SignatureBytes::<C>::default();
211        let len = (header.encoded_len()? + header.length)?;
212        let slice = buf
213            .get_mut(..usize::try_from(len)?)
214            .ok_or_else(|| reader.error(Tag::Sequence.length_error().kind()))?;
215
216        reader.read_into(slice)?;
217        Self::from_bytes(slice).map_err(|_| Tag::Integer.value_error())
218    }
219}
220
221impl<C> Encode for Signature<C>
222where
223    C: PrimeCurve,
224    MaxSize<C>: ArrayLength<u8>,
225    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
226{
227    fn encoded_len(&self) -> der::Result<Length> {
228        Length::try_from(self.len())
229    }
230
231    fn encode(&self, writer: &mut impl Writer) -> der::Result<()> {
232        writer.write(self.as_bytes())
233    }
234}
235
236impl<C> FixedTag for Signature<C>
237where
238    C: PrimeCurve,
239    MaxSize<C>: ArrayLength<u8>,
240    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
241{
242    const TAG: Tag = Tag::Sequence;
243}
244
245impl<C> From<crate::Signature<C>> for Signature<C>
246where
247    C: PrimeCurve,
248    MaxSize<C>: ArrayLength<u8>,
249    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
250{
251    fn from(sig: crate::Signature<C>) -> Signature<C> {
252        sig.to_der()
253    }
254}
255
256impl<C> TryFrom<&[u8]> for Signature<C>
257where
258    C: PrimeCurve,
259    MaxSize<C>: ArrayLength<u8>,
260    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
261{
262    type Error = Error;
263
264    fn try_from(input: &[u8]) -> Result<Self> {
265        Self::from_bytes(input)
266    }
267}
268
269impl<C> TryFrom<Signature<C>> for crate::Signature<C>
270where
271    C: PrimeCurve,
272    MaxSize<C>: ArrayLength<u8>,
273    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
274{
275    type Error = Error;
276
277    fn try_from(sig: Signature<C>) -> Result<super::Signature<C>> {
278        let mut bytes = super::SignatureBytes::<C>::default();
279        let r_begin = C::FieldBytesSize::USIZE.saturating_sub(sig.r().len());
280        let s_begin = bytes.len().saturating_sub(sig.s().len());
281        bytes[r_begin..C::FieldBytesSize::USIZE].copy_from_slice(sig.r());
282        bytes[s_begin..].copy_from_slice(sig.s());
283        Self::try_from(bytes.as_slice())
284    }
285}
286
287#[cfg(feature = "alloc")]
288impl<C> From<Signature<C>> for Box<[u8]>
289where
290    C: PrimeCurve,
291    MaxSize<C>: ArrayLength<u8>,
292    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
293{
294    fn from(signature: Signature<C>) -> Box<[u8]> {
295        signature.to_vec().into_boxed_slice()
296    }
297}
298
299#[cfg(feature = "alloc")]
300impl<C> SignatureEncoding for Signature<C>
301where
302    C: PrimeCurve,
303    MaxSize<C>: ArrayLength<u8>,
304    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
305{
306    type Repr = Box<[u8]>;
307
308    fn to_vec(&self) -> Vec<u8> {
309        self.as_bytes().into()
310    }
311}
312
313#[cfg(feature = "alloc")]
314impl<C> SignatureBitStringEncoding for Signature<C>
315where
316    C: PrimeCurve,
317    MaxSize<C>: ArrayLength<u8>,
318    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
319{
320    fn to_bitstring(&self) -> der::Result<BitString> {
321        BitString::new(0, self.to_vec())
322    }
323}
324
325#[cfg(feature = "serde")]
326impl<C> Serialize for Signature<C>
327where
328    C: PrimeCurve,
329    MaxSize<C>: ArrayLength<u8>,
330    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
331{
332    fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error>
333    where
334        S: ser::Serializer,
335    {
336        serdect::slice::serialize_hex_upper_or_bin(&self.as_bytes(), serializer)
337    }
338}
339
340#[cfg(feature = "serde")]
341impl<'de, C> Deserialize<'de> for Signature<C>
342where
343    C: PrimeCurve,
344    MaxSize<C>: ArrayLength<u8>,
345    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
346{
347    fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error>
348    where
349        D: de::Deserializer<'de>,
350    {
351        let mut buf = SignatureBytes::<C>::default();
352        let slice = serdect::slice::deserialize_hex_or_bin(&mut buf, deserializer)?;
353        Self::try_from(slice).map_err(de::Error::custom)
354    }
355}
356
357fn decode_der(der_bytes: &[u8]) -> der::Result<(UintRef<'_>, UintRef<'_>)> {
359    let mut reader = der::SliceReader::new(der_bytes)?;
360    let header = der::Header::decode(&mut reader)?;
361    header.tag.assert_eq(der::Tag::Sequence)?;
362
363    let ret = reader.read_nested(header.length, |reader| {
364        let r = UintRef::decode(reader)?;
365        let s = UintRef::decode(reader)?;
366        Ok((r, s))
367    })?;
368
369    reader.finish(ret)
370}
371
372fn find_scalar_range(outer: &[u8], inner: &[u8]) -> Result<Range<usize>> {
374    let outer_start = outer.as_ptr() as usize;
375    let inner_start = inner.as_ptr() as usize;
376    let start = inner_start
377        .checked_sub(outer_start)
378        .ok_or_else(Error::new)?;
379    let end = start.checked_add(inner.len()).ok_or_else(Error::new)?;
380    Ok(Range { start, end })
381}
382
383#[cfg(all(feature = "digest", feature = "hazmat"))]
384impl<C> signature::PrehashSignature for Signature<C>
385where
386    C: PrimeCurve + crate::hazmat::DigestPrimitive,
387    MaxSize<C>: ArrayLength<u8>,
388    <FieldBytesSize<C> as Add>::Output: Add<MaxOverhead> + ArrayLength<u8>,
389{
390    type Digest = C::Digest;
391}
392
393#[cfg(all(test, feature = "arithmetic"))]
394mod tests {
395    use elliptic_curve::dev::MockCurve;
396
397    type Signature = crate::Signature<MockCurve>;
398
399    const EXAMPLE_SIGNATURE: [u8; 64] = [
400        0xf3, 0xac, 0x80, 0x61, 0xb5, 0x14, 0x79, 0x5b, 0x88, 0x43, 0xe3, 0xd6, 0x62, 0x95, 0x27,
401        0xed, 0x2a, 0xfd, 0x6b, 0x1f, 0x6a, 0x55, 0x5a, 0x7a, 0xca, 0xbb, 0x5e, 0x6f, 0x79, 0xc8,
402        0xc2, 0xac, 0x8b, 0xf7, 0x78, 0x19, 0xca, 0x5, 0xa6, 0xb2, 0x78, 0x6c, 0x76, 0x26, 0x2b,
403        0xf7, 0x37, 0x1c, 0xef, 0x97, 0xb2, 0x18, 0xe9, 0x6f, 0x17, 0x5a, 0x3c, 0xcd, 0xda, 0x2a,
404        0xcc, 0x5, 0x89, 0x3,
405    ];
406
407    #[test]
408    fn test_fixed_to_asn1_signature_roundtrip() {
409        let signature1 = Signature::try_from(EXAMPLE_SIGNATURE.as_ref()).unwrap();
410
411        let asn1_signature = signature1.to_der();
413        let signature2 = Signature::from_der(asn1_signature.as_ref()).unwrap();
414
415        assert_eq!(signature1, signature2);
416    }
417
418    #[test]
419    fn test_asn1_too_short_signature() {
420        assert!(Signature::from_der(&[]).is_err());
421        assert!(Signature::from_der(&[der::Tag::Sequence.into()]).is_err());
422        assert!(Signature::from_der(&[der::Tag::Sequence.into(), 0x00]).is_err());
423        assert!(Signature::from_der(&[
424            der::Tag::Sequence.into(),
425            0x03,
426            der::Tag::Integer.into(),
427            0x01,
428            0x01
429        ])
430        .is_err());
431    }
432
433    #[test]
434    fn test_asn1_non_der_signature() {
435        assert!(Signature::from_der(&[
437            der::Tag::Sequence.into(),
438            0x06, der::Tag::Integer.into(),
440            0x01, 0x01, der::Tag::Integer.into(),
443            0x01, 0x01, ])
446        .is_ok());
447
448        assert!(Signature::from_der(&[
452            der::Tag::Sequence.into(),
453            0x81, 0x06, der::Tag::Integer.into(),
456            0x01, 0x01, der::Tag::Integer.into(),
459            0x01, 0x01, ])
462        .is_err());
463    }
464}