diff --git a/auth/Cargo.toml b/auth/Cargo.toml index b653122..dfdeca2 100644 --- a/auth/Cargo.toml +++ b/auth/Cargo.toml @@ -12,4 +12,6 @@ tracing = "0.1.40" serde = { version = "1.0.203", features = ["derive"] } chrono = { version = "0.4.24", features = ["serde"] } dotenvy = "0.15" -sqlx = { version = "0.7", features = [ "runtime-tokio", "tls-rustls", "mysql", "chrono" ] } \ No newline at end of file +sqlx = { version = "0.7", features = [ "runtime-tokio", "tls-rustls", "mysql", "chrono" ] } +sha3 = "0.10.8" +hex = "0.4.3" diff --git a/auth/src/main.rs b/auth/src/main.rs index 11652bc..2535e2e 100644 --- a/auth/src/main.rs +++ b/auth/src/main.rs @@ -31,6 +31,7 @@ async fn main() -> anyhow::Result<()> { id SERIAL, username VARCHAR(255) NOT NULL, email VARCHAR(255) NOT NULL, + avatar TEXT NOT NULL login_code INT(6), tokens TEXT, google_oauth VARCHAR(255), @@ -56,7 +57,7 @@ async fn main() -> anyhow::Result<()> { // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = RealmAuthServer::new(channel.transport().peer_addr().unwrap()); + let server = RealmAuthServer::new(channel.transport().peer_addr().unwrap(), db_pool); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/auth/src/server.rs b/auth/src/server.rs index 794b2ba..51a1667 100644 --- a/auth/src/server.rs +++ b/auth/src/server.rs @@ -1,22 +1,86 @@ use std::net::SocketAddr; + +use sha3::{Digest, Sha3_256}; +use sha3::digest::Update; +use sqlx::{MySql, Pool, Row}; use tarpc::context::Context; -use crate::types::RealmAuth; + +use crate::types::{AuthUser, ErrorCode, RealmAuth}; #[derive(Clone)] pub struct RealmAuthServer { pub socket: SocketAddr, + pub db_pool: Pool, } impl RealmAuthServer { - pub fn new(socket: SocketAddr) -> RealmAuthServer { + pub fn new(socket: SocketAddr, db_pool: Pool) -> RealmAuthServer { RealmAuthServer { socket, + db_pool, } } } impl RealmAuth for RealmAuthServer { - async fn test(self, context: Context, name: String) -> String { - format!("Hello {}", name) + async fn test(self, _: Context, name: String) -> String { + format!("Hello {} auth!", name) + } + + async fn server_token_validation(self, _: Context, server_token: String, username: String, server_id: String, domain: String, tarpc_port: u16) -> bool { + let result = sqlx::query("SELECT tokens FROM user WHERE username = ?").bind(username).fetch_one(&self.db_pool).await; + + match result { + Ok(row) => { + let token_long: &str = row.try_get("tokens").unwrap(); + let tokens = token_long.split(',').collect::>(); + + for token in tokens { + let hash = Sha3_256::new().chain(format!("{}{}{}{}", token, server_id, domain, tarpc_port)).finalize(); + if hex::encode(hash) == server_token { + return true + } + } + + false + }, + Err(_) => false, + } + } + + async fn create_account(self, _: Context, username: String, email: String, avatar: String) -> Result { + todo!() + } + + async fn create_login_flow(self, _: Context, username: String) -> ErrorCode { + todo!() + } + + async fn create_token_from_login(self, _: Context, username: String, login_code: u16) -> Result { + todo!() + } + + async fn change_email_flow(self, _: Context, username: String, token: String) -> ErrorCode { + todo!() + } + + async fn resolve_email_flow(self, _: Context, username: String, token: String, login_code: u16, new_email: String) -> ErrorCode { + todo!() + } + + async fn change_username(self, _: Context, username: String, token: String, new_username: String) -> ErrorCode { + todo!() + } + + async fn change_avatar(self, _: Context, username: String, token: String, avatar: String) -> ErrorCode { + todo!() + } + + async fn get_all_data(self, _: Context, username: String, token: String) -> Result { + todo!() + } + + async fn get_avatar_for_user(self, _: Context, username: String) -> Result { + todo!() } } \ No newline at end of file diff --git a/auth/src/types.rs b/auth/src/types.rs index 943b089..b529379 100644 --- a/auth/src/types.rs +++ b/auth/src/types.rs @@ -3,16 +3,20 @@ use serde::{Deserialize, Serialize}; #[tarpc::service] pub trait RealmAuth { async fn test(name: String) -> String; - async fn server_token_validation(username: String, server_id: String, domain: String, tarpc_port: u16) -> bool; + async fn server_token_validation(server_token: String, username: String, server_id: String, domain: String, tarpc_port: u16) -> bool; async fn create_account(username: String, email: String, avatar: String) -> Result; async fn create_login_flow(username: String) -> ErrorCode; async fn create_token_from_login(username: String, login_code: u16) -> Result; //NOTE: Need to be the user - async fn change_email_flow(token: String) -> ErrorCode; - async fn resolve_email_flow(token: String, login_code: u16, new_email: String) -> ErrorCode; - async fn change_username(token: String, new_username: String) -> ErrorCode; - async fn change_avatar(token: String, avatar: String) -> ErrorCode; + async fn change_email_flow(username: String, token: String) -> ErrorCode; + async fn resolve_email_flow(username: String, token: String, login_code: u16, new_email: String) -> ErrorCode; + async fn change_username(username: String, token: String, new_username: String) -> ErrorCode; + async fn change_avatar(username: String, token: String, avatar: String) -> ErrorCode; + async fn get_all_data(username: String, token: String) -> Result; + + //NOTE: Anyone can call + async fn get_avatar_for_user(username: String) -> Result; //TODO: // Create account // Change email @@ -32,4 +36,20 @@ pub enum ErrorCode { EmailTaken, UsernameTaken, InvalidLoginCode, + InvalidImage, + InvalidUsername, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthUser { + pub id: u32, + pub username: String, + pub email: String, + pub avatar: String, + pub login_code: Option, + pub tokens: Option>, + pub google_oauth: Option, + pub apple_oauth: Option, + pub github_oauth: Option, + pub discord_oauth: Option, } \ No newline at end of file