1#![deny(missing_docs)]
8#![allow(clippy::module_name_repetitions)]
9
10use std::{
13 collections::{BTreeMap, HashSet},
14 sync::Arc,
15};
16
17use anyhow::Context as _;
18use arc_swap::ArcSwap;
19use camino::{Utf8Path, Utf8PathBuf};
20use mas_i18n::Translator;
21use mas_router::UrlBuilder;
22use mas_spa::ViteManifest;
23use minijinja::{UndefinedBehavior, Value};
24use rand::Rng;
25use serde::Serialize;
26use thiserror::Error;
27use tokio::task::JoinError;
28use tracing::{debug, info};
29use walkdir::DirEntry;
30
31mod context;
32mod forms;
33mod functions;
34
35#[macro_use]
36mod macros;
37
38pub use self::{
39 context::{
40 AccountInactiveContext, ApiDocContext, AppContext, CompatLoginPolicyViolationContext,
41 CompatSsoContext, ConsentContext, DeviceConsentContext, DeviceLinkContext,
42 DeviceLinkFormField, DeviceNameContext, EmailRecoveryContext, EmailVerificationContext,
43 EmptyContext, ErrorContext, FormPostContext, IndexContext, LoginContext, LoginFormField,
44 NotFoundContext, PasswordRegisterContext, PolicyViolationContext, PostAuthContext,
45 PostAuthContextInner, RecoveryExpiredContext, RecoveryFinishContext,
46 RecoveryFinishFormField, RecoveryProgressContext, RecoveryStartContext,
47 RecoveryStartFormField, RegisterContext, RegisterFormField,
48 RegisterStepsDisplayNameContext, RegisterStepsDisplayNameFormField,
49 RegisterStepsEmailInUseContext, RegisterStepsRegistrationTokenContext,
50 RegisterStepsRegistrationTokenFormField, RegisterStepsVerifyEmailContext,
51 RegisterStepsVerifyEmailFormField, SiteBranding, SiteConfigExt, SiteFeatures,
52 TemplateContext, UpstreamExistingLinkContext, UpstreamRegister, UpstreamRegisterFormField,
53 UpstreamSuggestLink, WithCaptcha, WithCsrf, WithLanguage, WithOptionalSession, WithSession,
54 },
55 forms::{FieldError, FormError, FormField, FormState, ToFormState},
56};
57use crate::context::SampleIdentifier;
58
59#[must_use]
63pub fn escape_html(input: &str) -> String {
64 v_htmlescape::escape(input).to_string()
65}
66
67#[derive(Debug, Clone)]
70pub struct Templates {
71 environment: Arc<ArcSwap<minijinja::Environment<'static>>>,
72 translator: Arc<ArcSwap<Translator>>,
73 url_builder: UrlBuilder,
74 branding: SiteBranding,
75 features: SiteFeatures,
76 vite_manifest_path: Option<Utf8PathBuf>,
77 translations_path: Utf8PathBuf,
78 path: Utf8PathBuf,
79 strict: bool,
82}
83
84#[derive(Error, Debug)]
86pub enum TemplateLoadingError {
87 #[error(transparent)]
89 IO(#[from] std::io::Error),
90
91 #[error("failed to read the assets manifest")]
93 ViteManifestIO(#[source] std::io::Error),
94
95 #[error("invalid assets manifest")]
97 ViteManifest(#[from] serde_json::Error),
98
99 #[error("failed to load the translations")]
101 Translations(#[from] mas_i18n::LoadError),
102
103 #[error("failed to traverse the filesystem")]
105 WalkDir(#[from] walkdir::Error),
106
107 #[error("encountered non-UTF-8 path")]
109 NonUtf8Path(#[from] camino::FromPathError),
110
111 #[error("encountered non-UTF-8 path")]
113 NonUtf8PathBuf(#[from] camino::FromPathBufError),
114
115 #[error("encountered invalid path")]
117 InvalidPath(#[from] std::path::StripPrefixError),
118
119 #[error("could not load and compile some templates")]
121 Compile(#[from] minijinja::Error),
122
123 #[error("error from async runtime")]
125 Runtime(#[from] JoinError),
126
127 #[error("missing templates {missing:?}")]
129 MissingTemplates {
130 missing: HashSet<String>,
132 loaded: HashSet<String>,
134 },
135}
136
137fn is_hidden(entry: &DirEntry) -> bool {
138 entry
139 .file_name()
140 .to_str()
141 .is_some_and(|s| s.starts_with('.'))
142}
143
144impl Templates {
145 #[tracing::instrument(
156 name = "templates.load",
157 skip_all,
158 fields(%path),
159 )]
160 pub async fn load(
161 path: Utf8PathBuf,
162 url_builder: UrlBuilder,
163 vite_manifest_path: Option<Utf8PathBuf>,
164 translations_path: Utf8PathBuf,
165 branding: SiteBranding,
166 features: SiteFeatures,
167 strict: bool,
168 ) -> Result<Self, TemplateLoadingError> {
169 let (translator, environment) = Self::load_(
170 &path,
171 url_builder.clone(),
172 vite_manifest_path.as_deref(),
173 &translations_path,
174 branding.clone(),
175 features,
176 strict,
177 )
178 .await?;
179 Ok(Self {
180 environment: Arc::new(ArcSwap::new(environment)),
181 translator: Arc::new(ArcSwap::new(translator)),
182 path,
183 url_builder,
184 vite_manifest_path,
185 translations_path,
186 branding,
187 features,
188 strict,
189 })
190 }
191
192 async fn load_(
193 path: &Utf8Path,
194 url_builder: UrlBuilder,
195 vite_manifest_path: Option<&Utf8Path>,
196 translations_path: &Utf8Path,
197 branding: SiteBranding,
198 features: SiteFeatures,
199 strict: bool,
200 ) -> Result<(Arc<Translator>, Arc<minijinja::Environment<'static>>), TemplateLoadingError> {
201 let path = path.to_owned();
202 let span = tracing::Span::current();
203
204 let vite_manifest = if let Some(vite_manifest_path) = vite_manifest_path {
206 let raw_vite_manifest = tokio::fs::read(vite_manifest_path)
207 .await
208 .map_err(TemplateLoadingError::ViteManifestIO)?;
209
210 Some(
211 serde_json::from_slice::<ViteManifest>(&raw_vite_manifest)
212 .map_err(TemplateLoadingError::ViteManifest)?,
213 )
214 } else {
215 None
216 };
217
218 let translations_path = translations_path.to_owned();
221 let translator =
222 tokio::task::spawn_blocking(move || Translator::load_from_path(&translations_path))
223 .await??;
224 let translator = Arc::new(translator);
225
226 debug!(locales = ?translator.available_locales(), "Loaded translations");
227
228 let (loaded, mut env) = tokio::task::spawn_blocking(move || {
229 span.in_scope(move || {
230 let mut loaded: HashSet<_> = HashSet::new();
231 let mut env = minijinja::Environment::new();
232 env.set_undefined_behavior(if strict {
234 UndefinedBehavior::Strict
235 } else {
236 UndefinedBehavior::SemiStrict
240 });
241 let root = path.canonicalize_utf8()?;
242 info!(%root, "Loading templates from filesystem");
243 for entry in walkdir::WalkDir::new(&root)
244 .min_depth(1)
245 .into_iter()
246 .filter_entry(|e| !is_hidden(e))
247 {
248 let entry = entry?;
249 if entry.file_type().is_file() {
250 let path = Utf8PathBuf::try_from(entry.into_path())?;
251 let Some(ext) = path.extension() else {
252 continue;
253 };
254
255 if ext == "html" || ext == "txt" || ext == "subject" {
256 let relative = path.strip_prefix(&root)?;
257 debug!(%relative, "Registering template");
258 let template = std::fs::read_to_string(&path)?;
259 env.add_template_owned(relative.as_str().to_owned(), template)?;
260 loaded.insert(relative.as_str().to_owned());
261 }
262 }
263 }
264
265 Ok::<_, TemplateLoadingError>((loaded, env))
266 })
267 })
268 .await??;
269
270 env.add_global("branding", Value::from_object(branding));
271 env.add_global("features", Value::from_object(features));
272
273 self::functions::register(
274 &mut env,
275 url_builder,
276 vite_manifest,
277 Arc::clone(&translator),
278 );
279
280 let env = Arc::new(env);
281
282 let needed: HashSet<_> = TEMPLATES.into_iter().map(ToOwned::to_owned).collect();
283 debug!(?loaded, ?needed, "Templates loaded");
284 let missing: HashSet<_> = needed.difference(&loaded).cloned().collect();
285
286 if missing.is_empty() {
287 Ok((translator, env))
288 } else {
289 Err(TemplateLoadingError::MissingTemplates { missing, loaded })
290 }
291 }
292
293 #[tracing::instrument(
299 name = "templates.reload",
300 skip_all,
301 fields(path = %self.path),
302 )]
303 pub async fn reload(&self) -> Result<(), TemplateLoadingError> {
304 let (translator, environment) = Self::load_(
305 &self.path,
306 self.url_builder.clone(),
307 self.vite_manifest_path.as_deref(),
308 &self.translations_path,
309 self.branding.clone(),
310 self.features,
311 self.strict,
312 )
313 .await?;
314
315 self.environment.store(environment);
317 self.translator.store(translator);
318
319 Ok(())
320 }
321
322 #[must_use]
324 pub fn translator(&self) -> Arc<Translator> {
325 self.translator.load_full()
326 }
327}
328
329#[derive(Error, Debug)]
331pub enum TemplateError {
332 #[error("missing template {template:?}")]
334 Missing {
335 template: &'static str,
337
338 #[source]
340 source: minijinja::Error,
341 },
342
343 #[error("could not render template {template:?}")]
345 Render {
346 template: &'static str,
348
349 #[source]
351 source: minijinja::Error,
352 },
353}
354
355register_templates! {
356 pub fn render_not_found(WithLanguage<NotFoundContext>) { "pages/404.html" }
358
359 pub fn render_app(WithLanguage<AppContext>) { "app.html" }
361
362 pub fn render_swagger(ApiDocContext) { "swagger/doc.html" }
364
365 pub fn render_swagger_callback(ApiDocContext) { "swagger/oauth2-redirect.html" }
367
368 pub fn render_login(WithLanguage<WithCsrf<LoginContext>>) { "pages/login.html" }
370
371 pub fn render_register(WithLanguage<WithCsrf<RegisterContext>>) { "pages/register/index.html" }
373
374 pub fn render_password_register(WithLanguage<WithCsrf<WithCaptcha<PasswordRegisterContext>>>) { "pages/register/password.html" }
376
377 pub fn render_register_steps_verify_email(WithLanguage<WithCsrf<RegisterStepsVerifyEmailContext>>) { "pages/register/steps/verify_email.html" }
379
380 pub fn render_register_steps_email_in_use(WithLanguage<RegisterStepsEmailInUseContext>) { "pages/register/steps/email_in_use.html" }
382
383 pub fn render_register_steps_display_name(WithLanguage<WithCsrf<RegisterStepsDisplayNameContext>>) { "pages/register/steps/display_name.html" }
385
386 pub fn render_register_steps_registration_token(WithLanguage<WithCsrf<RegisterStepsRegistrationTokenContext>>) { "pages/register/steps/registration_token.html" }
388
389 pub fn render_consent(WithLanguage<WithCsrf<WithSession<ConsentContext>>>) { "pages/consent.html" }
391
392 pub fn render_policy_violation(WithLanguage<WithCsrf<WithSession<PolicyViolationContext>>>) { "pages/policy_violation.html" }
394
395 pub fn render_compat_login_policy_violation(WithLanguage<WithCsrf<WithSession<CompatLoginPolicyViolationContext>>>) { "pages/compat_login_policy_violation.html" }
397
398 pub fn render_sso_login(WithLanguage<WithCsrf<WithSession<CompatSsoContext>>>) { "pages/sso.html" }
400
401 pub fn render_index(WithLanguage<WithCsrf<WithOptionalSession<IndexContext>>>) { "pages/index.html" }
403
404 pub fn render_recovery_start(WithLanguage<WithCsrf<RecoveryStartContext>>) { "pages/recovery/start.html" }
406
407 pub fn render_recovery_progress(WithLanguage<WithCsrf<RecoveryProgressContext>>) { "pages/recovery/progress.html" }
409
410 pub fn render_recovery_finish(WithLanguage<WithCsrf<RecoveryFinishContext>>) { "pages/recovery/finish.html" }
412
413 pub fn render_recovery_expired(WithLanguage<WithCsrf<RecoveryExpiredContext>>) { "pages/recovery/expired.html" }
415
416 pub fn render_recovery_consumed(WithLanguage<EmptyContext>) { "pages/recovery/consumed.html" }
418
419 pub fn render_recovery_disabled(WithLanguage<EmptyContext>) { "pages/recovery/disabled.html" }
421
422 pub fn render_form_post<#[sample(EmptyContext)] T: Serialize>(WithLanguage<FormPostContext<T>>) { "form_post.html" }
424
425 pub fn render_error(ErrorContext) { "pages/error.html" }
427
428 pub fn render_email_recovery_txt(WithLanguage<EmailRecoveryContext>) { "emails/recovery.txt" }
430
431 pub fn render_email_recovery_html(WithLanguage<EmailRecoveryContext>) { "emails/recovery.html" }
433
434 pub fn render_email_recovery_subject(WithLanguage<EmailRecoveryContext>) { "emails/recovery.subject" }
436
437 pub fn render_email_verification_txt(WithLanguage<EmailVerificationContext>) { "emails/verification.txt" }
439
440 pub fn render_email_verification_html(WithLanguage<EmailVerificationContext>) { "emails/verification.html" }
442
443 pub fn render_email_verification_subject(WithLanguage<EmailVerificationContext>) { "emails/verification.subject" }
445
446 pub fn render_upstream_oauth2_link_mismatch(WithLanguage<WithCsrf<WithSession<UpstreamExistingLinkContext>>>) { "pages/upstream_oauth2/link_mismatch.html" }
448
449 pub fn render_upstream_oauth2_login_link(WithLanguage<WithCsrf<UpstreamExistingLinkContext>>) { "pages/upstream_oauth2/login_link.html" }
451
452 pub fn render_upstream_oauth2_suggest_link(WithLanguage<WithCsrf<WithSession<UpstreamSuggestLink>>>) { "pages/upstream_oauth2/suggest_link.html" }
454
455 pub fn render_upstream_oauth2_do_register(WithLanguage<WithCsrf<UpstreamRegister>>) { "pages/upstream_oauth2/do_register.html" }
457
458 pub fn render_device_link(WithLanguage<DeviceLinkContext>) { "pages/device_link.html" }
460
461 pub fn render_device_consent(WithLanguage<WithCsrf<WithSession<DeviceConsentContext>>>) { "pages/device_consent.html" }
463
464 pub fn render_account_deactivated(WithLanguage<WithCsrf<AccountInactiveContext>>) { "pages/account/deactivated.html" }
466
467 pub fn render_account_locked(WithLanguage<WithCsrf<AccountInactiveContext>>) { "pages/account/locked.html" }
469
470 pub fn render_account_logged_out(WithLanguage<WithCsrf<AccountInactiveContext>>) { "pages/account/logged_out.html" }
472
473 pub fn render_device_name(WithLanguage<DeviceNameContext>) { "device_name.txt" }
475}
476
477impl Templates {
478 pub fn check_render<R: Rng + Clone>(
491 &self,
492 now: chrono::DateTime<chrono::Utc>,
493 rng: &R,
494 ) -> anyhow::Result<BTreeMap<(&'static str, SampleIdentifier), String>> {
495 check::all(self, now, rng)
496 }
497}
498
499#[cfg(test)]
500mod tests {
501 use rand::SeedableRng;
502
503 use super::*;
504
505 #[tokio::test]
506 async fn check_builtin_templates() {
507 #[allow(clippy::disallowed_methods)]
508 let now = chrono::Utc::now();
509 let rng = rand_chacha::ChaCha8Rng::from_seed([42; 32]);
510
511 let path = Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../templates/");
512 let url_builder = UrlBuilder::new("https://example.com/".parse().unwrap(), None, None);
513 let branding = SiteBranding::new("example.com");
514 let features = SiteFeatures {
515 password_login: true,
516 password_registration: true,
517 password_registration_email_required: true,
518 account_recovery: true,
519 login_with_email_allowed: true,
520 };
521 let vite_manifest_path =
522 Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../frontend/dist/manifest.json");
523 let translations_path =
524 Utf8Path::new(env!("CARGO_MANIFEST_DIR")).join("../../translations");
525
526 for use_real_vite_manifest in [true, false] {
527 let templates = Templates::load(
528 path.clone(),
529 url_builder.clone(),
530 use_real_vite_manifest.then_some(vite_manifest_path.clone()),
533 translations_path.clone(),
534 branding.clone(),
535 features,
536 true,
538 )
539 .await
540 .unwrap();
541
542 let render1 = templates.check_render(now, &rng).unwrap();
544 let render2 = templates.check_render(now, &rng).unwrap();
545
546 assert_eq!(render1, render2);
547 }
548 }
549}