From 336177941f93247463964651f54c4824c672f21f Mon Sep 17 00:00:00 2001 From: Andreew Gregory Date: Sun, 29 Mar 2026 01:37:04 +0300 Subject: [PATCH] Chat is working somewhat --- dedicated_ai_server/server.py | 4 +- frontend/Cargo.toml | 3 + frontend/pages/chat.html | 28 +- frontend/src/chat.rs | 356 +++++++++++++++--- frontend/static/css/site.css | 152 ++++++++ frontend_protocol/src/lib.rs | 47 +-- website/src/bin/E.rs | 7 - website/src/bin/F.rs | 7 - website/src/dedicated_ai_server/TEST.rs | 117 ------ website/src/dedicated_ai_server/api.rs | 18 +- website/src/dedicated_ai_server/connection.rs | 54 ++- website/src/dedicated_ai_server/mod.rs | 1 - website/src/lib.rs | 152 +++++--- 13 files changed, 623 insertions(+), 323 deletions(-) delete mode 100644 website/src/bin/E.rs delete mode 100644 website/src/bin/F.rs delete mode 100644 website/src/dedicated_ai_server/TEST.rs diff --git a/dedicated_ai_server/server.py b/dedicated_ai_server/server.py index 826897c..898b7ce 100644 --- a/dedicated_ai_server/server.py +++ b/dedicated_ai_server/server.py @@ -16,7 +16,7 @@ SERVER_PORT = 9000 SHARED_SECRET = "change-me" -PROCESS_DELAY_SECONDS = 1.5 +PROCESS_DELAY_SECONDS = 0.4 @dataclass @@ -90,7 +90,9 @@ def worker_loop( record.response_queue.put_nowait, MessagePiece(piece=piece, is_end=False, is_cancel=False), ) + print("[debug] got a new piece") if record.is_cancelled(): + print("[debug] record was cancelled") cancelled = True break diff --git a/frontend/Cargo.toml b/frontend/Cargo.toml index 22b8534..5885487 100644 --- a/frontend/Cargo.toml +++ b/frontend/Cargo.toml @@ -16,8 +16,11 @@ version = "0.3" features = [ "Window", "Document", + "Element", "Event", "EventTarget", + "HtmlElement", + "HtmlTextAreaElement", "KeyboardEvent", "MouseEvent", "WheelEvent", diff --git a/frontend/pages/chat.html b/frontend/pages/chat.html index 7ff9e1b..56bf4f7 100644 --- a/frontend/pages/chat.html +++ b/frontend/pages/chat.html @@ -1,5 +1,5 @@ - + @@ -14,10 +14,28 @@ init_chat(); - -
-

Chat

-

WebSocket connection initialized. Open the console to see messages.

+ +
+
+
+

Chat

+
Live session
+
+ +
+
+ +
+ +
Enter to send. Shift+Enter for a new line.
+
+
+
diff --git a/frontend/src/chat.rs b/frontend/src/chat.rs index 1adb2c3..17a54de 100644 --- a/frontend/src/chat.rs +++ b/frontend/src/chat.rs @@ -1,79 +1,284 @@ use std::cell::RefCell; use std::rc::Rc; + +use js_sys::{ArrayBuffer, Uint8Array}; use wasm_bindgen::prelude::*; use wasm_bindgen::JsCast; -use web_sys::{console, window, BinaryType, ErrorEvent, Event, MessageEvent, WebSocket}; -use js_sys::{ArrayBuffer, Uint8Array}; +use web_sys::{ + console, + window, + BinaryType, + Document, + Element, + ErrorEvent, + Event, + HtmlElement, + HtmlTextAreaElement, + KeyboardEvent, + MessageEvent, + WebSocket, +}; + use frontend_protocol::{ - ChatMessage, + UserChatMessage, DekuBytes, UserChatCompletionRequest, UserRequest, - UserRequestPayload, UserResponse, - UserResponsePayload, }; thread_local! { - static WS_HANDLE: RefCell>> = RefCell::new(None); + static APP_HANDLE: RefCell>> = RefCell::new(None); +} + +struct ChatState { + messages: Vec, + message_nodes: Vec, + is_receiving: bool, + active_assistant_index: Option, +} + +struct MessageNode { + content: Element, + status: Option, +} + +#[derive(Copy, Clone)] +enum ChatStatus { + Pending, + Cancelled, + Hidden, +} + +struct AppState { + ws: WebSocket, + document: Document, + messages_container: Element, + input: HtmlTextAreaElement, + state: RefCell, +} + +fn append_message( + document: &Document, + container: &Element, + role: &str, + content: &str, + status: Option, +) -> Result { + let wrapper = document.create_element("div")?; + wrapper.set_class_name(&format!("chat-message chat-message--{}", role)); + + let role_el = document.create_element("div")?; + role_el.set_class_name("chat-message__role"); + role_el.set_text_content(Some(role)); + + let content_el = document.create_element("div")?; + content_el.set_class_name("chat-message__content"); + content_el.set_text_content(Some(content)); + + wrapper.append_child(&role_el)?; + wrapper.append_child(&content_el)?; + + let status_el = if let Some(status) = status { + let status_el = document.create_element("div")?; + apply_status(&status_el, status); + wrapper.append_child(&status_el)?; + Some(status_el) + } else { + None + }; + + container.append_child(&wrapper)?; + + Ok(MessageNode { + content: content_el, + status: status_el, + }) +} + +fn set_message_content(node: &Element, content: &str) { + node.set_text_content(Some(content)); +} + +fn apply_status(node: &Element, status: ChatStatus) { + match status { + ChatStatus::Pending => { + node.set_class_name("chat-message__status chat-message__status--pending"); + node.set_text_content(Some("...")); + } + ChatStatus::Cancelled => { + node.set_class_name("chat-message__status chat-message__status--cancelled"); + node.set_text_content(Some("[canceled]")); + } + ChatStatus::Hidden => { + node.set_class_name("chat-message__status chat-message__status--hidden"); + node.set_text_content(Some("")); + } + } +} + +fn scroll_to_bottom(container: &Element) { + if let Some(element) = container.dyn_ref::() { + let height = element.scroll_height(); + element.set_scroll_top(height); + } } #[wasm_bindgen] pub fn init_chat() -> Result<(), JsValue> { let window = window().ok_or_else(|| JsValue::from_str("Missing window"))?; + let document = window.document().ok_or_else(|| JsValue::from_str("Missing document"))?; let location = window.location(); let host = location.host()?; let protocol = location.protocol()?; let scheme = if protocol == "https:" { "wss" } else { "ws" }; let ws_url = format!("{scheme}://{host}/chat"); - let ws = Rc::new(WebSocket::new(&ws_url)?); + let messages_container = document + .get_element_by_id("chat-messages") + .ok_or_else(|| JsValue::from_str("Missing chat-messages element"))?; + let input_el: HtmlTextAreaElement = document + .get_element_by_id("chat-input") + .ok_or_else(|| JsValue::from_str("Missing chat-input element"))? + .dyn_into()?; + + let ws = WebSocket::new(&ws_url)?; ws.set_binary_type(BinaryType::Arraybuffer); console::log_1(&format!("[ws] connecting to {ws_url}").into()); - let ws_for_open = ws.clone(); + let app = Rc::new(AppState { + ws, + document, + messages_container, + input: input_el, + state: RefCell::new(ChatState { + messages: Vec::new(), + message_nodes: Vec::new(), + is_receiving: false, + active_assistant_index: None, + }), + }); + let onopen = Closure::::wrap(Box::new(move |_| { console::log_1(&"[ws] connected".into()); - let requests = [ - UserRequest { - request_id: 1, - payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![ - ChatMessage { - role: "user".to_string(), - content: r#"this is the first message first chat -too bad it is in lower case one one one"# - .to_string(), - }, - ])), - }, - UserRequest { - request_id: 2, - payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![ - ChatMessage { - role: "user".to_string(), - content: "And this is message two of chat request 2. \ -Too bad it isn't processed two two two" - .to_string(), - }, - ])), - }, - ]; - - for request in requests { - match request.to_bytes() { - Ok(bytes) => { - console::log_1(&format!("[ws] sending request_id={} bytes={}", request.request_id, bytes.len()).into()); - let _ = ws_for_open.send_with_u8_array(&bytes); - } - Err(err) => { - console::error_1(&format!("[ws] encode error: {err:#}").into()); - } - } - } })); - ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); + app.ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); onopen.forget(); + let app_for_keydown = app.clone(); + let onkeydown = Closure::::wrap(Box::new(move |event: KeyboardEvent| { + if event.ctrl_key() && (event.key() == "c" || event.key() == "C") { + let state = app_for_keydown.state.borrow(); + if state.is_receiving { + event.prevent_default(); + if app_for_keydown.ws.ready_state() != WebSocket::OPEN { + console::error_1(&"[ws] socket is not open".into()); + return; + } + + let request = UserRequest::ChatCompletionCancellation; + match request.to_bytes() { + Ok(bytes) => { + if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) { + console::error_1(&format!("[ws] cancel send error: {:?}", err).into()); + } + } + Err(err) => { + console::error_1(&format!("[ws] cancel encode error: {err:#}").into()); + } + } + } + return; + } + + if event.key() != "Enter" || event.shift_key() { + return; + } + event.prevent_default(); + + let mut state = app_for_keydown.state.borrow_mut(); + if state.is_receiving { + return; + } + + let raw = app_for_keydown.input.value(); + let trimmed = raw.trim(); + if trimmed.is_empty() { + return; + } + + if app_for_keydown.ws.ready_state() != WebSocket::OPEN { + console::error_1(&"[ws] socket is not open".into()); + return; + } + + let user_content = trimmed.to_string(); + app_for_keydown.input.set_value(""); + + let user_node = match append_message( + &app_for_keydown.document, + &app_for_keydown.messages_container, + "user", + &user_content, + None, + ) { + Ok(node) => node, + Err(err) => { + console::error_1(&format!("[ui] failed to append user message: {:?}", err).into()); + return; + } + }; + + state.messages.push(UserChatMessage { + role: "user".to_string(), + content: user_content, + }); + state.message_nodes.push(user_node); + scroll_to_bottom(&app_for_keydown.messages_container); + + let history = state.messages.clone(); + let request = UserRequest::ChatCompletion(UserChatCompletionRequest::new(history)); + + let bytes = match request.to_bytes() { + Ok(bytes) => bytes, + Err(err) => { + console::error_1(&format!("[ws] encode error: {err:#}").into()); + return; + } + }; + + if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) { + console::error_1(&format!("[ws] send error: {:?}", err).into()); + return; + } + + let assistant_node = match append_message( + &app_for_keydown.document, + &app_for_keydown.messages_container, + "assistant", + "", + Some(ChatStatus::Pending), + ) { + Ok(node) => node, + Err(err) => { + console::error_1(&format!("[ui] failed to append assistant message: {:?}", err).into()); + return; + } + }; + + state.messages.push(UserChatMessage { + role: "assistant".to_string(), + content: String::new(), + }); + state.message_nodes.push(assistant_node); + state.active_assistant_index = Some(state.messages.len() - 1); + state.is_receiving = true; + scroll_to_bottom(&app_for_keydown.messages_container); + })); + app.input.set_onkeydown(Some(onkeydown.as_ref().unchecked_ref())); + onkeydown.forget(); + + let app_for_message = app.clone(); let onmessage = Closure::::wrap(Box::new(move |event: MessageEvent| { let data = event.data(); if let Some(text) = data.as_string() { @@ -87,36 +292,73 @@ Too bad it isn't processed two two two" let data = Uint8Array::new(&data); let bytes = data.to_vec(); - console::log_1(&format!("[ws] received bytes={}", bytes.len()).into()); let response = match UserResponse::from_bytes(&bytes) { Ok(response) => response, Err(err) => { - console::error_1(&format!("[ws] decode error: {err:#} (bytes={})", bytes.len()).into()); + console::error_1(&format!("[ws] decode error: {err:#}").into()); return; } }; - match response.payload { - UserResponsePayload::ChatCompletion(payload) => { - console::log_1(&format!("[ws] request_id={} piece={}", response.request_id, payload.piece).into()); + let mut state = app_for_message.state.borrow_mut(); + let assistant_index = match state.active_assistant_index { + Some(index) => index, + None => { + console::log_1(&"[ws] missing assistant index".into()); + return; } - UserResponsePayload::ChatCompletionCancellation(_) => { - console::log_1(&format!("[ws] request_id={} [cancel]", response.request_id).into()); + }; + + match response { + UserResponse::ChatCompletion(completion) => { + if let Some(message) = state.messages.get_mut(assistant_index) { + message.content.push_str(&completion.piece); + } + if let Some(node) = state.message_nodes.get(assistant_index) { + if let Some(message) = state.messages.get(assistant_index) { + set_message_content(&node.content, &message.content); + } + } + scroll_to_bottom(&app_for_message.messages_container); } - UserResponsePayload::ChatCompletionEnd(_) => { - console::log_1(&format!("[ws] request_id={} [end]", response.request_id).into()); + UserResponse::ChatCompletionEnd => { + state.is_receiving = false; + if let Some(node) = state + .active_assistant_index + .and_then(|index| state.message_nodes.get(index)) + { + if let Some(status_node) = node.status.as_ref() { + apply_status(status_node, ChatStatus::Hidden); + } + } + state.active_assistant_index = None; + let _ = app_for_message.input.focus(); + } + UserResponse::ChatCompletionCancellation => { + state.is_receiving = false; + if let Some(node) = state + .active_assistant_index + .and_then(|index| state.message_nodes.get(index)) + { + if let Some(status_node) = node.status.as_ref() { + apply_status(status_node, ChatStatus::Cancelled); + } + } + state.active_assistant_index = None; + let _ = app_for_message.input.focus(); } } })); - ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); + app.ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); onmessage.forget(); let onerror = Closure::::wrap(Box::new(move |event: ErrorEvent| { console::error_1(&format!("[ws] error: {}", event.message()).into()); })); - ws.set_onerror(Some(onerror.as_ref().unchecked_ref())); + app.ws.set_onerror(Some(onerror.as_ref().unchecked_ref())); onerror.forget(); - WS_HANDLE.with(|slot| *slot.borrow_mut() = Some(ws.clone())); + APP_HANDLE.with(|slot| *slot.borrow_mut() = Some(app.clone())); + let _ = app.input.focus(); Ok(()) } diff --git a/frontend/static/css/site.css b/frontend/static/css/site.css index 6594287..28e6240 100644 --- a/frontend/static/css/site.css +++ b/frontend/static/css/site.css @@ -7,6 +7,158 @@ body { overflow: hidden; } +.chat-html, +.chat-body { + height: 100%; + width: 100%; +} + +.chat-body { + margin: 0; + background: #ffffff; + color: #1b1b1b; + font-family: "Fira Sans", "Space Grotesk", "Montserrat", sans-serif; + font-size: 14pt; + display: block; + overflow: auto; +} + +.chat-shell { + min-height: 100%; + padding: 32px 24px; + box-sizing: border-box; + display: flex; + justify-content: center; +} + +.chat-panel { + width: min(100%, max(70%, 1000pt)); + display: flex; + flex-direction: column; + gap: 16px; + min-height: 0; +} + +.chat-header { + display: flex; + flex-direction: column; + gap: 4px; +} + +.chat-title { + margin: 0; + font-size: 20pt; + font-weight: 600; +} + +.chat-subtitle { + font-size: 10pt; + text-transform: uppercase; + letter-spacing: 1px; + color: #6b6b6b; +} + +.chat-container { + display: flex; + flex-direction: column; + gap: 16px; + min-height: 0; + flex: 1; +} + +.chat-messages { + display: flex; + flex-direction: column; + gap: 12px; + overflow-y: auto; + padding-right: 4px; + flex: 1; +} + +.chat-message { + width: 100%; + border-radius: 14px; + padding: 12px 14px; + box-sizing: border-box; + border: 1px solid #e2e2e2; + background: #fafafa; + display: flex; + flex-direction: column; + gap: 6px; +} + +.chat-message--user { + background: #eef4ff; + border-color: #d5e2ff; +} + +.chat-message--assistant { + background: #f7f2ff; + border-color: #e5d7ff; +} + +.chat-message--system { + background: #f8f8f8; + border-color: #dddddd; +} + +.chat-message__role { + font-size: 9pt; + text-transform: uppercase; + letter-spacing: 1px; + color: #6b6b6b; +} + +.chat-message__content { + white-space: pre-wrap; + line-height: 1.45; +} + +.chat-message__status { + font-size: 9pt; + color: #9aa0a6; +} + +.chat-message__status--pending { + color: #9aa0a6; +} + +.chat-message__status--cancelled { + color: #c0392b; +} + +.chat-message__status--hidden { + display: none; +} + +.chat-input-area { + display: flex; + flex-direction: column; + gap: 6px; +} + +.chat-input { + width: 100%; + min-height: 120px; + padding: 12px 14px; + border-radius: 12px; + border: 1px solid #d4d4d4; + font-size: 14pt; + font-family: inherit; + resize: vertical; + box-sizing: border-box; +} + +.chat-input:disabled { + background: #f2f2f2; + color: #8a8a8a; +} + +.chat-input-hint { + font-size: 9.5pt; + color: #6b6b6b; +} + body { display: flex; align-items: center; diff --git a/frontend_protocol/src/lib.rs b/frontend_protocol/src/lib.rs index 8e54f6b..c92150f 100644 --- a/frontend_protocol/src/lib.rs +++ b/frontend_protocol/src/lib.rs @@ -10,7 +10,7 @@ pub use self::utils::{ }; #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] -pub struct ChatMessage { +pub struct UserChatMessage { #[deku( reader = "read_pascal_string(deku::reader)", writer = "write_pascal_string(deku::writer, &self.role)" @@ -31,33 +31,23 @@ pub struct UserChatCompletionRequest { reader = "read_vec_u32(deku::reader)", writer = "write_vec_u32(deku::writer, &self.messages)" )] - pub messages: Vec, + pub messages: Vec, } impl UserChatCompletionRequest { - pub fn new(messages: Vec) -> Self { + pub fn new(messages: Vec) -> Self { Self { messages } } } -#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] -pub struct UserChatCompletionCancellationRequest; - #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[deku(id_type = "u8")] #[repr(u8)] -pub enum UserRequestPayload { +pub enum UserRequest { #[deku(id = "0")] ChatCompletion(UserChatCompletionRequest), #[deku(id = "1")] - ChatCompletionCancellation(UserChatCompletionCancellationRequest), -} - -#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] -pub struct UserRequest { - #[deku(endian = "little")] - pub request_id: u64, - pub payload: UserRequestPayload, + ChatCompletionCancellation, } #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] @@ -69,42 +59,23 @@ pub struct UserResponseChatCompletion { pub piece: String, } -#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] -pub struct UserResponseChatCompletionCancellation; - -#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] -pub struct UserResponseChatCompletionEnd; - #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[deku(id_type = "u8")] #[repr(u8)] -pub enum UserResponsePayload { +pub enum UserResponse { #[deku(id = "0")] ChatCompletion(UserResponseChatCompletion), #[deku(id = "1")] - ChatCompletionCancellation(UserResponseChatCompletionCancellation), + ChatCompletionCancellation, #[deku(id = "2")] - ChatCompletionEnd(UserResponseChatCompletionEnd), + ChatCompletionEnd, } -#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] -pub struct UserResponse { - #[deku(endian = "little")] - pub request_id: u64, - pub payload: UserResponsePayload, -} - -impl DekuBytes for ChatMessage {} +impl DekuBytes for UserChatMessage {} impl DekuBytes for UserChatCompletionRequest {} -impl DekuBytes for UserChatCompletionCancellationRequest {} - -impl DekuBytes for UserRequestPayload {} impl DekuBytes for UserRequest {} impl DekuBytes for UserResponseChatCompletion {} -impl DekuBytes for UserResponseChatCompletionCancellation {} -impl DekuBytes for UserResponseChatCompletionEnd {} -impl DekuBytes for UserResponsePayload {} impl DekuBytes for UserResponse {} diff --git a/website/src/bin/E.rs b/website/src/bin/E.rs deleted file mode 100644 index 416be7d..0000000 --- a/website/src/bin/E.rs +++ /dev/null @@ -1,7 +0,0 @@ -use website::dedicated_ai_server::TEST::main_E; - -#[tokio::main] -async fn main() -> Result<(), Box> { - main_E().await?; - Ok(()) -} diff --git a/website/src/bin/F.rs b/website/src/bin/F.rs deleted file mode 100644 index 63f6316..0000000 --- a/website/src/bin/F.rs +++ /dev/null @@ -1,7 +0,0 @@ -use website::dedicated_ai_server::TEST::main_F; - -#[tokio::main] -async fn main() -> Result<(), Box> { - main_F().await?; - Ok(()) -} diff --git a/website/src/dedicated_ai_server/TEST.rs b/website/src/dedicated_ai_server/TEST.rs deleted file mode 100644 index 71ede44..0000000 --- a/website/src/dedicated_ai_server/TEST.rs +++ /dev/null @@ -1,117 +0,0 @@ -use tokio::net::{TcpListener, TcpStream}; - -use anyhow::Result; -use crate::dedicated_ai_server::api::{ - ChatCompletionRequest, - MessageInChat, - Request, - RequestPayload, -}; -use frontend_protocol::DekuBytes; -use crate::dedicated_ai_server::talking::{SecretStreamSocket, wrap_connection_socket}; -use crate::dedicated_ai_server::talking::{ProtocolError, FrameCallback}; - - -const SERVER_HOST: &str = "127.0.0.1"; -const SERVER_PORT: u16 = 9000; -const SHARED_SECRET: &str = "change-me"; - -/* client =========== aka F.py */ - - -async fn connect_to_server(host: &str, port: u16, shared_secret: &str) -> Result { - let socket = TcpStream::connect((host, port)).await?; - wrap_connection_socket(socket, shared_secret).await -} - - -async fn send_message(connection: &mut SecretStreamSocket, request_id: u64, role: &str, content: &str) -> Result<()> { - let batch = ChatCompletionRequest::new(vec![MessageInChat { - role: role.to_string(), - content: content.to_string(), - }]); - let request = Request { - request_id, - payload: RequestPayload::ChatCompletion(batch), - }; - let payload = request.to_bytes()?; - connection.send_frame(&payload).await -} - - -async fn close_connection(connection: &mut SecretStreamSocket) -> Result<()> { - connection.close(true).await -} - - -pub async fn main_F() -> Result<()> { - let mut connection = connect_to_server(SERVER_HOST, SERVER_PORT, SHARED_SECRET).await?; - send_message(&mut connection, 1, "user", "hello from Rust client").await?; - close_connection(&mut connection).await?; - Ok(()) -} - - -/* server========= aka E.py */ - -fn print_batch(peer: &str, request_id: u64, batch: &ChatCompletionRequest) { - println!( - "[packet] {} sent request_id={} with {} message(s)", - peer, - request_id, - batch.messages.len(), - ); - for (index, message) in batch.messages.iter().enumerate() { - println!( - " [{}] role={:?} content={:?}", - index, message.role, message.content - ); - } -} - - -async fn handle_client(client_socket: TcpStream, peer: String, shared_secret: &str) -> Result<()> { - let peer_for_callback = peer.clone(); - let on_frame: FrameCallback<'_> = Box::new(move |frame: Vec| { - let request = Request::from_bytes(&frame)?; - match request.payload { - RequestPayload::ChatCompletion(batch) => { - print_batch(&peer_for_callback, request.request_id, &batch); - } - RequestPayload::ChatCompletionCancellation(_) => { - println!( - "[packet] {} sent request_id={} cancel", - peer_for_callback, - request.request_id, - ); - } - } - Ok(()) - }); - - let mut transport = wrap_connection_socket(client_socket, shared_secret).await?; - println!("[connected] {}", peer); - let result = transport.run_receiving_loop(on_frame).await; - println!("[disconnected] {}", peer); - let _ = transport.close(true).await; - result -} - - -pub async fn main_E() -> Result<()> { - let listener = TcpListener::bind((SERVER_HOST, SERVER_PORT)).await?; - println!("[listening] {}:{}", SERVER_HOST, SERVER_PORT); - - loop { - let (client_socket, addr) = listener.accept().await?; - let peer = addr.to_string(); - - if let Err(err) = handle_client(client_socket, peer.clone(), SHARED_SECRET).await { - if err.downcast_ref::().is_some() { - eprintln!("[protocol error] {}: {:#}", peer, err); - } else { - eprintln!("[error] {}: {:#}", peer, err); - } - } - } -} diff --git a/website/src/dedicated_ai_server/api.rs b/website/src/dedicated_ai_server/api.rs index f8fba38..8f10475 100644 --- a/website/src/dedicated_ai_server/api.rs +++ b/website/src/dedicated_ai_server/api.rs @@ -40,9 +40,6 @@ impl ChatCompletionRequest { } -#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] -pub struct ChatCompletionCancellationRequest; - #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[deku(id_type = "u8")] #[repr(u8)] @@ -50,7 +47,7 @@ pub enum RequestPayload { #[deku(id = "0")] ChatCompletion(ChatCompletionRequest), #[deku(id = "1")] - ChatCompletionCancellation(ChatCompletionCancellationRequest), + ChatCompletionCancellation, } #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] @@ -69,12 +66,6 @@ pub struct ResponseChatCompletion { pub piece: String, } -#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] -pub struct ResponseChatCompletionCancellation; - -#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] -pub struct ResponseChatCompletionEnd; - #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[deku(id_type = "u8")] #[repr(u8)] @@ -82,9 +73,9 @@ pub enum ResponsePayload { #[deku(id = "0")] ChatCompletion(ResponseChatCompletion), #[deku(id = "1")] - ChatCompletionCancellation(ResponseChatCompletionCancellation), + ChatCompletionCancellation, #[deku(id = "2")] - ChatCompletionEnd(ResponseChatCompletionEnd) + ChatCompletionEnd } #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] @@ -98,12 +89,9 @@ impl DekuBytes for MessageInChat {} impl DekuBytes for ChatCompletionRequest {} -impl DekuBytes for ChatCompletionCancellationRequest {} - impl DekuBytes for RequestPayload {} impl DekuBytes for Request {} impl DekuBytes for ResponseChatCompletion {} -impl DekuBytes for ResponseChatCompletionCancellation {} impl DekuBytes for ResponsePayload {} impl DekuBytes for Response {} diff --git a/website/src/dedicated_ai_server/connection.rs b/website/src/dedicated_ai_server/connection.rs index 46414de..3e3397c 100644 --- a/website/src/dedicated_ai_server/connection.rs +++ b/website/src/dedicated_ai_server/connection.rs @@ -32,9 +32,14 @@ struct PendingChatCompletionRecord { } #[derive(Debug)] -struct RequestWorkItem { - request: Request, - response_tx: mpsc::UnboundedSender, +enum RequestWorkItem { + ChatCompletion { + request: Request, + response_tx: mpsc::UnboundedSender, + }, + ChatCompletionCancellation { + request_id: u64, + }, } #[derive(Debug)] @@ -70,7 +75,7 @@ impl DedicatedAiServerConnection { pub fn send_chat_completion( &self, messages: Vec, - ) -> Result> { + ) -> Result<(mpsc::UnboundedReceiver, u64)> { let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); let request = Request { request_id, @@ -78,9 +83,16 @@ impl DedicatedAiServerConnection { }; let (response_tx, response_rx) = mpsc::unbounded_channel(); self.request_tx - .send(RequestWorkItem { request, response_tx }) + .send(RequestWorkItem::ChatCompletion { request, response_tx }) .map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?; - Ok(response_rx) + Ok((response_rx, request_id)) + } + + pub fn send_chat_completion_cancellation(&self, request_id: u64) -> Result<()> { + self.request_tx + .send(RequestWorkItem::ChatCompletionCancellation { request_id }) + .map_err(|err| anyhow::anyhow!("failed to enqueue cancellation: {err}"))?; + Ok(()) } } @@ -145,14 +157,26 @@ async fn run_connected_loop( Some(item) => item, None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen }; - let request_id = item.request.request_id; - pending.insert( - request_id, - PendingChatCompletionRecord { response_tx: item.response_tx }, - ); + match item { + RequestWorkItem::ChatCompletion { request, response_tx } => { + let request_id = request.request_id; + pending.insert( + request_id, + PendingChatCompletionRecord { response_tx }, + ); - let payload = item.request.to_bytes().context("failed to encode request")?; - transport.send_frame(&payload).await.context("failed to send request")?; + let payload = request.to_bytes().context("failed to encode request")?; + transport.send_frame(&payload).await.context("failed to send request")?; + } + RequestWorkItem::ChatCompletionCancellation { request_id } => { + let request = Request { + request_id, + payload: RequestPayload::ChatCompletionCancellation, + }; + let payload = request.to_bytes().context("failed to encode cancellation")?; + transport.send_frame(&payload).await.context("failed to send cancellation")?; + } + } } } } @@ -183,12 +207,12 @@ fn handle_response_frame( pending.remove(&request_id); } } - ResponsePayload::ChatCompletionEnd(_) => { + ResponsePayload::ChatCompletionEnd => { if let Some(record) = pending.remove(&request_id) { let _ = record.response_tx.send(MessagePiece::End); } } - ResponsePayload::ChatCompletionCancellation(_) => { + ResponsePayload::ChatCompletionCancellation => { if let Some(record) = pending.remove(&request_id) { let _ = record.response_tx.send(MessagePiece::Cancelled); } diff --git a/website/src/dedicated_ai_server/mod.rs b/website/src/dedicated_ai_server/mod.rs index e485011..0b739b4 100644 --- a/website/src/dedicated_ai_server/mod.rs +++ b/website/src/dedicated_ai_server/mod.rs @@ -1,4 +1,3 @@ pub mod api; pub mod connection; pub mod talking; -pub mod TEST; diff --git a/website/src/lib.rs b/website/src/lib.rs index 35a2009..ce12370 100644 --- a/website/src/lib.rs +++ b/website/src/lib.rs @@ -31,12 +31,8 @@ use crate::dedicated_ai_server::connection::{MessagePiece, MessagePiecePayload}; use frontend_protocol::{ DekuBytes, UserRequest, - UserRequestPayload, UserResponse, UserResponseChatCompletion, - UserResponseChatCompletionCancellation, - UserResponseChatCompletionEnd, - UserResponsePayload, }; async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box> { @@ -264,10 +260,11 @@ async fn chat( } async fn handle_chat_socket(mut socket: WebSocket, state: AppState) { - 'outer: while let Some(msg) = socket.recv().await { - let msg = match msg { - Ok(msg) => msg, - Err(_) => break, + 'outer: loop { + let msg = match socket.recv().await { + Some(Ok(msg)) => msg, + Some(Err(_)) => break, + None => break, }; match msg { @@ -280,60 +277,62 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) { } }; - match request.payload { - UserRequestPayload::ChatCompletion(payload) => { - let messages = payload - .messages - .into_iter() - .map(|msg| MessageInChat { - role: msg.role, - content: msg.content, - }) - .collect(); + let payload = match request { + UserRequest::ChatCompletion(payload) => payload, + UserRequest::ChatCompletionCancellation => { + eprintln!("[chat] protocol error: unexpected cancellation without active request"); + break; + } + }; - let mut response_rx = match state.dedicated_ai.send_chat_completion(messages) { - Ok(rx) => rx, - Err(err) => { - eprintln!("[chat] failed to send request: {err:#}"); - let response = UserResponse { - request_id: request.request_id, - payload: UserResponsePayload::ChatCompletionCancellation( - UserResponseChatCompletionCancellation, - ), - }; - if let Ok(bytes) = response.to_bytes() { - let _ = socket.send(Message::Binary(bytes)).await; - } - continue; - } - }; + let messages = payload + .messages + .into_iter() + .map(|msg| MessageInChat { + role: msg.role, + content: msg.content, + }) + .collect(); - while let Some(piece) = response_rx.recv().await { - let (payload, should_break) = match piece { + let (mut response_rx, request_id) = match state.dedicated_ai. + send_chat_completion(messages) + { + Ok(rx) => rx, + Err(err) => { + eprintln!("[chat] failed to send request: {err:#}"); + let response = UserResponse::ChatCompletionCancellation; + // todo: make to_bytes nofail + if let Ok(bytes) = response.to_bytes() { + let _ = socket.send(Message::Binary(bytes)).await; + } + continue; + } + }; + + loop { + tokio::select! { + piece = response_rx.recv() => { + let piece = match piece { + Some(piece) => piece, + None => break, + }; + let (response, should_break) = match piece { MessagePiece::Piece(MessagePiecePayload(text)) => ( - UserResponsePayload::ChatCompletion(UserResponseChatCompletion { + UserResponse::ChatCompletion(UserResponseChatCompletion { piece: text, }), false, ), MessagePiece::End => ( - UserResponsePayload::ChatCompletionEnd( - UserResponseChatCompletionEnd, - ), + UserResponse::ChatCompletionEnd, true, ), MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => ( - UserResponsePayload::ChatCompletionCancellation( - UserResponseChatCompletionCancellation, - ), + UserResponse::ChatCompletionCancellation, true, ), }; - let response = UserResponse { - request_id: request.request_id, - payload, - }; let bytes = match response.to_bytes() { Ok(bytes) => bytes, Err(err) => { @@ -348,24 +347,57 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) { break; } } - } - UserRequestPayload::ChatCompletionCancellation(_) => { - let response = UserResponse { - request_id: request.request_id, - payload: UserResponsePayload::ChatCompletionCancellation( - UserResponseChatCompletionCancellation, - ), - }; - if let Ok(bytes) = response.to_bytes() { - if socket.send(Message::Binary(bytes)).await.is_err() { - break 'outer; + msg = socket.recv() => { + let msg = match msg { + Some(Ok(msg)) => msg, + Some(Err(_)) => break 'outer, + None => break 'outer, + }; + + match msg { + Message::Binary(bytes) => { + let request = match UserRequest::from_bytes(&bytes) { + Ok(request) => request, + Err(err) => { + eprintln!("[chat] failed to decode request: {err:#}"); + break 'outer; + } + }; + + match request { + UserRequest::ChatCompletionCancellation => { + if let Err(err) = state + .dedicated_ai + .send_chat_completion_cancellation(request_id) + { + eprintln!("[chat] failed to send cancellation: {err:#}"); + break 'outer; + } + } + UserRequest::ChatCompletion(_) => { + eprintln!("[chat] protocol error: chat completion while receiving"); + break 'outer; + } + } + } + Message::Close(_) => break 'outer, + _ => { + eprintln!("[chat] protocol error: unexpected non-binary message"); + break 'outer; + } } } } } } - Message::Close(_) => break, - _ => {} + Message::Close(_) => { + println!(" [debug] websocket closed"); + break; + } + _ => { + eprintln!("[chat] protocol error: unexpected non-binary message"); + break; + } } } }