1#![no_std]
3#![doc(
4    html_logo_url = "https://raw.githubusercontent.com/RustCrypto/media/6ee8e381/logo.svg",
5    html_favicon_url = "https://raw.githubusercontent.com/RustCrypto/media/6ee8e381/logo.svg"
6)]
7#![warn(missing_docs, rust_2018_idioms)]
8
9pub use generic_array;
10
11use core::{fmt, marker::PhantomData, slice};
12use generic_array::{
13    typenum::{IsLess, Le, NonZero, U256},
14    ArrayLength, GenericArray,
15};
16
17mod sealed;
18
19pub type Block<BlockSize> = GenericArray<u8, BlockSize>;
21
22pub trait BufferKind: sealed::Sealed {}
24
25#[derive(Copy, Clone, Debug, Default)]
28pub struct Eager {}
29
30#[derive(Copy, Clone, Debug, Default)]
33pub struct Lazy {}
34
35impl BufferKind for Eager {}
36impl BufferKind for Lazy {}
37
38pub type EagerBuffer<B> = BlockBuffer<B, Eager>;
40pub type LazyBuffer<B> = BlockBuffer<B, Lazy>;
42
43#[derive(Copy, Clone, Eq, PartialEq, Debug)]
45pub struct Error;
46
47impl fmt::Display for Error {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
49        f.write_str("Block buffer error")
50    }
51}
52
53#[derive(Debug)]
55pub struct BlockBuffer<BlockSize, Kind>
56where
57    BlockSize: ArrayLength<u8> + IsLess<U256>,
58    Le<BlockSize, U256>: NonZero,
59    Kind: BufferKind,
60{
61    buffer: Block<BlockSize>,
62    pos: u8,
63    _pd: PhantomData<Kind>,
64}
65
66impl<BlockSize, Kind> Default for BlockBuffer<BlockSize, Kind>
67where
68    BlockSize: ArrayLength<u8> + IsLess<U256>,
69    Le<BlockSize, U256>: NonZero,
70    Kind: BufferKind,
71{
72    fn default() -> Self {
73        if BlockSize::USIZE == 0 {
74            panic!("Block size can not be equal to zero");
75        }
76        Self {
77            buffer: Default::default(),
78            pos: 0,
79            _pd: PhantomData,
80        }
81    }
82}
83
84impl<BlockSize, Kind> Clone for BlockBuffer<BlockSize, Kind>
85where
86    BlockSize: ArrayLength<u8> + IsLess<U256>,
87    Le<BlockSize, U256>: NonZero,
88    Kind: BufferKind,
89{
90    fn clone(&self) -> Self {
91        Self {
92            buffer: self.buffer.clone(),
93            pos: self.pos,
94            _pd: PhantomData,
95        }
96    }
97}
98
99impl<BlockSize, Kind> BlockBuffer<BlockSize, Kind>
100where
101    BlockSize: ArrayLength<u8> + IsLess<U256>,
102    Le<BlockSize, U256>: NonZero,
103    Kind: BufferKind,
104{
105    #[inline(always)]
110    pub fn new(buf: &[u8]) -> Self {
111        Self::try_new(buf).unwrap()
112    }
113
114    #[inline(always)]
118    pub fn try_new(buf: &[u8]) -> Result<Self, Error> {
119        if BlockSize::USIZE == 0 {
120            panic!("Block size can not be equal to zero");
121        }
122        let pos = buf.len();
123        if !Kind::invariant(pos, BlockSize::USIZE) {
124            return Err(Error);
125        }
126        let mut buffer = Block::<BlockSize>::default();
127        buffer[..pos].copy_from_slice(buf);
128        Ok(Self {
129            buffer,
130            pos: pos as u8,
131            _pd: PhantomData,
132        })
133    }
134
135    #[inline]
138    pub fn digest_blocks(
139        &mut self,
140        mut input: &[u8],
141        mut compress: impl FnMut(&[Block<BlockSize>]),
142    ) {
143        let pos = self.get_pos();
144        let rem = self.size() - pos;
147        let n = input.len();
148        if Kind::invariant(n, rem) {
156            self.buffer[pos..][..n].copy_from_slice(input);
158            self.set_pos_unchecked(pos + n);
159            return;
160        }
161        if pos != 0 {
162            let (left, right) = input.split_at(rem);
163            input = right;
164            self.buffer[pos..].copy_from_slice(left);
165            compress(slice::from_ref(&self.buffer));
166        }
167
168        let (blocks, leftover) = Kind::split_blocks(input);
169        if !blocks.is_empty() {
170            compress(blocks);
171        }
172
173        let n = leftover.len();
174        self.buffer[..n].copy_from_slice(leftover);
175        self.set_pos_unchecked(n);
176    }
177
178    #[inline(always)]
180    pub fn reset(&mut self) {
181        self.set_pos_unchecked(0);
182    }
183
184    #[inline(always)]
186    pub fn pad_with_zeros(&mut self) -> &mut Block<BlockSize> {
187        let pos = self.get_pos();
188        self.buffer[pos..].iter_mut().for_each(|b| *b = 0);
189        self.set_pos_unchecked(0);
190        &mut self.buffer
191    }
192
193    #[inline(always)]
195    pub fn get_pos(&self) -> usize {
196        let pos = self.pos as usize;
197        if !Kind::invariant(pos, BlockSize::USIZE) {
198            debug_assert!(false);
199            unsafe {
201                core::hint::unreachable_unchecked();
202            }
203        }
204        pos
205    }
206
207    #[inline(always)]
209    pub fn get_data(&self) -> &[u8] {
210        &self.buffer[..self.get_pos()]
211    }
212
213    #[inline]
218    pub fn set(&mut self, buf: Block<BlockSize>, pos: usize) {
219        assert!(Kind::invariant(pos, BlockSize::USIZE));
220        self.buffer = buf;
221        self.set_pos_unchecked(pos);
222    }
223
224    #[inline(always)]
226    pub fn size(&self) -> usize {
227        BlockSize::USIZE
228    }
229
230    #[inline(always)]
232    pub fn remaining(&self) -> usize {
233        self.size() - self.get_pos()
234    }
235
236    #[inline(always)]
237    fn set_pos_unchecked(&mut self, pos: usize) {
238        debug_assert!(Kind::invariant(pos, BlockSize::USIZE));
239        self.pos = pos as u8;
240    }
241}
242
243impl<BlockSize> BlockBuffer<BlockSize, Eager>
244where
245    BlockSize: ArrayLength<u8> + IsLess<U256>,
246    Le<BlockSize, U256>: NonZero,
247{
248    #[inline]
250    pub fn set_data(
251        &mut self,
252        mut data: &mut [u8],
253        mut process_blocks: impl FnMut(&mut [Block<BlockSize>]),
254    ) {
255        let pos = self.get_pos();
256        let r = self.remaining();
257        let n = data.len();
258        if pos != 0 {
259            if n < r {
260                data.copy_from_slice(&self.buffer[pos..][..n]);
262                self.set_pos_unchecked(pos + n);
263                return;
264            }
265            let (left, right) = data.split_at_mut(r);
266            data = right;
267            left.copy_from_slice(&self.buffer[pos..]);
268        }
269
270        let (blocks, leftover) = to_blocks_mut(data);
271        process_blocks(blocks);
272
273        let n = leftover.len();
274        if n != 0 {
275            let mut block = Default::default();
276            process_blocks(slice::from_mut(&mut block));
277            leftover.copy_from_slice(&block[..n]);
278            self.buffer = block;
279        }
280        self.set_pos_unchecked(n);
281    }
282
283    #[inline(always)]
290    pub fn digest_pad(
291        &mut self,
292        delim: u8,
293        suffix: &[u8],
294        mut compress: impl FnMut(&Block<BlockSize>),
295    ) {
296        if suffix.len() > BlockSize::USIZE {
297            panic!("suffix is too long");
298        }
299        let pos = self.get_pos();
300        self.buffer[pos] = delim;
301        for b in &mut self.buffer[pos + 1..] {
302            *b = 0;
303        }
304
305        let n = self.size() - suffix.len();
306        if self.size() - pos - 1 < suffix.len() {
307            compress(&self.buffer);
308            let mut block = Block::<BlockSize>::default();
309            block[n..].copy_from_slice(suffix);
310            compress(&block);
311        } else {
312            self.buffer[n..].copy_from_slice(suffix);
313            compress(&self.buffer);
314        }
315        self.set_pos_unchecked(0)
316    }
317
318    #[inline]
321    pub fn len64_padding_be(&mut self, data_len: u64, compress: impl FnMut(&Block<BlockSize>)) {
322        self.digest_pad(0x80, &data_len.to_be_bytes(), compress);
323    }
324
325    #[inline]
328    pub fn len64_padding_le(&mut self, data_len: u64, compress: impl FnMut(&Block<BlockSize>)) {
329        self.digest_pad(0x80, &data_len.to_le_bytes(), compress);
330    }
331
332    #[inline]
335    pub fn len128_padding_be(&mut self, data_len: u128, compress: impl FnMut(&Block<BlockSize>)) {
336        self.digest_pad(0x80, &data_len.to_be_bytes(), compress);
337    }
338}
339
340#[inline(always)]
342fn to_blocks_mut<N: ArrayLength<u8>>(data: &mut [u8]) -> (&mut [Block<N>], &mut [u8]) {
343    let nb = data.len() / N::USIZE;
344    let (left, right) = data.split_at_mut(nb * N::USIZE);
345    let p = left.as_mut_ptr() as *mut Block<N>;
346    let blocks = unsafe { slice::from_raw_parts_mut(p, nb) };
349    (blocks, right)
350}