From 6a66cde0d09a8af9815978fe1cff404152e21af3 Mon Sep 17 00:00:00 2001 From: Andreew Gregory Date: Sat, 28 Mar 2026 18:00:22 +0300 Subject: [PATCH] We communicate from client to dedicated ai server --- Cargo.lock | 10 + Cargo.toml | 1 + dedicated_ai_server/server.py | 3 + frontend/Cargo.toml | 2 + frontend/src/chat.rs | 99 ++++++++- frontend_protocol/Cargo.toml | 12 + frontend_protocol/src/lib.rs | 110 ++++++++++ frontend_protocol/src/utils.rs | 84 +++++++ secret-config.toml | 1 + website/Cargo.toml | 3 +- website/src/config.rs | 47 +++- website/src/dedicated_ai_server/TEST.rs | 2 +- website/src/dedicated_ai_server/api.rs | 80 +------ website/src/dedicated_ai_server/connection.rs | 205 ++++++++++++++++++ .../dedicated_ai_server/marshalling_utils.rs | 56 ----- website/src/dedicated_ai_server/mod.rs | 2 +- website/src/lib.rs | 159 ++++++++++++-- website/src/web_app_state.rs | 4 +- 18 files changed, 718 insertions(+), 162 deletions(-) create mode 100644 frontend_protocol/Cargo.toml create mode 100644 frontend_protocol/src/lib.rs create mode 100644 frontend_protocol/src/utils.rs create mode 100644 secret-config.toml create mode 100644 website/src/dedicated_ai_server/connection.rs delete mode 100644 website/src/dedicated_ai_server/marshalling_utils.rs diff --git a/Cargo.lock b/Cargo.lock index 0c40bc6..db5c87d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -567,11 +567,20 @@ dependencies = [ name = "frontend" version = "0.1.0" dependencies = [ + "frontend_protocol", "js-sys", "wasm-bindgen", "web-sys", ] +[[package]] +name = "frontend_protocol" +version = "0.1.0" +dependencies = [ + "anyhow", + "deku", +] + [[package]] name = "funty" version = "2.0.0" @@ -2563,6 +2572,7 @@ dependencies = [ "blake2b_simd", "deku", "dryoc", + "frontend_protocol", "rand 0.8.5", "reqwest", "serde", diff --git a/Cargo.toml b/Cargo.toml index c5a5520..bda479d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = [ "frontend", + "frontend_protocol", "website", ] \ No newline at end of file diff --git a/dedicated_ai_server/server.py b/dedicated_ai_server/server.py index 6f65fc8..826897c 100644 --- a/dedicated_ai_server/server.py +++ b/dedicated_ai_server/server.py @@ -32,6 +32,9 @@ class PendingChatCompletionRecord: was_cancelled: bool = False _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) + + + def mark_cancelled(self) -> None: with self._lock: self.was_cancelled = True diff --git a/frontend/Cargo.toml b/frontend/Cargo.toml index 5df32cc..22b8534 100644 --- a/frontend/Cargo.toml +++ b/frontend/Cargo.toml @@ -9,6 +9,7 @@ crate-type = ["cdylib"] [dependencies] wasm-bindgen = "0.2" js-sys = "0.3" +frontend_protocol = { path = "../frontend_protocol" } [dependencies.web-sys] version = "0.3" @@ -22,6 +23,7 @@ features = [ "WheelEvent", "Location", "WebSocket", + "BinaryType", "MessageEvent", "ErrorEvent", "console", diff --git a/frontend/src/chat.rs b/frontend/src/chat.rs index 9954579..1adb2c3 100644 --- a/frontend/src/chat.rs +++ b/frontend/src/chat.rs @@ -1,7 +1,22 @@ +use std::cell::RefCell; use std::rc::Rc; use wasm_bindgen::prelude::*; use wasm_bindgen::JsCast; -use web_sys::{console, window, ErrorEvent, Event, MessageEvent, WebSocket}; +use web_sys::{console, window, BinaryType, ErrorEvent, Event, MessageEvent, WebSocket}; +use js_sys::{ArrayBuffer, Uint8Array}; +use frontend_protocol::{ + ChatMessage, + DekuBytes, + UserChatCompletionRequest, + UserRequest, + UserRequestPayload, + UserResponse, + UserResponsePayload, +}; + +thread_local! { + static WS_HANDLE: RefCell>> = RefCell::new(None); +} #[wasm_bindgen] pub fn init_chat() -> Result<(), JsValue> { @@ -13,24 +28,84 @@ pub fn init_chat() -> Result<(), JsValue> { let ws_url = format!("{scheme}://{host}/chat"); let ws = Rc::new(WebSocket::new(&ws_url)?); + ws.set_binary_type(BinaryType::Arraybuffer); + console::log_1(&format!("[ws] connecting to {ws_url}").into()); let ws_for_open = ws.clone(); let onopen = Closure::::wrap(Box::new(move |_| { - let _ = ws_for_open.send_with_str( - r#"this is the first message first chat - too bad it is in lower case one one one"#); - let _ = ws_for_open.send_with_str( - "And this is message two of chat request 2. \ - Too bad it isn't processed two two two"); + console::log_1(&"[ws] connected".into()); + let requests = [ + UserRequest { + request_id: 1, + payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![ + ChatMessage { + role: "user".to_string(), + content: r#"this is the first message first chat +too bad it is in lower case one one one"# + .to_string(), + }, + ])), + }, + UserRequest { + request_id: 2, + payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![ + ChatMessage { + role: "user".to_string(), + content: "And this is message two of chat request 2. \ +Too bad it isn't processed two two two" + .to_string(), + }, + ])), + }, + ]; + + for request in requests { + match request.to_bytes() { + Ok(bytes) => { + console::log_1(&format!("[ws] sending request_id={} bytes={}", request.request_id, bytes.len()).into()); + let _ = ws_for_open.send_with_u8_array(&bytes); + } + Err(err) => { + console::error_1(&format!("[ws] encode error: {err:#}").into()); + } + } + } })); ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); onopen.forget(); let onmessage = Closure::::wrap(Box::new(move |event: MessageEvent| { - if let Some(text) = event.data().as_string() { - console::log_1(&format!("[ws] {text}").into()); - } else { - console::log_1(&"[ws] non-text message".into()); + let data = event.data(); + if let Some(text) = data.as_string() { + console::log_1(&format!("[ws] unexpected text frame: {text}").into()); + return; + } + if !data.is_instance_of::() { + console::log_1(&"[ws] unexpected non-binary frame".into()); + return; + } + + let data = Uint8Array::new(&data); + let bytes = data.to_vec(); + console::log_1(&format!("[ws] received bytes={}", bytes.len()).into()); + let response = match UserResponse::from_bytes(&bytes) { + Ok(response) => response, + Err(err) => { + console::error_1(&format!("[ws] decode error: {err:#} (bytes={})", bytes.len()).into()); + return; + } + }; + + match response.payload { + UserResponsePayload::ChatCompletion(payload) => { + console::log_1(&format!("[ws] request_id={} piece={}", response.request_id, payload.piece).into()); + } + UserResponsePayload::ChatCompletionCancellation(_) => { + console::log_1(&format!("[ws] request_id={} [cancel]", response.request_id).into()); + } + UserResponsePayload::ChatCompletionEnd(_) => { + console::log_1(&format!("[ws] request_id={} [end]", response.request_id).into()); + } } })); ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); @@ -42,6 +117,6 @@ pub fn init_chat() -> Result<(), JsValue> { ws.set_onerror(Some(onerror.as_ref().unchecked_ref())); onerror.forget(); - let _ = ws.clone(); + WS_HANDLE.with(|slot| *slot.borrow_mut() = Some(ws.clone())); Ok(()) } diff --git a/frontend_protocol/Cargo.toml b/frontend_protocol/Cargo.toml new file mode 100644 index 0000000..fe17c41 --- /dev/null +++ b/frontend_protocol/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "frontend_protocol" +version = "0.1.0" +edition = "2021" + +[features] +default = ["std"] +std = ["deku/std"] + +[dependencies] +deku = { version = "0.20", default-features = false, features = ["alloc", "bits"] } +anyhow = "1.0" diff --git a/frontend_protocol/src/lib.rs b/frontend_protocol/src/lib.rs new file mode 100644 index 0000000..8e54f6b --- /dev/null +++ b/frontend_protocol/src/lib.rs @@ -0,0 +1,110 @@ +pub mod utils; + +use deku::prelude::*; +pub use self::utils::{ + DekuBytes, + read_pascal_string, + read_vec_u32, + write_pascal_string, + write_vec_u32, +}; + +#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] +pub struct ChatMessage { + #[deku( + reader = "read_pascal_string(deku::reader)", + writer = "write_pascal_string(deku::writer, &self.role)" + )] + pub role: String, + + #[deku( + reader = "read_pascal_string(deku::reader)", + writer = "write_pascal_string(deku::writer, &self.content)" + )] + pub content: String, +} + +#[deku::deku_derive(DekuRead, DekuWrite)] +#[derive(Debug, Clone, PartialEq)] +pub struct UserChatCompletionRequest { + #[deku( + reader = "read_vec_u32(deku::reader)", + writer = "write_vec_u32(deku::writer, &self.messages)" + )] + pub messages: Vec, +} + +impl UserChatCompletionRequest { + pub fn new(messages: Vec) -> Self { + Self { messages } + } +} + +#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] +pub struct UserChatCompletionCancellationRequest; + +#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] +#[deku(id_type = "u8")] +#[repr(u8)] +pub enum UserRequestPayload { + #[deku(id = "0")] + ChatCompletion(UserChatCompletionRequest), + #[deku(id = "1")] + ChatCompletionCancellation(UserChatCompletionCancellationRequest), +} + +#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] +pub struct UserRequest { + #[deku(endian = "little")] + pub request_id: u64, + pub payload: UserRequestPayload, +} + +#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] +pub struct UserResponseChatCompletion { + #[deku( + reader = "read_pascal_string(deku::reader)", + writer = "write_pascal_string(deku::writer, &self.piece)" + )] + pub piece: String, +} + +#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] +pub struct UserResponseChatCompletionCancellation; + +#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] +pub struct UserResponseChatCompletionEnd; + +#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] +#[deku(id_type = "u8")] +#[repr(u8)] +pub enum UserResponsePayload { + #[deku(id = "0")] + ChatCompletion(UserResponseChatCompletion), + #[deku(id = "1")] + ChatCompletionCancellation(UserResponseChatCompletionCancellation), + #[deku(id = "2")] + ChatCompletionEnd(UserResponseChatCompletionEnd), +} + +#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] +pub struct UserResponse { + #[deku(endian = "little")] + pub request_id: u64, + pub payload: UserResponsePayload, +} + +impl DekuBytes for ChatMessage {} + +impl DekuBytes for UserChatCompletionRequest {} + +impl DekuBytes for UserChatCompletionCancellationRequest {} + +impl DekuBytes for UserRequestPayload {} +impl DekuBytes for UserRequest {} + +impl DekuBytes for UserResponseChatCompletion {} +impl DekuBytes for UserResponseChatCompletionCancellation {} +impl DekuBytes for UserResponseChatCompletionEnd {} +impl DekuBytes for UserResponsePayload {} +impl DekuBytes for UserResponse {} diff --git a/frontend_protocol/src/utils.rs b/frontend_protocol/src/utils.rs new file mode 100644 index 0000000..4012184 --- /dev/null +++ b/frontend_protocol/src/utils.rs @@ -0,0 +1,84 @@ +use anyhow::{Context, Result, bail}; +use deku::ctx::{Endian, Order}; +use deku::{DekuError, DekuWriter}; +use deku::prelude::{Reader, Writer}; +use deku::prelude::*; + +pub trait DekuBytes: Sized + Clone + DekuContainerWrite + for<'a> DekuContainerRead<'a> { + fn to_bytes(&self) -> Result> { + let type_name = std::any::type_name::(); + DekuContainerWrite::to_bytes(self) + .context(format!("failed to encode {type_name}")) + } + + fn from_bytes(bytes: &[u8]) -> Result { + let type_name = std::any::type_name::(); + let ((rest, bit_offset), value) = + ::from_bytes((bytes, 0)) + .context(format!("failed to decode {type_name}"))?; + + if !rest.is_empty() || bit_offset != 0 { + bail!("trailing bytes after {type_name}"); + } + + Ok(value) + } +} + +pub fn read_pascal_string(reader: &mut Reader) -> core::result::Result +where + R: deku::no_std_io::Read + deku::no_std_io::Seek, +{ + let byte_len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize; + let mut bytes = vec![0u8; byte_len]; + reader.read_bytes(byte_len, &mut bytes, Order::Msb0)?; + + String::from_utf8(bytes).map_err(|_| DekuError::Parse("invalid utf-8".into())) +} + +pub fn write_pascal_string( + writer: &mut Writer, + value: &String, +) -> core::result::Result<(), DekuError> +where + W: deku::no_std_io::Write + deku::no_std_io::Seek, +{ + let bytes = value.as_bytes(); + let byte_len = + u32::try_from(bytes.len()).map_err(|_| DekuError::Parse("string too large".into()))?; + + byte_len.to_writer(writer, Endian::Little)?; + writer.write_bytes(bytes)?; + Ok(()) +} + +pub fn read_vec_u32(reader: &mut Reader) -> core::result::Result, DekuError> +where + R: deku::no_std_io::Read + deku::no_std_io::Seek, + for<'a> T: DekuReader<'a, ()>, +{ + let len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize; + let mut items = Vec::with_capacity(len); + for _ in 0..len { + let item = T::from_reader_with_ctx(reader, ())?; + items.push(item); + } + Ok(items) +} + +pub fn write_vec_u32( + writer: &mut Writer, + value: &Vec, +) -> core::result::Result<(), DekuError> +where + W: deku::no_std_io::Write + deku::no_std_io::Seek, + T: DekuWriter<()>, +{ + let len = u32::try_from(value.len()) + .map_err(|_| DekuError::Parse("vector too large".into()))?; + len.to_writer(writer, Endian::Little)?; + for item in value { + item.to_writer(writer, ())?; + } + Ok(()) +} diff --git a/secret-config.toml b/secret-config.toml new file mode 100644 index 0000000..bdfaafa --- /dev/null +++ b/secret-config.toml @@ -0,0 +1 @@ +dedicated_ai_server_secret="change-me" \ No newline at end of file diff --git a/website/Cargo.toml b/website/Cargo.toml index 3a21546..b07a8aa 100644 --- a/website/Cargo.toml +++ b/website/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] rand = "0.8" axum = { version = "0.7", features = ["multipart", "ws"] } -tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "net"] } +tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "net", "sync", "time"] } tower-http = { version = "0.5", features = ["fs"] } tera = "1.19.1" reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false } @@ -19,3 +19,4 @@ anyhow = "1.0" blake2b_simd = "1.0" dryoc = "0.7" deku = "0.20" +frontend_protocol = { path = "../frontend_protocol" } diff --git a/website/src/config.rs b/website/src/config.rs index aa02f04..56b1109 100644 --- a/website/src/config.rs +++ b/website/src/config.rs @@ -2,6 +2,22 @@ use serde::Deserialize; use std::path::Path; const DEFAULT_CONFIG_PATH: &str = "config.toml"; +const DEFAULT_SECRET_CONFIG_PATH: &str = "secret-config.toml"; + +#[derive(Debug, Clone, Deserialize)] +struct PublicConfig { + pub postgres_host: String, + pub pg_database: String, + pub postgres_user: String, + pub file_storage: String, + pub dedicated_ai_server_address: String, + pub dedicated_ai_server_port: u16, +} + +#[derive(Debug, Clone, Deserialize)] +struct SecretConfig { + pub dedicated_ai_server_secret: String, +} #[derive(Debug, Clone, Deserialize)] pub struct GeneralServiceConfig { @@ -9,13 +25,28 @@ pub struct GeneralServiceConfig { pub pg_database: String, pub postgres_user: String, pub file_storage: String, + pub dedicated_ai_server_address: String, + pub dedicated_ai_server_port: u16, + pub dedicated_ai_server_secret: String, } pub type ConfigResult = Result>; -pub fn load_config(path: impl AsRef) -> ConfigResult { - let raw = std::fs::read_to_string(path.as_ref())?; - let config: GeneralServiceConfig = toml::from_str(&raw)?; +pub fn load_config(config_path: impl AsRef, secret_config_path: impl AsRef) -> ConfigResult { + let raw_config = std::fs::read_to_string(config_path.as_ref())?; + let public_config: PublicConfig = toml::from_str(&raw_config)?; + let raw_secret = std::fs::read_to_string(secret_config_path.as_ref())?; + let secret_config: SecretConfig = toml::from_str(&raw_secret)?; + + let config = GeneralServiceConfig { + postgres_host: public_config.postgres_host, + pg_database: public_config.pg_database, + postgres_user: public_config.postgres_user, + file_storage: public_config.file_storage, + dedicated_ai_server_address: public_config.dedicated_ai_server_address, + dedicated_ai_server_port: public_config.dedicated_ai_server_port, + dedicated_ai_server_secret: secret_config.dedicated_ai_server_secret, + }; if config.pg_database.trim().is_empty() { return Err(std::io::Error::new( @@ -41,9 +72,17 @@ pub fn load_config(path: impl AsRef) -> ConfigResult .into()); } + if config.dedicated_ai_server_secret.trim().is_empty() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "config dedicated_ai_server_secret is empty", + ) + .into()); + } + Ok(config) } pub fn load_config_default() -> ConfigResult { - load_config(DEFAULT_CONFIG_PATH) + load_config(DEFAULT_CONFIG_PATH, DEFAULT_SECRET_CONFIG_PATH) } diff --git a/website/src/dedicated_ai_server/TEST.rs b/website/src/dedicated_ai_server/TEST.rs index 5ed60e2..71ede44 100644 --- a/website/src/dedicated_ai_server/TEST.rs +++ b/website/src/dedicated_ai_server/TEST.rs @@ -3,11 +3,11 @@ use tokio::net::{TcpListener, TcpStream}; use anyhow::Result; use crate::dedicated_ai_server::api::{ ChatCompletionRequest, - DekuBytes, MessageInChat, Request, RequestPayload, }; +use frontend_protocol::DekuBytes; use crate::dedicated_ai_server::talking::{SecretStreamSocket, wrap_connection_socket}; use crate::dedicated_ai_server::talking::{ProtocolError, FrameCallback}; diff --git a/website/src/dedicated_ai_server/api.rs b/website/src/dedicated_ai_server/api.rs index b6e7972..f8fba38 100644 --- a/website/src/dedicated_ai_server/api.rs +++ b/website/src/dedicated_ai_server/api.rs @@ -1,39 +1,12 @@ -use anyhow::{Context, Result, bail}; use deku::prelude::*; -use super::marshalling_utils::{ - read_bool_u8, +use frontend_protocol::{ + DekuBytes, read_pascal_string, - write_bool_u8, + read_vec_u32, write_pascal_string, + write_vec_u32, }; -pub trait DekuBytes: Sized + Clone + DekuContainerWrite + for<'a> DekuContainerRead<'a> { - fn pre_encode(&mut self) -> Result<()> { - Ok(()) - } - - fn to_bytes(&self) -> Result> { - let mut value = self.clone(); - value.pre_encode()?; - let type_name = std::any::type_name::(); - DekuContainerWrite::to_bytes(&value) - .context(format!("failed to encode {type_name}")) - } - - fn from_bytes(bytes: &[u8]) -> Result { - let type_name = std::any::type_name::(); - let ((rest, bit_offset), value) = - ::from_bytes((bytes, 0)) - .context(format!("failed to decode {type_name}"))?; - - if !rest.is_empty() || bit_offset != 0 { - bail!("trailing bytes after {type_name}"); - } - - Ok(value) - } -} - #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] pub struct MessageInChat { #[deku( @@ -49,24 +22,20 @@ pub struct MessageInChat { pub content: String, } -/* Deku is a joke. A pathetic excuse for a marshalling library. But it's okay */ #[deku::deku_derive(DekuRead, DekuWrite)] #[derive(Debug, Clone, PartialEq)] pub struct ChatCompletionRequest { - #[deku(endian = "little", update = "self.messages.len() as u32")] - message_count: u32, - - #[deku(count = "*message_count as usize")] + #[deku( + reader = "read_vec_u32(deku::reader)", + writer = "write_vec_u32(deku::writer, &self.messages)" + )] pub messages: Vec, } impl ChatCompletionRequest { pub fn new(messages: Vec) -> Self { - Self { - message_count: messages.len() as u32, - messages, - } + Self { messages } } } @@ -127,37 +96,12 @@ pub struct Response { impl DekuBytes for MessageInChat {} -impl DekuBytes for ChatCompletionRequest { - fn pre_encode(&mut self) -> Result<()> { - self.update() - .context("failed to update ChatCompletionRequest before encoding")?; - Ok(()) - } -} +impl DekuBytes for ChatCompletionRequest {} impl DekuBytes for ChatCompletionCancellationRequest {} -impl DekuBytes for RequestPayload { - fn pre_encode(&mut self) -> Result<()> { - if let RequestPayload::ChatCompletion(batch) = self { - batch - .update() - .context("failed to update ChatCompletionRequest before encoding")?; - } - Ok(()) - } -} - -impl DekuBytes for Request { - fn pre_encode(&mut self) -> Result<()> { - if let RequestPayload::ChatCompletion(batch) = &mut self.payload { - batch - .update() - .context("failed to update ChatCompletionRequest before encoding")?; - } - Ok(()) - } -} +impl DekuBytes for RequestPayload {} +impl DekuBytes for Request {} impl DekuBytes for ResponseChatCompletion {} impl DekuBytes for ResponseChatCompletionCancellation {} diff --git a/website/src/dedicated_ai_server/connection.rs b/website/src/dedicated_ai_server/connection.rs new file mode 100644 index 0000000..46414de --- /dev/null +++ b/website/src/dedicated_ai_server/connection.rs @@ -0,0 +1,205 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; + +use anyhow::{Context, Result}; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tokio::time::{sleep, Duration}; + +use frontend_protocol::DekuBytes; +use super::api::{ + ChatCompletionRequest, + MessageInChat, + Request, + RequestPayload, + Response, + ResponsePayload, +}; +use super::talking::{SecretStreamSocket, wrap_connection_socket}; + +#[derive(Debug, Clone)] +pub struct MessagePiecePayload(pub String); + +#[derive(Debug, Clone)] +pub enum MessagePiece { + Piece(MessagePiecePayload), End, Cancelled, DedicatedServerDisconnected +} + +#[derive(Debug)] +struct PendingChatCompletionRecord { + response_tx: mpsc::UnboundedSender, +} + +#[derive(Debug)] +struct RequestWorkItem { + request: Request, + response_tx: mpsc::UnboundedSender, +} + +#[derive(Debug)] +pub struct DedicatedAiServerConnection { + address: String, + port: u16, + secret: String, + request_tx: mpsc::UnboundedSender, + next_request_id: AtomicU64, +} + +pub fn connect_to_dedicated_ai_server(address: String, port: u16, secret: String) + -> (DedicatedAiServerConnection, JoinHandle<()>) { + let (request_tx, request_rx) = mpsc::unbounded_channel(); + let worker_address = address.clone(); + let worker_secret = secret.clone(); + + let join_hand = tokio::spawn(async move { + connection_worker_loop(worker_address, port, worker_secret, request_rx).await; + }); + + let conn_bridge = DedicatedAiServerConnection { + address, + port, + secret, + request_tx, + next_request_id: AtomicU64::new(1), + }; + (conn_bridge, join_hand) +} + +impl DedicatedAiServerConnection { + pub fn send_chat_completion( + &self, + messages: Vec, + ) -> Result> { + let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); + let request = Request { + request_id, + payload: RequestPayload::ChatCompletion(ChatCompletionRequest::new(messages)), + }; + let (response_tx, response_rx) = mpsc::unbounded_channel(); + self.request_tx + .send(RequestWorkItem { request, response_tx }) + .map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?; + Ok(response_rx) + } +} + +async fn connection_worker_loop( + address: String, + port: u16, + secret: String, + mut request_rx: mpsc::UnboundedReceiver, +) { + // todo : fix a lot of errors + loop { + let socket = match TcpStream::connect((address.as_str(), port)).await { + Ok(socket) => socket, + Err(err) => { + eprintln!( + "[dedicated-ai] failed to connect to {}:{}: {err}", + address, port + ); + sleep(Duration::from_secs(1)).await; + continue; + } + }; + + let mut transport = match wrap_connection_socket(socket, &secret).await { + Ok(transport) => transport, + Err(err) => { + eprintln!("[dedicated-ai] failed to wrap connection: {err:#}"); + sleep(Duration::from_secs(1)).await; + continue; + } + }; + + let mut pending: HashMap = HashMap::new(); + let result = run_connected_loop(&mut transport, &mut request_rx, &mut pending).await; + if let Err(err) = result { + eprintln!("[dedicated-ai] connection error: {err:#}"); + } + + cancel_all_pending(&mut pending); + let _ = transport.close(true).await; + sleep(Duration::from_secs(1)).await; + } +} + +async fn run_connected_loop( + transport: &mut SecretStreamSocket, + request_rx: &mut mpsc::UnboundedReceiver, + pending: &mut HashMap, +) -> Result<()> { + loop { + tokio::select! { + frame = transport.recv_frame() => { + let frame = frame.context("failed to read frame")?; + let frame = match frame { + Some(bytes) => bytes, + None => return Err(anyhow::anyhow!("connection closed by peer")), + }; + handle_response_frame(frame, pending)?; + } + maybe_item = request_rx.recv() => { + let item = match maybe_item { + Some(item) => item, + None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen + }; + let request_id = item.request.request_id; + pending.insert( + request_id, + PendingChatCompletionRecord { response_tx: item.response_tx }, + ); + + let payload = item.request.to_bytes().context("failed to encode request")?; + transport.send_frame(&payload).await.context("failed to send request")?; + } + } + } +} + +fn handle_response_frame( + frame: Vec, + pending: &mut HashMap, +) -> Result<()> { + let response = Response::from_bytes(&frame).context("failed to decode response")?; + let request_id = response.request_id; + let record = pending.get(&request_id); + if record.is_none() { + eprintln!("[dedicated-ai] response for unknown request_id={request_id}"); + return Ok(()); + } + + match response.payload { + ResponsePayload::ChatCompletion(payload) => { + let send_failed = match pending.get(&request_id) { + Some(record) => record + .response_tx + .send(MessagePiece::Piece( MessagePiecePayload(payload.piece) )) + .is_err(), + None => false, + }; + if send_failed { + pending.remove(&request_id); + } + } + ResponsePayload::ChatCompletionEnd(_) => { + if let Some(record) = pending.remove(&request_id) { + let _ = record.response_tx.send(MessagePiece::End); + } + } + ResponsePayload::ChatCompletionCancellation(_) => { + if let Some(record) = pending.remove(&request_id) { + let _ = record.response_tx.send(MessagePiece::Cancelled); + } + } + } + + Ok(()) +} + +fn cancel_all_pending(pending: &mut HashMap) { + for (_request_id, record) in pending.drain() { + let _ = record.response_tx.send(MessagePiece::DedicatedServerDisconnected); + } +} diff --git a/website/src/dedicated_ai_server/marshalling_utils.rs b/website/src/dedicated_ai_server/marshalling_utils.rs deleted file mode 100644 index 1025533..0000000 --- a/website/src/dedicated_ai_server/marshalling_utils.rs +++ /dev/null @@ -1,56 +0,0 @@ -use deku::ctx::{Endian, Order}; -use deku::{DekuError, DekuWriter}; -use deku::prelude::{Reader, Writer}; -use deku::prelude::*; - -pub fn read_pascal_string(reader: &mut Reader) -> core::result::Result -where - R: deku::no_std_io::Read + deku::no_std_io::Seek, -{ - let byte_len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize; - let mut bytes = vec![0u8; byte_len]; - reader.read_bytes(byte_len, &mut bytes, Order::Msb0)?; - - String::from_utf8(bytes).map_err(|err| DekuError::Parse(err.to_string().into())) -} - - -pub fn write_pascal_string( - writer: &mut Writer, - value: &String, -) -> core::result::Result<(), DekuError> -where - W: deku::no_std_io::Write + deku::no_std_io::Seek, -{ - let bytes = value.as_bytes(); - let byte_len = - u32::try_from(bytes.len()).map_err(|_| DekuError::Parse("string too large".into()))?; - - byte_len.to_writer(writer, Endian::Little)?; - writer.write_bytes(bytes)?; - Ok(()) -} - -pub fn read_bool_u8(reader: &mut Reader) -> core::result::Result -where - R: deku::no_std_io::Read + deku::no_std_io::Seek, -{ - let value = u8::from_reader_with_ctx(reader, Endian::Little)?; - match value { - 0 => Ok(false), - 1 => Ok(true), - _ => Err(DekuError::Parse("invalid bool value".into())), - } -} - -pub fn write_bool_u8( - writer: &mut Writer, - value: &bool, -) -> core::result::Result<(), DekuError> -where - W: deku::no_std_io::Write + deku::no_std_io::Seek, -{ - let encoded: u8 = if *value { 1 } else { 0 }; - encoded.to_writer(writer, Endian::Little)?; - Ok(()) -} diff --git a/website/src/dedicated_ai_server/mod.rs b/website/src/dedicated_ai_server/mod.rs index 0b51480..e485011 100644 --- a/website/src/dedicated_ai_server/mod.rs +++ b/website/src/dedicated_ai_server/mod.rs @@ -1,4 +1,4 @@ pub mod api; +pub mod connection; pub mod talking; pub mod TEST; -mod marshalling_utils; \ No newline at end of file diff --git a/website/src/lib.rs b/website/src/lib.rs index 30125e5..35a2009 100644 --- a/website/src/lib.rs +++ b/website/src/lib.rs @@ -17,24 +17,38 @@ use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite}; use serde::Deserialize; use std::future::Future; use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::OnceLock; use tera::{Tera}; -use tokio::time::{sleep, Duration}; use tower_http::services::{ServeDir, ServeFile}; use samplers::random_junk::random_string; -use config::{ConfigResult, GeneralServiceConfig, load_config, load_config_default}; -use db::{DbResult, connect_db, init_database}; +use config::load_config_default; +use db::{connect_db, init_database}; use web_file_uploads::{get_file, upload_get, upload_post}; use web_app_state::{AppState, AppStateInner, AuthenticatedUserId}; +use dedicated_ai_server::connection::connect_to_dedicated_ai_server; +use dedicated_ai_server::api::MessageInChat; +use crate::dedicated_ai_server::connection::{MessagePiece, MessagePiecePayload}; +use frontend_protocol::{ + DekuBytes, + UserRequest, + UserRequestPayload, + UserResponse, + UserResponseChatCompletion, + UserResponseChatCompletionCancellation, + UserResponseChatCompletionEnd, + UserResponsePayload, +}; - -async fn init_app_state() -> Result> { +async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box> { let config = load_config_default()?; let db = connect_db(&config).await?; let tera = Tera::new("frontend/pages/**/*.html")?; - Ok(std::sync::Arc::new(AppStateInner { config, db, tera })) + let (dedicated_ai, dedicated_ai_task_handler) = connect_to_dedicated_ai_server( + config.dedicated_ai_server_address.clone(), + config.dedicated_ai_server_port, + config.dedicated_ai_server_secret.clone(), + ); + Ok((std::sync::Arc::new(AppStateInner { config, db, tera, dedicated_ai }), dedicated_ai_task_handler)) } enum PasscodeAuthenticationResult{ @@ -235,7 +249,9 @@ async fn chat( user: AuthenticatedUserId, ) -> Response { if let Some(ws) = ws { - return ws.on_upgrade(move |socket| handle_chat_socket(socket)).into_response(); + return ws + .on_upgrade(move |socket| handle_chat_socket(socket, state.clone())) + .into_response(); } let mut ctx = tera::Context::new(); @@ -247,19 +263,105 @@ async fn chat( Html(body).into_response() } -async fn handle_chat_socket(mut socket: WebSocket) { - while let Some(msg) = socket.recv().await { +async fn handle_chat_socket(mut socket: WebSocket, state: AppState) { + 'outer: while let Some(msg) = socket.recv().await { let msg = match msg { Ok(msg) => msg, Err(_) => break, }; match msg { - Message::Text(text) => { - sleep(Duration::from_secs(2)).await; - let upper = text.to_uppercase(); - if socket.send(Message::Text(upper)).await.is_err() { - break; + Message::Binary(bytes) => { + let request = match UserRequest::from_bytes(&bytes) { + Ok(request) => request, + Err(err) => { + eprintln!("[chat] failed to decode request: {err:#}"); + continue; + } + }; + + match request.payload { + UserRequestPayload::ChatCompletion(payload) => { + let messages = payload + .messages + .into_iter() + .map(|msg| MessageInChat { + role: msg.role, + content: msg.content, + }) + .collect(); + + let mut response_rx = match state.dedicated_ai.send_chat_completion(messages) { + Ok(rx) => rx, + Err(err) => { + eprintln!("[chat] failed to send request: {err:#}"); + let response = UserResponse { + request_id: request.request_id, + payload: UserResponsePayload::ChatCompletionCancellation( + UserResponseChatCompletionCancellation, + ), + }; + if let Ok(bytes) = response.to_bytes() { + let _ = socket.send(Message::Binary(bytes)).await; + } + continue; + } + }; + + while let Some(piece) = response_rx.recv().await { + let (payload, should_break) = match piece { + MessagePiece::Piece(MessagePiecePayload(text)) => ( + UserResponsePayload::ChatCompletion(UserResponseChatCompletion { + piece: text, + }), + false, + ), + MessagePiece::End => ( + UserResponsePayload::ChatCompletionEnd( + UserResponseChatCompletionEnd, + ), + true, + ), + MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => ( + UserResponsePayload::ChatCompletionCancellation( + UserResponseChatCompletionCancellation, + ), + true, + ), + }; + + let response = UserResponse { + request_id: request.request_id, + payload, + }; + let bytes = match response.to_bytes() { + Ok(bytes) => bytes, + Err(err) => { + eprintln!("[chat] failed to encode response: {err:#}"); + break 'outer; + } + }; + if socket.send(Message::Binary(bytes)).await.is_err() { + break 'outer; + } + if should_break { + break; + } + } + } + UserRequestPayload::ChatCompletionCancellation(_) => { + let response = UserResponse { + request_id: request.request_id, + payload: UserResponsePayload::ChatCompletionCancellation( + UserResponseChatCompletionCancellation, + ), + }; + if let Ok(bytes) = response.to_bytes() { + if socket.send(Message::Binary(bytes)).await.is_err() { + break 'outer; + } + } + } } } Message::Close(_) => break, @@ -270,7 +372,7 @@ async fn handle_chat_socket(mut socket: WebSocket) { pub async fn run_server() -> Result<(), Box> { - let state = init_app_state().await.expect("lol"); + let (state, mut dedicated_ai_task) = init_app_state().await.expect("lol"); let app = Router::new() .route("/login", get(login_get)) .route("/login", post(login_post)) @@ -292,7 +394,28 @@ pub async fn run_server() -> Result<(), Box> { let listener = tokio::net::TcpListener::bind(addr) .await .expect("bind failed"); - axum::serve(listener, app).await?; + let mut server_task = tokio::spawn(async move { axum::serve(listener, app).await }); + + tokio::select! { + res = &mut server_task => { + dedicated_ai_task.abort(); + res??; + } + res = &mut dedicated_ai_task => { + server_task.abort(); + res.map_err(|err| { + std::io::Error::new( + std::io::ErrorKind::Other, + format!("dedicated ai task failed: {err}"), + ) + })?; + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "dedicated ai task ended unexpectedly", + ) + .into()); + } + } Ok(()) } diff --git a/website/src/web_app_state.rs b/website/src/web_app_state.rs index 95c077e..6cbece0 100644 --- a/website/src/web_app_state.rs +++ b/website/src/web_app_state.rs @@ -1,10 +1,12 @@ use tera::Tera; -use crate::GeneralServiceConfig; +use crate::config::GeneralServiceConfig; +use crate::dedicated_ai_server::connection::DedicatedAiServerConnection; pub struct AppStateInner { pub config: GeneralServiceConfig, pub db: tokio_postgres::Client, pub tera: Tera, + pub dedicated_ai: DedicatedAiServerConnection, } pub type AppState = std::sync::Arc;