1use crate::{FixedOutput, FixedOutputReset, Update};
2use crypto_common::{InvalidLength, Key, KeyInit, Output, OutputSizeUser, Reset};
3
4#[cfg(feature = "rand_core")]
5use crate::rand_core::{CryptoRng, RngCore};
6use core::fmt;
7use crypto_common::typenum::Unsigned;
8use subtle::{Choice, ConstantTimeEq};
9
10#[cfg_attr(docsrs, doc(cfg(feature = "mac")))]
12pub trait MacMarker {}
13
14#[cfg_attr(docsrs, doc(cfg(feature = "mac")))]
19pub trait Mac: OutputSizeUser + Sized {
20    fn new(key: &Key<Self>) -> Self
22    where
23        Self: KeyInit;
24
25    #[cfg(feature = "rand_core")]
27    #[cfg_attr(docsrs, doc(cfg(feature = "rand_core")))]
28    fn generate_key(rng: impl CryptoRng + RngCore) -> Key<Self>
29    where
30        Self: KeyInit;
31
32    fn new_from_slice(key: &[u8]) -> Result<Self, InvalidLength>
34    where
35        Self: KeyInit;
36
37    fn update(&mut self, data: &[u8]);
39
40    #[must_use]
42    fn chain_update(self, data: impl AsRef<[u8]>) -> Self;
43
44    fn finalize(self) -> CtOutput<Self>;
47
48    fn finalize_reset(&mut self) -> CtOutput<Self>
51    where
52        Self: FixedOutputReset;
53
54    fn reset(&mut self)
56    where
57        Self: Reset;
58
59    fn verify(self, tag: &Output<Self>) -> Result<(), MacError>;
61
62    fn verify_reset(&mut self, tag: &Output<Self>) -> Result<(), MacError>
65    where
66        Self: FixedOutputReset;
67
68    fn verify_slice(self, tag: &[u8]) -> Result<(), MacError>;
74
75    fn verify_slice_reset(&mut self, tag: &[u8]) -> Result<(), MacError>
81    where
82        Self: FixedOutputReset;
83
84    fn verify_truncated_left(self, tag: &[u8]) -> Result<(), MacError>;
89
90    fn verify_truncated_right(self, tag: &[u8]) -> Result<(), MacError>;
95}
96
97impl<T: Update + FixedOutput + MacMarker> Mac for T {
98    #[inline(always)]
99    fn new(key: &Key<Self>) -> Self
100    where
101        Self: KeyInit,
102    {
103        KeyInit::new(key)
104    }
105
106    #[inline(always)]
107    fn new_from_slice(key: &[u8]) -> Result<Self, InvalidLength>
108    where
109        Self: KeyInit,
110    {
111        KeyInit::new_from_slice(key)
112    }
113
114    #[inline]
115    fn update(&mut self, data: &[u8]) {
116        Update::update(self, data);
117    }
118
119    #[inline]
120    fn chain_update(mut self, data: impl AsRef<[u8]>) -> Self {
121        Update::update(&mut self, data.as_ref());
122        self
123    }
124
125    #[inline]
126    fn finalize(self) -> CtOutput<Self> {
127        CtOutput::new(self.finalize_fixed())
128    }
129
130    #[inline(always)]
131    fn finalize_reset(&mut self) -> CtOutput<Self>
132    where
133        Self: FixedOutputReset,
134    {
135        CtOutput::new(self.finalize_fixed_reset())
136    }
137
138    #[inline]
139    fn reset(&mut self)
140    where
141        Self: Reset,
142    {
143        Reset::reset(self)
144    }
145
146    #[inline]
147    fn verify(self, tag: &Output<Self>) -> Result<(), MacError> {
148        if self.finalize() == tag.into() {
149            Ok(())
150        } else {
151            Err(MacError)
152        }
153    }
154
155    #[inline]
156    fn verify_reset(&mut self, tag: &Output<Self>) -> Result<(), MacError>
157    where
158        Self: FixedOutputReset,
159    {
160        if self.finalize_reset() == tag.into() {
161            Ok(())
162        } else {
163            Err(MacError)
164        }
165    }
166
167    #[inline]
168    fn verify_slice(self, tag: &[u8]) -> Result<(), MacError> {
169        let n = tag.len();
170        if n != Self::OutputSize::USIZE {
171            return Err(MacError);
172        }
173        let choice = self.finalize_fixed().ct_eq(tag);
174        if choice.into() {
175            Ok(())
176        } else {
177            Err(MacError)
178        }
179    }
180
181    #[inline]
182    fn verify_slice_reset(&mut self, tag: &[u8]) -> Result<(), MacError>
183    where
184        Self: FixedOutputReset,
185    {
186        let n = tag.len();
187        if n != Self::OutputSize::USIZE {
188            return Err(MacError);
189        }
190        let choice = self.finalize_fixed_reset().ct_eq(tag);
191        if choice.into() {
192            Ok(())
193        } else {
194            Err(MacError)
195        }
196    }
197
198    fn verify_truncated_left(self, tag: &[u8]) -> Result<(), MacError> {
199        let n = tag.len();
200        if n == 0 || n > Self::OutputSize::USIZE {
201            return Err(MacError);
202        }
203        let choice = self.finalize_fixed()[..n].ct_eq(tag);
204
205        if choice.into() {
206            Ok(())
207        } else {
208            Err(MacError)
209        }
210    }
211
212    fn verify_truncated_right(self, tag: &[u8]) -> Result<(), MacError> {
213        let n = tag.len();
214        if n == 0 || n > Self::OutputSize::USIZE {
215            return Err(MacError);
216        }
217        let m = Self::OutputSize::USIZE - n;
218        let choice = self.finalize_fixed()[m..].ct_eq(tag);
219
220        if choice.into() {
221            Ok(())
222        } else {
223            Err(MacError)
224        }
225    }
226
227    #[cfg(feature = "rand_core")]
228    #[cfg_attr(docsrs, doc(cfg(feature = "rand_core")))]
229    #[inline]
230    fn generate_key(rng: impl CryptoRng + RngCore) -> Key<Self>
231    where
232        Self: KeyInit,
233    {
234        <T as KeyInit>::generate_key(rng)
235    }
236}
237
238#[derive(Clone)]
243#[cfg_attr(docsrs, doc(cfg(feature = "mac")))]
244pub struct CtOutput<T: OutputSizeUser> {
245    bytes: Output<T>,
246}
247
248impl<T: OutputSizeUser> CtOutput<T> {
249    #[inline(always)]
251    pub fn new(bytes: Output<T>) -> Self {
252        Self { bytes }
253    }
254
255    #[inline(always)]
257    pub fn into_bytes(self) -> Output<T> {
258        self.bytes
259    }
260}
261
262impl<T: OutputSizeUser> From<Output<T>> for CtOutput<T> {
263    #[inline(always)]
264    fn from(bytes: Output<T>) -> Self {
265        Self { bytes }
266    }
267}
268
269impl<'a, T: OutputSizeUser> From<&'a Output<T>> for CtOutput<T> {
270    #[inline(always)]
271    fn from(bytes: &'a Output<T>) -> Self {
272        bytes.clone().into()
273    }
274}
275
276impl<T: OutputSizeUser> ConstantTimeEq for CtOutput<T> {
277    #[inline(always)]
278    fn ct_eq(&self, other: &Self) -> Choice {
279        self.bytes.ct_eq(&other.bytes)
280    }
281}
282
283impl<T: OutputSizeUser> PartialEq for CtOutput<T> {
284    #[inline(always)]
285    fn eq(&self, x: &CtOutput<T>) -> bool {
286        self.ct_eq(x).into()
287    }
288}
289
290impl<T: OutputSizeUser> Eq for CtOutput<T> {}
291
292#[derive(Default, Debug, Copy, Clone, Eq, PartialEq)]
295#[cfg_attr(docsrs, doc(cfg(feature = "mac")))]
296pub struct MacError;
297
298impl fmt::Display for MacError {
299    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
300        f.write_str("MAC tag mismatch")
301    }
302}
303
304#[cfg(feature = "std")]
305impl std::error::Error for MacError {}