We communicate from client to dedicated ai server

This commit is contained in:
Андреев Григорий 2026-03-28 18:00:22 +03:00
parent c8bc79320e
commit 6a66cde0d0
18 changed files with 718 additions and 162 deletions

10
Cargo.lock generated
View File

@ -567,11 +567,20 @@ dependencies = [
name = "frontend"
version = "0.1.0"
dependencies = [
"frontend_protocol",
"js-sys",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "frontend_protocol"
version = "0.1.0"
dependencies = [
"anyhow",
"deku",
]
[[package]]
name = "funty"
version = "2.0.0"
@ -2563,6 +2572,7 @@ dependencies = [
"blake2b_simd",
"deku",
"dryoc",
"frontend_protocol",
"rand 0.8.5",
"reqwest",
"serde",

View File

@ -1,5 +1,6 @@
[workspace]
members = [
"frontend",
"frontend_protocol",
"website",
]

View File

@ -32,6 +32,9 @@ class PendingChatCompletionRecord:
was_cancelled: bool = False
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
def mark_cancelled(self) -> None:
with self._lock:
self.was_cancelled = True

View File

@ -9,6 +9,7 @@ crate-type = ["cdylib"]
[dependencies]
wasm-bindgen = "0.2"
js-sys = "0.3"
frontend_protocol = { path = "../frontend_protocol" }
[dependencies.web-sys]
version = "0.3"
@ -22,6 +23,7 @@ features = [
"WheelEvent",
"Location",
"WebSocket",
"BinaryType",
"MessageEvent",
"ErrorEvent",
"console",

View File

@ -1,7 +1,22 @@
use std::cell::RefCell;
use std::rc::Rc;
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsCast;
use web_sys::{console, window, ErrorEvent, Event, MessageEvent, WebSocket};
use web_sys::{console, window, BinaryType, ErrorEvent, Event, MessageEvent, WebSocket};
use js_sys::{ArrayBuffer, Uint8Array};
use frontend_protocol::{
ChatMessage,
DekuBytes,
UserChatCompletionRequest,
UserRequest,
UserRequestPayload,
UserResponse,
UserResponsePayload,
};
thread_local! {
static WS_HANDLE: RefCell<Option<Rc<WebSocket>>> = RefCell::new(None);
}
#[wasm_bindgen]
pub fn init_chat() -> Result<(), JsValue> {
@ -13,24 +28,84 @@ pub fn init_chat() -> Result<(), JsValue> {
let ws_url = format!("{scheme}://{host}/chat");
let ws = Rc::new(WebSocket::new(&ws_url)?);
ws.set_binary_type(BinaryType::Arraybuffer);
console::log_1(&format!("[ws] connecting to {ws_url}").into());
let ws_for_open = ws.clone();
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
let _ = ws_for_open.send_with_str(
r#"this is the first message first chat
too bad it is in lower case one one one"#);
let _ = ws_for_open.send_with_str(
"And this is message two of chat request 2. \
Too bad it isn't processed two two two");
console::log_1(&"[ws] connected".into());
let requests = [
UserRequest {
request_id: 1,
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
ChatMessage {
role: "user".to_string(),
content: r#"this is the first message first chat
too bad it is in lower case one one one"#
.to_string(),
},
])),
},
UserRequest {
request_id: 2,
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
ChatMessage {
role: "user".to_string(),
content: "And this is message two of chat request 2. \
Too bad it isn't processed two two two"
.to_string(),
},
])),
},
];
for request in requests {
match request.to_bytes() {
Ok(bytes) => {
console::log_1(&format!("[ws] sending request_id={} bytes={}", request.request_id, bytes.len()).into());
let _ = ws_for_open.send_with_u8_array(&bytes);
}
Err(err) => {
console::error_1(&format!("[ws] encode error: {err:#}").into());
}
}
}
}));
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
onopen.forget();
let onmessage = Closure::<dyn FnMut(MessageEvent)>::wrap(Box::new(move |event: MessageEvent| {
if let Some(text) = event.data().as_string() {
console::log_1(&format!("[ws] {text}").into());
} else {
console::log_1(&"[ws] non-text message".into());
let data = event.data();
if let Some(text) = data.as_string() {
console::log_1(&format!("[ws] unexpected text frame: {text}").into());
return;
}
if !data.is_instance_of::<ArrayBuffer>() {
console::log_1(&"[ws] unexpected non-binary frame".into());
return;
}
let data = Uint8Array::new(&data);
let bytes = data.to_vec();
console::log_1(&format!("[ws] received bytes={}", bytes.len()).into());
let response = match UserResponse::from_bytes(&bytes) {
Ok(response) => response,
Err(err) => {
console::error_1(&format!("[ws] decode error: {err:#} (bytes={})", bytes.len()).into());
return;
}
};
match response.payload {
UserResponsePayload::ChatCompletion(payload) => {
console::log_1(&format!("[ws] request_id={} piece={}", response.request_id, payload.piece).into());
}
UserResponsePayload::ChatCompletionCancellation(_) => {
console::log_1(&format!("[ws] request_id={} [cancel]", response.request_id).into());
}
UserResponsePayload::ChatCompletionEnd(_) => {
console::log_1(&format!("[ws] request_id={} [end]", response.request_id).into());
}
}
}));
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
@ -42,6 +117,6 @@ pub fn init_chat() -> Result<(), JsValue> {
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
onerror.forget();
let _ = ws.clone();
WS_HANDLE.with(|slot| *slot.borrow_mut() = Some(ws.clone()));
Ok(())
}

View File

@ -0,0 +1,12 @@
[package]
name = "frontend_protocol"
version = "0.1.0"
edition = "2021"
[features]
default = ["std"]
std = ["deku/std"]
[dependencies]
deku = { version = "0.20", default-features = false, features = ["alloc", "bits"] }
anyhow = "1.0"

View File

@ -0,0 +1,110 @@
pub mod utils;
use deku::prelude::*;
pub use self::utils::{
DekuBytes,
read_pascal_string,
read_vec_u32,
write_pascal_string,
write_vec_u32,
};
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct ChatMessage {
#[deku(
reader = "read_pascal_string(deku::reader)",
writer = "write_pascal_string(deku::writer, &self.role)"
)]
pub role: String,
#[deku(
reader = "read_pascal_string(deku::reader)",
writer = "write_pascal_string(deku::writer, &self.content)"
)]
pub content: String,
}
#[deku::deku_derive(DekuRead, DekuWrite)]
#[derive(Debug, Clone, PartialEq)]
pub struct UserChatCompletionRequest {
#[deku(
reader = "read_vec_u32(deku::reader)",
writer = "write_vec_u32(deku::writer, &self.messages)"
)]
pub messages: Vec<ChatMessage>,
}
impl UserChatCompletionRequest {
pub fn new(messages: Vec<ChatMessage>) -> Self {
Self { messages }
}
}
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserChatCompletionCancellationRequest;
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
#[deku(id_type = "u8")]
#[repr(u8)]
pub enum UserRequestPayload {
#[deku(id = "0")]
ChatCompletion(UserChatCompletionRequest),
#[deku(id = "1")]
ChatCompletionCancellation(UserChatCompletionCancellationRequest),
}
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserRequest {
#[deku(endian = "little")]
pub request_id: u64,
pub payload: UserRequestPayload,
}
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserResponseChatCompletion {
#[deku(
reader = "read_pascal_string(deku::reader)",
writer = "write_pascal_string(deku::writer, &self.piece)"
)]
pub piece: String,
}
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserResponseChatCompletionCancellation;
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserResponseChatCompletionEnd;
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
#[deku(id_type = "u8")]
#[repr(u8)]
pub enum UserResponsePayload {
#[deku(id = "0")]
ChatCompletion(UserResponseChatCompletion),
#[deku(id = "1")]
ChatCompletionCancellation(UserResponseChatCompletionCancellation),
#[deku(id = "2")]
ChatCompletionEnd(UserResponseChatCompletionEnd),
}
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserResponse {
#[deku(endian = "little")]
pub request_id: u64,
pub payload: UserResponsePayload,
}
impl DekuBytes for ChatMessage {}
impl DekuBytes for UserChatCompletionRequest {}
impl DekuBytes for UserChatCompletionCancellationRequest {}
impl DekuBytes for UserRequestPayload {}
impl DekuBytes for UserRequest {}
impl DekuBytes for UserResponseChatCompletion {}
impl DekuBytes for UserResponseChatCompletionCancellation {}
impl DekuBytes for UserResponseChatCompletionEnd {}
impl DekuBytes for UserResponsePayload {}
impl DekuBytes for UserResponse {}

View File

@ -0,0 +1,84 @@
use anyhow::{Context, Result, bail};
use deku::ctx::{Endian, Order};
use deku::{DekuError, DekuWriter};
use deku::prelude::{Reader, Writer};
use deku::prelude::*;
pub trait DekuBytes: Sized + Clone + DekuContainerWrite + for<'a> DekuContainerRead<'a> {
fn to_bytes(&self) -> Result<Vec<u8>> {
let type_name = std::any::type_name::<Self>();
DekuContainerWrite::to_bytes(self)
.context(format!("failed to encode {type_name}"))
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
let type_name = std::any::type_name::<Self>();
let ((rest, bit_offset), value) =
<Self as DekuContainerRead>::from_bytes((bytes, 0))
.context(format!("failed to decode {type_name}"))?;
if !rest.is_empty() || bit_offset != 0 {
bail!("trailing bytes after {type_name}");
}
Ok(value)
}
}
pub fn read_pascal_string<R>(reader: &mut Reader<R>) -> core::result::Result<String, DekuError>
where
R: deku::no_std_io::Read + deku::no_std_io::Seek,
{
let byte_len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize;
let mut bytes = vec![0u8; byte_len];
reader.read_bytes(byte_len, &mut bytes, Order::Msb0)?;
String::from_utf8(bytes).map_err(|_| DekuError::Parse("invalid utf-8".into()))
}
pub fn write_pascal_string<W>(
writer: &mut Writer<W>,
value: &String,
) -> core::result::Result<(), DekuError>
where
W: deku::no_std_io::Write + deku::no_std_io::Seek,
{
let bytes = value.as_bytes();
let byte_len =
u32::try_from(bytes.len()).map_err(|_| DekuError::Parse("string too large".into()))?;
byte_len.to_writer(writer, Endian::Little)?;
writer.write_bytes(bytes)?;
Ok(())
}
pub fn read_vec_u32<R, T>(reader: &mut Reader<R>) -> core::result::Result<Vec<T>, DekuError>
where
R: deku::no_std_io::Read + deku::no_std_io::Seek,
for<'a> T: DekuReader<'a, ()>,
{
let len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize;
let mut items = Vec::with_capacity(len);
for _ in 0..len {
let item = T::from_reader_with_ctx(reader, ())?;
items.push(item);
}
Ok(items)
}
pub fn write_vec_u32<W, T>(
writer: &mut Writer<W>,
value: &Vec<T>,
) -> core::result::Result<(), DekuError>
where
W: deku::no_std_io::Write + deku::no_std_io::Seek,
T: DekuWriter<()>,
{
let len = u32::try_from(value.len())
.map_err(|_| DekuError::Parse("vector too large".into()))?;
len.to_writer(writer, Endian::Little)?;
for item in value {
item.to_writer(writer, ())?;
}
Ok(())
}

1
secret-config.toml Normal file
View File

@ -0,0 +1 @@
dedicated_ai_server_secret="change-me"

View File

@ -6,7 +6,7 @@ edition = "2021"
[dependencies]
rand = "0.8"
axum = { version = "0.7", features = ["multipart", "ws"] }
tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "net"] }
tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "net", "sync", "time"] }
tower-http = { version = "0.5", features = ["fs"] }
tera = "1.19.1"
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
@ -19,3 +19,4 @@ anyhow = "1.0"
blake2b_simd = "1.0"
dryoc = "0.7"
deku = "0.20"
frontend_protocol = { path = "../frontend_protocol" }

View File

@ -2,6 +2,22 @@ use serde::Deserialize;
use std::path::Path;
const DEFAULT_CONFIG_PATH: &str = "config.toml";
const DEFAULT_SECRET_CONFIG_PATH: &str = "secret-config.toml";
#[derive(Debug, Clone, Deserialize)]
struct PublicConfig {
pub postgres_host: String,
pub pg_database: String,
pub postgres_user: String,
pub file_storage: String,
pub dedicated_ai_server_address: String,
pub dedicated_ai_server_port: u16,
}
#[derive(Debug, Clone, Deserialize)]
struct SecretConfig {
pub dedicated_ai_server_secret: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct GeneralServiceConfig {
@ -9,13 +25,28 @@ pub struct GeneralServiceConfig {
pub pg_database: String,
pub postgres_user: String,
pub file_storage: String,
pub dedicated_ai_server_address: String,
pub dedicated_ai_server_port: u16,
pub dedicated_ai_server_secret: String,
}
pub type ConfigResult<T> = Result<T, Box<dyn std::error::Error>>;
pub fn load_config(path: impl AsRef<Path>) -> ConfigResult<GeneralServiceConfig> {
let raw = std::fs::read_to_string(path.as_ref())?;
let config: GeneralServiceConfig = toml::from_str(&raw)?;
pub fn load_config(config_path: impl AsRef<Path>, secret_config_path: impl AsRef<Path>) -> ConfigResult<GeneralServiceConfig> {
let raw_config = std::fs::read_to_string(config_path.as_ref())?;
let public_config: PublicConfig = toml::from_str(&raw_config)?;
let raw_secret = std::fs::read_to_string(secret_config_path.as_ref())?;
let secret_config: SecretConfig = toml::from_str(&raw_secret)?;
let config = GeneralServiceConfig {
postgres_host: public_config.postgres_host,
pg_database: public_config.pg_database,
postgres_user: public_config.postgres_user,
file_storage: public_config.file_storage,
dedicated_ai_server_address: public_config.dedicated_ai_server_address,
dedicated_ai_server_port: public_config.dedicated_ai_server_port,
dedicated_ai_server_secret: secret_config.dedicated_ai_server_secret,
};
if config.pg_database.trim().is_empty() {
return Err(std::io::Error::new(
@ -41,9 +72,17 @@ pub fn load_config(path: impl AsRef<Path>) -> ConfigResult<GeneralServiceConfig>
.into());
}
if config.dedicated_ai_server_secret.trim().is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"config dedicated_ai_server_secret is empty",
)
.into());
}
Ok(config)
}
pub fn load_config_default() -> ConfigResult<GeneralServiceConfig> {
load_config(DEFAULT_CONFIG_PATH)
load_config(DEFAULT_CONFIG_PATH, DEFAULT_SECRET_CONFIG_PATH)
}

View File

@ -3,11 +3,11 @@ use tokio::net::{TcpListener, TcpStream};
use anyhow::Result;
use crate::dedicated_ai_server::api::{
ChatCompletionRequest,
DekuBytes,
MessageInChat,
Request,
RequestPayload,
};
use frontend_protocol::DekuBytes;
use crate::dedicated_ai_server::talking::{SecretStreamSocket, wrap_connection_socket};
use crate::dedicated_ai_server::talking::{ProtocolError, FrameCallback};

View File

@ -1,39 +1,12 @@
use anyhow::{Context, Result, bail};
use deku::prelude::*;
use super::marshalling_utils::{
read_bool_u8,
use frontend_protocol::{
DekuBytes,
read_pascal_string,
write_bool_u8,
read_vec_u32,
write_pascal_string,
write_vec_u32,
};
pub trait DekuBytes: Sized + Clone + DekuContainerWrite + for<'a> DekuContainerRead<'a> {
fn pre_encode(&mut self) -> Result<()> {
Ok(())
}
fn to_bytes(&self) -> Result<Vec<u8>> {
let mut value = self.clone();
value.pre_encode()?;
let type_name = std::any::type_name::<Self>();
DekuContainerWrite::to_bytes(&value)
.context(format!("failed to encode {type_name}"))
}
fn from_bytes(bytes: &[u8]) -> Result<Self> {
let type_name = std::any::type_name::<Self>();
let ((rest, bit_offset), value) =
<Self as DekuContainerRead>::from_bytes((bytes, 0))
.context(format!("failed to decode {type_name}"))?;
if !rest.is_empty() || bit_offset != 0 {
bail!("trailing bytes after {type_name}");
}
Ok(value)
}
}
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct MessageInChat {
#[deku(
@ -49,24 +22,20 @@ pub struct MessageInChat {
pub content: String,
}
/* Deku is a joke. A pathetic excuse for a marshalling library. But it's okay */
#[deku::deku_derive(DekuRead, DekuWrite)]
#[derive(Debug, Clone, PartialEq)]
pub struct ChatCompletionRequest {
#[deku(endian = "little", update = "self.messages.len() as u32")]
message_count: u32,
#[deku(count = "*message_count as usize")]
#[deku(
reader = "read_vec_u32(deku::reader)",
writer = "write_vec_u32(deku::writer, &self.messages)"
)]
pub messages: Vec<MessageInChat>,
}
impl ChatCompletionRequest {
pub fn new(messages: Vec<MessageInChat>) -> Self {
Self {
message_count: messages.len() as u32,
messages,
}
Self { messages }
}
}
@ -127,37 +96,12 @@ pub struct Response {
impl DekuBytes for MessageInChat {}
impl DekuBytes for ChatCompletionRequest {
fn pre_encode(&mut self) -> Result<()> {
self.update()
.context("failed to update ChatCompletionRequest before encoding")?;
Ok(())
}
}
impl DekuBytes for ChatCompletionRequest {}
impl DekuBytes for ChatCompletionCancellationRequest {}
impl DekuBytes for RequestPayload {
fn pre_encode(&mut self) -> Result<()> {
if let RequestPayload::ChatCompletion(batch) = self {
batch
.update()
.context("failed to update ChatCompletionRequest before encoding")?;
}
Ok(())
}
}
impl DekuBytes for Request {
fn pre_encode(&mut self) -> Result<()> {
if let RequestPayload::ChatCompletion(batch) = &mut self.payload {
batch
.update()
.context("failed to update ChatCompletionRequest before encoding")?;
}
Ok(())
}
}
impl DekuBytes for RequestPayload {}
impl DekuBytes for Request {}
impl DekuBytes for ResponseChatCompletion {}
impl DekuBytes for ResponseChatCompletionCancellation {}

View File

@ -0,0 +1,205 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use anyhow::{Context, Result};
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio::time::{sleep, Duration};
use frontend_protocol::DekuBytes;
use super::api::{
ChatCompletionRequest,
MessageInChat,
Request,
RequestPayload,
Response,
ResponsePayload,
};
use super::talking::{SecretStreamSocket, wrap_connection_socket};
#[derive(Debug, Clone)]
pub struct MessagePiecePayload(pub String);
#[derive(Debug, Clone)]
pub enum MessagePiece {
Piece(MessagePiecePayload), End, Cancelled, DedicatedServerDisconnected
}
#[derive(Debug)]
struct PendingChatCompletionRecord {
response_tx: mpsc::UnboundedSender<MessagePiece>,
}
#[derive(Debug)]
struct RequestWorkItem {
request: Request,
response_tx: mpsc::UnboundedSender<MessagePiece>,
}
#[derive(Debug)]
pub struct DedicatedAiServerConnection {
address: String,
port: u16,
secret: String,
request_tx: mpsc::UnboundedSender<RequestWorkItem>,
next_request_id: AtomicU64,
}
pub fn connect_to_dedicated_ai_server(address: String, port: u16, secret: String)
-> (DedicatedAiServerConnection, JoinHandle<()>) {
let (request_tx, request_rx) = mpsc::unbounded_channel();
let worker_address = address.clone();
let worker_secret = secret.clone();
let join_hand = tokio::spawn(async move {
connection_worker_loop(worker_address, port, worker_secret, request_rx).await;
});
let conn_bridge = DedicatedAiServerConnection {
address,
port,
secret,
request_tx,
next_request_id: AtomicU64::new(1),
};
(conn_bridge, join_hand)
}
impl DedicatedAiServerConnection {
pub fn send_chat_completion(
&self,
messages: Vec<MessageInChat>,
) -> Result<mpsc::UnboundedReceiver<MessagePiece>> {
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
let request = Request {
request_id,
payload: RequestPayload::ChatCompletion(ChatCompletionRequest::new(messages)),
};
let (response_tx, response_rx) = mpsc::unbounded_channel();
self.request_tx
.send(RequestWorkItem { request, response_tx })
.map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?;
Ok(response_rx)
}
}
async fn connection_worker_loop(
address: String,
port: u16,
secret: String,
mut request_rx: mpsc::UnboundedReceiver<RequestWorkItem>,
) {
// todo : fix a lot of errors
loop {
let socket = match TcpStream::connect((address.as_str(), port)).await {
Ok(socket) => socket,
Err(err) => {
eprintln!(
"[dedicated-ai] failed to connect to {}:{}: {err}",
address, port
);
sleep(Duration::from_secs(1)).await;
continue;
}
};
let mut transport = match wrap_connection_socket(socket, &secret).await {
Ok(transport) => transport,
Err(err) => {
eprintln!("[dedicated-ai] failed to wrap connection: {err:#}");
sleep(Duration::from_secs(1)).await;
continue;
}
};
let mut pending: HashMap<u64, PendingChatCompletionRecord> = HashMap::new();
let result = run_connected_loop(&mut transport, &mut request_rx, &mut pending).await;
if let Err(err) = result {
eprintln!("[dedicated-ai] connection error: {err:#}");
}
cancel_all_pending(&mut pending);
let _ = transport.close(true).await;
sleep(Duration::from_secs(1)).await;
}
}
async fn run_connected_loop(
transport: &mut SecretStreamSocket,
request_rx: &mut mpsc::UnboundedReceiver<RequestWorkItem>,
pending: &mut HashMap<u64, PendingChatCompletionRecord>,
) -> Result<()> {
loop {
tokio::select! {
frame = transport.recv_frame() => {
let frame = frame.context("failed to read frame")?;
let frame = match frame {
Some(bytes) => bytes,
None => return Err(anyhow::anyhow!("connection closed by peer")),
};
handle_response_frame(frame, pending)?;
}
maybe_item = request_rx.recv() => {
let item = match maybe_item {
Some(item) => item,
None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen
};
let request_id = item.request.request_id;
pending.insert(
request_id,
PendingChatCompletionRecord { response_tx: item.response_tx },
);
let payload = item.request.to_bytes().context("failed to encode request")?;
transport.send_frame(&payload).await.context("failed to send request")?;
}
}
}
}
fn handle_response_frame(
frame: Vec<u8>,
pending: &mut HashMap<u64, PendingChatCompletionRecord>,
) -> Result<()> {
let response = Response::from_bytes(&frame).context("failed to decode response")?;
let request_id = response.request_id;
let record = pending.get(&request_id);
if record.is_none() {
eprintln!("[dedicated-ai] response for unknown request_id={request_id}");
return Ok(());
}
match response.payload {
ResponsePayload::ChatCompletion(payload) => {
let send_failed = match pending.get(&request_id) {
Some(record) => record
.response_tx
.send(MessagePiece::Piece( MessagePiecePayload(payload.piece) ))
.is_err(),
None => false,
};
if send_failed {
pending.remove(&request_id);
}
}
ResponsePayload::ChatCompletionEnd(_) => {
if let Some(record) = pending.remove(&request_id) {
let _ = record.response_tx.send(MessagePiece::End);
}
}
ResponsePayload::ChatCompletionCancellation(_) => {
if let Some(record) = pending.remove(&request_id) {
let _ = record.response_tx.send(MessagePiece::Cancelled);
}
}
}
Ok(())
}
fn cancel_all_pending(pending: &mut HashMap<u64, PendingChatCompletionRecord>) {
for (_request_id, record) in pending.drain() {
let _ = record.response_tx.send(MessagePiece::DedicatedServerDisconnected);
}
}

View File

@ -1,56 +0,0 @@
use deku::ctx::{Endian, Order};
use deku::{DekuError, DekuWriter};
use deku::prelude::{Reader, Writer};
use deku::prelude::*;
pub fn read_pascal_string<R>(reader: &mut Reader<R>) -> core::result::Result<String, DekuError>
where
R: deku::no_std_io::Read + deku::no_std_io::Seek,
{
let byte_len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize;
let mut bytes = vec![0u8; byte_len];
reader.read_bytes(byte_len, &mut bytes, Order::Msb0)?;
String::from_utf8(bytes).map_err(|err| DekuError::Parse(err.to_string().into()))
}
pub fn write_pascal_string<W>(
writer: &mut Writer<W>,
value: &String,
) -> core::result::Result<(), DekuError>
where
W: deku::no_std_io::Write + deku::no_std_io::Seek,
{
let bytes = value.as_bytes();
let byte_len =
u32::try_from(bytes.len()).map_err(|_| DekuError::Parse("string too large".into()))?;
byte_len.to_writer(writer, Endian::Little)?;
writer.write_bytes(bytes)?;
Ok(())
}
pub fn read_bool_u8<R>(reader: &mut Reader<R>) -> core::result::Result<bool, DekuError>
where
R: deku::no_std_io::Read + deku::no_std_io::Seek,
{
let value = u8::from_reader_with_ctx(reader, Endian::Little)?;
match value {
0 => Ok(false),
1 => Ok(true),
_ => Err(DekuError::Parse("invalid bool value".into())),
}
}
pub fn write_bool_u8<W>(
writer: &mut Writer<W>,
value: &bool,
) -> core::result::Result<(), DekuError>
where
W: deku::no_std_io::Write + deku::no_std_io::Seek,
{
let encoded: u8 = if *value { 1 } else { 0 };
encoded.to_writer(writer, Endian::Little)?;
Ok(())
}

View File

@ -1,4 +1,4 @@
pub mod api;
pub mod connection;
pub mod talking;
pub mod TEST;
mod marshalling_utils;

View File

@ -17,24 +17,38 @@ use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
use serde::Deserialize;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::OnceLock;
use tera::{Tera};
use tokio::time::{sleep, Duration};
use tower_http::services::{ServeDir, ServeFile};
use samplers::random_junk::random_string;
use config::{ConfigResult, GeneralServiceConfig, load_config, load_config_default};
use db::{DbResult, connect_db, init_database};
use config::load_config_default;
use db::{connect_db, init_database};
use web_file_uploads::{get_file, upload_get, upload_post};
use web_app_state::{AppState, AppStateInner, AuthenticatedUserId};
use dedicated_ai_server::connection::connect_to_dedicated_ai_server;
use dedicated_ai_server::api::MessageInChat;
use crate::dedicated_ai_server::connection::{MessagePiece, MessagePiecePayload};
use frontend_protocol::{
DekuBytes,
UserRequest,
UserRequestPayload,
UserResponse,
UserResponseChatCompletion,
UserResponseChatCompletionCancellation,
UserResponseChatCompletionEnd,
UserResponsePayload,
};
async fn init_app_state() -> Result<AppState, Box<dyn std::error::Error>> {
async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box<dyn std::error::Error>> {
let config = load_config_default()?;
let db = connect_db(&config).await?;
let tera = Tera::new("frontend/pages/**/*.html")?;
Ok(std::sync::Arc::new(AppStateInner { config, db, tera }))
let (dedicated_ai, dedicated_ai_task_handler) = connect_to_dedicated_ai_server(
config.dedicated_ai_server_address.clone(),
config.dedicated_ai_server_port,
config.dedicated_ai_server_secret.clone(),
);
Ok((std::sync::Arc::new(AppStateInner { config, db, tera, dedicated_ai }), dedicated_ai_task_handler))
}
enum PasscodeAuthenticationResult{
@ -235,7 +249,9 @@ async fn chat(
user: AuthenticatedUserId,
) -> Response {
if let Some(ws) = ws {
return ws.on_upgrade(move |socket| handle_chat_socket(socket)).into_response();
return ws
.on_upgrade(move |socket| handle_chat_socket(socket, state.clone()))
.into_response();
}
let mut ctx = tera::Context::new();
@ -247,19 +263,105 @@ async fn chat(
Html(body).into_response()
}
async fn handle_chat_socket(mut socket: WebSocket) {
while let Some(msg) = socket.recv().await {
async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
'outer: while let Some(msg) = socket.recv().await {
let msg = match msg {
Ok(msg) => msg,
Err(_) => break,
};
match msg {
Message::Text(text) => {
sleep(Duration::from_secs(2)).await;
let upper = text.to_uppercase();
if socket.send(Message::Text(upper)).await.is_err() {
break;
Message::Binary(bytes) => {
let request = match UserRequest::from_bytes(&bytes) {
Ok(request) => request,
Err(err) => {
eprintln!("[chat] failed to decode request: {err:#}");
continue;
}
};
match request.payload {
UserRequestPayload::ChatCompletion(payload) => {
let messages = payload
.messages
.into_iter()
.map(|msg| MessageInChat {
role: msg.role,
content: msg.content,
})
.collect();
let mut response_rx = match state.dedicated_ai.send_chat_completion(messages) {
Ok(rx) => rx,
Err(err) => {
eprintln!("[chat] failed to send request: {err:#}");
let response = UserResponse {
request_id: request.request_id,
payload: UserResponsePayload::ChatCompletionCancellation(
UserResponseChatCompletionCancellation,
),
};
if let Ok(bytes) = response.to_bytes() {
let _ = socket.send(Message::Binary(bytes)).await;
}
continue;
}
};
while let Some(piece) = response_rx.recv().await {
let (payload, should_break) = match piece {
MessagePiece::Piece(MessagePiecePayload(text)) => (
UserResponsePayload::ChatCompletion(UserResponseChatCompletion {
piece: text,
}),
false,
),
MessagePiece::End => (
UserResponsePayload::ChatCompletionEnd(
UserResponseChatCompletionEnd,
),
true,
),
MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => (
UserResponsePayload::ChatCompletionCancellation(
UserResponseChatCompletionCancellation,
),
true,
),
};
let response = UserResponse {
request_id: request.request_id,
payload,
};
let bytes = match response.to_bytes() {
Ok(bytes) => bytes,
Err(err) => {
eprintln!("[chat] failed to encode response: {err:#}");
break 'outer;
}
};
if socket.send(Message::Binary(bytes)).await.is_err() {
break 'outer;
}
if should_break {
break;
}
}
}
UserRequestPayload::ChatCompletionCancellation(_) => {
let response = UserResponse {
request_id: request.request_id,
payload: UserResponsePayload::ChatCompletionCancellation(
UserResponseChatCompletionCancellation,
),
};
if let Ok(bytes) = response.to_bytes() {
if socket.send(Message::Binary(bytes)).await.is_err() {
break 'outer;
}
}
}
}
}
Message::Close(_) => break,
@ -270,7 +372,7 @@ async fn handle_chat_socket(mut socket: WebSocket) {
pub async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
let state = init_app_state().await.expect("lol");
let (state, mut dedicated_ai_task) = init_app_state().await.expect("lol");
let app = Router::new()
.route("/login", get(login_get))
.route("/login", post(login_post))
@ -292,7 +394,28 @@ pub async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
let listener = tokio::net::TcpListener::bind(addr)
.await
.expect("bind failed");
axum::serve(listener, app).await?;
let mut server_task = tokio::spawn(async move { axum::serve(listener, app).await });
tokio::select! {
res = &mut server_task => {
dedicated_ai_task.abort();
res??;
}
res = &mut dedicated_ai_task => {
server_task.abort();
res.map_err(|err| {
std::io::Error::new(
std::io::ErrorKind::Other,
format!("dedicated ai task failed: {err}"),
)
})?;
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"dedicated ai task ended unexpectedly",
)
.into());
}
}
Ok(())
}

View File

@ -1,10 +1,12 @@
use tera::Tera;
use crate::GeneralServiceConfig;
use crate::config::GeneralServiceConfig;
use crate::dedicated_ai_server::connection::DedicatedAiServerConnection;
pub struct AppStateInner {
pub config: GeneralServiceConfig,
pub db: tokio_postgres::Client,
pub tera: Tera,
pub dedicated_ai: DedicatedAiServerConnection,
}
pub type AppState = std::sync::Arc<AppStateInner>;