Most of auth checking done
TODO: reddit voting, get_messages_since, typing indicators, reddit algorithims for threads, thread fetching
This commit is contained in:
@@ -15,5 +15,8 @@ emojis = "0.6.3"
|
||||
chrono = { version = "0.4.38", features = ["serde"] }
|
||||
sqlx = { version = "0.8.0", features = [ "runtime-tokio", "tls-rustls", "sqlite", "chrono" ] }
|
||||
dotenvy = "0.15.7"
|
||||
moka = { version = "0.12.8", features = ["future"] }
|
||||
futures-util = "0.3.30"
|
||||
|
||||
realm_auth = { path = "../auth" }
|
||||
realm_shared = { path = "../shared" }
|
||||
realm_shared = { path = "../shared" }
|
||||
|
||||
@@ -8,74 +8,79 @@ use sqlx::migrate::MigrateDatabase;
|
||||
use sqlx::{migrate, Sqlite, SqlitePool};
|
||||
use sqlx::sqlite::SqlitePoolOptions;
|
||||
use tarpc::{
|
||||
server::{Channel},
|
||||
tokio_serde::formats::Json,
|
||||
server::{Channel},
|
||||
tokio_serde::formats::Json,
|
||||
};
|
||||
use tarpc::server::incoming::Incoming;
|
||||
use tarpc::server::BaseChannel;
|
||||
use tracing::{info, subscriber, warn};
|
||||
use realm_auth::types::RealmAuthClient;
|
||||
use realm_server::server::RealmChatServer;
|
||||
use realm_server::types::RealmChat;
|
||||
|
||||
async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(fut);
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
dotenv().ok();
|
||||
dotenv().ok();
|
||||
|
||||
let subscriber = tracing_subscriber::fmt()
|
||||
.compact()
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.with_thread_ids(true)
|
||||
.with_target(false)
|
||||
.finish();
|
||||
let subscriber = tracing_subscriber::fmt()
|
||||
.compact()
|
||||
.with_file(true)
|
||||
.with_line_number(true)
|
||||
.with_thread_ids(true)
|
||||
.with_target(false)
|
||||
.finish();
|
||||
|
||||
subscriber::set_global_default(subscriber).unwrap();
|
||||
subscriber::set_global_default(subscriber).unwrap();
|
||||
|
||||
let database_url: &str = &env::var("DATABASE_URL").expect("DATABASE_URL must be set");
|
||||
let database_url: &str = &env::var("DATABASE_URL").expect("DATABASE_URL must be set");
|
||||
|
||||
if !Sqlite::database_exists(database_url).await.unwrap_or(false) {
|
||||
info!("Creating database {}", database_url);
|
||||
match Sqlite::create_database(database_url).await {
|
||||
Ok(_) => info!("Create db success"),
|
||||
Err(error) => panic!("error: {}", error),
|
||||
}
|
||||
} else {
|
||||
warn!("Database already exists");
|
||||
} // TODO: Do in Docker with Sqlx-cli
|
||||
if !Sqlite::database_exists(database_url).await.unwrap_or(false) {
|
||||
info!("Creating database {}", database_url);
|
||||
match Sqlite::create_database(database_url).await {
|
||||
Ok(_) => info!("Create db success"),
|
||||
Err(error) => panic!("error: {}", error),
|
||||
}
|
||||
} else {
|
||||
warn!("Database already exists");
|
||||
} // TODO: Do in Docker with Sqlx-cli
|
||||
|
||||
let db_pool = SqlitePool::connect(database_url).await.unwrap();
|
||||
let db_pool = SqlitePool::connect(database_url).await.unwrap();
|
||||
|
||||
info!("Running migrations...");
|
||||
migrate!().run(&db_pool).await?; // TODO: Do in Docker with Sqlx-cli
|
||||
info!("Migrations complete!");
|
||||
|
||||
let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), env::var("PORT").expect("PORT must be set").parse::<u16>().unwrap());
|
||||
info!("Running migrations...");
|
||||
migrate!().run(&db_pool).await?; // TODO: Do in Docker with Sqlx-cli
|
||||
info!("Migrations complete!");
|
||||
|
||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||
// to start up a serde-powered json serialization strategy over TCP.
|
||||
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
|
||||
info!("Listening on port {}", listener.local_addr().port());
|
||||
listener.config_mut().max_frame_length(usize::MAX);
|
||||
listener
|
||||
// Ignore accept errors.
|
||||
.filter_map(|r| future::ready(r.ok()))
|
||||
.map(BaseChannel::with_defaults)
|
||||
// Limit channels to 1 per IP.
|
||||
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
|
||||
// serve is generated by the service attribute. It takes as input any type implementing
|
||||
// the generated World trait.
|
||||
.map(|channel| {
|
||||
let server = RealmChatServer::new(env::var("SERVER_ID").expect("SERVER_ID must be set"), channel.transport().peer_addr().unwrap(), db_pool.clone());
|
||||
channel.execute(server.serve()).for_each(spawn)
|
||||
})
|
||||
// Max 10 channels.
|
||||
.buffer_unordered(10)
|
||||
.for_each(|_| async {})
|
||||
.await;
|
||||
let mut auth_transport = tarpc::serde_transport::tcp::connect((IpAddr::V6(Ipv6Addr::LOCALHOST), 5052), Json::default);
|
||||
auth_transport.config_mut().max_frame_length(usize::MAX);
|
||||
let auth_client = RealmAuthClient::new(tarpc::client::Config::default(), auth_transport.await?).spawn();
|
||||
|
||||
Ok(())
|
||||
let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), env::var("PORT").expect("PORT must be set").parse::<u16>().unwrap());
|
||||
|
||||
// JSON transport is provided by the json_transport tarpc module. It makes it easy
|
||||
// to start up a serde-powered json serialization strategy over TCP.
|
||||
let mut listener = tarpc::serde_transport::tcp::listen(&server_addr, Json::default).await?;
|
||||
info!("Listening on port {}", listener.local_addr().port());
|
||||
listener.config_mut().max_frame_length(usize::MAX);
|
||||
listener
|
||||
// Ignore accept errors.
|
||||
.filter_map(|r| future::ready(r.ok()))
|
||||
.map(BaseChannel::with_defaults)
|
||||
// Limit channels to 1 per IP.
|
||||
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
|
||||
// serve is generated by the service attribute. It takes as input any type implementing
|
||||
// the generated World trait.
|
||||
.map(|channel| {
|
||||
let server = RealmChatServer::new(env::var("SERVER_ID").expect("SERVER_ID must be set"), channel.transport().peer_addr().unwrap(), db_pool.clone(), auth_client.clone());
|
||||
channel.execute(server.serve()).for_each(spawn)
|
||||
})
|
||||
// Max 10 channels.
|
||||
.buffer_unordered(10)
|
||||
.for_each(|_| async {})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,10 +1,14 @@
|
||||
use std::env;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use std::time::Duration;
|
||||
use chrono::{DateTime, Utc};
|
||||
use moka::future::Cache;
|
||||
use sqlx::{FromRow, Pool, query_as, Sqlite};
|
||||
use sqlx::query;
|
||||
use sqlx::sqlite::SqliteQueryResult;
|
||||
use tarpc::context::Context;
|
||||
|
||||
use tracing::error;
|
||||
use realm_auth::types::RealmAuthClient;
|
||||
use realm_shared::types::ErrorCode::*;
|
||||
use realm_shared::types::ErrorCode;
|
||||
|
||||
@@ -13,24 +17,123 @@ use crate::types::{Message, MessageData, RealmChat, Room, User};
|
||||
#[derive(Clone)]
|
||||
pub struct RealmChatServer {
|
||||
pub server_id: String,
|
||||
pub domain: String,
|
||||
pub port: u16,
|
||||
pub socket: SocketAddr,
|
||||
pub db_pool: Pool<Sqlite>,
|
||||
pub typing_users: Vec<(i64, i64)> //NOTE: userid, roomid
|
||||
} //TODO: Cache for auth
|
||||
pub typing_users: Vec<(String, String)>, //NOTE: user.userid, room.roomid
|
||||
pub auth_client: RealmAuthClient,
|
||||
pub cache: Cache<String, String>,
|
||||
}
|
||||
|
||||
impl RealmChatServer {
|
||||
pub fn new(server_id: String, socket: SocketAddr, db_pool: Pool<Sqlite>, auth_client: RealmAuthClient) -> RealmChatServer {
|
||||
RealmChatServer {
|
||||
server_id,
|
||||
port: env::var("PORT").unwrap().parse::<u16>().unwrap(),
|
||||
domain: env::var("DOMAIN").expect("DOMAIN must be set"),
|
||||
socket,
|
||||
db_pool,
|
||||
typing_users: Vec::new(),
|
||||
auth_client,
|
||||
cache: Cache::builder()
|
||||
.max_capacity(10_000)
|
||||
.time_to_idle(Duration::from_secs(5*60))
|
||||
.time_to_live(Duration::from_secs(60*60))
|
||||
.build()
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_stoken_valid(&self, userid: &str, stoken: &str) -> bool {
|
||||
match self.cache.get(stoken).await {
|
||||
None => {
|
||||
let result = self.auth_client.server_token_validation(
|
||||
tarpc::context::current(), stoken.to_string(), userid.to_string(), self.server_id.clone(), self.domain.clone(), self.port)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(valid) => {
|
||||
if valid {
|
||||
self.cache.insert(stoken.to_string(), userid.to_string()).await;
|
||||
return true
|
||||
}
|
||||
false
|
||||
}
|
||||
Err(_) => {
|
||||
error!("Error validating server token for user, {}, with stoken {}", userid, stoken);
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(cached_username) => {
|
||||
if cached_username.eq(userid) {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn is_user_admin(&self, stoken: &str) -> bool {
|
||||
if let Some(userid) = self.cache.get(stoken).await {
|
||||
let result = query!("SELECT admin FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await;
|
||||
return match result {
|
||||
Ok(record) => {
|
||||
if record.admin {
|
||||
return true
|
||||
}
|
||||
false
|
||||
}
|
||||
Err(_) => false
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl RealmChat for RealmChatServer {
|
||||
async fn test(self, _: Context, name: String) -> String {
|
||||
format!("Hello, {name}!")
|
||||
}
|
||||
|
||||
async fn send_message(self, _: Context, auth_token: String, message: Message) -> Result<Message, ErrorCode> {
|
||||
//TODO: verify authentication somehow for edits and redactions
|
||||
async fn send_message(self, _: Context, stoken: String, message: Message) -> Result<Message, ErrorCode> {
|
||||
if !self.is_stoken_valid(&message.user.userid, &stoken).await {
|
||||
return Err(Unauthorized)
|
||||
}
|
||||
|
||||
match &message.data {
|
||||
MessageData::Edit(e) => {
|
||||
let ref_msg = self.get_message_from_id(tarpc::context::current(), stoken.clone(), e.referencing_id).await?;
|
||||
if !ref_msg.user.userid.eq(&message.user.userid) {
|
||||
return Err(Unauthorized)
|
||||
}
|
||||
}
|
||||
MessageData::Redaction(r)=> {
|
||||
let ref_msg = self.get_message_from_id(tarpc::context::current(), stoken.clone(), r.referencing_id).await?;
|
||||
if !ref_msg.user.userid.eq(&message.user.userid) {
|
||||
return Err(Unauthorized)
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let is_admin = self.is_user_admin(&stoken).await;
|
||||
let admin_only_send = query!("SELECT admin_only_send FROM room WHERE roomid = ?",
|
||||
message.room.roomid).fetch_one(&self.db_pool).await;
|
||||
if let Ok(record) = admin_only_send {
|
||||
if record.admin_only_send && !is_admin {
|
||||
return Err(Unauthorized)
|
||||
}
|
||||
} else {
|
||||
return Err(RoomNotFound)
|
||||
}
|
||||
|
||||
let result = match &message.data {
|
||||
MessageData::Text(text) => {
|
||||
MessageData::Text(text) => {
|
||||
query!("INSERT INTO message (timestamp, user, room, msg_type, msg_text) VALUES (?, ?, ?, 'text', ?)",
|
||||
message.timestamp, message.user.id, message.room.id, text)
|
||||
.execute(&self.db_pool).await
|
||||
.execute(&self.db_pool).await
|
||||
}
|
||||
MessageData::Attachment(attachment) => { todo!() }
|
||||
MessageData::Reply(reply) => {
|
||||
@@ -54,9 +157,9 @@ impl RealmChat for RealmChatServer {
|
||||
.execute(&self.db_pool).await
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
match result {
|
||||
Ok(ids) => {
|
||||
Ok(ids) => {
|
||||
//TODO: Tell everyone
|
||||
|
||||
Ok(message)
|
||||
@@ -65,29 +168,30 @@ impl RealmChat for RealmChatServer {
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_typing(self, _: Context, auth_token: String) -> ErrorCode { //TODO: auth for all of these
|
||||
async fn start_typing(self, _: Context, stoken: String, userid: String, roomid: String) -> ErrorCode { //TODO: auth for all of these
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn stop_typing(self, _: Context, auth_token: String) -> ErrorCode {
|
||||
async fn stop_typing(self, _: Context, stoken: String, userid: String, roomid: String) -> ErrorCode {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn keep_typing(self, _: Context, auth_token: String) -> ErrorCode {
|
||||
async fn keep_typing(self, _: Context, stoken: String, userid: String, roomid: String) -> ErrorCode {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn get_message_from_id(self, _: Context, auth_token: String, id: i64) -> Result<Message, ErrorCode> {
|
||||
//TODO: Auth for admin room
|
||||
async fn get_message_from_id(self, _: Context, stoken: String, id: i64) -> Result<Message, ErrorCode> {
|
||||
let is_admin = self.is_user_admin(&stoken).await;
|
||||
let result = sqlx::query("SELECT message.*,
|
||||
room.id AS 'room_id', room.roomid AS 'room_roomid', room.name AS 'room_name', room.admin_only_send AS 'room_admin_only_send', room.admin_only_view AS 'room_admin_only_view',
|
||||
user.id AS 'user_id', user.userid AS 'user_userid', user.name AS 'user_name', user.online AS 'user_online', user.admin AS 'user_admin'
|
||||
FROM message INNER JOIN room ON message.room = room.id INNER JOIN user ON message.user = user.id WHERE message.id = ?")
|
||||
FROM message INNER JOIN room ON message.room = room.id INNER JOIN user ON message.user = user.id WHERE message.id = ? AND room.admin_only_view = ? OR false")
|
||||
.bind(id)
|
||||
.bind(is_admin)
|
||||
.fetch_one(&self.db_pool).await;
|
||||
|
||||
match result {
|
||||
Ok(row) => {
|
||||
Ok(row) => {
|
||||
Ok(Message::from_row(&row).unwrap())
|
||||
},
|
||||
Err(_) => {
|
||||
@@ -96,36 +200,38 @@ impl RealmChat for RealmChatServer {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_messages_since(self, _: Context, auth_token: String, time: DateTime<Utc>) -> Result<Vec<Message>, ErrorCode> {
|
||||
async fn get_messages_since(self, _: Context, stoken: String, time: DateTime<Utc>) -> Result<Vec<Message>, ErrorCode> {
|
||||
//TODO: Auth for admin rooms
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn get_rooms(self, _: Context, auth_token: String) -> Result<Vec<Room>, ErrorCode> {
|
||||
//TODO: Auth for admin rooms!
|
||||
let result = query_as!(Room, "SELECT * FROM room").fetch_all(&self.db_pool).await;
|
||||
async fn get_rooms(self, _: Context, stoken: String) -> Result<Vec<Room>, ErrorCode> {
|
||||
let is_admin = self.is_user_admin(&stoken).await;
|
||||
let result = query_as!(
|
||||
Room, "SELECT * FROM room WHERE admin_only_view = ? OR false", is_admin).fetch_all(&self.db_pool).await;
|
||||
|
||||
match result {
|
||||
Ok(rooms) => Ok(rooms),
|
||||
Ok(rooms) => Ok(rooms),
|
||||
Err(_) => Err(Error),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_room(self, _: Context, auth_token: String, roomid: String) -> Result<Room, ErrorCode> {
|
||||
//TODO: Auth for admin rooms!
|
||||
let result = query_as!(Room, "SELECT * FROM room WHERE roomid = ?", roomid).fetch_one(&self.db_pool).await;
|
||||
|
||||
async fn get_room(self, _: Context, stoken: String, roomid: String) -> Result<Room, ErrorCode> {
|
||||
let is_admin = self.is_user_admin(&stoken).await;
|
||||
let result = query_as!(
|
||||
Room, "SELECT * FROM room WHERE roomid = ? AND admin_only_view = ? OR false", is_admin, roomid).fetch_one(&self.db_pool).await;
|
||||
|
||||
match result {
|
||||
Ok(room) => { Ok(room) },
|
||||
Ok(room) => { Ok(room) },
|
||||
Err(_) => Err(RoomNotFound),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_user(self, _: Context, userid: String) -> Result<User, ErrorCode> {
|
||||
let result = query_as!(User, "SELECT * FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await;
|
||||
|
||||
|
||||
match result {
|
||||
Ok(user) => { Ok(user) },
|
||||
Ok(user) => { Ok(user) },
|
||||
Err(_) => Err(UserNotFound),
|
||||
}
|
||||
}
|
||||
@@ -147,15 +253,4 @@ impl RealmChat for RealmChatServer {
|
||||
Err(_) => Err(Error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RealmChatServer {
|
||||
pub fn new(server_id: String, socket: SocketAddr, db_pool: Pool<Sqlite>) -> RealmChatServer {
|
||||
RealmChatServer {
|
||||
server_id,
|
||||
socket,
|
||||
db_pool,
|
||||
typing_users: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,16 +12,16 @@ pub trait RealmChat {
|
||||
async fn test(name: String) -> String;
|
||||
|
||||
//TODO: Any user authorized as themselves
|
||||
async fn send_message(auth_token: String, message: Message) -> Result<Message, ErrorCode>;
|
||||
async fn start_typing(auth_token: String) -> ErrorCode;
|
||||
async fn stop_typing(auth_token: String) -> ErrorCode;
|
||||
async fn keep_typing(auth_token: String) -> ErrorCode; //NOTE: If a keep alive hasn't been received in 5 seconds, stop typing
|
||||
async fn send_message(stoken: String, message: Message) -> Result<Message, ErrorCode>;
|
||||
async fn start_typing(stoken: String, userid: String, roomid: String) -> ErrorCode;
|
||||
async fn stop_typing(stoken: String, userid: String, roomid: String) -> ErrorCode;
|
||||
async fn keep_typing(stoken: String, userid: String, roomid: String) -> ErrorCode; //NOTE: If a keep alive hasn't been received in 5 seconds, stop typing
|
||||
|
||||
//NOTE: Any user can call, if they are in the server
|
||||
async fn get_message_from_id(auth_token: String, id: i64) -> Result<Message, ErrorCode>;
|
||||
async fn get_messages_since(auth_token: String, time: DateTime<Utc>) -> Result<Vec<Message>, ErrorCode>;
|
||||
async fn get_rooms(auth_token: String) -> Result<Vec<Room>, ErrorCode>;
|
||||
async fn get_room(auth_token: String, roomid: String) -> Result<Room, ErrorCode>;
|
||||
async fn get_message_from_id(stoken: String, id: i64) -> Result<Message, ErrorCode>;
|
||||
async fn get_messages_since(stoken: String, time: DateTime<Utc>) -> Result<Vec<Message>, ErrorCode>;
|
||||
async fn get_rooms(stoken: String) -> Result<Vec<Room>, ErrorCode>;
|
||||
async fn get_room(stoken: String, roomid: String) -> Result<Room, ErrorCode>;
|
||||
async fn get_user(userid: String) -> Result<User, ErrorCode>;
|
||||
async fn get_users() -> Result<Vec<User>, ErrorCode>;
|
||||
async fn get_online_users() -> Result<Vec<User>, ErrorCode>;
|
||||
|
||||
Reference in New Issue
Block a user