diff --git a/auth/Cargo.toml b/auth/Cargo.toml index 502997d..a892163 100644 --- a/auth/Cargo.toml +++ b/auth/Cargo.toml @@ -17,3 +17,4 @@ sha3 = "0.10.8" hex = "0.4.3" rand = "0.8.5" mail-send = "0.4.8" +regex = "1.10.5" diff --git a/auth/src/server.rs b/auth/src/server.rs index d12fef2..1a9376b 100644 --- a/auth/src/server.rs +++ b/auth/src/server.rs @@ -1,12 +1,14 @@ +use std::env; use std::net::SocketAddr; + use chrono::Utc; use mail_send::{Credentials, SmtpClientBuilder}; use mail_send::mail_builder::MessageBuilder; use rand::Rng; +use regex::Regex; use sha3::{Digest, Sha3_256}; use sha3::digest::Update; use sqlx::{MySql, Pool, Row}; -use sqlx::mysql::{MySqlQueryResult, MySqlRow}; use tarpc::context::Context; use crate::types::{AuthEmail, AuthUser, ErrorCode, RealmAuth}; @@ -19,6 +21,7 @@ pub struct RealmAuthServer { pub auth_email: AuthEmail, pub template_html: String, pub template_txt: String, + pub domain: String, } impl RealmAuthServer { @@ -29,6 +32,7 @@ impl RealmAuthServer { auth_email, template_html: std::fs::read_to_string("./login_email.html").expect("A login_email.html file is needed"), template_txt: std::fs::read_to_string("./login_email.txt").expect("A login_email.txt file is needed"), + domain: env::var("DOMAIN").expect("DOMAIN must be set"), } } @@ -122,7 +126,7 @@ impl RealmAuthServer { } } } - + pub async fn is_login_code_valid(&self, username: &str, login_code: u16) -> Result { let result = sqlx::query("SELECT login_code FROM user WHERE username = ?;") .bind(username) @@ -136,7 +140,27 @@ impl RealmAuthServer { Ok(true) } Err(_) => Err(InvalidUsername) - } + } + } + + pub fn is_username_valid(&self, username: &str) -> bool { + if !username.starts_with('@') || !username.contains(':') { + return false + } + + let name = &username[1..username.find(':').unwrap()]; + let domain = &username[username.find(':').unwrap()+1..]; + + let re = Regex::new(r"^[a-zA-Z0-9]+$").unwrap(); + if !re.is_match(name) { + return false + } + + if !domain.eq(&self.domain) { + return false + } + + true } } @@ -167,18 +191,20 @@ impl RealmAuth for RealmAuthServer { } async fn create_account_flow(self, _: Context, username: String, email: String) -> Result<(), ErrorCode> { - //TODO: USERNAME FORMATTING! + if !self.is_username_valid(&username) { + return Err(InvalidUsername) + } if self.is_username_taken(&username).await? { return Err(UsernameTaken) } - + if self.is_email_taken(&email).await? { return Err(EmailTaken) } let code = self.gen_login_code(); - let _ = self.send_login_message(&username, &email, code).await?; + self.send_login_message(&username, &email, code).await?; let result = sqlx::query("INSERT INTO user (username, email, avatar, login_code, tokens) VALUES (?, ?, '', ?, '')") .bind(&username).bind(&email).bind(code).execute(&self.db_pool).await; @@ -267,11 +293,11 @@ impl RealmAuth for RealmAuthServer { if !self.is_authorized(&username, &token).await? { return Err(Unauthorized) } - + if self.is_email_taken(&new_email).await? { return Err(EmailTaken) } - + let result = sqlx::query("UPDATE user SET new_email = ? WHERE username = ?") .bind(&new_email) .bind(&username) @@ -310,23 +336,24 @@ impl RealmAuth for RealmAuthServer { let _ = sqlx::query("UPDATE user SET new_email = NULL WHERE username = ?") .bind(&username) .execute(&self.db_pool).await; - + let _ = sqlx::query("UPDATE user SET email = ? WHERE username = ?") .bind(&new_email) .bind(&username) .execute(&self.db_pool).await; - + Ok(()) } async fn change_username(self, _: Context, username: String, token: String, new_username: String) -> Result<(), ErrorCode> { - //TODO: USERNAME FORMATTING! - + if !self.is_username_valid(&new_username) { + return Err(InvalidUsername) + } if !self.is_authorized(&username, &token).await? { return Err(Unauthorized) } - + if self.is_username_taken(&new_username).await? { return Err(UsernameTaken) } @@ -403,7 +430,7 @@ impl RealmAuth for RealmAuthServer { .bind(tokens.join(",")) .bind(&username) .execute(&self.db_pool).await; - + return match result { Ok(_) => Ok(()), Err(_) => Err(Error) diff --git a/auth/src/types.rs b/auth/src/types.rs index e927a36..b800ab3 100644 --- a/auth/src/types.rs +++ b/auth/src/types.rs @@ -18,10 +18,7 @@ pub trait RealmAuth { //NOTE: Anyone can call async fn get_avatar_for_user(username: String) -> Result; - //TODO: - // Create account - // Change username - // OAuth login, check against email, store token, take avatar: Google, Apple, GitHub, Discord + // TODO: OAuth login, check against email, store token, take avatar: Google, Apple, GitHub, Discord } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]