chore: fetch model list

This commit is contained in:
Nathan 2025-02-01 23:50:31 +08:00
parent dd6b285cd3
commit 5ac0eded3c
17 changed files with 162 additions and 112 deletions

View file

@ -1,3 +1,5 @@
import 'dart:convert';
import 'package:appflowy/user/application/user_listener.dart'; import 'package:appflowy/user/application/user_listener.dart';
import 'package:appflowy/user/application/user_service.dart'; import 'package:appflowy/user/application/user_service.dart';
import 'package:appflowy_backend/dispatch/dispatch.dart'; import 'package:appflowy_backend/dispatch/dispatch.dart';
@ -9,6 +11,7 @@ import 'package:bloc/bloc.dart';
import 'package:freezed_annotation/freezed_annotation.dart'; import 'package:freezed_annotation/freezed_annotation.dart';
part 'settings_ai_bloc.freezed.dart'; part 'settings_ai_bloc.freezed.dart';
part 'settings_ai_bloc.g.dart';
class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> { class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
SettingsAIBloc( SettingsAIBloc(
@ -65,6 +68,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
}, },
); );
_loadUserWorkspaceSetting(); _loadUserWorkspaceSetting();
_loadModelList();
}, },
didReceiveUserProfile: (userProfile) { didReceiveUserProfile: (userProfile) {
emit(state.copyWith(userProfile: userProfile)); emit(state.copyWith(userProfile: userProfile));
@ -78,7 +82,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
!(state.aiSettings?.disableSearchIndexing ?? false), !(state.aiSettings?.disableSearchIndexing ?? false),
); );
}, },
selectModel: (AIModelPB model) { selectModel: (String model) {
_updateUserWorkspaceSetting(model: model); _updateUserWorkspaceSetting(model: model);
}, },
didLoadAISetting: (UseAISettingPB settings) { didLoadAISetting: (UseAISettingPB settings) {
@ -89,6 +93,14 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
), ),
); );
}, },
didLoadAvailableModels: (String models) {
final dynamic decodedJson = jsonDecode(models);
Log.info("Available models: $decodedJson");
if (decodedJson is Map<String, dynamic>) {
final models = ModelList.fromJson(decodedJson).models;
emit(state.copyWith(availableModels: models));
}
},
refreshMember: (member) { refreshMember: (member) {
emit(state.copyWith(currentWorkspaceMemberRole: member.role)); emit(state.copyWith(currentWorkspaceMemberRole: member.role));
}, },
@ -98,7 +110,7 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
void _updateUserWorkspaceSetting({ void _updateUserWorkspaceSetting({
bool? disableSearchIndexing, bool? disableSearchIndexing,
AIModelPB? model, String? model,
}) { }) {
final payload = UpdateUserWorkspaceSettingPB( final payload = UpdateUserWorkspaceSettingPB(
workspaceId: workspaceId, workspaceId: workspaceId,
@ -132,6 +144,18 @@ class SettingsAIBloc extends Bloc<SettingsAIEvent, SettingsAIState> {
}); });
}); });
} }
void _loadModelList() {
AIEventGetAvailableModels().send().then((result) {
result.fold((config) {
if (!isClosed) {
add(SettingsAIEvent.didLoadAvailableModels(config.models));
}
}, (err) {
Log.error(err);
});
});
}
} }
@freezed @freezed
@ -145,11 +169,15 @@ class SettingsAIEvent with _$SettingsAIEvent {
const factory SettingsAIEvent.refreshMember(WorkspaceMemberPB member) = const factory SettingsAIEvent.refreshMember(WorkspaceMemberPB member) =
_RefreshMember; _RefreshMember;
const factory SettingsAIEvent.selectModel(AIModelPB model) = _SelectAIModel; const factory SettingsAIEvent.selectModel(String model) = _SelectAIModel;
const factory SettingsAIEvent.didReceiveUserProfile( const factory SettingsAIEvent.didReceiveUserProfile(
UserProfilePB newUserProfile, UserProfilePB newUserProfile,
) = _DidReceiveUserProfile; ) = _DidReceiveUserProfile;
const factory SettingsAIEvent.didLoadAvailableModels(
String models,
) = _DidLoadAvailableModels;
} }
@freezed @freezed
@ -158,6 +186,21 @@ class SettingsAIState with _$SettingsAIState {
required UserProfilePB userProfile, required UserProfilePB userProfile,
UseAISettingPB? aiSettings, UseAISettingPB? aiSettings,
AFRolePB? currentWorkspaceMemberRole, AFRolePB? currentWorkspaceMemberRole,
@Default(["default"]) List<String> availableModels,
@Default(true) bool enableSearchIndexing, @Default(true) bool enableSearchIndexing,
}) = _SettingsAIState; }) = _SettingsAIState;
} }
@JsonSerializable()
class ModelList {
ModelList({
required this.models,
});
factory ModelList.fromJson(Map<String, dynamic> json) =>
_$ModelListFromJson(json);
final List<String> models;
Map<String, dynamic> toJson() => _$ModelListToJson(this);
}

View file

@ -4,8 +4,6 @@ import 'package:appflowy/generated/locale_keys.g.dart';
import 'package:appflowy/workspace/application/settings/ai/settings_ai_bloc.dart'; import 'package:appflowy/workspace/application/settings/ai/settings_ai_bloc.dart';
import 'package:appflowy/workspace/presentation/settings/shared/af_dropdown_menu_entry.dart'; import 'package:appflowy/workspace/presentation/settings/shared/af_dropdown_menu_entry.dart';
import 'package:appflowy/workspace/presentation/settings/shared/settings_dropdown.dart'; import 'package:appflowy/workspace/presentation/settings/shared/settings_dropdown.dart';
import 'package:appflowy_backend/log.dart';
import 'package:appflowy_backend/protobuf/flowy-user/protobuf.dart';
import 'package:easy_localization/easy_localization.dart'; import 'package:easy_localization/easy_localization.dart';
import 'package:flowy_infra_ui/style_widget/text.dart'; import 'package:flowy_infra_ui/style_widget/text.dart';
import 'package:flutter_bloc/flutter_bloc.dart'; import 'package:flutter_bloc/flutter_bloc.dart';
@ -30,18 +28,18 @@ class AIModelSelection extends StatelessWidget {
), ),
const Spacer(), const Spacer(),
Flexible( Flexible(
child: SettingsDropdown<AIModelPB>( child: SettingsDropdown<String>(
key: const Key('_AIModelSelection'), key: const Key('_AIModelSelection'),
onChanged: (model) => context onChanged: (model) => context
.read<SettingsAIBloc>() .read<SettingsAIBloc>()
.add(SettingsAIEvent.selectModel(model)), .add(SettingsAIEvent.selectModel(model)),
selectedOption: state.userProfile.aiModel, selectedOption: state.userProfile.aiModel,
options: _availableModels options: state.availableModels
.map( .map(
(format) => buildDropdownMenuEntry<AIModelPB>( (model) => buildDropdownMenuEntry<String>(
context, context,
value: format, value: model,
label: _titleForAIModel(format), label: model,
), ),
) )
.toList(), .toList(),
@ -54,29 +52,3 @@ class AIModelSelection extends StatelessWidget {
); );
} }
} }
List<AIModelPB> _availableModels = [
AIModelPB.DefaultModel,
AIModelPB.Claude3Opus,
AIModelPB.Claude3Sonnet,
AIModelPB.GPT4oMini,
AIModelPB.GPT4o,
];
String _titleForAIModel(AIModelPB model) {
switch (model) {
case AIModelPB.DefaultModel:
return "Default";
case AIModelPB.Claude3Opus:
return "Claude 3 Opus";
case AIModelPB.Claude3Sonnet:
return "Claude 3 Sonnet";
case AIModelPB.GPT4oMini:
return "GPT-4o-mini";
case AIModelPB.GPT4o:
return "GPT-4o";
default:
Log.error("Unknown AI model: $model, fallback to default");
return "Default";
}
}

View file

@ -163,7 +163,7 @@ checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7"
[[package]] [[package]]
name = "app-error" name = "app-error"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bincode", "bincode",
@ -183,7 +183,7 @@ dependencies = [
[[package]] [[package]]
name = "appflowy-ai-client" name = "appflowy-ai-client"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes", "bytes",
@ -786,7 +786,7 @@ dependencies = [
[[package]] [[package]]
name = "client-api" name = "client-api"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"again", "again",
"anyhow", "anyhow",
@ -843,7 +843,7 @@ dependencies = [
[[package]] [[package]]
name = "client-api-entity" name = "client-api-entity"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"collab-entity", "collab-entity",
"collab-rt-entity", "collab-rt-entity",
@ -856,7 +856,7 @@ dependencies = [
[[package]] [[package]]
name = "client-websocket" name = "client-websocket"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"futures-channel", "futures-channel",
"futures-util", "futures-util",
@ -1128,7 +1128,7 @@ dependencies = [
[[package]] [[package]]
name = "collab-rt-entity" name = "collab-rt-entity"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bincode", "bincode",
@ -1153,7 +1153,7 @@ dependencies = [
[[package]] [[package]]
name = "collab-rt-protocol" name = "collab-rt-protocol"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
@ -1400,7 +1400,7 @@ dependencies = [
"cssparser-macros", "cssparser-macros",
"dtoa-short", "dtoa-short",
"itoa", "itoa",
"phf 0.8.0", "phf 0.11.2",
"smallvec", "smallvec",
] ]
@ -1548,7 +1548,7 @@ checksum = "c2e66c9d817f1720209181c316d28635c050fa304f9c79e47a520882661b7308"
[[package]] [[package]]
name = "database-entity" name = "database-entity"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"app-error", "app-error",
@ -2970,7 +2970,7 @@ dependencies = [
[[package]] [[package]]
name = "gotrue" name = "gotrue"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"futures-util", "futures-util",
@ -2987,7 +2987,7 @@ dependencies = [
[[package]] [[package]]
name = "gotrue-entity" name = "gotrue-entity"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"app-error", "app-error",
@ -3598,7 +3598,7 @@ dependencies = [
[[package]] [[package]]
name = "infra" name = "infra"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bytes", "bytes",
@ -4624,7 +4624,7 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dfb61232e34fcb633f43d12c58f83c1df82962dcdfa565a4e866ffc17dafe12" checksum = "3dfb61232e34fcb633f43d12c58f83c1df82962dcdfa565a4e866ffc17dafe12"
dependencies = [ dependencies = [
"phf_macros", "phf_macros 0.8.0",
"phf_shared 0.8.0", "phf_shared 0.8.0",
"proc-macro-hack", "proc-macro-hack",
] ]
@ -4644,6 +4644,7 @@ version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc"
dependencies = [ dependencies = [
"phf_macros 0.11.3",
"phf_shared 0.11.2", "phf_shared 0.11.2",
] ]
@ -4711,6 +4712,19 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "phf_macros"
version = "0.11.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f84ac04429c13a7ff43785d75ad27569f2951ce0ffd30a3321230db2fc727216"
dependencies = [
"phf_generator 0.11.2",
"phf_shared 0.11.2",
"proc-macro2",
"quote",
"syn 2.0.94",
]
[[package]] [[package]]
name = "phf_shared" name = "phf_shared"
version = "0.8.0" version = "0.8.0"
@ -6140,7 +6154,7 @@ dependencies = [
[[package]] [[package]]
name = "shared-entity" name = "shared-entity"
version = "0.1.0" version = "0.1.0"
source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=4a26572a4e43714def9b362d444c640fdf1bc0d9#4a26572a4e43714def9b362d444c640fdf1bc0d9" source = "git+https://github.com/AppFlowy-IO/AppFlowy-Cloud?rev=82409199f8ffa0166f2f5d9403ccd55831890549#82409199f8ffa0166f2f5d9403ccd55831890549"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"app-error", "app-error",

View file

@ -103,8 +103,8 @@ dashmap = "6.0.1"
# Run the script.add_workspace_members: # Run the script.add_workspace_members:
# scripts/tool/update_client_api_rev.sh new_rev_id # scripts/tool/update_client_api_rev.sh new_rev_id
# ⚠️⚠️⚠️️ # ⚠️⚠️⚠️️
client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "4a26572a4e43714def9b362d444c640fdf1bc0d9" } client-api = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "82409199f8ffa0166f2f5d9403ccd55831890549" }
client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "4a26572a4e43714def9b362d444c640fdf1bc0d9" } client-api-entity = { git = "https://github.com/AppFlowy-IO/AppFlowy-Cloud", rev = "82409199f8ffa0166f2f5d9403ccd55831890549" }
[profile.dev] [profile.dev]
opt-level = 0 opt-level = 0

View file

@ -1,7 +1,7 @@
use bytes::Bytes; use bytes::Bytes;
pub use client_api::entity::ai_dto::{ pub use client_api::entity::ai_dto::{
AppFlowyOfflineAI, CompleteTextParams, CompletionMetadata, CompletionType, CreateChatContext, AppFlowyOfflineAI, CompleteTextParams, CompletionMetadata, CompletionType, CreateChatContext,
LLMModel, LocalAIConfig, ModelInfo, OutputContent, OutputLayout, RelatedQuestion, LLMModel, LocalAIConfig, ModelInfo, ModelList, OutputContent, OutputLayout, RelatedQuestion,
RepeatedRelatedQuestion, ResponseFormat, StringOrMessage, RepeatedRelatedQuestion, ResponseFormat, StringOrMessage,
}; };
pub use client_api::entity::billing_dto::SubscriptionPlan; pub use client_api::entity::billing_dto::SubscriptionPlan;
@ -119,4 +119,6 @@ pub trait ChatCloudService: Send + Sync + 'static {
chat_id: &str, chat_id: &str,
params: UpdateChatParams, params: UpdateChatParams,
) -> Result<(), FlowyError>; ) -> Result<(), FlowyError>;
async fn get_available_models(&self, workspace_id: &str) -> Result<ModelList, FlowyError>;
} }

View file

@ -10,7 +10,7 @@ use std::collections::HashMap;
use appflowy_plugin::manager::PluginManager; use appflowy_plugin::manager::PluginManager;
use dashmap::DashMap; use dashmap::DashMap;
use flowy_ai_pub::cloud::{ChatCloudService, ChatSettings, UpdateChatParams}; use flowy_ai_pub::cloud::{ChatCloudService, ChatSettings, ModelList, UpdateChatParams};
use flowy_error::{FlowyError, FlowyResult}; use flowy_error::{FlowyError, FlowyResult};
use flowy_sqlite::kv::KVStorePreferences; use flowy_sqlite::kv::KVStorePreferences;
use flowy_sqlite::DBConnection; use flowy_sqlite::DBConnection;
@ -241,6 +241,15 @@ impl AIManager {
Ok(()) Ok(())
} }
pub async fn get_available_models(&self) -> FlowyResult<ModelList> {
let workspace_id = self.user_service.workspace_id()?;
let list = self
.cloud_service_wm
.get_available_models(&workspace_id)
.await?;
Ok(list)
}
pub async fn get_or_create_chat_instance(&self, chat_id: &str) -> Result<Arc<Chat>, FlowyError> { pub async fn get_or_create_chat_instance(&self, chat_id: &str) -> Result<Arc<Chat>, FlowyError> {
let chat = self.chats.get(chat_id).as_deref().cloned(); let chat = self.chats.get(chat_id).as_deref().cloned();
match chat { match chat {

View file

@ -182,6 +182,12 @@ pub struct ChatMessageListPB {
pub total: i64, pub total: i64,
} }
#[derive(Default, ProtoBuf, Validate, Clone, Debug)]
pub struct ModelConfigPB {
#[pb(index = 1)]
pub models: String,
}
impl From<RepeatedChatMessage> for ChatMessageListPB { impl From<RepeatedChatMessage> for ChatMessageListPB {
fn from(repeated_chat_message: RepeatedChatMessage) -> Self { fn from(repeated_chat_message: RepeatedChatMessage) -> Self {
let messages = repeated_chat_message let messages = repeated_chat_message

View file

@ -107,6 +107,15 @@ pub(crate) async fn regenerate_response_handler(
Ok(()) Ok(())
} }
#[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn get_available_model_list_handler(
ai_manager: AFPluginState<Weak<AIManager>>,
) -> DataResult<ModelConfigPB, FlowyError> {
let ai_manager = upgrade_ai_manager(ai_manager)?;
let models = serde_json::to_string(&ai_manager.get_available_models().await?)?;
data_result_ok(ModelConfigPB { models })
}
#[tracing::instrument(level = "debug", skip_all, err)] #[tracing::instrument(level = "debug", skip_all, err)]
pub(crate) async fn load_prev_message_handler( pub(crate) async fn load_prev_message_handler(
data: AFPluginData<LoadPrevChatMessagePB>, data: AFPluginData<LoadPrevChatMessagePB>,

View file

@ -60,6 +60,10 @@ pub fn init(ai_manager: Weak<AIManager>) -> AFPlugin {
.event(AIEvent::GetChatSettings, get_chat_settings_handler) .event(AIEvent::GetChatSettings, get_chat_settings_handler)
.event(AIEvent::UpdateChatSettings, update_chat_settings_handler) .event(AIEvent::UpdateChatSettings, update_chat_settings_handler)
.event(AIEvent::RegenerateResponse, regenerate_response_handler) .event(AIEvent::RegenerateResponse, regenerate_response_handler)
.event(
AIEvent::GetAvailableModels,
get_available_model_list_handler,
)
} }
#[derive(Clone, Copy, PartialEq, Eq, Debug, Display, Hash, ProtoBuf_Enum, Flowy_Event)] #[derive(Clone, Copy, PartialEq, Eq, Debug, Display, Hash, ProtoBuf_Enum, Flowy_Event)]
@ -154,4 +158,7 @@ pub enum AIEvent {
#[event(input = "RegenerateResponsePB")] #[event(input = "RegenerateResponsePB")]
RegenerateResponse = 27, RegenerateResponse = 27,
#[event(output = "ModelConfigPB")]
GetAvailableModels = 28,
} }

View file

@ -10,9 +10,9 @@ use std::collections::HashMap;
use flowy_ai_pub::cloud::{ use flowy_ai_pub::cloud::{
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
CompleteTextParams, LocalAIConfig, MessageCursor, RelatedQuestion, RepeatedChatMessage, CompleteTextParams, LocalAIConfig, MessageCursor, ModelList, RelatedQuestion,
RepeatedRelatedQuestion, ResponseFormat, StreamAnswer, StreamComplete, SubscriptionPlan, RepeatedChatMessage, RepeatedRelatedQuestion, ResponseFormat, StreamAnswer, StreamComplete,
UpdateChatParams, SubscriptionPlan, UpdateChatParams,
}; };
use flowy_error::{FlowyError, FlowyResult}; use flowy_error::{FlowyError, FlowyResult};
use futures::{stream, Sink, StreamExt, TryStreamExt}; use futures::{stream, Sink, StreamExt, TryStreamExt};
@ -353,4 +353,8 @@ impl ChatCloudService for AICloudServiceMiddleware {
.update_chat_settings(workspace_id, chat_id, params) .update_chat_settings(workspace_id, chat_id, params)
.await .await
} }
async fn get_available_models(&self, workspace_id: &str) -> Result<ModelList, FlowyError> {
self.cloud_service.get_available_models(workspace_id).await
}
} }

View file

@ -22,7 +22,7 @@ use collab_integrate::collab_builder::{
}; };
use flowy_ai_pub::cloud::{ use flowy_ai_pub::cloud::{
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
CompleteTextParams, LocalAIConfig, MessageCursor, RepeatedChatMessage, ResponseFormat, CompleteTextParams, LocalAIConfig, MessageCursor, ModelList, RepeatedChatMessage, ResponseFormat,
StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
}; };
use flowy_database_pub::cloud::{ use flowy_database_pub::cloud::{
@ -838,6 +838,14 @@ impl ChatCloudService for ServerProvider {
.update_chat_settings(workspace_id, chat_id, params) .update_chat_settings(workspace_id, chat_id, params)
.await .await
} }
async fn get_available_models(&self, workspace_id: &str) -> Result<ModelList, FlowyError> {
self
.get_server()?
.chat_service()
.get_available_models(workspace_id)
.await
}
} }
#[async_trait] #[async_trait]

View file

@ -8,7 +8,7 @@ use client_api::entity::chat_dto::{
}; };
use flowy_ai_pub::cloud::{ use flowy_ai_pub::cloud::{
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, LocalAIConfig, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, LocalAIConfig,
StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams, ModelList, StreamAnswer, StreamComplete, SubscriptionPlan, UpdateChatParams,
}; };
use flowy_error::FlowyError; use flowy_error::FlowyError;
use futures_util::{StreamExt, TryStreamExt}; use futures_util::{StreamExt, TryStreamExt};
@ -270,4 +270,13 @@ where
.await?; .await?;
Ok(()) Ok(())
} }
async fn get_available_models(&self, workspace_id: &str) -> Result<ModelList, FlowyError> {
let list = self
.inner
.try_get_client()?
.get_model_list(workspace_id)
.await?;
Ok(list)
}
} }

View file

@ -1,4 +1,3 @@
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -7,7 +6,6 @@ use crate::af_cloud::define::ServerUser;
use anyhow::Error; use anyhow::Error;
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use client_api::collab_sync::ServerCollabMessage; use client_api::collab_sync::ServerCollabMessage;
use client_api::entity::ai_dto::AIModel;
use client_api::entity::UserMessage; use client_api::entity::UserMessage;
use client_api::notify::{TokenState, TokenStateReceiver}; use client_api::notify::{TokenState, TokenStateReceiver};
use client_api::ws::{ use client_api::ws::{
@ -124,7 +122,7 @@ impl AppFlowyServer for AppFlowyCloudServer {
} }
fn set_ai_model(&self, ai_model: &str) -> Result<(), Error> { fn set_ai_model(&self, ai_model: &str) -> Result<(), Error> {
self.client.set_ai_model(AIModel::from_str(ai_model)?); self.client.set_ai_model(ai_model.to_string());
Ok(()) Ok(())
} }

View file

@ -1,7 +1,7 @@
use client_api::entity::ai_dto::{LocalAIConfig, RepeatedRelatedQuestion}; use client_api::entity::ai_dto::{LocalAIConfig, RepeatedRelatedQuestion};
use flowy_ai_pub::cloud::{ use flowy_ai_pub::cloud::{
ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings, ChatCloudService, ChatMessage, ChatMessageMetadata, ChatMessageType, ChatSettings,
CompleteTextParams, MessageCursor, RepeatedChatMessage, ResponseFormat, StreamAnswer, CompleteTextParams, MessageCursor, ModelList, RepeatedChatMessage, ResponseFormat, StreamAnswer,
StreamComplete, SubscriptionPlan, UpdateChatParams, StreamComplete, SubscriptionPlan, UpdateChatParams,
}; };
use flowy_error::FlowyError; use flowy_error::FlowyError;
@ -144,4 +144,8 @@ impl ChatCloudService for DefaultChatCloudServiceImpl {
) -> Result<(), FlowyError> { ) -> Result<(), FlowyError> {
Err(FlowyError::not_support().with_context("Chat is not supported in local server.")) Err(FlowyError::not_support().with_context("Chat is not supported in local server."))
} }
async fn get_available_models(&self, _workspace_id: &str) -> Result<ModelList, FlowyError> {
Err(FlowyError::not_support().with_context("Chat is not supported in local server."))
}
} }

View file

@ -1,13 +1,11 @@
use std::convert::TryInto;
use std::str::FromStr;
use flowy_derive::{ProtoBuf, ProtoBuf_Enum}; use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
use flowy_user_pub::entities::*; use flowy_user_pub::entities::*;
use lib_infra::validator_fn::required_not_empty_str; use lib_infra::validator_fn::required_not_empty_str;
use std::convert::TryInto;
use validator::Validate; use validator::Validate;
use crate::entities::parser::{UserEmail, UserIcon, UserName, UserOpenaiKey, UserPassword}; use crate::entities::parser::{UserEmail, UserIcon, UserName, UserOpenaiKey, UserPassword};
use crate::entities::{AIModelPB, AuthenticatorPB}; use crate::entities::AuthenticatorPB;
use crate::errors::ErrorCode; use crate::errors::ErrorCode;
use super::parser::UserStabilityAIKey; use super::parser::UserStabilityAIKey;
@ -58,7 +56,7 @@ pub struct UserProfilePB {
pub stability_ai_key: String, pub stability_ai_key: String,
#[pb(index = 11)] #[pb(index = 11)]
pub ai_model: AIModelPB, pub ai_model: String,
} }
#[derive(ProtoBuf_Enum, Eq, PartialEq, Debug, Clone)] #[derive(ProtoBuf_Enum, Eq, PartialEq, Debug, Clone)]
@ -79,6 +77,10 @@ impl From<UserProfile> for UserProfilePB {
EncryptionType::NoEncryption => ("".to_string(), EncryptionTypePB::NoEncryption), EncryptionType::NoEncryption => ("".to_string(), EncryptionTypePB::NoEncryption),
EncryptionType::SelfEncryption(sign) => (sign, EncryptionTypePB::Symmetric), EncryptionType::SelfEncryption(sign) => (sign, EncryptionTypePB::Symmetric),
}; };
let mut ai_model = user_profile.ai_model;
if ai_model.is_empty() {
ai_model = "Default".to_string();
}
Self { Self {
id: user_profile.uid, id: user_profile.uid,
email: user_profile.email, email: user_profile.email,
@ -90,7 +92,7 @@ impl From<UserProfile> for UserProfilePB {
encryption_sign, encryption_sign,
encryption_type: encryption_ty, encryption_type: encryption_ty,
stability_ai_key: user_profile.stability_ai_key, stability_ai_key: user_profile.stability_ai_key,
ai_model: AIModelPB::from_str(&user_profile.ai_model).unwrap_or_default(), ai_model,
} }
} }
} }

View file

@ -3,7 +3,6 @@ use client_api::entity::billing_dto::{
WorkspaceSubscriptionStatus, WorkspaceUsageAndLimit, WorkspaceSubscriptionStatus, WorkspaceUsageAndLimit,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::str::FromStr;
use validator::Validate; use validator::Validate;
use flowy_derive::{ProtoBuf, ProtoBuf_Enum}; use flowy_derive::{ProtoBuf, ProtoBuf_Enum};
@ -382,14 +381,14 @@ pub struct UseAISettingPB {
pub disable_search_indexing: bool, pub disable_search_indexing: bool,
#[pb(index = 2)] #[pb(index = 2)]
pub ai_model: AIModelPB, pub ai_model: String,
} }
impl From<AFWorkspaceSettings> for UseAISettingPB { impl From<AFWorkspaceSettings> for UseAISettingPB {
fn from(value: AFWorkspaceSettings) -> Self { fn from(value: AFWorkspaceSettings) -> Self {
Self { Self {
disable_search_indexing: value.disable_search_indexing, disable_search_indexing: value.disable_search_indexing,
ai_model: AIModelPB::from_str(&value.ai_model).unwrap_or_default(), ai_model: value.ai_model,
} }
} }
} }
@ -404,7 +403,7 @@ pub struct UpdateUserWorkspaceSettingPB {
pub disable_search_indexing: Option<bool>, pub disable_search_indexing: Option<bool>,
#[pb(index = 3, one_of)] #[pb(index = 3, one_of)]
pub ai_model: Option<AIModelPB>, pub ai_model: Option<String>,
} }
impl From<UpdateUserWorkspaceSettingPB> for AFWorkspaceSettingsChange { impl From<UpdateUserWorkspaceSettingPB> for AFWorkspaceSettingsChange {
@ -414,48 +413,12 @@ impl From<UpdateUserWorkspaceSettingPB> for AFWorkspaceSettingsChange {
change = change.disable_search_indexing(disable_search_indexing); change = change.disable_search_indexing(disable_search_indexing);
} }
if let Some(ai_model) = value.ai_model { if let Some(ai_model) = value.ai_model {
change = change.ai_model(ai_model.to_str().to_string()); change = change.ai_model(ai_model);
} }
change change
} }
} }
#[derive(ProtoBuf_Enum, Debug, Clone, Eq, PartialEq, Default)]
pub enum AIModelPB {
#[default]
DefaultModel = 0,
GPT4oMini = 1,
GPT4o = 2,
Claude3Sonnet = 3,
Claude3Opus = 4,
}
impl AIModelPB {
pub fn to_str(&self) -> &str {
match self {
AIModelPB::DefaultModel => "default-model",
AIModelPB::GPT4oMini => "gpt-4o-mini",
AIModelPB::GPT4o => "gpt-4o",
AIModelPB::Claude3Sonnet => "claude-3-sonnet",
AIModelPB::Claude3Opus => "claude-3-opus",
}
}
}
impl FromStr for AIModelPB {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"gpt-3.5-turbo" => Ok(AIModelPB::GPT4oMini),
"gpt-4o" => Ok(AIModelPB::GPT4o),
"claude-3-sonnet" => Ok(AIModelPB::Claude3Sonnet),
"claude-3-opus" => Ok(AIModelPB::Claude3Opus),
_ => Ok(AIModelPB::DefaultModel),
}
}
}
#[derive(Debug, ProtoBuf, Default, Clone)] #[derive(Debug, ProtoBuf, Default, Clone)]
pub struct WorkspaceSubscriptionInfoPB { pub struct WorkspaceSubscriptionInfoPB {
#[pb(index = 1)] #[pb(index = 1)]

View file

@ -559,12 +559,12 @@ impl UserManager {
.send(); .send();
if let Some(ai_model) = &ai_model { if let Some(ai_model) = &ai_model {
if let Err(err) = self.cloud_services.set_ai_model(ai_model.to_str()) { if let Err(err) = self.cloud_services.set_ai_model(ai_model) {
error!("Set ai model failed: {}", err); error!("Set ai model failed: {}", err);
} }
let conn = self.db_connection(uid)?; let conn = self.db_connection(uid)?;
let params = UpdateUserProfileParams::new(uid).with_ai_model(ai_model.to_str()); let params = UpdateUserProfileParams::new(uid).with_ai_model(ai_model);
upsert_user_profile_change(uid, conn, UserTableChangeset::new(params))?; upsert_user_profile_change(uid, conn, UserTableChangeset::new(params))?;
} }
Ok(()) Ok(())