diff --git a/server/src/main.rs b/server/src/main.rs index 9ae1be0..4f73437 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -6,7 +6,6 @@ use futures::future::{self}; use futures::StreamExt; use sqlx::migrate::MigrateDatabase; use sqlx::{migrate, Sqlite, SqlitePool}; -use sqlx::sqlite::SqlitePoolOptions; use tarpc::{ server::{Channel}, tokio_serde::formats::Json, @@ -14,7 +13,6 @@ use tarpc::{ use tarpc::server::incoming::Incoming; use tarpc::server::BaseChannel; use tracing::{info, subscriber, warn}; -use realm_auth::types::RealmAuthClient; use realm_server::server::RealmChatServer; use realm_server::types::RealmChat; @@ -54,10 +52,6 @@ async fn main() -> anyhow::Result<()> { migrate!().run(&db_pool).await?; // TODO: Do in Docker with Sqlx-cli info!("Migrations complete!"); - let mut auth_transport = tarpc::serde_transport::tcp::connect((IpAddr::V6(Ipv6Addr::LOCALHOST), 5052), Json::default); - auth_transport.config_mut().max_frame_length(usize::MAX); - let auth_client = RealmAuthClient::new(tarpc::client::Config::default(), auth_transport.await?).spawn(); - let server_addr = (IpAddr::V6(Ipv6Addr::LOCALHOST), env::var("PORT").expect("PORT must be set").parse::().unwrap()); // JSON transport is provided by the json_transport tarpc module. It makes it easy @@ -74,7 +68,7 @@ async fn main() -> anyhow::Result<()> { // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = RealmChatServer::new(env::var("SERVER_ID").expect("SERVER_ID must be set"), channel.transport().peer_addr().unwrap(), db_pool.clone(), auth_client.clone()); + let server = RealmChatServer::new(env::var("SERVER_ID").expect("SERVER_ID must be set"), channel.transport().peer_addr().unwrap(), db_pool.clone()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/server/src/server.rs b/server/src/server.rs index f28e463..c4364db 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -1,11 +1,12 @@ use std::env; -use std::net::SocketAddr; +use std::net::{SocketAddr}; use std::time::Duration; use chrono::{DateTime, Utc}; use moka::future::Cache; use sqlx::{FromRow, Pool, query_as, Sqlite}; use sqlx::query; use tarpc::context::Context; +use tarpc::tokio_serde::formats::Json; use tracing::error; use realm_auth::types::RealmAuthClient; use realm_shared::types::ErrorCode::*; @@ -21,7 +22,6 @@ pub struct RealmChatServer { pub socket: SocketAddr, pub db_pool: Pool, pub typing_users: Vec<(String, String)>, //NOTE: user.userid, room.roomid - pub auth_client: RealmAuthClient, pub cache: Cache, } @@ -31,7 +31,7 @@ const FETCH_MESSAGE: &str = "SELECT message.*, FROM message INNER JOIN room ON message.room = room.id INNER JOIN user ON message.user = user.id WHERE room.admin_only_view = ? OR false"; impl RealmChatServer { - pub fn new(server_id: String, socket: SocketAddr, db_pool: Pool, auth_client: RealmAuthClient) -> RealmChatServer { + pub fn new(server_id: String, socket: SocketAddr, db_pool: Pool) -> RealmChatServer { RealmChatServer { server_id, port: env::var("PORT").unwrap().parse::().unwrap(), @@ -39,7 +39,6 @@ impl RealmChatServer { socket, db_pool, typing_users: Vec::new(), - auth_client, cache: Cache::builder() .max_capacity(10_000) .time_to_idle(Duration::from_secs(5*60)) @@ -55,7 +54,20 @@ impl RealmChatServer { return false; } - let result = self.auth_client.server_token_validation( + let user_domain = &userid[userid.find(':').unwrap()+1..]; + + let mut auth_transport = tarpc::serde_transport::tcp::connect((user_domain, 5052), Json::default); + auth_transport.config_mut().max_frame_length(usize::MAX); + let connected = match auth_transport.await { + Ok(out) => Some(out), + Err(_) => None + }; + if connected.is_none() { + return false; + } + let auth_client = RealmAuthClient::new(tarpc::client::Config::default(), connected.unwrap()).spawn(); + + let result = auth_client.server_token_validation( tarpc::context::current(), stoken.to_string(), userid.to_string(), self.server_id.clone(), self.domain.clone(), self.port) .await; @@ -303,7 +315,7 @@ impl RealmChat for RealmChatServer { self.inner_get_all_direct_replies(&stoken, head).await } - async fn get_reply_chain(self, ctx: Context, stoken: String, head: Message, depth: u8) -> Result { + async fn get_reply_chain(self, _: Context, stoken: String, head: Message, depth: u8) -> Result { self.inner_get_reply_chain(&stoken, head, depth).await }