Compare commits

..

2 Commits

18 changed files with 957 additions and 362 deletions

View File

@ -4,6 +4,7 @@ from huggingface_hub import snapshot_download
MODEL_ID = "zai-org/GLM-4.7-Flash" MODEL_ID = "zai-org/GLM-4.7-Flash"
# MODEL_ID = "mlabonne/Daredevil-8B-abliterated-GGUF"
def main() -> None: def main() -> None:

View File

@ -11,6 +11,7 @@ from transformers import (
) )
MODEL_ID = "zai-org/GLM-4.7-Flash" MODEL_ID = "zai-org/GLM-4.7-Flash"
# MODEL_ID = "/home/gregory/programming/testWithPython/Daredevil-GGUF/daredevil-8b-abliterated.Q8_0.gguf"
MAX_NEW_TOKENS = 256 MAX_NEW_TOKENS = 256
SYSTEM_PROMPT = ( SYSTEM_PROMPT = (

View File

@ -4,6 +4,7 @@ name = "dedicated_ai_server"
version = "0.1.0" version = "0.1.0"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"accelerate>=1.13.0",
"asyncio>=4.0.0", "asyncio>=4.0.0",
"construct>=2.10.70", "construct>=2.10.70",
"huggingface-hub>=1.0.0", "huggingface-hub>=1.0.0",

View File

@ -0,0 +1,54 @@
#!/usr/bin/env bash
set -euo pipefail
FISH_CONFIG_DIR="${HOME}/.config/fish"
FISH_FUNCTIONS_DIR="${FISH_CONFIG_DIR}/functions"
FISH_CONFIG_FILE="${FISH_CONFIG_DIR}/config.fish"
FISH_PROMPT_FILE="${FISH_FUNCTIONS_DIR}/fish_prompt.fish"
if ! command -v dnf >/dev/null 2>&1; then
echo "This script is for Fedora systems with dnf."
exit 1
fi
sudo dnf install -y htop tree git gcc gcc-c++ make cmake fish vim
curl -LsSf https://astral.sh/uv/install.sh | sh
if ! sudo dnf install -y uv; then
echo "uv package is not available in the currently enabled Fedora repositories."
fi
mkdir -p "${FISH_FUNCTIONS_DIR}"
touch "${FISH_CONFIG_FILE}"
if ! grep -qxF 'set -g fish_greeting' "${FISH_CONFIG_FILE}"; then
printf '\nset -g fish_greeting\n' >> "${FISH_CONFIG_FILE}"
fi
cat > "${FISH_PROMPT_FILE}" <<'EOF'
function fish_prompt
set -l last_status $status
set_color brblue
printf '%s' (prompt_pwd)
set_color normal
if test $last_status -ne 0
printf ' '
set_color red
printf '[%s]' $last_status
set_color normal
end
printf '\n'
set_color brcyan
printf '> '
set_color normal
end
EOF
echo "Installed: htop tree git gcc gcc-c++ make cmake fish"
echo "Fish prompt configured. Start it with: fish"

View File

@ -3,9 +3,11 @@ from __future__ import annotations
import asyncio import asyncio
import queue import queue
import threading import threading
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Iterable from typing import Any, Dict, Iterable
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from api import Request, Response from api import Request, Response
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
@ -15,8 +17,9 @@ SERVER_HOST = "127.0.0.1"
SERVER_PORT = 9000 SERVER_PORT = 9000
SHARED_SECRET = "change-me" SHARED_SECRET = "change-me"
MODEL_ID = "zai-org/GLM-4.7-Flash"
PROCESS_DELAY_SECONDS = 1.5 MAX_NEW_TOKENS = 256
PIECE_CHUNK_SIZE = 64
@dataclass @dataclass
@ -51,20 +54,81 @@ class WorkItem:
record: PendingChatCompletionRecord record: PendingChatCompletionRecord
def extract_last_message(messages: list) -> str: @dataclass
if not messages: class ModelBundle:
return "" tokenizer: Any
last = messages[-1] model: Any
if hasattr(last, "content"):
return last.content
return last.get("content", "")
def generate_uppercase_pieces(text: str) -> Iterable[str]: def fail(message: str) -> None:
words = text.split() print(message, flush=True)
for word in words: raise SystemExit(1)
time.sleep(PROCESS_DELAY_SECONDS)
yield word.upper() + " "
def build_context(messages: list) -> list[dict[str, str]]:
context: list[dict[str, str]] = []
for message in messages:
if hasattr(message, "role") and hasattr(message, "content"):
role = message.role
content = message.content
else:
role = message.get("role", "")
content = message.get("content", "")
context.append({"role": role, "content": content})
return context
def load_local_model() -> ModelBundle:
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
local_files_only=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
local_files_only=True,
dtype=torch.bfloat16,
device_map="auto",
)
except OSError as exc:
fail(
f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. "
f"Run B2.py first. Original error: {exc}"
)
return ModelBundle(
tokenizer=tokenizer,
model=model,
)
def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]:
context = build_context(messages)
inputs = bundle.tokenizer.apply_chat_template(
context,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = {name: tensor.to(bundle.model.device) for name, tensor in inputs.items()}
input_ids = inputs.get("input_ids")
input_len = int(input_ids.shape[1]) if input_ids is not None else 0
with torch.inference_mode():
output_ids = bundle.model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
generated_ids = output_ids[0][input_len:]
text = bundle.tokenizer.decode(generated_ids, skip_special_tokens=True)
if not text:
return []
return [text[i:i + PIECE_CHUNK_SIZE] for i in range(0, len(text), PIECE_CHUNK_SIZE)]
def worker_loop( def worker_loop(
@ -72,6 +136,7 @@ def worker_loop(
pending: Dict[int, PendingChatCompletionRecord], pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock, pending_lock: threading.Lock,
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
model_bundle: ModelBundle,
) -> None: ) -> None:
while True: while True:
item = work_queue.get() item = work_queue.get()
@ -84,8 +149,7 @@ def worker_loop(
if record.is_cancelled(): if record.is_cancelled():
cancelled = True cancelled = True
else: else:
text = extract_last_message(item.messages) for piece in generate_llm_pieces(model_bundle, item.messages):
for piece in generate_uppercase_pieces(text):
loop.call_soon_threadsafe( loop.call_soon_threadsafe(
record.response_queue.put_nowait, record.response_queue.put_nowait,
MessagePiece(piece=piece, is_end=False, is_cancel=False), MessagePiece(piece=piece, is_end=False, is_cancel=False),
@ -104,6 +168,12 @@ def worker_loop(
record.response_queue.put_nowait, record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=False), MessagePiece(is_end=True, is_cancel=False),
) )
except BaseException as exc:
print(f"[worker] generation failed: {exc}", flush=True)
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=True),
)
finally: finally:
with pending_lock: with pending_lock:
pending.pop(item.request_id, None) pending.pop(item.request_id, None)
@ -153,6 +223,38 @@ async def forward_pieces(
return return
async def handle_chat_request(
request: Request,
work_queue: queue.Queue[WorkItem | None],
pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock,
transport: SecretStreamSocket,
send_lock: asyncio.Lock,
response_tasks: set[asyncio.Task[None]],
) -> None:
request_id = request.request_id
print(f"[request] chat request_id={request_id}", flush=True)
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
record = PendingChatCompletionRecord(
response_queue=response_queue,
)
with pending_lock:
if request_id in pending:
raise ProtocolError(
f"Duplicate request_id {request_id} received on this connection"
)
pending[request_id] = record
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
task = asyncio.create_task(
forward_pieces(request_id, response_queue, transport, send_lock)
)
response_tasks.add(task)
task.add_done_callback(response_tasks.discard)
async def handle_client( async def handle_client(
reader: asyncio.StreamReader, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter, writer: asyncio.StreamWriter,
@ -176,27 +278,16 @@ async def handle_client(
request_id = request.request_id request_id = request.request_id
if request.kind == "chat": if request.kind == "chat":
print(f"[request] chat request_id={request_id}", flush=True) await handle_chat_request(
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue() request,
record = PendingChatCompletionRecord( work_queue,
response_queue=response_queue, pending,
pending_lock,
transport,
send_lock,
response_tasks,
) )
with pending_lock:
if request_id in pending:
raise ProtocolError(
f"Duplicate request_id {request_id} received on this connection"
)
pending[request_id] = record
owned_request_ids.add(request_id) owned_request_ids.add(request_id)
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
task = asyncio.create_task(
forward_pieces(request_id, response_queue, transport, send_lock)
)
response_tasks.add(task)
task.add_done_callback(response_tasks.discard)
elif request.kind == "cancel": elif request.kind == "cancel":
print(f"[request] cancel request_id={request_id}", flush=True) print(f"[request] cancel request_id={request_id}", flush=True)
with pending_lock: with pending_lock:
@ -222,7 +313,7 @@ async def handle_client(
await transport.close() await transport.close()
async def run_server() -> None: async def run_server(model_bundle: ModelBundle) -> None:
pending: Dict[int, PendingChatCompletionRecord] = {} pending: Dict[int, PendingChatCompletionRecord] = {}
pending_lock = threading.Lock() pending_lock = threading.Lock()
work_queue: queue.Queue[WorkItem | None] = queue.Queue() work_queue: queue.Queue[WorkItem | None] = queue.Queue()
@ -230,7 +321,7 @@ async def run_server() -> None:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
worker = threading.Thread( worker = threading.Thread(
target=worker_loop, target=worker_loop,
args=(work_queue, pending, pending_lock, loop), args=(work_queue, pending, pending_lock, loop, model_bundle),
daemon=True, daemon=True,
) )
worker.start() worker.start()
@ -254,7 +345,9 @@ async def run_server() -> None:
def main() -> None: def main() -> None:
asyncio.run(run_server()) model_bundle = load_local_model()
print("[model] loaded", flush=True)
asyncio.run(run_server(model_bundle))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,6 +2,24 @@ version = 1
revision = 3 revision = 3
requires-python = ">=3.13" requires-python = ">=3.13"
[[package]]
name = "accelerate"
version = "1.13.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "psutil" },
{ name = "pyyaml" },
{ name = "safetensors" },
{ name = "torch" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" },
]
[[package]] [[package]]
name = "annotated-doc" name = "annotated-doc"
version = "0.0.4" version = "0.0.4"
@ -188,6 +206,7 @@ name = "dedicated-ai-server"
version = "0.1.0" version = "0.1.0"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "accelerate" },
{ name = "asyncio" }, { name = "asyncio" },
{ name = "construct" }, { name = "construct" },
{ name = "huggingface-hub" }, { name = "huggingface-hub" },
@ -198,6 +217,7 @@ dependencies = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "accelerate", specifier = ">=1.13.0" },
{ name = "asyncio", specifier = ">=4.0.0" }, { name = "asyncio", specifier = ">=4.0.0" },
{ name = "construct", specifier = ">=2.10.70" }, { name = "construct", specifier = ">=2.10.70" },
{ name = "huggingface-hub", specifier = ">=1.0.0" }, { name = "huggingface-hub", specifier = ">=1.0.0" },
@ -633,6 +653,34 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
] ]
[[package]]
name = "psutil"
version = "7.2.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/51/08/510cbdb69c25a96f4ae523f733cdc963ae654904e8db864c07585ef99875/psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b", size = 130595, upload-time = "2026-01-28T18:14:57.293Z" },
{ url = "https://files.pythonhosted.org/packages/d6/f5/97baea3fe7a5a9af7436301f85490905379b1c6f2dd51fe3ecf24b4c5fbf/psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea", size = 131082, upload-time = "2026-01-28T18:14:59.732Z" },
{ url = "https://files.pythonhosted.org/packages/37/d6/246513fbf9fa174af531f28412297dd05241d97a75911ac8febefa1a53c6/psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63", size = 181476, upload-time = "2026-01-28T18:15:01.884Z" },
{ url = "https://files.pythonhosted.org/packages/b8/b5/9182c9af3836cca61696dabe4fd1304e17bc56cb62f17439e1154f225dd3/psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312", size = 184062, upload-time = "2026-01-28T18:15:04.436Z" },
{ url = "https://files.pythonhosted.org/packages/16/ba/0756dca669f5a9300d0cbcbfae9a4c30e446dfc7440ffe43ded5724bfd93/psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b", size = 139893, upload-time = "2026-01-28T18:15:06.378Z" },
{ url = "https://files.pythonhosted.org/packages/1c/61/8fa0e26f33623b49949346de05ec1ddaad02ed8ba64af45f40a147dbfa97/psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9", size = 135589, upload-time = "2026-01-28T18:15:08.03Z" },
{ url = "https://files.pythonhosted.org/packages/81/69/ef179ab5ca24f32acc1dac0c247fd6a13b501fd5534dbae0e05a1c48b66d/psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00", size = 130664, upload-time = "2026-01-28T18:15:09.469Z" },
{ url = "https://files.pythonhosted.org/packages/7b/64/665248b557a236d3fa9efc378d60d95ef56dd0a490c2cd37dafc7660d4a9/psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9", size = 131087, upload-time = "2026-01-28T18:15:11.724Z" },
{ url = "https://files.pythonhosted.org/packages/d5/2e/e6782744700d6759ebce3043dcfa661fb61e2fb752b91cdeae9af12c2178/psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a", size = 182383, upload-time = "2026-01-28T18:15:13.445Z" },
{ url = "https://files.pythonhosted.org/packages/57/49/0a41cefd10cb7505cdc04dab3eacf24c0c2cb158a998b8c7b1d27ee2c1f5/psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf", size = 185210, upload-time = "2026-01-28T18:15:16.002Z" },
{ url = "https://files.pythonhosted.org/packages/dd/2c/ff9bfb544f283ba5f83ba725a3c5fec6d6b10b8f27ac1dc641c473dc390d/psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1", size = 141228, upload-time = "2026-01-28T18:15:18.385Z" },
{ url = "https://files.pythonhosted.org/packages/f2/fc/f8d9c31db14fcec13748d373e668bc3bed94d9077dbc17fb0eebc073233c/psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841", size = 136284, upload-time = "2026-01-28T18:15:19.912Z" },
{ url = "https://files.pythonhosted.org/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486", size = 129090, upload-time = "2026-01-28T18:15:22.168Z" },
{ url = "https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979", size = 129859, upload-time = "2026-01-28T18:15:23.795Z" },
{ url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" },
{ url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" },
{ url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" },
{ url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" },
{ url = "https://files.pythonhosted.org/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988", size = 137737, upload-time = "2026-01-28T18:15:33.849Z" },
{ url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" },
]
[[package]] [[package]]
name = "pycparser" name = "pycparser"
version = "3.0" version = "3.0"

View File

@ -16,8 +16,11 @@ version = "0.3"
features = [ features = [
"Window", "Window",
"Document", "Document",
"Element",
"Event", "Event",
"EventTarget", "EventTarget",
"HtmlElement",
"HtmlTextAreaElement",
"KeyboardEvent", "KeyboardEvent",
"MouseEvent", "MouseEvent",
"WheelEvent", "WheelEvent",

View File

@ -1,5 +1,5 @@
<!doctype html> <!doctype html>
<html lang="en"> <html lang="en" class="chat-html">
<head> <head>
<meta charset="utf-8" /> <meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" /> <meta name="viewport" content="width=device-width, initial-scale=1" />
@ -14,10 +14,28 @@
init_chat(); init_chat();
</script> </script>
</head> </head>
<body> <body class="chat-body">
<main class="welcome-card"> <main class="chat-shell">
<h1>Chat</h1> <section class="chat-panel">
<p>WebSocket connection initialized. Open the console to see messages.</p> <div class="chat-header">
<h1 class="chat-title">Chat</h1>
<div class="chat-subtitle">Live session</div>
</div>
<div class="chat-container">
<div class="chat-messages" id="chat-messages"></div>
<div class="chat-input-area">
<textarea
class="chat-input"
id="chat-input"
placeholder="Write a message. Shift+Enter for a new line."
rows="4"
></textarea>
<div class="chat-status-label" id="chat-status"></div>
</div>
</div>
</section>
</main> </main>
</body> </body>
</html> </html>

View File

@ -1,79 +1,383 @@
use std::cell::RefCell; use std::cell::RefCell;
use std::rc::Rc; use std::rc::Rc;
use js_sys::{ArrayBuffer, Uint8Array};
use wasm_bindgen::prelude::*; use wasm_bindgen::prelude::*;
use wasm_bindgen::JsCast; use wasm_bindgen::JsCast;
use web_sys::{console, window, BinaryType, ErrorEvent, Event, MessageEvent, WebSocket}; use web_sys::{
use js_sys::{ArrayBuffer, Uint8Array}; console,
window,
BinaryType,
Document,
Element,
ErrorEvent,
Event,
HtmlElement,
HtmlTextAreaElement,
KeyboardEvent,
MessageEvent,
WebSocket,
};
use frontend_protocol::{ use frontend_protocol::{
ChatMessage, UserChatMessage,
DekuBytes, DekuBytes,
UserChatCompletionRequest, UserChatCompletionRequest,
UserRequest, UserRequest,
UserRequestPayload,
UserResponse, UserResponse,
UserResponsePayload,
}; };
thread_local! { thread_local! {
static WS_HANDLE: RefCell<Option<Rc<WebSocket>>> = RefCell::new(None); static APP_HANDLE: RefCell<Option<Rc<AppState>>> = RefCell::new(None);
}
struct ChatState {
messages: Vec<UserChatMessage>,
message_nodes: Vec<MessageNode>,
is_receiving: bool,
active_assistant_index: Option<usize>,
compose_role: ComposeRole,
generation_enabled: bool,
}
struct MessageNode {
wrapper: Element,
content: Element,
status: Option<Element>,
}
#[derive(Copy, Clone)]
enum ChatStatus {
Pending,
Cancelled,
Hidden,
}
#[derive(Copy, Clone, PartialEq)]
enum ComposeRole {
System,
Assistant,
User,
}
impl ComposeRole {
fn as_str(self) -> &'static str {
match self {
ComposeRole::System => "system",
ComposeRole::Assistant => "assistant",
ComposeRole::User => "user",
}
}
fn next(self) -> Self {
match self {
ComposeRole::System => ComposeRole::Assistant,
ComposeRole::Assistant => ComposeRole::User,
ComposeRole::User => ComposeRole::System,
}
}
}
struct AppState {
ws: WebSocket,
document: Document,
messages_container: Element,
input: HtmlTextAreaElement,
status_label: Element,
state: RefCell<ChatState>,
}
fn append_message(
document: &Document,
container: &Element,
role: &str,
content: &str,
status: Option<ChatStatus>,
) -> Result<MessageNode, JsValue> {
let wrapper = document.create_element("div")?;
wrapper.set_class_name(&format!("chat-message chat-message--{}", role));
let role_el = document.create_element("div")?;
role_el.set_class_name("chat-message__role");
role_el.set_text_content(Some(role));
let content_el = document.create_element("div")?;
content_el.set_class_name("chat-message__content");
content_el.set_text_content(Some(content));
wrapper.append_child(&role_el)?;
wrapper.append_child(&content_el)?;
let status_el = if let Some(status) = status {
let status_el = document.create_element("div")?;
apply_status(&status_el, status);
wrapper.append_child(&status_el)?;
Some(status_el)
} else {
None
};
container.append_child(&wrapper)?;
Ok(MessageNode {
wrapper,
content: content_el,
status: status_el,
})
}
fn set_message_content(node: &Element, content: &str) {
node.set_text_content(Some(content));
}
fn apply_status(node: &Element, status: ChatStatus) {
match status {
ChatStatus::Pending => {
node.set_class_name("chat-message__status chat-message__status--pending");
node.set_text_content(Some("..."));
}
ChatStatus::Cancelled => {
node.set_class_name("chat-message__status chat-message__status--cancelled");
node.set_text_content(Some("[canceled]"));
}
ChatStatus::Hidden => {
node.set_class_name("chat-message__status chat-message__status--hidden");
node.set_text_content(Some(""));
}
}
}
fn update_status_label(label: &Element, role: ComposeRole, generation_enabled: bool) {
let generation = if generation_enabled { "on" } else { "off" };
let text = format!("Role: {} | Generation: {}", role.as_str(), generation);
label.set_text_content(Some(&text));
}
fn scroll_to_bottom(container: &Element) {
if let Some(element) = container.dyn_ref::<HtmlElement>() {
let height = element.scroll_height();
element.set_scroll_top(height);
}
} }
#[wasm_bindgen] #[wasm_bindgen]
pub fn init_chat() -> Result<(), JsValue> { pub fn init_chat() -> Result<(), JsValue> {
let window = window().ok_or_else(|| JsValue::from_str("Missing window"))?; let window = window().ok_or_else(|| JsValue::from_str("Missing window"))?;
let document = window.document().ok_or_else(|| JsValue::from_str("Missing document"))?;
let location = window.location(); let location = window.location();
let host = location.host()?; let host = location.host()?;
let protocol = location.protocol()?; let protocol = location.protocol()?;
let scheme = if protocol == "https:" { "wss" } else { "ws" }; let scheme = if protocol == "https:" { "wss" } else { "ws" };
let ws_url = format!("{scheme}://{host}/chat"); let ws_url = format!("{scheme}://{host}/chat");
let ws = Rc::new(WebSocket::new(&ws_url)?); let messages_container = document
.get_element_by_id("chat-messages")
.ok_or_else(|| JsValue::from_str("Missing chat-messages element"))?;
let input_el: HtmlTextAreaElement = document
.get_element_by_id("chat-input")
.ok_or_else(|| JsValue::from_str("Missing chat-input element"))?
.dyn_into()?;
let status_label = document
.get_element_by_id("chat-status")
.ok_or_else(|| JsValue::from_str("Missing chat-status element"))?;
let ws = WebSocket::new(&ws_url)?;
ws.set_binary_type(BinaryType::Arraybuffer); ws.set_binary_type(BinaryType::Arraybuffer);
console::log_1(&format!("[ws] connecting to {ws_url}").into()); console::log_1(&format!("[ws] connecting to {ws_url}").into());
let ws_for_open = ws.clone(); let app = Rc::new(AppState {
ws,
document,
messages_container,
input: input_el,
status_label,
state: RefCell::new(ChatState {
messages: Vec::new(),
message_nodes: Vec::new(),
is_receiving: false,
active_assistant_index: None,
compose_role: ComposeRole::User,
generation_enabled: true,
}),
});
{
let state = app.state.borrow();
update_status_label(&app.status_label, state.compose_role, state.generation_enabled);
}
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| { let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
console::log_1(&"[ws] connected".into()); console::log_1(&"[ws] connected".into());
let requests = [
UserRequest {
request_id: 1,
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
ChatMessage {
role: "user".to_string(),
content: r#"this is the first message first chat
too bad it is in lower case one one one"#
.to_string(),
},
])),
},
UserRequest {
request_id: 2,
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
ChatMessage {
role: "user".to_string(),
content: "And this is message two of chat request 2. \
Too bad it isn't processed two two two"
.to_string(),
},
])),
},
];
for request in requests {
match request.to_bytes() {
Ok(bytes) => {
console::log_1(&format!("[ws] sending request_id={} bytes={}", request.request_id, bytes.len()).into());
let _ = ws_for_open.send_with_u8_array(&bytes);
}
Err(err) => {
console::error_1(&format!("[ws] encode error: {err:#}").into());
}
}
}
})); }));
ws.set_onopen(Some(onopen.as_ref().unchecked_ref())); app.ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
onopen.forget(); onopen.forget();
let app_for_keydown = app.clone();
let onkeydown = Closure::<dyn FnMut(KeyboardEvent)>::wrap(Box::new(move |event: KeyboardEvent| {
if event.ctrl_key() {
let key = event.key();
if key == "c" || key == "C" {
let state = app_for_keydown.state.borrow();
if state.is_receiving {
event.prevent_default();
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
console::error_1(&"[ws] socket is not open".into());
return;
}
let request = UserRequest::ChatCompletionCancellation;
match request.to_bytes() {
Ok(bytes) => {
if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) {
console::error_1(&format!("[ws] cancel send error: {:?}", err).into());
}
}
Err(err) => {
console::error_1(&format!("[ws] cancel encode error: {err:#}").into());
}
}
}
return;
}
if key == "1" {
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
state.compose_role = state.compose_role.next();
update_status_label(
&app_for_keydown.status_label,
state.compose_role,
state.generation_enabled,
);
return;
}
if key == "4" {
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
state.generation_enabled = !state.generation_enabled;
update_status_label(
&app_for_keydown.status_label,
state.compose_role,
state.generation_enabled,
);
return;
}
if key == "d" || key == "D" {
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
if state.is_receiving {
return;
}
let last_message = match state.messages.pop() {
Some(message) => message,
None => return,
};
if let Some(node) = state.message_nodes.pop() {
if let Some(parent) = node.wrapper.parent_node() {
let _ = parent.remove_child(&node.wrapper);
}
}
app_for_keydown.input.set_value(&last_message.content);
scroll_to_bottom(&app_for_keydown.messages_container);
return;
}
}
if event.key() != "Enter" || event.shift_key() {
return;
}
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
if state.is_receiving {
return;
}
let raw = app_for_keydown.input.value();
let trimmed = raw.trim();
if trimmed.is_empty() {
return;
}
let role = state.compose_role;
let user_content = trimmed.to_string();
app_for_keydown.input.set_value("");
let user_node = match append_message(
&app_for_keydown.document,
&app_for_keydown.messages_container,
role.as_str(),
&user_content,
None,
) {
Ok(node) => node,
Err(err) => {
console::error_1(&format!("[ui] failed to append user message: {:?}", err).into());
return;
}
};
state.messages.push(UserChatMessage {
role: role.as_str().to_string(),
content: user_content,
});
state.message_nodes.push(user_node);
scroll_to_bottom(&app_for_keydown.messages_container);
if role != ComposeRole::User || !state.generation_enabled {
return;
}
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
console::error_1(&"[ws] socket is not open".into());
return;
}
let history = state.messages.clone();
let request = UserRequest::ChatCompletion(UserChatCompletionRequest::new(history));
let bytes = match request.to_bytes() {
Ok(bytes) => bytes,
Err(err) => {
console::error_1(&format!("[ws] encode error: {err:#}").into());
return;
}
};
if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) {
console::error_1(&format!("[ws] send error: {:?}", err).into());
return;
}
let assistant_node = match append_message(
&app_for_keydown.document,
&app_for_keydown.messages_container,
"assistant",
"",
Some(ChatStatus::Pending),
) {
Ok(node) => node,
Err(err) => {
console::error_1(&format!("[ui] failed to append assistant message: {:?}", err).into());
return;
}
};
state.messages.push(UserChatMessage {
role: "assistant".to_string(),
content: String::new(),
});
state.message_nodes.push(assistant_node);
state.active_assistant_index = Some(state.messages.len() - 1);
state.is_receiving = true;
scroll_to_bottom(&app_for_keydown.messages_container);
}));
app.input.set_onkeydown(Some(onkeydown.as_ref().unchecked_ref()));
onkeydown.forget();
let app_for_message = app.clone();
let onmessage = Closure::<dyn FnMut(MessageEvent)>::wrap(Box::new(move |event: MessageEvent| { let onmessage = Closure::<dyn FnMut(MessageEvent)>::wrap(Box::new(move |event: MessageEvent| {
let data = event.data(); let data = event.data();
if let Some(text) = data.as_string() { if let Some(text) = data.as_string() {
@ -87,36 +391,73 @@ Too bad it isn't processed two two two"
let data = Uint8Array::new(&data); let data = Uint8Array::new(&data);
let bytes = data.to_vec(); let bytes = data.to_vec();
console::log_1(&format!("[ws] received bytes={}", bytes.len()).into());
let response = match UserResponse::from_bytes(&bytes) { let response = match UserResponse::from_bytes(&bytes) {
Ok(response) => response, Ok(response) => response,
Err(err) => { Err(err) => {
console::error_1(&format!("[ws] decode error: {err:#} (bytes={})", bytes.len()).into()); console::error_1(&format!("[ws] decode error: {err:#}").into());
return; return;
} }
}; };
match response.payload { let mut state = app_for_message.state.borrow_mut();
UserResponsePayload::ChatCompletion(payload) => { let assistant_index = match state.active_assistant_index {
console::log_1(&format!("[ws] request_id={} piece={}", response.request_id, payload.piece).into()); Some(index) => index,
None => {
console::log_1(&"[ws] missing assistant index".into());
return;
} }
UserResponsePayload::ChatCompletionCancellation(_) => { };
console::log_1(&format!("[ws] request_id={} [cancel]", response.request_id).into());
match response {
UserResponse::ChatCompletion(completion) => {
if let Some(message) = state.messages.get_mut(assistant_index) {
message.content.push_str(&completion.piece);
} }
UserResponsePayload::ChatCompletionEnd(_) => { if let Some(node) = state.message_nodes.get(assistant_index) {
console::log_1(&format!("[ws] request_id={} [end]", response.request_id).into()); if let Some(message) = state.messages.get(assistant_index) {
set_message_content(&node.content, &message.content);
}
}
scroll_to_bottom(&app_for_message.messages_container);
}
UserResponse::ChatCompletionEnd => {
state.is_receiving = false;
if let Some(node) = state
.active_assistant_index
.and_then(|index| state.message_nodes.get(index))
{
if let Some(status_node) = node.status.as_ref() {
apply_status(status_node, ChatStatus::Hidden);
}
}
state.active_assistant_index = None;
let _ = app_for_message.input.focus();
}
UserResponse::ChatCompletionCancellation => {
state.is_receiving = false;
if let Some(node) = state
.active_assistant_index
.and_then(|index| state.message_nodes.get(index))
{
if let Some(status_node) = node.status.as_ref() {
apply_status(status_node, ChatStatus::Cancelled);
}
}
state.active_assistant_index = None;
let _ = app_for_message.input.focus();
} }
} }
})); }));
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref())); app.ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
onmessage.forget(); onmessage.forget();
let onerror = Closure::<dyn FnMut(ErrorEvent)>::wrap(Box::new(move |event: ErrorEvent| { let onerror = Closure::<dyn FnMut(ErrorEvent)>::wrap(Box::new(move |event: ErrorEvent| {
console::error_1(&format!("[ws] error: {}", event.message()).into()); console::error_1(&format!("[ws] error: {}", event.message()).into());
})); }));
ws.set_onerror(Some(onerror.as_ref().unchecked_ref())); app.ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
onerror.forget(); onerror.forget();
WS_HANDLE.with(|slot| *slot.borrow_mut() = Some(ws.clone())); APP_HANDLE.with(|slot| *slot.borrow_mut() = Some(app.clone()));
let _ = app.input.focus();
Ok(()) Ok(())
} }

View File

@ -7,6 +7,158 @@ body {
overflow: hidden; overflow: hidden;
} }
.chat-html,
.chat-body {
height: 100%;
width: 100%;
}
.chat-body {
margin: 0;
background: #ffffff;
color: #1b1b1b;
font-family: "Fira Sans", "Space Grotesk", "Montserrat", sans-serif;
font-size: 14pt;
display: block;
overflow: auto;
}
.chat-shell {
min-height: 100%;
padding: 32px 24px;
box-sizing: border-box;
display: flex;
justify-content: center;
}
.chat-panel {
width: min(100%, max(70%, 1000pt));
display: flex;
flex-direction: column;
gap: 16px;
min-height: 0;
}
.chat-header {
display: flex;
flex-direction: column;
gap: 4px;
}
.chat-title {
margin: 0;
font-size: 20pt;
font-weight: 600;
}
.chat-subtitle {
font-size: 10pt;
text-transform: uppercase;
letter-spacing: 1px;
color: #6b6b6b;
}
.chat-container {
display: flex;
flex-direction: column;
gap: 16px;
min-height: 0;
flex: 1;
}
.chat-messages {
display: flex;
flex-direction: column;
gap: 12px;
overflow-y: auto;
padding-right: 4px;
flex: 1;
}
.chat-message {
width: 100%;
border-radius: 14px;
padding: 12px 14px;
box-sizing: border-box;
border: 1px solid #e2e2e2;
background: #fafafa;
display: flex;
flex-direction: column;
gap: 6px;
}
.chat-message--user {
background: #eef4ff;
border-color: #d5e2ff;
}
.chat-message--assistant {
background: #f7f2ff;
border-color: #e5d7ff;
}
.chat-message--system {
background: #f8f8f8;
border-color: #dddddd;
}
.chat-message__role {
font-size: 9pt;
text-transform: uppercase;
letter-spacing: 1px;
color: #6b6b6b;
}
.chat-message__content {
white-space: pre-wrap;
line-height: 1.45;
}
.chat-message__status {
font-size: 9pt;
color: #9aa0a6;
}
.chat-message__status--pending {
color: #9aa0a6;
}
.chat-message__status--cancelled {
color: #c0392b;
}
.chat-message__status--hidden {
display: none;
}
.chat-input-area {
display: flex;
flex-direction: column;
gap: 6px;
}
.chat-input {
width: 100%;
min-height: 120px;
padding: 12px 14px;
border-radius: 12px;
border: 1px solid #d4d4d4;
font-size: 14pt;
font-family: inherit;
resize: vertical;
box-sizing: border-box;
}
.chat-input:disabled {
background: #f2f2f2;
color: #8a8a8a;
}
.chat-status-label {
font-size: 15pt;
color: #4b4b4b;
}
body { body {
display: flex; display: flex;
align-items: center; align-items: center;

View File

@ -10,7 +10,7 @@ pub use self::utils::{
}; };
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct ChatMessage { pub struct UserChatMessage {
#[deku( #[deku(
reader = "read_pascal_string(deku::reader)", reader = "read_pascal_string(deku::reader)",
writer = "write_pascal_string(deku::writer, &self.role)" writer = "write_pascal_string(deku::writer, &self.role)"
@ -31,33 +31,23 @@ pub struct UserChatCompletionRequest {
reader = "read_vec_u32(deku::reader)", reader = "read_vec_u32(deku::reader)",
writer = "write_vec_u32(deku::writer, &self.messages)" writer = "write_vec_u32(deku::writer, &self.messages)"
)] )]
pub messages: Vec<ChatMessage>, pub messages: Vec<UserChatMessage>,
} }
impl UserChatCompletionRequest { impl UserChatCompletionRequest {
pub fn new(messages: Vec<ChatMessage>) -> Self { pub fn new(messages: Vec<UserChatMessage>) -> Self {
Self { messages } Self { messages }
} }
} }
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserChatCompletionCancellationRequest;
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
#[deku(id_type = "u8")] #[deku(id_type = "u8")]
#[repr(u8)] #[repr(u8)]
pub enum UserRequestPayload { pub enum UserRequest {
#[deku(id = "0")] #[deku(id = "0")]
ChatCompletion(UserChatCompletionRequest), ChatCompletion(UserChatCompletionRequest),
#[deku(id = "1")] #[deku(id = "1")]
ChatCompletionCancellation(UserChatCompletionCancellationRequest), ChatCompletionCancellation,
}
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserRequest {
#[deku(endian = "little")]
pub request_id: u64,
pub payload: UserRequestPayload,
} }
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
@ -69,42 +59,23 @@ pub struct UserResponseChatCompletion {
pub piece: String, pub piece: String,
} }
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserResponseChatCompletionCancellation;
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct UserResponseChatCompletionEnd;
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
#[deku(id_type = "u8")] #[deku(id_type = "u8")]
#[repr(u8)] #[repr(u8)]
pub enum UserResponsePayload { pub enum UserResponse {
#[deku(id = "0")] #[deku(id = "0")]
ChatCompletion(UserResponseChatCompletion), ChatCompletion(UserResponseChatCompletion),
#[deku(id = "1")] #[deku(id = "1")]
ChatCompletionCancellation(UserResponseChatCompletionCancellation), ChatCompletionCancellation,
#[deku(id = "2")] #[deku(id = "2")]
ChatCompletionEnd(UserResponseChatCompletionEnd), ChatCompletionEnd,
} }
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] impl DekuBytes for UserChatMessage {}
pub struct UserResponse {
#[deku(endian = "little")]
pub request_id: u64,
pub payload: UserResponsePayload,
}
impl DekuBytes for ChatMessage {}
impl DekuBytes for UserChatCompletionRequest {} impl DekuBytes for UserChatCompletionRequest {}
impl DekuBytes for UserChatCompletionCancellationRequest {}
impl DekuBytes for UserRequestPayload {}
impl DekuBytes for UserRequest {} impl DekuBytes for UserRequest {}
impl DekuBytes for UserResponseChatCompletion {} impl DekuBytes for UserResponseChatCompletion {}
impl DekuBytes for UserResponseChatCompletionCancellation {}
impl DekuBytes for UserResponseChatCompletionEnd {}
impl DekuBytes for UserResponsePayload {}
impl DekuBytes for UserResponse {} impl DekuBytes for UserResponse {}

View File

@ -1,7 +0,0 @@
use website::dedicated_ai_server::TEST::main_E;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
main_E().await?;
Ok(())
}

View File

@ -1,7 +0,0 @@
use website::dedicated_ai_server::TEST::main_F;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
main_F().await?;
Ok(())
}

View File

@ -1,117 +0,0 @@
use tokio::net::{TcpListener, TcpStream};
use anyhow::Result;
use crate::dedicated_ai_server::api::{
ChatCompletionRequest,
MessageInChat,
Request,
RequestPayload,
};
use frontend_protocol::DekuBytes;
use crate::dedicated_ai_server::talking::{SecretStreamSocket, wrap_connection_socket};
use crate::dedicated_ai_server::talking::{ProtocolError, FrameCallback};
const SERVER_HOST: &str = "127.0.0.1";
const SERVER_PORT: u16 = 9000;
const SHARED_SECRET: &str = "change-me";
/* client =========== aka F.py */
async fn connect_to_server(host: &str, port: u16, shared_secret: &str) -> Result<SecretStreamSocket> {
let socket = TcpStream::connect((host, port)).await?;
wrap_connection_socket(socket, shared_secret).await
}
async fn send_message(connection: &mut SecretStreamSocket, request_id: u64, role: &str, content: &str) -> Result<()> {
let batch = ChatCompletionRequest::new(vec![MessageInChat {
role: role.to_string(),
content: content.to_string(),
}]);
let request = Request {
request_id,
payload: RequestPayload::ChatCompletion(batch),
};
let payload = request.to_bytes()?;
connection.send_frame(&payload).await
}
async fn close_connection(connection: &mut SecretStreamSocket) -> Result<()> {
connection.close(true).await
}
pub async fn main_F() -> Result<()> {
let mut connection = connect_to_server(SERVER_HOST, SERVER_PORT, SHARED_SECRET).await?;
send_message(&mut connection, 1, "user", "hello from Rust client").await?;
close_connection(&mut connection).await?;
Ok(())
}
/* server========= aka E.py */
fn print_batch(peer: &str, request_id: u64, batch: &ChatCompletionRequest) {
println!(
"[packet] {} sent request_id={} with {} message(s)",
peer,
request_id,
batch.messages.len(),
);
for (index, message) in batch.messages.iter().enumerate() {
println!(
" [{}] role={:?} content={:?}",
index, message.role, message.content
);
}
}
async fn handle_client(client_socket: TcpStream, peer: String, shared_secret: &str) -> Result<()> {
let peer_for_callback = peer.clone();
let on_frame: FrameCallback<'_> = Box::new(move |frame: Vec<u8>| {
let request = Request::from_bytes(&frame)?;
match request.payload {
RequestPayload::ChatCompletion(batch) => {
print_batch(&peer_for_callback, request.request_id, &batch);
}
RequestPayload::ChatCompletionCancellation(_) => {
println!(
"[packet] {} sent request_id={} cancel",
peer_for_callback,
request.request_id,
);
}
}
Ok(())
});
let mut transport = wrap_connection_socket(client_socket, shared_secret).await?;
println!("[connected] {}", peer);
let result = transport.run_receiving_loop(on_frame).await;
println!("[disconnected] {}", peer);
let _ = transport.close(true).await;
result
}
pub async fn main_E() -> Result<()> {
let listener = TcpListener::bind((SERVER_HOST, SERVER_PORT)).await?;
println!("[listening] {}:{}", SERVER_HOST, SERVER_PORT);
loop {
let (client_socket, addr) = listener.accept().await?;
let peer = addr.to_string();
if let Err(err) = handle_client(client_socket, peer.clone(), SHARED_SECRET).await {
if err.downcast_ref::<ProtocolError>().is_some() {
eprintln!("[protocol error] {}: {:#}", peer, err);
} else {
eprintln!("[error] {}: {:#}", peer, err);
}
}
}
}

View File

@ -40,9 +40,6 @@ impl ChatCompletionRequest {
} }
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct ChatCompletionCancellationRequest;
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
#[deku(id_type = "u8")] #[deku(id_type = "u8")]
#[repr(u8)] #[repr(u8)]
@ -50,7 +47,7 @@ pub enum RequestPayload {
#[deku(id = "0")] #[deku(id = "0")]
ChatCompletion(ChatCompletionRequest), ChatCompletion(ChatCompletionRequest),
#[deku(id = "1")] #[deku(id = "1")]
ChatCompletionCancellation(ChatCompletionCancellationRequest), ChatCompletionCancellation,
} }
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
@ -69,12 +66,6 @@ pub struct ResponseChatCompletion {
pub piece: String, pub piece: String,
} }
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct ResponseChatCompletionCancellation;
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
pub struct ResponseChatCompletionEnd;
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
#[deku(id_type = "u8")] #[deku(id_type = "u8")]
#[repr(u8)] #[repr(u8)]
@ -82,9 +73,9 @@ pub enum ResponsePayload {
#[deku(id = "0")] #[deku(id = "0")]
ChatCompletion(ResponseChatCompletion), ChatCompletion(ResponseChatCompletion),
#[deku(id = "1")] #[deku(id = "1")]
ChatCompletionCancellation(ResponseChatCompletionCancellation), ChatCompletionCancellation,
#[deku(id = "2")] #[deku(id = "2")]
ChatCompletionEnd(ResponseChatCompletionEnd) ChatCompletionEnd
} }
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)] #[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
@ -98,12 +89,9 @@ impl DekuBytes for MessageInChat {}
impl DekuBytes for ChatCompletionRequest {} impl DekuBytes for ChatCompletionRequest {}
impl DekuBytes for ChatCompletionCancellationRequest {}
impl DekuBytes for RequestPayload {} impl DekuBytes for RequestPayload {}
impl DekuBytes for Request {} impl DekuBytes for Request {}
impl DekuBytes for ResponseChatCompletion {} impl DekuBytes for ResponseChatCompletion {}
impl DekuBytes for ResponseChatCompletionCancellation {}
impl DekuBytes for ResponsePayload {} impl DekuBytes for ResponsePayload {}
impl DekuBytes for Response {} impl DekuBytes for Response {}

View File

@ -32,9 +32,14 @@ struct PendingChatCompletionRecord {
} }
#[derive(Debug)] #[derive(Debug)]
struct RequestWorkItem { enum RequestWorkItem {
ChatCompletion {
request: Request, request: Request,
response_tx: mpsc::UnboundedSender<MessagePiece>, response_tx: mpsc::UnboundedSender<MessagePiece>,
},
ChatCompletionCancellation {
request_id: u64,
},
} }
#[derive(Debug)] #[derive(Debug)]
@ -70,7 +75,7 @@ impl DedicatedAiServerConnection {
pub fn send_chat_completion( pub fn send_chat_completion(
&self, &self,
messages: Vec<MessageInChat>, messages: Vec<MessageInChat>,
) -> Result<mpsc::UnboundedReceiver<MessagePiece>> { ) -> Result<(mpsc::UnboundedReceiver<MessagePiece>, u64)> {
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed); let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
let request = Request { let request = Request {
request_id, request_id,
@ -78,9 +83,16 @@ impl DedicatedAiServerConnection {
}; };
let (response_tx, response_rx) = mpsc::unbounded_channel(); let (response_tx, response_rx) = mpsc::unbounded_channel();
self.request_tx self.request_tx
.send(RequestWorkItem { request, response_tx }) .send(RequestWorkItem::ChatCompletion { request, response_tx })
.map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?; .map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?;
Ok(response_rx) Ok((response_rx, request_id))
}
pub fn send_chat_completion_cancellation(&self, request_id: u64) -> Result<()> {
self.request_tx
.send(RequestWorkItem::ChatCompletionCancellation { request_id })
.map_err(|err| anyhow::anyhow!("failed to enqueue cancellation: {err}"))?;
Ok(())
} }
} }
@ -145,15 +157,27 @@ async fn run_connected_loop(
Some(item) => item, Some(item) => item,
None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen
}; };
let request_id = item.request.request_id; match item {
RequestWorkItem::ChatCompletion { request, response_tx } => {
let request_id = request.request_id;
pending.insert( pending.insert(
request_id, request_id,
PendingChatCompletionRecord { response_tx: item.response_tx }, PendingChatCompletionRecord { response_tx },
); );
let payload = item.request.to_bytes().context("failed to encode request")?; let payload = request.to_bytes().context("failed to encode request")?;
transport.send_frame(&payload).await.context("failed to send request")?; transport.send_frame(&payload).await.context("failed to send request")?;
} }
RequestWorkItem::ChatCompletionCancellation { request_id } => {
let request = Request {
request_id,
payload: RequestPayload::ChatCompletionCancellation,
};
let payload = request.to_bytes().context("failed to encode cancellation")?;
transport.send_frame(&payload).await.context("failed to send cancellation")?;
}
}
}
} }
} }
} }
@ -183,12 +207,12 @@ fn handle_response_frame(
pending.remove(&request_id); pending.remove(&request_id);
} }
} }
ResponsePayload::ChatCompletionEnd(_) => { ResponsePayload::ChatCompletionEnd => {
if let Some(record) = pending.remove(&request_id) { if let Some(record) = pending.remove(&request_id) {
let _ = record.response_tx.send(MessagePiece::End); let _ = record.response_tx.send(MessagePiece::End);
} }
} }
ResponsePayload::ChatCompletionCancellation(_) => { ResponsePayload::ChatCompletionCancellation => {
if let Some(record) = pending.remove(&request_id) { if let Some(record) = pending.remove(&request_id) {
let _ = record.response_tx.send(MessagePiece::Cancelled); let _ = record.response_tx.send(MessagePiece::Cancelled);
} }

View File

@ -1,4 +1,3 @@
pub mod api; pub mod api;
pub mod connection; pub mod connection;
pub mod talking; pub mod talking;
pub mod TEST;

View File

@ -31,12 +31,8 @@ use crate::dedicated_ai_server::connection::{MessagePiece, MessagePiecePayload};
use frontend_protocol::{ use frontend_protocol::{
DekuBytes, DekuBytes,
UserRequest, UserRequest,
UserRequestPayload,
UserResponse, UserResponse,
UserResponseChatCompletion, UserResponseChatCompletion,
UserResponseChatCompletionCancellation,
UserResponseChatCompletionEnd,
UserResponsePayload,
}; };
async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box<dyn std::error::Error>> { async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box<dyn std::error::Error>> {
@ -264,10 +260,11 @@ async fn chat(
} }
async fn handle_chat_socket(mut socket: WebSocket, state: AppState) { async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
'outer: while let Some(msg) = socket.recv().await { 'outer: loop {
let msg = match msg { let msg = match socket.recv().await {
Ok(msg) => msg, Some(Ok(msg)) => msg,
Err(_) => break, Some(Err(_)) => break,
None => break,
}; };
match msg { match msg {
@ -280,8 +277,14 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
} }
}; };
match request.payload { let payload = match request {
UserRequestPayload::ChatCompletion(payload) => { UserRequest::ChatCompletion(payload) => payload,
UserRequest::ChatCompletionCancellation => {
eprintln!("[chat] protocol error: unexpected cancellation without active request");
break;
}
};
let messages = payload let messages = payload
.messages .messages
.into_iter() .into_iter()
@ -291,16 +294,14 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
}) })
.collect(); .collect();
let mut response_rx = match state.dedicated_ai.send_chat_completion(messages) { let (mut response_rx, request_id) = match state.dedicated_ai.
send_chat_completion(messages)
{
Ok(rx) => rx, Ok(rx) => rx,
Err(err) => { Err(err) => {
eprintln!("[chat] failed to send request: {err:#}"); eprintln!("[chat] failed to send request: {err:#}");
let response = UserResponse { let response = UserResponse::ChatCompletionCancellation;
request_id: request.request_id, // todo: make to_bytes nofail
payload: UserResponsePayload::ChatCompletionCancellation(
UserResponseChatCompletionCancellation,
),
};
if let Ok(bytes) = response.to_bytes() { if let Ok(bytes) = response.to_bytes() {
let _ = socket.send(Message::Binary(bytes)).await; let _ = socket.send(Message::Binary(bytes)).await;
} }
@ -308,32 +309,30 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
} }
}; };
while let Some(piece) = response_rx.recv().await { loop {
let (payload, should_break) = match piece { tokio::select! {
piece = response_rx.recv() => {
let piece = match piece {
Some(piece) => piece,
None => break,
};
let (response, should_break) = match piece {
MessagePiece::Piece(MessagePiecePayload(text)) => ( MessagePiece::Piece(MessagePiecePayload(text)) => (
UserResponsePayload::ChatCompletion(UserResponseChatCompletion { UserResponse::ChatCompletion(UserResponseChatCompletion {
piece: text, piece: text,
}), }),
false, false,
), ),
MessagePiece::End => ( MessagePiece::End => (
UserResponsePayload::ChatCompletionEnd( UserResponse::ChatCompletionEnd,
UserResponseChatCompletionEnd,
),
true, true,
), ),
MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => ( MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => (
UserResponsePayload::ChatCompletionCancellation( UserResponse::ChatCompletionCancellation,
UserResponseChatCompletionCancellation,
),
true, true,
), ),
}; };
let response = UserResponse {
request_id: request.request_id,
payload,
};
let bytes = match response.to_bytes() { let bytes = match response.to_bytes() {
Ok(bytes) => bytes, Ok(bytes) => bytes,
Err(err) => { Err(err) => {
@ -348,24 +347,57 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
break; break;
} }
} }
} msg = socket.recv() => {
UserRequestPayload::ChatCompletionCancellation(_) => { let msg = match msg {
let response = UserResponse { Some(Ok(msg)) => msg,
request_id: request.request_id, Some(Err(_)) => break 'outer,
payload: UserResponsePayload::ChatCompletionCancellation( None => break 'outer,
UserResponseChatCompletionCancellation,
),
}; };
if let Ok(bytes) = response.to_bytes() {
if socket.send(Message::Binary(bytes)).await.is_err() { match msg {
Message::Binary(bytes) => {
let request = match UserRequest::from_bytes(&bytes) {
Ok(request) => request,
Err(err) => {
eprintln!("[chat] failed to decode request: {err:#}");
break 'outer;
}
};
match request {
UserRequest::ChatCompletionCancellation => {
if let Err(err) = state
.dedicated_ai
.send_chat_completion_cancellation(request_id)
{
eprintln!("[chat] failed to send cancellation: {err:#}");
break 'outer;
}
}
UserRequest::ChatCompletion(_) => {
eprintln!("[chat] protocol error: chat completion while receiving");
break 'outer;
}
}
}
Message::Close(_) => break 'outer,
_ => {
eprintln!("[chat] protocol error: unexpected non-binary message");
break 'outer; break 'outer;
} }
} }
} }
} }
} }
Message::Close(_) => break, }
_ => {} Message::Close(_) => {
println!(" [debug] websocket closed");
break;
}
_ => {
eprintln!("[chat] protocol error: unexpected non-binary message");
break;
}
} }
} }
} }