chore: Chat history (#546)

* chore: imple sql curd

* chore: update chat

* chore: select messages

* chore: update test

* chore: update schema

* chore: update
This commit is contained in:
Nathan.fooo 2024-05-11 20:41:21 +08:00 committed by GitHub
parent 348217a117
commit 4c00ddd593
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 1013 additions and 24 deletions

View file

@ -0,0 +1,22 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT EXISTS(\n SELECT 1\n FROM af_workspace\n WHERE workspace_id = $1\n ) AS user_exists;\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "user_exists",
"type_info": "Bool"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
null
]
},
"hash": "291f0916b7868f3598b50f659689b9c77d34112c2a2fff9fc04775da9f97e46d"
}

View file

@ -0,0 +1,14 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE af_chat\n SET deleted_at = now()\n WHERE chat_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "2c496e29533dd27117fbb688ba2324f04d7cc306181fcf3f82079d5639f632c4"
}

View file

@ -0,0 +1,17 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO af_chat (chat_id, name, workspace_id, rag_ids)\n VALUES ($1, $2, $3, $4)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Uuid",
"Jsonb"
]
},
"nullable": []
},
"hash": "3bb5b82d46c55bbfd51319310a3cd065c4b796462a1ddf3c17617ee65ce9961a"
}

View file

@ -0,0 +1,40 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT message_id, content, created_at, author\n FROM af_chat_messages\n WHERE chat_id = $1\n ORDER BY created_at ASC\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "message_id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "content",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 3,
"name": "author",
"type_info": "Jsonb"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false,
false,
false,
false
]
},
"hash": "533ef0f5237ca12ce0c6ca1dc938cc8dd34603b256f36dc683013264018332fc"
}

View file

@ -0,0 +1,22 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT COUNT(*)\n FROM public.af_chat_messages\n WHERE chat_id = $1\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "count",
"type_info": "Int8"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
null
]
},
"hash": "5e0d58f612425e1cf36dfc7f56691cfb8f6def1a3d29645922cb437d11ce62ef"
}

View file

@ -0,0 +1,23 @@
{
"db_name": "PostgreSQL",
"query": "SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id > $2)",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "exists",
"type_info": "Bool"
}
],
"parameters": {
"Left": [
"Uuid",
"Int8"
]
},
"nullable": [
null
]
},
"hash": "a3ab30d48e4a10aff1fbfa9dbc5d275a06598610bc471893c8c0febfc36c4737"
}

View file

@ -0,0 +1,23 @@
{
"db_name": "PostgreSQL",
"query": "SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id < $2)",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "exists",
"type_info": "Bool"
}
],
"parameters": {
"Left": [
"Uuid",
"Int8"
]
},
"nullable": [
null
]
},
"hash": "d2e87c077e5702cd57a88e23e1eabe4b0badd98ef99da1b185bffa8d5c9ed298"
}

View file

@ -0,0 +1,16 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO af_chat_messages (chat_id, author, content)\n VALUES ($1, $2, $3)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Jsonb",
"Text"
]
},
"nullable": []
},
"hash": "e2c448a5fad523713d2f755935f8cc1e9b66c4b9c40e9e688bc34f5de127f33a"
}

View file

@ -0,0 +1,52 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT *\n FROM af_chat\n WHERE chat_id = $1 AND deleted_at IS NULL\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "chat_id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 2,
"name": "deleted_at",
"type_info": "Timestamptz"
},
{
"ordinal": 3,
"name": "name",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "rag_ids",
"type_info": "Jsonb"
},
{
"ordinal": 5,
"name": "workspace_id",
"type_info": "Uuid"
}
],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": [
false,
false,
true,
false,
false,
false
]
},
"hash": "fb21df2827de97055cdc1c493b079b29667f75b18169c909c4c8341697fd0105"
}

7
Cargo.lock generated
View file

@ -522,12 +522,13 @@ dependencies = [
[[package]]
name = "appflowy-ai-client"
version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-AI?tag=0.0.7#115f185e15f895abf750415584248689b9ff8ea2"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-AI-Client?tag=0.0.1#ea7caaad27bb9773bcca1d67f7785fcd7d4fbc46"
dependencies = [
"anyhow",
"reqwest 0.12.4",
"serde",
"serde_json",
"serde_repr",
"thiserror",
"tracing",
]
@ -6136,9 +6137,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.36.0"
version = "1.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931"
checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787"
dependencies = [
"backtrace",
"bytes",

View file

@ -161,6 +161,7 @@ collab-rt-protocol = { path = "libs/collab-rt-protocol" }
database = { path = "libs/database" }
database-entity = { path = "libs/database-entity" }
shared-entity = { path = "libs/shared-entity" }
gotrue-entity = { path = "libs/gotrue-entity" }
access-control = { path = "libs/access-control" }
app-error = { path = "libs/app-error" }
async-trait = "0.1.77"
@ -209,7 +210,7 @@ inherits = "release"
debug = true
[patch.crates-io]
appflowy-ai-client = { git = "https://github.com/AppFlowy-IO/AppFlowy-AI", tag = "0.0.7" }
appflowy-ai-client = { git = "https://github.com/AppFlowy-IO/AppFlowy-AI-Client", tag = "0.0.1" }
# It's diffcult to resovle different version with the same crate used in AppFlowy Frontend and the Client-API crate.
# So using patch to workaround this issue.

View file

@ -12,8 +12,6 @@ reqwest = { version = "0.11.27", features = ["stream", "json"] }
anyhow = "1.0.79"
serde_repr = "0.1.18"
gotrue = { path = "../gotrue" }
gotrue-entity = { path = "../gotrue-entity" }
shared-entity = { path = "../shared-entity" }
tracing = { version = "0.1" }
thiserror = "1.0.56"
bytes = "1.5"
@ -28,22 +26,25 @@ bincode = "1.3.3"
url = "2.5.0"
mime = "0.3.17"
tokio-stream = { version = "0.1.14" }
collab-rt-entity = { workspace = true }
chrono = "0.4"
client-websocket = { workspace = true, features = ["native-tls"] }
semver = "1.0.22"
collab = { workspace = true, optional = true }
collab-entity = { workspace = true }
yrs = { workspace = true, optional = true }
collab-rt-protocol = { workspace = true }
workspace-template = { workspace = true, optional = true }
serde_json.workspace = true
serde.workspace = true
database-entity.workspace = true
app-error = { workspace = true, features = ["tokio_error", "bincode_error"] }
scraper = { version = "0.17.1", optional = true }
collab-entity = { workspace = true }
gotrue-entity = { workspace = true }
shared-entity = { workspace = true }
collab-rt-entity = { workspace = true }
database-entity.workspace = true
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
tokio-retry = "0.3"
tokio-util = "0.7"

View file

@ -0,0 +1,104 @@
use crate::util::validate_not_empty_str;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use validator::Validate;
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct CreateChatParams {
#[validate(custom = "validate_not_empty_str")]
pub chat_id: String,
#[validate(custom = "validate_not_empty_str")]
pub name: String,
pub rag_ids: Vec<String>,
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct UpdateChatParams {
#[validate(custom = "validate_not_empty_str")]
pub chat_id: String,
#[validate(custom = "validate_not_empty_str")]
pub name: Option<String>,
pub rag_ids: Option<Vec<String>>,
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct CreateChatMessageParams {
#[validate(custom = "validate_not_empty_str")]
pub content: String,
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct GetChatMessageParams {
#[validate(custom = "validate_not_empty_str")]
pub chat_id: String,
pub offset: MessageOffset,
pub limit: u64,
}
impl GetChatMessageParams {
pub fn offset(chat_id: String, offset: u64, limit: u64) -> Self {
Self {
chat_id,
offset: MessageOffset::Offset(offset),
limit,
}
}
pub fn after_message_id(chat_id: String, after_message_id: i64, limit: u64) -> Self {
Self {
chat_id,
offset: MessageOffset::AfterMessageId(after_message_id),
limit,
}
}
pub fn before_message_id(chat_id: String, before_message_id: i64, limit: u64) -> Self {
Self {
chat_id,
offset: MessageOffset::BeforeMessageId(before_message_id),
limit,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageOffset {
Offset(u64),
AfterMessageId(i64),
BeforeMessageId(i64),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub author: ChatAuthor,
pub message_id: i64,
pub content: String,
pub created_at: DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RepeatedChatMessage {
pub messages: Vec<ChatMessage>,
pub has_more: bool,
pub total: i64,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub enum ChatAuthor {
#[default]
Unknown,
Human {
uid: i64,
},
System,
AI,
}
impl From<serde_json::Value> for ChatAuthor {
fn from(value: serde_json::Value) -> Self {
serde_json::from_value::<ChatAuthor>(value).unwrap_or_default()
}
}

View file

@ -1,3 +1,4 @@
use crate::util::{validate_not_empty_payload, validate_not_empty_str};
use chrono::{DateTime, Utc};
use collab_entity::CollabType;
use serde::{Deserialize, Serialize};
@ -9,7 +10,7 @@ use std::ops::{Deref, DerefMut};
use std::str::FromStr;
use tracing::error;
use uuid::Uuid;
use validator::{Validate, ValidationError};
use validator::Validate;
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct CreateCollabParams {
@ -113,20 +114,6 @@ pub struct DeleteCollabParams {
pub workspace_id: String,
}
fn validate_not_empty_str(s: &str) -> Result<(), ValidationError> {
if s.is_empty() {
return Err(ValidationError::new("should not be empty string"));
}
Ok(())
}
fn validate_not_empty_payload(payload: &[u8]) -> Result<(), ValidationError> {
if payload.is_empty() {
return Err(ValidationError::new("should not be empty payload"));
}
Ok(())
}
#[derive(Debug, Clone, Validate, Serialize, Deserialize)]
pub struct InsertSnapshotParams {
#[validate(custom = "validate_not_empty_str")]

View file

@ -1 +1,3 @@
pub mod chat;
pub mod dto;
mod util;

View file

@ -0,0 +1,15 @@
use validator::ValidationError;
pub(crate) fn validate_not_empty_str(s: &str) -> Result<(), ValidationError> {
if s.is_empty() {
return Err(ValidationError::new("should not be empty string"));
}
Ok(())
}
pub(crate) fn validate_not_empty_payload(payload: &[u8]) -> Result<(), ValidationError> {
if payload.is_empty() {
return Err(ValidationError::new("should not be empty payload"));
}
Ok(())
}

View file

@ -0,0 +1,280 @@
use crate::pg_row::AFChatRow;
use crate::workspace::is_workspace_exist;
use anyhow::anyhow;
use app_error::AppError;
use chrono::{DateTime, Utc};
use database_entity::chat::{
ChatAuthor, ChatMessage, CreateChatMessageParams, CreateChatParams, GetChatMessageParams,
MessageOffset, RepeatedChatMessage, UpdateChatParams,
};
use serde_json::json;
use sqlx::postgres::PgArguments;
use sqlx::{Arguments, Executor, Postgres, Transaction};
use std::ops::DerefMut;
use std::str::FromStr;
use uuid::Uuid;
pub async fn insert_chat(
txn: &mut Transaction<'_, Postgres>,
workspace_id: &str,
params: CreateChatParams,
) -> Result<(), AppError> {
let chat_id = Uuid::from_str(&params.chat_id)?;
let workspace_id = Uuid::from_str(workspace_id)?;
if !is_workspace_exist(txn.deref_mut(), &workspace_id).await? {
return Err(AppError::RecordNotFound(format!(
"workspace with given id:{} is not found",
workspace_id
)));
}
let rag_ids = json!(params.rag_ids);
sqlx::query!(
r#"
INSERT INTO af_chat (chat_id, name, workspace_id, rag_ids)
VALUES ($1, $2, $3, $4)
"#,
chat_id,
params.name,
workspace_id,
rag_ids,
)
.execute(txn.deref_mut())
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to insert chat: {}", err)))?;
Ok(())
}
/// Updates specific fields of a chat record in the database using transactional queries.
///
/// This function dynamically builds an SQL `UPDATE` query based on the provided parameters to
/// update fields of a specific chat record identified by `chat_id`. It uses a transaction to ensure
/// that the update operation is atomic.
///
pub async fn update_chat(
txn: &mut Transaction<'_, Postgres>,
chat_id: &Uuid,
params: UpdateChatParams,
) -> Result<(), AppError> {
let mut query_parts = vec!["UPDATE af_chat SET".to_string()];
let mut args = PgArguments::default();
let mut current_param_pos = 1; // Start counting SQL parameters from 1
if let Some(ref name) = params.name {
query_parts.push(format!("name = ${}", current_param_pos));
args.add(name);
current_param_pos += 1;
}
if let Some(ref rag_ids) = params.rag_ids {
query_parts.push(format!("rag_ids = ${}", current_param_pos));
let rag_ids_json = json!(rag_ids);
args.add(rag_ids_json);
current_param_pos += 1;
}
query_parts.push(format!("WHERE chat_id = ${}", current_param_pos));
args.add(chat_id);
let query = query_parts.join(", ") + ";";
let query = sqlx::query_with(&query, args);
query.execute(txn.deref_mut()).await?;
Ok(())
}
pub async fn delete_chat(
txn: &mut Transaction<'_, Postgres>,
chat_id: &str,
) -> Result<(), AppError> {
let chat_id = Uuid::from_str(chat_id)?;
sqlx::query!(
r#"
UPDATE af_chat
SET deleted_at = now()
WHERE chat_id = $1
"#,
chat_id,
)
.execute(txn.deref_mut())
.await?;
Ok(())
}
pub async fn select_chat<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
chat_id: &str,
) -> Result<AFChatRow, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let row = sqlx::query_as!(
AFChatRow,
r#"
SELECT *
FROM af_chat
WHERE chat_id = $1 AND deleted_at IS NULL
"#,
&chat_id,
)
.fetch_optional(executor)
.await?;
match row {
Some(row) => Ok(row),
None => Err(AppError::RecordNotFound(format!(
"chat with given id:{} is not found",
chat_id
))),
}
}
pub async fn insert_chat_message<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
author: ChatAuthor,
chat_id: &str,
params: CreateChatMessageParams,
) -> Result<(), AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let author = json!(author);
sqlx::query!(
r#"
INSERT INTO af_chat_messages (chat_id, author, content)
VALUES ($1, $2, $3)
"#,
chat_id,
author,
params.content,
)
.execute(executor)
.await
.map_err(|err| AppError::Internal(anyhow!("Failed to insert chat message: {}", err)))?;
Ok(())
}
pub async fn select_chat_messages(
txn: &mut Transaction<'_, Postgres>,
chat_id: &str,
params: GetChatMessageParams,
) -> Result<RepeatedChatMessage, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let mut query = r#"
SELECT message_id, content, created_at, author
FROM af_chat_messages
WHERE chat_id = $1
"#
.to_string();
let mut args = PgArguments::default();
args.add(&chat_id);
// Message IDs: 1 2 3 4 5
// AfterMessageId(3, 5): [4] [5] has_more = false
// BeforeMessageId(3, 5): [1] [2] has_more = false
// Offset(3, 5): [4] [5] has_more = true
match params.offset {
MessageOffset::AfterMessageId(after_message_id) => {
query += " AND message_id > $2";
args.add(after_message_id);
query += " ORDER BY message_id ASC LIMIT $3";
args.add(params.limit as i64);
},
MessageOffset::Offset(offset) => {
query += " ORDER BY message_id ASC LIMIT $2 OFFSET $3";
args.add(params.limit as i64);
args.add(offset as i64);
},
MessageOffset::BeforeMessageId(before_message_id) => {
query += " AND message_id < $2";
args.add(before_message_id);
query += " ORDER BY message_id ASC LIMIT $3";
args.add(params.limit as i64);
},
}
let rows: Vec<(i64, String, DateTime<Utc>, serde_json::Value)> =
sqlx::query_as_with(&query, args)
.fetch_all(txn.deref_mut())
.await?;
let messages = rows
.into_iter()
.map(|(message_id, content, created_at, author)| ChatMessage {
author: serde_json::from_value::<ChatAuthor>(author).unwrap_or_default(),
message_id,
content,
created_at,
})
.collect::<Vec<ChatMessage>>();
let total = sqlx::query_scalar!(
r#"
SELECT COUNT(*)
FROM public.af_chat_messages
WHERE chat_id = $1
"#,
&chat_id
)
.fetch_one(txn.deref_mut())
.await?
.unwrap_or(0);
let has_more = match params.offset {
MessageOffset::AfterMessageId(_) => {
if messages.is_empty() {
false
} else {
sqlx::query!(
"SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id > $2)",
&chat_id,
messages.last().as_ref().unwrap().message_id
)
.fetch_one(txn.deref_mut())
.await?
.exists
.unwrap_or(false)
}
},
MessageOffset::Offset(offset) => (offset + params.limit) < total as u64,
MessageOffset::BeforeMessageId(_) => {
if messages.is_empty() {
false
} else {
sqlx::query!(
"SELECT EXISTS(SELECT 1 FROM af_chat_messages WHERE chat_id = $1 AND message_id < $2)",
&chat_id,
messages[0].message_id
)
.fetch_one(txn.deref_mut())
.await?
.exists
.unwrap_or(false)
}
},
};
Ok(RepeatedChatMessage {
messages,
total,
has_more,
})
}
pub async fn get_all_chat_messages<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
chat_id: &str,
) -> Result<Vec<ChatMessage>, AppError> {
let chat_id = Uuid::from_str(chat_id)?;
let messages: Vec<ChatMessage> = sqlx::query_as!(
ChatMessage,
r#"
SELECT message_id, content, created_at, author
FROM af_chat_messages
WHERE chat_id = $1
ORDER BY created_at ASC
"#,
chat_id,
)
.fetch_all(executor)
.await?;
Ok(messages)
}

View file

@ -0,0 +1 @@
pub mod chat_ops;

View file

@ -1,3 +1,4 @@
pub mod chat;
pub mod collab;
pub mod file;
pub mod history;

View file

@ -195,3 +195,20 @@ pub struct AFCollabRowMeta {
pub deleted_at: Option<DateTime<Utc>>,
pub created_at: Option<DateTime<Utc>>,
}
#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
pub struct AFChatRow {
pub chat_id: Uuid,
pub name: String,
pub created_at: DateTime<Utc>,
pub deleted_at: Option<DateTime<Utc>>,
pub rag_ids: serde_json::Value,
pub workspace_id: Uuid,
}
#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
pub struct AFChatMessageRow {
pub message_id: i64,
pub chat_id: Uuid,
pub content: String,
pub created_at: DateTime<Utc>,
}

View file

@ -725,3 +725,24 @@ pub async fn select_workspace_pending_invitations(
.await?;
Ok(invitee_emails.into_iter().collect())
}
#[inline]
pub async fn is_workspace_exist<'a, E: Executor<'a, Database = Postgres>>(
executor: E,
workspace_id: &Uuid,
) -> Result<bool, AppError> {
let exists = sqlx::query_scalar!(
r#"
SELECT EXISTS(
SELECT 1
FROM af_workspace
WHERE workspace_id = $1
) AS user_exists;
"#,
workspace_id
)
.fetch_one(executor)
.await?;
Ok(exists.unwrap_or(false))
}

View file

@ -0,0 +1,27 @@
-- Add migration script here
-- Create table for chat documents
CREATE TABLE af_chat
(
chat_id UUID PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
deleted_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
name TEXT NOT NULL DEFAULT '',
rag_ids JSONB NOT NULL DEFAULT '[]',
workspace_id UUID NOT NULL,
FOREIGN KEY (workspace_id) REFERENCES af_workspace (workspace_id) ON DELETE CASCADE
);
-- Create table for chat messages
CREATE TABLE af_chat_messages
(
message_id BIGSERIAL PRIMARY KEY,
author JSONB NOT NULL,
chat_id UUID NOT NULL,
content TEXT NOT NULL,
deleted_at TIMESTAMP WITH TIME ZONE DEFAULT NULL,
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
edited_at TIMESTAMP DEFAULT NULL,
FOREIGN KEY (chat_id) REFERENCES af_chat (chat_id) ON DELETE CASCADE
);
CREATE INDEX idx_chat_messages_chat_id_created_at ON af_chat_messages (message_id ASC, created_at ASC);

74
src/api/chat.rs Normal file
View file

@ -0,0 +1,74 @@
use crate::biz::chat::ops::{create_chat, create_chat_message, delete_chat, get_chat_messages};
use crate::biz::user::auth::jwt::UserUuid;
use crate::state::AppState;
use actix_web::web::{Data, Json};
use actix_web::{web, Scope};
use database_entity::chat::{
CreateChatMessageParams, CreateChatParams, GetChatMessageParams, RepeatedChatMessage,
};
use shared_entity::response::{AppResponse, JsonAppResponse};
pub fn chat_scope() -> Scope {
web::scope("/api/chat/{workspace_id}")
.service(web::resource("/").route(web::post().to(create_chat_handler)))
.service(
web::resource("/{chat_id}")
.route(web::delete().to(delete_chat_handler))
.route(web::post().to(update_chat_handler)),
)
.service(
web::resource("/{chat_id}/messages")
.route(web::get().to(get_chat_message_handler))
.route(web::post().to(post_chat_message_handler)),
)
}
async fn create_chat_handler(
path: web::Path<String>,
state: Data<AppState>,
payload: Json<CreateChatParams>,
) -> actix_web::Result<JsonAppResponse<()>> {
let workspace_id = path.into_inner();
create_chat(&state.pg_pool, payload.into_inner(), &workspace_id).await?;
Ok(AppResponse::Ok().into())
}
async fn delete_chat_handler(
path: web::Path<(String, String)>,
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<()>> {
let (_, chat_id) = path.into_inner();
delete_chat(&state.pg_pool, &chat_id).await?;
Ok(AppResponse::Ok().into())
}
async fn update_chat_handler(
path: web::Path<(String, String)>,
state: Data<AppState>,
) -> actix_web::Result<JsonAppResponse<()>> {
let (_, chat_id) = path.into_inner();
delete_chat(&state.pg_pool, &chat_id).await?;
Ok(AppResponse::Ok().into())
}
async fn post_chat_message_handler(
state: Data<AppState>,
chat_id: web::Path<String>,
payload: Json<CreateChatMessageParams>,
uuid: UserUuid,
) -> actix_web::Result<JsonAppResponse<()>> {
let chat_id = chat_id.into_inner();
let uid = state.user_cache.get_user_uid(&uuid).await?;
create_chat_message(&state.pg_pool, uid, payload.into_inner(), &chat_id).await?;
Ok(AppResponse::Ok().into())
}
async fn get_chat_message_handler(
chat_id: web::Path<String>,
state: Data<AppState>,
payload: Json<GetChatMessageParams>,
) -> actix_web::Result<JsonAppResponse<RepeatedChatMessage>> {
let chat_id = chat_id.into_inner();
let messages = get_chat_messages(&state.pg_pool, payload.into_inner(), &chat_id).await?;
Ok(AppResponse::Ok().with_data(messages).into())
}

View file

@ -1,3 +1,4 @@
pub mod chat;
pub mod file_storage;
pub mod metrics;
pub mod user;

View file

@ -7,6 +7,7 @@ use crate::api::ws::ws_scope;
use crate::mailer::Mailer;
use access_control::access::{enable_access_control, AccessControl};
use crate::api::chat::chat_scope;
use crate::biz::actix_ws::server::RealtimeServerActor;
use crate::biz::casbin::{
CollabAccessControlImpl, RealtimeCollabAccessControlImpl, WorkspaceAccessControlImpl,
@ -136,6 +137,7 @@ pub async fn run_actix_server(
.service(collab_scope())
.service(ws_scope())
.service(file_storage_scope())
.service(chat_scope())
.service(metrics_scope())
.app_data(Data::new(state.metrics.registry.clone()))
.app_data(Data::new(state.metrics.request_metrics.clone()))

1
src/biz/chat/mod.rs Normal file
View file

@ -0,0 +1 @@
pub mod ops;

53
src/biz/chat/ops.rs Normal file
View file

@ -0,0 +1,53 @@
use app_error::AppError;
use database::chat;
use database::chat::chat_ops::{insert_chat, insert_chat_message, select_chat_messages};
use database_entity::chat::{
ChatAuthor, CreateChatMessageParams, CreateChatParams, GetChatMessageParams, RepeatedChatMessage,
};
use sqlx::PgPool;
use validator::Validate;
pub(crate) async fn create_chat(
pg_pool: &PgPool,
params: CreateChatParams,
workspace_id: &str,
) -> Result<(), AppError> {
params.validate()?;
let mut txn = pg_pool.begin().await?;
insert_chat(&mut txn, workspace_id, params).await?;
txn.commit().await?;
Ok(())
}
pub(crate) async fn delete_chat(pg_pool: &PgPool, chat_id: &str) -> Result<(), AppError> {
let mut txn = pg_pool.begin().await?;
chat::chat_ops::delete_chat(&mut txn, chat_id).await?;
txn.commit().await?;
Ok(())
}
pub async fn create_chat_message(
pg_pool: &PgPool,
uid: i64,
params: CreateChatMessageParams,
chat_id: &str,
) -> Result<(), AppError> {
params.validate()?;
insert_chat_message(pg_pool, ChatAuthor::Human { uid }, chat_id, params).await?;
Ok(())
}
pub async fn get_chat_messages(
pg_pool: &PgPool,
params: GetChatMessageParams,
chat_id: &str,
) -> Result<RepeatedChatMessage, AppError> {
params.validate()?;
let mut txn = pg_pool.begin().await?;
let messages = select_chat_messages(&mut txn, chat_id, params).await?;
txn.commit().await?;
Ok(messages)
}

View file

@ -1,5 +1,6 @@
pub mod actix_ws;
pub mod casbin;
pub mod chat;
pub mod collab;
pub mod pg_listener;
pub mod snapshot;

169
tests/sql_test/chat_test.rs Normal file
View file

@ -0,0 +1,169 @@
use crate::sql_test::util::{setup_db, test_create_user};
use database::chat::chat_ops::{
delete_chat, get_all_chat_messages, insert_chat, insert_chat_message, select_chat,
select_chat_messages,
};
use database_entity::chat::{
ChatAuthor, CreateChatMessageParams, CreateChatParams, GetChatMessageParams,
};
use serde_json::json;
use sqlx::PgPool;
#[sqlx::test(migrations = false)]
async fn chat_crud_test(pool: PgPool) {
setup_db(&pool).await.unwrap();
let user_uuid = uuid::Uuid::new_v4();
let name = user_uuid.to_string();
let email = format!("{}@appflowy.io", name);
let user = test_create_user(&pool, user_uuid, &email, &name)
.await
.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
// create chat
{
let mut txn = pool.begin().await.unwrap();
insert_chat(
&mut txn,
&user.workspace_id,
CreateChatParams {
chat_id: chat_id.clone(),
name: "my first chat".to_string(),
rag_ids: vec!["rag_id_1".to_string(), "rag_id_2".to_string()],
},
)
.await
.unwrap();
txn.commit().await.unwrap();
}
// get chat
{
let chat = select_chat(&pool, &chat_id).await.unwrap();
assert_eq!(chat.name, "my first chat");
assert_eq!(
chat.rag_ids,
json!(vec!["rag_id_1".to_string(), "rag_id_2".to_string()]),
);
}
// delete chat
{
let mut txn = pool.begin().await.unwrap();
delete_chat(&mut txn, &chat_id).await.unwrap();
txn.commit().await.unwrap();
}
// get chat
{
let result = select_chat(&pool, &chat_id).await.unwrap_err();
assert!(result.is_record_not_found());
}
}
#[sqlx::test(migrations = false)]
async fn chat_message_crud_test(pool: PgPool) {
setup_db(&pool).await.unwrap();
let user_uuid = uuid::Uuid::new_v4();
let name = user_uuid.to_string();
let email = format!("{}@appflowy.io", name);
let user = test_create_user(&pool, user_uuid, &email, &name)
.await
.unwrap();
let chat_id = uuid::Uuid::new_v4().to_string();
// create chat
{
let mut txn = pool.begin().await.unwrap();
insert_chat(
&mut txn,
&user.workspace_id,
CreateChatParams {
chat_id: chat_id.clone(),
name: "my first chat".to_string(),
rag_ids: vec!["rag_id_1".to_string(), "rag_id_2".to_string()],
},
)
.await
.unwrap();
txn.commit().await.unwrap();
}
// create chat messages
for i in 0..5 {
let params = CreateChatMessageParams {
content: format!("message {}", i),
};
insert_chat_message(&pool, ChatAuthor::Human { uid: user.uid }, &chat_id, params)
.await
.unwrap();
}
// get 3 messages: 1,2,3
{
// option 1:use offset to get 3 messages => 1,2,3
let mut txn = pool.begin().await.unwrap();
let params = GetChatMessageParams::offset(chat_id.clone(), 0, 3);
let result_1 = select_chat_messages(&mut txn, &chat_id, params)
.await
.unwrap();
txn.commit().await.unwrap();
assert_eq!(result_1.messages.len(), 3);
assert_eq!(result_1.messages[0].message_id, 1);
assert_eq!(result_1.messages[1].message_id, 2);
assert_eq!(result_1.messages[2].message_id, 3);
assert_eq!(result_1.total, 5);
assert!(result_1.has_more);
// option 2:use before_message_id to get 3 messages => 1,2,3
let params = GetChatMessageParams::before_message_id(chat_id.clone(), 4, 3);
let mut txn = pool.begin().await.unwrap();
let result_2 = select_chat_messages(&mut txn, &chat_id, params)
.await
.unwrap();
txn.commit().await.unwrap();
assert_eq!(result_2.messages.len(), 3);
assert_eq!(result_2.messages[0].message_id, 1);
assert_eq!(result_2.messages[1].message_id, 2);
assert_eq!(result_2.messages[2].message_id, 3);
assert_eq!(result_2.total, 5);
assert!(!result_2.has_more);
}
// get two messages: 4,5
{
// option 1:use offset to get 2 messages => 4,5
let params = GetChatMessageParams::offset(chat_id.clone(), 3, 3);
let mut txn = pool.begin().await.unwrap();
let result_1 = select_chat_messages(&mut txn, &chat_id, params)
.await
.unwrap();
txn.commit().await.unwrap();
assert_eq!(result_1.messages.len(), 2);
assert_eq!(result_1.messages[0].message_id, 4);
assert_eq!(result_1.messages[1].message_id, 5);
assert_eq!(result_1.total, 5);
assert!(!result_1.has_more);
// option 2:use after_message_id to get remaining 2 messages => 4,5
let params = GetChatMessageParams::after_message_id(chat_id.clone(), 3, 3);
let mut txn = pool.begin().await.unwrap();
let result_2 = select_chat_messages(&mut txn, &chat_id, params)
.await
.unwrap();
txn.commit().await.unwrap();
assert_eq!(result_2.messages.len(), 2);
assert_eq!(result_2.messages[0].message_id, 4);
assert_eq!(result_2.messages[1].message_id, 5);
assert_eq!(result_2.total, 5);
assert!(!result_2.has_more);
}
// get all messages
{
let messages = get_all_chat_messages(&pool, &chat_id).await.unwrap();
assert_eq!(messages.len(), 5);
}
}

View file

@ -1,3 +1,4 @@
mod chat_test;
mod history_test;
pub(crate) mod util;
mod workspace_test;