mas_data_model/upstream_oauth2/
session.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7use chrono::{DateTime, Utc};
8use serde::Serialize;
9use ulid::Ulid;
10
11use super::UpstreamOAuthLink;
12use crate::InvalidTransitionError;
13
14#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
15pub enum UpstreamOAuthAuthorizationSessionState {
16    #[default]
17    Pending,
18    Completed {
19        completed_at: DateTime<Utc>,
20        link_id: Ulid,
21        id_token: Option<String>,
22        extra_callback_parameters: Option<serde_json::Value>,
23        userinfo: Option<serde_json::Value>,
24    },
25    Consumed {
26        completed_at: DateTime<Utc>,
27        consumed_at: DateTime<Utc>,
28        link_id: Ulid,
29        id_token: Option<String>,
30        extra_callback_parameters: Option<serde_json::Value>,
31        userinfo: Option<serde_json::Value>,
32    },
33    Unlinked {
34        completed_at: DateTime<Utc>,
35        consumed_at: Option<DateTime<Utc>>,
36        unlinked_at: DateTime<Utc>,
37        id_token: Option<String>,
38    },
39}
40
41impl UpstreamOAuthAuthorizationSessionState {
42    /// Mark the upstream OAuth 2.0 authorization session as completed.
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if the upstream OAuth 2.0 authorization session state
47    /// is not [`Pending`].
48    ///
49    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
50    pub fn complete(
51        self,
52        completed_at: DateTime<Utc>,
53        link: &UpstreamOAuthLink,
54        id_token: Option<String>,
55        extra_callback_parameters: Option<serde_json::Value>,
56        userinfo: Option<serde_json::Value>,
57    ) -> Result<Self, InvalidTransitionError> {
58        match self {
59            Self::Pending => Ok(Self::Completed {
60                completed_at,
61                link_id: link.id,
62                id_token,
63                extra_callback_parameters,
64                userinfo,
65            }),
66            Self::Completed { .. } | Self::Consumed { .. } | Self::Unlinked { .. } => {
67                Err(InvalidTransitionError)
68            }
69        }
70    }
71
72    /// Mark the upstream OAuth 2.0 authorization session as consumed.
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if the upstream OAuth 2.0 authorization session state
77    /// is not [`Completed`].
78    ///
79    /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
80    pub fn consume(self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
81        match self {
82            Self::Completed {
83                completed_at,
84                link_id,
85                id_token,
86                extra_callback_parameters,
87                userinfo,
88            } => Ok(Self::Consumed {
89                completed_at,
90                link_id,
91                consumed_at,
92                id_token,
93                extra_callback_parameters,
94                userinfo,
95            }),
96            Self::Pending | Self::Consumed { .. } | Self::Unlinked { .. } => {
97                Err(InvalidTransitionError)
98            }
99        }
100    }
101
102    /// Get the link ID for the upstream OAuth 2.0 authorization session.
103    ///
104    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
105    /// [`Pending`].
106    ///
107    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
108    #[must_use]
109    pub fn link_id(&self) -> Option<Ulid> {
110        match self {
111            Self::Pending | Self::Unlinked { .. } => None,
112            Self::Completed { link_id, .. } | Self::Consumed { link_id, .. } => Some(*link_id),
113        }
114    }
115
116    /// Get the time at which the upstream OAuth 2.0 authorization session was
117    /// completed.
118    ///
119    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
120    /// [`Pending`].
121    ///
122    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
123    #[must_use]
124    pub fn completed_at(&self) -> Option<DateTime<Utc>> {
125        match self {
126            Self::Pending => None,
127            Self::Completed { completed_at, .. }
128            | Self::Consumed { completed_at, .. }
129            | Self::Unlinked { completed_at, .. } => Some(*completed_at),
130        }
131    }
132
133    /// Get the ID token for the upstream OAuth 2.0 authorization session.
134    ///
135    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
136    /// [`Pending`].
137    ///
138    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
139    #[must_use]
140    pub fn id_token(&self) -> Option<&str> {
141        match self {
142            Self::Pending => None,
143            Self::Completed { id_token, .. }
144            | Self::Consumed { id_token, .. }
145            | Self::Unlinked { id_token, .. } => id_token.as_deref(),
146        }
147    }
148
149    /// Get the extra query parameters that were sent to the upstream provider.
150    ///
151    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
152    /// not [`Pending`].
153    ///
154    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
155    #[must_use]
156    pub fn extra_callback_parameters(&self) -> Option<&serde_json::Value> {
157        match self {
158            Self::Pending | Self::Unlinked { .. } => None,
159            Self::Completed {
160                extra_callback_parameters,
161                ..
162            }
163            | Self::Consumed {
164                extra_callback_parameters,
165                ..
166            } => extra_callback_parameters.as_ref(),
167        }
168    }
169
170    #[must_use]
171    pub fn userinfo(&self) -> Option<&serde_json::Value> {
172        match self {
173            Self::Pending | Self::Unlinked { .. } => None,
174            Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(),
175        }
176    }
177
178    /// Get the time at which the upstream OAuth 2.0 authorization session was
179    /// consumed.
180    ///
181    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
182    /// not [`Consumed`].
183    ///
184    /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
185    #[must_use]
186    pub fn consumed_at(&self) -> Option<DateTime<Utc>> {
187        match self {
188            Self::Pending | Self::Completed { .. } => None,
189            Self::Consumed { consumed_at, .. } => Some(*consumed_at),
190            Self::Unlinked { consumed_at, .. } => *consumed_at,
191        }
192    }
193
194    /// Get the time at which the upstream OAuth 2.0 authorization session was
195    /// unlinked.
196    ///
197    /// Returns `None` if the upstream OAuth 2.0 authorization session state is
198    /// not [`Unlinked`].
199    ///
200    /// [`Unlinked`]: UpstreamOAuthAuthorizationSessionState::Unlinked
201    #[must_use]
202    pub fn unlinked_at(&self) -> Option<DateTime<Utc>> {
203        match self {
204            Self::Pending | Self::Completed { .. } | Self::Consumed { .. } => None,
205            Self::Unlinked { unlinked_at, .. } => Some(*unlinked_at),
206        }
207    }
208
209    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
210    /// [`Pending`].
211    ///
212    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
213    #[must_use]
214    pub fn is_pending(&self) -> bool {
215        matches!(self, Self::Pending)
216    }
217
218    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
219    /// [`Completed`].
220    ///
221    /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
222    #[must_use]
223    pub fn is_completed(&self) -> bool {
224        matches!(self, Self::Completed { .. })
225    }
226
227    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
228    /// [`Consumed`].
229    ///
230    /// [`Consumed`]: UpstreamOAuthAuthorizationSessionState::Consumed
231    #[must_use]
232    pub fn is_consumed(&self) -> bool {
233        matches!(self, Self::Consumed { .. })
234    }
235
236    /// Returns `true` if the upstream OAuth 2.0 authorization session state is
237    /// [`Unlinked`].
238    ///
239    /// [`Unlinked`]: UpstreamOAuthAuthorizationSessionState::Unlinked
240    #[must_use]
241    pub fn is_unlinked(&self) -> bool {
242        matches!(self, Self::Unlinked { .. })
243    }
244}
245
246#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
247pub struct UpstreamOAuthAuthorizationSession {
248    pub id: Ulid,
249    pub state: UpstreamOAuthAuthorizationSessionState,
250    pub provider_id: Ulid,
251    pub state_str: String,
252    pub code_challenge_verifier: Option<String>,
253    pub nonce: String,
254    pub created_at: DateTime<Utc>,
255}
256
257impl std::ops::Deref for UpstreamOAuthAuthorizationSession {
258    type Target = UpstreamOAuthAuthorizationSessionState;
259
260    fn deref(&self) -> &Self::Target {
261        &self.state
262    }
263}
264
265impl UpstreamOAuthAuthorizationSession {
266    /// Mark the upstream OAuth 2.0 authorization session as completed. Returns
267    /// the updated session.
268    ///
269    /// # Errors
270    ///
271    /// Returns an error if the upstream OAuth 2.0 authorization session state
272    /// is not [`Pending`].
273    ///
274    /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
275    pub fn complete(
276        mut self,
277        completed_at: DateTime<Utc>,
278        link: &UpstreamOAuthLink,
279        id_token: Option<String>,
280        extra_callback_parameters: Option<serde_json::Value>,
281        userinfo: Option<serde_json::Value>,
282    ) -> Result<Self, InvalidTransitionError> {
283        self.state = self.state.complete(
284            completed_at,
285            link,
286            id_token,
287            extra_callback_parameters,
288            userinfo,
289        )?;
290        Ok(self)
291    }
292
293    /// Mark the upstream OAuth 2.0 authorization session as consumed. Returns
294    /// the updated session.
295    ///
296    /// # Errors
297    ///
298    /// Returns an error if the upstream OAuth 2.0 authorization session state
299    /// is not [`Completed`].
300    ///
301    /// [`Completed`]: UpstreamOAuthAuthorizationSessionState::Completed
302    pub fn consume(mut self, consumed_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
303        self.state = self.state.consume(consumed_at)?;
304        Ok(self)
305    }
306}