use std::env; use std::net::SocketAddr; use chrono::Utc; use mail_send::{Credentials, SmtpClientBuilder}; use mail_send::mail_builder::MessageBuilder; use rand::Rng; use regex::Regex; use sha3::{Digest, Sha3_256}; use sha3::digest::Update; use sqlx::{MySql, Pool, Row}; use tarpc::context::Context; use crate::types::{AuthEmail, AuthUser, ErrorCode, RealmAuth}; use crate::types::ErrorCode::*; #[derive(Clone)] pub struct RealmAuthServer { pub socket: SocketAddr, pub db_pool: Pool, pub auth_email: AuthEmail, pub template_html: String, pub template_txt: String, pub domain: String, } impl RealmAuthServer { pub fn new(socket: SocketAddr, db_pool: Pool, auth_email: AuthEmail) -> RealmAuthServer { RealmAuthServer { socket, db_pool, auth_email, template_html: std::fs::read_to_string("./login_email.html").expect("A login_email.html file is needed"), template_txt: std::fs::read_to_string("./login_email.txt").expect("A login_email.txt file is needed"), domain: env::var("DOMAIN").expect("DOMAIN must be set"), } } pub fn gen_login_code(&self) -> u16 { let mut rng = rand::thread_rng(); let mut login_code: u16 = 0; for n in 1..=6 { if n == 1 { login_code += rng.gen_range(1..=9); } login_code += rng.gen_range(0..=9) * (10^n); } login_code } pub async fn is_username_taken(&self, username: &str) -> Result { let result = sqlx::query("SELECT NOT EXISTS (SELECT 1 FROM user WHERE username = ?) AS does_exist") .bind(username) .fetch_one(&self.db_pool).await; match result { Ok(row) => Ok(row.try_get("does_exist").unwrap()), Err(_) => Err(InvalidUsername) } } pub async fn is_email_taken(&self, email: &str) -> Result { let result = sqlx::query("SELECT NOT EXISTS (SELECT 1 FROM user WHERE email = ?) AS does_exist") .bind(email) .fetch_one(&self.db_pool).await; match result { Ok(row) => Ok(row.try_get("does_exist").unwrap()), Err(_) => Err(InvalidUsername) } } pub async fn is_authorized(&self, username: &str, token: &str) -> Result { let result = sqlx::query("SELECT tokens FROM user WHERE username = ?") .bind(username).fetch_one(&self.db_pool).await; match result { Ok(row) => { let token_long: &str = row.try_get("tokens").unwrap(); let tokens = token_long.split(',').collect::>(); for i in 0..tokens.len() { if tokens.get(i).unwrap() == &token { return Ok(true) } } Ok(false) }, Err(_) => Err(InvalidUsername), } } pub async fn send_login_message(&self, username: &str, email: &str, login_code: u16) -> Result<(), ErrorCode> { let message = MessageBuilder::new() .from((self.auth_email.auth_name.clone(), self.auth_email.auth_username.clone())) .to(vec![ (username, email), ]) .subject(format!("Realm confirmation code: {}", &login_code)) .html_body(self.template_html.replace("{}", &login_code.to_string())) .text_body(self.template_txt.replace("{}", &login_code.to_string())); let result = SmtpClientBuilder::new(&self.auth_email.server_address, self.auth_email.server_port) .implicit_tls(false) .credentials(Credentials::new(&self.auth_email.auth_username, &self.auth_email.auth_password)) .connect() .await; match result { Ok(mut client) => { let result = client.send(message).await; match result { Ok(_) => { Ok(()) } Err(_) => { Err(UnableToSendMail) } } } Err(_) => { Err(UnableToConnectToMail) } } } pub async fn is_login_code_valid(&self, username: &str, login_code: u16) -> Result { let result = sqlx::query("SELECT login_code FROM user WHERE username = ?;") .bind(username) .fetch_one(&self.db_pool).await; match result { Ok(row) => { if row.try_get::("login_code").unwrap() != login_code { return Ok(false) } Ok(true) } Err(_) => Err(InvalidUsername) } } pub fn is_username_valid(&self, username: &str) -> bool { if !username.starts_with('@') || !username.contains(':') { return false } let name = &username[1..username.find(':').unwrap()]; let domain = &username[username.find(':').unwrap()+1..]; let re = Regex::new(r"^[a-zA-Z0-9]+$").unwrap(); if !re.is_match(name) { return false } if !domain.eq(&self.domain) { return false } true } } impl RealmAuth for RealmAuthServer { async fn test(self, _: Context, name: String) -> String { format!("Hello {} auth!", name) } async fn server_token_validation(self, _: Context, server_token: String, username: String, server_id: String, domain: String, tarpc_port: u16) -> bool { let result = sqlx::query("SELECT tokens FROM user WHERE username = ?").bind(username).fetch_one(&self.db_pool).await; match result { Ok(row) => { let token_long: &str = row.try_get("tokens").unwrap(); let tokens = token_long.split(',').collect::>(); for token in tokens { let hash = Sha3_256::new().chain(format!("{}{}{}{}", token, server_id, domain, tarpc_port)).finalize(); if hex::encode(hash) == server_token { return true } } false }, Err(_) => false, } } async fn create_account_flow(self, _: Context, username: String, email: String) -> Result<(), ErrorCode> { if !self.is_username_valid(&username) { return Err(InvalidUsername) } if self.is_username_taken(&username).await? { return Err(UsernameTaken) } if self.is_email_taken(&email).await? { return Err(EmailTaken) } let code = self.gen_login_code(); self.send_login_message(&username, &email, code).await?; let result = sqlx::query("INSERT INTO user (username, email, avatar, login_code, tokens) VALUES (?, ?, '', ?, '')") .bind(&username).bind(&email).bind(code).execute(&self.db_pool).await; match result { Ok(_) => Ok(()), Err(_) => Err(Error) } } async fn create_login_flow(self, _: Context, mut username: Option, mut email: Option) -> Result<(), ErrorCode> { if username.is_none() && email.is_none() { return Err(Error) } if username.is_none() { let result = sqlx::query("SELECT username FROM user WHERE email = ?;") .bind(&email.clone().unwrap()) .fetch_one(&self.db_pool).await; match result { Ok(row) => { username = row.try_get("username").unwrap(); } Err(_) => return Err(InvalidEmail) } } if email.is_none() { let result = sqlx::query("SELECT email FROM user WHERE username = ?;") .bind(&username.clone().unwrap()) .fetch_one(&self.db_pool).await; match result { Ok(row) => { email = row.try_get("email").unwrap(); } Err(_) => return Err(InvalidUsername) } } let code = self.gen_login_code(); let result = sqlx::query("UPDATE user SET login_code = ? WHERE username = ?;") .bind(code) .bind(&username) .execute(&self.db_pool).await; match result { Ok(_) => self.send_login_message(&username.unwrap(), &email.unwrap(), code).await, Err(_) => Err(InvalidUsername) } } async fn finish_login_flow(self, _: Context, username: String, login_code: u16) -> Result { if !self.is_login_code_valid(&username, login_code).await? { return Err(InvalidLoginCode) } let _ = sqlx::query("UPDATE user SET login_code = NULL WHERE username = ?").bind(&username).execute(&self.db_pool).await; let hash = Sha3_256::new().chain(format!("{}{}{}", username, login_code, Utc::now().to_utc())).finalize(); let token = hex::encode(hash); let result = sqlx::query("SELECT tokens FROM user WHERE username = ?").bind(&username).fetch_one(&self.db_pool).await; match result { Ok(row) => { let token_long: &str = row.try_get("tokens").unwrap(); let mut tokens = token_long.split(',').collect::>(); tokens.push(&token); let result = sqlx::query("UPDATE user SET tokens = ? WHERE username = ?") .bind(tokens.join(",")) // TODO: This doesn't seem right and may cause problems .bind(&username) .execute(&self.db_pool).await; match result { Ok(_) => Ok(token), Err(_) => Err(InvalidUsername) } } Err(_) => Err(InvalidUsername) } } async fn change_email_flow(self, _: Context, username: String, new_email: String, token: String) -> Result<(), ErrorCode> { if !self.is_authorized(&username, &token).await? { return Err(Unauthorized) } if self.is_email_taken(&new_email).await? { return Err(EmailTaken) } let result = sqlx::query("UPDATE user SET new_email = ? WHERE username = ?") .bind(&new_email) .bind(&username) .execute(&self.db_pool).await; match result { Ok(_) => {} Err(_) => return Err(InvalidUsername) } let code = self.gen_login_code(); let result = sqlx::query("UPDATE user SET login_code = ? WHERE username = ?;") .bind(code) .bind(&username) .execute(&self.db_pool).await; match result { Ok(_) => self.send_login_message(&username, &new_email, code).await, Err(_) => Err(InvalidUsername) } } async fn finish_change_email_flow(self, _: Context, username: String, new_email: String, token: String, login_code: u16) -> Result<(), ErrorCode> { if !self.is_authorized(&username, &token).await? { return Err(Unauthorized) } if self.is_email_taken(&new_email).await? { return Err(EmailTaken) } if !self.is_login_code_valid(&username, login_code).await? { return Err(InvalidLoginCode) } let _ = sqlx::query("UPDATE user SET new_email = NULL WHERE username = ?") .bind(&username) .execute(&self.db_pool).await; let _ = sqlx::query("UPDATE user SET email = ? WHERE username = ?") .bind(&new_email) .bind(&username) .execute(&self.db_pool).await; Ok(()) } async fn change_username(self, _: Context, username: String, token: String, new_username: String) -> Result<(), ErrorCode> { if !self.is_username_valid(&new_username) { return Err(InvalidUsername) } if !self.is_authorized(&username, &token).await? { return Err(Unauthorized) } if self.is_username_taken(&new_username).await? { return Err(UsernameTaken) } let result = sqlx::query("UPDATE user SET username = ? WHERE username = ?") .bind(&new_username) .bind(&username).execute(&self.db_pool).await; match result { Ok(_) => Ok(()), Err(_) => Err(Error) } } async fn change_avatar(self, _: Context, username: String, token: String, new_avatar: String) -> Result<(), ErrorCode> { if !self.is_authorized(&username, &token).await? { return Err(Unauthorized) } let result = sqlx::query("UPDATE user SET avatar = ? WHERE username = ?") .bind(&new_avatar) .bind(&username).execute(&self.db_pool).await; match result { Ok(_) => Ok(()), Err(_) => Err(Error) } } async fn get_all_data(self, _: Context, username: String, token: String) -> Result { let result = self.is_authorized(&username, &token).await; match result { Ok(authorized) => { if !authorized { return Err(Unauthorized) } } Err(error) => return Err(error) } let result = sqlx::query("SELECT * FROM user WHERE username = ?") .bind(&username).fetch_one(&self.db_pool).await; match result { Ok(row) => { Ok(AuthUser { id: row.try_get("id").unwrap(), username: row.try_get("username").unwrap(), email: row.try_get("email").unwrap(), avatar: row.try_get("avatar").unwrap(), login_code: None, bigtoken: row.try_get("tokens").unwrap(), google_oauth: row.try_get("google_oauth").unwrap(), apple_oauth: row.try_get("apple_oauth").unwrap(), github_oauth: row.try_get("github_oauth").unwrap(), discord_oauth: row.try_get("discord_oauth").unwrap(), }) } Err(_) => Err(InvalidUsername) } } async fn sign_out(self, _: Context, username: String, token: String) -> Result<(), ErrorCode> { let result = sqlx::query("SELECT tokens FROM user WHERE username = ?") .bind(&username).fetch_one(&self.db_pool).await; match result { Ok(row) => { let token_long: &str = row.try_get("tokens").unwrap(); let mut tokens = token_long.split(',').collect::>(); for i in 0..tokens.len() { if tokens.get(i).unwrap().eq(&token.as_str()) { tokens.remove(i); let result = sqlx::query("UPDATE user SET tokens = ? WHERE username = ?") .bind(tokens.join(",")) .bind(&username) .execute(&self.db_pool).await; return match result { Ok(_) => Ok(()), Err(_) => Err(Error) }; } } Err(Unauthorized) }, Err(_) => Err(InvalidUsername), } } async fn get_avatar_for_user(self, _: Context, username: String) -> Result { let result = sqlx::query("SELECT tokens FROM user WHERE username = ?").bind(username).fetch_one(&self.db_pool).await; match result { Ok(row) => Ok(row.try_get("avatar").unwrap_or("".to_string())), Err(_) => Err(InvalidUsername) } } }