From fcc0ffb80268fe5ea52a7b05eed203a7a2127c1a Mon Sep 17 00:00:00 2001 From: Joshua Higgins Date: Fri, 23 Aug 2024 19:19:07 -0400 Subject: [PATCH] Moved out clones --- server/src/server.rs | 149 ++++++++++++++++++++++++++----------------- server/src/types.rs | 6 +- 2 files changed, 92 insertions(+), 63 deletions(-) diff --git a/server/src/server.rs b/server/src/server.rs index efdb4c5..6fe1673 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -49,7 +49,7 @@ impl RealmChatServer { } } - pub async fn is_stoken_valid(&self, userid: &str, stoken: &str) -> bool { + async fn is_stoken_valid(&self, userid: &str, stoken: &str) -> bool { match self.cache.get(stoken).await { None => { let result = self.auth_client.server_token_validation( @@ -74,7 +74,7 @@ impl RealmChatServer { } } - pub async fn is_user_admin(&self, stoken: &str) -> bool { + async fn is_user_admin(&self, stoken: &str) -> bool { if let Some(userid) = self.cache.get(stoken).await { let result = query!("SELECT admin FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await; return match result { @@ -89,6 +89,78 @@ impl RealmChatServer { } false } + + async fn inner_get_all_direct_replies(&self, stoken: &str, head: i64) -> Result, ErrorCode> { + let is_admin = self.is_user_admin(stoken).await; + let result = sqlx::query(&format!("{}{}", FETCH_MESSAGE, "AND message.referencing_id = ?")) + .bind(is_admin) + .bind(head) + .fetch_all(&self.db_pool).await; + + match result { + Ok(rows) => Ok(Message::from_rows(rows).unwrap()), + Err(_) => Err(MessageNotFound), + } + } + + async fn inner_get_reply_chain(&self, stoken: &str, head: Message, depth: u8) -> Result { + if depth > 8 { + return Err(DepthTooLarge) + } + + let direct_replies = self.inner_get_all_direct_replies(stoken, head.id).await?; + let replies = if direct_replies.is_empty() || depth == 0 { + None + } else { + let mut chains = Vec::new(); + + for reply in direct_replies { + chains.push(Box::pin(self.inner_get_reply_chain(stoken, reply, depth - 1)).await?); + } + + Some(chains) + }; + + let chain = ReplyChain { + message: head, + replies, + }; + + Ok(chain) + } + + async fn inner_get_room(&self, stoken: &str, roomid: &str) -> Result { + let is_admin = self.is_user_admin(&stoken).await; + let result = query_as!( + Room, "SELECT * FROM room WHERE roomid = ? AND admin_only_view = ? OR false", is_admin, roomid).fetch_one(&self.db_pool).await; + + match result { + Ok(room) => Ok(room), + Err(_) => Err(RoomNotFound), + } + } + + async fn inner_get_user(&self, userid: &str) -> Result { + let result = query_as!(User, "SELECT * FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await; + + match result { + Ok(user) => Ok(user), + Err(_) => Err(UserNotFound), + } + } + + async fn inner_get_message(&self, stoken: &str, id: i64) -> Result { + let is_admin = self.is_user_admin(&stoken).await; + let result = sqlx::query(&format!("{}{}", FETCH_MESSAGE, "AND message.id = ?")) + .bind(is_admin) + .bind(id) + .fetch_one(&self.db_pool).await; + + match result { + Ok(row) => Ok(Message::from_row(&row).unwrap()), + Err(_) =>Err(MessageNotFound), + } + } } impl RealmChat for RealmChatServer { @@ -96,24 +168,24 @@ impl RealmChat for RealmChatServer { format!("Hello, {name}!") } - async fn send_message(self, ctx: Context, stoken: String, mut message: Message) -> Result { + async fn send_message(self, _: Context, stoken: String, mut message: Message) -> Result { if !self.is_stoken_valid(&message.user.userid, &stoken).await { // Check sender userid return Err(Unauthorized) } - + // Assert all the data in message is correct - message.user = self.clone().get_user(ctx, message.user.userid).await.unwrap(); - + message.user = self.inner_get_user(&message.user.userid).await.unwrap(); + match &message.data { // Check that the sender is the owner of the referencing msg MessageData::Edit(e) => { - let ref_msg = self.clone().get_message_from_id(ctx, stoken.clone(), e.referencing_id).await?; + let ref_msg = self.inner_get_message(&stoken, e.referencing_id).await?; if !ref_msg.user.userid.eq(&message.user.userid) { return Err(Unauthorized) } } MessageData::Redaction(r)=> { - let ref_msg = self.clone().get_message_from_id(ctx, stoken.clone(), r.referencing_id).await?; - if !ref_msg.user.userid.eq(&message.user.userid) || !self.clone().is_user_admin(&stoken).await { + let ref_msg = self.inner_get_message(&stoken, r.referencing_id).await?; + if !ref_msg.user.userid.eq(&message.user.userid) || !self.is_user_admin(&stoken).await { return Err(Unauthorized) } } @@ -132,7 +204,7 @@ impl RealmChat for RealmChatServer { return Err(RoomNotFound) } - message.room = self.clone().get_room(ctx, stoken.clone(), message.room.roomid).await.unwrap(); + message.room = self.inner_get_room(&stoken, &message.room.roomid).await.unwrap(); let result = match &message.data { MessageData::Text(text) => { @@ -164,7 +236,7 @@ impl RealmChat for RealmChatServer { }; match result { - Ok(ids) => { + Ok(id) => { //TODO: Tell everyone Ok(message) @@ -185,7 +257,7 @@ impl RealmChat for RealmChatServer { todo!() } - async fn get_message_from_id(self, _: Context, stoken: String, id: i64) -> Result { + async fn get_message(self, _: Context, stoken: String, id: i64) -> Result { let is_admin = self.is_user_admin(&stoken).await; let result = sqlx::query(&format!("{}{}", FETCH_MESSAGE, "AND message.id = ?")) .bind(is_admin) @@ -208,7 +280,7 @@ impl RealmChat for RealmChatServer { .bind(is_admin) .bind(time) .fetch_all(&self.db_pool).await; - + match result { Ok(rows) => Ok(Message::from_rows(rows).unwrap()), Err(_) => Err(MalformedDBResponse) @@ -216,42 +288,11 @@ impl RealmChat for RealmChatServer { } async fn get_all_direct_replies(self, _: Context, stoken: String, head: i64) -> Result, ErrorCode> { - let is_admin = self.is_user_admin(&stoken).await; - let result = sqlx::query(&format!("{}{}", FETCH_MESSAGE, "AND message.referencing_id = ?")) - .bind(is_admin) - .bind(head) - .fetch_all(&self.db_pool).await; - - match result { - Ok(rows) => Ok(Message::from_rows(rows).unwrap()), - Err(_) => Err(MessageNotFound), - } + self.inner_get_all_direct_replies(&stoken, head).await } async fn get_reply_chain(self, ctx: Context, stoken: String, head: Message, depth: u8) -> Result { - if depth > 8 { - return Err(DepthTooLarge) - } - - let direct_replies = self.clone().get_all_direct_replies(ctx, stoken.clone(), head.id).await?; - let replies = if direct_replies.is_empty() || depth == 0 { - None - } else { - let mut chains = Vec::new(); - - for reply in direct_replies { - chains.push(Box::pin(self.clone().get_reply_chain(ctx, stoken.clone(), reply, depth - 1)).await?); - } - - Some(chains) - }; - - let chain = ReplyChain { - message: head, - replies, - }; - - Ok(chain) + self.inner_get_reply_chain(&stoken, head, depth).await } async fn get_rooms(self, _: Context, stoken: String) -> Result, ErrorCode> { @@ -266,23 +307,11 @@ impl RealmChat for RealmChatServer { } async fn get_room(self, _: Context, stoken: String, roomid: String) -> Result { - let is_admin = self.is_user_admin(&stoken).await; - let result = query_as!( - Room, "SELECT * FROM room WHERE roomid = ? AND admin_only_view = ? OR false", is_admin, roomid).fetch_one(&self.db_pool).await; - - match result { - Ok(room) => { Ok(room) }, - Err(_) => Err(RoomNotFound), - } + self.inner_get_room(&stoken, &roomid).await } async fn get_user(self, _: Context, userid: String) -> Result { - let result = query_as!(User, "SELECT * FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await; - - match result { - Ok(user) => { Ok(user) }, - Err(_) => Err(UserNotFound), - } + self.inner_get_user(&userid).await } async fn get_users(self, _: Context) -> Result, ErrorCode> { diff --git a/server/src/types.rs b/server/src/types.rs index 61e4bb9..ddfce55 100644 --- a/server/src/types.rs +++ b/server/src/types.rs @@ -18,7 +18,7 @@ pub trait RealmChat { async fn keep_typing(stoken: String, userid: String, roomid: String) -> ErrorCode; //NOTE: If a keep alive hasn't been received in 5 seconds, stop typing //NOTE: Any user can call, if they are in the server - async fn get_message_from_id(stoken: String, id: i64) -> Result; + async fn get_message(stoken: String, id: i64) -> Result; async fn get_messages_since(stoken: String, time: DateTime) -> Result, ErrorCode>; async fn get_all_direct_replies(stoken: String, head: i64) -> Result, ErrorCode>; async fn get_reply_chain(stoken: String, head: Message, depth: u8) -> Result; @@ -54,11 +54,11 @@ pub trait FromRows: Sized { impl FromRows for Message { fn from_rows(rows: Vec) -> sqlx::Result> { let mut messages = Vec::new(); - + for row in rows { messages.push(Message::from_row(&row)?); } - + Ok(messages) } }