mas_handlers/graphql/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7#![allow(clippy::module_name_repetitions)]
8
9use std::{net::IpAddr, ops::Deref, sync::Arc};
10
11use async_graphql::{
12    EmptySubscription, InputObject,
13    extensions::Tracing,
14    http::{GraphQLPlaygroundConfig, MultipartOptions, playground_source},
15};
16use axum::{
17    Extension, Json,
18    body::Body,
19    extract::{RawQuery, State as AxumState},
20    http::StatusCode,
21    response::{Html, IntoResponse, Response},
22};
23use axum_extra::typed_header::TypedHeader;
24use chrono::{DateTime, Utc};
25use futures_util::TryStreamExt;
26use headers::{Authorization, ContentType, HeaderValue, authorization::Bearer};
27use hyper::header::CACHE_CONTROL;
28use mas_axum_utils::{
29    FancyError, SessionInfo, SessionInfoExt, cookies::CookieJar, sentry::SentryEventID,
30};
31use mas_data_model::{BrowserSession, Session, SiteConfig, User};
32use mas_matrix::HomeserverConnection;
33use mas_policy::{InstantiateError, Policy, PolicyFactory};
34use mas_router::UrlBuilder;
35use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock};
36use mas_storage_pg::PgRepository;
37use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
38use rand::{SeedableRng, thread_rng};
39use rand_chacha::ChaChaRng;
40use sqlx::PgPool;
41use state::has_session_ended;
42use tracing::{Instrument, info_span};
43use ulid::Ulid;
44
45mod model;
46mod mutations;
47mod query;
48mod state;
49
50pub use self::state::{BoxState, State};
51use self::{
52    model::{CreationEvent, Node},
53    mutations::Mutation,
54    query::Query,
55};
56use crate::{
57    BoundActivityTracker, Limiter, RequesterFingerprint, impl_from_error_for_route,
58    passwords::PasswordManager,
59};
60
61#[cfg(test)]
62mod tests;
63
64/// Extra parameters we get from the listener configuration, because they are
65/// per-listener options. We pass them through request extensions.
66#[derive(Debug, Clone)]
67pub struct ExtraRouterParameters {
68    pub undocumented_oauth2_access: bool,
69}
70
71struct GraphQLState {
72    pool: PgPool,
73    homeserver_connection: Arc<dyn HomeserverConnection>,
74    policy_factory: Arc<PolicyFactory>,
75    site_config: SiteConfig,
76    password_manager: PasswordManager,
77    url_builder: UrlBuilder,
78    limiter: Limiter,
79}
80
81#[async_trait::async_trait]
82impl state::State for GraphQLState {
83    async fn repository(&self) -> Result<BoxRepository, RepositoryError> {
84        let repo = PgRepository::from_pool(&self.pool)
85            .await
86            .map_err(RepositoryError::from_error)?;
87
88        Ok(repo.boxed())
89    }
90
91    async fn policy(&self) -> Result<Policy, InstantiateError> {
92        self.policy_factory.instantiate().await
93    }
94
95    fn password_manager(&self) -> PasswordManager {
96        self.password_manager.clone()
97    }
98
99    fn site_config(&self) -> &SiteConfig {
100        &self.site_config
101    }
102
103    fn homeserver_connection(&self) -> &dyn HomeserverConnection {
104        self.homeserver_connection.as_ref()
105    }
106
107    fn url_builder(&self) -> &UrlBuilder {
108        &self.url_builder
109    }
110
111    fn limiter(&self) -> &Limiter {
112        &self.limiter
113    }
114
115    fn clock(&self) -> BoxClock {
116        let clock = SystemClock::default();
117        Box::new(clock)
118    }
119
120    fn rng(&self) -> BoxRng {
121        #[allow(clippy::disallowed_methods)]
122        let rng = thread_rng();
123
124        let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
125        Box::new(rng)
126    }
127}
128
129#[must_use]
130pub fn schema(
131    pool: &PgPool,
132    policy_factory: &Arc<PolicyFactory>,
133    homeserver_connection: impl HomeserverConnection + 'static,
134    site_config: SiteConfig,
135    password_manager: PasswordManager,
136    url_builder: UrlBuilder,
137    limiter: Limiter,
138) -> Schema {
139    let state = GraphQLState {
140        pool: pool.clone(),
141        policy_factory: Arc::clone(policy_factory),
142        homeserver_connection: Arc::new(homeserver_connection),
143        site_config,
144        password_manager,
145        url_builder,
146        limiter,
147    };
148    let state: BoxState = Box::new(state);
149
150    schema_builder().extension(Tracing).data(state).finish()
151}
152
153fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
154    let span = info_span!(
155        "GraphQL operation",
156        "otel.name" = tracing::field::Empty,
157        "otel.kind" = "server",
158        { GRAPHQL_DOCUMENT } = request.query,
159        { GRAPHQL_OPERATION_NAME } = tracing::field::Empty,
160    );
161
162    if let Some(name) = &request.operation_name {
163        span.record("otel.name", name);
164        span.record(GRAPHQL_OPERATION_NAME, name);
165    }
166
167    span
168}
169
170#[derive(thiserror::Error, Debug)]
171pub enum RouteError {
172    #[error(transparent)]
173    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
174
175    #[error("Loading of some database objects failed")]
176    LoadFailed,
177
178    #[error("Invalid access token")]
179    InvalidToken,
180
181    #[error("Missing scope")]
182    MissingScope,
183
184    #[error(transparent)]
185    ParseRequest(#[from] async_graphql::ParseRequestError),
186}
187
188impl_from_error_for_route!(mas_storage::RepositoryError);
189
190impl IntoResponse for RouteError {
191    fn into_response(self) -> Response {
192        let event_id = sentry::capture_error(&self);
193
194        let response = match self {
195            e @ (Self::Internal(_) | Self::LoadFailed) => {
196                let error = async_graphql::Error::new_with_source(e);
197                (
198                    StatusCode::INTERNAL_SERVER_ERROR,
199                    Json(serde_json::json!({"errors": [error]})),
200                )
201                    .into_response()
202            }
203
204            Self::InvalidToken => {
205                let error = async_graphql::Error::new("Invalid token");
206                (
207                    StatusCode::UNAUTHORIZED,
208                    Json(serde_json::json!({"errors": [error]})),
209                )
210                    .into_response()
211            }
212
213            Self::MissingScope => {
214                let error = async_graphql::Error::new("Missing urn:mas:graphql:* scope");
215                (
216                    StatusCode::UNAUTHORIZED,
217                    Json(serde_json::json!({"errors": [error]})),
218                )
219                    .into_response()
220            }
221
222            Self::ParseRequest(e) => {
223                let error = async_graphql::Error::new_with_source(e);
224                (
225                    StatusCode::BAD_REQUEST,
226                    Json(serde_json::json!({"errors": [error]})),
227                )
228                    .into_response()
229            }
230        };
231
232        (SentryEventID::from(event_id), response).into_response()
233    }
234}
235
236async fn get_requester(
237    undocumented_oauth2_access: bool,
238    clock: &impl Clock,
239    activity_tracker: &BoundActivityTracker,
240    mut repo: BoxRepository,
241    session_info: &SessionInfo,
242    user_agent: Option<String>,
243    token: Option<&str>,
244) -> Result<Requester, RouteError> {
245    let entity = if let Some(token) = token {
246        // If we haven't enabled undocumented_oauth2_access on the listener, we bail out
247        if !undocumented_oauth2_access {
248            return Err(RouteError::InvalidToken);
249        }
250
251        let token = repo
252            .oauth2_access_token()
253            .find_by_token(token)
254            .await?
255            .ok_or(RouteError::InvalidToken)?;
256
257        let session = repo
258            .oauth2_session()
259            .lookup(token.session_id)
260            .await?
261            .ok_or(RouteError::LoadFailed)?;
262
263        activity_tracker
264            .record_oauth2_session(clock, &session)
265            .await;
266
267        // Load the user if there is one
268        let user = if let Some(user_id) = session.user_id {
269            let user = repo
270                .user()
271                .lookup(user_id)
272                .await?
273                .ok_or(RouteError::LoadFailed)?;
274            Some(user)
275        } else {
276            None
277        };
278
279        // If there is a user for this session, check that it is not locked
280        let user_valid = user.as_ref().is_none_or(User::is_valid);
281
282        if !token.is_valid(clock.now()) || !session.is_valid() || !user_valid {
283            return Err(RouteError::InvalidToken);
284        }
285
286        if !session.scope.contains("urn:mas:graphql:*") {
287            return Err(RouteError::MissingScope);
288        }
289
290        RequestingEntity::OAuth2Session(Box::new((session, user)))
291    } else {
292        let maybe_session = session_info.load_active_session(&mut repo).await?;
293
294        if let Some(session) = maybe_session.as_ref() {
295            activity_tracker
296                .record_browser_session(clock, session)
297                .await;
298        }
299
300        RequestingEntity::from(maybe_session)
301    };
302
303    let requester = Requester {
304        entity,
305        ip_address: activity_tracker.ip(),
306        user_agent,
307    };
308
309    repo.cancel().await?;
310    Ok(requester)
311}
312
313pub async fn post(
314    AxumState(schema): AxumState<Schema>,
315    Extension(ExtraRouterParameters {
316        undocumented_oauth2_access,
317    }): Extension<ExtraRouterParameters>,
318    clock: BoxClock,
319    repo: BoxRepository,
320    activity_tracker: BoundActivityTracker,
321    cookie_jar: CookieJar,
322    content_type: Option<TypedHeader<ContentType>>,
323    authorization: Option<TypedHeader<Authorization<Bearer>>>,
324    user_agent: Option<TypedHeader<headers::UserAgent>>,
325    body: Body,
326) -> Result<impl IntoResponse, RouteError> {
327    let body = body.into_data_stream();
328    let token = authorization
329        .as_ref()
330        .map(|TypedHeader(Authorization(bearer))| bearer.token());
331    let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
332    let (session_info, mut cookie_jar) = cookie_jar.session_info();
333    let requester = get_requester(
334        undocumented_oauth2_access,
335        &clock,
336        &activity_tracker,
337        repo,
338        &session_info,
339        user_agent,
340        token,
341    )
342    .await?;
343
344    let content_type = content_type.map(|TypedHeader(h)| h.to_string());
345
346    let request = async_graphql::http::receive_body(
347        content_type,
348        body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
349            .into_async_read(),
350        MultipartOptions::default(),
351    )
352    .await?
353    .data(requester); // XXX: this should probably return another error response?
354
355    let span = span_for_graphql_request(&request);
356    let mut response = schema.execute(request).instrument(span).await;
357
358    if has_session_ended(&mut response) {
359        let session_info = session_info.mark_session_ended();
360        cookie_jar = cookie_jar.update_session_info(&session_info);
361    }
362
363    let cache_control = response
364        .cache_control
365        .value()
366        .and_then(|v| HeaderValue::from_str(&v).ok())
367        .map(|h| [(CACHE_CONTROL, h)]);
368
369    let headers = response.http_headers.clone();
370
371    Ok((headers, cache_control, cookie_jar, Json(response)))
372}
373
374pub async fn get(
375    AxumState(schema): AxumState<Schema>,
376    Extension(ExtraRouterParameters {
377        undocumented_oauth2_access,
378    }): Extension<ExtraRouterParameters>,
379    clock: BoxClock,
380    repo: BoxRepository,
381    activity_tracker: BoundActivityTracker,
382    cookie_jar: CookieJar,
383    authorization: Option<TypedHeader<Authorization<Bearer>>>,
384    user_agent: Option<TypedHeader<headers::UserAgent>>,
385    RawQuery(query): RawQuery,
386) -> Result<impl IntoResponse, FancyError> {
387    let token = authorization
388        .as_ref()
389        .map(|TypedHeader(Authorization(bearer))| bearer.token());
390    let user_agent = user_agent.map(|TypedHeader(h)| h.to_string());
391    let (session_info, mut cookie_jar) = cookie_jar.session_info();
392    let requester = get_requester(
393        undocumented_oauth2_access,
394        &clock,
395        &activity_tracker,
396        repo,
397        &session_info,
398        user_agent,
399        token,
400    )
401    .await?;
402
403    let request =
404        async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
405
406    let span = span_for_graphql_request(&request);
407    let mut response = schema.execute(request).instrument(span).await;
408
409    if has_session_ended(&mut response) {
410        let session_info = session_info.mark_session_ended();
411        cookie_jar = cookie_jar.update_session_info(&session_info);
412    }
413
414    let cache_control = response
415        .cache_control
416        .value()
417        .and_then(|v| HeaderValue::from_str(&v).ok())
418        .map(|h| [(CACHE_CONTROL, h)]);
419
420    let headers = response.http_headers.clone();
421
422    Ok((headers, cache_control, cookie_jar, Json(response)))
423}
424
425pub async fn playground() -> impl IntoResponse {
426    Html(playground_source(
427        GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
428    ))
429}
430
431pub type Schema = async_graphql::Schema<Query, Mutation, EmptySubscription>;
432pub type SchemaBuilder = async_graphql::SchemaBuilder<Query, Mutation, EmptySubscription>;
433
434#[must_use]
435pub fn schema_builder() -> SchemaBuilder {
436    async_graphql::Schema::build(Query::new(), Mutation::new(), EmptySubscription)
437        .register_output_type::<Node>()
438        .register_output_type::<CreationEvent>()
439}
440
441pub struct Requester {
442    entity: RequestingEntity,
443    ip_address: Option<IpAddr>,
444    user_agent: Option<String>,
445}
446
447impl Requester {
448    pub fn fingerprint(&self) -> RequesterFingerprint {
449        if let Some(ip) = self.ip_address {
450            RequesterFingerprint::new(ip)
451        } else {
452            RequesterFingerprint::EMPTY
453        }
454    }
455
456    pub fn for_policy(&self) -> mas_policy::Requester {
457        mas_policy::Requester {
458            ip_address: self.ip_address,
459            user_agent: self.user_agent.clone(),
460        }
461    }
462}
463
464impl Deref for Requester {
465    type Target = RequestingEntity;
466
467    fn deref(&self) -> &Self::Target {
468        &self.entity
469    }
470}
471
472/// The identity of the requester.
473#[derive(Debug, Clone, Default, PartialEq, Eq)]
474pub enum RequestingEntity {
475    /// The requester presented no authentication information.
476    #[default]
477    Anonymous,
478
479    /// The requester is a browser session, stored in a cookie.
480    BrowserSession(Box<BrowserSession>),
481
482    /// The requester is a `OAuth2` session, with an access token.
483    OAuth2Session(Box<(Session, Option<User>)>),
484}
485
486trait OwnerId {
487    fn owner_id(&self) -> Option<Ulid>;
488}
489
490impl OwnerId for User {
491    fn owner_id(&self) -> Option<Ulid> {
492        Some(self.id)
493    }
494}
495
496impl OwnerId for BrowserSession {
497    fn owner_id(&self) -> Option<Ulid> {
498        Some(self.user.id)
499    }
500}
501
502impl OwnerId for mas_data_model::UserEmail {
503    fn owner_id(&self) -> Option<Ulid> {
504        Some(self.user_id)
505    }
506}
507
508impl OwnerId for Session {
509    fn owner_id(&self) -> Option<Ulid> {
510        self.user_id
511    }
512}
513
514impl OwnerId for mas_data_model::CompatSession {
515    fn owner_id(&self) -> Option<Ulid> {
516        Some(self.user_id)
517    }
518}
519
520impl OwnerId for mas_data_model::UpstreamOAuthLink {
521    fn owner_id(&self) -> Option<Ulid> {
522        self.user_id
523    }
524}
525
526/// A dumb wrapper around a `Ulid` to implement `OwnerId` for it.
527pub struct UserId(Ulid);
528
529impl OwnerId for UserId {
530    fn owner_id(&self) -> Option<Ulid> {
531        Some(self.0)
532    }
533}
534
535impl RequestingEntity {
536    fn browser_session(&self) -> Option<&BrowserSession> {
537        match self {
538            Self::BrowserSession(session) => Some(session),
539            Self::OAuth2Session(_) | Self::Anonymous => None,
540        }
541    }
542
543    fn user(&self) -> Option<&User> {
544        match self {
545            Self::BrowserSession(session) => Some(&session.user),
546            Self::OAuth2Session(tuple) => tuple.1.as_ref(),
547            Self::Anonymous => None,
548        }
549    }
550
551    fn oauth2_session(&self) -> Option<&Session> {
552        match self {
553            Self::OAuth2Session(tuple) => Some(&tuple.0),
554            Self::BrowserSession(_) | Self::Anonymous => None,
555        }
556    }
557
558    /// Returns true if the requester can access the resource.
559    fn is_owner_or_admin(&self, resource: &impl OwnerId) -> bool {
560        // If the requester is an admin, they can do anything.
561        if self.is_admin() {
562            return true;
563        }
564
565        // Otherwise, they must be the owner of the resource.
566        let Some(owner_id) = resource.owner_id() else {
567            return false;
568        };
569
570        let Some(user) = self.user() else {
571            return false;
572        };
573
574        user.id == owner_id
575    }
576
577    fn is_admin(&self) -> bool {
578        match self {
579            Self::OAuth2Session(tuple) => {
580                // TODO: is this the right scope?
581                // This has to be in sync with the policy
582                tuple.0.scope.contains("urn:mas:admin")
583            }
584            Self::BrowserSession(_) | Self::Anonymous => false,
585        }
586    }
587
588    fn is_unauthenticated(&self) -> bool {
589        matches!(self, Self::Anonymous)
590    }
591}
592
593impl From<BrowserSession> for RequestingEntity {
594    fn from(session: BrowserSession) -> Self {
595        Self::BrowserSession(Box::new(session))
596    }
597}
598
599impl<T> From<Option<T>> for RequestingEntity
600where
601    T: Into<RequestingEntity>,
602{
603    fn from(session: Option<T>) -> Self {
604        session.map(Into::into).unwrap_or_default()
605    }
606}
607
608/// A filter for dates, with a lower bound and an upper bound
609#[derive(InputObject, Default, Clone, Copy)]
610pub struct DateFilter {
611    /// The lower bound of the date range
612    after: Option<DateTime<Utc>>,
613
614    /// The upper bound of the date range
615    before: Option<DateTime<Utc>>,
616}