We communicate from client to dedicated ai server
This commit is contained in:
parent
c8bc79320e
commit
6a66cde0d0
10
Cargo.lock
generated
10
Cargo.lock
generated
@ -567,11 +567,20 @@ dependencies = [
|
|||||||
name = "frontend"
|
name = "frontend"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"frontend_protocol",
|
||||||
"js-sys",
|
"js-sys",
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
"web-sys",
|
"web-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "frontend_protocol"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"anyhow",
|
||||||
|
"deku",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "funty"
|
name = "funty"
|
||||||
version = "2.0.0"
|
version = "2.0.0"
|
||||||
@ -2563,6 +2572,7 @@ dependencies = [
|
|||||||
"blake2b_simd",
|
"blake2b_simd",
|
||||||
"deku",
|
"deku",
|
||||||
"dryoc",
|
"dryoc",
|
||||||
|
"frontend_protocol",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"serde",
|
"serde",
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"frontend",
|
"frontend",
|
||||||
|
"frontend_protocol",
|
||||||
"website",
|
"website",
|
||||||
]
|
]
|
||||||
@ -32,6 +32,9 @@ class PendingChatCompletionRecord:
|
|||||||
was_cancelled: bool = False
|
was_cancelled: bool = False
|
||||||
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def mark_cancelled(self) -> None:
|
def mark_cancelled(self) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.was_cancelled = True
|
self.was_cancelled = True
|
||||||
|
|||||||
@ -9,6 +9,7 @@ crate-type = ["cdylib"]
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
wasm-bindgen = "0.2"
|
wasm-bindgen = "0.2"
|
||||||
js-sys = "0.3"
|
js-sys = "0.3"
|
||||||
|
frontend_protocol = { path = "../frontend_protocol" }
|
||||||
|
|
||||||
[dependencies.web-sys]
|
[dependencies.web-sys]
|
||||||
version = "0.3"
|
version = "0.3"
|
||||||
@ -22,6 +23,7 @@ features = [
|
|||||||
"WheelEvent",
|
"WheelEvent",
|
||||||
"Location",
|
"Location",
|
||||||
"WebSocket",
|
"WebSocket",
|
||||||
|
"BinaryType",
|
||||||
"MessageEvent",
|
"MessageEvent",
|
||||||
"ErrorEvent",
|
"ErrorEvent",
|
||||||
"console",
|
"console",
|
||||||
|
|||||||
@ -1,7 +1,22 @@
|
|||||||
|
use std::cell::RefCell;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
use wasm_bindgen::JsCast;
|
use wasm_bindgen::JsCast;
|
||||||
use web_sys::{console, window, ErrorEvent, Event, MessageEvent, WebSocket};
|
use web_sys::{console, window, BinaryType, ErrorEvent, Event, MessageEvent, WebSocket};
|
||||||
|
use js_sys::{ArrayBuffer, Uint8Array};
|
||||||
|
use frontend_protocol::{
|
||||||
|
ChatMessage,
|
||||||
|
DekuBytes,
|
||||||
|
UserChatCompletionRequest,
|
||||||
|
UserRequest,
|
||||||
|
UserRequestPayload,
|
||||||
|
UserResponse,
|
||||||
|
UserResponsePayload,
|
||||||
|
};
|
||||||
|
|
||||||
|
thread_local! {
|
||||||
|
static WS_HANDLE: RefCell<Option<Rc<WebSocket>>> = RefCell::new(None);
|
||||||
|
}
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
pub fn init_chat() -> Result<(), JsValue> {
|
pub fn init_chat() -> Result<(), JsValue> {
|
||||||
@ -13,24 +28,84 @@ pub fn init_chat() -> Result<(), JsValue> {
|
|||||||
let ws_url = format!("{scheme}://{host}/chat");
|
let ws_url = format!("{scheme}://{host}/chat");
|
||||||
|
|
||||||
let ws = Rc::new(WebSocket::new(&ws_url)?);
|
let ws = Rc::new(WebSocket::new(&ws_url)?);
|
||||||
|
ws.set_binary_type(BinaryType::Arraybuffer);
|
||||||
|
console::log_1(&format!("[ws] connecting to {ws_url}").into());
|
||||||
|
|
||||||
let ws_for_open = ws.clone();
|
let ws_for_open = ws.clone();
|
||||||
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
|
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
|
||||||
let _ = ws_for_open.send_with_str(
|
console::log_1(&"[ws] connected".into());
|
||||||
r#"this is the first message first chat
|
let requests = [
|
||||||
too bad it is in lower case one one one"#);
|
UserRequest {
|
||||||
let _ = ws_for_open.send_with_str(
|
request_id: 1,
|
||||||
"And this is message two of chat request 2. \
|
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
|
||||||
Too bad it isn't processed two two two");
|
ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: r#"this is the first message first chat
|
||||||
|
too bad it is in lower case one one one"#
|
||||||
|
.to_string(),
|
||||||
|
},
|
||||||
|
])),
|
||||||
|
},
|
||||||
|
UserRequest {
|
||||||
|
request_id: 2,
|
||||||
|
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
|
||||||
|
ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "And this is message two of chat request 2. \
|
||||||
|
Too bad it isn't processed two two two"
|
||||||
|
.to_string(),
|
||||||
|
},
|
||||||
|
])),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
for request in requests {
|
||||||
|
match request.to_bytes() {
|
||||||
|
Ok(bytes) => {
|
||||||
|
console::log_1(&format!("[ws] sending request_id={} bytes={}", request.request_id, bytes.len()).into());
|
||||||
|
let _ = ws_for_open.send_with_u8_array(&bytes);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
console::error_1(&format!("[ws] encode error: {err:#}").into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}));
|
}));
|
||||||
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
|
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
|
||||||
onopen.forget();
|
onopen.forget();
|
||||||
|
|
||||||
let onmessage = Closure::<dyn FnMut(MessageEvent)>::wrap(Box::new(move |event: MessageEvent| {
|
let onmessage = Closure::<dyn FnMut(MessageEvent)>::wrap(Box::new(move |event: MessageEvent| {
|
||||||
if let Some(text) = event.data().as_string() {
|
let data = event.data();
|
||||||
console::log_1(&format!("[ws] {text}").into());
|
if let Some(text) = data.as_string() {
|
||||||
} else {
|
console::log_1(&format!("[ws] unexpected text frame: {text}").into());
|
||||||
console::log_1(&"[ws] non-text message".into());
|
return;
|
||||||
|
}
|
||||||
|
if !data.is_instance_of::<ArrayBuffer>() {
|
||||||
|
console::log_1(&"[ws] unexpected non-binary frame".into());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = Uint8Array::new(&data);
|
||||||
|
let bytes = data.to_vec();
|
||||||
|
console::log_1(&format!("[ws] received bytes={}", bytes.len()).into());
|
||||||
|
let response = match UserResponse::from_bytes(&bytes) {
|
||||||
|
Ok(response) => response,
|
||||||
|
Err(err) => {
|
||||||
|
console::error_1(&format!("[ws] decode error: {err:#} (bytes={})", bytes.len()).into());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match response.payload {
|
||||||
|
UserResponsePayload::ChatCompletion(payload) => {
|
||||||
|
console::log_1(&format!("[ws] request_id={} piece={}", response.request_id, payload.piece).into());
|
||||||
|
}
|
||||||
|
UserResponsePayload::ChatCompletionCancellation(_) => {
|
||||||
|
console::log_1(&format!("[ws] request_id={} [cancel]", response.request_id).into());
|
||||||
|
}
|
||||||
|
UserResponsePayload::ChatCompletionEnd(_) => {
|
||||||
|
console::log_1(&format!("[ws] request_id={} [end]", response.request_id).into());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
|
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
|
||||||
@ -42,6 +117,6 @@ pub fn init_chat() -> Result<(), JsValue> {
|
|||||||
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
|
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
|
||||||
onerror.forget();
|
onerror.forget();
|
||||||
|
|
||||||
let _ = ws.clone();
|
WS_HANDLE.with(|slot| *slot.borrow_mut() = Some(ws.clone()));
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
12
frontend_protocol/Cargo.toml
Normal file
12
frontend_protocol/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[package]
|
||||||
|
name = "frontend_protocol"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["std"]
|
||||||
|
std = ["deku/std"]
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
deku = { version = "0.20", default-features = false, features = ["alloc", "bits"] }
|
||||||
|
anyhow = "1.0"
|
||||||
110
frontend_protocol/src/lib.rs
Normal file
110
frontend_protocol/src/lib.rs
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
pub mod utils;
|
||||||
|
|
||||||
|
use deku::prelude::*;
|
||||||
|
pub use self::utils::{
|
||||||
|
DekuBytes,
|
||||||
|
read_pascal_string,
|
||||||
|
read_vec_u32,
|
||||||
|
write_pascal_string,
|
||||||
|
write_vec_u32,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
#[deku(
|
||||||
|
reader = "read_pascal_string(deku::reader)",
|
||||||
|
writer = "write_pascal_string(deku::writer, &self.role)"
|
||||||
|
)]
|
||||||
|
pub role: String,
|
||||||
|
|
||||||
|
#[deku(
|
||||||
|
reader = "read_pascal_string(deku::reader)",
|
||||||
|
writer = "write_pascal_string(deku::writer, &self.content)"
|
||||||
|
)]
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[deku::deku_derive(DekuRead, DekuWrite)]
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct UserChatCompletionRequest {
|
||||||
|
#[deku(
|
||||||
|
reader = "read_vec_u32(deku::reader)",
|
||||||
|
writer = "write_vec_u32(deku::writer, &self.messages)"
|
||||||
|
)]
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UserChatCompletionRequest {
|
||||||
|
pub fn new(messages: Vec<ChatMessage>) -> Self {
|
||||||
|
Self { messages }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
|
pub struct UserChatCompletionCancellationRequest;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
|
#[deku(id_type = "u8")]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum UserRequestPayload {
|
||||||
|
#[deku(id = "0")]
|
||||||
|
ChatCompletion(UserChatCompletionRequest),
|
||||||
|
#[deku(id = "1")]
|
||||||
|
ChatCompletionCancellation(UserChatCompletionCancellationRequest),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
|
pub struct UserRequest {
|
||||||
|
#[deku(endian = "little")]
|
||||||
|
pub request_id: u64,
|
||||||
|
pub payload: UserRequestPayload,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
|
pub struct UserResponseChatCompletion {
|
||||||
|
#[deku(
|
||||||
|
reader = "read_pascal_string(deku::reader)",
|
||||||
|
writer = "write_pascal_string(deku::writer, &self.piece)"
|
||||||
|
)]
|
||||||
|
pub piece: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
|
pub struct UserResponseChatCompletionCancellation;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
|
pub struct UserResponseChatCompletionEnd;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
|
#[deku(id_type = "u8")]
|
||||||
|
#[repr(u8)]
|
||||||
|
pub enum UserResponsePayload {
|
||||||
|
#[deku(id = "0")]
|
||||||
|
ChatCompletion(UserResponseChatCompletion),
|
||||||
|
#[deku(id = "1")]
|
||||||
|
ChatCompletionCancellation(UserResponseChatCompletionCancellation),
|
||||||
|
#[deku(id = "2")]
|
||||||
|
ChatCompletionEnd(UserResponseChatCompletionEnd),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
|
pub struct UserResponse {
|
||||||
|
#[deku(endian = "little")]
|
||||||
|
pub request_id: u64,
|
||||||
|
pub payload: UserResponsePayload,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DekuBytes for ChatMessage {}
|
||||||
|
|
||||||
|
impl DekuBytes for UserChatCompletionRequest {}
|
||||||
|
|
||||||
|
impl DekuBytes for UserChatCompletionCancellationRequest {}
|
||||||
|
|
||||||
|
impl DekuBytes for UserRequestPayload {}
|
||||||
|
impl DekuBytes for UserRequest {}
|
||||||
|
|
||||||
|
impl DekuBytes for UserResponseChatCompletion {}
|
||||||
|
impl DekuBytes for UserResponseChatCompletionCancellation {}
|
||||||
|
impl DekuBytes for UserResponseChatCompletionEnd {}
|
||||||
|
impl DekuBytes for UserResponsePayload {}
|
||||||
|
impl DekuBytes for UserResponse {}
|
||||||
84
frontend_protocol/src/utils.rs
Normal file
84
frontend_protocol/src/utils.rs
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
use anyhow::{Context, Result, bail};
|
||||||
|
use deku::ctx::{Endian, Order};
|
||||||
|
use deku::{DekuError, DekuWriter};
|
||||||
|
use deku::prelude::{Reader, Writer};
|
||||||
|
use deku::prelude::*;
|
||||||
|
|
||||||
|
pub trait DekuBytes: Sized + Clone + DekuContainerWrite + for<'a> DekuContainerRead<'a> {
|
||||||
|
fn to_bytes(&self) -> Result<Vec<u8>> {
|
||||||
|
let type_name = std::any::type_name::<Self>();
|
||||||
|
DekuContainerWrite::to_bytes(self)
|
||||||
|
.context(format!("failed to encode {type_name}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_bytes(bytes: &[u8]) -> Result<Self> {
|
||||||
|
let type_name = std::any::type_name::<Self>();
|
||||||
|
let ((rest, bit_offset), value) =
|
||||||
|
<Self as DekuContainerRead>::from_bytes((bytes, 0))
|
||||||
|
.context(format!("failed to decode {type_name}"))?;
|
||||||
|
|
||||||
|
if !rest.is_empty() || bit_offset != 0 {
|
||||||
|
bail!("trailing bytes after {type_name}");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_pascal_string<R>(reader: &mut Reader<R>) -> core::result::Result<String, DekuError>
|
||||||
|
where
|
||||||
|
R: deku::no_std_io::Read + deku::no_std_io::Seek,
|
||||||
|
{
|
||||||
|
let byte_len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize;
|
||||||
|
let mut bytes = vec![0u8; byte_len];
|
||||||
|
reader.read_bytes(byte_len, &mut bytes, Order::Msb0)?;
|
||||||
|
|
||||||
|
String::from_utf8(bytes).map_err(|_| DekuError::Parse("invalid utf-8".into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_pascal_string<W>(
|
||||||
|
writer: &mut Writer<W>,
|
||||||
|
value: &String,
|
||||||
|
) -> core::result::Result<(), DekuError>
|
||||||
|
where
|
||||||
|
W: deku::no_std_io::Write + deku::no_std_io::Seek,
|
||||||
|
{
|
||||||
|
let bytes = value.as_bytes();
|
||||||
|
let byte_len =
|
||||||
|
u32::try_from(bytes.len()).map_err(|_| DekuError::Parse("string too large".into()))?;
|
||||||
|
|
||||||
|
byte_len.to_writer(writer, Endian::Little)?;
|
||||||
|
writer.write_bytes(bytes)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_vec_u32<R, T>(reader: &mut Reader<R>) -> core::result::Result<Vec<T>, DekuError>
|
||||||
|
where
|
||||||
|
R: deku::no_std_io::Read + deku::no_std_io::Seek,
|
||||||
|
for<'a> T: DekuReader<'a, ()>,
|
||||||
|
{
|
||||||
|
let len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize;
|
||||||
|
let mut items = Vec::with_capacity(len);
|
||||||
|
for _ in 0..len {
|
||||||
|
let item = T::from_reader_with_ctx(reader, ())?;
|
||||||
|
items.push(item);
|
||||||
|
}
|
||||||
|
Ok(items)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn write_vec_u32<W, T>(
|
||||||
|
writer: &mut Writer<W>,
|
||||||
|
value: &Vec<T>,
|
||||||
|
) -> core::result::Result<(), DekuError>
|
||||||
|
where
|
||||||
|
W: deku::no_std_io::Write + deku::no_std_io::Seek,
|
||||||
|
T: DekuWriter<()>,
|
||||||
|
{
|
||||||
|
let len = u32::try_from(value.len())
|
||||||
|
.map_err(|_| DekuError::Parse("vector too large".into()))?;
|
||||||
|
len.to_writer(writer, Endian::Little)?;
|
||||||
|
for item in value {
|
||||||
|
item.to_writer(writer, ())?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
1
secret-config.toml
Normal file
1
secret-config.toml
Normal file
@ -0,0 +1 @@
|
|||||||
|
dedicated_ai_server_secret="change-me"
|
||||||
@ -6,7 +6,7 @@ edition = "2021"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
rand = "0.8"
|
rand = "0.8"
|
||||||
axum = { version = "0.7", features = ["multipart", "ws"] }
|
axum = { version = "0.7", features = ["multipart", "ws"] }
|
||||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "net"] }
|
tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "net", "sync", "time"] }
|
||||||
tower-http = { version = "0.5", features = ["fs"] }
|
tower-http = { version = "0.5", features = ["fs"] }
|
||||||
tera = "1.19.1"
|
tera = "1.19.1"
|
||||||
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
|
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
|
||||||
@ -19,3 +19,4 @@ anyhow = "1.0"
|
|||||||
blake2b_simd = "1.0"
|
blake2b_simd = "1.0"
|
||||||
dryoc = "0.7"
|
dryoc = "0.7"
|
||||||
deku = "0.20"
|
deku = "0.20"
|
||||||
|
frontend_protocol = { path = "../frontend_protocol" }
|
||||||
|
|||||||
@ -2,6 +2,22 @@ use serde::Deserialize;
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
const DEFAULT_CONFIG_PATH: &str = "config.toml";
|
const DEFAULT_CONFIG_PATH: &str = "config.toml";
|
||||||
|
const DEFAULT_SECRET_CONFIG_PATH: &str = "secret-config.toml";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
struct PublicConfig {
|
||||||
|
pub postgres_host: String,
|
||||||
|
pub pg_database: String,
|
||||||
|
pub postgres_user: String,
|
||||||
|
pub file_storage: String,
|
||||||
|
pub dedicated_ai_server_address: String,
|
||||||
|
pub dedicated_ai_server_port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
struct SecretConfig {
|
||||||
|
pub dedicated_ai_server_secret: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct GeneralServiceConfig {
|
pub struct GeneralServiceConfig {
|
||||||
@ -9,13 +25,28 @@ pub struct GeneralServiceConfig {
|
|||||||
pub pg_database: String,
|
pub pg_database: String,
|
||||||
pub postgres_user: String,
|
pub postgres_user: String,
|
||||||
pub file_storage: String,
|
pub file_storage: String,
|
||||||
|
pub dedicated_ai_server_address: String,
|
||||||
|
pub dedicated_ai_server_port: u16,
|
||||||
|
pub dedicated_ai_server_secret: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type ConfigResult<T> = Result<T, Box<dyn std::error::Error>>;
|
pub type ConfigResult<T> = Result<T, Box<dyn std::error::Error>>;
|
||||||
|
|
||||||
pub fn load_config(path: impl AsRef<Path>) -> ConfigResult<GeneralServiceConfig> {
|
pub fn load_config(config_path: impl AsRef<Path>, secret_config_path: impl AsRef<Path>) -> ConfigResult<GeneralServiceConfig> {
|
||||||
let raw = std::fs::read_to_string(path.as_ref())?;
|
let raw_config = std::fs::read_to_string(config_path.as_ref())?;
|
||||||
let config: GeneralServiceConfig = toml::from_str(&raw)?;
|
let public_config: PublicConfig = toml::from_str(&raw_config)?;
|
||||||
|
let raw_secret = std::fs::read_to_string(secret_config_path.as_ref())?;
|
||||||
|
let secret_config: SecretConfig = toml::from_str(&raw_secret)?;
|
||||||
|
|
||||||
|
let config = GeneralServiceConfig {
|
||||||
|
postgres_host: public_config.postgres_host,
|
||||||
|
pg_database: public_config.pg_database,
|
||||||
|
postgres_user: public_config.postgres_user,
|
||||||
|
file_storage: public_config.file_storage,
|
||||||
|
dedicated_ai_server_address: public_config.dedicated_ai_server_address,
|
||||||
|
dedicated_ai_server_port: public_config.dedicated_ai_server_port,
|
||||||
|
dedicated_ai_server_secret: secret_config.dedicated_ai_server_secret,
|
||||||
|
};
|
||||||
|
|
||||||
if config.pg_database.trim().is_empty() {
|
if config.pg_database.trim().is_empty() {
|
||||||
return Err(std::io::Error::new(
|
return Err(std::io::Error::new(
|
||||||
@ -41,9 +72,17 @@ pub fn load_config(path: impl AsRef<Path>) -> ConfigResult<GeneralServiceConfig>
|
|||||||
.into());
|
.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if config.dedicated_ai_server_secret.trim().is_empty() {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
"config dedicated_ai_server_secret is empty",
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn load_config_default() -> ConfigResult<GeneralServiceConfig> {
|
pub fn load_config_default() -> ConfigResult<GeneralServiceConfig> {
|
||||||
load_config(DEFAULT_CONFIG_PATH)
|
load_config(DEFAULT_CONFIG_PATH, DEFAULT_SECRET_CONFIG_PATH)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,11 +3,11 @@ use tokio::net::{TcpListener, TcpStream};
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use crate::dedicated_ai_server::api::{
|
use crate::dedicated_ai_server::api::{
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
DekuBytes,
|
|
||||||
MessageInChat,
|
MessageInChat,
|
||||||
Request,
|
Request,
|
||||||
RequestPayload,
|
RequestPayload,
|
||||||
};
|
};
|
||||||
|
use frontend_protocol::DekuBytes;
|
||||||
use crate::dedicated_ai_server::talking::{SecretStreamSocket, wrap_connection_socket};
|
use crate::dedicated_ai_server::talking::{SecretStreamSocket, wrap_connection_socket};
|
||||||
use crate::dedicated_ai_server::talking::{ProtocolError, FrameCallback};
|
use crate::dedicated_ai_server::talking::{ProtocolError, FrameCallback};
|
||||||
|
|
||||||
|
|||||||
@ -1,39 +1,12 @@
|
|||||||
use anyhow::{Context, Result, bail};
|
|
||||||
use deku::prelude::*;
|
use deku::prelude::*;
|
||||||
use super::marshalling_utils::{
|
use frontend_protocol::{
|
||||||
read_bool_u8,
|
DekuBytes,
|
||||||
read_pascal_string,
|
read_pascal_string,
|
||||||
write_bool_u8,
|
read_vec_u32,
|
||||||
write_pascal_string,
|
write_pascal_string,
|
||||||
|
write_vec_u32,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub trait DekuBytes: Sized + Clone + DekuContainerWrite + for<'a> DekuContainerRead<'a> {
|
|
||||||
fn pre_encode(&mut self) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn to_bytes(&self) -> Result<Vec<u8>> {
|
|
||||||
let mut value = self.clone();
|
|
||||||
value.pre_encode()?;
|
|
||||||
let type_name = std::any::type_name::<Self>();
|
|
||||||
DekuContainerWrite::to_bytes(&value)
|
|
||||||
.context(format!("failed to encode {type_name}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn from_bytes(bytes: &[u8]) -> Result<Self> {
|
|
||||||
let type_name = std::any::type_name::<Self>();
|
|
||||||
let ((rest, bit_offset), value) =
|
|
||||||
<Self as DekuContainerRead>::from_bytes((bytes, 0))
|
|
||||||
.context(format!("failed to decode {type_name}"))?;
|
|
||||||
|
|
||||||
if !rest.is_empty() || bit_offset != 0 {
|
|
||||||
bail!("trailing bytes after {type_name}");
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
pub struct MessageInChat {
|
pub struct MessageInChat {
|
||||||
#[deku(
|
#[deku(
|
||||||
@ -49,24 +22,20 @@ pub struct MessageInChat {
|
|||||||
pub content: String,
|
pub content: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Deku is a joke. A pathetic excuse for a marshalling library. But it's okay */
|
|
||||||
#[deku::deku_derive(DekuRead, DekuWrite)]
|
#[deku::deku_derive(DekuRead, DekuWrite)]
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
pub struct ChatCompletionRequest {
|
pub struct ChatCompletionRequest {
|
||||||
#[deku(endian = "little", update = "self.messages.len() as u32")]
|
#[deku(
|
||||||
message_count: u32,
|
reader = "read_vec_u32(deku::reader)",
|
||||||
|
writer = "write_vec_u32(deku::writer, &self.messages)"
|
||||||
#[deku(count = "*message_count as usize")]
|
)]
|
||||||
pub messages: Vec<MessageInChat>,
|
pub messages: Vec<MessageInChat>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl ChatCompletionRequest {
|
impl ChatCompletionRequest {
|
||||||
pub fn new(messages: Vec<MessageInChat>) -> Self {
|
pub fn new(messages: Vec<MessageInChat>) -> Self {
|
||||||
Self {
|
Self { messages }
|
||||||
message_count: messages.len() as u32,
|
|
||||||
messages,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,37 +96,12 @@ pub struct Response {
|
|||||||
|
|
||||||
impl DekuBytes for MessageInChat {}
|
impl DekuBytes for MessageInChat {}
|
||||||
|
|
||||||
impl DekuBytes for ChatCompletionRequest {
|
impl DekuBytes for ChatCompletionRequest {}
|
||||||
fn pre_encode(&mut self) -> Result<()> {
|
|
||||||
self.update()
|
|
||||||
.context("failed to update ChatCompletionRequest before encoding")?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DekuBytes for ChatCompletionCancellationRequest {}
|
impl DekuBytes for ChatCompletionCancellationRequest {}
|
||||||
|
|
||||||
impl DekuBytes for RequestPayload {
|
impl DekuBytes for RequestPayload {}
|
||||||
fn pre_encode(&mut self) -> Result<()> {
|
impl DekuBytes for Request {}
|
||||||
if let RequestPayload::ChatCompletion(batch) = self {
|
|
||||||
batch
|
|
||||||
.update()
|
|
||||||
.context("failed to update ChatCompletionRequest before encoding")?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DekuBytes for Request {
|
|
||||||
fn pre_encode(&mut self) -> Result<()> {
|
|
||||||
if let RequestPayload::ChatCompletion(batch) = &mut self.payload {
|
|
||||||
batch
|
|
||||||
.update()
|
|
||||||
.context("failed to update ChatCompletionRequest before encoding")?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DekuBytes for ResponseChatCompletion {}
|
impl DekuBytes for ResponseChatCompletion {}
|
||||||
impl DekuBytes for ResponseChatCompletionCancellation {}
|
impl DekuBytes for ResponseChatCompletionCancellation {}
|
||||||
|
|||||||
205
website/src/dedicated_ai_server/connection.rs
Normal file
205
website/src/dedicated_ai_server/connection.rs
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
|
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use tokio::net::TcpStream;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use tokio::time::{sleep, Duration};
|
||||||
|
|
||||||
|
use frontend_protocol::DekuBytes;
|
||||||
|
use super::api::{
|
||||||
|
ChatCompletionRequest,
|
||||||
|
MessageInChat,
|
||||||
|
Request,
|
||||||
|
RequestPayload,
|
||||||
|
Response,
|
||||||
|
ResponsePayload,
|
||||||
|
};
|
||||||
|
use super::talking::{SecretStreamSocket, wrap_connection_socket};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct MessagePiecePayload(pub String);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum MessagePiece {
|
||||||
|
Piece(MessagePiecePayload), End, Cancelled, DedicatedServerDisconnected
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct PendingChatCompletionRecord {
|
||||||
|
response_tx: mpsc::UnboundedSender<MessagePiece>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct RequestWorkItem {
|
||||||
|
request: Request,
|
||||||
|
response_tx: mpsc::UnboundedSender<MessagePiece>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct DedicatedAiServerConnection {
|
||||||
|
address: String,
|
||||||
|
port: u16,
|
||||||
|
secret: String,
|
||||||
|
request_tx: mpsc::UnboundedSender<RequestWorkItem>,
|
||||||
|
next_request_id: AtomicU64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn connect_to_dedicated_ai_server(address: String, port: u16, secret: String)
|
||||||
|
-> (DedicatedAiServerConnection, JoinHandle<()>) {
|
||||||
|
let (request_tx, request_rx) = mpsc::unbounded_channel();
|
||||||
|
let worker_address = address.clone();
|
||||||
|
let worker_secret = secret.clone();
|
||||||
|
|
||||||
|
let join_hand = tokio::spawn(async move {
|
||||||
|
connection_worker_loop(worker_address, port, worker_secret, request_rx).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
let conn_bridge = DedicatedAiServerConnection {
|
||||||
|
address,
|
||||||
|
port,
|
||||||
|
secret,
|
||||||
|
request_tx,
|
||||||
|
next_request_id: AtomicU64::new(1),
|
||||||
|
};
|
||||||
|
(conn_bridge, join_hand)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DedicatedAiServerConnection {
|
||||||
|
pub fn send_chat_completion(
|
||||||
|
&self,
|
||||||
|
messages: Vec<MessageInChat>,
|
||||||
|
) -> Result<mpsc::UnboundedReceiver<MessagePiece>> {
|
||||||
|
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||||
|
let request = Request {
|
||||||
|
request_id,
|
||||||
|
payload: RequestPayload::ChatCompletion(ChatCompletionRequest::new(messages)),
|
||||||
|
};
|
||||||
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
|
self.request_tx
|
||||||
|
.send(RequestWorkItem { request, response_tx })
|
||||||
|
.map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?;
|
||||||
|
Ok(response_rx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connection_worker_loop(
|
||||||
|
address: String,
|
||||||
|
port: u16,
|
||||||
|
secret: String,
|
||||||
|
mut request_rx: mpsc::UnboundedReceiver<RequestWorkItem>,
|
||||||
|
) {
|
||||||
|
// todo : fix a lot of errors
|
||||||
|
loop {
|
||||||
|
let socket = match TcpStream::connect((address.as_str(), port)).await {
|
||||||
|
Ok(socket) => socket,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!(
|
||||||
|
"[dedicated-ai] failed to connect to {}:{}: {err}",
|
||||||
|
address, port
|
||||||
|
);
|
||||||
|
sleep(Duration::from_secs(1)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut transport = match wrap_connection_socket(socket, &secret).await {
|
||||||
|
Ok(transport) => transport,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("[dedicated-ai] failed to wrap connection: {err:#}");
|
||||||
|
sleep(Duration::from_secs(1)).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut pending: HashMap<u64, PendingChatCompletionRecord> = HashMap::new();
|
||||||
|
let result = run_connected_loop(&mut transport, &mut request_rx, &mut pending).await;
|
||||||
|
if let Err(err) = result {
|
||||||
|
eprintln!("[dedicated-ai] connection error: {err:#}");
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel_all_pending(&mut pending);
|
||||||
|
let _ = transport.close(true).await;
|
||||||
|
sleep(Duration::from_secs(1)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_connected_loop(
|
||||||
|
transport: &mut SecretStreamSocket,
|
||||||
|
request_rx: &mut mpsc::UnboundedReceiver<RequestWorkItem>,
|
||||||
|
pending: &mut HashMap<u64, PendingChatCompletionRecord>,
|
||||||
|
) -> Result<()> {
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
frame = transport.recv_frame() => {
|
||||||
|
let frame = frame.context("failed to read frame")?;
|
||||||
|
let frame = match frame {
|
||||||
|
Some(bytes) => bytes,
|
||||||
|
None => return Err(anyhow::anyhow!("connection closed by peer")),
|
||||||
|
};
|
||||||
|
handle_response_frame(frame, pending)?;
|
||||||
|
}
|
||||||
|
maybe_item = request_rx.recv() => {
|
||||||
|
let item = match maybe_item {
|
||||||
|
Some(item) => item,
|
||||||
|
None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen
|
||||||
|
};
|
||||||
|
let request_id = item.request.request_id;
|
||||||
|
pending.insert(
|
||||||
|
request_id,
|
||||||
|
PendingChatCompletionRecord { response_tx: item.response_tx },
|
||||||
|
);
|
||||||
|
|
||||||
|
let payload = item.request.to_bytes().context("failed to encode request")?;
|
||||||
|
transport.send_frame(&payload).await.context("failed to send request")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_response_frame(
|
||||||
|
frame: Vec<u8>,
|
||||||
|
pending: &mut HashMap<u64, PendingChatCompletionRecord>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let response = Response::from_bytes(&frame).context("failed to decode response")?;
|
||||||
|
let request_id = response.request_id;
|
||||||
|
let record = pending.get(&request_id);
|
||||||
|
if record.is_none() {
|
||||||
|
eprintln!("[dedicated-ai] response for unknown request_id={request_id}");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
match response.payload {
|
||||||
|
ResponsePayload::ChatCompletion(payload) => {
|
||||||
|
let send_failed = match pending.get(&request_id) {
|
||||||
|
Some(record) => record
|
||||||
|
.response_tx
|
||||||
|
.send(MessagePiece::Piece( MessagePiecePayload(payload.piece) ))
|
||||||
|
.is_err(),
|
||||||
|
None => false,
|
||||||
|
};
|
||||||
|
if send_failed {
|
||||||
|
pending.remove(&request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ResponsePayload::ChatCompletionEnd(_) => {
|
||||||
|
if let Some(record) = pending.remove(&request_id) {
|
||||||
|
let _ = record.response_tx.send(MessagePiece::End);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ResponsePayload::ChatCompletionCancellation(_) => {
|
||||||
|
if let Some(record) = pending.remove(&request_id) {
|
||||||
|
let _ = record.response_tx.send(MessagePiece::Cancelled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cancel_all_pending(pending: &mut HashMap<u64, PendingChatCompletionRecord>) {
|
||||||
|
for (_request_id, record) in pending.drain() {
|
||||||
|
let _ = record.response_tx.send(MessagePiece::DedicatedServerDisconnected);
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,56 +0,0 @@
|
|||||||
use deku::ctx::{Endian, Order};
|
|
||||||
use deku::{DekuError, DekuWriter};
|
|
||||||
use deku::prelude::{Reader, Writer};
|
|
||||||
use deku::prelude::*;
|
|
||||||
|
|
||||||
pub fn read_pascal_string<R>(reader: &mut Reader<R>) -> core::result::Result<String, DekuError>
|
|
||||||
where
|
|
||||||
R: deku::no_std_io::Read + deku::no_std_io::Seek,
|
|
||||||
{
|
|
||||||
let byte_len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize;
|
|
||||||
let mut bytes = vec![0u8; byte_len];
|
|
||||||
reader.read_bytes(byte_len, &mut bytes, Order::Msb0)?;
|
|
||||||
|
|
||||||
String::from_utf8(bytes).map_err(|err| DekuError::Parse(err.to_string().into()))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
pub fn write_pascal_string<W>(
|
|
||||||
writer: &mut Writer<W>,
|
|
||||||
value: &String,
|
|
||||||
) -> core::result::Result<(), DekuError>
|
|
||||||
where
|
|
||||||
W: deku::no_std_io::Write + deku::no_std_io::Seek,
|
|
||||||
{
|
|
||||||
let bytes = value.as_bytes();
|
|
||||||
let byte_len =
|
|
||||||
u32::try_from(bytes.len()).map_err(|_| DekuError::Parse("string too large".into()))?;
|
|
||||||
|
|
||||||
byte_len.to_writer(writer, Endian::Little)?;
|
|
||||||
writer.write_bytes(bytes)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn read_bool_u8<R>(reader: &mut Reader<R>) -> core::result::Result<bool, DekuError>
|
|
||||||
where
|
|
||||||
R: deku::no_std_io::Read + deku::no_std_io::Seek,
|
|
||||||
{
|
|
||||||
let value = u8::from_reader_with_ctx(reader, Endian::Little)?;
|
|
||||||
match value {
|
|
||||||
0 => Ok(false),
|
|
||||||
1 => Ok(true),
|
|
||||||
_ => Err(DekuError::Parse("invalid bool value".into())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn write_bool_u8<W>(
|
|
||||||
writer: &mut Writer<W>,
|
|
||||||
value: &bool,
|
|
||||||
) -> core::result::Result<(), DekuError>
|
|
||||||
where
|
|
||||||
W: deku::no_std_io::Write + deku::no_std_io::Seek,
|
|
||||||
{
|
|
||||||
let encoded: u8 = if *value { 1 } else { 0 };
|
|
||||||
encoded.to_writer(writer, Endian::Little)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@ -1,4 +1,4 @@
|
|||||||
pub mod api;
|
pub mod api;
|
||||||
|
pub mod connection;
|
||||||
pub mod talking;
|
pub mod talking;
|
||||||
pub mod TEST;
|
pub mod TEST;
|
||||||
mod marshalling_utils;
|
|
||||||
@ -17,24 +17,38 @@ use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
|
|||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::future::Future;
|
use std::future::Future;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::pin::Pin;
|
|
||||||
use std::sync::OnceLock;
|
|
||||||
use tera::{Tera};
|
use tera::{Tera};
|
||||||
use tokio::time::{sleep, Duration};
|
|
||||||
use tower_http::services::{ServeDir, ServeFile};
|
use tower_http::services::{ServeDir, ServeFile};
|
||||||
|
|
||||||
use samplers::random_junk::random_string;
|
use samplers::random_junk::random_string;
|
||||||
use config::{ConfigResult, GeneralServiceConfig, load_config, load_config_default};
|
use config::load_config_default;
|
||||||
use db::{DbResult, connect_db, init_database};
|
use db::{connect_db, init_database};
|
||||||
use web_file_uploads::{get_file, upload_get, upload_post};
|
use web_file_uploads::{get_file, upload_get, upload_post};
|
||||||
use web_app_state::{AppState, AppStateInner, AuthenticatedUserId};
|
use web_app_state::{AppState, AppStateInner, AuthenticatedUserId};
|
||||||
|
use dedicated_ai_server::connection::connect_to_dedicated_ai_server;
|
||||||
|
use dedicated_ai_server::api::MessageInChat;
|
||||||
|
use crate::dedicated_ai_server::connection::{MessagePiece, MessagePiecePayload};
|
||||||
|
use frontend_protocol::{
|
||||||
|
DekuBytes,
|
||||||
|
UserRequest,
|
||||||
|
UserRequestPayload,
|
||||||
|
UserResponse,
|
||||||
|
UserResponseChatCompletion,
|
||||||
|
UserResponseChatCompletionCancellation,
|
||||||
|
UserResponseChatCompletionEnd,
|
||||||
|
UserResponsePayload,
|
||||||
|
};
|
||||||
|
|
||||||
|
async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box<dyn std::error::Error>> {
|
||||||
async fn init_app_state() -> Result<AppState, Box<dyn std::error::Error>> {
|
|
||||||
let config = load_config_default()?;
|
let config = load_config_default()?;
|
||||||
let db = connect_db(&config).await?;
|
let db = connect_db(&config).await?;
|
||||||
let tera = Tera::new("frontend/pages/**/*.html")?;
|
let tera = Tera::new("frontend/pages/**/*.html")?;
|
||||||
Ok(std::sync::Arc::new(AppStateInner { config, db, tera }))
|
let (dedicated_ai, dedicated_ai_task_handler) = connect_to_dedicated_ai_server(
|
||||||
|
config.dedicated_ai_server_address.clone(),
|
||||||
|
config.dedicated_ai_server_port,
|
||||||
|
config.dedicated_ai_server_secret.clone(),
|
||||||
|
);
|
||||||
|
Ok((std::sync::Arc::new(AppStateInner { config, db, tera, dedicated_ai }), dedicated_ai_task_handler))
|
||||||
}
|
}
|
||||||
|
|
||||||
enum PasscodeAuthenticationResult{
|
enum PasscodeAuthenticationResult{
|
||||||
@ -235,7 +249,9 @@ async fn chat(
|
|||||||
user: AuthenticatedUserId,
|
user: AuthenticatedUserId,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
if let Some(ws) = ws {
|
if let Some(ws) = ws {
|
||||||
return ws.on_upgrade(move |socket| handle_chat_socket(socket)).into_response();
|
return ws
|
||||||
|
.on_upgrade(move |socket| handle_chat_socket(socket, state.clone()))
|
||||||
|
.into_response();
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut ctx = tera::Context::new();
|
let mut ctx = tera::Context::new();
|
||||||
@ -247,21 +263,107 @@ async fn chat(
|
|||||||
Html(body).into_response()
|
Html(body).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_chat_socket(mut socket: WebSocket) {
|
async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
|
||||||
while let Some(msg) = socket.recv().await {
|
'outer: while let Some(msg) = socket.recv().await {
|
||||||
let msg = match msg {
|
let msg = match msg {
|
||||||
Ok(msg) => msg,
|
Ok(msg) => msg,
|
||||||
Err(_) => break,
|
Err(_) => break,
|
||||||
};
|
};
|
||||||
|
|
||||||
match msg {
|
match msg {
|
||||||
Message::Text(text) => {
|
Message::Binary(bytes) => {
|
||||||
sleep(Duration::from_secs(2)).await;
|
let request = match UserRequest::from_bytes(&bytes) {
|
||||||
let upper = text.to_uppercase();
|
Ok(request) => request,
|
||||||
if socket.send(Message::Text(upper)).await.is_err() {
|
Err(err) => {
|
||||||
|
eprintln!("[chat] failed to decode request: {err:#}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match request.payload {
|
||||||
|
UserRequestPayload::ChatCompletion(payload) => {
|
||||||
|
let messages = payload
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|msg| MessageInChat {
|
||||||
|
role: msg.role,
|
||||||
|
content: msg.content,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut response_rx = match state.dedicated_ai.send_chat_completion(messages) {
|
||||||
|
Ok(rx) => rx,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("[chat] failed to send request: {err:#}");
|
||||||
|
let response = UserResponse {
|
||||||
|
request_id: request.request_id,
|
||||||
|
payload: UserResponsePayload::ChatCompletionCancellation(
|
||||||
|
UserResponseChatCompletionCancellation,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
if let Ok(bytes) = response.to_bytes() {
|
||||||
|
let _ = socket.send(Message::Binary(bytes)).await;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
while let Some(piece) = response_rx.recv().await {
|
||||||
|
let (payload, should_break) = match piece {
|
||||||
|
MessagePiece::Piece(MessagePiecePayload(text)) => (
|
||||||
|
UserResponsePayload::ChatCompletion(UserResponseChatCompletion {
|
||||||
|
piece: text,
|
||||||
|
}),
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
MessagePiece::End => (
|
||||||
|
UserResponsePayload::ChatCompletionEnd(
|
||||||
|
UserResponseChatCompletionEnd,
|
||||||
|
),
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => (
|
||||||
|
UserResponsePayload::ChatCompletionCancellation(
|
||||||
|
UserResponseChatCompletionCancellation,
|
||||||
|
),
|
||||||
|
true,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = UserResponse {
|
||||||
|
request_id: request.request_id,
|
||||||
|
payload,
|
||||||
|
};
|
||||||
|
let bytes = match response.to_bytes() {
|
||||||
|
Ok(bytes) => bytes,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("[chat] failed to encode response: {err:#}");
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if socket.send(Message::Binary(bytes)).await.is_err() {
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
if should_break {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
UserRequestPayload::ChatCompletionCancellation(_) => {
|
||||||
|
let response = UserResponse {
|
||||||
|
request_id: request.request_id,
|
||||||
|
payload: UserResponsePayload::ChatCompletionCancellation(
|
||||||
|
UserResponseChatCompletionCancellation,
|
||||||
|
),
|
||||||
|
};
|
||||||
|
if let Ok(bytes) = response.to_bytes() {
|
||||||
|
if socket.send(Message::Binary(bytes)).await.is_err() {
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Message::Close(_) => break,
|
Message::Close(_) => break,
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
@ -270,7 +372,7 @@ async fn handle_chat_socket(mut socket: WebSocket) {
|
|||||||
|
|
||||||
|
|
||||||
pub async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
let state = init_app_state().await.expect("lol");
|
let (state, mut dedicated_ai_task) = init_app_state().await.expect("lol");
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/login", get(login_get))
|
.route("/login", get(login_get))
|
||||||
.route("/login", post(login_post))
|
.route("/login", post(login_post))
|
||||||
@ -292,7 +394,28 @@ pub async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let listener = tokio::net::TcpListener::bind(addr)
|
let listener = tokio::net::TcpListener::bind(addr)
|
||||||
.await
|
.await
|
||||||
.expect("bind failed");
|
.expect("bind failed");
|
||||||
axum::serve(listener, app).await?;
|
let mut server_task = tokio::spawn(async move { axum::serve(listener, app).await });
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
res = &mut server_task => {
|
||||||
|
dedicated_ai_task.abort();
|
||||||
|
res??;
|
||||||
|
}
|
||||||
|
res = &mut dedicated_ai_task => {
|
||||||
|
server_task.abort();
|
||||||
|
res.map_err(|err| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::Other,
|
||||||
|
format!("dedicated ai task failed: {err}"),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::Other,
|
||||||
|
"dedicated ai task ended unexpectedly",
|
||||||
|
)
|
||||||
|
.into());
|
||||||
|
}
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
use tera::Tera;
|
use tera::Tera;
|
||||||
use crate::GeneralServiceConfig;
|
use crate::config::GeneralServiceConfig;
|
||||||
|
use crate::dedicated_ai_server::connection::DedicatedAiServerConnection;
|
||||||
|
|
||||||
pub struct AppStateInner {
|
pub struct AppStateInner {
|
||||||
pub config: GeneralServiceConfig,
|
pub config: GeneralServiceConfig,
|
||||||
pub db: tokio_postgres::Client,
|
pub db: tokio_postgres::Client,
|
||||||
pub tera: Tera,
|
pub tera: Tera,
|
||||||
|
pub dedicated_ai: DedicatedAiServerConnection,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type AppState = std::sync::Arc<AppStateInner>;
|
pub type AppState = std::sync::Arc<AppStateInner>;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user