diff --git a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart index 91ac63944c..6d0f396424 100644 --- a/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart +++ b/frontend/appflowy_flutter/lib/workspace/application/settings/ai/settings_ai_bloc.dart @@ -1,3 +1,5 @@ +import 'dart:convert'; + import 'package:appflowy/user/application/user_listener.dart'; import 'package:appflowy/user/application/user_service.dart'; import 'package:appflowy_backend/dispatch/dispatch.dart'; @@ -9,6 +11,7 @@ import 'package:bloc/bloc.dart'; import 'package:freezed_annotation/freezed_annotation.dart'; part 'settings_ai_bloc.freezed.dart'; +part 'settings_ai_bloc.g.dart'; class SettingsAIBloc extends Bloc { SettingsAIBloc( @@ -65,6 +68,7 @@ class SettingsAIBloc extends Bloc { }, ); _loadUserWorkspaceSetting(); + _loadModelList(); }, didReceiveUserProfile: (userProfile) { emit(state.copyWith(userProfile: userProfile)); @@ -78,7 +82,7 @@ class SettingsAIBloc extends Bloc { !(state.aiSettings?.disableSearchIndexing ?? false), ); }, - selectModel: (AIModelPB model) { + selectModel: (String model) { _updateUserWorkspaceSetting(model: model); }, didLoadAISetting: (UseAISettingPB settings) { @@ -89,6 +93,14 @@ class SettingsAIBloc extends Bloc { ), ); }, + didLoadAvailableModels: (String models) { + final dynamic decodedJson = jsonDecode(models); + Log.info("Available models: $decodedJson"); + if (decodedJson is Map) { + final models = ModelList.fromJson(decodedJson).models; + emit(state.copyWith(availableModels: models)); + } + }, refreshMember: (member) { emit(state.copyWith(currentWorkspaceMemberRole: member.role)); }, @@ -98,7 +110,7 @@ class SettingsAIBloc extends Bloc { void _updateUserWorkspaceSetting({ bool? disableSearchIndexing, - AIModelPB? model, + String? model, }) { final payload = UpdateUserWorkspaceSettingPB( workspaceId: workspaceId, @@ -132,6 +144,18 @@ class SettingsAIBloc extends Bloc { }); }); } + + void _loadModelList() { + AIEventGetAvailableModels().send().then((result) { + result.fold((config) { + if (!isClosed) { + add(SettingsAIEvent.didLoadAvailableModels(config.models)); + } + }, (err) { + Log.error(err); + }); + }); + } } @freezed @@ -145,11 +169,15 @@ class SettingsAIEvent with _$SettingsAIEvent { const factory SettingsAIEvent.refreshMember(WorkspaceMemberPB member) = _RefreshMember; - const factory SettingsAIEvent.selectModel(AIModelPB model) = _SelectAIModel; + const factory SettingsAIEvent.selectModel(String model) = _SelectAIModel; const factory SettingsAIEvent.didReceiveUserProfile( UserProfilePB newUserProfile, ) = _DidReceiveUserProfile; + + const factory SettingsAIEvent.didLoadAvailableModels( + String models, + ) = _DidLoadAvailableModels; } @freezed @@ -158,6 +186,21 @@ class SettingsAIState with _$SettingsAIState { required UserProfilePB userProfile, UseAISettingPB? aiSettings, AFRolePB? currentWorkspaceMemberRole, + @Default(["default"]) List availableModels, @Default(true) bool enableSearchIndexing, }) = _SettingsAIState; } + +@JsonSerializable() +class ModelList { + ModelList({ + required this.models, + }); + + factory ModelList.fromJson(Map json) => + _$ModelListFromJson(json); + + final List models; + + Map toJson() => _$ModelListToJson(this); +} diff --git a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart index 22aaf0bcca..e03aa639e4 100644 --- a/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart +++ b/frontend/appflowy_flutter/lib/workspace/presentation/settings/pages/setting_ai_view/model_selection.dart @@ -4,8 +4,6 @@ import 'package:appflowy/generated/locale_keys.g.dart'; import 'package:appflowy/workspace/application/settings/ai/settings_ai_bloc.dart'; import 'package:appflowy/workspace/presentation/settings/shared/af_dropdown_menu_entry.dart'; import 'package:appflowy/workspace/presentation/settings/shared/settings_dropdown.dart'; -import 'package:appflowy_backend/log.dart'; -import 'package:appflowy_backend/protobuf/flowy-user/protobuf.dart'; import 'package:easy_localization/easy_localization.dart'; import 'package:flowy_infra_ui/style_widget/text.dart'; import 'package:flutter_bloc/flutter_bloc.dart'; @@ -30,18 +28,18 @@ class AIModelSelection extends StatelessWidget { ), const Spacer(), Flexible( - child: SettingsDropdown( + child: SettingsDropdown( key: const Key('_AIModelSelection'), onChanged: (model) => context .read() .add(SettingsAIEvent.selectModel(model)), selectedOption: state.userProfile.aiModel, - options: _availableModels + options: state.availableModels .map( - (format) => buildDropdownMenuEntry( + (model) => buildDropdownMenuEntry( context, - value: format, - label: _titleForAIModel(format), + value: model, + label: model, ), ) .toList(), @@ -54,29 +52,3 @@ class AIModelSelection extends StatelessWidget { ); } } - -List _availableModels = [ - AIModelPB.DefaultModel, - AIModelPB.Claude3Opus, - AIModelPB.Claude3Sonnet, - AIModelPB.GPT4oMini, - AIModelPB.GPT4o, -]; - -String _titleForAIModel(AIModelPB model) { - switch (model) { - case AIModelPB.DefaultModel: - return "Default"; - case AIModelPB.Claude3Opus: - return "Claude 3 Opus"; - case AIModelPB.Claude3Sonnet: - return "Claude 3 Sonnet"; - case AIModelPB.GPT4oMini: - return "GPT-4o-mini"; - case AIModelPB.GPT4o: - return "GPT-4o"; - default: - Log.error("Unknown AI model: $model, fallback to default"); - return "Default"; - } -} diff --git a/frontend/rust-lib/Cargo.lock b/frontend/rust-lib/Cargo.lock index c26bb535ff..733a784e42 100644 --- a/frontend/rust-lib/Cargo.lock +++ b/frontend/rust-lib/Cargo.lock @@ -163,7 +163,7 @@ checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7" [[package]] name = "app-error" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "anyhow", "bincode", @@ -183,7 +183,7 @@ dependencies = [ [[package]] name = "appflowy-ai-client" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "anyhow", "bytes", @@ -786,7 +786,7 @@ dependencies = [ [[package]] name = "client-api" version = "0.2.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "again", "anyhow", @@ -843,7 +843,7 @@ dependencies = [ [[package]] name = "client-api-entity" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "collab-entity", "collab-rt-entity", @@ -856,7 +856,7 @@ dependencies = [ [[package]] name = "client-websocket" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "futures-channel", "futures-util", @@ -1128,7 +1128,7 @@ dependencies = [ [[package]] name = "collab-rt-entity" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "anyhow", "bincode", @@ -1153,7 +1153,7 @@ dependencies = [ [[package]] name = "collab-rt-protocol" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "anyhow", "async-trait", @@ -1400,7 +1400,7 @@ dependencies = [ "cssparser-macros", "dtoa-short", "itoa", - "phf 0.8.0", + "phf 0.11.2", "smallvec", ] @@ -1548,7 +1548,7 @@ checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308" [[package]] name = "database-entity" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "anyhow", "app-error", @@ -2970,7 +2970,7 @@ dependencies = [ [[package]] name = "gotrue" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "anyhow", "futures-util", @@ -2987,7 +2987,7 @@ dependencies = [ [[package]] name = "gotrue-entity" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "anyhow", "app-error", @@ -3598,7 +3598,7 @@ dependencies = [ [[package]] name = "infra" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "anyhow", "bytes", @@ -4624,7 +4624,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3dfb61232e34fcb633f43d12c58f83c1df82962dcdfa565a4e866ffc17dafe12" dependencies = [ - "phf_macros", + "phf_macros 0.8.0", "phf_shared 0.8.0", "proc-macro-hack", ] @@ -4644,6 +4644,7 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" dependencies = [ + "phf_macros 0.11.3", "phf_shared 0.11.2", ] @@ -4711,6 +4712,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "phf_macros" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216" +dependencies = [ + "phf_generator 0.11.2", + "phf_shared 0.11.2", + "proc-macro2", + "quote", + "syn 2.0.94", +] + [[package]] name = "phf_shared" version = "0.8.0" @@ -6140,7 +6154,7 @@ dependencies = [ [[package]] name = "shared-entity" version = "0.1.0" -source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" +source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549" dependencies = [ "anyhow", "app-error", diff --git a/frontend/rust-lib/Cargo.toml b/frontend/rust-lib/Cargo.toml index 8435b0f904..b93d19954b 100644 --- a/frontend/rust-lib/Cargo.toml +++ b/frontend/rust-lib/Cargo.toml @@ -103,8 +103,8 @@ dashmap = "6.0.1" # Run the script.add_workspace_members: # scripts/tool/update_client_api_rev.sh new_rev_id # ⚠️⚠️⚠️️ -client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "4a26572a4e43714def9b362d444c640fdf1bc0d9" } -client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "4a26572a4e43714def9b362d444c640fdf1bc0d9" } +client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "82409199f8ffa0166f2f5d9403ccd55831890549" } +client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "82409199f8ffa0166f2f5d9403ccd55831890549" } [profile.dev] opt-level = 0 diff --git a/frontend/rust-lib/flowy-ai-pub/src/cloud.rs b/frontend/rust-lib/flowy-ai-pub/src/cloud.rs index bf787f3d20..98198c8f9f 100644 --- a/frontend/rust-lib/flowy-ai-pub/src/cloud.rs +++ b/frontend/rust-lib/flowy-ai-pub/src/cloud.rs @@ -1,7 +1,7 @@ use bytes::Bytes; pub use client_api::entity::ai_dto::{ AppFlowyOfflineAI, CompleteTextParams, CompletionMetadata, CompletionType, CreateChatContext, - LLMModel, LocalAIConfig, ModelInfo, OutputContent, OutputLayout, RelatedQuestion, + LLMModel, LocalAIConfig, ModelInfo, ModelList, OutputContent, OutputLayout, RelatedQuestion, RepeatedRelatedQuestion, ResponseFormat, StringOrMessage, }; pub use client_api::entity::billing_dto::SubscriptionPlan; @@ -119,4 +119,6 @@ pub trait ChatCloudService: Send + Sync + 'static { chat_id: &str, params: UpdateChatParams, ) -> Result<(), FlowyError>; + + async fn get_available_models(&self, workspace_id: &str) -> Result; } diff --git a/frontend/rust-lib/flowy-ai/src/ai_manager.rs b/frontend/rust-lib/flowy-ai/src/ai_manager.rs index c88a26504c..5a784f6db6 100644 --- a/frontend/rust-lib/flowy-ai/src/ai_manager.rs +++ b/frontend/rust-lib/flowy-ai/src/ai_manager.rs @@ -10,7 +10,7 @@ use std::collections::HashMap; use appflowy_plugin::manager::PluginManager; use dashmap::DashMap; -use flowy_ai_pub::cloud::{ChatCloudService, ChatSettings, UpdateChatParams}; +use flowy_ai_pub::cloud::{ChatCloudService, ChatSettings, ModelList, UpdateChatParams}; use flowy_error::{FlowyError, FlowyResult}; use flowy_sqlite::kv::KVStorePreferences; use flowy_sqlite::DBConnection; @@ -241,6 +241,15 @@ impl AIManager { Ok(()) } + pub async fn get_available_models(&self) -> FlowyResult { + let workspace_id = self.user_service.workspace_id()?; + let list = self + .cloud_service_wm + .get_available_models(&workspace_id) + .await?; + Ok(list) + } + pub async fn get_or_create_chat_instance(&self, chat_id: &str) -> Result, FlowyError> { let chat = self.chats.get(chat_id).as_deref().cloned(); match chat { diff --git a/frontend/rust-lib/flowy-ai/src/entities.rs b/frontend/rust-lib/flowy-ai/src/entities.rs index a5b9778f06..6a30a7dff7 100644 --- a/frontend/rust-lib/flowy-ai/src/entities.rs +++ b/frontend/rust-lib/flowy-ai/src/entities.rs @@ -182,6 +182,12 @@ pub struct ChatMessageListPB { pub total: i64, } +#[derive(Default, ProtoBuf, Validate, Clone, Debug)] +pub struct ModelConfigPB { + #[pb(index = 1)] + pub models: String, +} + impl From for ChatMessageListPB { fn from(repeated_chat_message: RepeatedChatMessage) -> Self { let messages = repeated_chat_message diff --git a/frontend/rust-lib/flowy-ai/src/event_handler.rs b/frontend/rust-lib/flowy-ai/src/event_handler.rs index d8ebd0e93b..baa0898220 100644 --- a/frontend/rust-lib/flowy-ai/src/event_handler.rs +++ b/frontend/rust-lib/flowy-ai/src/event_handler.rs @@ -107,6 +107,15 @@ pub(crate) async fn regenerate_response_handler( Ok(()) } +#[tracing::instrument(level = "debug", skip_all, err)] +pub(crate) async fn get_available_model_list_handler( + ai_manager: AFPluginState>, +) -> DataResult { + let ai_manager = upgrade_ai_manager(ai_manager)?; + let models = serde_json::to_string(&ai_manager.get_available_models().await?)?; + data_result_ok(ModelConfigPB { models }) +} + #[tracing::instrument(level = "debug", skip_all, err)] pub(crate) async fn load_prev_message_handler( data: AFPluginData, diff --git a/frontend/rust-lib/flowy-ai/src/event_map.rs b/frontend/rust-lib/flowy-ai/src/event_map.rs index c27e94c334..2bd2bee863 100644 --- a/frontend/rust-lib/flowy-ai/src/event_map.rs +++ b/frontend/rust-lib/flowy-ai/src/event_map.rs @@ -60,6 +60,10 @@ pub fn init(ai_manager: Weak) -> AFPlugin { .event(AIEvent::GetChatSettings, get_chat_settings_handler) .event(AIEvent::UpdateChatSettings, update_chat_settings_handler) .event(AIEvent::RegenerateResponse, regenerate_response_handler) + .event( + AIEvent::GetAvailableModels, + get_available_model_list_handler, + ) } #[derive(Clone, Copy, PartialEq, Eq, Debug, Display, Hash, ProtoBuf_Enum, Flowy_Event)] @@ -154,4 +158,7 @@ pub enum AIEvent { #[event(input = "RegenerateResponsePB")] RegenerateResponse = 27, + + #[event(output = "ModelConfigPB")] + GetAvailableModels = 28, } diff --git a/frontend/rust-lib/flowy-ai/src/middleware/chat_service_mw.rs b/frontend/rust-lib/flowy-ai/src/middleware/chat_service_mw.rs index ae8ce9b8d0..d6dce3696c 100644 --- a/frontend/rust-lib/flowy-ai/src/middleware/chat_service_mw.rs +++ b/frontend/rust-lib/flowy-ai/src/middleware/chat_service_mw.rs @@ -10,9 +10,9 @@ use std::collections::HashMap; use flowy_ai_pub::cloud::{ ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, - CompleteTextParams, LocalAIConfig, MessageCursor, RelatedQuestion, RepeatedChatMessage, - RepeatedRelatedQuestion, ResponseFormat, StreamAnswer, StreamComplete, SubscriptionPlan, - UpdateChatParams, + CompleteTextParams, LocalAIConfig, MessageCursor, ModelList, RelatedQuestion, + RepeatedChatMessage, RepeatedRelatedQuestion, ResponseFormat, StreamAnswer, StreamComplete, + SubscriptionPlan, UpdateChatParams, }; use flowy_error::{FlowyError, FlowyResult}; use futures::{stream, Sink, StreamExt, TryStreamExt}; @@ -353,4 +353,8 @@ impl ChatCloudService for AICloudServiceMiddleware { .update_chat_settings(workspace_id, chat_id, params) .await } + + async fn get_available_models(&self, workspace_id: &str) -> Result { + self.cloud_service.get_available_models(workspace_id).await + } } diff --git a/frontend/rust-lib/flowy-core/src/deps_resolve/cloud_service_impl.rs b/frontend/rust-lib/flowy-core/src/deps_resolve/cloud_service_impl.rs index b01ea49f82..22342615e3 100644 --- a/frontend/rust-lib/flowy-core/src/deps_resolve/cloud_service_impl.rs +++ b/frontend/rust-lib/flowy-core/src/deps_resolve/cloud_service_impl.rs @@ -22,7 +22,7 @@ use collab_integrate::collab_builder::{ }; use flowy_ai_pub::cloud::{ ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, - CompleteTextParams, LocalAIConfig, MessageCursor, RepeatedChatMessage, ResponseFormat, + CompleteTextParams, LocalAIConfig, MessageCursor, ModelList, RepeatedChatMessage, ResponseFormat, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams, }; use flowy_database_pub::cloud::{ @@ -838,6 +838,14 @@ impl ChatCloudService for ServerProvider { .update_chat_settings(workspace_id, chat_id, params) .await } + + async fn get_available_models(&self, workspace_id: &str) -> Result { + self + .get_server()? + .chat_service() + .get_available_models(workspace_id) + .await + } } #[async_trait] diff --git a/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs b/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs index 7512b9e48c..11fc5cc27c 100644 --- a/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs +++ b/frontend/rust-lib/flowy-server/src/af_cloud/impls/chat.rs @@ -8,7 +8,7 @@ use client_api::entity::chat_dto::{ }; use flowy_ai_pub::cloud::{ ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, LocalAIConfig, - StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams, + ModelList, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams, }; use flowy_error::FlowyError; use futures_util::{StreamExt, TryStreamExt}; @@ -270,4 +270,13 @@ where .await?; Ok(()) } + + async fn get_available_models(&self, workspace_id: &str) -> Result { + let list = self + .inner + .try_get_client()? + .get_model_list(workspace_id) + .await?; + Ok(list) + } } diff --git a/frontend/rust-lib/flowy-server/src/af_cloud/server.rs b/frontend/rust-lib/flowy-server/src/af_cloud/server.rs index ee908ac9cc..06e56a8c05 100644 --- a/frontend/rust-lib/flowy-server/src/af_cloud/server.rs +++ b/frontend/rust-lib/flowy-server/src/af_cloud/server.rs @@ -1,4 +1,3 @@ -use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::Duration; @@ -7,7 +6,6 @@ use crate::af_cloud::define::ServerUser; use anyhow::Error; use arc_swap::ArcSwap; use client_api::collab_sync::ServerCollabMessage; -use client_api::entity::ai_dto::AIModel; use client_api::entity::UserMessage; use client_api::notify::{TokenState, TokenStateReceiver}; use client_api::ws::{ @@ -124,7 +122,7 @@ impl AppFlowyServer for AppFlowyCloudServer { } fn set_ai_model(&self, ai_model: &str) -> Result<(), Error> { - self.client.set_ai_model(AIModel::from_str(ai_model)?); + self.client.set_ai_model(ai_model.to_string()); Ok(()) } diff --git a/frontend/rust-lib/flowy-server/src/default_impl.rs b/frontend/rust-lib/flowy-server/src/default_impl.rs index 792db6e23a..194aa89ef3 100644 --- a/frontend/rust-lib/flowy-server/src/default_impl.rs +++ b/frontend/rust-lib/flowy-server/src/default_impl.rs @@ -1,7 +1,7 @@ use client_api::entity::ai_dto::{LocalAIConfig, RepeatedRelatedQuestion}; use flowy_ai_pub::cloud::{ ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, - CompleteTextParams, MessageCursor, RepeatedChatMessage, ResponseFormat, StreamAnswer, + CompleteTextParams, MessageCursor, ModelList, RepeatedChatMessage, ResponseFormat, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams, }; use flowy_error::FlowyError; @@ -144,4 +144,8 @@ impl ChatCloudService for DefaultChatCloudServiceImpl { ) -> Result<(), FlowyError> { Err(FlowyError::not_support().with_context("Chat is not supported in local server.")) } + + async fn get_available_models(&self, _workspace_id: &str) -> Result { + Err(FlowyError::not_support().with_context("Chat is not supported in local server.")) + } } diff --git a/frontend/rust-lib/flowy-user/src/entities/user_profile.rs b/frontend/rust-lib/flowy-user/src/entities/user_profile.rs index facc9d8b41..aa9d38a9cd 100644 --- a/frontend/rust-lib/flowy-user/src/entities/user_profile.rs +++ b/frontend/rust-lib/flowy-user/src/entities/user_profile.rs @@ -1,13 +1,11 @@ -use std::convert::TryInto; -use std::str::FromStr; - use flowy_derive::{ProtoBuf, ProtoBuf_Enum}; use flowy_user_pub::entities::*; use lib_infra::validator_fn::required_not_empty_str; +use std::convert::TryInto; use validator::Validate; use crate::entities::parser::{UserEmail, UserIcon, UserName, UserOpenaiKey, UserPassword}; -use crate::entities::{AIModelPB, AuthenticatorPB}; +use crate::entities::AuthenticatorPB; use crate::errors::ErrorCode; use super::parser::UserStabilityAIKey; @@ -58,7 +56,7 @@ pub struct UserProfilePB { pub stability_ai_key: String, #[pb(index = 11)] - pub ai_model: AIModelPB, + pub ai_model: String, } #[derive(ProtoBuf_Enum, Eq, PartialEq, Debug, Clone)] @@ -79,6 +77,10 @@ impl From for UserProfilePB { EncryptionType::NoEncryption => ("".to_string(), EncryptionTypePB::NoEncryption), EncryptionType::SelfEncryption(sign) => (sign, EncryptionTypePB::Symmetric), }; + let mut ai_model = user_profile.ai_model; + if ai_model.is_empty() { + ai_model = "Default".to_string(); + } Self { id: user_profile.uid, email: user_profile.email, @@ -90,7 +92,7 @@ impl From for UserProfilePB { encryption_sign, encryption_type: encryption_ty, stability_ai_key: user_profile.stability_ai_key, - ai_model: AIModelPB::from_str(&user_profile.ai_model).unwrap_or_default(), + ai_model, } } } diff --git a/frontend/rust-lib/flowy-user/src/entities/workspace.rs b/frontend/rust-lib/flowy-user/src/entities/workspace.rs index 736979f11a..240b153e33 100644 --- a/frontend/rust-lib/flowy-user/src/entities/workspace.rs +++ b/frontend/rust-lib/flowy-user/src/entities/workspace.rs @@ -3,7 +3,6 @@ use client_api::entity::billing_dto::{ WorkspaceSubscriptionStatus, WorkspaceUsageAndLimit, }; use serde::{Deserialize, Serialize}; -use std::str::FromStr; use validator::Validate; use flowy_derive::{ProtoBuf, ProtoBuf_Enum}; @@ -382,14 +381,14 @@ pub struct UseAISettingPB { pub disable_search_indexing: bool, #[pb(index = 2)] - pub ai_model: AIModelPB, + pub ai_model: String, } impl From for UseAISettingPB { fn from(value: AFWorkspaceSettings) -> Self { Self { disable_search_indexing: value.disable_search_indexing, - ai_model: AIModelPB::from_str(&value.ai_model).unwrap_or_default(), + ai_model: value.ai_model, } } } @@ -404,7 +403,7 @@ pub struct UpdateUserWorkspaceSettingPB { pub disable_search_indexing: Option, #[pb(index = 3, one_of)] - pub ai_model: Option, + pub ai_model: Option, } impl From for AFWorkspaceSettingsChange { @@ -414,48 +413,12 @@ impl From for AFWorkspaceSettingsChange { change = change.disable_search_indexing(disable_search_indexing); } if let Some(ai_model) = value.ai_model { - change = change.ai_model(ai_model.to_str().to_string()); + change = change.ai_model(ai_model); } change } } -#[derive(ProtoBuf_Enum, Debug, Clone, Eq, PartialEq, Default)] -pub enum AIModelPB { - #[default] - DefaultModel = 0, - GPT4oMini = 1, - GPT4o = 2, - Claude3Sonnet = 3, - Claude3Opus = 4, -} - -impl AIModelPB { - pub fn to_str(&self) -> &str { - match self { - AIModelPB::DefaultModel => "default-model", - AIModelPB::GPT4oMini => "gpt-4o-mini", - AIModelPB::GPT4o => "gpt-4o", - AIModelPB::Claude3Sonnet => "claude-3-sonnet", - AIModelPB::Claude3Opus => "claude-3-opus", - } - } -} - -impl FromStr for AIModelPB { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s { - "gpt-3.5-turbo" => Ok(AIModelPB::GPT4oMini), - "gpt-4o" => Ok(AIModelPB::GPT4o), - "claude-3-sonnet" => Ok(AIModelPB::Claude3Sonnet), - "claude-3-opus" => Ok(AIModelPB::Claude3Opus), - _ => Ok(AIModelPB::DefaultModel), - } - } -} - #[derive(Debug, ProtoBuf, Default, Clone)] pub struct WorkspaceSubscriptionInfoPB { #[pb(index = 1)] diff --git a/frontend/rust-lib/flowy-user/src/user_manager/manager_user_workspace.rs b/frontend/rust-lib/flowy-user/src/user_manager/manager_user_workspace.rs index 31ad71ce74..e666d2486a 100644 --- a/frontend/rust-lib/flowy-user/src/user_manager/manager_user_workspace.rs +++ b/frontend/rust-lib/flowy-user/src/user_manager/manager_user_workspace.rs @@ -559,12 +559,12 @@ impl UserManager { .send(); if let Some(ai_model) = &ai_model { - if let Err(err) = self.cloud_services.set_ai_model(ai_model.to_str()) { + if let Err(err) = self.cloud_services.set_ai_model(ai_model) { error!("Set ai model failed: {}", err); } let conn = self.db_connection(uid)?; - let params = UpdateUserProfileParams::new(uid).with_ai_model(ai_model.to_str()); + let params = UpdateUserProfileParams::new(uid).with_ai_model(ai_model); upsert_user_profile_change(uid, conn, UserTableChangeset::new(params))?; } Ok(())