mas_handlers/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7#![deny(clippy::future_not_send)]
8#![allow(
9    // Some axum handlers need that
10    clippy::unused_async,
11    // Because of how axum handlers work, we sometime have take many arguments
12    clippy::too_many_arguments,
13    // Code generated by tracing::instrument trigger this when returning an `impl Trait`
14    // See https://github.com/tokio-rs/tracing/issues/2613
15    clippy::let_with_type_underscore,
16)]
17
18use std::{
19    convert::Infallible,
20    sync::{Arc, LazyLock},
21    time::Duration,
22};
23
24use axum::{
25    Extension, Router,
26    extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State},
27    http::Method,
28    response::{Html, IntoResponse},
29    routing::{get, post},
30};
31use headers::HeaderName;
32use hyper::{
33    StatusCode, Version,
34    header::{
35        ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
36    },
37};
38use mas_axum_utils::{InternalError, cookies::CookieJar};
39use mas_data_model::SiteConfig;
40use mas_http::CorsLayerExt;
41use mas_keystore::{Encrypter, Keystore};
42use mas_matrix::HomeserverConnection;
43use mas_policy::Policy;
44use mas_router::{Route, UrlBuilder};
45use mas_storage::{BoxRepository, BoxRepositoryFactory};
46use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
47use opentelemetry::metrics::Meter;
48use sqlx::PgPool;
49use tower::util::AndThenLayer;
50use tower_http::cors::{Any, CorsLayer};
51
52use self::{graphql::ExtraRouterParameters, passwords::PasswordManager};
53
54mod admin;
55mod compat;
56mod graphql;
57mod health;
58mod oauth2;
59pub mod passwords;
60pub mod upstream_oauth2;
61mod views;
62
63mod activity_tracker;
64mod captcha;
65mod preferred_language;
66mod rate_limit;
67mod session;
68#[cfg(test)]
69mod test_utils;
70
71static METER: LazyLock<Meter> = LazyLock::new(|| {
72    let scope = opentelemetry::InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
73        .with_version(env!("CARGO_PKG_VERSION"))
74        .with_schema_url(opentelemetry_semantic_conventions::SCHEMA_URL)
75        .build();
76
77    opentelemetry::global::meter_with_scope(scope)
78});
79
80/// Implement `From<E>` for `RouteError`, for "internal server error" kind of
81/// errors.
82#[macro_export]
83macro_rules! impl_from_error_for_route {
84    ($route_error:ty : $error:ty) => {
85        impl From<$error> for $route_error {
86            fn from(e: $error) -> Self {
87                Self::Internal(Box::new(e))
88            }
89        }
90    };
91    ($error:ty) => {
92        impl_from_error_for_route!(self::RouteError: $error);
93    };
94}
95
96pub use mas_axum_utils::{ErrorWrapper, cookies::CookieManager};
97use mas_data_model::{BoxClock, BoxRng};
98
99pub use self::{
100    activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
101    admin::router as admin_api_router,
102    graphql::{
103        Schema as GraphQLSchema, schema as graphql_schema, schema_builder as graphql_schema_builder,
104    },
105    preferred_language::PreferredLanguage,
106    rate_limit::{Limiter, RequesterFingerprint},
107    upstream_oauth2::cache::MetadataCache,
108};
109
110pub fn healthcheck_router<S>() -> Router<S>
111where
112    S: Clone + Send + Sync + 'static,
113    PgPool: FromRef<S>,
114{
115    Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
116}
117
118pub fn graphql_router<S>(playground: bool, undocumented_oauth2_access: bool) -> Router<S>
119where
120    S: Clone + Send + Sync + 'static,
121    graphql::Schema: FromRef<S>,
122    BoundActivityTracker: FromRequestParts<S>,
123    BoxRepository: FromRequestParts<S>,
124    BoxClock: FromRequestParts<S>,
125    Encrypter: FromRef<S>,
126    CookieJar: FromRequestParts<S>,
127    Limiter: FromRef<S>,
128    RequesterFingerprint: FromRequestParts<S>,
129{
130    let mut router = Router::new()
131        .route(
132            mas_router::GraphQL::route(),
133            get(self::graphql::get).post(self::graphql::post),
134        )
135        // Pass the undocumented_oauth2_access parameter through the request extension, as it is
136        // per-listener
137        .layer(Extension(ExtraRouterParameters {
138            undocumented_oauth2_access,
139        }))
140        .layer(
141            CorsLayer::new()
142                .allow_origin(Any)
143                .allow_methods(Any)
144                .allow_otel_headers([
145                    AUTHORIZATION,
146                    ACCEPT,
147                    ACCEPT_LANGUAGE,
148                    CONTENT_LANGUAGE,
149                    CONTENT_TYPE,
150                ]),
151        );
152
153    if playground {
154        router = router.route(
155            mas_router::GraphQLPlayground::route(),
156            get(self::graphql::playground),
157        );
158    }
159
160    router
161}
162
163pub fn discovery_router<S>() -> Router<S>
164where
165    S: Clone + Send + Sync + 'static,
166    Keystore: FromRef<S>,
167    SiteConfig: FromRef<S>,
168    UrlBuilder: FromRef<S>,
169    BoxClock: FromRequestParts<S>,
170    BoxRng: FromRequestParts<S>,
171{
172    Router::new()
173        .route(
174            mas_router::OidcConfiguration::route(),
175            get(self::oauth2::discovery::get),
176        )
177        .route(
178            mas_router::Webfinger::route(),
179            get(self::oauth2::webfinger::get),
180        )
181        .layer(
182            CorsLayer::new()
183                .allow_origin(Any)
184                .allow_methods(Any)
185                .allow_otel_headers([
186                    AUTHORIZATION,
187                    ACCEPT,
188                    ACCEPT_LANGUAGE,
189                    CONTENT_LANGUAGE,
190                    CONTENT_TYPE,
191                ])
192                .max_age(Duration::from_secs(60 * 60)),
193        )
194}
195
196pub fn api_router<S>() -> Router<S>
197where
198    S: Clone + Send + Sync + 'static,
199    Keystore: FromRef<S>,
200    UrlBuilder: FromRef<S>,
201    BoxRepository: FromRequestParts<S>,
202    ActivityTracker: FromRequestParts<S>,
203    BoundActivityTracker: FromRequestParts<S>,
204    Encrypter: FromRef<S>,
205    reqwest::Client: FromRef<S>,
206    SiteConfig: FromRef<S>,
207    Templates: FromRef<S>,
208    Arc<dyn HomeserverConnection>: FromRef<S>,
209    BoxClock: FromRequestParts<S>,
210    BoxRng: FromRequestParts<S>,
211    Policy: FromRequestParts<S>,
212{
213    // All those routes are API-like, with a common CORS layer
214    Router::new()
215        .route(
216            mas_router::OAuth2Keys::route(),
217            get(self::oauth2::keys::get),
218        )
219        .route(
220            mas_router::OidcUserinfo::route(),
221            get(self::oauth2::userinfo::get).post(self::oauth2::userinfo::get),
222        )
223        .route(
224            mas_router::OAuth2Introspection::route(),
225            post(self::oauth2::introspection::post),
226        )
227        .route(
228            mas_router::OAuth2Revocation::route(),
229            post(self::oauth2::revoke::post),
230        )
231        .route(
232            mas_router::OAuth2TokenEndpoint::route(),
233            post(self::oauth2::token::post),
234        )
235        .route(
236            mas_router::OAuth2RegistrationEndpoint::route(),
237            post(self::oauth2::registration::post),
238        )
239        .route(
240            mas_router::OAuth2DeviceAuthorizationEndpoint::route(),
241            post(self::oauth2::device::authorize::post),
242        )
243        .layer(
244            CorsLayer::new()
245                .allow_origin(Any)
246                .allow_methods(Any)
247                .allow_otel_headers([
248                    AUTHORIZATION,
249                    ACCEPT,
250                    ACCEPT_LANGUAGE,
251                    CONTENT_LANGUAGE,
252                    CONTENT_TYPE,
253                    // Swagger will send this header, so we have to allow it to avoid CORS errors
254                    HeaderName::from_static("x-requested-with"),
255                ])
256                .max_age(Duration::from_secs(60 * 60)),
257        )
258}
259
260#[allow(clippy::trait_duplication_in_bounds)]
261pub fn compat_router<S>(templates: Templates) -> Router<S>
262where
263    S: Clone + Send + Sync + 'static,
264    UrlBuilder: FromRef<S>,
265    SiteConfig: FromRef<S>,
266    Arc<dyn HomeserverConnection>: FromRef<S>,
267    PasswordManager: FromRef<S>,
268    Limiter: FromRef<S>,
269    BoxRepositoryFactory: FromRef<S>,
270    BoundActivityTracker: FromRequestParts<S>,
271    RequesterFingerprint: FromRequestParts<S>,
272    BoxRepository: FromRequestParts<S>,
273    BoxClock: FromRequestParts<S>,
274    BoxRng: FromRequestParts<S>,
275    Policy: FromRequestParts<S>,
276{
277    // A sub-router for human-facing routes with error handling
278    let human_router = Router::new()
279        .route(
280            mas_router::CompatLoginSsoRedirect::route(),
281            get(self::compat::login_sso_redirect::get),
282        )
283        .route(
284            mas_router::CompatLoginSsoRedirectIdp::route(),
285            get(self::compat::login_sso_redirect::get),
286        )
287        .route(
288            mas_router::CompatLoginSsoRedirectSlash::route(),
289            get(self::compat::login_sso_redirect::get),
290        )
291        .layer(AndThenLayer::new(
292            async move |response: axum::response::Response| {
293                Ok::<_, Infallible>(recover_error(&templates, response))
294            },
295        ));
296
297    // A sub-router for API-facing routes with CORS
298    let api_router = Router::new()
299        .route(
300            mas_router::CompatLogin::route(),
301            get(self::compat::login::get).post(self::compat::login::post),
302        )
303        .route(
304            mas_router::CompatLogout::route(),
305            post(self::compat::logout::post),
306        )
307        .route(
308            mas_router::CompatLogoutAll::route(),
309            post(self::compat::logout_all::post),
310        )
311        .route(
312            mas_router::CompatRefresh::route(),
313            post(self::compat::refresh::post),
314        )
315        .layer(
316            CorsLayer::new()
317                .allow_origin(Any)
318                .allow_methods(Any)
319                .allow_otel_headers([
320                    AUTHORIZATION,
321                    ACCEPT,
322                    ACCEPT_LANGUAGE,
323                    CONTENT_LANGUAGE,
324                    CONTENT_TYPE,
325                    HeaderName::from_static("x-requested-with"),
326                ])
327                .max_age(Duration::from_secs(60 * 60)),
328        );
329
330    Router::new().merge(human_router).merge(api_router)
331}
332
333pub fn human_router<S>(templates: Templates) -> Router<S>
334where
335    S: Clone + Send + Sync + 'static,
336    UrlBuilder: FromRef<S>,
337    PreferredLanguage: FromRequestParts<S>,
338    BoxRepository: FromRequestParts<S>,
339    CookieJar: FromRequestParts<S>,
340    BoundActivityTracker: FromRequestParts<S>,
341    RequesterFingerprint: FromRequestParts<S>,
342    Encrypter: FromRef<S>,
343    Templates: FromRef<S>,
344    Keystore: FromRef<S>,
345    PasswordManager: FromRef<S>,
346    MetadataCache: FromRef<S>,
347    SiteConfig: FromRef<S>,
348    Limiter: FromRef<S>,
349    reqwest::Client: FromRef<S>,
350    Arc<dyn HomeserverConnection>: FromRef<S>,
351    BoxClock: FromRequestParts<S>,
352    BoxRng: FromRequestParts<S>,
353    Policy: FromRequestParts<S>,
354{
355    Router::new()
356        // XXX: hard-coded redirect from /account to /account/
357        .route(
358            "/account",
359            get(
360                async |State(url_builder): State<UrlBuilder>, RawQuery(query): RawQuery| {
361                    let prefix = url_builder.prefix().unwrap_or_default();
362                    let route = mas_router::Account::route();
363                    let destination = if let Some(query) = query {
364                        format!("{prefix}{route}?{query}")
365                    } else {
366                        format!("{prefix}{route}")
367                    };
368
369                    axum::response::Redirect::to(&destination)
370                },
371            ),
372        )
373        .route(mas_router::Account::route(), get(self::views::app::get))
374        .route(
375            mas_router::AccountWildcard::route(),
376            get(self::views::app::get),
377        )
378        .route(
379            mas_router::AccountRecoveryFinish::route(),
380            get(self::views::app::get_anonymous),
381        )
382        .route(
383            mas_router::ChangePasswordDiscovery::route(),
384            get(async |State(url_builder): State<UrlBuilder>| {
385                url_builder.redirect(&mas_router::AccountPasswordChange)
386            }),
387        )
388        .route(mas_router::Index::route(), get(self::views::index::get))
389        .route(
390            mas_router::Login::route(),
391            get(self::views::login::get).post(self::views::login::post),
392        )
393        .route(mas_router::Logout::route(), post(self::views::logout::post))
394        .route(
395            mas_router::Register::route(),
396            get(self::views::register::get),
397        )
398        .route(
399            mas_router::PasswordRegister::route(),
400            get(self::views::register::password::get).post(self::views::register::password::post),
401        )
402        .route(
403            mas_router::RegisterVerifyEmail::route(),
404            get(self::views::register::steps::verify_email::get)
405                .post(self::views::register::steps::verify_email::post),
406        )
407        .route(
408            mas_router::RegisterToken::route(),
409            get(self::views::register::steps::registration_token::get)
410                .post(self::views::register::steps::registration_token::post),
411        )
412        .route(
413            mas_router::RegisterDisplayName::route(),
414            get(self::views::register::steps::display_name::get)
415                .post(self::views::register::steps::display_name::post),
416        )
417        .route(
418            mas_router::RegisterFinish::route(),
419            get(self::views::register::steps::finish::get),
420        )
421        .route(
422            mas_router::AccountRecoveryStart::route(),
423            get(self::views::recovery::start::get).post(self::views::recovery::start::post),
424        )
425        .route(
426            mas_router::AccountRecoveryProgress::route(),
427            get(self::views::recovery::progress::get).post(self::views::recovery::progress::post),
428        )
429        .route(
430            mas_router::OAuth2AuthorizationEndpoint::route(),
431            get(self::oauth2::authorization::get),
432        )
433        .route(
434            mas_router::Consent::route(),
435            get(self::oauth2::authorization::consent::get)
436                .post(self::oauth2::authorization::consent::post),
437        )
438        .route(
439            mas_router::CompatLoginSsoComplete::route(),
440            get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
441        )
442        .route(
443            mas_router::UpstreamOAuth2Authorize::route(),
444            get(self::upstream_oauth2::authorize::get),
445        )
446        .route(
447            mas_router::UpstreamOAuth2Callback::route(),
448            get(self::upstream_oauth2::callback::handler)
449                .post(self::upstream_oauth2::callback::handler),
450        )
451        .route(
452            mas_router::UpstreamOAuth2Link::route(),
453            get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post),
454        )
455        .route(
456            mas_router::UpstreamOAuth2BackchannelLogout::route(),
457            post(self::upstream_oauth2::backchannel_logout::post),
458        )
459        .route(
460            mas_router::DeviceCodeLink::route(),
461            get(self::oauth2::device::link::get),
462        )
463        .route(
464            mas_router::DeviceCodeConsent::route(),
465            get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post),
466        )
467        .layer(AndThenLayer::new(
468            async move |response: axum::response::Response| {
469                Ok::<_, Infallible>(recover_error(&templates, response))
470            },
471        ))
472}
473
474fn recover_error(
475    templates: &Templates,
476    response: axum::response::Response,
477) -> axum::response::Response {
478    // Error responses should have an ErrorContext attached to them
479    let ext = response.extensions().get::<ErrorContext>();
480    if let Some(ctx) = ext
481        && let Ok(res) = templates.render_error(ctx)
482    {
483        let (mut parts, _original_body) = response.into_parts();
484        parts.headers.remove(CONTENT_TYPE);
485        parts.headers.remove(CONTENT_LENGTH);
486        return (parts, Html(res)).into_response();
487    }
488
489    response
490}
491
492/// The fallback handler for all routes that don't match anything else.
493///
494/// # Errors
495///
496/// Returns an error if the template rendering fails.
497pub async fn fallback(
498    State(templates): State<Templates>,
499    OriginalUri(uri): OriginalUri,
500    method: Method,
501    version: Version,
502    PreferredLanguage(locale): PreferredLanguage,
503) -> Result<impl IntoResponse, InternalError> {
504    let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
505    // XXX: this should look at the Accept header and return JSON if requested
506
507    let res = templates.render_not_found(&ctx)?;
508
509    Ok((StatusCode::NOT_FOUND, Html(res)))
510}