diff --git a/auth/Cargo.toml b/auth/Cargo.toml index d4ee8c7..ee6d02d 100644 --- a/auth/Cargo.toml +++ b/auth/Cargo.toml @@ -12,7 +12,7 @@ tracing = "0.1.40" serde = { version = "1.0.203", features = ["derive"] } chrono = { version = "0.4.24", features = ["serde"] } dotenvy = "0.15" -sqlx = { version = "0.7", features = [ "runtime-tokio", "tls-rustls", "mysql", "chrono" ] } +sqlx = { version = "0.7", features = [ "runtime-tokio", "tls-rustls", "sqlite", "macros", "migrate", "chrono" ] } sha3 = "0.10.8" hex = "0.4.3" rand = "0.8.5" diff --git a/auth/build.rs b/auth/build.rs new file mode 100644 index 0000000..d506869 --- /dev/null +++ b/auth/build.rs @@ -0,0 +1,5 @@ +// generated by `sqlx migrate build-script` +fn main() { + // trigger recompilation when a new migration is added + println!("cargo:rerun-if-changed=migrations"); +} diff --git a/auth/migrations/20240725215330_create_everything.sql b/auth/migrations/20240725215330_create_everything.sql new file mode 100644 index 0000000..17a3c1a --- /dev/null +++ b/auth/migrations/20240725215330_create_everything.sql @@ -0,0 +1,14 @@ +-- Add migration script here +CREATE TABLE IF NOT EXISTS user ( + id INTEGER PRIMARY KEY, + username VARCHAR(255) NOT NULL, + email VARCHAR(255) NOT NULL, + new_email VARCHAR(255), + avatar TEXT NOT NULL, + login_code INT(6), + tokens TEXT, + google_oauth VARCHAR(255), + apple_oauth VARCHAR(255), + github_oauth VARCHAR(255), + discord_oauth VARCHAR(255) + ); \ No newline at end of file diff --git a/auth/src/main.rs b/auth/src/main.rs index b9f4196..1972669 100644 --- a/auth/src/main.rs +++ b/auth/src/main.rs @@ -3,7 +3,8 @@ use std::future::Future; use std::net::{IpAddr, Ipv6Addr}; use dotenvy::dotenv; use futures::{future, StreamExt}; -use sqlx::mysql::MySqlPoolOptions; +use sqlx::{migrate, Sqlite, SqlitePool}; +use sqlx::migrate::MigrateDatabase; use tarpc::server::{BaseChannel, Channel}; use tarpc::server::incoming::Incoming; use tarpc::tokio_serde::formats::Json; @@ -27,28 +28,21 @@ async fn main() -> anyhow::Result<()> { auth_password: env::var("SERVER_MAIL_PASSWORD").expect("SERVER_MAIL_PASSWORD must be set"), }; - let db_pool = MySqlPoolOptions::new() - .max_connections(64) - .connect(env::var("DATABASE_URL").expect("DATABASE_URL must be set").as_str()).await?; + let DB_URL: &str = &env::var("DATABASE_URL").expect("DATABASE_URL must be set"); - //TODO: In a docker container or figure out somewhere to do this command - //sqlx::query("CREATE DATABASE IF NOT EXISTS realmauth").execute(&db_pool).await?; + if !Sqlite::database_exists(DB_URL).await.unwrap_or(false) { + println!("Creating database {}", DB_URL); + match Sqlite::create_database(DB_URL).await { + Ok(_) => println!("Create db success"), + Err(error) => panic!("error: {}", error), + } + } else { + println!("Database already exists"); + } // TODO: Do in Docker with Sqlx-cli + + let db_pool = SqlitePool::connect(DB_URL).await.unwrap(); - sqlx::query( - "CREATE TABLE IF NOT EXISTS user ( - id SERIAL, - username VARCHAR(255) NOT NULL, - email VARCHAR(255) NOT NULL, - new_email VARCHAR(255), - avatar TEXT NOT NULL, - login_code INT(6), - tokens TEXT, - google_oauth VARCHAR(255), - apple_oauth VARCHAR(255), - github_oauth VARCHAR(255), - discord_oauth VARCHAR(255) - );" - ).execute(&db_pool).await?; + migrate!().run(&db_pool).await?; // TODO: Do in Docker with Sqlx-cli let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), env::var("PORT").expect("PORT must be set").parse::().unwrap()); diff --git a/auth/src/server.rs b/auth/src/server.rs index 7c68ec0..573e741 100644 --- a/auth/src/server.rs +++ b/auth/src/server.rs @@ -8,7 +8,7 @@ use rand::Rng; use regex::Regex; use sha3::{Digest, Sha3_256}; use sha3::digest::Update; -use sqlx::{MySql, Pool, Row}; +use sqlx::{Pool, query, Sqlite}; use tarpc::context::Context; use crate::types::{AuthEmail, AuthUser, RealmAuth}; @@ -18,7 +18,7 @@ use realm_shared::types::ErrorCode::*; #[derive(Clone)] pub struct RealmAuthServer { pub socket: SocketAddr, - pub db_pool: Pool, + pub db_pool: Pool, pub auth_email: AuthEmail, pub template_html: String, pub template_txt: String, @@ -26,7 +26,7 @@ pub struct RealmAuthServer { } impl RealmAuthServer { - pub fn new(socket: SocketAddr, db_pool: Pool, auth_email: AuthEmail) -> RealmAuthServer { + pub fn new(socket: SocketAddr, db_pool: Pool, auth_email: AuthEmail) -> RealmAuthServer { RealmAuthServer { socket, db_pool, @@ -52,34 +52,29 @@ impl RealmAuthServer { } pub async fn is_username_taken(&self, username: &str) -> Result { - let result = sqlx::query("SELECT NOT EXISTS (SELECT 1 FROM user WHERE username = ?) AS does_exist") - .bind(username) - .fetch_one(&self.db_pool).await; + let result = query!("SELECT NOT EXISTS (SELECT 1 FROM user WHERE username = ?) AS does_exist", username).fetch_one(&self.db_pool).await; match result { - Ok(row) => Ok(row.try_get("does_exist").unwrap()), + Ok(row) => Ok(row.does_exist.unwrap() != 0), Err(_) => Err(InvalidUsername) } } pub async fn is_email_taken(&self, email: &str) -> Result { - let result = sqlx::query("SELECT NOT EXISTS (SELECT 1 FROM user WHERE email = ?) AS does_exist") - .bind(email) - .fetch_one(&self.db_pool).await; + let result = query!("SELECT NOT EXISTS (SELECT 1 FROM user WHERE email = ?) AS does_exist", email).fetch_one(&self.db_pool).await; match result { - Ok(row) => Ok(row.try_get("does_exist").unwrap()), + Ok(row) => Ok(row.does_exist.unwrap() != 0), Err(_) => Err(InvalidUsername) } } pub async fn is_authorized(&self, username: &str, token: &str) -> Result { - let result = sqlx::query("SELECT tokens FROM user WHERE username = ?") - .bind(username).fetch_one(&self.db_pool).await; + let result = query!("SELECT tokens FROM user WHERE username = ?", username).fetch_one(&self.db_pool).await; match result { Ok(row) => { - let token_long: &str = row.try_get("tokens").unwrap(); + let token_long: &str = &row.tokens.unwrap(); let tokens = token_long.split(',').collect::>(); for i in 0..tokens.len() { @@ -129,13 +124,11 @@ impl RealmAuthServer { } pub async fn is_login_code_valid(&self, username: &str, login_code: u16) -> Result { - let result = sqlx::query("SELECT login_code FROM user WHERE username = ?;") - .bind(username) - .fetch_one(&self.db_pool).await; + let result = query!("SELECT login_code FROM user WHERE username = ?;", username).fetch_one(&self.db_pool).await; match result { Ok(row) => { - if row.try_get::("login_code").unwrap() != login_code { + if row.login_code.unwrap() as u16 != login_code { return Ok(false) } Ok(true) @@ -171,11 +164,11 @@ impl RealmAuth for RealmAuthServer { } async fn server_token_validation(self, _: Context, server_token: String, username: String, server_id: String, domain: String, tarpc_port: u16) -> bool { - let result = sqlx::query("SELECT tokens FROM user WHERE username = ?").bind(username).fetch_one(&self.db_pool).await; + let result = query!("SELECT tokens FROM user WHERE username = ?", username).fetch_one(&self.db_pool).await; match result { Ok(row) => { - let token_long: &str = row.try_get("tokens").unwrap(); + let token_long: &str = &row.tokens.unwrap(); let tokens = token_long.split(',').collect::>(); for token in tokens { @@ -207,8 +200,8 @@ impl RealmAuth for RealmAuthServer { let code = self.gen_login_code(); self.send_login_message(&username, &email, code).await?; - let result = sqlx::query("INSERT INTO user (username, email, avatar, login_code, tokens) VALUES (?, ?, '', ?, '')") - .bind(&username).bind(&email).bind(code).execute(&self.db_pool).await; + let result = query!("INSERT INTO user (username, email, avatar, login_code, tokens) VALUES (?, ?, '', ?, '')", username, email, code) + .execute(&self.db_pool).await; match result { Ok(_) => Ok(()), @@ -222,26 +215,26 @@ impl RealmAuth for RealmAuthServer { } if username.is_none() { - let result = sqlx::query("SELECT username FROM user WHERE email = ?;") - .bind(&email.clone().unwrap()) + let tmp = email.clone().unwrap(); + let result = query!("SELECT username FROM user WHERE email = ?;", tmp) .fetch_one(&self.db_pool).await; match result { Ok(row) => { - username = row.try_get("username").unwrap(); + username = Some(row.username); } Err(_) => return Err(InvalidEmail) } } - if email.is_none() { - let result = sqlx::query("SELECT email FROM user WHERE username = ?;") - .bind(&username.clone().unwrap()) + if email.clone().is_none() { + let tmp = username.clone().unwrap(); + let result = query!("SELECT email FROM user WHERE username = ?;", tmp) .fetch_one(&self.db_pool).await; match result { Ok(row) => { - email = row.try_get("email").unwrap(); + email = Some(row.email); } Err(_) => return Err(InvalidUsername) } @@ -249,9 +242,7 @@ impl RealmAuth for RealmAuthServer { let code = self.gen_login_code(); - let result = sqlx::query("UPDATE user SET login_code = ? WHERE username = ?;") - .bind(code) - .bind(&username) + let result = query!("UPDATE user SET login_code = ? WHERE username = ?;", code, username) .execute(&self.db_pool).await; match result { @@ -265,21 +256,20 @@ impl RealmAuth for RealmAuthServer { return Err(InvalidLoginCode) } - let _ = sqlx::query("UPDATE user SET login_code = NULL WHERE username = ?").bind(&username).execute(&self.db_pool).await; + let _ = query!("UPDATE user SET login_code = NULL WHERE username = ?", username).execute(&self.db_pool).await; let hash = Sha3_256::new().chain(format!("{}{}{}", username, login_code, Utc::now().to_utc())).finalize(); let token = hex::encode(hash); - let result = sqlx::query("SELECT tokens FROM user WHERE username = ?").bind(&username).fetch_one(&self.db_pool).await; + let result = query!("SELECT tokens FROM user WHERE username = ?", username).fetch_one(&self.db_pool).await; match result { Ok(row) => { - let token_long: &str = row.try_get("tokens").unwrap(); + let token_long: &str = &row.tokens.unwrap(); let mut tokens = token_long.split(',').collect::>(); tokens.push(&token); - let result = sqlx::query("UPDATE user SET tokens = ? WHERE username = ?") - .bind(tokens.join(",")) // TODO: This doesn't seem right and may cause problems - .bind(&username) + let mega_token = tokens.join(","); + let result = query!("UPDATE user SET tokens = ? WHERE username = ?", mega_token, username) .execute(&self.db_pool).await; match result { Ok(_) => Ok(token), @@ -299,9 +289,7 @@ impl RealmAuth for RealmAuthServer { return Err(EmailTaken) } - let result = sqlx::query("UPDATE user SET new_email = ? WHERE username = ?") - .bind(&new_email) - .bind(&username) + let result = query!("UPDATE user SET new_email = ? WHERE username = ?", new_email, username) .execute(&self.db_pool).await; match result { Ok(_) => {} @@ -310,9 +298,7 @@ impl RealmAuth for RealmAuthServer { let code = self.gen_login_code(); - let result = sqlx::query("UPDATE user SET login_code = ? WHERE username = ?;") - .bind(code) - .bind(&username) + let result = query!("UPDATE user SET login_code = ? WHERE username = ?;", code, username) .execute(&self.db_pool).await; match result { @@ -334,14 +320,9 @@ impl RealmAuth for RealmAuthServer { return Err(InvalidLoginCode) } - let _ = sqlx::query("UPDATE user SET new_email = NULL WHERE username = ?") - .bind(&username) - .execute(&self.db_pool).await; + let _ = query!("UPDATE user SET new_email = NULL WHERE username = ?", username).execute(&self.db_pool).await; - let _ = sqlx::query("UPDATE user SET email = ? WHERE username = ?") - .bind(&new_email) - .bind(&username) - .execute(&self.db_pool).await; + let _ = query!("UPDATE user SET email = ? WHERE username = ?", new_email, username).execute(&self.db_pool).await; Ok(()) } @@ -359,9 +340,7 @@ impl RealmAuth for RealmAuthServer { return Err(UsernameTaken) } - let result = sqlx::query("UPDATE user SET username = ? WHERE username = ?") - .bind(&new_username) - .bind(&username).execute(&self.db_pool).await; + let result = query!("UPDATE user SET username = ? WHERE username = ?", new_username, username).execute(&self.db_pool).await; match result { Ok(_) => Ok(()), Err(_) => Err(Error) @@ -373,9 +352,7 @@ impl RealmAuth for RealmAuthServer { return Err(Unauthorized) } - let result = sqlx::query("UPDATE user SET avatar = ? WHERE username = ?") - .bind(&new_avatar) - .bind(&username).execute(&self.db_pool).await; + let result = query!("UPDATE user SET avatar = ? WHERE username = ?", new_avatar, username).execute(&self.db_pool).await; match result { Ok(_) => Ok(()), Err(_) => Err(Error) @@ -393,21 +370,20 @@ impl RealmAuth for RealmAuthServer { Err(error) => return Err(error) } - let result = sqlx::query("SELECT * FROM user WHERE username = ?") - .bind(&username).fetch_one(&self.db_pool).await; + let result = query!(r"SELECT * FROM user WHERE username = ?", username).fetch_one(&self.db_pool).await; match result { Ok(row) => { Ok(AuthUser { - id: row.try_get("id").unwrap(), - username: row.try_get("username").unwrap(), - email: row.try_get("email").unwrap(), - avatar: row.try_get("avatar").unwrap(), + id: row.id, + username: row.username, + email: row.email, + avatar: row.avatar, login_code: None, - bigtoken: row.try_get("tokens").unwrap(), - google_oauth: row.try_get("google_oauth").unwrap(), - apple_oauth: row.try_get("apple_oauth").unwrap(), - github_oauth: row.try_get("github_oauth").unwrap(), - discord_oauth: row.try_get("discord_oauth").unwrap(), + bigtoken: row.tokens, + google_oauth: row.google_oauth, + apple_oauth: row.apple_oauth, + github_oauth: row.github_oauth, + discord_oauth: row.discord_oauth, }) } Err(_) => Err(InvalidUsername) @@ -415,21 +391,19 @@ impl RealmAuth for RealmAuthServer { } async fn sign_out(self, _: Context, username: String, token: String) -> Result<(), ErrorCode> { - let result = sqlx::query("SELECT tokens FROM user WHERE username = ?") - .bind(&username).fetch_one(&self.db_pool).await; + let result = query!("SELECT tokens FROM user WHERE username = ?", username).fetch_one(&self.db_pool).await; match result { Ok(row) => { - let token_long: &str = row.try_get("tokens").unwrap(); + let token_long: &str = &row.tokens.unwrap(); let mut tokens = token_long.split(',').collect::>(); for i in 0..tokens.len() { if tokens.get(i).unwrap().eq(&token.as_str()) { tokens.remove(i); - let result = sqlx::query("UPDATE user SET tokens = ? WHERE username = ?") - .bind(tokens.join(",")) - .bind(&username) + let mega_token = tokens.join(",").to_string(); + let result = query!("UPDATE user SET tokens = ? WHERE username = ?", mega_token, username) .execute(&self.db_pool).await; return match result { @@ -446,10 +420,10 @@ impl RealmAuth for RealmAuthServer { } async fn get_avatar_for_user(self, _: Context, username: String) -> Result { - let result = sqlx::query("SELECT tokens FROM user WHERE username = ?").bind(username).fetch_one(&self.db_pool).await; + let result = query!("SELECT avatar FROM user WHERE username = ?", username).fetch_one(&self.db_pool).await; match result { - Ok(row) => Ok(row.try_get("avatar").unwrap_or("".to_string())), + Ok(row) => Ok(row.avatar), Err(_) => Err(InvalidUsername) } } diff --git a/auth/src/types.rs b/auth/src/types.rs index 511ea47..b23cdbb 100644 --- a/auth/src/types.rs +++ b/auth/src/types.rs @@ -24,7 +24,7 @@ pub trait RealmAuth { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AuthUser { - pub id: u32, + pub id: i64, pub username: String, pub email: String, pub avatar: String,