1use chrono::{DateTime, Utc};
8use mas_iana::oauth::PkceCodeChallengeMethod;
9use oauth2_types::{
10 pkce::{CodeChallengeError, CodeChallengeMethodExt},
11 requests::ResponseMode,
12 scope::{OPENID, PROFILE, Scope},
13};
14use rand::{
15 RngCore,
16 distributions::{Alphanumeric, DistString},
17};
18use ruma_common::UserId;
19use serde::Serialize;
20use ulid::Ulid;
21use url::Url;
22
23use super::session::Session;
24use crate::InvalidTransitionError;
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
27pub struct Pkce {
28 pub challenge_method: PkceCodeChallengeMethod,
29 pub challenge: String,
30}
31
32impl Pkce {
33 #[must_use]
35 pub fn new(challenge_method: PkceCodeChallengeMethod, challenge: String) -> Self {
36 Pkce {
37 challenge_method,
38 challenge,
39 }
40 }
41
42 pub fn verify(&self, verifier: &str) -> Result<(), CodeChallengeError> {
48 self.challenge_method.verify(&self.challenge, verifier)
49 }
50}
51
52#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
53pub struct AuthorizationCode {
54 pub code: String,
55 pub pkce: Option<Pkce>,
56}
57
58#[derive(Debug, Clone, PartialEq, Eq, Serialize, Default)]
59#[serde(tag = "stage", rename_all = "lowercase")]
60pub enum AuthorizationGrantStage {
61 #[default]
62 Pending,
63 Fulfilled {
64 session_id: Ulid,
65 fulfilled_at: DateTime<Utc>,
66 },
67 Exchanged {
68 session_id: Ulid,
69 fulfilled_at: DateTime<Utc>,
70 exchanged_at: DateTime<Utc>,
71 },
72 Cancelled {
73 cancelled_at: DateTime<Utc>,
74 },
75}
76
77impl AuthorizationGrantStage {
78 #[must_use]
79 pub fn new() -> Self {
80 Self::Pending
81 }
82
83 fn fulfill(
84 self,
85 fulfilled_at: DateTime<Utc>,
86 session: &Session,
87 ) -> Result<Self, InvalidTransitionError> {
88 match self {
89 Self::Pending => Ok(Self::Fulfilled {
90 fulfilled_at,
91 session_id: session.id,
92 }),
93 _ => Err(InvalidTransitionError),
94 }
95 }
96
97 fn exchange(self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
98 match self {
99 Self::Fulfilled {
100 fulfilled_at,
101 session_id,
102 } => Ok(Self::Exchanged {
103 fulfilled_at,
104 exchanged_at,
105 session_id,
106 }),
107 _ => Err(InvalidTransitionError),
108 }
109 }
110
111 fn cancel(self, cancelled_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
112 match self {
113 Self::Pending => Ok(Self::Cancelled { cancelled_at }),
114 _ => Err(InvalidTransitionError),
115 }
116 }
117
118 #[must_use]
122 pub fn is_pending(&self) -> bool {
123 matches!(self, Self::Pending)
124 }
125
126 #[must_use]
130 pub fn is_fulfilled(&self) -> bool {
131 matches!(self, Self::Fulfilled { .. })
132 }
133
134 #[must_use]
138 pub fn is_exchanged(&self) -> bool {
139 matches!(self, Self::Exchanged { .. })
140 }
141}
142
143pub enum LoginHint<'a> {
144 MXID(&'a UserId),
145 None,
146}
147
148#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
149pub struct AuthorizationGrant {
150 pub id: Ulid,
151 #[serde(flatten)]
152 pub stage: AuthorizationGrantStage,
153 pub code: Option<AuthorizationCode>,
154 pub client_id: Ulid,
155 pub redirect_uri: Url,
156 pub scope: Scope,
157 pub state: Option<String>,
158 pub nonce: Option<String>,
159 pub response_mode: ResponseMode,
160 pub response_type_id_token: bool,
161 pub created_at: DateTime<Utc>,
162 pub login_hint: Option<String>,
163}
164
165impl std::ops::Deref for AuthorizationGrant {
166 type Target = AuthorizationGrantStage;
167
168 fn deref(&self) -> &Self::Target {
169 &self.stage
170 }
171}
172
173impl AuthorizationGrant {
174 #[must_use]
175 pub fn parse_login_hint(&self, homeserver: &str) -> LoginHint {
176 let Some(login_hint) = &self.login_hint else {
177 return LoginHint::None;
178 };
179
180 let Some((prefix, value)) = login_hint.split_once(':') else {
182 return LoginHint::None;
183 };
184
185 match prefix {
186 "mxid" => {
187 let Ok(mxid) = <&UserId>::try_from(value) else {
189 return LoginHint::None;
190 };
191
192 if mxid.server_name() != homeserver {
194 return LoginHint::None;
195 }
196
197 LoginHint::MXID(mxid)
198 }
199 _ => LoginHint::None,
201 }
202 }
203
204 pub fn exchange(mut self, exchanged_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
212 self.stage = self.stage.exchange(exchanged_at)?;
213 Ok(self)
214 }
215
216 pub fn fulfill(
224 mut self,
225 fulfilled_at: DateTime<Utc>,
226 session: &Session,
227 ) -> Result<Self, InvalidTransitionError> {
228 self.stage = self.stage.fulfill(fulfilled_at, session)?;
229 Ok(self)
230 }
231
232 pub fn cancel(mut self, canceld_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
244 self.stage = self.stage.cancel(canceld_at)?;
245 Ok(self)
246 }
247
248 #[doc(hidden)]
249 pub fn sample(now: DateTime<Utc>, rng: &mut impl RngCore) -> Self {
250 Self {
251 id: Ulid::from_datetime_with_source(now.into(), rng),
252 stage: AuthorizationGrantStage::Pending,
253 code: Some(AuthorizationCode {
254 code: Alphanumeric.sample_string(rng, 10),
255 pkce: None,
256 }),
257 client_id: Ulid::from_datetime_with_source(now.into(), rng),
258 redirect_uri: Url::parse("http://localhost:8080").unwrap(),
259 scope: Scope::from_iter([OPENID, PROFILE]),
260 state: Some(Alphanumeric.sample_string(rng, 10)),
261 nonce: Some(Alphanumeric.sample_string(rng, 10)),
262 response_mode: ResponseMode::Query,
263 response_type_id_token: false,
264 created_at: now,
265 login_hint: Some(String::from("mxid:@example-user:example.com")),
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use rand::thread_rng;
273
274 use super::*;
275
276 #[test]
277 fn no_login_hint() {
278 #[allow(clippy::disallowed_methods)]
279 let mut rng = thread_rng();
280
281 #[allow(clippy::disallowed_methods)]
282 let now = Utc::now();
283
284 let grant = AuthorizationGrant {
285 login_hint: None,
286 ..AuthorizationGrant::sample(now, &mut rng)
287 };
288
289 let hint = grant.parse_login_hint("example.com");
290
291 assert!(matches!(hint, LoginHint::None));
292 }
293
294 #[test]
295 fn valid_login_hint() {
296 #[allow(clippy::disallowed_methods)]
297 let mut rng = thread_rng();
298
299 #[allow(clippy::disallowed_methods)]
300 let now = Utc::now();
301
302 let grant = AuthorizationGrant {
303 login_hint: Some(String::from("mxid:@example-user:example.com")),
304 ..AuthorizationGrant::sample(now, &mut rng)
305 };
306
307 let hint = grant.parse_login_hint("example.com");
308
309 assert!(matches!(hint, LoginHint::MXID(mxid) if mxid.localpart() == "example-user"));
310 }
311
312 #[test]
313 fn invalid_login_hint() {
314 #[allow(clippy::disallowed_methods)]
315 let mut rng = thread_rng();
316
317 #[allow(clippy::disallowed_methods)]
318 let now = Utc::now();
319
320 let grant = AuthorizationGrant {
321 login_hint: Some(String::from("example-user")),
322 ..AuthorizationGrant::sample(now, &mut rng)
323 };
324
325 let hint = grant.parse_login_hint("example.com");
326
327 assert!(matches!(hint, LoginHint::None));
328 }
329
330 #[test]
331 fn valid_login_hint_for_wrong_homeserver() {
332 #[allow(clippy::disallowed_methods)]
333 let mut rng = thread_rng();
334
335 #[allow(clippy::disallowed_methods)]
336 let now = Utc::now();
337
338 let grant = AuthorizationGrant {
339 login_hint: Some(String::from("mxid:@example-user:matrix.org")),
340 ..AuthorizationGrant::sample(now, &mut rng)
341 };
342
343 let hint = grant.parse_login_hint("example.com");
344
345 assert!(matches!(hint, LoginHint::None));
346 }
347
348 #[test]
349 fn unknown_login_hint_type() {
350 #[allow(clippy::disallowed_methods)]
351 let mut rng = thread_rng();
352
353 #[allow(clippy::disallowed_methods)]
354 let now = Utc::now();
355
356 let grant = AuthorizationGrant {
357 login_hint: Some(String::from("something:anything")),
358 ..AuthorizationGrant::sample(now, &mut rng)
359 };
360
361 let hint = grant.parse_login_hint("example.com");
362
363 assert!(matches!(hint, LoginHint::None));
364 }
365}