lowrisc/
rsa.rs

1// Licensed under the Apache License, Version 2.0 or the MIT License.
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3// Copyright Tock Contributors 2022.
4
5//! RSA Implemented on top of the OTBN
6
7use crate::virtual_otbn::VirtualMuxAccel;
8use kernel::hil::public_key_crypto::rsa_math::{Client, ClientMut, RsaCryptoBase};
9use kernel::utilities::cells::OptionalCell;
10use kernel::utilities::cells::TakeCell;
11use kernel::utilities::mut_imut_buffer::MutImutBuffer;
12use kernel::ErrorCode;
13
14pub struct AppAddresses {
15    pub imem_start: usize,
16    pub imem_size: usize,
17    pub dmem_start: usize,
18    pub dmem_size: usize,
19}
20
21pub struct OtbnRsa<'a> {
22    otbn: &'a VirtualMuxAccel<'a>,
23    client: OptionalCell<&'a dyn Client<'a>>,
24    client_mut: OptionalCell<&'a dyn ClientMut<'a>>,
25
26    internal: TakeCell<'static, [u8]>,
27
28    message: TakeCell<'static, [u8]>,
29    modulus: OptionalCell<MutImutBuffer<'static, u8>>,
30    exponent: OptionalCell<MutImutBuffer<'static, u8>>,
31
32    rsa: AppAddresses,
33}
34
35impl<'a> OtbnRsa<'a> {
36    pub fn new(
37        otbn: &'a VirtualMuxAccel<'a>,
38        rsa: AppAddresses,
39        internal_buffer: &'static mut [u8],
40    ) -> Self {
41        OtbnRsa {
42            otbn,
43            client: OptionalCell::empty(),
44            client_mut: OptionalCell::empty(),
45            internal: TakeCell::new(internal_buffer),
46            message: TakeCell::empty(),
47            modulus: OptionalCell::empty(),
48            exponent: OptionalCell::empty(),
49            rsa,
50        }
51    }
52
53    fn report_error(&self, error: ErrorCode, result: &'static mut [u8]) {
54        match self.exponent.take().unwrap() {
55            MutImutBuffer::Mutable(exponent) => {
56                self.client_mut
57                    .map(|client| match self.modulus.take().unwrap() {
58                        MutImutBuffer::Mutable(modulus) => {
59                            client.mod_exponent_done(
60                                Err(error),
61                                self.message.take().unwrap(),
62                                modulus,
63                                exponent,
64                                result,
65                            );
66                        }
67                        MutImutBuffer::Immutable(_) => unreachable!(),
68                    });
69            }
70            MutImutBuffer::Immutable(exponent) => match self.modulus.take().unwrap() {
71                MutImutBuffer::Immutable(modulus) => {
72                    self.client.map(|client| {
73                        client.mod_exponent_done(
74                            Err(error),
75                            self.message.take().unwrap(),
76                            modulus,
77                            exponent,
78                            result,
79                        );
80                    });
81                }
82                MutImutBuffer::Mutable(_) => unreachable!(),
83            },
84        }
85    }
86}
87
88impl<'a> crate::otbn::Client<'a> for OtbnRsa<'a> {
89    fn op_done(&'a self, result: Result<(), ErrorCode>, output: &'static mut [u8]) {
90        if let Err(e) = result {
91            self.report_error(e, output);
92            return;
93        }
94
95        // We want to return BE data
96        output.reverse();
97
98        match self.exponent.take().unwrap() {
99            MutImutBuffer::Mutable(exponent) => {
100                self.client_mut
101                    .map(|client| match self.modulus.take().unwrap() {
102                        MutImutBuffer::Mutable(modulus) => {
103                            client.mod_exponent_done(
104                                Ok(true),
105                                self.message.take().unwrap(),
106                                modulus,
107                                exponent,
108                                output,
109                            );
110                        }
111                        MutImutBuffer::Immutable(_) => unreachable!(),
112                    });
113            }
114            MutImutBuffer::Immutable(exponent) => match self.modulus.take().unwrap() {
115                MutImutBuffer::Immutable(modulus) => {
116                    self.client.map(|client| {
117                        client.mod_exponent_done(
118                            Ok(true),
119                            self.message.take().unwrap(),
120                            modulus,
121                            exponent,
122                            output,
123                        );
124                    });
125                }
126                MutImutBuffer::Mutable(_) => unreachable!(),
127            },
128        }
129    }
130}
131
132impl<'a> RsaCryptoBase<'a> for OtbnRsa<'a> {
133    fn set_client(&'a self, client: &'a dyn Client<'a>) {
134        self.client.set(client);
135    }
136
137    fn clear_data(&self) {
138        self.otbn.clear_data();
139    }
140
141    fn mod_exponent(
142        &self,
143        message: &'static mut [u8],
144        modulus: &'static [u8],
145        exponent: &'static [u8],
146        result: &'static mut [u8],
147    ) -> Result<
148        (),
149        (
150            ErrorCode,
151            &'static mut [u8],
152            &'static [u8],
153            &'static [u8],
154            &'static mut [u8],
155        ),
156    > {
157        // Check that the lengths match our expectations
158        let op_len = modulus.len();
159
160        if result.len() < op_len {
161            return Err((ErrorCode::SIZE, message, modulus, exponent, result));
162        }
163
164        let slice = unsafe {
165            core::slice::from_raw_parts(self.rsa.imem_start as *mut u8, self.rsa.imem_size)
166        };
167        if let Err(e) = self.otbn.load_binary(slice) {
168            return Err((e, message, modulus, exponent, result));
169        }
170
171        let slice = unsafe {
172            core::slice::from_raw_parts(self.rsa.dmem_start as *mut u8, self.rsa.dmem_size)
173        };
174        if let Err(e) = self.otbn.load_data(0, slice) {
175            return Err((e, message, modulus, exponent, result));
176        }
177
178        // Set the mode to decryption
179        if let Some(data) = self.internal.take() {
180            data[0] = 2;
181            data[1] = 0;
182            data[2] = 0;
183            data[3] = 0;
184            // Set the RSA mode
185            // The address is the offset of `mode` in the RSA elf
186            if let Err(e) = self.otbn.load_data(0, &data[0..4]) {
187                return Err((e, message, modulus, exponent, result));
188            }
189
190            data[0] = (op_len / 32) as u8;
191            data[1] = 0;
192            data[2] = 0;
193            data[3] = 0;
194            // Set the RSA length
195            // The address is the offset of `n_limbs` in the RSA elf
196            if let Err(e) = self.otbn.load_data(4, &data[0..4]) {
197                return Err((e, message, modulus, exponent, result));
198            }
199
200            data[0..op_len].copy_from_slice(modulus);
201            // We were passed BE data and the OTBN expects LE
202            // so reverse the order.
203            data[0..op_len].reverse();
204            // Set the RSA modulus
205            // The address is the offset of `modulus` in the RSA elf
206            if let Err(e) = self.otbn.load_data(0x20, &data[0..op_len]) {
207                return Err((e, message, modulus, exponent, result));
208            }
209
210            let len = exponent.len().min(op_len);
211            data[0..len].copy_from_slice(exponent);
212            // We were passed BE data and the OTBN expects LE
213            // so reverse the order.
214            data[0..len].reverse();
215
216            // Set the RSA exponent
217            // The address is the offset of `exp` in the RSA elf
218            if let Err(e) = self.otbn.load_data(0x220, &data[0..len]) {
219                return Err((e, message, modulus, exponent, result));
220            }
221
222            self.internal.replace(data);
223        } else {
224            return Err((ErrorCode::NOMEM, message, modulus, exponent, result));
225        }
226
227        // Set the data in
228        // The address is the offset of `inout` in the RSA elf
229        if let Err(e) = self.otbn.load_data(0x420, message) {
230            return Err((e, message, modulus, exponent, result));
231        }
232
233        self.message.replace(message);
234        self.modulus.replace(MutImutBuffer::Immutable(modulus));
235        self.exponent.replace(MutImutBuffer::Immutable(exponent));
236
237        // Get the data out
238        // The address is the offset of `inout` in the RSA elf
239        if let Err(e) = self.otbn.run(0x420, result) {
240            return Err((e.0, self.message.take().unwrap(), modulus, exponent, e.1));
241        }
242
243        Ok(())
244    }
245}