1pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::{SessionLimitConfig, Ulid};
13use opa_wasm::{
14 Runtime,
15 wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use serde::Serialize;
18use thiserror::Error;
19use tokio::io::{AsyncRead, AsyncReadExt};
20
21pub use self::model::{
22 AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput,
23 EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester,
24 Violation,
25};
26
27#[derive(Debug, Error)]
28pub enum LoadError {
29 #[error("failed to read module")]
30 Read(#[from] tokio::io::Error),
31
32 #[error("failed to create WASM engine")]
33 Engine(#[source] anyhow::Error),
34
35 #[error("module compilation task crashed")]
36 CompilationTask(#[from] tokio::task::JoinError),
37
38 #[error("failed to compile WASM module")]
39 Compilation(#[source] anyhow::Error),
40
41 #[error("invalid policy data")]
42 InvalidData(#[source] anyhow::Error),
43
44 #[error("failed to instantiate a test instance")]
45 Instantiate(#[source] InstantiateError),
46}
47
48impl LoadError {
49 #[doc(hidden)]
52 #[must_use]
53 pub fn invalid_data_example() -> Self {
54 Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
55 }
56}
57
58#[derive(Debug, Error)]
59pub enum InstantiateError {
60 #[error("failed to create WASM runtime")]
61 Runtime(#[source] anyhow::Error),
62
63 #[error("missing entrypoint {entrypoint}")]
64 MissingEntrypoint { entrypoint: String },
65
66 #[error("failed to load policy data")]
67 LoadData(#[source] anyhow::Error),
68}
69
70#[derive(Debug, Clone)]
72pub struct Entrypoints {
73 pub register: String,
74 pub client_registration: String,
75 pub authorization_grant: String,
76 pub compat_login: String,
77 pub email: String,
78}
79
80impl Entrypoints {
81 fn all(&self) -> [&str; 5] {
82 [
83 self.register.as_str(),
84 self.client_registration.as_str(),
85 self.authorization_grant.as_str(),
86 self.compat_login.as_str(),
87 self.email.as_str(),
88 ]
89 }
90}
91
92#[derive(Debug)]
93pub struct Data {
94 base: BaseData,
95
96 rest: Option<serde_json::Value>,
98}
99
100#[derive(Serialize, Debug)]
101struct BaseData {
102 server_name: String,
103
104 session_limit: Option<SessionLimitConfig>,
106}
107
108impl Data {
109 #[must_use]
110 pub fn new(server_name: String, session_limit: Option<SessionLimitConfig>) -> Self {
111 Self {
112 base: BaseData {
113 server_name,
114 session_limit,
115 },
116
117 rest: None,
118 }
119 }
120
121 #[must_use]
122 pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
123 self.rest = Some(rest);
124 self
125 }
126
127 fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
128 let base = serde_json::to_value(&self.base)?;
129
130 if let Some(rest) = &self.rest {
131 merge_data(base, rest.clone())
132 } else {
133 Ok(base)
134 }
135 }
136}
137
138fn value_kind(value: &serde_json::Value) -> &'static str {
139 match value {
140 serde_json::Value::Object(_) => "object",
141 serde_json::Value::Array(_) => "array",
142 serde_json::Value::String(_) => "string",
143 serde_json::Value::Number(_) => "number",
144 serde_json::Value::Bool(_) => "boolean",
145 serde_json::Value::Null => "null",
146 }
147}
148
149fn merge_data(
150 mut left: serde_json::Value,
151 right: serde_json::Value,
152) -> Result<serde_json::Value, anyhow::Error> {
153 merge_data_rec(&mut left, right)?;
154 Ok(left)
155}
156
157fn merge_data_rec(
158 left: &mut serde_json::Value,
159 right: serde_json::Value,
160) -> Result<(), anyhow::Error> {
161 match (left, right) {
162 (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
163 for (key, value) in right {
164 if let Some(left_value) = left.get_mut(&key) {
165 merge_data_rec(left_value, value)?;
166 } else {
167 left.insert(key, value);
168 }
169 }
170 }
171 (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
172 left.extend(right);
173 }
174 (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
176 *left = right;
177 }
178 (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
179 *left = right;
180 }
181 (serde_json::Value::String(left), serde_json::Value::String(right)) => {
182 *left = right;
183 }
184
185 (left, right) if left.is_null() => *left = right,
187
188 (left, right) if right.is_null() => *left = right,
190
191 (left, right) => anyhow::bail!(
192 "Cannot merge a {} into a {}",
193 value_kind(&right),
194 value_kind(left),
195 ),
196 }
197
198 Ok(())
199}
200
201struct DynamicData {
202 version: Option<Ulid>,
203 merged: serde_json::Value,
204}
205
206pub struct PolicyFactory {
207 engine: Engine,
208 module: Module,
209 data: Data,
210 dynamic_data: ArcSwap<DynamicData>,
211 entrypoints: Entrypoints,
212}
213
214impl PolicyFactory {
215 #[tracing::instrument(name = "policy.load", skip(source))]
221 pub async fn load(
222 mut source: impl AsyncRead + std::marker::Unpin,
223 data: Data,
224 entrypoints: Entrypoints,
225 ) -> Result<Self, LoadError> {
226 let mut config = Config::default();
227 config.async_support(true);
228 config.cranelift_opt_level(OptLevel::SpeedAndSize);
229
230 let engine = Engine::new(&config).map_err(LoadError::Engine)?;
231
232 let mut buf = Vec::new();
234 source.read_to_end(&mut buf).await?;
235 let (engine, module) = tokio::task::spawn_blocking(move || {
237 let module = Module::new(&engine, buf)?;
238 anyhow::Ok((engine, module))
239 })
240 .await?
241 .map_err(LoadError::Compilation)?;
242
243 let merged = data.to_value().map_err(LoadError::InvalidData)?;
244 let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
245 version: None,
246 merged,
247 }));
248
249 let factory = Self {
250 engine,
251 module,
252 data,
253 dynamic_data,
254 entrypoints,
255 };
256
257 factory
259 .instantiate()
260 .await
261 .map_err(LoadError::Instantiate)?;
262
263 Ok(factory)
264 }
265
266 pub async fn set_dynamic_data(
279 &self,
280 dynamic_data: mas_data_model::PolicyData,
281 ) -> Result<bool, LoadError> {
282 if self.dynamic_data.load().version == Some(dynamic_data.id) {
285 return Ok(false);
287 }
288
289 let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
290 let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
291
292 self.instantiate_with_data(&merged)
294 .await
295 .map_err(LoadError::Instantiate)?;
296
297 self.dynamic_data.store(Arc::new(DynamicData {
299 version: Some(dynamic_data.id),
300 merged,
301 }));
302
303 Ok(true)
304 }
305
306 #[tracing::instrument(name = "policy.instantiate", skip_all)]
313 pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
314 let data = self.dynamic_data.load();
315 self.instantiate_with_data(&data.merged).await
316 }
317
318 async fn instantiate_with_data(
319 &self,
320 data: &serde_json::Value,
321 ) -> Result<Policy, InstantiateError> {
322 let mut store = Store::new(&self.engine, ());
323 let runtime = Runtime::new(&mut store, &self.module)
324 .await
325 .map_err(InstantiateError::Runtime)?;
326
327 let policy_entrypoints = runtime.entrypoints();
329
330 for e in self.entrypoints.all() {
331 if !policy_entrypoints.contains(e) {
332 return Err(InstantiateError::MissingEntrypoint {
333 entrypoint: e.to_owned(),
334 });
335 }
336 }
337
338 let instance = runtime
339 .with_data(&mut store, data)
340 .await
341 .map_err(InstantiateError::LoadData)?;
342
343 Ok(Policy {
344 store,
345 instance,
346 entrypoints: self.entrypoints.clone(),
347 })
348 }
349}
350
351pub struct Policy {
352 store: Store<()>,
353 instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
354 entrypoints: Entrypoints,
355}
356
357#[derive(Debug, Error)]
358#[error("failed to evaluate policy")]
359pub enum EvaluationError {
360 Serialization(#[from] serde_json::Error),
361 Evaluation(#[from] anyhow::Error),
362}
363
364impl Policy {
365 #[tracing::instrument(
371 name = "policy.evaluate_email",
372 skip_all,
373 fields(
374 %input.email,
375 ),
376 )]
377 pub async fn evaluate_email(
378 &mut self,
379 input: EmailInput<'_>,
380 ) -> Result<EvaluationResult, EvaluationError> {
381 let [res]: [EvaluationResult; 1] = self
382 .instance
383 .evaluate(&mut self.store, &self.entrypoints.email, &input)
384 .await?;
385
386 Ok(res)
387 }
388
389 #[tracing::instrument(
395 name = "policy.evaluate.register",
396 skip_all,
397 fields(
398 ?input.registration_method,
399 input.username = input.username,
400 input.email = input.email,
401 ),
402 )]
403 pub async fn evaluate_register(
404 &mut self,
405 input: RegisterInput<'_>,
406 ) -> Result<EvaluationResult, EvaluationError> {
407 let [res]: [EvaluationResult; 1] = self
408 .instance
409 .evaluate(&mut self.store, &self.entrypoints.register, &input)
410 .await?;
411
412 Ok(res)
413 }
414
415 #[tracing::instrument(skip(self))]
421 pub async fn evaluate_client_registration(
422 &mut self,
423 input: ClientRegistrationInput<'_>,
424 ) -> Result<EvaluationResult, EvaluationError> {
425 let [res]: [EvaluationResult; 1] = self
426 .instance
427 .evaluate(
428 &mut self.store,
429 &self.entrypoints.client_registration,
430 &input,
431 )
432 .await?;
433
434 Ok(res)
435 }
436
437 #[tracing::instrument(
443 name = "policy.evaluate.authorization_grant",
444 skip_all,
445 fields(
446 %input.scope,
447 %input.client.id,
448 ),
449 )]
450 pub async fn evaluate_authorization_grant(
451 &mut self,
452 input: AuthorizationGrantInput<'_>,
453 ) -> Result<EvaluationResult, EvaluationError> {
454 let [res]: [EvaluationResult; 1] = self
455 .instance
456 .evaluate(
457 &mut self.store,
458 &self.entrypoints.authorization_grant,
459 &input,
460 )
461 .await?;
462
463 Ok(res)
464 }
465
466 #[tracing::instrument(
472 name = "policy.evaluate.compat_login",
473 skip_all,
474 fields(
475 %input.user.id,
476 ),
477 )]
478 pub async fn evaluate_compat_login(
479 &mut self,
480 input: CompatLoginInput<'_>,
481 ) -> Result<EvaluationResult, EvaluationError> {
482 let [res]: [EvaluationResult; 1] = self
483 .instance
484 .evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
485 .await?;
486
487 Ok(res)
488 }
489}
490
491#[cfg(test)]
492mod tests {
493
494 use std::time::SystemTime;
495
496 use super::*;
497
498 fn make_entrypoints() -> Entrypoints {
499 Entrypoints {
500 register: "register/violation".to_owned(),
501 client_registration: "client_registration/violation".to_owned(),
502 authorization_grant: "authorization_grant/violation".to_owned(),
503 compat_login: "compat_login/violation".to_owned(),
504 email: "email/violation".to_owned(),
505 }
506 }
507
508 #[tokio::test]
509 async fn test_register() {
510 let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
511 "allowed_domains": ["element.io", "*.element.io"],
512 "banned_domains": ["staging.element.io"],
513 }));
514
515 #[allow(clippy::disallowed_types)]
516 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
517 .join("..")
518 .join("..")
519 .join("policies")
520 .join("policy.wasm");
521
522 let file = tokio::fs::File::open(path).await.unwrap();
523
524 let factory = PolicyFactory::load(file, data, make_entrypoints())
525 .await
526 .unwrap();
527
528 let mut policy = factory.instantiate().await.unwrap();
529
530 let res = policy
531 .evaluate_register(RegisterInput {
532 registration_method: RegistrationMethod::Password,
533 username: "hello",
534 email: Some("hello@example.com"),
535 requester: Requester {
536 ip_address: None,
537 user_agent: None,
538 },
539 })
540 .await
541 .unwrap();
542 assert!(!res.valid());
543
544 let res = policy
545 .evaluate_register(RegisterInput {
546 registration_method: RegistrationMethod::Password,
547 username: "hello",
548 email: Some("hello@foo.element.io"),
549 requester: Requester {
550 ip_address: None,
551 user_agent: None,
552 },
553 })
554 .await
555 .unwrap();
556 assert!(res.valid());
557
558 let res = policy
559 .evaluate_register(RegisterInput {
560 registration_method: RegistrationMethod::Password,
561 username: "hello",
562 email: Some("hello@staging.element.io"),
563 requester: Requester {
564 ip_address: None,
565 user_agent: None,
566 },
567 })
568 .await
569 .unwrap();
570 assert!(!res.valid());
571 }
572
573 #[tokio::test]
574 async fn test_dynamic_data() {
575 let data = Data::new("example.com".to_owned(), None);
576
577 #[allow(clippy::disallowed_types)]
578 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
579 .join("..")
580 .join("..")
581 .join("policies")
582 .join("policy.wasm");
583
584 let file = tokio::fs::File::open(path).await.unwrap();
585
586 let factory = PolicyFactory::load(file, data, make_entrypoints())
587 .await
588 .unwrap();
589
590 let mut policy = factory.instantiate().await.unwrap();
591
592 let res = policy
593 .evaluate_register(RegisterInput {
594 registration_method: RegistrationMethod::Password,
595 username: "hello",
596 email: Some("hello@example.com"),
597 requester: Requester {
598 ip_address: None,
599 user_agent: None,
600 },
601 })
602 .await
603 .unwrap();
604 assert!(res.valid());
605
606 factory
608 .set_dynamic_data(mas_data_model::PolicyData {
609 id: Ulid::nil(),
610 created_at: SystemTime::now().into(),
611 data: serde_json::json!({
612 "emails": {
613 "banned_addresses": {
614 "substrings": ["hello"]
615 }
616 }
617 }),
618 })
619 .await
620 .unwrap();
621 let mut policy = factory.instantiate().await.unwrap();
622 let res = policy
623 .evaluate_register(RegisterInput {
624 registration_method: RegistrationMethod::Password,
625 username: "hello",
626 email: Some("hello@example.com"),
627 requester: Requester {
628 ip_address: None,
629 user_agent: None,
630 },
631 })
632 .await
633 .unwrap();
634 assert!(!res.valid());
635 }
636
637 #[tokio::test]
638 async fn test_big_dynamic_data() {
639 let data = Data::new("example.com".to_owned(), None);
640
641 #[allow(clippy::disallowed_types)]
642 let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
643 .join("..")
644 .join("..")
645 .join("policies")
646 .join("policy.wasm");
647
648 let file = tokio::fs::File::open(path).await.unwrap();
649
650 let factory = PolicyFactory::load(file, data, make_entrypoints())
651 .await
652 .unwrap();
653
654 let data: Vec<String> = (0..(1024 * 1024 / 8))
657 .map(|i| format!("{:05}", i % 100_000))
658 .collect();
659 let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
660 factory
661 .set_dynamic_data(mas_data_model::PolicyData {
662 id: Ulid::nil(),
663 created_at: SystemTime::now().into(),
664 data: json,
665 })
666 .await
667 .unwrap();
668
669 let mut policy = factory.instantiate().await.unwrap();
672 let res = policy
673 .evaluate_register(RegisterInput {
674 registration_method: RegistrationMethod::Password,
675 username: "hello",
676 email: Some("12345@example.com"),
677 requester: Requester {
678 ip_address: None,
679 user_agent: None,
680 },
681 })
682 .await
683 .unwrap();
684 assert!(!res.valid());
685 }
686
687 #[test]
688 fn test_merge() {
689 use serde_json::json as j;
690
691 let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
693 assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
694
695 let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
697 assert_eq!(res, j!({"hello": "john"}));
698
699 let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
700 assert_eq!(res, j!({"hello": false}));
701
702 let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
703 assert_eq!(res, j!({"hello": 42}));
704
705 merge_data(j!({"hello": "world"}), j!({"hello": 123}))
707 .expect_err("Can't merge different types");
708
709 let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
711 assert_eq!(res, j!({"hello": ["world", "john"]}));
712
713 let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
715 assert_eq!(res, j!({"hello": null}));
716
717 let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
719 assert_eq!(res, j!({"hello": "world"}));
720
721 let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
723 assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
724 }
725}