mas_policy/
lib.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7pub mod model;
8
9use std::sync::Arc;
10
11use arc_swap::ArcSwap;
12use mas_data_model::{SessionLimitConfig, Ulid};
13use opa_wasm::{
14    Runtime,
15    wasmtime::{Config, Engine, Module, OptLevel, Store},
16};
17use serde::Serialize;
18use thiserror::Error;
19use tokio::io::{AsyncRead, AsyncReadExt};
20
21pub use self::model::{
22    AuthorizationGrantInput, ClientRegistrationInput, Code as ViolationCode, CompatLoginInput,
23    EmailInput, EvaluationResult, GrantType, RegisterInput, RegistrationMethod, Requester,
24    Violation,
25};
26
27#[derive(Debug, Error)]
28pub enum LoadError {
29    #[error("failed to read module")]
30    Read(#[from] tokio::io::Error),
31
32    #[error("failed to create WASM engine")]
33    Engine(#[source] anyhow::Error),
34
35    #[error("module compilation task crashed")]
36    CompilationTask(#[from] tokio::task::JoinError),
37
38    #[error("failed to compile WASM module")]
39    Compilation(#[source] anyhow::Error),
40
41    #[error("invalid policy data")]
42    InvalidData(#[source] anyhow::Error),
43
44    #[error("failed to instantiate a test instance")]
45    Instantiate(#[source] InstantiateError),
46}
47
48impl LoadError {
49    /// Creates an example of an invalid data error, used for API response
50    /// documentation
51    #[doc(hidden)]
52    #[must_use]
53    pub fn invalid_data_example() -> Self {
54        Self::InvalidData(anyhow::Error::msg("Failed to merge policy data objects"))
55    }
56}
57
58#[derive(Debug, Error)]
59pub enum InstantiateError {
60    #[error("failed to create WASM runtime")]
61    Runtime(#[source] anyhow::Error),
62
63    #[error("missing entrypoint {entrypoint}")]
64    MissingEntrypoint { entrypoint: String },
65
66    #[error("failed to load policy data")]
67    LoadData(#[source] anyhow::Error),
68}
69
70/// Holds the entrypoint of each policy
71#[derive(Debug, Clone)]
72pub struct Entrypoints {
73    pub register: String,
74    pub client_registration: String,
75    pub authorization_grant: String,
76    pub compat_login: String,
77    pub email: String,
78}
79
80impl Entrypoints {
81    fn all(&self) -> [&str; 5] {
82        [
83            self.register.as_str(),
84            self.client_registration.as_str(),
85            self.authorization_grant.as_str(),
86            self.compat_login.as_str(),
87            self.email.as_str(),
88        ]
89    }
90}
91
92#[derive(Debug)]
93pub struct Data {
94    base: BaseData,
95
96    // We will merge this in a custom way, so don't emit as part of the base
97    rest: Option<serde_json::Value>,
98}
99
100#[derive(Serialize, Debug)]
101struct BaseData {
102    server_name: String,
103
104    /// Limits on the number of application sessions that each user can have
105    session_limit: Option<SessionLimitConfig>,
106}
107
108impl Data {
109    #[must_use]
110    pub fn new(server_name: String, session_limit: Option<SessionLimitConfig>) -> Self {
111        Self {
112            base: BaseData {
113                server_name,
114                session_limit,
115            },
116
117            rest: None,
118        }
119    }
120
121    #[must_use]
122    pub fn with_rest(mut self, rest: serde_json::Value) -> Self {
123        self.rest = Some(rest);
124        self
125    }
126
127    fn to_value(&self) -> Result<serde_json::Value, anyhow::Error> {
128        let base = serde_json::to_value(&self.base)?;
129
130        if let Some(rest) = &self.rest {
131            merge_data(base, rest.clone())
132        } else {
133            Ok(base)
134        }
135    }
136}
137
138fn value_kind(value: &serde_json::Value) -> &'static str {
139    match value {
140        serde_json::Value::Object(_) => "object",
141        serde_json::Value::Array(_) => "array",
142        serde_json::Value::String(_) => "string",
143        serde_json::Value::Number(_) => "number",
144        serde_json::Value::Bool(_) => "boolean",
145        serde_json::Value::Null => "null",
146    }
147}
148
149fn merge_data(
150    mut left: serde_json::Value,
151    right: serde_json::Value,
152) -> Result<serde_json::Value, anyhow::Error> {
153    merge_data_rec(&mut left, right)?;
154    Ok(left)
155}
156
157fn merge_data_rec(
158    left: &mut serde_json::Value,
159    right: serde_json::Value,
160) -> Result<(), anyhow::Error> {
161    match (left, right) {
162        (serde_json::Value::Object(left), serde_json::Value::Object(right)) => {
163            for (key, value) in right {
164                if let Some(left_value) = left.get_mut(&key) {
165                    merge_data_rec(left_value, value)?;
166                } else {
167                    left.insert(key, value);
168                }
169            }
170        }
171        (serde_json::Value::Array(left), serde_json::Value::Array(right)) => {
172            left.extend(right);
173        }
174        // Other values override
175        (serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
176            *left = right;
177        }
178        (serde_json::Value::Bool(left), serde_json::Value::Bool(right)) => {
179            *left = right;
180        }
181        (serde_json::Value::String(left), serde_json::Value::String(right)) => {
182            *left = right;
183        }
184
185        // Null gets overridden by anything
186        (left, right) if left.is_null() => *left = right,
187
188        // Null on the right makes the left value null
189        (left, right) if right.is_null() => *left = right,
190
191        (left, right) => anyhow::bail!(
192            "Cannot merge a {} into a {}",
193            value_kind(&right),
194            value_kind(left),
195        ),
196    }
197
198    Ok(())
199}
200
201struct DynamicData {
202    version: Option<Ulid>,
203    merged: serde_json::Value,
204}
205
206pub struct PolicyFactory {
207    engine: Engine,
208    module: Module,
209    data: Data,
210    dynamic_data: ArcSwap<DynamicData>,
211    entrypoints: Entrypoints,
212}
213
214impl PolicyFactory {
215    /// Load the policy from the given data source.
216    ///
217    /// # Errors
218    ///
219    /// Returns an error if the policy can't be loaded or instantiated.
220    #[tracing::instrument(name = "policy.load", skip(source))]
221    pub async fn load(
222        mut source: impl AsyncRead + std::marker::Unpin,
223        data: Data,
224        entrypoints: Entrypoints,
225    ) -> Result<Self, LoadError> {
226        let mut config = Config::default();
227        config.async_support(true);
228        config.cranelift_opt_level(OptLevel::SpeedAndSize);
229
230        let engine = Engine::new(&config).map_err(LoadError::Engine)?;
231
232        // Read and compile the module
233        let mut buf = Vec::new();
234        source.read_to_end(&mut buf).await?;
235        // Compilation is CPU-bound, so spawn that in a blocking task
236        let (engine, module) = tokio::task::spawn_blocking(move || {
237            let module = Module::new(&engine, buf)?;
238            anyhow::Ok((engine, module))
239        })
240        .await?
241        .map_err(LoadError::Compilation)?;
242
243        let merged = data.to_value().map_err(LoadError::InvalidData)?;
244        let dynamic_data = ArcSwap::new(Arc::new(DynamicData {
245            version: None,
246            merged,
247        }));
248
249        let factory = Self {
250            engine,
251            module,
252            data,
253            dynamic_data,
254            entrypoints,
255        };
256
257        // Try to instantiate
258        factory
259            .instantiate()
260            .await
261            .map_err(LoadError::Instantiate)?;
262
263        Ok(factory)
264    }
265
266    /// Set the dynamic data for the policy.
267    ///
268    /// The `dynamic_data` object is merged with the static data given when the
269    /// policy was loaded.
270    ///
271    /// Returns `true` if the data was updated, `false` if the version
272    /// of the dynamic data was the same as the one we already have.
273    ///
274    /// # Errors
275    ///
276    /// Returns an error if the data can't be merged with the static data, or if
277    /// the policy can't be instantiated with the new data.
278    pub async fn set_dynamic_data(
279        &self,
280        dynamic_data: mas_data_model::PolicyData,
281    ) -> Result<bool, LoadError> {
282        // Check if the version of the dynamic data we have is the same as the one we're
283        // trying to set
284        if self.dynamic_data.load().version == Some(dynamic_data.id) {
285            // Don't do anything if the version is the same
286            return Ok(false);
287        }
288
289        let static_data = self.data.to_value().map_err(LoadError::InvalidData)?;
290        let merged = merge_data(static_data, dynamic_data.data).map_err(LoadError::InvalidData)?;
291
292        // Try to instantiate with the new data
293        self.instantiate_with_data(&merged)
294            .await
295            .map_err(LoadError::Instantiate)?;
296
297        // If instantiation succeeds, swap the data
298        self.dynamic_data.store(Arc::new(DynamicData {
299            version: Some(dynamic_data.id),
300            merged,
301        }));
302
303        Ok(true)
304    }
305
306    /// Create a new policy instance.
307    ///
308    /// # Errors
309    ///
310    /// Returns an error if the policy can't be instantiated with the current
311    /// dynamic data.
312    #[tracing::instrument(name = "policy.instantiate", skip_all)]
313    pub async fn instantiate(&self) -> Result<Policy, InstantiateError> {
314        let data = self.dynamic_data.load();
315        self.instantiate_with_data(&data.merged).await
316    }
317
318    async fn instantiate_with_data(
319        &self,
320        data: &serde_json::Value,
321    ) -> Result<Policy, InstantiateError> {
322        let mut store = Store::new(&self.engine, ());
323        let runtime = Runtime::new(&mut store, &self.module)
324            .await
325            .map_err(InstantiateError::Runtime)?;
326
327        // Check that we have the required entrypoints
328        let policy_entrypoints = runtime.entrypoints();
329
330        for e in self.entrypoints.all() {
331            if !policy_entrypoints.contains(e) {
332                return Err(InstantiateError::MissingEntrypoint {
333                    entrypoint: e.to_owned(),
334                });
335            }
336        }
337
338        let instance = runtime
339            .with_data(&mut store, data)
340            .await
341            .map_err(InstantiateError::LoadData)?;
342
343        Ok(Policy {
344            store,
345            instance,
346            entrypoints: self.entrypoints.clone(),
347        })
348    }
349}
350
351pub struct Policy {
352    store: Store<()>,
353    instance: opa_wasm::Policy<opa_wasm::DefaultContext>,
354    entrypoints: Entrypoints,
355}
356
357#[derive(Debug, Error)]
358#[error("failed to evaluate policy")]
359pub enum EvaluationError {
360    Serialization(#[from] serde_json::Error),
361    Evaluation(#[from] anyhow::Error),
362}
363
364impl Policy {
365    /// Evaluate the 'email' entrypoint.
366    ///
367    /// # Errors
368    ///
369    /// Returns an error if the policy engine fails to evaluate the entrypoint.
370    #[tracing::instrument(
371        name = "policy.evaluate_email",
372        skip_all,
373        fields(
374            %input.email,
375        ),
376    )]
377    pub async fn evaluate_email(
378        &mut self,
379        input: EmailInput<'_>,
380    ) -> Result<EvaluationResult, EvaluationError> {
381        let [res]: [EvaluationResult; 1] = self
382            .instance
383            .evaluate(&mut self.store, &self.entrypoints.email, &input)
384            .await?;
385
386        Ok(res)
387    }
388
389    /// Evaluate the 'register' entrypoint.
390    ///
391    /// # Errors
392    ///
393    /// Returns an error if the policy engine fails to evaluate the entrypoint.
394    #[tracing::instrument(
395        name = "policy.evaluate.register",
396        skip_all,
397        fields(
398            ?input.registration_method,
399            input.username = input.username,
400            input.email = input.email,
401        ),
402    )]
403    pub async fn evaluate_register(
404        &mut self,
405        input: RegisterInput<'_>,
406    ) -> Result<EvaluationResult, EvaluationError> {
407        let [res]: [EvaluationResult; 1] = self
408            .instance
409            .evaluate(&mut self.store, &self.entrypoints.register, &input)
410            .await?;
411
412        Ok(res)
413    }
414
415    /// Evaluate the `client_registration` entrypoint.
416    ///
417    /// # Errors
418    ///
419    /// Returns an error if the policy engine fails to evaluate the entrypoint.
420    #[tracing::instrument(skip(self))]
421    pub async fn evaluate_client_registration(
422        &mut self,
423        input: ClientRegistrationInput<'_>,
424    ) -> Result<EvaluationResult, EvaluationError> {
425        let [res]: [EvaluationResult; 1] = self
426            .instance
427            .evaluate(
428                &mut self.store,
429                &self.entrypoints.client_registration,
430                &input,
431            )
432            .await?;
433
434        Ok(res)
435    }
436
437    /// Evaluate the `authorization_grant` entrypoint.
438    ///
439    /// # Errors
440    ///
441    /// Returns an error if the policy engine fails to evaluate the entrypoint.
442    #[tracing::instrument(
443        name = "policy.evaluate.authorization_grant",
444        skip_all,
445        fields(
446            %input.scope,
447            %input.client.id,
448        ),
449    )]
450    pub async fn evaluate_authorization_grant(
451        &mut self,
452        input: AuthorizationGrantInput<'_>,
453    ) -> Result<EvaluationResult, EvaluationError> {
454        let [res]: [EvaluationResult; 1] = self
455            .instance
456            .evaluate(
457                &mut self.store,
458                &self.entrypoints.authorization_grant,
459                &input,
460            )
461            .await?;
462
463        Ok(res)
464    }
465
466    /// Evaluate the `compat_login` entrypoint.
467    ///
468    /// # Errors
469    ///
470    /// Returns an error if the policy engine fails to evaluate the entrypoint.
471    #[tracing::instrument(
472        name = "policy.evaluate.compat_login",
473        skip_all,
474        fields(
475            %input.user.id,
476        ),
477    )]
478    pub async fn evaluate_compat_login(
479        &mut self,
480        input: CompatLoginInput<'_>,
481    ) -> Result<EvaluationResult, EvaluationError> {
482        let [res]: [EvaluationResult; 1] = self
483            .instance
484            .evaluate(&mut self.store, &self.entrypoints.compat_login, &input)
485            .await?;
486
487        Ok(res)
488    }
489}
490
491#[cfg(test)]
492mod tests {
493
494    use std::time::SystemTime;
495
496    use super::*;
497
498    fn make_entrypoints() -> Entrypoints {
499        Entrypoints {
500            register: "register/violation".to_owned(),
501            client_registration: "client_registration/violation".to_owned(),
502            authorization_grant: "authorization_grant/violation".to_owned(),
503            compat_login: "compat_login/violation".to_owned(),
504            email: "email/violation".to_owned(),
505        }
506    }
507
508    #[tokio::test]
509    async fn test_register() {
510        let data = Data::new("example.com".to_owned(), None).with_rest(serde_json::json!({
511            "allowed_domains": ["element.io", "*.element.io"],
512            "banned_domains": ["staging.element.io"],
513        }));
514
515        #[allow(clippy::disallowed_types)]
516        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
517            .join("..")
518            .join("..")
519            .join("policies")
520            .join("policy.wasm");
521
522        let file = tokio::fs::File::open(path).await.unwrap();
523
524        let factory = PolicyFactory::load(file, data, make_entrypoints())
525            .await
526            .unwrap();
527
528        let mut policy = factory.instantiate().await.unwrap();
529
530        let res = policy
531            .evaluate_register(RegisterInput {
532                registration_method: RegistrationMethod::Password,
533                username: "hello",
534                email: Some("hello@example.com"),
535                requester: Requester {
536                    ip_address: None,
537                    user_agent: None,
538                },
539            })
540            .await
541            .unwrap();
542        assert!(!res.valid());
543
544        let res = policy
545            .evaluate_register(RegisterInput {
546                registration_method: RegistrationMethod::Password,
547                username: "hello",
548                email: Some("hello@foo.element.io"),
549                requester: Requester {
550                    ip_address: None,
551                    user_agent: None,
552                },
553            })
554            .await
555            .unwrap();
556        assert!(res.valid());
557
558        let res = policy
559            .evaluate_register(RegisterInput {
560                registration_method: RegistrationMethod::Password,
561                username: "hello",
562                email: Some("hello@staging.element.io"),
563                requester: Requester {
564                    ip_address: None,
565                    user_agent: None,
566                },
567            })
568            .await
569            .unwrap();
570        assert!(!res.valid());
571    }
572
573    #[tokio::test]
574    async fn test_dynamic_data() {
575        let data = Data::new("example.com".to_owned(), None);
576
577        #[allow(clippy::disallowed_types)]
578        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
579            .join("..")
580            .join("..")
581            .join("policies")
582            .join("policy.wasm");
583
584        let file = tokio::fs::File::open(path).await.unwrap();
585
586        let factory = PolicyFactory::load(file, data, make_entrypoints())
587            .await
588            .unwrap();
589
590        let mut policy = factory.instantiate().await.unwrap();
591
592        let res = policy
593            .evaluate_register(RegisterInput {
594                registration_method: RegistrationMethod::Password,
595                username: "hello",
596                email: Some("hello@example.com"),
597                requester: Requester {
598                    ip_address: None,
599                    user_agent: None,
600                },
601            })
602            .await
603            .unwrap();
604        assert!(res.valid());
605
606        // Update the policy data
607        factory
608            .set_dynamic_data(mas_data_model::PolicyData {
609                id: Ulid::nil(),
610                created_at: SystemTime::now().into(),
611                data: serde_json::json!({
612                    "emails": {
613                        "banned_addresses": {
614                            "substrings": ["hello"]
615                        }
616                    }
617                }),
618            })
619            .await
620            .unwrap();
621        let mut policy = factory.instantiate().await.unwrap();
622        let res = policy
623            .evaluate_register(RegisterInput {
624                registration_method: RegistrationMethod::Password,
625                username: "hello",
626                email: Some("hello@example.com"),
627                requester: Requester {
628                    ip_address: None,
629                    user_agent: None,
630                },
631            })
632            .await
633            .unwrap();
634        assert!(!res.valid());
635    }
636
637    #[tokio::test]
638    async fn test_big_dynamic_data() {
639        let data = Data::new("example.com".to_owned(), None);
640
641        #[allow(clippy::disallowed_types)]
642        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
643            .join("..")
644            .join("..")
645            .join("policies")
646            .join("policy.wasm");
647
648        let file = tokio::fs::File::open(path).await.unwrap();
649
650        let factory = PolicyFactory::load(file, data, make_entrypoints())
651            .await
652            .unwrap();
653
654        // That is around 1 MB of JSON data. Each element is a 5-digit string, so 8
655        // characters including the quotes and a comma.
656        let data: Vec<String> = (0..(1024 * 1024 / 8))
657            .map(|i| format!("{:05}", i % 100_000))
658            .collect();
659        let json = serde_json::json!({ "emails": { "banned_addresses": { "substrings": data } } });
660        factory
661            .set_dynamic_data(mas_data_model::PolicyData {
662                id: Ulid::nil(),
663                created_at: SystemTime::now().into(),
664                data: json,
665            })
666            .await
667            .unwrap();
668
669        // Try instantiating the policy, make sure 5-digit numbers are banned from email
670        // addresses
671        let mut policy = factory.instantiate().await.unwrap();
672        let res = policy
673            .evaluate_register(RegisterInput {
674                registration_method: RegistrationMethod::Password,
675                username: "hello",
676                email: Some("12345@example.com"),
677                requester: Requester {
678                    ip_address: None,
679                    user_agent: None,
680                },
681            })
682            .await
683            .unwrap();
684        assert!(!res.valid());
685    }
686
687    #[test]
688    fn test_merge() {
689        use serde_json::json as j;
690
691        // Merging objects
692        let res = merge_data(j!({"hello": "world"}), j!({"foo": "bar"})).unwrap();
693        assert_eq!(res, j!({"hello": "world", "foo": "bar"}));
694
695        // Override a value of the same type
696        let res = merge_data(j!({"hello": "world"}), j!({"hello": "john"})).unwrap();
697        assert_eq!(res, j!({"hello": "john"}));
698
699        let res = merge_data(j!({"hello": true}), j!({"hello": false})).unwrap();
700        assert_eq!(res, j!({"hello": false}));
701
702        let res = merge_data(j!({"hello": 0}), j!({"hello": 42})).unwrap();
703        assert_eq!(res, j!({"hello": 42}));
704
705        // Override a value of a different type
706        merge_data(j!({"hello": "world"}), j!({"hello": 123}))
707            .expect_err("Can't merge different types");
708
709        // Merge arrays
710        let res = merge_data(j!({"hello": ["world"]}), j!({"hello": ["john"]})).unwrap();
711        assert_eq!(res, j!({"hello": ["world", "john"]}));
712
713        // Null overrides a value
714        let res = merge_data(j!({"hello": "world"}), j!({"hello": null})).unwrap();
715        assert_eq!(res, j!({"hello": null}));
716
717        // Null gets overridden by a value
718        let res = merge_data(j!({"hello": null}), j!({"hello": "world"})).unwrap();
719        assert_eq!(res, j!({"hello": "world"}));
720
721        // Objects get deeply merged
722        let res = merge_data(j!({"a": {"b": {"c": "d"}}}), j!({"a": {"b": {"e": "f"}}})).unwrap();
723        assert_eq!(res, j!({"a": {"b": {"c": "d", "e": "f"}}}));
724    }
725}