We communicate from client to dedicated ai server
This commit is contained in:
parent
c8bc79320e
commit
6a66cde0d0
10
Cargo.lock
generated
10
Cargo.lock
generated
@ -567,11 +567,20 @@ dependencies = [
|
||||
name = "frontend"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"frontend_protocol",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "frontend_protocol"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"deku",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "funty"
|
||||
version = "2.0.0"
|
||||
@ -2563,6 +2572,7 @@ dependencies = [
|
||||
"blake2b_simd",
|
||||
"deku",
|
||||
"dryoc",
|
||||
"frontend_protocol",
|
||||
"rand 0.8.5",
|
||||
"reqwest",
|
||||
"serde",
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
[workspace]
|
||||
members = [
|
||||
"frontend",
|
||||
"frontend_protocol",
|
||||
"website",
|
||||
]
|
||||
@ -32,6 +32,9 @@ class PendingChatCompletionRecord:
|
||||
was_cancelled: bool = False
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, repr=False)
|
||||
|
||||
|
||||
|
||||
|
||||
def mark_cancelled(self) -> None:
|
||||
with self._lock:
|
||||
self.was_cancelled = True
|
||||
|
||||
@ -9,6 +9,7 @@ crate-type = ["cdylib"]
|
||||
[dependencies]
|
||||
wasm-bindgen = "0.2"
|
||||
js-sys = "0.3"
|
||||
frontend_protocol = { path = "../frontend_protocol" }
|
||||
|
||||
[dependencies.web-sys]
|
||||
version = "0.3"
|
||||
@ -22,6 +23,7 @@ features = [
|
||||
"WheelEvent",
|
||||
"Location",
|
||||
"WebSocket",
|
||||
"BinaryType",
|
||||
"MessageEvent",
|
||||
"ErrorEvent",
|
||||
"console",
|
||||
|
||||
@ -1,7 +1,22 @@
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen::JsCast;
|
||||
use web_sys::{console, window, ErrorEvent, Event, MessageEvent, WebSocket};
|
||||
use web_sys::{console, window, BinaryType, ErrorEvent, Event, MessageEvent, WebSocket};
|
||||
use js_sys::{ArrayBuffer, Uint8Array};
|
||||
use frontend_protocol::{
|
||||
ChatMessage,
|
||||
DekuBytes,
|
||||
UserChatCompletionRequest,
|
||||
UserRequest,
|
||||
UserRequestPayload,
|
||||
UserResponse,
|
||||
UserResponsePayload,
|
||||
};
|
||||
|
||||
thread_local! {
|
||||
static WS_HANDLE: RefCell<Option<Rc<WebSocket>>> = RefCell::new(None);
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn init_chat() -> Result<(), JsValue> {
|
||||
@ -13,24 +28,84 @@ pub fn init_chat() -> Result<(), JsValue> {
|
||||
let ws_url = format!("{scheme}://{host}/chat");
|
||||
|
||||
let ws = Rc::new(WebSocket::new(&ws_url)?);
|
||||
ws.set_binary_type(BinaryType::Arraybuffer);
|
||||
console::log_1(&format!("[ws] connecting to {ws_url}").into());
|
||||
|
||||
let ws_for_open = ws.clone();
|
||||
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
|
||||
let _ = ws_for_open.send_with_str(
|
||||
r#"this is the first message first chat
|
||||
too bad it is in lower case one one one"#);
|
||||
let _ = ws_for_open.send_with_str(
|
||||
"And this is message two of chat request 2. \
|
||||
Too bad it isn't processed two two two");
|
||||
console::log_1(&"[ws] connected".into());
|
||||
let requests = [
|
||||
UserRequest {
|
||||
request_id: 1,
|
||||
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: r#"this is the first message first chat
|
||||
too bad it is in lower case one one one"#
|
||||
.to_string(),
|
||||
},
|
||||
])),
|
||||
},
|
||||
UserRequest {
|
||||
request_id: 2,
|
||||
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: "And this is message two of chat request 2. \
|
||||
Too bad it isn't processed two two two"
|
||||
.to_string(),
|
||||
},
|
||||
])),
|
||||
},
|
||||
];
|
||||
|
||||
for request in requests {
|
||||
match request.to_bytes() {
|
||||
Ok(bytes) => {
|
||||
console::log_1(&format!("[ws] sending request_id={} bytes={}", request.request_id, bytes.len()).into());
|
||||
let _ = ws_for_open.send_with_u8_array(&bytes);
|
||||
}
|
||||
Err(err) => {
|
||||
console::error_1(&format!("[ws] encode error: {err:#}").into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
|
||||
onopen.forget();
|
||||
|
||||
let onmessage = Closure::<dyn FnMut(MessageEvent)>::wrap(Box::new(move |event: MessageEvent| {
|
||||
if let Some(text) = event.data().as_string() {
|
||||
console::log_1(&format!("[ws] {text}").into());
|
||||
} else {
|
||||
console::log_1(&"[ws] non-text message".into());
|
||||
let data = event.data();
|
||||
if let Some(text) = data.as_string() {
|
||||
console::log_1(&format!("[ws] unexpected text frame: {text}").into());
|
||||
return;
|
||||
}
|
||||
if !data.is_instance_of::<ArrayBuffer>() {
|
||||
console::log_1(&"[ws] unexpected non-binary frame".into());
|
||||
return;
|
||||
}
|
||||
|
||||
let data = Uint8Array::new(&data);
|
||||
let bytes = data.to_vec();
|
||||
console::log_1(&format!("[ws] received bytes={}", bytes.len()).into());
|
||||
let response = match UserResponse::from_bytes(&bytes) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
console::error_1(&format!("[ws] decode error: {err:#} (bytes={})", bytes.len()).into());
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match response.payload {
|
||||
UserResponsePayload::ChatCompletion(payload) => {
|
||||
console::log_1(&format!("[ws] request_id={} piece={}", response.request_id, payload.piece).into());
|
||||
}
|
||||
UserResponsePayload::ChatCompletionCancellation(_) => {
|
||||
console::log_1(&format!("[ws] request_id={} [cancel]", response.request_id).into());
|
||||
}
|
||||
UserResponsePayload::ChatCompletionEnd(_) => {
|
||||
console::log_1(&format!("[ws] request_id={} [end]", response.request_id).into());
|
||||
}
|
||||
}
|
||||
}));
|
||||
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
|
||||
@ -42,6 +117,6 @@ pub fn init_chat() -> Result<(), JsValue> {
|
||||
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
|
||||
onerror.forget();
|
||||
|
||||
let _ = ws.clone();
|
||||
WS_HANDLE.with(|slot| *slot.borrow_mut() = Some(ws.clone()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
12
frontend_protocol/Cargo.toml
Normal file
12
frontend_protocol/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "frontend_protocol"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = ["deku/std"]
|
||||
|
||||
[dependencies]
|
||||
deku = { version = "0.20", default-features = false, features = ["alloc", "bits"] }
|
||||
anyhow = "1.0"
|
||||
110
frontend_protocol/src/lib.rs
Normal file
110
frontend_protocol/src/lib.rs
Normal file
@ -0,0 +1,110 @@
|
||||
pub mod utils;
|
||||
|
||||
use deku::prelude::*;
|
||||
pub use self::utils::{
|
||||
DekuBytes,
|
||||
read_pascal_string,
|
||||
read_vec_u32,
|
||||
write_pascal_string,
|
||||
write_vec_u32,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct ChatMessage {
|
||||
#[deku(
|
||||
reader = "read_pascal_string(deku::reader)",
|
||||
writer = "write_pascal_string(deku::writer, &self.role)"
|
||||
)]
|
||||
pub role: String,
|
||||
|
||||
#[deku(
|
||||
reader = "read_pascal_string(deku::reader)",
|
||||
writer = "write_pascal_string(deku::writer, &self.content)"
|
||||
)]
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[deku::deku_derive(DekuRead, DekuWrite)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct UserChatCompletionRequest {
|
||||
#[deku(
|
||||
reader = "read_vec_u32(deku::reader)",
|
||||
writer = "write_vec_u32(deku::writer, &self.messages)"
|
||||
)]
|
||||
pub messages: Vec<ChatMessage>,
|
||||
}
|
||||
|
||||
impl UserChatCompletionRequest {
|
||||
pub fn new(messages: Vec<ChatMessage>) -> Self {
|
||||
Self { messages }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserChatCompletionCancellationRequest;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
#[deku(id_type = "u8")]
|
||||
#[repr(u8)]
|
||||
pub enum UserRequestPayload {
|
||||
#[deku(id = "0")]
|
||||
ChatCompletion(UserChatCompletionRequest),
|
||||
#[deku(id = "1")]
|
||||
ChatCompletionCancellation(UserChatCompletionCancellationRequest),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserRequest {
|
||||
#[deku(endian = "little")]
|
||||
pub request_id: u64,
|
||||
pub payload: UserRequestPayload,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserResponseChatCompletion {
|
||||
#[deku(
|
||||
reader = "read_pascal_string(deku::reader)",
|
||||
writer = "write_pascal_string(deku::writer, &self.piece)"
|
||||
)]
|
||||
pub piece: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserResponseChatCompletionCancellation;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserResponseChatCompletionEnd;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
#[deku(id_type = "u8")]
|
||||
#[repr(u8)]
|
||||
pub enum UserResponsePayload {
|
||||
#[deku(id = "0")]
|
||||
ChatCompletion(UserResponseChatCompletion),
|
||||
#[deku(id = "1")]
|
||||
ChatCompletionCancellation(UserResponseChatCompletionCancellation),
|
||||
#[deku(id = "2")]
|
||||
ChatCompletionEnd(UserResponseChatCompletionEnd),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserResponse {
|
||||
#[deku(endian = "little")]
|
||||
pub request_id: u64,
|
||||
pub payload: UserResponsePayload,
|
||||
}
|
||||
|
||||
impl DekuBytes for ChatMessage {}
|
||||
|
||||
impl DekuBytes for UserChatCompletionRequest {}
|
||||
|
||||
impl DekuBytes for UserChatCompletionCancellationRequest {}
|
||||
|
||||
impl DekuBytes for UserRequestPayload {}
|
||||
impl DekuBytes for UserRequest {}
|
||||
|
||||
impl DekuBytes for UserResponseChatCompletion {}
|
||||
impl DekuBytes for UserResponseChatCompletionCancellation {}
|
||||
impl DekuBytes for UserResponseChatCompletionEnd {}
|
||||
impl DekuBytes for UserResponsePayload {}
|
||||
impl DekuBytes for UserResponse {}
|
||||
84
frontend_protocol/src/utils.rs
Normal file
84
frontend_protocol/src/utils.rs
Normal file
@ -0,0 +1,84 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use deku::ctx::{Endian, Order};
|
||||
use deku::{DekuError, DekuWriter};
|
||||
use deku::prelude::{Reader, Writer};
|
||||
use deku::prelude::*;
|
||||
|
||||
pub trait DekuBytes: Sized + Clone + DekuContainerWrite + for<'a> DekuContainerRead<'a> {
|
||||
fn to_bytes(&self) -> Result<Vec<u8>> {
|
||||
let type_name = std::any::type_name::<Self>();
|
||||
DekuContainerWrite::to_bytes(self)
|
||||
.context(format!("failed to encode {type_name}"))
|
||||
}
|
||||
|
||||
fn from_bytes(bytes: &[u8]) -> Result<Self> {
|
||||
let type_name = std::any::type_name::<Self>();
|
||||
let ((rest, bit_offset), value) =
|
||||
<Self as DekuContainerRead>::from_bytes((bytes, 0))
|
||||
.context(format!("failed to decode {type_name}"))?;
|
||||
|
||||
if !rest.is_empty() || bit_offset != 0 {
|
||||
bail!("trailing bytes after {type_name}");
|
||||
}
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_pascal_string<R>(reader: &mut Reader<R>) -> core::result::Result<String, DekuError>
|
||||
where
|
||||
R: deku::no_std_io::Read + deku::no_std_io::Seek,
|
||||
{
|
||||
let byte_len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize;
|
||||
let mut bytes = vec![0u8; byte_len];
|
||||
reader.read_bytes(byte_len, &mut bytes, Order::Msb0)?;
|
||||
|
||||
String::from_utf8(bytes).map_err(|_| DekuError::Parse("invalid utf-8".into()))
|
||||
}
|
||||
|
||||
pub fn write_pascal_string<W>(
|
||||
writer: &mut Writer<W>,
|
||||
value: &String,
|
||||
) -> core::result::Result<(), DekuError>
|
||||
where
|
||||
W: deku::no_std_io::Write + deku::no_std_io::Seek,
|
||||
{
|
||||
let bytes = value.as_bytes();
|
||||
let byte_len =
|
||||
u32::try_from(bytes.len()).map_err(|_| DekuError::Parse("string too large".into()))?;
|
||||
|
||||
byte_len.to_writer(writer, Endian::Little)?;
|
||||
writer.write_bytes(bytes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn read_vec_u32<R, T>(reader: &mut Reader<R>) -> core::result::Result<Vec<T>, DekuError>
|
||||
where
|
||||
R: deku::no_std_io::Read + deku::no_std_io::Seek,
|
||||
for<'a> T: DekuReader<'a, ()>,
|
||||
{
|
||||
let len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize;
|
||||
let mut items = Vec::with_capacity(len);
|
||||
for _ in 0..len {
|
||||
let item = T::from_reader_with_ctx(reader, ())?;
|
||||
items.push(item);
|
||||
}
|
||||
Ok(items)
|
||||
}
|
||||
|
||||
pub fn write_vec_u32<W, T>(
|
||||
writer: &mut Writer<W>,
|
||||
value: &Vec<T>,
|
||||
) -> core::result::Result<(), DekuError>
|
||||
where
|
||||
W: deku::no_std_io::Write + deku::no_std_io::Seek,
|
||||
T: DekuWriter<()>,
|
||||
{
|
||||
let len = u32::try_from(value.len())
|
||||
.map_err(|_| DekuError::Parse("vector too large".into()))?;
|
||||
len.to_writer(writer, Endian::Little)?;
|
||||
for item in value {
|
||||
item.to_writer(writer, ())?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
1
secret-config.toml
Normal file
1
secret-config.toml
Normal file
@ -0,0 +1 @@
|
||||
dedicated_ai_server_secret="change-me"
|
||||
@ -6,7 +6,7 @@ edition = "2021"
|
||||
[dependencies]
|
||||
rand = "0.8"
|
||||
axum = { version = "0.7", features = ["multipart", "ws"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "net"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "fs", "io-util", "net", "sync", "time"] }
|
||||
tower-http = { version = "0.5", features = ["fs"] }
|
||||
tera = "1.19.1"
|
||||
reqwest = { version = "0.12", features = ["json", "rustls-tls"], default-features = false }
|
||||
@ -19,3 +19,4 @@ anyhow = "1.0"
|
||||
blake2b_simd = "1.0"
|
||||
dryoc = "0.7"
|
||||
deku = "0.20"
|
||||
frontend_protocol = { path = "../frontend_protocol" }
|
||||
|
||||
@ -2,6 +2,22 @@ use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
|
||||
const DEFAULT_CONFIG_PATH: &str = "config.toml";
|
||||
const DEFAULT_SECRET_CONFIG_PATH: &str = "secret-config.toml";
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct PublicConfig {
|
||||
pub postgres_host: String,
|
||||
pub pg_database: String,
|
||||
pub postgres_user: String,
|
||||
pub file_storage: String,
|
||||
pub dedicated_ai_server_address: String,
|
||||
pub dedicated_ai_server_port: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct SecretConfig {
|
||||
pub dedicated_ai_server_secret: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct GeneralServiceConfig {
|
||||
@ -9,13 +25,28 @@ pub struct GeneralServiceConfig {
|
||||
pub pg_database: String,
|
||||
pub postgres_user: String,
|
||||
pub file_storage: String,
|
||||
pub dedicated_ai_server_address: String,
|
||||
pub dedicated_ai_server_port: u16,
|
||||
pub dedicated_ai_server_secret: String,
|
||||
}
|
||||
|
||||
pub type ConfigResult<T> = Result<T, Box<dyn std::error::Error>>;
|
||||
|
||||
pub fn load_config(path: impl AsRef<Path>) -> ConfigResult<GeneralServiceConfig> {
|
||||
let raw = std::fs::read_to_string(path.as_ref())?;
|
||||
let config: GeneralServiceConfig = toml::from_str(&raw)?;
|
||||
pub fn load_config(config_path: impl AsRef<Path>, secret_config_path: impl AsRef<Path>) -> ConfigResult<GeneralServiceConfig> {
|
||||
let raw_config = std::fs::read_to_string(config_path.as_ref())?;
|
||||
let public_config: PublicConfig = toml::from_str(&raw_config)?;
|
||||
let raw_secret = std::fs::read_to_string(secret_config_path.as_ref())?;
|
||||
let secret_config: SecretConfig = toml::from_str(&raw_secret)?;
|
||||
|
||||
let config = GeneralServiceConfig {
|
||||
postgres_host: public_config.postgres_host,
|
||||
pg_database: public_config.pg_database,
|
||||
postgres_user: public_config.postgres_user,
|
||||
file_storage: public_config.file_storage,
|
||||
dedicated_ai_server_address: public_config.dedicated_ai_server_address,
|
||||
dedicated_ai_server_port: public_config.dedicated_ai_server_port,
|
||||
dedicated_ai_server_secret: secret_config.dedicated_ai_server_secret,
|
||||
};
|
||||
|
||||
if config.pg_database.trim().is_empty() {
|
||||
return Err(std::io::Error::new(
|
||||
@ -41,9 +72,17 @@ pub fn load_config(path: impl AsRef<Path>) -> ConfigResult<GeneralServiceConfig>
|
||||
.into());
|
||||
}
|
||||
|
||||
if config.dedicated_ai_server_secret.trim().is_empty() {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidData,
|
||||
"config dedicated_ai_server_secret is empty",
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn load_config_default() -> ConfigResult<GeneralServiceConfig> {
|
||||
load_config(DEFAULT_CONFIG_PATH)
|
||||
load_config(DEFAULT_CONFIG_PATH, DEFAULT_SECRET_CONFIG_PATH)
|
||||
}
|
||||
|
||||
@ -3,11 +3,11 @@ use tokio::net::{TcpListener, TcpStream};
|
||||
use anyhow::Result;
|
||||
use crate::dedicated_ai_server::api::{
|
||||
ChatCompletionRequest,
|
||||
DekuBytes,
|
||||
MessageInChat,
|
||||
Request,
|
||||
RequestPayload,
|
||||
};
|
||||
use frontend_protocol::DekuBytes;
|
||||
use crate::dedicated_ai_server::talking::{SecretStreamSocket, wrap_connection_socket};
|
||||
use crate::dedicated_ai_server::talking::{ProtocolError, FrameCallback};
|
||||
|
||||
|
||||
@ -1,39 +1,12 @@
|
||||
use anyhow::{Context, Result, bail};
|
||||
use deku::prelude::*;
|
||||
use super::marshalling_utils::{
|
||||
read_bool_u8,
|
||||
use frontend_protocol::{
|
||||
DekuBytes,
|
||||
read_pascal_string,
|
||||
write_bool_u8,
|
||||
read_vec_u32,
|
||||
write_pascal_string,
|
||||
write_vec_u32,
|
||||
};
|
||||
|
||||
pub trait DekuBytes: Sized + Clone + DekuContainerWrite + for<'a> DekuContainerRead<'a> {
|
||||
fn pre_encode(&mut self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_bytes(&self) -> Result<Vec<u8>> {
|
||||
let mut value = self.clone();
|
||||
value.pre_encode()?;
|
||||
let type_name = std::any::type_name::<Self>();
|
||||
DekuContainerWrite::to_bytes(&value)
|
||||
.context(format!("failed to encode {type_name}"))
|
||||
}
|
||||
|
||||
fn from_bytes(bytes: &[u8]) -> Result<Self> {
|
||||
let type_name = std::any::type_name::<Self>();
|
||||
let ((rest, bit_offset), value) =
|
||||
<Self as DekuContainerRead>::from_bytes((bytes, 0))
|
||||
.context(format!("failed to decode {type_name}"))?;
|
||||
|
||||
if !rest.is_empty() || bit_offset != 0 {
|
||||
bail!("trailing bytes after {type_name}");
|
||||
}
|
||||
|
||||
Ok(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct MessageInChat {
|
||||
#[deku(
|
||||
@ -49,24 +22,20 @@ pub struct MessageInChat {
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/* Deku is a joke. A pathetic excuse for a marshalling library. But it's okay */
|
||||
#[deku::deku_derive(DekuRead, DekuWrite)]
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ChatCompletionRequest {
|
||||
#[deku(endian = "little", update = "self.messages.len() as u32")]
|
||||
message_count: u32,
|
||||
|
||||
#[deku(count = "*message_count as usize")]
|
||||
#[deku(
|
||||
reader = "read_vec_u32(deku::reader)",
|
||||
writer = "write_vec_u32(deku::writer, &self.messages)"
|
||||
)]
|
||||
pub messages: Vec<MessageInChat>,
|
||||
}
|
||||
|
||||
|
||||
impl ChatCompletionRequest {
|
||||
pub fn new(messages: Vec<MessageInChat>) -> Self {
|
||||
Self {
|
||||
message_count: messages.len() as u32,
|
||||
messages,
|
||||
}
|
||||
Self { messages }
|
||||
}
|
||||
}
|
||||
|
||||
@ -127,37 +96,12 @@ pub struct Response {
|
||||
|
||||
impl DekuBytes for MessageInChat {}
|
||||
|
||||
impl DekuBytes for ChatCompletionRequest {
|
||||
fn pre_encode(&mut self) -> Result<()> {
|
||||
self.update()
|
||||
.context("failed to update ChatCompletionRequest before encoding")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl DekuBytes for ChatCompletionRequest {}
|
||||
|
||||
impl DekuBytes for ChatCompletionCancellationRequest {}
|
||||
|
||||
impl DekuBytes for RequestPayload {
|
||||
fn pre_encode(&mut self) -> Result<()> {
|
||||
if let RequestPayload::ChatCompletion(batch) = self {
|
||||
batch
|
||||
.update()
|
||||
.context("failed to update ChatCompletionRequest before encoding")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl DekuBytes for Request {
|
||||
fn pre_encode(&mut self) -> Result<()> {
|
||||
if let RequestPayload::ChatCompletion(batch) = &mut self.payload {
|
||||
batch
|
||||
.update()
|
||||
.context("failed to update ChatCompletionRequest before encoding")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
impl DekuBytes for RequestPayload {}
|
||||
impl DekuBytes for Request {}
|
||||
|
||||
impl DekuBytes for ResponseChatCompletion {}
|
||||
impl DekuBytes for ResponseChatCompletionCancellation {}
|
||||
|
||||
205
website/src/dedicated_ai_server/connection.rs
Normal file
205
website/src/dedicated_ai_server/connection.rs
Normal file
@ -0,0 +1,205 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio::time::{sleep, Duration};
|
||||
|
||||
use frontend_protocol::DekuBytes;
|
||||
use super::api::{
|
||||
ChatCompletionRequest,
|
||||
MessageInChat,
|
||||
Request,
|
||||
RequestPayload,
|
||||
Response,
|
||||
ResponsePayload,
|
||||
};
|
||||
use super::talking::{SecretStreamSocket, wrap_connection_socket};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MessagePiecePayload(pub String);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MessagePiece {
|
||||
Piece(MessagePiecePayload), End, Cancelled, DedicatedServerDisconnected
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PendingChatCompletionRecord {
|
||||
response_tx: mpsc::UnboundedSender<MessagePiece>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct RequestWorkItem {
|
||||
request: Request,
|
||||
response_tx: mpsc::UnboundedSender<MessagePiece>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DedicatedAiServerConnection {
|
||||
address: String,
|
||||
port: u16,
|
||||
secret: String,
|
||||
request_tx: mpsc::UnboundedSender<RequestWorkItem>,
|
||||
next_request_id: AtomicU64,
|
||||
}
|
||||
|
||||
pub fn connect_to_dedicated_ai_server(address: String, port: u16, secret: String)
|
||||
-> (DedicatedAiServerConnection, JoinHandle<()>) {
|
||||
let (request_tx, request_rx) = mpsc::unbounded_channel();
|
||||
let worker_address = address.clone();
|
||||
let worker_secret = secret.clone();
|
||||
|
||||
let join_hand = tokio::spawn(async move {
|
||||
connection_worker_loop(worker_address, port, worker_secret, request_rx).await;
|
||||
});
|
||||
|
||||
let conn_bridge = DedicatedAiServerConnection {
|
||||
address,
|
||||
port,
|
||||
secret,
|
||||
request_tx,
|
||||
next_request_id: AtomicU64::new(1),
|
||||
};
|
||||
(conn_bridge, join_hand)
|
||||
}
|
||||
|
||||
impl DedicatedAiServerConnection {
|
||||
pub fn send_chat_completion(
|
||||
&self,
|
||||
messages: Vec<MessageInChat>,
|
||||
) -> Result<mpsc::UnboundedReceiver<MessagePiece>> {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
let request = Request {
|
||||
request_id,
|
||||
payload: RequestPayload::ChatCompletion(ChatCompletionRequest::new(messages)),
|
||||
};
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
self.request_tx
|
||||
.send(RequestWorkItem { request, response_tx })
|
||||
.map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?;
|
||||
Ok(response_rx)
|
||||
}
|
||||
}
|
||||
|
||||
async fn connection_worker_loop(
|
||||
address: String,
|
||||
port: u16,
|
||||
secret: String,
|
||||
mut request_rx: mpsc::UnboundedReceiver<RequestWorkItem>,
|
||||
) {
|
||||
// todo : fix a lot of errors
|
||||
loop {
|
||||
let socket = match TcpStream::connect((address.as_str(), port)).await {
|
||||
Ok(socket) => socket,
|
||||
Err(err) => {
|
||||
eprintln!(
|
||||
"[dedicated-ai] failed to connect to {}:{}: {err}",
|
||||
address, port
|
||||
);
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut transport = match wrap_connection_socket(socket, &secret).await {
|
||||
Ok(transport) => transport,
|
||||
Err(err) => {
|
||||
eprintln!("[dedicated-ai] failed to wrap connection: {err:#}");
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut pending: HashMap<u64, PendingChatCompletionRecord> = HashMap::new();
|
||||
let result = run_connected_loop(&mut transport, &mut request_rx, &mut pending).await;
|
||||
if let Err(err) = result {
|
||||
eprintln!("[dedicated-ai] connection error: {err:#}");
|
||||
}
|
||||
|
||||
cancel_all_pending(&mut pending);
|
||||
let _ = transport.close(true).await;
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_connected_loop(
|
||||
transport: &mut SecretStreamSocket,
|
||||
request_rx: &mut mpsc::UnboundedReceiver<RequestWorkItem>,
|
||||
pending: &mut HashMap<u64, PendingChatCompletionRecord>,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
tokio::select! {
|
||||
frame = transport.recv_frame() => {
|
||||
let frame = frame.context("failed to read frame")?;
|
||||
let frame = match frame {
|
||||
Some(bytes) => bytes,
|
||||
None => return Err(anyhow::anyhow!("connection closed by peer")),
|
||||
};
|
||||
handle_response_frame(frame, pending)?;
|
||||
}
|
||||
maybe_item = request_rx.recv() => {
|
||||
let item = match maybe_item {
|
||||
Some(item) => item,
|
||||
None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen
|
||||
};
|
||||
let request_id = item.request.request_id;
|
||||
pending.insert(
|
||||
request_id,
|
||||
PendingChatCompletionRecord { response_tx: item.response_tx },
|
||||
);
|
||||
|
||||
let payload = item.request.to_bytes().context("failed to encode request")?;
|
||||
transport.send_frame(&payload).await.context("failed to send request")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_response_frame(
|
||||
frame: Vec<u8>,
|
||||
pending: &mut HashMap<u64, PendingChatCompletionRecord>,
|
||||
) -> Result<()> {
|
||||
let response = Response::from_bytes(&frame).context("failed to decode response")?;
|
||||
let request_id = response.request_id;
|
||||
let record = pending.get(&request_id);
|
||||
if record.is_none() {
|
||||
eprintln!("[dedicated-ai] response for unknown request_id={request_id}");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match response.payload {
|
||||
ResponsePayload::ChatCompletion(payload) => {
|
||||
let send_failed = match pending.get(&request_id) {
|
||||
Some(record) => record
|
||||
.response_tx
|
||||
.send(MessagePiece::Piece( MessagePiecePayload(payload.piece) ))
|
||||
.is_err(),
|
||||
None => false,
|
||||
};
|
||||
if send_failed {
|
||||
pending.remove(&request_id);
|
||||
}
|
||||
}
|
||||
ResponsePayload::ChatCompletionEnd(_) => {
|
||||
if let Some(record) = pending.remove(&request_id) {
|
||||
let _ = record.response_tx.send(MessagePiece::End);
|
||||
}
|
||||
}
|
||||
ResponsePayload::ChatCompletionCancellation(_) => {
|
||||
if let Some(record) = pending.remove(&request_id) {
|
||||
let _ = record.response_tx.send(MessagePiece::Cancelled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cancel_all_pending(pending: &mut HashMap<u64, PendingChatCompletionRecord>) {
|
||||
for (_request_id, record) in pending.drain() {
|
||||
let _ = record.response_tx.send(MessagePiece::DedicatedServerDisconnected);
|
||||
}
|
||||
}
|
||||
@ -1,56 +0,0 @@
|
||||
use deku::ctx::{Endian, Order};
|
||||
use deku::{DekuError, DekuWriter};
|
||||
use deku::prelude::{Reader, Writer};
|
||||
use deku::prelude::*;
|
||||
|
||||
pub fn read_pascal_string<R>(reader: &mut Reader<R>) -> core::result::Result<String, DekuError>
|
||||
where
|
||||
R: deku::no_std_io::Read + deku::no_std_io::Seek,
|
||||
{
|
||||
let byte_len = u32::from_reader_with_ctx(reader, Endian::Little)? as usize;
|
||||
let mut bytes = vec![0u8; byte_len];
|
||||
reader.read_bytes(byte_len, &mut bytes, Order::Msb0)?;
|
||||
|
||||
String::from_utf8(bytes).map_err(|err| DekuError::Parse(err.to_string().into()))
|
||||
}
|
||||
|
||||
|
||||
pub fn write_pascal_string<W>(
|
||||
writer: &mut Writer<W>,
|
||||
value: &String,
|
||||
) -> core::result::Result<(), DekuError>
|
||||
where
|
||||
W: deku::no_std_io::Write + deku::no_std_io::Seek,
|
||||
{
|
||||
let bytes = value.as_bytes();
|
||||
let byte_len =
|
||||
u32::try_from(bytes.len()).map_err(|_| DekuError::Parse("string too large".into()))?;
|
||||
|
||||
byte_len.to_writer(writer, Endian::Little)?;
|
||||
writer.write_bytes(bytes)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn read_bool_u8<R>(reader: &mut Reader<R>) -> core::result::Result<bool, DekuError>
|
||||
where
|
||||
R: deku::no_std_io::Read + deku::no_std_io::Seek,
|
||||
{
|
||||
let value = u8::from_reader_with_ctx(reader, Endian::Little)?;
|
||||
match value {
|
||||
0 => Ok(false),
|
||||
1 => Ok(true),
|
||||
_ => Err(DekuError::Parse("invalid bool value".into())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_bool_u8<W>(
|
||||
writer: &mut Writer<W>,
|
||||
value: &bool,
|
||||
) -> core::result::Result<(), DekuError>
|
||||
where
|
||||
W: deku::no_std_io::Write + deku::no_std_io::Seek,
|
||||
{
|
||||
let encoded: u8 = if *value { 1 } else { 0 };
|
||||
encoded.to_writer(writer, Endian::Little)?;
|
||||
Ok(())
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
pub mod api;
|
||||
pub mod connection;
|
||||
pub mod talking;
|
||||
pub mod TEST;
|
||||
mod marshalling_utils;
|
||||
@ -17,24 +17,38 @@ use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
|
||||
use serde::Deserialize;
|
||||
use std::future::Future;
|
||||
use std::net::SocketAddr;
|
||||
use std::pin::Pin;
|
||||
use std::sync::OnceLock;
|
||||
use tera::{Tera};
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tower_http::services::{ServeDir, ServeFile};
|
||||
|
||||
use samplers::random_junk::random_string;
|
||||
use config::{ConfigResult, GeneralServiceConfig, load_config, load_config_default};
|
||||
use db::{DbResult, connect_db, init_database};
|
||||
use config::load_config_default;
|
||||
use db::{connect_db, init_database};
|
||||
use web_file_uploads::{get_file, upload_get, upload_post};
|
||||
use web_app_state::{AppState, AppStateInner, AuthenticatedUserId};
|
||||
use dedicated_ai_server::connection::connect_to_dedicated_ai_server;
|
||||
use dedicated_ai_server::api::MessageInChat;
|
||||
use crate::dedicated_ai_server::connection::{MessagePiece, MessagePiecePayload};
|
||||
use frontend_protocol::{
|
||||
DekuBytes,
|
||||
UserRequest,
|
||||
UserRequestPayload,
|
||||
UserResponse,
|
||||
UserResponseChatCompletion,
|
||||
UserResponseChatCompletionCancellation,
|
||||
UserResponseChatCompletionEnd,
|
||||
UserResponsePayload,
|
||||
};
|
||||
|
||||
|
||||
async fn init_app_state() -> Result<AppState, Box<dyn std::error::Error>> {
|
||||
async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box<dyn std::error::Error>> {
|
||||
let config = load_config_default()?;
|
||||
let db = connect_db(&config).await?;
|
||||
let tera = Tera::new("frontend/pages/**/*.html")?;
|
||||
Ok(std::sync::Arc::new(AppStateInner { config, db, tera }))
|
||||
let (dedicated_ai, dedicated_ai_task_handler) = connect_to_dedicated_ai_server(
|
||||
config.dedicated_ai_server_address.clone(),
|
||||
config.dedicated_ai_server_port,
|
||||
config.dedicated_ai_server_secret.clone(),
|
||||
);
|
||||
Ok((std::sync::Arc::new(AppStateInner { config, db, tera, dedicated_ai }), dedicated_ai_task_handler))
|
||||
}
|
||||
|
||||
enum PasscodeAuthenticationResult{
|
||||
@ -235,7 +249,9 @@ async fn chat(
|
||||
user: AuthenticatedUserId,
|
||||
) -> Response {
|
||||
if let Some(ws) = ws {
|
||||
return ws.on_upgrade(move |socket| handle_chat_socket(socket)).into_response();
|
||||
return ws
|
||||
.on_upgrade(move |socket| handle_chat_socket(socket, state.clone()))
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let mut ctx = tera::Context::new();
|
||||
@ -247,19 +263,105 @@ async fn chat(
|
||||
Html(body).into_response()
|
||||
}
|
||||
|
||||
async fn handle_chat_socket(mut socket: WebSocket) {
|
||||
while let Some(msg) = socket.recv().await {
|
||||
async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
|
||||
'outer: while let Some(msg) = socket.recv().await {
|
||||
let msg = match msg {
|
||||
Ok(msg) => msg,
|
||||
Err(_) => break,
|
||||
};
|
||||
|
||||
match msg {
|
||||
Message::Text(text) => {
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
let upper = text.to_uppercase();
|
||||
if socket.send(Message::Text(upper)).await.is_err() {
|
||||
break;
|
||||
Message::Binary(bytes) => {
|
||||
let request = match UserRequest::from_bytes(&bytes) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
eprintln!("[chat] failed to decode request: {err:#}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match request.payload {
|
||||
UserRequestPayload::ChatCompletion(payload) => {
|
||||
let messages = payload
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| MessageInChat {
|
||||
role: msg.role,
|
||||
content: msg.content,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut response_rx = match state.dedicated_ai.send_chat_completion(messages) {
|
||||
Ok(rx) => rx,
|
||||
Err(err) => {
|
||||
eprintln!("[chat] failed to send request: {err:#}");
|
||||
let response = UserResponse {
|
||||
request_id: request.request_id,
|
||||
payload: UserResponsePayload::ChatCompletionCancellation(
|
||||
UserResponseChatCompletionCancellation,
|
||||
),
|
||||
};
|
||||
if let Ok(bytes) = response.to_bytes() {
|
||||
let _ = socket.send(Message::Binary(bytes)).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(piece) = response_rx.recv().await {
|
||||
let (payload, should_break) = match piece {
|
||||
MessagePiece::Piece(MessagePiecePayload(text)) => (
|
||||
UserResponsePayload::ChatCompletion(UserResponseChatCompletion {
|
||||
piece: text,
|
||||
}),
|
||||
false,
|
||||
),
|
||||
MessagePiece::End => (
|
||||
UserResponsePayload::ChatCompletionEnd(
|
||||
UserResponseChatCompletionEnd,
|
||||
),
|
||||
true,
|
||||
),
|
||||
MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => (
|
||||
UserResponsePayload::ChatCompletionCancellation(
|
||||
UserResponseChatCompletionCancellation,
|
||||
),
|
||||
true,
|
||||
),
|
||||
};
|
||||
|
||||
let response = UserResponse {
|
||||
request_id: request.request_id,
|
||||
payload,
|
||||
};
|
||||
let bytes = match response.to_bytes() {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
eprintln!("[chat] failed to encode response: {err:#}");
|
||||
break 'outer;
|
||||
}
|
||||
};
|
||||
if socket.send(Message::Binary(bytes)).await.is_err() {
|
||||
break 'outer;
|
||||
}
|
||||
if should_break {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
UserRequestPayload::ChatCompletionCancellation(_) => {
|
||||
let response = UserResponse {
|
||||
request_id: request.request_id,
|
||||
payload: UserResponsePayload::ChatCompletionCancellation(
|
||||
UserResponseChatCompletionCancellation,
|
||||
),
|
||||
};
|
||||
if let Ok(bytes) = response.to_bytes() {
|
||||
if socket.send(Message::Binary(bytes)).await.is_err() {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
@ -270,7 +372,7 @@ async fn handle_chat_socket(mut socket: WebSocket) {
|
||||
|
||||
|
||||
pub async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let state = init_app_state().await.expect("lol");
|
||||
let (state, mut dedicated_ai_task) = init_app_state().await.expect("lol");
|
||||
let app = Router::new()
|
||||
.route("/login", get(login_get))
|
||||
.route("/login", post(login_post))
|
||||
@ -292,7 +394,28 @@ pub async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let listener = tokio::net::TcpListener::bind(addr)
|
||||
.await
|
||||
.expect("bind failed");
|
||||
axum::serve(listener, app).await?;
|
||||
let mut server_task = tokio::spawn(async move { axum::serve(listener, app).await });
|
||||
|
||||
tokio::select! {
|
||||
res = &mut server_task => {
|
||||
dedicated_ai_task.abort();
|
||||
res??;
|
||||
}
|
||||
res = &mut dedicated_ai_task => {
|
||||
server_task.abort();
|
||||
res.map_err(|err| {
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
format!("dedicated ai task failed: {err}"),
|
||||
)
|
||||
})?;
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"dedicated ai task ended unexpectedly",
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
use tera::Tera;
|
||||
use crate::GeneralServiceConfig;
|
||||
use crate::config::GeneralServiceConfig;
|
||||
use crate::dedicated_ai_server::connection::DedicatedAiServerConnection;
|
||||
|
||||
pub struct AppStateInner {
|
||||
pub config: GeneralServiceConfig,
|
||||
pub db: tokio_postgres::Client,
|
||||
pub tera: Tera,
|
||||
pub dedicated_ai: DedicatedAiServerConnection,
|
||||
}
|
||||
|
||||
pub type AppState = std::sync::Arc<AppStateInner>;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user