1use std::collections::HashMap;
8
9use axum::{
10 BoxError, Json,
11 extract::{
12 Form, FromRequest, FromRequestParts,
13 rejection::{FailedToDeserializeForm, FormRejection},
14 },
15 response::IntoResponse,
16};
17use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
18use headers::{Authorization, authorization::Basic};
19use http::{Request, StatusCode};
20use mas_data_model::{Client, JwksOrJwksUri};
21use mas_http::RequestBuilderExt;
22use mas_iana::oauth::OAuthClientAuthenticationMethod;
23use mas_jose::{jwk::PublicJsonWebKeySet, jwt::Jwt};
24use mas_keystore::Encrypter;
25use mas_storage::{RepositoryAccess, oauth2::OAuth2ClientRepository};
26use oauth2_types::errors::{ClientError, ClientErrorCode};
27use serde::{Deserialize, de::DeserializeOwned};
28use serde_json::Value;
29use thiserror::Error;
30
31use crate::record_error;
32
33static JWT_BEARER_CLIENT_ASSERTION: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
34
35#[derive(Deserialize)]
36struct AuthorizedForm<F = ()> {
37 client_id: Option<String>,
38 client_secret: Option<String>,
39 client_assertion_type: Option<String>,
40 client_assertion: Option<String>,
41
42 #[serde(flatten)]
43 inner: F,
44}
45
46#[derive(Debug, PartialEq, Eq)]
47pub enum Credentials {
48 None {
49 client_id: String,
50 },
51 ClientSecretBasic {
52 client_id: String,
53 client_secret: String,
54 },
55 ClientSecretPost {
56 client_id: String,
57 client_secret: String,
58 },
59 ClientAssertionJwtBearer {
60 client_id: String,
61 jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
62 },
63}
64
65impl Credentials {
66 #[must_use]
68 pub fn client_id(&self) -> &str {
69 match self {
70 Credentials::None { client_id }
71 | Credentials::ClientSecretBasic { client_id, .. }
72 | Credentials::ClientSecretPost { client_id, .. }
73 | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
74 }
75 }
76
77 pub async fn fetch<E>(
84 &self,
85 repo: &mut impl RepositoryAccess<Error = E>,
86 ) -> Result<Option<Client>, E> {
87 let client_id = match self {
88 Credentials::None { client_id }
89 | Credentials::ClientSecretBasic { client_id, .. }
90 | Credentials::ClientSecretPost { client_id, .. }
91 | Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
92 };
93
94 repo.oauth2_client().find_by_client_id(client_id).await
95 }
96
97 #[tracing::instrument(skip_all)]
103 pub async fn verify(
104 &self,
105 http_client: &reqwest::Client,
106 encrypter: &Encrypter,
107 method: &OAuthClientAuthenticationMethod,
108 client: &Client,
109 ) -> Result<(), CredentialsVerificationError> {
110 match (self, method) {
111 (Credentials::None { .. }, OAuthClientAuthenticationMethod::None) => {}
112
113 (
114 Credentials::ClientSecretPost { client_secret, .. },
115 OAuthClientAuthenticationMethod::ClientSecretPost,
116 )
117 | (
118 Credentials::ClientSecretBasic { client_secret, .. },
119 OAuthClientAuthenticationMethod::ClientSecretBasic,
120 ) => {
121 let encrypted_client_secret = client
123 .encrypted_client_secret
124 .as_ref()
125 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
126
127 let decrypted_client_secret = encrypter
128 .decrypt_string(encrypted_client_secret)
129 .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
130
131 if client_secret.as_bytes() != decrypted_client_secret {
133 return Err(CredentialsVerificationError::ClientSecretMismatch);
134 }
135 }
136
137 (
138 Credentials::ClientAssertionJwtBearer { jwt, .. },
139 OAuthClientAuthenticationMethod::PrivateKeyJwt,
140 ) => {
141 let jwks = client
143 .jwks
144 .as_ref()
145 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
146
147 let jwks = fetch_jwks(http_client, jwks)
148 .await
149 .map_err(CredentialsVerificationError::JwksFetchFailed)?;
150
151 jwt.verify_with_jwks(&jwks)
152 .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
153 }
154
155 (
156 Credentials::ClientAssertionJwtBearer { jwt, .. },
157 OAuthClientAuthenticationMethod::ClientSecretJwt,
158 ) => {
159 let encrypted_client_secret = client
161 .encrypted_client_secret
162 .as_ref()
163 .ok_or(CredentialsVerificationError::InvalidClientConfig)?;
164
165 let decrypted_client_secret = encrypter
166 .decrypt_string(encrypted_client_secret)
167 .map_err(|_e| CredentialsVerificationError::DecryptionError)?;
168
169 jwt.verify_with_shared_secret(decrypted_client_secret)
170 .map_err(|_| CredentialsVerificationError::InvalidAssertionSignature)?;
171 }
172
173 (_, _) => {
174 return Err(CredentialsVerificationError::AuthenticationMethodMismatch);
175 }
176 }
177 Ok(())
178 }
179}
180
181async fn fetch_jwks(
182 http_client: &reqwest::Client,
183 jwks: &JwksOrJwksUri,
184) -> Result<PublicJsonWebKeySet, BoxError> {
185 let uri = match jwks {
186 JwksOrJwksUri::Jwks(j) => return Ok(j.clone()),
187 JwksOrJwksUri::JwksUri(u) => u,
188 };
189
190 let response = http_client
191 .get(uri.as_str())
192 .send_traced()
193 .await?
194 .error_for_status()?
195 .json()
196 .await?;
197
198 Ok(response)
199}
200
201#[derive(Debug, Error)]
202pub enum CredentialsVerificationError {
203 #[error("failed to decrypt client credentials")]
204 DecryptionError,
205
206 #[error("invalid client configuration")]
207 InvalidClientConfig,
208
209 #[error("client secret did not match")]
210 ClientSecretMismatch,
211
212 #[error("authentication method mismatch")]
213 AuthenticationMethodMismatch,
214
215 #[error("invalid assertion signature")]
216 InvalidAssertionSignature,
217
218 #[error("failed to fetch jwks")]
219 JwksFetchFailed(#[source] Box<dyn std::error::Error + Send + Sync + 'static>),
220}
221
222impl CredentialsVerificationError {
223 #[must_use]
225 pub fn is_internal(&self) -> bool {
226 matches!(
227 self,
228 Self::DecryptionError | Self::InvalidClientConfig | Self::JwksFetchFailed(_)
229 )
230 }
231}
232
233#[derive(Debug, PartialEq, Eq)]
234pub struct ClientAuthorization<F = ()> {
235 pub credentials: Credentials,
236 pub form: Option<F>,
237}
238
239impl<F> ClientAuthorization<F> {
240 #[must_use]
242 pub fn client_id(&self) -> &str {
243 self.credentials.client_id()
244 }
245}
246
247#[derive(Debug, Error)]
248pub enum ClientAuthorizationError {
249 #[error("Invalid Authorization header")]
250 InvalidHeader,
251
252 #[error("Could not deserialize request body")]
253 BadForm(#[source] FailedToDeserializeForm),
254
255 #[error("client_id in form ({form:?}) does not match credential ({credential:?})")]
256 ClientIdMismatch { credential: String, form: String },
257
258 #[error("Unsupported client_assertion_type: {client_assertion_type}")]
259 UnsupportedClientAssertion { client_assertion_type: String },
260
261 #[error("No credentials were presented")]
262 MissingCredentials,
263
264 #[error("Invalid request")]
265 InvalidRequest,
266
267 #[error("Invalid client_assertion")]
268 InvalidAssertion,
269
270 #[error(transparent)]
271 Internal(Box<dyn std::error::Error>),
272}
273
274impl IntoResponse for ClientAuthorizationError {
275 fn into_response(self) -> axum::response::Response {
276 let sentry_event_id = record_error!(self, Self::Internal(_));
277 match &self {
278 ClientAuthorizationError::InvalidHeader => (
279 StatusCode::BAD_REQUEST,
280 sentry_event_id,
281 Json(ClientError::new(
282 ClientErrorCode::InvalidRequest,
283 "Invalid Authorization header",
284 )),
285 ),
286
287 ClientAuthorizationError::BadForm(err) => (
288 StatusCode::BAD_REQUEST,
289 sentry_event_id,
290 Json(
291 ClientError::from(ClientErrorCode::InvalidRequest)
292 .with_description(format!("{err}")),
293 ),
294 ),
295
296 ClientAuthorizationError::ClientIdMismatch { .. } => (
297 StatusCode::BAD_REQUEST,
298 sentry_event_id,
299 Json(
300 ClientError::from(ClientErrorCode::InvalidGrant)
301 .with_description(format!("{self}")),
302 ),
303 ),
304
305 ClientAuthorizationError::UnsupportedClientAssertion { .. } => (
306 StatusCode::BAD_REQUEST,
307 sentry_event_id,
308 Json(
309 ClientError::from(ClientErrorCode::InvalidRequest)
310 .with_description(format!("{self}")),
311 ),
312 ),
313
314 ClientAuthorizationError::MissingCredentials => (
315 StatusCode::BAD_REQUEST,
316 sentry_event_id,
317 Json(ClientError::new(
318 ClientErrorCode::InvalidRequest,
319 "No credentials were presented",
320 )),
321 ),
322
323 ClientAuthorizationError::InvalidRequest => (
324 StatusCode::BAD_REQUEST,
325 sentry_event_id,
326 Json(ClientError::from(ClientErrorCode::InvalidRequest)),
327 ),
328
329 ClientAuthorizationError::InvalidAssertion => (
330 StatusCode::BAD_REQUEST,
331 sentry_event_id,
332 Json(ClientError::new(
333 ClientErrorCode::InvalidRequest,
334 "Invalid client_assertion",
335 )),
336 ),
337
338 ClientAuthorizationError::Internal(e) => (
339 StatusCode::INTERNAL_SERVER_ERROR,
340 sentry_event_id,
341 Json(
342 ClientError::from(ClientErrorCode::ServerError)
343 .with_description(format!("{e}")),
344 ),
345 ),
346 }
347 .into_response()
348 }
349}
350
351impl<S, F> FromRequest<S> for ClientAuthorization<F>
352where
353 F: DeserializeOwned,
354 S: Send + Sync,
355{
356 type Rejection = ClientAuthorizationError;
357
358 #[allow(clippy::too_many_lines)]
359 async fn from_request(
360 req: Request<axum::body::Body>,
361 state: &S,
362 ) -> Result<Self, Self::Rejection> {
363 let (mut parts, body) = req.into_parts();
365
366 let header =
367 TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
368
369 let credentials_from_header = match header {
371 Ok(header) => Some((header.username().to_owned(), header.password().to_owned())),
372 Err(err) => match err.reason() {
373 TypedHeaderRejectionReason::Missing => None,
375 _ => return Err(ClientAuthorizationError::InvalidHeader),
377 },
378 };
379
380 let req = Request::from_parts(parts, body);
382
383 let (
385 client_id_from_form,
386 client_secret_from_form,
387 client_assertion_type,
388 client_assertion,
389 form,
390 ) = match Form::<AuthorizedForm<F>>::from_request(req, state).await {
391 Ok(Form(form)) => (
392 form.client_id,
393 form.client_secret,
394 form.client_assertion_type,
395 form.client_assertion,
396 Some(form.inner),
397 ),
398 Err(FormRejection::InvalidFormContentType(_err)) => (None, None, None, None, None),
400 Err(FormRejection::FailedToDeserializeForm(err)) => {
402 return Err(ClientAuthorizationError::BadForm(err));
403 }
404 Err(e) => return Err(ClientAuthorizationError::Internal(Box::new(e))),
406 };
407
408 let credentials = match (
410 credentials_from_header,
411 client_id_from_form,
412 client_secret_from_form,
413 client_assertion_type,
414 client_assertion,
415 ) {
416 (Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
417 if let Some(client_id_from_form) = client_id_from_form {
418 if client_id != client_id_from_form {
420 return Err(ClientAuthorizationError::ClientIdMismatch {
421 credential: client_id,
422 form: client_id_from_form,
423 });
424 }
425 }
426
427 Credentials::ClientSecretBasic {
428 client_id,
429 client_secret,
430 }
431 }
432
433 (None, Some(client_id), Some(client_secret), None, None) => {
434 Credentials::ClientSecretPost {
436 client_id,
437 client_secret,
438 }
439 }
440
441 (None, Some(client_id), None, None, None) => {
442 Credentials::None { client_id }
444 }
445
446 (
447 None,
448 client_id_from_form,
449 None,
450 Some(client_assertion_type),
451 Some(client_assertion),
452 ) if client_assertion_type == JWT_BEARER_CLIENT_ASSERTION => {
453 let jwt: Jwt<'static, HashMap<String, Value>> = Jwt::try_from(client_assertion)
455 .map_err(|_| ClientAuthorizationError::InvalidAssertion)?;
456
457 let client_id = if let Some(Value::String(client_id)) = jwt.payload().get("sub") {
458 client_id.clone()
459 } else {
460 return Err(ClientAuthorizationError::InvalidAssertion);
461 };
462
463 if let Some(client_id_from_form) = client_id_from_form {
464 if client_id != client_id_from_form {
466 return Err(ClientAuthorizationError::ClientIdMismatch {
467 credential: client_id,
468 form: client_id_from_form,
469 });
470 }
471 }
472
473 Credentials::ClientAssertionJwtBearer {
474 client_id,
475 jwt: Box::new(jwt),
476 }
477 }
478
479 (None, None, None, Some(client_assertion_type), Some(_client_assertion)) => {
480 return Err(ClientAuthorizationError::UnsupportedClientAssertion {
482 client_assertion_type,
483 });
484 }
485
486 (None, None, None, None, None) => {
487 return Err(ClientAuthorizationError::MissingCredentials);
489 }
490
491 _ => {
492 return Err(ClientAuthorizationError::InvalidRequest);
494 }
495 };
496
497 Ok(ClientAuthorization { credentials, form })
498 }
499}
500
501#[cfg(test)]
502mod tests {
503 use axum::body::Body;
504 use http::{Method, Request};
505
506 use super::*;
507
508 #[tokio::test]
509 async fn none_test() {
510 let req = Request::builder()
511 .method(Method::POST)
512 .header(
513 http::header::CONTENT_TYPE,
514 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
515 )
516 .body(Body::new("client_id=client-id&foo=bar".to_owned()))
517 .unwrap();
518
519 assert_eq!(
520 ClientAuthorization::<serde_json::Value>::from_request(req, &())
521 .await
522 .unwrap(),
523 ClientAuthorization {
524 credentials: Credentials::None {
525 client_id: "client-id".to_owned(),
526 },
527 form: Some(serde_json::json!({"foo": "bar"})),
528 }
529 );
530 }
531
532 #[tokio::test]
533 async fn client_secret_basic_test() {
534 let req = Request::builder()
535 .method(Method::POST)
536 .header(
537 http::header::CONTENT_TYPE,
538 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
539 )
540 .header(
541 http::header::AUTHORIZATION,
542 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
543 )
544 .body(Body::new("foo=bar".to_owned()))
545 .unwrap();
546
547 assert_eq!(
548 ClientAuthorization::<serde_json::Value>::from_request(req, &())
549 .await
550 .unwrap(),
551 ClientAuthorization {
552 credentials: Credentials::ClientSecretBasic {
553 client_id: "client-id".to_owned(),
554 client_secret: "client-secret".to_owned(),
555 },
556 form: Some(serde_json::json!({"foo": "bar"})),
557 }
558 );
559
560 let req = Request::builder()
562 .method(Method::POST)
563 .header(
564 http::header::CONTENT_TYPE,
565 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
566 )
567 .header(
568 http::header::AUTHORIZATION,
569 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
570 )
571 .body(Body::new("client_id=client-id&foo=bar".to_owned()))
572 .unwrap();
573
574 assert_eq!(
575 ClientAuthorization::<serde_json::Value>::from_request(req, &())
576 .await
577 .unwrap(),
578 ClientAuthorization {
579 credentials: Credentials::ClientSecretBasic {
580 client_id: "client-id".to_owned(),
581 client_secret: "client-secret".to_owned(),
582 },
583 form: Some(serde_json::json!({"foo": "bar"})),
584 }
585 );
586
587 let req = Request::builder()
589 .method(Method::POST)
590 .header(
591 http::header::CONTENT_TYPE,
592 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
593 )
594 .header(
595 http::header::AUTHORIZATION,
596 "Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=",
597 )
598 .body(Body::new("client_id=mismatch-id&foo=bar".to_owned()))
599 .unwrap();
600
601 assert!(matches!(
602 ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
603 Err(ClientAuthorizationError::ClientIdMismatch { .. }),
604 ));
605
606 let req = Request::builder()
608 .method(Method::POST)
609 .header(
610 http::header::CONTENT_TYPE,
611 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
612 )
613 .header(http::header::AUTHORIZATION, "Basic invalid")
614 .body(Body::new("foo=bar".to_owned()))
615 .unwrap();
616
617 assert!(matches!(
618 ClientAuthorization::<serde_json::Value>::from_request(req, &()).await,
619 Err(ClientAuthorizationError::InvalidHeader),
620 ));
621 }
622
623 #[tokio::test]
624 async fn client_secret_post_test() {
625 let req = Request::builder()
626 .method(Method::POST)
627 .header(
628 http::header::CONTENT_TYPE,
629 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
630 )
631 .body(Body::new(
632 "client_id=client-id&client_secret=client-secret&foo=bar".to_owned(),
633 ))
634 .unwrap();
635
636 assert_eq!(
637 ClientAuthorization::<serde_json::Value>::from_request(req, &())
638 .await
639 .unwrap(),
640 ClientAuthorization {
641 credentials: Credentials::ClientSecretPost {
642 client_id: "client-id".to_owned(),
643 client_secret: "client-secret".to_owned(),
644 },
645 form: Some(serde_json::json!({"foo": "bar"})),
646 }
647 );
648 }
649
650 #[tokio::test]
651 async fn client_assertion_test() {
652 let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJjbGllbnQtaWQiLCJzdWIiOiJjbGllbnQtaWQiLCJhdWQiOiJodHRwczovL2V4YW1wbGUuY29tL29hdXRoMi9pbnRyb3NwZWN0IiwianRpIjoiYWFiYmNjIiwiZXhwIjoxNTE2MjM5MzIyLCJpYXQiOjE1MTYyMzkwMjJ9.XTaACG_Rww0GPecSZvkbem-AczNy9LLNBueCLCiQajU";
654 let body = Body::new(format!(
655 "client_assertion_type={JWT_BEARER_CLIENT_ASSERTION}&client_assertion={jwt}&foo=bar",
656 ));
657
658 let req = Request::builder()
659 .method(Method::POST)
660 .header(
661 http::header::CONTENT_TYPE,
662 mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
663 )
664 .body(body)
665 .unwrap();
666
667 let authz = ClientAuthorization::<serde_json::Value>::from_request(req, &())
668 .await
669 .unwrap();
670 assert_eq!(authz.form, Some(serde_json::json!({"foo": "bar"})));
671
672 let Credentials::ClientAssertionJwtBearer { client_id, jwt } = authz.credentials else {
673 panic!("expected a JWT client_assertion");
674 };
675
676 assert_eq!(client_id, "client-id");
677 jwt.verify_with_shared_secret(b"client-secret".to_vec())
678 .unwrap();
679 }
680}