1#![allow(clippy::module_name_repetitions)]
8
9use std::{net::IpAddr, ops::Deref, sync::Arc};
10
11use async_graphql::{
12 EmptySubscription, InputObject,
13 extensions::Tracing,
14 http::{GraphQLPlaygroundConfig, MultipartOptions, playground_source},
15};
16use axum::{
17 Extension, Json,
18 body::Body,
19 extract::{RawQuery, State as AxumState},
20 http::StatusCode,
21 response::{Html, IntoResponse, Response},
22};
23use axum_extra::typed_header::TypedHeader;
24use chrono::{DateTime, Utc};
25use futures_util::TryStreamExt;
26use headers::{Authorization, ContentType, HeaderValue, authorization::Bearer};
27use hyper::header::CACHE_CONTROL;
28use mas_axum_utils::{
29 FancyError, SessionInfo, SessionInfoExt, cookies::CookieJar, sentry::SentryEventID,
30};
31use mas_data_model::{BrowserSession, Session, SiteConfig, User};
32use mas_matrix::HomeserverConnection;
33use mas_policy::{InstantiateError, Policy, PolicyFactory};
34use mas_router::UrlBuilder;
35use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock};
36use mas_storage_pg::PgRepository;
37use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
38use rand::{SeedableRng, thread_rng};
39use rand_chacha::ChaChaRng;
40use sqlx::PgPool;
41use state::has_session_ended;
42use tracing::{Instrument, info_span};
43use ulid::Ulid;
44
45mod model;
46mod mutations;
47mod query;
48mod state;
49
50pub use self::state::{BoxState, State};
51use self::{
52 model::{CreationEvent, Node},
53 mutations::Mutation,
54 query::Query,
55};
56use crate::{
57 BoundActivityTracker, Limiter, RequesterFingerprint, impl_from_error_for_route,
58 passwords::PasswordManager,
59};
60
61#[cfg(test)]
62mod tests;
63
64#[derive(Debug, Clone)]
67pub struct ExtraRouterParameters {
68 pub undocumented_oauth2_access: bool,
69}
70
71struct GraphQLState {
72 pool: PgPool,
73 homeserver_connection: Arc<dyn HomeserverConnection>,
74 policy_factory: Arc<PolicyFactory>,
75 site_config: SiteConfig,
76 password_manager: PasswordManager,
77 url_builder: UrlBuilder,
78 limiter: Limiter,
79}
80
81#[async_trait::async_trait]
82impl state::State for GraphQLState {
83 async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
84 let repo = PgRepository::from_pool(&self.pool)
85 .await
86 .map_err(RepositoryError::from_error)?;
87
88 Ok(repo.boxed())
89 }
90
91 async fn policy(&self) -> Result<Policy, InstantiateError> {
92 self.policy_factory.instantiate().await
93 }
94
95 fn password_manager(&self) -> PasswordManager {
96 self.password_manager.clone()
97 }
98
99 fn site_config(&self) -> &SiteConfig {
100 &self.site_config
101 }
102
103 fn homeserver_connection(&self) -> &dyn HomeserverConnection {
104 self.homeserver_connection.as_ref()
105 }
106
107 fn url_builder(&self) -> &UrlBuilder {
108 &self.url_builder
109 }
110
111 fn limiter(&self) -> &Limiter {
112 &self.limiter
113 }
114
115 fn clock(&self) -> BoxClock {
116 let clock = SystemClock::default();
117 Box::new(clock)
118 }
119
120 fn rng(&self) -> BoxRng {
121 #[allow(clippy::disallowed_methods)]
122 let rng = thread_rng();
123
124 let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
125 Box::new(rng)
126 }
127}
128
129#[must_use]
130pub fn schema(
131 pool: &PgPool,
132 policy_factory: &Arc<PolicyFactory>,
133 homeserver_connection: impl HomeserverConnection + 'static,
134 site_config: SiteConfig,
135 password_manager: PasswordManager,
136 url_builder: UrlBuilder,
137 limiter: Limiter,
138) -> Schema {
139 let state = GraphQLState {
140 pool: pool.clone(),
141 policy_factory: Arc::clone(policy_factory),
142 homeserver_connection: Arc::new(homeserver_connection),
143 site_config,
144 password_manager,
145 url_builder,
146 limiter,
147 };
148 let state: BoxState = Box::new(state);
149
150 schema_builder().extension(Tracing).data(state).finish()
151}
152
153fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
154 let span = info_span!(
155 "GraphQL operation",
156 "otel.name" = tracing::field::Empty,
157 "otel.kind" = "server",
158 { GRAPHQL_DOCUMENT } = request.query,
159 { GRAPHQL_OPERATION_NAME } = tracing::field::Empty,
160 );
161
162 if let Some(name) = &request.operation_name {
163 span.record("otel.name", name);
164 span.record(GRAPHQL_OPERATION_NAME, name);
165 }
166
167 span
168}
169
170#[derive(thiserror::Error, Debug)]
171pub enum RouteError {
172 #[error(transparent)]
173 Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
174
175 #[error("Loading of some database objects failed")]
176 LoadFailed,
177
178 #[error("Invalid access token")]
179 InvalidToken,
180
181 #[error("Missing scope")]
182 MissingScope,
183
184 #[error(transparent)]
185 ParseRequest(#[from] async_graphql::ParseRequestError),
186}
187
188impl_from_error_for_route!(mas_storage::RepositoryError);
189
190impl IntoResponse for RouteError {
191 fn into_response(self) -> Response {
192 let event_id = sentry::capture_error(&self);
193
194 let response = match self {
195 e @ (Self::Internal(_) | Self::LoadFailed) => {
196 let error = async_graphql::Error::new_with_source(e);
197 (
198 StatusCode::INTERNAL_SERVER_ERROR,
199 Json(serde_json::json!({"errors": [error]})),
200 )
201 .into_response()
202 }
203
204 Self::InvalidToken => {
205 let error = async_graphql::Error::new("Invalid token");
206 (
207 StatusCode::UNAUTHORIZED,
208 Json(serde_json::json!({"errors": [error]})),
209 )
210 .into_response()
211 }
212
213 Self::MissingScope => {
214 let error = async_graphql::Error::new("Missing urn:mas:graphql:* scope");
215 (
216 StatusCode::UNAUTHORIZED,
217 Json(serde_json::json!({"errors": [error]})),
218 )
219 .into_response()
220 }
221
222 Self::ParseRequest(e) => {
223 let error = async_graphql::Error::new_with_source(e);
224 (
225 StatusCode::BAD_REQUEST,
226 Json(serde_json::json!({"errors": [error]})),
227 )
228 .into_response()
229 }
230 };
231
232 (SentryEventID::from(event_id), response).into_response()
233 }
234}
235
236async fn get_requester(
237 undocumented_oauth2_access: bool,
238 clock: &impl Clock,
239 activity_tracker: &BoundActivityTracker,
240 mut repo: BoxRepository,
241 session_info: &SessionInfo,
242 user_agent: Option<String>,
243 token: Option<&str>,
244) -> Result<Requester, RouteError> {
245 let entity = if let Some(token) = token {
246 if !undocumented_oauth2_access {
248 return Err(RouteError::InvalidToken);
249 }
250
251 let token = repo
252 .oauth2_access_token()
253 .find_by_token(token)
254 .await?
255 .ok_or(RouteError::InvalidToken)?;
256
257 let session = repo
258 .oauth2_session()
259 .lookup(token.session_id)
260 .await?
261 .ok_or(RouteError::LoadFailed)?;
262
263 activity_tracker
264 .record_oauth2_session(clock, &session)
265 .await;
266
267 let user = if let Some(user_id) = session.user_id {
269 let user = repo
270 .user()
271 .lookup(user_id)
272 .await?
273 .ok_or(RouteError::LoadFailed)?;
274 Some(user)
275 } else {
276 None
277 };
278
279 let user_valid = user.as_ref().is_none_or(User::is_valid);
281
282 if !token.is_valid(clock.now()) || !session.is_valid() || !user_valid {
283 return Err(RouteError::InvalidToken);
284 }
285
286 if !session.scope.contains("urn:mas:graphql:*") {
287 return Err(RouteError::MissingScope);
288 }
289
290 RequestingEntity::OAuth2Session(Box::new((session, user)))
291 } else {
292 let maybe_session = session_info.load_active_session(&mut repo).await?;
293
294 if let Some(session) = maybe_session.as_ref() {
295 activity_tracker
296 .record_browser_session(clock, session)
297 .await;
298 }
299
300 RequestingEntity::from(maybe_session)
301 };
302
303 let requester = Requester {
304 entity,
305 ip_address: activity_tracker.ip(),
306 user_agent,
307 };
308
309 repo.cancel().await?;
310 Ok(requester)
311}
312
313pub async fn post(
314 AxumState(schema): AxumState<Schema>,
315 Extension(ExtraRouterParameters {
316 undocumented_oauth2_access,
317 }): Extension<ExtraRouterParameters>,
318 clock: BoxClock,
319 repo: BoxRepository,
320 activity_tracker: BoundActivityTracker,
321 cookie_jar: CookieJar,
322 content_type: Option<TypedHeader<ContentType>>,
323 authorization: Option<TypedHeader<Authorization<Bearer>>>,
324 user_agent: Option<TypedHeader<headers::UserAgent>>,
325 body: Body,
326) -> Result<impl IntoResponse, RouteError> {
327 let body = body.into_data_stream();
328 let token = authorization
329 .as_ref()
330 .map(|TypedHeader(Authorization(bearer))| bearer.token());
331 let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
332 let (session_info, mut cookie_jar) = cookie_jar.session_info();
333 let requester = get_requester(
334 undocumented_oauth2_access,
335 &clock,
336 &activity_tracker,
337 repo,
338 &session_info,
339 user_agent,
340 token,
341 )
342 .await?;
343
344 let content_type = content_type.map(|TypedHeader(h)| h.to_string());
345
346 let request = async_graphql::http::receive_body(
347 content_type,
348 body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
349 .into_async_read(),
350 MultipartOptions::default(),
351 )
352 .await?
353 .data(requester); let span = span_for_graphql_request(&request);
356 let mut response = schema.execute(request).instrument(span).await;
357
358 if has_session_ended(&mut response) {
359 let session_info = session_info.mark_session_ended();
360 cookie_jar = cookie_jar.update_session_info(&session_info);
361 }
362
363 let cache_control = response
364 .cache_control
365 .value()
366 .and_then(|v| HeaderValue::from_str(&v).ok())
367 .map(|h| [(CACHE_CONTROL, h)]);
368
369 let headers = response.http_headers.clone();
370
371 Ok((headers, cache_control, cookie_jar, Json(response)))
372}
373
374pub async fn get(
375 AxumState(schema): AxumState<Schema>,
376 Extension(ExtraRouterParameters {
377 undocumented_oauth2_access,
378 }): Extension<ExtraRouterParameters>,
379 clock: BoxClock,
380 repo: BoxRepository,
381 activity_tracker: BoundActivityTracker,
382 cookie_jar: CookieJar,
383 authorization: Option<TypedHeader<Authorization<Bearer>>>,
384 user_agent: Option<TypedHeader<headers::UserAgent>>,
385 RawQuery(query): RawQuery,
386) -> Result<impl IntoResponse, FancyError> {
387 let token = authorization
388 .as_ref()
389 .map(|TypedHeader(Authorization(bearer))| bearer.token());
390 let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
391 let (session_info, mut cookie_jar) = cookie_jar.session_info();
392 let requester = get_requester(
393 undocumented_oauth2_access,
394 &clock,
395 &activity_tracker,
396 repo,
397 &session_info,
398 user_agent,
399 token,
400 )
401 .await?;
402
403 let request =
404 async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
405
406 let span = span_for_graphql_request(&request);
407 let mut response = schema.execute(request).instrument(span).await;
408
409 if has_session_ended(&mut response) {
410 let session_info = session_info.mark_session_ended();
411 cookie_jar = cookie_jar.update_session_info(&session_info);
412 }
413
414 let cache_control = response
415 .cache_control
416 .value()
417 .and_then(|v| HeaderValue::from_str(&v).ok())
418 .map(|h| [(CACHE_CONTROL, h)]);
419
420 let headers = response.http_headers.clone();
421
422 Ok((headers, cache_control, cookie_jar, Json(response)))
423}
424
425pub async fn playground() -> impl IntoResponse {
426 Html(playground_source(
427 GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
428 ))
429}
430
431pub type Schema = async_graphql::Schema<Query, Mutation, EmptySubscription>;
432pub type SchemaBuilder = async_graphql::SchemaBuilder<Query, Mutation, EmptySubscription>;
433
434#[must_use]
435pub fn schema_builder() -> SchemaBuilder {
436 async_graphql::Schema::build(Query::new(), Mutation::new(), EmptySubscription)
437 .register_output_type::<Node>()
438 .register_output_type::<CreationEvent>()
439}
440
441pub struct Requester {
442 entity: RequestingEntity,
443 ip_address: Option<IpAddr>,
444 user_agent: Option<String>,
445}
446
447impl Requester {
448 pub fn fingerprint(&self) -> RequesterFingerprint {
449 if let Some(ip) = self.ip_address {
450 RequesterFingerprint::new(ip)
451 } else {
452 RequesterFingerprint::EMPTY
453 }
454 }
455
456 pub fn for_policy(&self) -> mas_policy::Requester {
457 mas_policy::Requester {
458 ip_address: self.ip_address,
459 user_agent: self.user_agent.clone(),
460 }
461 }
462}
463
464impl Deref for Requester {
465 type Target = RequestingEntity;
466
467 fn deref(&self) -> &Self::Target {
468 &self.entity
469 }
470}
471
472#[derive(Debug, Clone, Default, PartialEq, Eq)]
474pub enum RequestingEntity {
475 #[default]
477 Anonymous,
478
479 BrowserSession(Box<BrowserSession>),
481
482 OAuth2Session(Box<(Session, Option<User>)>),
484}
485
486trait OwnerId {
487 fn owner_id(&self) -> Option<Ulid>;
488}
489
490impl OwnerId for User {
491 fn owner_id(&self) -> Option<Ulid> {
492 Some(self.id)
493 }
494}
495
496impl OwnerId for BrowserSession {
497 fn owner_id(&self) -> Option<Ulid> {
498 Some(self.user.id)
499 }
500}
501
502impl OwnerId for mas_data_model::UserEmail {
503 fn owner_id(&self) -> Option<Ulid> {
504 Some(self.user_id)
505 }
506}
507
508impl OwnerId for Session {
509 fn owner_id(&self) -> Option<Ulid> {
510 self.user_id
511 }
512}
513
514impl OwnerId for mas_data_model::CompatSession {
515 fn owner_id(&self) -> Option<Ulid> {
516 Some(self.user_id)
517 }
518}
519
520impl OwnerId for mas_data_model::UpstreamOAuthLink {
521 fn owner_id(&self) -> Option<Ulid> {
522 self.user_id
523 }
524}
525
526pub struct UserId(Ulid);
528
529impl OwnerId for UserId {
530 fn owner_id(&self) -> Option<Ulid> {
531 Some(self.0)
532 }
533}
534
535impl RequestingEntity {
536 fn browser_session(&self) -> Option<&BrowserSession> {
537 match self {
538 Self::BrowserSession(session) => Some(session),
539 Self::OAuth2Session(_) | Self::Anonymous => None,
540 }
541 }
542
543 fn user(&self) -> Option<&User> {
544 match self {
545 Self::BrowserSession(session) => Some(&session.user),
546 Self::OAuth2Session(tuple) => tuple.1.as_ref(),
547 Self::Anonymous => None,
548 }
549 }
550
551 fn oauth2_session(&self) -> Option<&Session> {
552 match self {
553 Self::OAuth2Session(tuple) => Some(&tuple.0),
554 Self::BrowserSession(_) | Self::Anonymous => None,
555 }
556 }
557
558 fn is_owner_or_admin(&self, resource: &impl OwnerId) -> bool {
560 if self.is_admin() {
562 return true;
563 }
564
565 let Some(owner_id) = resource.owner_id() else {
567 return false;
568 };
569
570 let Some(user) = self.user() else {
571 return false;
572 };
573
574 user.id == owner_id
575 }
576
577 fn is_admin(&self) -> bool {
578 match self {
579 Self::OAuth2Session(tuple) => {
580 tuple.0.scope.contains("urn:mas:admin")
583 }
584 Self::BrowserSession(_) | Self::Anonymous => false,
585 }
586 }
587
588 fn is_unauthenticated(&self) -> bool {
589 matches!(self, Self::Anonymous)
590 }
591}
592
593impl From<BrowserSession> for RequestingEntity {
594 fn from(session: BrowserSession) -> Self {
595 Self::BrowserSession(Box::new(session))
596 }
597}
598
599impl<T> From<Option<T>> for RequestingEntity
600where
601 T: Into<RequestingEntity>,
602{
603 fn from(session: Option<T>) -> Self {
604 session.map(Into::into).unwrap_or_default()
605 }
606}
607
608#[derive(InputObject, Default, Clone, Copy)]
610pub struct DateFilter {
611 after: Option<DateTime<Utc>>,
613
614 before: Option<DateTime<Utc>>,
616}