Moved out clones

This commit is contained in:
2024-08-23 19:19:07 -04:00
Unverified
parent ac489c5592
commit fcc0ffb802
2 changed files with 92 additions and 63 deletions

View File

@@ -49,7 +49,7 @@ impl RealmChatServer {
} }
} }
pub async fn is_stoken_valid(&self, userid: &str, stoken: &str) -> bool { async fn is_stoken_valid(&self, userid: &str, stoken: &str) -> bool {
match self.cache.get(stoken).await { match self.cache.get(stoken).await {
None => { None => {
let result = self.auth_client.server_token_validation( let result = self.auth_client.server_token_validation(
@@ -74,7 +74,7 @@ impl RealmChatServer {
} }
} }
pub async fn is_user_admin(&self, stoken: &str) -> bool { async fn is_user_admin(&self, stoken: &str) -> bool {
if let Some(userid) = self.cache.get(stoken).await { if let Some(userid) = self.cache.get(stoken).await {
let result = query!("SELECT admin FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await; let result = query!("SELECT admin FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await;
return match result { return match result {
@@ -89,6 +89,78 @@ impl RealmChatServer {
} }
false false
} }
async fn inner_get_all_direct_replies(&self, stoken: &str, head: i64) -> Result<Vec<Message>, ErrorCode> {
let is_admin = self.is_user_admin(stoken).await;
let result = sqlx::query(&format!("{}{}", FETCH_MESSAGE, "AND message.referencing_id = ?"))
.bind(is_admin)
.bind(head)
.fetch_all(&self.db_pool).await;
match result {
Ok(rows) => Ok(Message::from_rows(rows).unwrap()),
Err(_) => Err(MessageNotFound),
}
}
async fn inner_get_reply_chain(&self, stoken: &str, head: Message, depth: u8) -> Result<ReplyChain, ErrorCode> {
if depth > 8 {
return Err(DepthTooLarge)
}
let direct_replies = self.inner_get_all_direct_replies(stoken, head.id).await?;
let replies = if direct_replies.is_empty() || depth == 0 {
None
} else {
let mut chains = Vec::new();
for reply in direct_replies {
chains.push(Box::pin(self.inner_get_reply_chain(stoken, reply, depth - 1)).await?);
}
Some(chains)
};
let chain = ReplyChain {
message: head,
replies,
};
Ok(chain)
}
async fn inner_get_room(&self, stoken: &str, roomid: &str) -> Result<Room, ErrorCode> {
let is_admin = self.is_user_admin(&stoken).await;
let result = query_as!(
Room, "SELECT * FROM room WHERE roomid = ? AND admin_only_view = ? OR false", is_admin, roomid).fetch_one(&self.db_pool).await;
match result {
Ok(room) => Ok(room),
Err(_) => Err(RoomNotFound),
}
}
async fn inner_get_user(&self, userid: &str) -> Result<User, ErrorCode> {
let result = query_as!(User, "SELECT * FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await;
match result {
Ok(user) => Ok(user),
Err(_) => Err(UserNotFound),
}
}
async fn inner_get_message(&self, stoken: &str, id: i64) -> Result<Message, ErrorCode> {
let is_admin = self.is_user_admin(&stoken).await;
let result = sqlx::query(&format!("{}{}", FETCH_MESSAGE, "AND message.id = ?"))
.bind(is_admin)
.bind(id)
.fetch_one(&self.db_pool).await;
match result {
Ok(row) => Ok(Message::from_row(&row).unwrap()),
Err(_) =>Err(MessageNotFound),
}
}
} }
impl RealmChat for RealmChatServer { impl RealmChat for RealmChatServer {
@@ -96,24 +168,24 @@ impl RealmChat for RealmChatServer {
format!("Hello, {name}!") format!("Hello, {name}!")
} }
async fn send_message(self, ctx: Context, stoken: String, mut message: Message) -> Result<Message, ErrorCode> { async fn send_message(self, _: Context, stoken: String, mut message: Message) -> Result<Message, ErrorCode> {
if !self.is_stoken_valid(&message.user.userid, &stoken).await { // Check sender userid if !self.is_stoken_valid(&message.user.userid, &stoken).await { // Check sender userid
return Err(Unauthorized) return Err(Unauthorized)
} }
// Assert all the data in message is correct // Assert all the data in message is correct
message.user = self.clone().get_user(ctx, message.user.userid).await.unwrap(); message.user = self.inner_get_user(&message.user.userid).await.unwrap();
match &message.data { // Check that the sender is the owner of the referencing msg match &message.data { // Check that the sender is the owner of the referencing msg
MessageData::Edit(e) => { MessageData::Edit(e) => {
let ref_msg = self.clone().get_message_from_id(ctx, stoken.clone(), e.referencing_id).await?; let ref_msg = self.inner_get_message(&stoken, e.referencing_id).await?;
if !ref_msg.user.userid.eq(&message.user.userid) { if !ref_msg.user.userid.eq(&message.user.userid) {
return Err(Unauthorized) return Err(Unauthorized)
} }
} }
MessageData::Redaction(r)=> { MessageData::Redaction(r)=> {
let ref_msg = self.clone().get_message_from_id(ctx, stoken.clone(), r.referencing_id).await?; let ref_msg = self.inner_get_message(&stoken, r.referencing_id).await?;
if !ref_msg.user.userid.eq(&message.user.userid) || !self.clone().is_user_admin(&stoken).await { if !ref_msg.user.userid.eq(&message.user.userid) || !self.is_user_admin(&stoken).await {
return Err(Unauthorized) return Err(Unauthorized)
} }
} }
@@ -132,7 +204,7 @@ impl RealmChat for RealmChatServer {
return Err(RoomNotFound) return Err(RoomNotFound)
} }
message.room = self.clone().get_room(ctx, stoken.clone(), message.room.roomid).await.unwrap(); message.room = self.inner_get_room(&stoken, &message.room.roomid).await.unwrap();
let result = match &message.data { let result = match &message.data {
MessageData::Text(text) => { MessageData::Text(text) => {
@@ -164,7 +236,7 @@ impl RealmChat for RealmChatServer {
}; };
match result { match result {
Ok(ids) => { Ok(id) => {
//TODO: Tell everyone //TODO: Tell everyone
Ok(message) Ok(message)
@@ -185,7 +257,7 @@ impl RealmChat for RealmChatServer {
todo!() todo!()
} }
async fn get_message_from_id(self, _: Context, stoken: String, id: i64) -> Result<Message, ErrorCode> { async fn get_message(self, _: Context, stoken: String, id: i64) -> Result<Message, ErrorCode> {
let is_admin = self.is_user_admin(&stoken).await; let is_admin = self.is_user_admin(&stoken).await;
let result = sqlx::query(&format!("{}{}", FETCH_MESSAGE, "AND message.id = ?")) let result = sqlx::query(&format!("{}{}", FETCH_MESSAGE, "AND message.id = ?"))
.bind(is_admin) .bind(is_admin)
@@ -208,7 +280,7 @@ impl RealmChat for RealmChatServer {
.bind(is_admin) .bind(is_admin)
.bind(time) .bind(time)
.fetch_all(&self.db_pool).await; .fetch_all(&self.db_pool).await;
match result { match result {
Ok(rows) => Ok(Message::from_rows(rows).unwrap()), Ok(rows) => Ok(Message::from_rows(rows).unwrap()),
Err(_) => Err(MalformedDBResponse) Err(_) => Err(MalformedDBResponse)
@@ -216,42 +288,11 @@ impl RealmChat for RealmChatServer {
} }
async fn get_all_direct_replies(self, _: Context, stoken: String, head: i64) -> Result<Vec<Message>, ErrorCode> { async fn get_all_direct_replies(self, _: Context, stoken: String, head: i64) -> Result<Vec<Message>, ErrorCode> {
let is_admin = self.is_user_admin(&stoken).await; self.inner_get_all_direct_replies(&stoken, head).await
let result = sqlx::query(&format!("{}{}", FETCH_MESSAGE, "AND message.referencing_id = ?"))
.bind(is_admin)
.bind(head)
.fetch_all(&self.db_pool).await;
match result {
Ok(rows) => Ok(Message::from_rows(rows).unwrap()),
Err(_) => Err(MessageNotFound),
}
} }
async fn get_reply_chain(self, ctx: Context, stoken: String, head: Message, depth: u8) -> Result<ReplyChain, ErrorCode> { async fn get_reply_chain(self, ctx: Context, stoken: String, head: Message, depth: u8) -> Result<ReplyChain, ErrorCode> {
if depth > 8 { self.inner_get_reply_chain(&stoken, head, depth).await
return Err(DepthTooLarge)
}
let direct_replies = self.clone().get_all_direct_replies(ctx, stoken.clone(), head.id).await?;
let replies = if direct_replies.is_empty() || depth == 0 {
None
} else {
let mut chains = Vec::new();
for reply in direct_replies {
chains.push(Box::pin(self.clone().get_reply_chain(ctx, stoken.clone(), reply, depth - 1)).await?);
}
Some(chains)
};
let chain = ReplyChain {
message: head,
replies,
};
Ok(chain)
} }
async fn get_rooms(self, _: Context, stoken: String) -> Result<Vec<Room>, ErrorCode> { async fn get_rooms(self, _: Context, stoken: String) -> Result<Vec<Room>, ErrorCode> {
@@ -266,23 +307,11 @@ impl RealmChat for RealmChatServer {
} }
async fn get_room(self, _: Context, stoken: String, roomid: String) -> Result<Room, ErrorCode> { async fn get_room(self, _: Context, stoken: String, roomid: String) -> Result<Room, ErrorCode> {
let is_admin = self.is_user_admin(&stoken).await; self.inner_get_room(&stoken, &roomid).await
let result = query_as!(
Room, "SELECT * FROM room WHERE roomid = ? AND admin_only_view = ? OR false", is_admin, roomid).fetch_one(&self.db_pool).await;
match result {
Ok(room) => { Ok(room) },
Err(_) => Err(RoomNotFound),
}
} }
async fn get_user(self, _: Context, userid: String) -> Result<User, ErrorCode> { async fn get_user(self, _: Context, userid: String) -> Result<User, ErrorCode> {
let result = query_as!(User, "SELECT * FROM user WHERE userid = ?", userid).fetch_one(&self.db_pool).await; self.inner_get_user(&userid).await
match result {
Ok(user) => { Ok(user) },
Err(_) => Err(UserNotFound),
}
} }
async fn get_users(self, _: Context) -> Result<Vec<User>, ErrorCode> { async fn get_users(self, _: Context) -> Result<Vec<User>, ErrorCode> {

View File

@@ -18,7 +18,7 @@ pub trait RealmChat {
async fn keep_typing(stoken: String, userid: String, roomid: String) -> ErrorCode; //NOTE: If a keep alive hasn't been received in 5 seconds, stop typing async fn keep_typing(stoken: String, userid: String, roomid: String) -> ErrorCode; //NOTE: If a keep alive hasn't been received in 5 seconds, stop typing
//NOTE: Any user can call, if they are in the server //NOTE: Any user can call, if they are in the server
async fn get_message_from_id(stoken: String, id: i64) -> Result<Message, ErrorCode>; async fn get_message(stoken: String, id: i64) -> Result<Message, ErrorCode>;
async fn get_messages_since(stoken: String, time: DateTime<Utc>) -> Result<Vec<Message>, ErrorCode>; async fn get_messages_since(stoken: String, time: DateTime<Utc>) -> Result<Vec<Message>, ErrorCode>;
async fn get_all_direct_replies(stoken: String, head: i64) -> Result<Vec<Message>, ErrorCode>; async fn get_all_direct_replies(stoken: String, head: i64) -> Result<Vec<Message>, ErrorCode>;
async fn get_reply_chain(stoken: String, head: Message, depth: u8) -> Result<ReplyChain, ErrorCode>; async fn get_reply_chain(stoken: String, head: Message, depth: u8) -> Result<ReplyChain, ErrorCode>;
@@ -54,11 +54,11 @@ pub trait FromRows<R: Row>: Sized {
impl FromRows<SqliteRow> for Message { impl FromRows<SqliteRow> for Message {
fn from_rows(rows: Vec<SqliteRow>) -> sqlx::Result<Vec<Self>> { fn from_rows(rows: Vec<SqliteRow>) -> sqlx::Result<Vec<Self>> {
let mut messages = Vec::new(); let mut messages = Vec::new();
for row in rows { for row in rows {
messages.push(Message::from_row(&row)?); messages.push(Message::from_row(&row)?);
} }
Ok(messages) Ok(messages)
} }
} }