Chat is working somewhat
This commit is contained in:
parent
6a66cde0d0
commit
336177941f
@ -16,7 +16,7 @@ SERVER_PORT = 9000
|
|||||||
SHARED_SECRET = "change-me"
|
SHARED_SECRET = "change-me"
|
||||||
|
|
||||||
|
|
||||||
PROCESS_DELAY_SECONDS = 1.5
|
PROCESS_DELAY_SECONDS = 0.4
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -90,7 +90,9 @@ def worker_loop(
|
|||||||
record.response_queue.put_nowait,
|
record.response_queue.put_nowait,
|
||||||
MessagePiece(piece=piece, is_end=False, is_cancel=False),
|
MessagePiece(piece=piece, is_end=False, is_cancel=False),
|
||||||
)
|
)
|
||||||
|
print("[debug] got a new piece")
|
||||||
if record.is_cancelled():
|
if record.is_cancelled():
|
||||||
|
print("[debug] record was cancelled")
|
||||||
cancelled = True
|
cancelled = True
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@ -16,8 +16,11 @@ version = "0.3"
|
|||||||
features = [
|
features = [
|
||||||
"Window",
|
"Window",
|
||||||
"Document",
|
"Document",
|
||||||
|
"Element",
|
||||||
"Event",
|
"Event",
|
||||||
"EventTarget",
|
"EventTarget",
|
||||||
|
"HtmlElement",
|
||||||
|
"HtmlTextAreaElement",
|
||||||
"KeyboardEvent",
|
"KeyboardEvent",
|
||||||
"MouseEvent",
|
"MouseEvent",
|
||||||
"WheelEvent",
|
"WheelEvent",
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
<!doctype html>
|
<!doctype html>
|
||||||
<html lang="en">
|
<html lang="en" class="chat-html">
|
||||||
<head>
|
<head>
|
||||||
<meta charset="utf-8" />
|
<meta charset="utf-8" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||||
@ -14,10 +14,28 @@
|
|||||||
init_chat();
|
init_chat();
|
||||||
</script>
|
</script>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body class="chat-body">
|
||||||
<main class="welcome-card">
|
<main class="chat-shell">
|
||||||
<h1>Chat</h1>
|
<section class="chat-panel">
|
||||||
<p>WebSocket connection initialized. Open the console to see messages.</p>
|
<div class="chat-header">
|
||||||
|
<h1 class="chat-title">Chat</h1>
|
||||||
|
<div class="chat-subtitle">Live session</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="chat-container">
|
||||||
|
<div class="chat-messages" id="chat-messages"></div>
|
||||||
|
|
||||||
|
<div class="chat-input-area">
|
||||||
|
<textarea
|
||||||
|
class="chat-input"
|
||||||
|
id="chat-input"
|
||||||
|
placeholder="Write a message. Enter to send, Shift+Enter for a new line."
|
||||||
|
rows="4"
|
||||||
|
></textarea>
|
||||||
|
<div class="chat-input-hint">Enter to send. Shift+Enter for a new line.</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
</main>
|
</main>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
|
|||||||
@ -1,79 +1,284 @@
|
|||||||
use std::cell::RefCell;
|
use std::cell::RefCell;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
use js_sys::{ArrayBuffer, Uint8Array};
|
||||||
use wasm_bindgen::prelude::*;
|
use wasm_bindgen::prelude::*;
|
||||||
use wasm_bindgen::JsCast;
|
use wasm_bindgen::JsCast;
|
||||||
use web_sys::{console, window, BinaryType, ErrorEvent, Event, MessageEvent, WebSocket};
|
use web_sys::{
|
||||||
use js_sys::{ArrayBuffer, Uint8Array};
|
console,
|
||||||
|
window,
|
||||||
|
BinaryType,
|
||||||
|
Document,
|
||||||
|
Element,
|
||||||
|
ErrorEvent,
|
||||||
|
Event,
|
||||||
|
HtmlElement,
|
||||||
|
HtmlTextAreaElement,
|
||||||
|
KeyboardEvent,
|
||||||
|
MessageEvent,
|
||||||
|
WebSocket,
|
||||||
|
};
|
||||||
|
|
||||||
use frontend_protocol::{
|
use frontend_protocol::{
|
||||||
ChatMessage,
|
UserChatMessage,
|
||||||
DekuBytes,
|
DekuBytes,
|
||||||
UserChatCompletionRequest,
|
UserChatCompletionRequest,
|
||||||
UserRequest,
|
UserRequest,
|
||||||
UserRequestPayload,
|
|
||||||
UserResponse,
|
UserResponse,
|
||||||
UserResponsePayload,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
static WS_HANDLE: RefCell<Option<Rc<WebSocket>>> = RefCell::new(None);
|
static APP_HANDLE: RefCell<Option<Rc<AppState>>> = RefCell::new(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ChatState {
|
||||||
|
messages: Vec<UserChatMessage>,
|
||||||
|
message_nodes: Vec<MessageNode>,
|
||||||
|
is_receiving: bool,
|
||||||
|
active_assistant_index: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MessageNode {
|
||||||
|
content: Element,
|
||||||
|
status: Option<Element>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
enum ChatStatus {
|
||||||
|
Pending,
|
||||||
|
Cancelled,
|
||||||
|
Hidden,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AppState {
|
||||||
|
ws: WebSocket,
|
||||||
|
document: Document,
|
||||||
|
messages_container: Element,
|
||||||
|
input: HtmlTextAreaElement,
|
||||||
|
state: RefCell<ChatState>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn append_message(
|
||||||
|
document: &Document,
|
||||||
|
container: &Element,
|
||||||
|
role: &str,
|
||||||
|
content: &str,
|
||||||
|
status: Option<ChatStatus>,
|
||||||
|
) -> Result<MessageNode, JsValue> {
|
||||||
|
let wrapper = document.create_element("div")?;
|
||||||
|
wrapper.set_class_name(&format!("chat-message chat-message--{}", role));
|
||||||
|
|
||||||
|
let role_el = document.create_element("div")?;
|
||||||
|
role_el.set_class_name("chat-message__role");
|
||||||
|
role_el.set_text_content(Some(role));
|
||||||
|
|
||||||
|
let content_el = document.create_element("div")?;
|
||||||
|
content_el.set_class_name("chat-message__content");
|
||||||
|
content_el.set_text_content(Some(content));
|
||||||
|
|
||||||
|
wrapper.append_child(&role_el)?;
|
||||||
|
wrapper.append_child(&content_el)?;
|
||||||
|
|
||||||
|
let status_el = if let Some(status) = status {
|
||||||
|
let status_el = document.create_element("div")?;
|
||||||
|
apply_status(&status_el, status);
|
||||||
|
wrapper.append_child(&status_el)?;
|
||||||
|
Some(status_el)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
container.append_child(&wrapper)?;
|
||||||
|
|
||||||
|
Ok(MessageNode {
|
||||||
|
content: content_el,
|
||||||
|
status: status_el,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_message_content(node: &Element, content: &str) {
|
||||||
|
node.set_text_content(Some(content));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_status(node: &Element, status: ChatStatus) {
|
||||||
|
match status {
|
||||||
|
ChatStatus::Pending => {
|
||||||
|
node.set_class_name("chat-message__status chat-message__status--pending");
|
||||||
|
node.set_text_content(Some("..."));
|
||||||
|
}
|
||||||
|
ChatStatus::Cancelled => {
|
||||||
|
node.set_class_name("chat-message__status chat-message__status--cancelled");
|
||||||
|
node.set_text_content(Some("[canceled]"));
|
||||||
|
}
|
||||||
|
ChatStatus::Hidden => {
|
||||||
|
node.set_class_name("chat-message__status chat-message__status--hidden");
|
||||||
|
node.set_text_content(Some(""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scroll_to_bottom(container: &Element) {
|
||||||
|
if let Some(element) = container.dyn_ref::<HtmlElement>() {
|
||||||
|
let height = element.scroll_height();
|
||||||
|
element.set_scroll_top(height);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[wasm_bindgen]
|
#[wasm_bindgen]
|
||||||
pub fn init_chat() -> Result<(), JsValue> {
|
pub fn init_chat() -> Result<(), JsValue> {
|
||||||
let window = window().ok_or_else(|| JsValue::from_str("Missing window"))?;
|
let window = window().ok_or_else(|| JsValue::from_str("Missing window"))?;
|
||||||
|
let document = window.document().ok_or_else(|| JsValue::from_str("Missing document"))?;
|
||||||
let location = window.location();
|
let location = window.location();
|
||||||
let host = location.host()?;
|
let host = location.host()?;
|
||||||
let protocol = location.protocol()?;
|
let protocol = location.protocol()?;
|
||||||
let scheme = if protocol == "https:" { "wss" } else { "ws" };
|
let scheme = if protocol == "https:" { "wss" } else { "ws" };
|
||||||
let ws_url = format!("{scheme}://{host}/chat");
|
let ws_url = format!("{scheme}://{host}/chat");
|
||||||
|
|
||||||
let ws = Rc::new(WebSocket::new(&ws_url)?);
|
let messages_container = document
|
||||||
|
.get_element_by_id("chat-messages")
|
||||||
|
.ok_or_else(|| JsValue::from_str("Missing chat-messages element"))?;
|
||||||
|
let input_el: HtmlTextAreaElement = document
|
||||||
|
.get_element_by_id("chat-input")
|
||||||
|
.ok_or_else(|| JsValue::from_str("Missing chat-input element"))?
|
||||||
|
.dyn_into()?;
|
||||||
|
|
||||||
|
let ws = WebSocket::new(&ws_url)?;
|
||||||
ws.set_binary_type(BinaryType::Arraybuffer);
|
ws.set_binary_type(BinaryType::Arraybuffer);
|
||||||
console::log_1(&format!("[ws] connecting to {ws_url}").into());
|
console::log_1(&format!("[ws] connecting to {ws_url}").into());
|
||||||
|
|
||||||
let ws_for_open = ws.clone();
|
let app = Rc::new(AppState {
|
||||||
|
ws,
|
||||||
|
document,
|
||||||
|
messages_container,
|
||||||
|
input: input_el,
|
||||||
|
state: RefCell::new(ChatState {
|
||||||
|
messages: Vec::new(),
|
||||||
|
message_nodes: Vec::new(),
|
||||||
|
is_receiving: false,
|
||||||
|
active_assistant_index: None,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
|
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
|
||||||
console::log_1(&"[ws] connected".into());
|
console::log_1(&"[ws] connected".into());
|
||||||
let requests = [
|
|
||||||
UserRequest {
|
|
||||||
request_id: 1,
|
|
||||||
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
|
|
||||||
ChatMessage {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: r#"this is the first message first chat
|
|
||||||
too bad it is in lower case one one one"#
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
])),
|
|
||||||
},
|
|
||||||
UserRequest {
|
|
||||||
request_id: 2,
|
|
||||||
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
|
|
||||||
ChatMessage {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "And this is message two of chat request 2. \
|
|
||||||
Too bad it isn't processed two two two"
|
|
||||||
.to_string(),
|
|
||||||
},
|
|
||||||
])),
|
|
||||||
},
|
|
||||||
];
|
|
||||||
|
|
||||||
for request in requests {
|
|
||||||
match request.to_bytes() {
|
|
||||||
Ok(bytes) => {
|
|
||||||
console::log_1(&format!("[ws] sending request_id={} bytes={}", request.request_id, bytes.len()).into());
|
|
||||||
let _ = ws_for_open.send_with_u8_array(&bytes);
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
console::error_1(&format!("[ws] encode error: {err:#}").into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}));
|
}));
|
||||||
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
|
app.ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
|
||||||
onopen.forget();
|
onopen.forget();
|
||||||
|
|
||||||
|
let app_for_keydown = app.clone();
|
||||||
|
let onkeydown = Closure::<dyn FnMut(KeyboardEvent)>::wrap(Box::new(move |event: KeyboardEvent| {
|
||||||
|
if event.ctrl_key() && (event.key() == "c" || event.key() == "C") {
|
||||||
|
let state = app_for_keydown.state.borrow();
|
||||||
|
if state.is_receiving {
|
||||||
|
event.prevent_default();
|
||||||
|
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
|
||||||
|
console::error_1(&"[ws] socket is not open".into());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let request = UserRequest::ChatCompletionCancellation;
|
||||||
|
match request.to_bytes() {
|
||||||
|
Ok(bytes) => {
|
||||||
|
if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) {
|
||||||
|
console::error_1(&format!("[ws] cancel send error: {:?}", err).into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
console::error_1(&format!("[ws] cancel encode error: {err:#}").into());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.key() != "Enter" || event.shift_key() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
event.prevent_default();
|
||||||
|
|
||||||
|
let mut state = app_for_keydown.state.borrow_mut();
|
||||||
|
if state.is_receiving {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw = app_for_keydown.input.value();
|
||||||
|
let trimmed = raw.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
|
||||||
|
console::error_1(&"[ws] socket is not open".into());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let user_content = trimmed.to_string();
|
||||||
|
app_for_keydown.input.set_value("");
|
||||||
|
|
||||||
|
let user_node = match append_message(
|
||||||
|
&app_for_keydown.document,
|
||||||
|
&app_for_keydown.messages_container,
|
||||||
|
"user",
|
||||||
|
&user_content,
|
||||||
|
None,
|
||||||
|
) {
|
||||||
|
Ok(node) => node,
|
||||||
|
Err(err) => {
|
||||||
|
console::error_1(&format!("[ui] failed to append user message: {:?}", err).into());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
state.messages.push(UserChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: user_content,
|
||||||
|
});
|
||||||
|
state.message_nodes.push(user_node);
|
||||||
|
scroll_to_bottom(&app_for_keydown.messages_container);
|
||||||
|
|
||||||
|
let history = state.messages.clone();
|
||||||
|
let request = UserRequest::ChatCompletion(UserChatCompletionRequest::new(history));
|
||||||
|
|
||||||
|
let bytes = match request.to_bytes() {
|
||||||
|
Ok(bytes) => bytes,
|
||||||
|
Err(err) => {
|
||||||
|
console::error_1(&format!("[ws] encode error: {err:#}").into());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) {
|
||||||
|
console::error_1(&format!("[ws] send error: {:?}", err).into());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let assistant_node = match append_message(
|
||||||
|
&app_for_keydown.document,
|
||||||
|
&app_for_keydown.messages_container,
|
||||||
|
"assistant",
|
||||||
|
"",
|
||||||
|
Some(ChatStatus::Pending),
|
||||||
|
) {
|
||||||
|
Ok(node) => node,
|
||||||
|
Err(err) => {
|
||||||
|
console::error_1(&format!("[ui] failed to append assistant message: {:?}", err).into());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
state.messages.push(UserChatMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: String::new(),
|
||||||
|
});
|
||||||
|
state.message_nodes.push(assistant_node);
|
||||||
|
state.active_assistant_index = Some(state.messages.len() - 1);
|
||||||
|
state.is_receiving = true;
|
||||||
|
scroll_to_bottom(&app_for_keydown.messages_container);
|
||||||
|
}));
|
||||||
|
app.input.set_onkeydown(Some(onkeydown.as_ref().unchecked_ref()));
|
||||||
|
onkeydown.forget();
|
||||||
|
|
||||||
|
let app_for_message = app.clone();
|
||||||
let onmessage = Closure::<dyn FnMut(MessageEvent)>::wrap(Box::new(move |event: MessageEvent| {
|
let onmessage = Closure::<dyn FnMut(MessageEvent)>::wrap(Box::new(move |event: MessageEvent| {
|
||||||
let data = event.data();
|
let data = event.data();
|
||||||
if let Some(text) = data.as_string() {
|
if let Some(text) = data.as_string() {
|
||||||
@ -87,36 +292,73 @@ Too bad it isn't processed two two two"
|
|||||||
|
|
||||||
let data = Uint8Array::new(&data);
|
let data = Uint8Array::new(&data);
|
||||||
let bytes = data.to_vec();
|
let bytes = data.to_vec();
|
||||||
console::log_1(&format!("[ws] received bytes={}", bytes.len()).into());
|
|
||||||
let response = match UserResponse::from_bytes(&bytes) {
|
let response = match UserResponse::from_bytes(&bytes) {
|
||||||
Ok(response) => response,
|
Ok(response) => response,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
console::error_1(&format!("[ws] decode error: {err:#} (bytes={})", bytes.len()).into());
|
console::error_1(&format!("[ws] decode error: {err:#}").into());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match response.payload {
|
let mut state = app_for_message.state.borrow_mut();
|
||||||
UserResponsePayload::ChatCompletion(payload) => {
|
let assistant_index = match state.active_assistant_index {
|
||||||
console::log_1(&format!("[ws] request_id={} piece={}", response.request_id, payload.piece).into());
|
Some(index) => index,
|
||||||
|
None => {
|
||||||
|
console::log_1(&"[ws] missing assistant index".into());
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
UserResponsePayload::ChatCompletionCancellation(_) => {
|
};
|
||||||
console::log_1(&format!("[ws] request_id={} [cancel]", response.request_id).into());
|
|
||||||
|
match response {
|
||||||
|
UserResponse::ChatCompletion(completion) => {
|
||||||
|
if let Some(message) = state.messages.get_mut(assistant_index) {
|
||||||
|
message.content.push_str(&completion.piece);
|
||||||
|
}
|
||||||
|
if let Some(node) = state.message_nodes.get(assistant_index) {
|
||||||
|
if let Some(message) = state.messages.get(assistant_index) {
|
||||||
|
set_message_content(&node.content, &message.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scroll_to_bottom(&app_for_message.messages_container);
|
||||||
}
|
}
|
||||||
UserResponsePayload::ChatCompletionEnd(_) => {
|
UserResponse::ChatCompletionEnd => {
|
||||||
console::log_1(&format!("[ws] request_id={} [end]", response.request_id).into());
|
state.is_receiving = false;
|
||||||
|
if let Some(node) = state
|
||||||
|
.active_assistant_index
|
||||||
|
.and_then(|index| state.message_nodes.get(index))
|
||||||
|
{
|
||||||
|
if let Some(status_node) = node.status.as_ref() {
|
||||||
|
apply_status(status_node, ChatStatus::Hidden);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.active_assistant_index = None;
|
||||||
|
let _ = app_for_message.input.focus();
|
||||||
|
}
|
||||||
|
UserResponse::ChatCompletionCancellation => {
|
||||||
|
state.is_receiving = false;
|
||||||
|
if let Some(node) = state
|
||||||
|
.active_assistant_index
|
||||||
|
.and_then(|index| state.message_nodes.get(index))
|
||||||
|
{
|
||||||
|
if let Some(status_node) = node.status.as_ref() {
|
||||||
|
apply_status(status_node, ChatStatus::Cancelled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state.active_assistant_index = None;
|
||||||
|
let _ = app_for_message.input.focus();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}));
|
}));
|
||||||
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
|
app.ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
|
||||||
onmessage.forget();
|
onmessage.forget();
|
||||||
|
|
||||||
let onerror = Closure::<dyn FnMut(ErrorEvent)>::wrap(Box::new(move |event: ErrorEvent| {
|
let onerror = Closure::<dyn FnMut(ErrorEvent)>::wrap(Box::new(move |event: ErrorEvent| {
|
||||||
console::error_1(&format!("[ws] error: {}", event.message()).into());
|
console::error_1(&format!("[ws] error: {}", event.message()).into());
|
||||||
}));
|
}));
|
||||||
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
|
app.ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
|
||||||
onerror.forget();
|
onerror.forget();
|
||||||
|
|
||||||
WS_HANDLE.with(|slot| *slot.borrow_mut() = Some(ws.clone()));
|
APP_HANDLE.with(|slot| *slot.borrow_mut() = Some(app.clone()));
|
||||||
|
let _ = app.input.focus();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@ -7,6 +7,158 @@ body {
|
|||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.chat-html,
|
||||||
|
.chat-body {
|
||||||
|
height: 100%;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-body {
|
||||||
|
margin: 0;
|
||||||
|
background: #ffffff;
|
||||||
|
color: #1b1b1b;
|
||||||
|
font-family: "Fira Sans", "Space Grotesk", "Montserrat", sans-serif;
|
||||||
|
font-size: 14pt;
|
||||||
|
display: block;
|
||||||
|
overflow: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-shell {
|
||||||
|
min-height: 100%;
|
||||||
|
padding: 32px 24px;
|
||||||
|
box-sizing: border-box;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-panel {
|
||||||
|
width: min(100%, max(70%, 1000pt));
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 16px;
|
||||||
|
min-height: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-header {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-title {
|
||||||
|
margin: 0;
|
||||||
|
font-size: 20pt;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-subtitle {
|
||||||
|
font-size: 10pt;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 1px;
|
||||||
|
color: #6b6b6b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-container {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 16px;
|
||||||
|
min-height: 0;
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-messages {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 12px;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding-right: 4px;
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message {
|
||||||
|
width: 100%;
|
||||||
|
border-radius: 14px;
|
||||||
|
padding: 12px 14px;
|
||||||
|
box-sizing: border-box;
|
||||||
|
border: 1px solid #e2e2e2;
|
||||||
|
background: #fafafa;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message--user {
|
||||||
|
background: #eef4ff;
|
||||||
|
border-color: #d5e2ff;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message--assistant {
|
||||||
|
background: #f7f2ff;
|
||||||
|
border-color: #e5d7ff;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message--system {
|
||||||
|
background: #f8f8f8;
|
||||||
|
border-color: #dddddd;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message__role {
|
||||||
|
font-size: 9pt;
|
||||||
|
text-transform: uppercase;
|
||||||
|
letter-spacing: 1px;
|
||||||
|
color: #6b6b6b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message__content {
|
||||||
|
white-space: pre-wrap;
|
||||||
|
line-height: 1.45;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message__status {
|
||||||
|
font-size: 9pt;
|
||||||
|
color: #9aa0a6;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message__status--pending {
|
||||||
|
color: #9aa0a6;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message__status--cancelled {
|
||||||
|
color: #c0392b;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-message__status--hidden {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input-area {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input {
|
||||||
|
width: 100%;
|
||||||
|
min-height: 120px;
|
||||||
|
padding: 12px 14px;
|
||||||
|
border-radius: 12px;
|
||||||
|
border: 1px solid #d4d4d4;
|
||||||
|
font-size: 14pt;
|
||||||
|
font-family: inherit;
|
||||||
|
resize: vertical;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input:disabled {
|
||||||
|
background: #f2f2f2;
|
||||||
|
color: #8a8a8a;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-input-hint {
|
||||||
|
font-size: 9.5pt;
|
||||||
|
color: #6b6b6b;
|
||||||
|
}
|
||||||
|
|
||||||
body {
|
body {
|
||||||
display: flex;
|
display: flex;
|
||||||
align-items: center;
|
align-items: center;
|
||||||
|
|||||||
@ -10,7 +10,7 @@ pub use self::utils::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
pub struct ChatMessage {
|
pub struct UserChatMessage {
|
||||||
#[deku(
|
#[deku(
|
||||||
reader = "read_pascal_string(deku::reader)",
|
reader = "read_pascal_string(deku::reader)",
|
||||||
writer = "write_pascal_string(deku::writer, &self.role)"
|
writer = "write_pascal_string(deku::writer, &self.role)"
|
||||||
@ -31,33 +31,23 @@ pub struct UserChatCompletionRequest {
|
|||||||
reader = "read_vec_u32(deku::reader)",
|
reader = "read_vec_u32(deku::reader)",
|
||||||
writer = "write_vec_u32(deku::writer, &self.messages)"
|
writer = "write_vec_u32(deku::writer, &self.messages)"
|
||||||
)]
|
)]
|
||||||
pub messages: Vec<ChatMessage>,
|
pub messages: Vec<UserChatMessage>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UserChatCompletionRequest {
|
impl UserChatCompletionRequest {
|
||||||
pub fn new(messages: Vec<ChatMessage>) -> Self {
|
pub fn new(messages: Vec<UserChatMessage>) -> Self {
|
||||||
Self { messages }
|
Self { messages }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
|
||||||
pub struct UserChatCompletionCancellationRequest;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
#[deku(id_type = "u8")]
|
#[deku(id_type = "u8")]
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
pub enum UserRequestPayload {
|
pub enum UserRequest {
|
||||||
#[deku(id = "0")]
|
#[deku(id = "0")]
|
||||||
ChatCompletion(UserChatCompletionRequest),
|
ChatCompletion(UserChatCompletionRequest),
|
||||||
#[deku(id = "1")]
|
#[deku(id = "1")]
|
||||||
ChatCompletionCancellation(UserChatCompletionCancellationRequest),
|
ChatCompletionCancellation,
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
|
||||||
pub struct UserRequest {
|
|
||||||
#[deku(endian = "little")]
|
|
||||||
pub request_id: u64,
|
|
||||||
pub payload: UserRequestPayload,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
@ -69,42 +59,23 @@ pub struct UserResponseChatCompletion {
|
|||||||
pub piece: String,
|
pub piece: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
|
||||||
pub struct UserResponseChatCompletionCancellation;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
|
||||||
pub struct UserResponseChatCompletionEnd;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
#[deku(id_type = "u8")]
|
#[deku(id_type = "u8")]
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
pub enum UserResponsePayload {
|
pub enum UserResponse {
|
||||||
#[deku(id = "0")]
|
#[deku(id = "0")]
|
||||||
ChatCompletion(UserResponseChatCompletion),
|
ChatCompletion(UserResponseChatCompletion),
|
||||||
#[deku(id = "1")]
|
#[deku(id = "1")]
|
||||||
ChatCompletionCancellation(UserResponseChatCompletionCancellation),
|
ChatCompletionCancellation,
|
||||||
#[deku(id = "2")]
|
#[deku(id = "2")]
|
||||||
ChatCompletionEnd(UserResponseChatCompletionEnd),
|
ChatCompletionEnd,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
impl DekuBytes for UserChatMessage {}
|
||||||
pub struct UserResponse {
|
|
||||||
#[deku(endian = "little")]
|
|
||||||
pub request_id: u64,
|
|
||||||
pub payload: UserResponsePayload,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl DekuBytes for ChatMessage {}
|
|
||||||
|
|
||||||
impl DekuBytes for UserChatCompletionRequest {}
|
impl DekuBytes for UserChatCompletionRequest {}
|
||||||
|
|
||||||
impl DekuBytes for UserChatCompletionCancellationRequest {}
|
|
||||||
|
|
||||||
impl DekuBytes for UserRequestPayload {}
|
|
||||||
impl DekuBytes for UserRequest {}
|
impl DekuBytes for UserRequest {}
|
||||||
|
|
||||||
impl DekuBytes for UserResponseChatCompletion {}
|
impl DekuBytes for UserResponseChatCompletion {}
|
||||||
impl DekuBytes for UserResponseChatCompletionCancellation {}
|
|
||||||
impl DekuBytes for UserResponseChatCompletionEnd {}
|
|
||||||
impl DekuBytes for UserResponsePayload {}
|
|
||||||
impl DekuBytes for UserResponse {}
|
impl DekuBytes for UserResponse {}
|
||||||
|
|||||||
@ -1,7 +0,0 @@
|
|||||||
use website::dedicated_ai_server::TEST::main_E;
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
main_E().await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@ -1,7 +0,0 @@
|
|||||||
use website::dedicated_ai_server::TEST::main_F;
|
|
||||||
|
|
||||||
#[tokio::main]
|
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
||||||
main_F().await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
@ -1,117 +0,0 @@
|
|||||||
use tokio::net::{TcpListener, TcpStream};
|
|
||||||
|
|
||||||
use anyhow::Result;
|
|
||||||
use crate::dedicated_ai_server::api::{
|
|
||||||
ChatCompletionRequest,
|
|
||||||
MessageInChat,
|
|
||||||
Request,
|
|
||||||
RequestPayload,
|
|
||||||
};
|
|
||||||
use frontend_protocol::DekuBytes;
|
|
||||||
use crate::dedicated_ai_server::talking::{SecretStreamSocket, wrap_connection_socket};
|
|
||||||
use crate::dedicated_ai_server::talking::{ProtocolError, FrameCallback};
|
|
||||||
|
|
||||||
|
|
||||||
const SERVER_HOST: &str = "127.0.0.1";
|
|
||||||
const SERVER_PORT: u16 = 9000;
|
|
||||||
const SHARED_SECRET: &str = "change-me";
|
|
||||||
|
|
||||||
/* client =========== aka F.py */
|
|
||||||
|
|
||||||
|
|
||||||
async fn connect_to_server(host: &str, port: u16, shared_secret: &str) -> Result<SecretStreamSocket> {
|
|
||||||
let socket = TcpStream::connect((host, port)).await?;
|
|
||||||
wrap_connection_socket(socket, shared_secret).await
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async fn send_message(connection: &mut SecretStreamSocket, request_id: u64, role: &str, content: &str) -> Result<()> {
|
|
||||||
let batch = ChatCompletionRequest::new(vec![MessageInChat {
|
|
||||||
role: role.to_string(),
|
|
||||||
content: content.to_string(),
|
|
||||||
}]);
|
|
||||||
let request = Request {
|
|
||||||
request_id,
|
|
||||||
payload: RequestPayload::ChatCompletion(batch),
|
|
||||||
};
|
|
||||||
let payload = request.to_bytes()?;
|
|
||||||
connection.send_frame(&payload).await
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async fn close_connection(connection: &mut SecretStreamSocket) -> Result<()> {
|
|
||||||
connection.close(true).await
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
pub async fn main_F() -> Result<()> {
|
|
||||||
let mut connection = connect_to_server(SERVER_HOST, SERVER_PORT, SHARED_SECRET).await?;
|
|
||||||
send_message(&mut connection, 1, "user", "hello from Rust client").await?;
|
|
||||||
close_connection(&mut connection).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/* server========= aka E.py */
|
|
||||||
|
|
||||||
fn print_batch(peer: &str, request_id: u64, batch: &ChatCompletionRequest) {
|
|
||||||
println!(
|
|
||||||
"[packet] {} sent request_id={} with {} message(s)",
|
|
||||||
peer,
|
|
||||||
request_id,
|
|
||||||
batch.messages.len(),
|
|
||||||
);
|
|
||||||
for (index, message) in batch.messages.iter().enumerate() {
|
|
||||||
println!(
|
|
||||||
" [{}] role={:?} content={:?}",
|
|
||||||
index, message.role, message.content
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async fn handle_client(client_socket: TcpStream, peer: String, shared_secret: &str) -> Result<()> {
|
|
||||||
let peer_for_callback = peer.clone();
|
|
||||||
let on_frame: FrameCallback<'_> = Box::new(move |frame: Vec<u8>| {
|
|
||||||
let request = Request::from_bytes(&frame)?;
|
|
||||||
match request.payload {
|
|
||||||
RequestPayload::ChatCompletion(batch) => {
|
|
||||||
print_batch(&peer_for_callback, request.request_id, &batch);
|
|
||||||
}
|
|
||||||
RequestPayload::ChatCompletionCancellation(_) => {
|
|
||||||
println!(
|
|
||||||
"[packet] {} sent request_id={} cancel",
|
|
||||||
peer_for_callback,
|
|
||||||
request.request_id,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut transport = wrap_connection_socket(client_socket, shared_secret).await?;
|
|
||||||
println!("[connected] {}", peer);
|
|
||||||
let result = transport.run_receiving_loop(on_frame).await;
|
|
||||||
println!("[disconnected] {}", peer);
|
|
||||||
let _ = transport.close(true).await;
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
pub async fn main_E() -> Result<()> {
|
|
||||||
let listener = TcpListener::bind((SERVER_HOST, SERVER_PORT)).await?;
|
|
||||||
println!("[listening] {}:{}", SERVER_HOST, SERVER_PORT);
|
|
||||||
|
|
||||||
loop {
|
|
||||||
let (client_socket, addr) = listener.accept().await?;
|
|
||||||
let peer = addr.to_string();
|
|
||||||
|
|
||||||
if let Err(err) = handle_client(client_socket, peer.clone(), SHARED_SECRET).await {
|
|
||||||
if err.downcast_ref::<ProtocolError>().is_some() {
|
|
||||||
eprintln!("[protocol error] {}: {:#}", peer, err);
|
|
||||||
} else {
|
|
||||||
eprintln!("[error] {}: {:#}", peer, err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -40,9 +40,6 @@ impl ChatCompletionRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
|
||||||
pub struct ChatCompletionCancellationRequest;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
#[deku(id_type = "u8")]
|
#[deku(id_type = "u8")]
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
@ -50,7 +47,7 @@ pub enum RequestPayload {
|
|||||||
#[deku(id = "0")]
|
#[deku(id = "0")]
|
||||||
ChatCompletion(ChatCompletionRequest),
|
ChatCompletion(ChatCompletionRequest),
|
||||||
#[deku(id = "1")]
|
#[deku(id = "1")]
|
||||||
ChatCompletionCancellation(ChatCompletionCancellationRequest),
|
ChatCompletionCancellation,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
@ -69,12 +66,6 @@ pub struct ResponseChatCompletion {
|
|||||||
pub piece: String,
|
pub piece: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
|
||||||
pub struct ResponseChatCompletionCancellation;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
|
||||||
pub struct ResponseChatCompletionEnd;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
#[deku(id_type = "u8")]
|
#[deku(id_type = "u8")]
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
@ -82,9 +73,9 @@ pub enum ResponsePayload {
|
|||||||
#[deku(id = "0")]
|
#[deku(id = "0")]
|
||||||
ChatCompletion(ResponseChatCompletion),
|
ChatCompletion(ResponseChatCompletion),
|
||||||
#[deku(id = "1")]
|
#[deku(id = "1")]
|
||||||
ChatCompletionCancellation(ResponseChatCompletionCancellation),
|
ChatCompletionCancellation,
|
||||||
#[deku(id = "2")]
|
#[deku(id = "2")]
|
||||||
ChatCompletionEnd(ResponseChatCompletionEnd)
|
ChatCompletionEnd
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||||
@ -98,12 +89,9 @@ impl DekuBytes for MessageInChat {}
|
|||||||
|
|
||||||
impl DekuBytes for ChatCompletionRequest {}
|
impl DekuBytes for ChatCompletionRequest {}
|
||||||
|
|
||||||
impl DekuBytes for ChatCompletionCancellationRequest {}
|
|
||||||
|
|
||||||
impl DekuBytes for RequestPayload {}
|
impl DekuBytes for RequestPayload {}
|
||||||
impl DekuBytes for Request {}
|
impl DekuBytes for Request {}
|
||||||
|
|
||||||
impl DekuBytes for ResponseChatCompletion {}
|
impl DekuBytes for ResponseChatCompletion {}
|
||||||
impl DekuBytes for ResponseChatCompletionCancellation {}
|
|
||||||
impl DekuBytes for ResponsePayload {}
|
impl DekuBytes for ResponsePayload {}
|
||||||
impl DekuBytes for Response {}
|
impl DekuBytes for Response {}
|
||||||
|
|||||||
@ -32,9 +32,14 @@ struct PendingChatCompletionRecord {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct RequestWorkItem {
|
enum RequestWorkItem {
|
||||||
request: Request,
|
ChatCompletion {
|
||||||
response_tx: mpsc::UnboundedSender<MessagePiece>,
|
request: Request,
|
||||||
|
response_tx: mpsc::UnboundedSender<MessagePiece>,
|
||||||
|
},
|
||||||
|
ChatCompletionCancellation {
|
||||||
|
request_id: u64,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -70,7 +75,7 @@ impl DedicatedAiServerConnection {
|
|||||||
pub fn send_chat_completion(
|
pub fn send_chat_completion(
|
||||||
&self,
|
&self,
|
||||||
messages: Vec<MessageInChat>,
|
messages: Vec<MessageInChat>,
|
||||||
) -> Result<mpsc::UnboundedReceiver<MessagePiece>> {
|
) -> Result<(mpsc::UnboundedReceiver<MessagePiece>, u64)> {
|
||||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||||
let request = Request {
|
let request = Request {
|
||||||
request_id,
|
request_id,
|
||||||
@ -78,9 +83,16 @@ impl DedicatedAiServerConnection {
|
|||||||
};
|
};
|
||||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||||
self.request_tx
|
self.request_tx
|
||||||
.send(RequestWorkItem { request, response_tx })
|
.send(RequestWorkItem::ChatCompletion { request, response_tx })
|
||||||
.map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?;
|
.map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?;
|
||||||
Ok(response_rx)
|
Ok((response_rx, request_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_chat_completion_cancellation(&self, request_id: u64) -> Result<()> {
|
||||||
|
self.request_tx
|
||||||
|
.send(RequestWorkItem::ChatCompletionCancellation { request_id })
|
||||||
|
.map_err(|err| anyhow::anyhow!("failed to enqueue cancellation: {err}"))?;
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,14 +157,26 @@ async fn run_connected_loop(
|
|||||||
Some(item) => item,
|
Some(item) => item,
|
||||||
None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen
|
None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen
|
||||||
};
|
};
|
||||||
let request_id = item.request.request_id;
|
match item {
|
||||||
pending.insert(
|
RequestWorkItem::ChatCompletion { request, response_tx } => {
|
||||||
request_id,
|
let request_id = request.request_id;
|
||||||
PendingChatCompletionRecord { response_tx: item.response_tx },
|
pending.insert(
|
||||||
);
|
request_id,
|
||||||
|
PendingChatCompletionRecord { response_tx },
|
||||||
|
);
|
||||||
|
|
||||||
let payload = item.request.to_bytes().context("failed to encode request")?;
|
let payload = request.to_bytes().context("failed to encode request")?;
|
||||||
transport.send_frame(&payload).await.context("failed to send request")?;
|
transport.send_frame(&payload).await.context("failed to send request")?;
|
||||||
|
}
|
||||||
|
RequestWorkItem::ChatCompletionCancellation { request_id } => {
|
||||||
|
let request = Request {
|
||||||
|
request_id,
|
||||||
|
payload: RequestPayload::ChatCompletionCancellation,
|
||||||
|
};
|
||||||
|
let payload = request.to_bytes().context("failed to encode cancellation")?;
|
||||||
|
transport.send_frame(&payload).await.context("failed to send cancellation")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -183,12 +207,12 @@ fn handle_response_frame(
|
|||||||
pending.remove(&request_id);
|
pending.remove(&request_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ResponsePayload::ChatCompletionEnd(_) => {
|
ResponsePayload::ChatCompletionEnd => {
|
||||||
if let Some(record) = pending.remove(&request_id) {
|
if let Some(record) = pending.remove(&request_id) {
|
||||||
let _ = record.response_tx.send(MessagePiece::End);
|
let _ = record.response_tx.send(MessagePiece::End);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ResponsePayload::ChatCompletionCancellation(_) => {
|
ResponsePayload::ChatCompletionCancellation => {
|
||||||
if let Some(record) = pending.remove(&request_id) {
|
if let Some(record) = pending.remove(&request_id) {
|
||||||
let _ = record.response_tx.send(MessagePiece::Cancelled);
|
let _ = record.response_tx.send(MessagePiece::Cancelled);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,3 @@
|
|||||||
pub mod api;
|
pub mod api;
|
||||||
pub mod connection;
|
pub mod connection;
|
||||||
pub mod talking;
|
pub mod talking;
|
||||||
pub mod TEST;
|
|
||||||
|
|||||||
@ -31,12 +31,8 @@ use crate::dedicated_ai_server::connection::{MessagePiece, MessagePiecePayload};
|
|||||||
use frontend_protocol::{
|
use frontend_protocol::{
|
||||||
DekuBytes,
|
DekuBytes,
|
||||||
UserRequest,
|
UserRequest,
|
||||||
UserRequestPayload,
|
|
||||||
UserResponse,
|
UserResponse,
|
||||||
UserResponseChatCompletion,
|
UserResponseChatCompletion,
|
||||||
UserResponseChatCompletionCancellation,
|
|
||||||
UserResponseChatCompletionEnd,
|
|
||||||
UserResponsePayload,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box<dyn std::error::Error>> {
|
async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box<dyn std::error::Error>> {
|
||||||
@ -264,10 +260,11 @@ async fn chat(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
|
async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
|
||||||
'outer: while let Some(msg) = socket.recv().await {
|
'outer: loop {
|
||||||
let msg = match msg {
|
let msg = match socket.recv().await {
|
||||||
Ok(msg) => msg,
|
Some(Ok(msg)) => msg,
|
||||||
Err(_) => break,
|
Some(Err(_)) => break,
|
||||||
|
None => break,
|
||||||
};
|
};
|
||||||
|
|
||||||
match msg {
|
match msg {
|
||||||
@ -280,60 +277,62 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match request.payload {
|
let payload = match request {
|
||||||
UserRequestPayload::ChatCompletion(payload) => {
|
UserRequest::ChatCompletion(payload) => payload,
|
||||||
let messages = payload
|
UserRequest::ChatCompletionCancellation => {
|
||||||
.messages
|
eprintln!("[chat] protocol error: unexpected cancellation without active request");
|
||||||
.into_iter()
|
break;
|
||||||
.map(|msg| MessageInChat {
|
}
|
||||||
role: msg.role,
|
};
|
||||||
content: msg.content,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut response_rx = match state.dedicated_ai.send_chat_completion(messages) {
|
let messages = payload
|
||||||
Ok(rx) => rx,
|
.messages
|
||||||
Err(err) => {
|
.into_iter()
|
||||||
eprintln!("[chat] failed to send request: {err:#}");
|
.map(|msg| MessageInChat {
|
||||||
let response = UserResponse {
|
role: msg.role,
|
||||||
request_id: request.request_id,
|
content: msg.content,
|
||||||
payload: UserResponsePayload::ChatCompletionCancellation(
|
})
|
||||||
UserResponseChatCompletionCancellation,
|
.collect();
|
||||||
),
|
|
||||||
};
|
|
||||||
if let Ok(bytes) = response.to_bytes() {
|
|
||||||
let _ = socket.send(Message::Binary(bytes)).await;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
while let Some(piece) = response_rx.recv().await {
|
let (mut response_rx, request_id) = match state.dedicated_ai.
|
||||||
let (payload, should_break) = match piece {
|
send_chat_completion(messages)
|
||||||
|
{
|
||||||
|
Ok(rx) => rx,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("[chat] failed to send request: {err:#}");
|
||||||
|
let response = UserResponse::ChatCompletionCancellation;
|
||||||
|
// todo: make to_bytes nofail
|
||||||
|
if let Ok(bytes) = response.to_bytes() {
|
||||||
|
let _ = socket.send(Message::Binary(bytes)).await;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
piece = response_rx.recv() => {
|
||||||
|
let piece = match piece {
|
||||||
|
Some(piece) => piece,
|
||||||
|
None => break,
|
||||||
|
};
|
||||||
|
let (response, should_break) = match piece {
|
||||||
MessagePiece::Piece(MessagePiecePayload(text)) => (
|
MessagePiece::Piece(MessagePiecePayload(text)) => (
|
||||||
UserResponsePayload::ChatCompletion(UserResponseChatCompletion {
|
UserResponse::ChatCompletion(UserResponseChatCompletion {
|
||||||
piece: text,
|
piece: text,
|
||||||
}),
|
}),
|
||||||
false,
|
false,
|
||||||
),
|
),
|
||||||
MessagePiece::End => (
|
MessagePiece::End => (
|
||||||
UserResponsePayload::ChatCompletionEnd(
|
UserResponse::ChatCompletionEnd,
|
||||||
UserResponseChatCompletionEnd,
|
|
||||||
),
|
|
||||||
true,
|
true,
|
||||||
),
|
),
|
||||||
MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => (
|
MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => (
|
||||||
UserResponsePayload::ChatCompletionCancellation(
|
UserResponse::ChatCompletionCancellation,
|
||||||
UserResponseChatCompletionCancellation,
|
|
||||||
),
|
|
||||||
true,
|
true,
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = UserResponse {
|
|
||||||
request_id: request.request_id,
|
|
||||||
payload,
|
|
||||||
};
|
|
||||||
let bytes = match response.to_bytes() {
|
let bytes = match response.to_bytes() {
|
||||||
Ok(bytes) => bytes,
|
Ok(bytes) => bytes,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
@ -348,24 +347,57 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
msg = socket.recv() => {
|
||||||
UserRequestPayload::ChatCompletionCancellation(_) => {
|
let msg = match msg {
|
||||||
let response = UserResponse {
|
Some(Ok(msg)) => msg,
|
||||||
request_id: request.request_id,
|
Some(Err(_)) => break 'outer,
|
||||||
payload: UserResponsePayload::ChatCompletionCancellation(
|
None => break 'outer,
|
||||||
UserResponseChatCompletionCancellation,
|
};
|
||||||
),
|
|
||||||
};
|
match msg {
|
||||||
if let Ok(bytes) = response.to_bytes() {
|
Message::Binary(bytes) => {
|
||||||
if socket.send(Message::Binary(bytes)).await.is_err() {
|
let request = match UserRequest::from_bytes(&bytes) {
|
||||||
break 'outer;
|
Ok(request) => request,
|
||||||
|
Err(err) => {
|
||||||
|
eprintln!("[chat] failed to decode request: {err:#}");
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match request {
|
||||||
|
UserRequest::ChatCompletionCancellation => {
|
||||||
|
if let Err(err) = state
|
||||||
|
.dedicated_ai
|
||||||
|
.send_chat_completion_cancellation(request_id)
|
||||||
|
{
|
||||||
|
eprintln!("[chat] failed to send cancellation: {err:#}");
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
UserRequest::ChatCompletion(_) => {
|
||||||
|
eprintln!("[chat] protocol error: chat completion while receiving");
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Close(_) => break 'outer,
|
||||||
|
_ => {
|
||||||
|
eprintln!("[chat] protocol error: unexpected non-binary message");
|
||||||
|
break 'outer;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Message::Close(_) => break,
|
Message::Close(_) => {
|
||||||
_ => {}
|
println!(" [debug] websocket closed");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
eprintln!("[chat] protocol error: unexpected non-binary message");
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user