1use crate::virtual_otbn::VirtualMuxAccel;
8use kernel::hil::public_key_crypto::rsa_math::{Client, ClientMut, RsaCryptoBase};
9use kernel::utilities::cells::OptionalCell;
10use kernel::utilities::cells::TakeCell;
11use kernel::utilities::mut_imut_buffer::MutImutBuffer;
12use kernel::ErrorCode;
13
14pub struct AppAddresses {
15 pub imem_start: usize,
16 pub imem_size: usize,
17 pub dmem_start: usize,
18 pub dmem_size: usize,
19}
20
21pub struct OtbnRsa<'a> {
22 otbn: &'a VirtualMuxAccel<'a>,
23 client: OptionalCell<&'a dyn Client<'a>>,
24 client_mut: OptionalCell<&'a dyn ClientMut<'a>>,
25
26 internal: TakeCell<'static, [u8]>,
27
28 message: TakeCell<'static, [u8]>,
29 modulus: OptionalCell<MutImutBuffer<'static, u8>>,
30 exponent: OptionalCell<MutImutBuffer<'static, u8>>,
31
32 rsa: AppAddresses,
33}
34
35impl<'a> OtbnRsa<'a> {
36 pub fn new(
37 otbn: &'a VirtualMuxAccel<'a>,
38 rsa: AppAddresses,
39 internal_buffer: &'static mut [u8],
40 ) -> Self {
41 OtbnRsa {
42 otbn,
43 client: OptionalCell::empty(),
44 client_mut: OptionalCell::empty(),
45 internal: TakeCell::new(internal_buffer),
46 message: TakeCell::empty(),
47 modulus: OptionalCell::empty(),
48 exponent: OptionalCell::empty(),
49 rsa,
50 }
51 }
52
53 fn report_error(&self, error: ErrorCode, result: &'static mut [u8]) {
54 match self.exponent.take().unwrap() {
55 MutImutBuffer::Mutable(exponent) => {
56 self.client_mut
57 .map(|client| match self.modulus.take().unwrap() {
58 MutImutBuffer::Mutable(modulus) => {
59 client.mod_exponent_done(
60 Err(error),
61 self.message.take().unwrap(),
62 modulus,
63 exponent,
64 result,
65 );
66 }
67 MutImutBuffer::Immutable(_) => unreachable!(),
68 });
69 }
70 MutImutBuffer::Immutable(exponent) => match self.modulus.take().unwrap() {
71 MutImutBuffer::Immutable(modulus) => {
72 self.client.map(|client| {
73 client.mod_exponent_done(
74 Err(error),
75 self.message.take().unwrap(),
76 modulus,
77 exponent,
78 result,
79 );
80 });
81 }
82 MutImutBuffer::Mutable(_) => unreachable!(),
83 },
84 }
85 }
86}
87
88impl<'a> crate::otbn::Client<'a> for OtbnRsa<'a> {
89 fn op_done(&'a self, result: Result<(), ErrorCode>, output: &'static mut [u8]) {
90 if let Err(e) = result {
91 self.report_error(e, output);
92 return;
93 }
94
95 output.reverse();
97
98 match self.exponent.take().unwrap() {
99 MutImutBuffer::Mutable(exponent) => {
100 self.client_mut
101 .map(|client| match self.modulus.take().unwrap() {
102 MutImutBuffer::Mutable(modulus) => {
103 client.mod_exponent_done(
104 Ok(true),
105 self.message.take().unwrap(),
106 modulus,
107 exponent,
108 output,
109 );
110 }
111 MutImutBuffer::Immutable(_) => unreachable!(),
112 });
113 }
114 MutImutBuffer::Immutable(exponent) => match self.modulus.take().unwrap() {
115 MutImutBuffer::Immutable(modulus) => {
116 self.client.map(|client| {
117 client.mod_exponent_done(
118 Ok(true),
119 self.message.take().unwrap(),
120 modulus,
121 exponent,
122 output,
123 );
124 });
125 }
126 MutImutBuffer::Mutable(_) => unreachable!(),
127 },
128 }
129 }
130}
131
132impl<'a> RsaCryptoBase<'a> for OtbnRsa<'a> {
133 fn set_client(&'a self, client: &'a dyn Client<'a>) {
134 self.client.set(client);
135 }
136
137 fn clear_data(&self) {
138 self.otbn.clear_data();
139 }
140
141 fn mod_exponent(
142 &self,
143 message: &'static mut [u8],
144 modulus: &'static [u8],
145 exponent: &'static [u8],
146 result: &'static mut [u8],
147 ) -> Result<
148 (),
149 (
150 ErrorCode,
151 &'static mut [u8],
152 &'static [u8],
153 &'static [u8],
154 &'static mut [u8],
155 ),
156 > {
157 let op_len = modulus.len();
159
160 if result.len() < op_len {
161 return Err((ErrorCode::SIZE, message, modulus, exponent, result));
162 }
163
164 let slice = unsafe {
165 core::slice::from_raw_parts(self.rsa.imem_start as *mut u8, self.rsa.imem_size)
166 };
167 if let Err(e) = self.otbn.load_binary(slice) {
168 return Err((e, message, modulus, exponent, result));
169 }
170
171 let slice = unsafe {
172 core::slice::from_raw_parts(self.rsa.dmem_start as *mut u8, self.rsa.dmem_size)
173 };
174 if let Err(e) = self.otbn.load_data(0, slice) {
175 return Err((e, message, modulus, exponent, result));
176 }
177
178 if let Some(data) = self.internal.take() {
180 data[0] = 2;
181 data[1] = 0;
182 data[2] = 0;
183 data[3] = 0;
184 if let Err(e) = self.otbn.load_data(0, &data[0..4]) {
187 return Err((e, message, modulus, exponent, result));
188 }
189
190 data[0] = (op_len / 32) as u8;
191 data[1] = 0;
192 data[2] = 0;
193 data[3] = 0;
194 if let Err(e) = self.otbn.load_data(4, &data[0..4]) {
197 return Err((e, message, modulus, exponent, result));
198 }
199
200 data[0..op_len].copy_from_slice(modulus);
201 data[0..op_len].reverse();
204 if let Err(e) = self.otbn.load_data(0x20, &data[0..op_len]) {
207 return Err((e, message, modulus, exponent, result));
208 }
209
210 let len = exponent.len().min(op_len);
211 data[0..len].copy_from_slice(exponent);
212 data[0..len].reverse();
215
216 if let Err(e) = self.otbn.load_data(0x220, &data[0..len]) {
219 return Err((e, message, modulus, exponent, result));
220 }
221
222 self.internal.replace(data);
223 } else {
224 return Err((ErrorCode::NOMEM, message, modulus, exponent, result));
225 }
226
227 if let Err(e) = self.otbn.load_data(0x420, message) {
230 return Err((e, message, modulus, exponent, result));
231 }
232
233 self.message.replace(message);
234 self.modulus.replace(MutImutBuffer::Immutable(modulus));
235 self.exponent.replace(MutImutBuffer::Immutable(exponent));
236
237 if let Err(e) = self.otbn.run(0x420, result) {
240 return Err((e.0, self.message.take().unwrap(), modulus, exponent, e.1));
241 }
242
243 Ok(())
244 }
245}