From 38ca7cb26b587df81ef73c8ebfb31060b5c9e117 Mon Sep 17 00:00:00 2001 From: Joshua Higgins Date: Tue, 30 Jul 2024 23:42:04 -0400 Subject: [PATCH] Most of auth checking done TODO: reddit voting, get_messages_since, typing indicators, reddit algorithims for threads, thread fetching --- server/Cargo.toml | 5 +- server/src/main.rs | 105 +++++++++++++------------- server/src/server.rs | 175 +++++++++++++++++++++++++++++++++---------- server/src/types.rs | 16 ++-- 4 files changed, 202 insertions(+), 99 deletions(-) diff --git a/server/Cargo.toml b/server/Cargo.toml index 3e6b489..fa9a7a1 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -15,5 +15,8 @@ emojis = "0.6.3" chrono = { version = "0.4.38", features = ["serde"] } sqlx = { version = "0.8.0", features = [ "runtime-tokio", "tls-rustls", "sqlite", "chrono" ] } dotenvy = "0.15.7" +moka = { version = "0.12.8", features = ["future"] } +futures-util = "0.3.30" + realm_auth = { path = "../auth" } -realm_shared = { path = "../shared" } \ No newline at end of file +realm_shared = { path = "../shared" } diff --git a/server/src/main.rs b/server/src/main.rs index 5b9d865..9ae1be0 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -8,74 +8,79 @@ use sqlx::migrate::MigrateDatabase; use sqlx::{migrate, Sqlite, SqlitePool}; use sqlx::sqlite::SqlitePoolOptions; use tarpc::{ - server::{Channel}, - tokio_serde::formats::Json, + server::{Channel}, + tokio_serde::formats::Json, }; use tarpc::server::incoming::Incoming; use tarpc::server::BaseChannel; use tracing::{info, subscriber, warn}; +use realm_auth::types::RealmAuthClient; use realm_server::server::RealmChatServer; use realm_server::types::RealmChat; async fn spawn(fut: impl Future + Send + 'static) { - tokio::spawn(fut); + tokio::spawn(fut); } #[tokio::main] async fn main() -> anyhow::Result<()> { - dotenv().ok(); + dotenv().ok(); - let subscriber = tracing_subscriber::fmt() - .compact() - .with_file(true) - .with_line_number(true) - .with_thread_ids(true) - .with_target(false) - .finish(); + let subscriber = tracing_subscriber::fmt() + .compact() + .with_file(true) + .with_line_number(true) + .with_thread_ids(true) + .with_target(false) + .finish(); - subscriber::set_global_default(subscriber).unwrap(); + subscriber::set_global_default(subscriber).unwrap(); - let database_url: &str = &env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + let database_url: &str = &env::var("DATABASE_URL").expect("DATABASE_URL must be set"); - if !Sqlite::database_exists(database_url).await.unwrap_or(false) { - info!("Creating database {}", database_url); - match Sqlite::create_database(database_url).await { - Ok(_) => info!("Create db success"), - Err(error) => panic!("error: {}", error), - } - } else { - warn!("Database already exists"); - } // TODO: Do in Docker with Sqlx-cli + if !Sqlite::database_exists(database_url).await.unwrap_or(false) { + info!("Creating database {}", database_url); + match Sqlite::create_database(database_url).await { + Ok(_) => info!("Create db success"), + Err(error) => panic!("error: {}", error), + } + } else { + warn!("Database already exists"); + } // TODO: Do in Docker with Sqlx-cli - let db_pool = SqlitePool::connect(database_url).await.unwrap(); + let db_pool = SqlitePool::connect(database_url).await.unwrap(); - info!("Running migrations..."); - migrate!().run(&db_pool).await?; // TODO: Do in Docker with Sqlx-cli - info!("Migrations complete!"); - - let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), env::var("PORT").expect("PORT must be set").parse::().unwrap()); + info!("Running migrations..."); + migrate!().run(&db_pool).await?; // TODO: Do in Docker with Sqlx-cli + info!("Migrations complete!"); - // JSON transport is provided by the json_transport tarpc module. It makes it easy - // to start up a serde-powered json serialization strategy over TCP. - let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?; - info!("Listening on port {}", listener.local_addr().port()); - listener.config_mut().max_frame_length(usize::MAX); - listener - // Ignore accept errors. - .filter_map(|r| future::ready(r.ok())) - .map(BaseChannel::with_defaults) - // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) - // serve is generated by the service attribute. It takes as input any type implementing - // the generated World trait. - .map(|channel| { - let server = RealmChatServer::new(env::var("SERVER_ID").expect("SERVER_ID must be set"), channel.transport().peer_addr().unwrap(), db_pool.clone()); - channel.execute(server.serve()).for_each(spawn) - }) - // Max 10 channels. - .buffer_unordered(10) - .for_each(|_| async {}) - .await; + let mut auth_transport = tarpc::serde_transport::tcp::connect((IpAddr::V6(Ipv6Addr::LOCALHOST), 5052), Json::default); + auth_transport.config_mut().max_frame_length(usize::MAX); + let auth_client = RealmAuthClient::new(tarpc::client::Config::default(), auth_transport.await?).spawn(); - Ok(()) + let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), env::var("PORT").expect("PORT must be set").parse::().unwrap()); + + // JSON transport is provided by the json_transport tarpc module. It makes it easy + // to start up a serde-powered json serialization strategy over TCP. + let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?; + info!("Listening on port {}", listener.local_addr().port()); + listener.config_mut().max_frame_length(usize::MAX); + listener + // Ignore accept errors. + .filter_map(|r| future::ready(r.ok())) + .map(BaseChannel::with_defaults) + // Limit channels to 1 per IP. + .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) + // serve is generated by the service attribute. It takes as input any type implementing + // the generated World trait. + .map(|channel| { + let server = RealmChatServer::new(env::var("SERVER_ID").expect("SERVER_ID must be set"), channel.transport().peer_addr().unwrap(), db_pool.clone(), auth_client.clone()); + channel.execute(server.serve()).for_each(spawn) + }) + // Max 10 channels. + .buffer_unordered(10) + .for_each(|_| async {}) + .await; + + Ok(()) } \ No newline at end of file diff --git a/server/src/server.rs b/server/src/server.rs index 927b111..2abe77e 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -1,10 +1,14 @@ +use std::env; use std::net::SocketAddr; - +use std::time::Duration; use chrono::{DateTime, Utc}; +use moka::future::Cache; use sqlx::{FromRow, Pool, query_as, Sqlite}; use sqlx::query; +use sqlx::sqlite::SqliteQueryResult; use tarpc::context::Context; - +use tracing::error; +use realm_auth::types::RealmAuthClient; use realm_shared::types::ErrorCode::*; use realm_shared::types::ErrorCode; @@ -13,24 +17,123 @@ use crate::types::{Message, MessageData, RealmChat, Room, User}; #[derive(Clone)] pub struct RealmChatServer { pub server_id: String, + pub domain: String, + pub port: u16, pub socket: SocketAddr, pub db_pool: Pool, - pub typing_users: Vec<(i64, i64)> //NOTE: userid, roomid -} //TODO: Cache for auth + pub typing_users: Vec<(String, String)>, //NOTE: user.userid, room.roomid + pub auth_client: RealmAuthClient, + pub cache: Cache, +} + +impl RealmChatServer { + pub fn new(server_id: String, socket: SocketAddr, db_pool: Pool, auth_client: RealmAuthClient) -> RealmChatServer { + RealmChatServer { + server_id, + port: env::var("PORT").unwrap().parse::().unwrap(), + domain: env::var("DOMAIN").expect("DOMAIN must be set"), + socket, + db_pool, + typing_users: Vec::new(), + auth_client, + cache: Cache::builder() + .max_capacity(10_000) + .time_to_idle(Duration::from_secs(5*60)) + .time_to_live(Duration::from_secs(60*60)) + .build() + } + } + + pub async fn is_stoken_valid(&self, userid: &str, stoken: &str) -> bool { + match self.cache.get(stoken).await { + None => { + let result = self.auth_client.server_token_validation( + tarpc::context::current(), stoken.to_string(), userid.to_string(), self.server_id.clone(), self.domain.clone(), self.port) + .await; + + match result { + Ok(valid) => { + if valid { + self.cache.insert(stoken.to_string(), userid.to_string()).await; + return true + } + false + } + Err(_) => { + error!("Error validating server token for user, {}, with stoken {}", userid, stoken); + false + } + } + } + Some(cached_username) => { + if cached_username.eq(userid) { + true + } else { + false + } + }, + } + } + + pub async fn is_user_admin(&self, stoken: &str) -> bool { + if let Some(userid) = self.cache.get(stoken).await { + let result = query!("SELECT admin FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await; + return match result { + Ok(record) => { + if record.admin { + return true + } + false + } + Err(_) => false + } + } + false + } +} impl RealmChat for RealmChatServer { async fn test(self, _: Context, name: String) -> String { format!("Hello, {name}!") } - async fn send_message(self, _: Context, auth_token: String, message: Message) -> Result { - //TODO: verify authentication somehow for edits and redactions + async fn send_message(self, _: Context, stoken: String, message: Message) -> Result { + if !self.is_stoken_valid(&message.user.userid, &stoken).await { + return Err(Unauthorized) + } + match &message.data { + MessageData::Edit(e) => { + let ref_msg = self.get_message_from_id(tarpc::context::current(), stoken.clone(), e.referencing_id).await?; + if !ref_msg.user.userid.eq(&message.user.userid) { + return Err(Unauthorized) + } + } + MessageData::Redaction(r)=> { + let ref_msg = self.get_message_from_id(tarpc::context::current(), stoken.clone(), r.referencing_id).await?; + if !ref_msg.user.userid.eq(&message.user.userid) { + return Err(Unauthorized) + } + } + _ => {} + } + + let is_admin = self.is_user_admin(&stoken).await; + let admin_only_send = query!("SELECT admin_only_send FROM room WHERE roomid = ?", + message.room.roomid).fetch_one(&self.db_pool).await; + if let Ok(record) = admin_only_send { + if record.admin_only_send && !is_admin { + return Err(Unauthorized) + } + } else { + return Err(RoomNotFound) + } + let result = match &message.data { - MessageData::Text(text) => { + MessageData::Text(text) => { query!("INSERT INTO message (timestamp, user, room, msg_type, msg_text) VALUES (?, ?, ?, 'text', ?)", message.timestamp, message.user.id, message.room.id, text) - .execute(&self.db_pool).await + .execute(&self.db_pool).await } MessageData::Attachment(attachment) => { todo!() } MessageData::Reply(reply) => { @@ -54,9 +157,9 @@ impl RealmChat for RealmChatServer { .execute(&self.db_pool).await } }; - + match result { - Ok(ids) => { + Ok(ids) => { //TODO: Tell everyone Ok(message) @@ -65,29 +168,30 @@ impl RealmChat for RealmChatServer { } } - async fn start_typing(self, _: Context, auth_token: String) -> ErrorCode { //TODO: auth for all of these + async fn start_typing(self, _: Context, stoken: String, userid: String, roomid: String) -> ErrorCode { //TODO: auth for all of these todo!() } - async fn stop_typing(self, _: Context, auth_token: String) -> ErrorCode { + async fn stop_typing(self, _: Context, stoken: String, userid: String, roomid: String) -> ErrorCode { todo!() } - async fn keep_typing(self, _: Context, auth_token: String) -> ErrorCode { + async fn keep_typing(self, _: Context, stoken: String, userid: String, roomid: String) -> ErrorCode { todo!() } - async fn get_message_from_id(self, _: Context, auth_token: String, id: i64) -> Result { - //TODO: Auth for admin room + async fn get_message_from_id(self, _: Context, stoken: String, id: i64) -> Result { + let is_admin = self.is_user_admin(&stoken).await; let result = sqlx::query("SELECT message.*, room.id AS 'room_id', room.roomid AS 'room_roomid', room.name AS 'room_name', room.admin_only_send AS 'room_admin_only_send', room.admin_only_view AS 'room_admin_only_view', user.id AS 'user_id', user.userid AS 'user_userid', user.name AS 'user_name', user.online AS 'user_online', user.admin AS 'user_admin' - FROM message INNER JOIN room ON message.room = room.id INNER JOIN user ON message.user = user.id WHERE message.id = ?") + FROM message INNER JOIN room ON message.room = room.id INNER JOIN user ON message.user = user.id WHERE message.id = ? AND room.admin_only_view = ? OR false") .bind(id) + .bind(is_admin) .fetch_one(&self.db_pool).await; match result { - Ok(row) => { + Ok(row) => { Ok(Message::from_row(&row).unwrap()) }, Err(_) => { @@ -96,36 +200,38 @@ impl RealmChat for RealmChatServer { } } - async fn get_messages_since(self, _: Context, auth_token: String, time: DateTime) -> Result, ErrorCode> { + async fn get_messages_since(self, _: Context, stoken: String, time: DateTime) -> Result, ErrorCode> { //TODO: Auth for admin rooms todo!() } - async fn get_rooms(self, _: Context, auth_token: String) -> Result, ErrorCode> { - //TODO: Auth for admin rooms! - let result = query_as!(Room, "SELECT * FROM room").fetch_all(&self.db_pool).await; + async fn get_rooms(self, _: Context, stoken: String) -> Result, ErrorCode> { + let is_admin = self.is_user_admin(&stoken).await; + let result = query_as!( + Room, "SELECT * FROM room WHERE admin_only_view = ? OR false", is_admin).fetch_all(&self.db_pool).await; match result { - Ok(rooms) => Ok(rooms), + Ok(rooms) => Ok(rooms), Err(_) => Err(Error), } } - async fn get_room(self, _: Context, auth_token: String, roomid: String) -> Result { - //TODO: Auth for admin rooms! - let result = query_as!(Room, "SELECT * FROM room WHERE roomid = ?", roomid).fetch_one(&self.db_pool).await; - + async fn get_room(self, _: Context, stoken: String, roomid: String) -> Result { + let is_admin = self.is_user_admin(&stoken).await; + let result = query_as!( + Room, "SELECT * FROM room WHERE roomid = ? AND admin_only_view = ? OR false", is_admin, roomid).fetch_one(&self.db_pool).await; + match result { - Ok(room) => { Ok(room) }, + Ok(room) => { Ok(room) }, Err(_) => Err(RoomNotFound), } } async fn get_user(self, _: Context, userid: String) -> Result { let result = query_as!(User, "SELECT * FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await; - + match result { - Ok(user) => { Ok(user) }, + Ok(user) => { Ok(user) }, Err(_) => Err(UserNotFound), } } @@ -147,15 +253,4 @@ impl RealmChat for RealmChatServer { Err(_) => Err(Error), } } -} - -impl RealmChatServer { - pub fn new(server_id: String, socket: SocketAddr, db_pool: Pool) -> RealmChatServer { - RealmChatServer { - server_id, - socket, - db_pool, - typing_users: Vec::new(), - } - } } \ No newline at end of file diff --git a/server/src/types.rs b/server/src/types.rs index 55864aa..690e91e 100644 --- a/server/src/types.rs +++ b/server/src/types.rs @@ -12,16 +12,16 @@ pub trait RealmChat { async fn test(name: String) -> String; //TODO: Any user authorized as themselves - async fn send_message(auth_token: String, message: Message) -> Result; - async fn start_typing(auth_token: String) -> ErrorCode; - async fn stop_typing(auth_token: String) -> ErrorCode; - async fn keep_typing(auth_token: String) -> ErrorCode; //NOTE: If a keep alive hasn't been received in 5 seconds, stop typing + async fn send_message(stoken: String, message: Message) -> Result; + async fn start_typing(stoken: String, userid: String, roomid: String) -> ErrorCode; + async fn stop_typing(stoken: String, userid: String, roomid: String) -> ErrorCode; + async fn keep_typing(stoken: String, userid: String, roomid: String) -> ErrorCode; //NOTE: If a keep alive hasn't been received in 5 seconds, stop typing //NOTE: Any user can call, if they are in the server - async fn get_message_from_id(auth_token: String, id: i64) -> Result; - async fn get_messages_since(auth_token: String, time: DateTime) -> Result, ErrorCode>; - async fn get_rooms(auth_token: String) -> Result, ErrorCode>; - async fn get_room(auth_token: String, roomid: String) -> Result; + async fn get_message_from_id(stoken: String, id: i64) -> Result; + async fn get_messages_since(stoken: String, time: DateTime) -> Result, ErrorCode>; + async fn get_rooms(stoken: String) -> Result, ErrorCode>; + async fn get_room(stoken: String, roomid: String) -> Result; async fn get_user(userid: String) -> Result; async fn get_users() -> Result, ErrorCode>; async fn get_online_users() -> Result, ErrorCode>;