Compare commits
No commits in common. "63340f25948f8055590dd16bfbea2d8b2ec527d0" and "6a66cde0d09a8af9815978fe1cff404152e21af3" have entirely different histories.
63340f2594
...
6a66cde0d0
@ -4,7 +4,6 @@ from huggingface_hub import snapshot_download
|
||||
|
||||
|
||||
MODEL_ID = "zai-org/GLM-4.7-Flash"
|
||||
# MODEL_ID = "mlabonne/Daredevil-8B-abliterated-GGUF"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
||||
@ -11,7 +11,6 @@ from transformers import (
|
||||
)
|
||||
|
||||
MODEL_ID = "zai-org/GLM-4.7-Flash"
|
||||
# MODEL_ID = "/home/gregory/programming/testWithPython/Daredevil-GGUF/daredevil-8b-abliterated.Q8_0.gguf"
|
||||
MAX_NEW_TOKENS = 256
|
||||
|
||||
SYSTEM_PROMPT = (
|
||||
|
||||
@ -4,7 +4,6 @@ name = "dedicated_ai_server"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"accelerate>=1.13.0",
|
||||
"asyncio>=4.0.0",
|
||||
"construct>=2.10.70",
|
||||
"huggingface-hub>=1.0.0",
|
||||
|
||||
@ -1,54 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
FISH_CONFIG_DIR="${HOME}/.config/fish"
|
||||
FISH_FUNCTIONS_DIR="${FISH_CONFIG_DIR}/functions"
|
||||
FISH_CONFIG_FILE="${FISH_CONFIG_DIR}/config.fish"
|
||||
FISH_PROMPT_FILE="${FISH_FUNCTIONS_DIR}/fish_prompt.fish"
|
||||
|
||||
|
||||
if ! command -v dnf >/dev/null 2>&1; then
|
||||
echo "This script is for Fedora systems with dnf."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
sudo dnf install -y htop tree git gcc gcc-c++ make cmake fish vim
|
||||
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
if ! sudo dnf install -y uv; then
|
||||
echo "uv package is not available in the currently enabled Fedora repositories."
|
||||
fi
|
||||
|
||||
mkdir -p "${FISH_FUNCTIONS_DIR}"
|
||||
touch "${FISH_CONFIG_FILE}"
|
||||
|
||||
if ! grep -qxF 'set -g fish_greeting' "${FISH_CONFIG_FILE}"; then
|
||||
printf '\nset -g fish_greeting\n' >> "${FISH_CONFIG_FILE}"
|
||||
fi
|
||||
|
||||
cat > "${FISH_PROMPT_FILE}" <<'EOF'
|
||||
function fish_prompt
|
||||
set -l last_status $status
|
||||
|
||||
set_color brblue
|
||||
printf '%s' (prompt_pwd)
|
||||
set_color normal
|
||||
|
||||
if test $last_status -ne 0
|
||||
printf ' '
|
||||
set_color red
|
||||
printf '[%s]' $last_status
|
||||
set_color normal
|
||||
end
|
||||
|
||||
printf '\n'
|
||||
set_color brcyan
|
||||
printf '> '
|
||||
set_color normal
|
||||
end
|
||||
EOF
|
||||
|
||||
echo "Installed: htop tree git gcc gcc-c++ make cmake fish"
|
||||
echo "Fish prompt configured. Start it with: fish"
|
||||
@ -3,11 +3,9 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from typing import Dict, Iterable
|
||||
|
||||
from api import Request, Response
|
||||
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
|
||||
@ -17,9 +15,8 @@ SERVER_HOST = "127.0.0.1"
|
||||
SERVER_PORT = 9000
|
||||
SHARED_SECRET = "change-me"
|
||||
|
||||
MODEL_ID = "zai-org/GLM-4.7-Flash"
|
||||
MAX_NEW_TOKENS = 256
|
||||
PIECE_CHUNK_SIZE = 64
|
||||
|
||||
PROCESS_DELAY_SECONDS = 1.5
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -54,81 +51,20 @@ class WorkItem:
|
||||
record: PendingChatCompletionRecord
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelBundle:
|
||||
tokenizer: Any
|
||||
model: Any
|
||||
def extract_last_message(messages: list) -> str:
|
||||
if not messages:
|
||||
return ""
|
||||
last = messages[-1]
|
||||
if hasattr(last, "content"):
|
||||
return last.content
|
||||
return last.get("content", "")
|
||||
|
||||
|
||||
def fail(message: str) -> None:
|
||||
print(message, flush=True)
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
def build_context(messages: list) -> list[dict[str, str]]:
|
||||
context: list[dict[str, str]] = []
|
||||
for message in messages:
|
||||
if hasattr(message, "role") and hasattr(message, "content"):
|
||||
role = message.role
|
||||
content = message.content
|
||||
else:
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
context.append({"role": role, "content": content})
|
||||
return context
|
||||
|
||||
|
||||
def load_local_model() -> ModelBundle:
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
MODEL_ID,
|
||||
local_files_only=True,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
local_files_only=True,
|
||||
dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
)
|
||||
except OSError as exc:
|
||||
fail(
|
||||
f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. "
|
||||
f"Run B2.py first. Original error: {exc}"
|
||||
)
|
||||
|
||||
return ModelBundle(
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]:
|
||||
context = build_context(messages)
|
||||
inputs = bundle.tokenizer.apply_chat_template(
|
||||
context,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = {name: tensor.to(bundle.model.device) for name, tensor in inputs.items()}
|
||||
|
||||
input_ids = inputs.get("input_ids")
|
||||
input_len = int(input_ids.shape[1]) if input_ids is not None else 0
|
||||
|
||||
with torch.inference_mode():
|
||||
output_ids = bundle.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
generated_ids = output_ids[0][input_len:]
|
||||
text = bundle.tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
if not text:
|
||||
return []
|
||||
|
||||
return [text[i:i + PIECE_CHUNK_SIZE] for i in range(0, len(text), PIECE_CHUNK_SIZE)]
|
||||
def generate_uppercase_pieces(text: str) -> Iterable[str]:
|
||||
words = text.split()
|
||||
for word in words:
|
||||
time.sleep(PROCESS_DELAY_SECONDS)
|
||||
yield word.upper() + " "
|
||||
|
||||
|
||||
def worker_loop(
|
||||
@ -136,7 +72,6 @@ def worker_loop(
|
||||
pending: Dict[int, PendingChatCompletionRecord],
|
||||
pending_lock: threading.Lock,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
model_bundle: ModelBundle,
|
||||
) -> None:
|
||||
while True:
|
||||
item = work_queue.get()
|
||||
@ -149,7 +84,8 @@ def worker_loop(
|
||||
if record.is_cancelled():
|
||||
cancelled = True
|
||||
else:
|
||||
for piece in generate_llm_pieces(model_bundle, item.messages):
|
||||
text = extract_last_message(item.messages)
|
||||
for piece in generate_uppercase_pieces(text):
|
||||
loop.call_soon_threadsafe(
|
||||
record.response_queue.put_nowait,
|
||||
MessagePiece(piece=piece, is_end=False, is_cancel=False),
|
||||
@ -168,12 +104,6 @@ def worker_loop(
|
||||
record.response_queue.put_nowait,
|
||||
MessagePiece(is_end=True, is_cancel=False),
|
||||
)
|
||||
except BaseException as exc:
|
||||
print(f"[worker] generation failed: {exc}", flush=True)
|
||||
loop.call_soon_threadsafe(
|
||||
record.response_queue.put_nowait,
|
||||
MessagePiece(is_end=True, is_cancel=True),
|
||||
)
|
||||
finally:
|
||||
with pending_lock:
|
||||
pending.pop(item.request_id, None)
|
||||
@ -223,38 +153,6 @@ async def forward_pieces(
|
||||
return
|
||||
|
||||
|
||||
async def handle_chat_request(
|
||||
request: Request,
|
||||
work_queue: queue.Queue[WorkItem | None],
|
||||
pending: Dict[int, PendingChatCompletionRecord],
|
||||
pending_lock: threading.Lock,
|
||||
transport: SecretStreamSocket,
|
||||
send_lock: asyncio.Lock,
|
||||
response_tasks: set[asyncio.Task[None]],
|
||||
) -> None:
|
||||
request_id = request.request_id
|
||||
print(f"[request] chat request_id={request_id}", flush=True)
|
||||
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
|
||||
record = PendingChatCompletionRecord(
|
||||
response_queue=response_queue,
|
||||
)
|
||||
|
||||
with pending_lock:
|
||||
if request_id in pending:
|
||||
raise ProtocolError(
|
||||
f"Duplicate request_id {request_id} received on this connection"
|
||||
)
|
||||
pending[request_id] = record
|
||||
|
||||
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
|
||||
|
||||
task = asyncio.create_task(
|
||||
forward_pieces(request_id, response_queue, transport, send_lock)
|
||||
)
|
||||
response_tasks.add(task)
|
||||
task.add_done_callback(response_tasks.discard)
|
||||
|
||||
|
||||
async def handle_client(
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
@ -278,16 +176,27 @@ async def handle_client(
|
||||
request_id = request.request_id
|
||||
|
||||
if request.kind == "chat":
|
||||
await handle_chat_request(
|
||||
request,
|
||||
work_queue,
|
||||
pending,
|
||||
pending_lock,
|
||||
transport,
|
||||
send_lock,
|
||||
response_tasks,
|
||||
print(f"[request] chat request_id={request_id}", flush=True)
|
||||
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
|
||||
record = PendingChatCompletionRecord(
|
||||
response_queue=response_queue,
|
||||
)
|
||||
|
||||
with pending_lock:
|
||||
if request_id in pending:
|
||||
raise ProtocolError(
|
||||
f"Duplicate request_id {request_id} received on this connection"
|
||||
)
|
||||
pending[request_id] = record
|
||||
|
||||
owned_request_ids.add(request_id)
|
||||
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
|
||||
|
||||
task = asyncio.create_task(
|
||||
forward_pieces(request_id, response_queue, transport, send_lock)
|
||||
)
|
||||
response_tasks.add(task)
|
||||
task.add_done_callback(response_tasks.discard)
|
||||
elif request.kind == "cancel":
|
||||
print(f"[request] cancel request_id={request_id}", flush=True)
|
||||
with pending_lock:
|
||||
@ -313,7 +222,7 @@ async def handle_client(
|
||||
await transport.close()
|
||||
|
||||
|
||||
async def run_server(model_bundle: ModelBundle) -> None:
|
||||
async def run_server() -> None:
|
||||
pending: Dict[int, PendingChatCompletionRecord] = {}
|
||||
pending_lock = threading.Lock()
|
||||
work_queue: queue.Queue[WorkItem | None] = queue.Queue()
|
||||
@ -321,7 +230,7 @@ async def run_server(model_bundle: ModelBundle) -> None:
|
||||
loop = asyncio.get_running_loop()
|
||||
worker = threading.Thread(
|
||||
target=worker_loop,
|
||||
args=(work_queue, pending, pending_lock, loop, model_bundle),
|
||||
args=(work_queue, pending, pending_lock, loop),
|
||||
daemon=True,
|
||||
)
|
||||
worker.start()
|
||||
@ -345,9 +254,7 @@ async def run_server(model_bundle: ModelBundle) -> None:
|
||||
|
||||
|
||||
def main() -> None:
|
||||
model_bundle = load_local_model()
|
||||
print("[model] loaded", flush=True)
|
||||
asyncio.run(run_server(model_bundle))
|
||||
asyncio.run(run_server())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
48
dedicated_ai_server/uv.lock
generated
48
dedicated_ai_server/uv.lock
generated
@ -2,24 +2,6 @@ version = 1
|
||||
revision = 3
|
||||
requires-python = ">=3.13"
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
version = "1.13.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "numpy" },
|
||||
{ name = "packaging" },
|
||||
{ name = "psutil" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "safetensors" },
|
||||
{ name = "torch" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-doc"
|
||||
version = "0.0.4"
|
||||
@ -206,7 +188,6 @@ name = "dedicated-ai-server"
|
||||
version = "0.1.0"
|
||||
source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "asyncio" },
|
||||
{ name = "construct" },
|
||||
{ name = "huggingface-hub" },
|
||||
@ -217,7 +198,6 @@ dependencies = [
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "accelerate", specifier = ">=1.13.0" },
|
||||
{ name = "asyncio", specifier = ">=4.0.0" },
|
||||
{ name = "construct", specifier = ">=2.10.70" },
|
||||
{ name = "huggingface-hub", specifier = ">=1.0.0" },
|
||||
@ -653,34 +633,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psutil"
|
||||
version = "7.2.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/51/08/510cbdb69c25a96f4ae523f733cdc963ae654904e8db864c07585ef99875/psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b", size = 130595, upload-time = "2026-01-28T18:14:57.293Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d6/f5/97baea3fe7a5a9af7436301f85490905379b1c6f2dd51fe3ecf24b4c5fbf/psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea", size = 131082, upload-time = "2026-01-28T18:14:59.732Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/37/d6/246513fbf9fa174af531f28412297dd05241d97a75911ac8febefa1a53c6/psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63", size = 181476, upload-time = "2026-01-28T18:15:01.884Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b8/b5/9182c9af3836cca61696dabe4fd1304e17bc56cb62f17439e1154f225dd3/psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312", size = 184062, upload-time = "2026-01-28T18:15:04.436Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/16/ba/0756dca669f5a9300d0cbcbfae9a4c30e446dfc7440ffe43ded5724bfd93/psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b", size = 139893, upload-time = "2026-01-28T18:15:06.378Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/61/8fa0e26f33623b49949346de05ec1ddaad02ed8ba64af45f40a147dbfa97/psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9", size = 135589, upload-time = "2026-01-28T18:15:08.03Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/81/69/ef179ab5ca24f32acc1dac0c247fd6a13b501fd5534dbae0e05a1c48b66d/psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00", size = 130664, upload-time = "2026-01-28T18:15:09.469Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/7b/64/665248b557a236d3fa9efc378d60d95ef56dd0a490c2cd37dafc7660d4a9/psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9", size = 131087, upload-time = "2026-01-28T18:15:11.724Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d5/2e/e6782744700d6759ebce3043dcfa661fb61e2fb752b91cdeae9af12c2178/psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a", size = 182383, upload-time = "2026-01-28T18:15:13.445Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/57/49/0a41cefd10cb7505cdc04dab3eacf24c0c2cb158a998b8c7b1d27ee2c1f5/psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf", size = 185210, upload-time = "2026-01-28T18:15:16.002Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/2c/ff9bfb544f283ba5f83ba725a3c5fec6d6b10b8f27ac1dc641c473dc390d/psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1", size = 141228, upload-time = "2026-01-28T18:15:18.385Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/f2/fc/f8d9c31db14fcec13748d373e668bc3bed94d9077dbc17fb0eebc073233c/psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841", size = 136284, upload-time = "2026-01-28T18:15:19.912Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486", size = 129090, upload-time = "2026-01-28T18:15:22.168Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979", size = 129859, upload-time = "2026-01-28T18:15:23.795Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988", size = 137737, upload-time = "2026-01-28T18:15:33.849Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pycparser"
|
||||
version = "3.0"
|
||||
|
||||
@ -16,11 +16,8 @@ version = "0.3"
|
||||
features = [
|
||||
"Window",
|
||||
"Document",
|
||||
"Element",
|
||||
"Event",
|
||||
"EventTarget",
|
||||
"HtmlElement",
|
||||
"HtmlTextAreaElement",
|
||||
"KeyboardEvent",
|
||||
"MouseEvent",
|
||||
"WheelEvent",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
<!doctype html>
|
||||
<html lang="en" class="chat-html">
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
@ -14,28 +14,10 @@
|
||||
init_chat();
|
||||
</script>
|
||||
</head>
|
||||
<body class="chat-body">
|
||||
<main class="chat-shell">
|
||||
<section class="chat-panel">
|
||||
<div class="chat-header">
|
||||
<h1 class="chat-title">Chat</h1>
|
||||
<div class="chat-subtitle">Live session</div>
|
||||
</div>
|
||||
|
||||
<div class="chat-container">
|
||||
<div class="chat-messages" id="chat-messages"></div>
|
||||
|
||||
<div class="chat-input-area">
|
||||
<textarea
|
||||
class="chat-input"
|
||||
id="chat-input"
|
||||
placeholder="Write a message. Shift+Enter for a new line."
|
||||
rows="4"
|
||||
></textarea>
|
||||
<div class="chat-status-label" id="chat-status"></div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
<body>
|
||||
<main class="welcome-card">
|
||||
<h1>Chat</h1>
|
||||
<p>WebSocket connection initialized. Open the console to see messages.</p>
|
||||
</main>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@ -1,383 +1,79 @@
|
||||
use std::cell::RefCell;
|
||||
use std::rc::Rc;
|
||||
|
||||
use js_sys::{ArrayBuffer, Uint8Array};
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen::JsCast;
|
||||
use web_sys::{
|
||||
console,
|
||||
window,
|
||||
BinaryType,
|
||||
Document,
|
||||
Element,
|
||||
ErrorEvent,
|
||||
Event,
|
||||
HtmlElement,
|
||||
HtmlTextAreaElement,
|
||||
KeyboardEvent,
|
||||
MessageEvent,
|
||||
WebSocket,
|
||||
};
|
||||
|
||||
use web_sys::{console, window, BinaryType, ErrorEvent, Event, MessageEvent, WebSocket};
|
||||
use js_sys::{ArrayBuffer, Uint8Array};
|
||||
use frontend_protocol::{
|
||||
UserChatMessage,
|
||||
ChatMessage,
|
||||
DekuBytes,
|
||||
UserChatCompletionRequest,
|
||||
UserRequest,
|
||||
UserRequestPayload,
|
||||
UserResponse,
|
||||
UserResponsePayload,
|
||||
};
|
||||
|
||||
thread_local! {
|
||||
static APP_HANDLE: RefCell<Option<Rc<AppState>>> = RefCell::new(None);
|
||||
}
|
||||
|
||||
struct ChatState {
|
||||
messages: Vec<UserChatMessage>,
|
||||
message_nodes: Vec<MessageNode>,
|
||||
is_receiving: bool,
|
||||
active_assistant_index: Option<usize>,
|
||||
compose_role: ComposeRole,
|
||||
generation_enabled: bool,
|
||||
}
|
||||
|
||||
struct MessageNode {
|
||||
wrapper: Element,
|
||||
content: Element,
|
||||
status: Option<Element>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
enum ChatStatus {
|
||||
Pending,
|
||||
Cancelled,
|
||||
Hidden,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
enum ComposeRole {
|
||||
System,
|
||||
Assistant,
|
||||
User,
|
||||
}
|
||||
|
||||
impl ComposeRole {
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
ComposeRole::System => "system",
|
||||
ComposeRole::Assistant => "assistant",
|
||||
ComposeRole::User => "user",
|
||||
}
|
||||
}
|
||||
|
||||
fn next(self) -> Self {
|
||||
match self {
|
||||
ComposeRole::System => ComposeRole::Assistant,
|
||||
ComposeRole::Assistant => ComposeRole::User,
|
||||
ComposeRole::User => ComposeRole::System,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct AppState {
|
||||
ws: WebSocket,
|
||||
document: Document,
|
||||
messages_container: Element,
|
||||
input: HtmlTextAreaElement,
|
||||
status_label: Element,
|
||||
state: RefCell<ChatState>,
|
||||
}
|
||||
|
||||
fn append_message(
|
||||
document: &Document,
|
||||
container: &Element,
|
||||
role: &str,
|
||||
content: &str,
|
||||
status: Option<ChatStatus>,
|
||||
) -> Result<MessageNode, JsValue> {
|
||||
let wrapper = document.create_element("div")?;
|
||||
wrapper.set_class_name(&format!("chat-message chat-message--{}", role));
|
||||
|
||||
let role_el = document.create_element("div")?;
|
||||
role_el.set_class_name("chat-message__role");
|
||||
role_el.set_text_content(Some(role));
|
||||
|
||||
let content_el = document.create_element("div")?;
|
||||
content_el.set_class_name("chat-message__content");
|
||||
content_el.set_text_content(Some(content));
|
||||
|
||||
wrapper.append_child(&role_el)?;
|
||||
wrapper.append_child(&content_el)?;
|
||||
|
||||
let status_el = if let Some(status) = status {
|
||||
let status_el = document.create_element("div")?;
|
||||
apply_status(&status_el, status);
|
||||
wrapper.append_child(&status_el)?;
|
||||
Some(status_el)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
container.append_child(&wrapper)?;
|
||||
|
||||
Ok(MessageNode {
|
||||
wrapper,
|
||||
content: content_el,
|
||||
status: status_el,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_message_content(node: &Element, content: &str) {
|
||||
node.set_text_content(Some(content));
|
||||
}
|
||||
|
||||
fn apply_status(node: &Element, status: ChatStatus) {
|
||||
match status {
|
||||
ChatStatus::Pending => {
|
||||
node.set_class_name("chat-message__status chat-message__status--pending");
|
||||
node.set_text_content(Some("..."));
|
||||
}
|
||||
ChatStatus::Cancelled => {
|
||||
node.set_class_name("chat-message__status chat-message__status--cancelled");
|
||||
node.set_text_content(Some("[canceled]"));
|
||||
}
|
||||
ChatStatus::Hidden => {
|
||||
node.set_class_name("chat-message__status chat-message__status--hidden");
|
||||
node.set_text_content(Some(""));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_status_label(label: &Element, role: ComposeRole, generation_enabled: bool) {
|
||||
let generation = if generation_enabled { "on" } else { "off" };
|
||||
let text = format!("Role: {} | Generation: {}", role.as_str(), generation);
|
||||
label.set_text_content(Some(&text));
|
||||
}
|
||||
|
||||
fn scroll_to_bottom(container: &Element) {
|
||||
if let Some(element) = container.dyn_ref::<HtmlElement>() {
|
||||
let height = element.scroll_height();
|
||||
element.set_scroll_top(height);
|
||||
}
|
||||
static WS_HANDLE: RefCell<Option<Rc<WebSocket>>> = RefCell::new(None);
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
pub fn init_chat() -> Result<(), JsValue> {
|
||||
let window = window().ok_or_else(|| JsValue::from_str("Missing window"))?;
|
||||
let document = window.document().ok_or_else(|| JsValue::from_str("Missing document"))?;
|
||||
let location = window.location();
|
||||
let host = location.host()?;
|
||||
let protocol = location.protocol()?;
|
||||
let scheme = if protocol == "https:" { "wss" } else { "ws" };
|
||||
let ws_url = format!("{scheme}://{host}/chat");
|
||||
|
||||
let messages_container = document
|
||||
.get_element_by_id("chat-messages")
|
||||
.ok_or_else(|| JsValue::from_str("Missing chat-messages element"))?;
|
||||
let input_el: HtmlTextAreaElement = document
|
||||
.get_element_by_id("chat-input")
|
||||
.ok_or_else(|| JsValue::from_str("Missing chat-input element"))?
|
||||
.dyn_into()?;
|
||||
let status_label = document
|
||||
.get_element_by_id("chat-status")
|
||||
.ok_or_else(|| JsValue::from_str("Missing chat-status element"))?;
|
||||
|
||||
let ws = WebSocket::new(&ws_url)?;
|
||||
let ws = Rc::new(WebSocket::new(&ws_url)?);
|
||||
ws.set_binary_type(BinaryType::Arraybuffer);
|
||||
console::log_1(&format!("[ws] connecting to {ws_url}").into());
|
||||
|
||||
let app = Rc::new(AppState {
|
||||
ws,
|
||||
document,
|
||||
messages_container,
|
||||
input: input_el,
|
||||
status_label,
|
||||
state: RefCell::new(ChatState {
|
||||
messages: Vec::new(),
|
||||
message_nodes: Vec::new(),
|
||||
is_receiving: false,
|
||||
active_assistant_index: None,
|
||||
compose_role: ComposeRole::User,
|
||||
generation_enabled: true,
|
||||
}),
|
||||
});
|
||||
|
||||
{
|
||||
let state = app.state.borrow();
|
||||
update_status_label(&app.status_label, state.compose_role, state.generation_enabled);
|
||||
}
|
||||
|
||||
let ws_for_open = ws.clone();
|
||||
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
|
||||
console::log_1(&"[ws] connected".into());
|
||||
let requests = [
|
||||
UserRequest {
|
||||
request_id: 1,
|
||||
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: r#"this is the first message first chat
|
||||
too bad it is in lower case one one one"#
|
||||
.to_string(),
|
||||
},
|
||||
])),
|
||||
},
|
||||
UserRequest {
|
||||
request_id: 2,
|
||||
payload: UserRequestPayload::ChatCompletion(UserChatCompletionRequest::new(vec![
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: "And this is message two of chat request 2. \
|
||||
Too bad it isn't processed two two two"
|
||||
.to_string(),
|
||||
},
|
||||
])),
|
||||
},
|
||||
];
|
||||
|
||||
for request in requests {
|
||||
match request.to_bytes() {
|
||||
Ok(bytes) => {
|
||||
console::log_1(&format!("[ws] sending request_id={} bytes={}", request.request_id, bytes.len()).into());
|
||||
let _ = ws_for_open.send_with_u8_array(&bytes);
|
||||
}
|
||||
Err(err) => {
|
||||
console::error_1(&format!("[ws] encode error: {err:#}").into());
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
app.ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
|
||||
ws.set_onopen(Some(onopen.as_ref().unchecked_ref()));
|
||||
onopen.forget();
|
||||
|
||||
let app_for_keydown = app.clone();
|
||||
let onkeydown = Closure::<dyn FnMut(KeyboardEvent)>::wrap(Box::new(move |event: KeyboardEvent| {
|
||||
if event.ctrl_key() {
|
||||
let key = event.key();
|
||||
if key == "c" || key == "C" {
|
||||
let state = app_for_keydown.state.borrow();
|
||||
if state.is_receiving {
|
||||
event.prevent_default();
|
||||
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
|
||||
console::error_1(&"[ws] socket is not open".into());
|
||||
return;
|
||||
}
|
||||
|
||||
let request = UserRequest::ChatCompletionCancellation;
|
||||
match request.to_bytes() {
|
||||
Ok(bytes) => {
|
||||
if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) {
|
||||
console::error_1(&format!("[ws] cancel send error: {:?}", err).into());
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
console::error_1(&format!("[ws] cancel encode error: {err:#}").into());
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if key == "1" {
|
||||
event.prevent_default();
|
||||
let mut state = app_for_keydown.state.borrow_mut();
|
||||
state.compose_role = state.compose_role.next();
|
||||
update_status_label(
|
||||
&app_for_keydown.status_label,
|
||||
state.compose_role,
|
||||
state.generation_enabled,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if key == "4" {
|
||||
event.prevent_default();
|
||||
let mut state = app_for_keydown.state.borrow_mut();
|
||||
state.generation_enabled = !state.generation_enabled;
|
||||
update_status_label(
|
||||
&app_for_keydown.status_label,
|
||||
state.compose_role,
|
||||
state.generation_enabled,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if key == "d" || key == "D" {
|
||||
event.prevent_default();
|
||||
let mut state = app_for_keydown.state.borrow_mut();
|
||||
if state.is_receiving {
|
||||
return;
|
||||
}
|
||||
let last_message = match state.messages.pop() {
|
||||
Some(message) => message,
|
||||
None => return,
|
||||
};
|
||||
if let Some(node) = state.message_nodes.pop() {
|
||||
if let Some(parent) = node.wrapper.parent_node() {
|
||||
let _ = parent.remove_child(&node.wrapper);
|
||||
}
|
||||
}
|
||||
app_for_keydown.input.set_value(&last_message.content);
|
||||
scroll_to_bottom(&app_for_keydown.messages_container);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if event.key() != "Enter" || event.shift_key() {
|
||||
return;
|
||||
}
|
||||
event.prevent_default();
|
||||
|
||||
let mut state = app_for_keydown.state.borrow_mut();
|
||||
if state.is_receiving {
|
||||
return;
|
||||
}
|
||||
|
||||
let raw = app_for_keydown.input.value();
|
||||
let trimmed = raw.trim();
|
||||
if trimmed.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let role = state.compose_role;
|
||||
let user_content = trimmed.to_string();
|
||||
app_for_keydown.input.set_value("");
|
||||
|
||||
let user_node = match append_message(
|
||||
&app_for_keydown.document,
|
||||
&app_for_keydown.messages_container,
|
||||
role.as_str(),
|
||||
&user_content,
|
||||
None,
|
||||
) {
|
||||
Ok(node) => node,
|
||||
Err(err) => {
|
||||
console::error_1(&format!("[ui] failed to append user message: {:?}", err).into());
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
state.messages.push(UserChatMessage {
|
||||
role: role.as_str().to_string(),
|
||||
content: user_content,
|
||||
});
|
||||
state.message_nodes.push(user_node);
|
||||
scroll_to_bottom(&app_for_keydown.messages_container);
|
||||
|
||||
if role != ComposeRole::User || !state.generation_enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
|
||||
console::error_1(&"[ws] socket is not open".into());
|
||||
return;
|
||||
}
|
||||
|
||||
let history = state.messages.clone();
|
||||
let request = UserRequest::ChatCompletion(UserChatCompletionRequest::new(history));
|
||||
|
||||
let bytes = match request.to_bytes() {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
console::error_1(&format!("[ws] encode error: {err:#}").into());
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) {
|
||||
console::error_1(&format!("[ws] send error: {:?}", err).into());
|
||||
return;
|
||||
}
|
||||
|
||||
let assistant_node = match append_message(
|
||||
&app_for_keydown.document,
|
||||
&app_for_keydown.messages_container,
|
||||
"assistant",
|
||||
"",
|
||||
Some(ChatStatus::Pending),
|
||||
) {
|
||||
Ok(node) => node,
|
||||
Err(err) => {
|
||||
console::error_1(&format!("[ui] failed to append assistant message: {:?}", err).into());
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
state.messages.push(UserChatMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: String::new(),
|
||||
});
|
||||
state.message_nodes.push(assistant_node);
|
||||
state.active_assistant_index = Some(state.messages.len() - 1);
|
||||
state.is_receiving = true;
|
||||
scroll_to_bottom(&app_for_keydown.messages_container);
|
||||
}));
|
||||
app.input.set_onkeydown(Some(onkeydown.as_ref().unchecked_ref()));
|
||||
onkeydown.forget();
|
||||
|
||||
let app_for_message = app.clone();
|
||||
let onmessage = Closure::<dyn FnMut(MessageEvent)>::wrap(Box::new(move |event: MessageEvent| {
|
||||
let data = event.data();
|
||||
if let Some(text) = data.as_string() {
|
||||
@ -391,73 +87,36 @@ pub fn init_chat() -> Result<(), JsValue> {
|
||||
|
||||
let data = Uint8Array::new(&data);
|
||||
let bytes = data.to_vec();
|
||||
console::log_1(&format!("[ws] received bytes={}", bytes.len()).into());
|
||||
let response = match UserResponse::from_bytes(&bytes) {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
console::error_1(&format!("[ws] decode error: {err:#}").into());
|
||||
console::error_1(&format!("[ws] decode error: {err:#} (bytes={})", bytes.len()).into());
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut state = app_for_message.state.borrow_mut();
|
||||
let assistant_index = match state.active_assistant_index {
|
||||
Some(index) => index,
|
||||
None => {
|
||||
console::log_1(&"[ws] missing assistant index".into());
|
||||
return;
|
||||
match response.payload {
|
||||
UserResponsePayload::ChatCompletion(payload) => {
|
||||
console::log_1(&format!("[ws] request_id={} piece={}", response.request_id, payload.piece).into());
|
||||
}
|
||||
};
|
||||
|
||||
match response {
|
||||
UserResponse::ChatCompletion(completion) => {
|
||||
if let Some(message) = state.messages.get_mut(assistant_index) {
|
||||
message.content.push_str(&completion.piece);
|
||||
}
|
||||
if let Some(node) = state.message_nodes.get(assistant_index) {
|
||||
if let Some(message) = state.messages.get(assistant_index) {
|
||||
set_message_content(&node.content, &message.content);
|
||||
}
|
||||
}
|
||||
scroll_to_bottom(&app_for_message.messages_container);
|
||||
UserResponsePayload::ChatCompletionCancellation(_) => {
|
||||
console::log_1(&format!("[ws] request_id={} [cancel]", response.request_id).into());
|
||||
}
|
||||
UserResponse::ChatCompletionEnd => {
|
||||
state.is_receiving = false;
|
||||
if let Some(node) = state
|
||||
.active_assistant_index
|
||||
.and_then(|index| state.message_nodes.get(index))
|
||||
{
|
||||
if let Some(status_node) = node.status.as_ref() {
|
||||
apply_status(status_node, ChatStatus::Hidden);
|
||||
}
|
||||
}
|
||||
state.active_assistant_index = None;
|
||||
let _ = app_for_message.input.focus();
|
||||
}
|
||||
UserResponse::ChatCompletionCancellation => {
|
||||
state.is_receiving = false;
|
||||
if let Some(node) = state
|
||||
.active_assistant_index
|
||||
.and_then(|index| state.message_nodes.get(index))
|
||||
{
|
||||
if let Some(status_node) = node.status.as_ref() {
|
||||
apply_status(status_node, ChatStatus::Cancelled);
|
||||
}
|
||||
}
|
||||
state.active_assistant_index = None;
|
||||
let _ = app_for_message.input.focus();
|
||||
UserResponsePayload::ChatCompletionEnd(_) => {
|
||||
console::log_1(&format!("[ws] request_id={} [end]", response.request_id).into());
|
||||
}
|
||||
}
|
||||
}));
|
||||
app.ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
|
||||
ws.set_onmessage(Some(onmessage.as_ref().unchecked_ref()));
|
||||
onmessage.forget();
|
||||
|
||||
let onerror = Closure::<dyn FnMut(ErrorEvent)>::wrap(Box::new(move |event: ErrorEvent| {
|
||||
console::error_1(&format!("[ws] error: {}", event.message()).into());
|
||||
}));
|
||||
app.ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
|
||||
ws.set_onerror(Some(onerror.as_ref().unchecked_ref()));
|
||||
onerror.forget();
|
||||
|
||||
APP_HANDLE.with(|slot| *slot.borrow_mut() = Some(app.clone()));
|
||||
let _ = app.input.focus();
|
||||
WS_HANDLE.with(|slot| *slot.borrow_mut() = Some(ws.clone()));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -7,158 +7,6 @@ body {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.chat-html,
|
||||
.chat-body {
|
||||
height: 100%;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.chat-body {
|
||||
margin: 0;
|
||||
background: #ffffff;
|
||||
color: #1b1b1b;
|
||||
font-family: "Fira Sans", "Space Grotesk", "Montserrat", sans-serif;
|
||||
font-size: 14pt;
|
||||
display: block;
|
||||
overflow: auto;
|
||||
}
|
||||
|
||||
.chat-shell {
|
||||
min-height: 100%;
|
||||
padding: 32px 24px;
|
||||
box-sizing: border-box;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.chat-panel {
|
||||
width: min(100%, max(70%, 1000pt));
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 16px;
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
.chat-header {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.chat-title {
|
||||
margin: 0;
|
||||
font-size: 20pt;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.chat-subtitle {
|
||||
font-size: 10pt;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1px;
|
||||
color: #6b6b6b;
|
||||
}
|
||||
|
||||
.chat-container {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 16px;
|
||||
min-height: 0;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.chat-messages {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
overflow-y: auto;
|
||||
padding-right: 4px;
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.chat-message {
|
||||
width: 100%;
|
||||
border-radius: 14px;
|
||||
padding: 12px 14px;
|
||||
box-sizing: border-box;
|
||||
border: 1px solid #e2e2e2;
|
||||
background: #fafafa;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.chat-message--user {
|
||||
background: #eef4ff;
|
||||
border-color: #d5e2ff;
|
||||
}
|
||||
|
||||
.chat-message--assistant {
|
||||
background: #f7f2ff;
|
||||
border-color: #e5d7ff;
|
||||
}
|
||||
|
||||
.chat-message--system {
|
||||
background: #f8f8f8;
|
||||
border-color: #dddddd;
|
||||
}
|
||||
|
||||
.chat-message__role {
|
||||
font-size: 9pt;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1px;
|
||||
color: #6b6b6b;
|
||||
}
|
||||
|
||||
.chat-message__content {
|
||||
white-space: pre-wrap;
|
||||
line-height: 1.45;
|
||||
}
|
||||
|
||||
.chat-message__status {
|
||||
font-size: 9pt;
|
||||
color: #9aa0a6;
|
||||
}
|
||||
|
||||
.chat-message__status--pending {
|
||||
color: #9aa0a6;
|
||||
}
|
||||
|
||||
.chat-message__status--cancelled {
|
||||
color: #c0392b;
|
||||
}
|
||||
|
||||
.chat-message__status--hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.chat-input-area {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.chat-input {
|
||||
width: 100%;
|
||||
min-height: 120px;
|
||||
padding: 12px 14px;
|
||||
border-radius: 12px;
|
||||
border: 1px solid #d4d4d4;
|
||||
font-size: 14pt;
|
||||
font-family: inherit;
|
||||
resize: vertical;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.chat-input:disabled {
|
||||
background: #f2f2f2;
|
||||
color: #8a8a8a;
|
||||
}
|
||||
|
||||
.chat-status-label {
|
||||
font-size: 15pt;
|
||||
color: #4b4b4b;
|
||||
}
|
||||
|
||||
body {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
||||
@ -10,7 +10,7 @@ pub use self::utils::{
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserChatMessage {
|
||||
pub struct ChatMessage {
|
||||
#[deku(
|
||||
reader = "read_pascal_string(deku::reader)",
|
||||
writer = "write_pascal_string(deku::writer, &self.role)"
|
||||
@ -31,23 +31,33 @@ pub struct UserChatCompletionRequest {
|
||||
reader = "read_vec_u32(deku::reader)",
|
||||
writer = "write_vec_u32(deku::writer, &self.messages)"
|
||||
)]
|
||||
pub messages: Vec<UserChatMessage>,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
}
|
||||
|
||||
impl UserChatCompletionRequest {
|
||||
pub fn new(messages: Vec<UserChatMessage>) -> Self {
|
||||
pub fn new(messages: Vec<ChatMessage>) -> Self {
|
||||
Self { messages }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserChatCompletionCancellationRequest;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
#[deku(id_type = "u8")]
|
||||
#[repr(u8)]
|
||||
pub enum UserRequest {
|
||||
pub enum UserRequestPayload {
|
||||
#[deku(id = "0")]
|
||||
ChatCompletion(UserChatCompletionRequest),
|
||||
#[deku(id = "1")]
|
||||
ChatCompletionCancellation,
|
||||
ChatCompletionCancellation(UserChatCompletionCancellationRequest),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserRequest {
|
||||
#[deku(endian = "little")]
|
||||
pub request_id: u64,
|
||||
pub payload: UserRequestPayload,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
@ -59,23 +69,42 @@ pub struct UserResponseChatCompletion {
|
||||
pub piece: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserResponseChatCompletionCancellation;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserResponseChatCompletionEnd;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
#[deku(id_type = "u8")]
|
||||
#[repr(u8)]
|
||||
pub enum UserResponse {
|
||||
pub enum UserResponsePayload {
|
||||
#[deku(id = "0")]
|
||||
ChatCompletion(UserResponseChatCompletion),
|
||||
#[deku(id = "1")]
|
||||
ChatCompletionCancellation,
|
||||
ChatCompletionCancellation(UserResponseChatCompletionCancellation),
|
||||
#[deku(id = "2")]
|
||||
ChatCompletionEnd,
|
||||
ChatCompletionEnd(UserResponseChatCompletionEnd),
|
||||
}
|
||||
|
||||
impl DekuBytes for UserChatMessage {}
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct UserResponse {
|
||||
#[deku(endian = "little")]
|
||||
pub request_id: u64,
|
||||
pub payload: UserResponsePayload,
|
||||
}
|
||||
|
||||
impl DekuBytes for ChatMessage {}
|
||||
|
||||
impl DekuBytes for UserChatCompletionRequest {}
|
||||
|
||||
impl DekuBytes for UserChatCompletionCancellationRequest {}
|
||||
|
||||
impl DekuBytes for UserRequestPayload {}
|
||||
impl DekuBytes for UserRequest {}
|
||||
|
||||
impl DekuBytes for UserResponseChatCompletion {}
|
||||
impl DekuBytes for UserResponseChatCompletionCancellation {}
|
||||
impl DekuBytes for UserResponseChatCompletionEnd {}
|
||||
impl DekuBytes for UserResponsePayload {}
|
||||
impl DekuBytes for UserResponse {}
|
||||
|
||||
7
website/src/bin/E.rs
Normal file
7
website/src/bin/E.rs
Normal file
@ -0,0 +1,7 @@
|
||||
use website::dedicated_ai_server::TEST::main_E;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
main_E().await?;
|
||||
Ok(())
|
||||
}
|
||||
7
website/src/bin/F.rs
Normal file
7
website/src/bin/F.rs
Normal file
@ -0,0 +1,7 @@
|
||||
use website::dedicated_ai_server::TEST::main_F;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
main_F().await?;
|
||||
Ok(())
|
||||
}
|
||||
117
website/src/dedicated_ai_server/TEST.rs
Normal file
117
website/src/dedicated_ai_server/TEST.rs
Normal file
@ -0,0 +1,117 @@
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
use anyhow::Result;
|
||||
use crate::dedicated_ai_server::api::{
|
||||
ChatCompletionRequest,
|
||||
MessageInChat,
|
||||
Request,
|
||||
RequestPayload,
|
||||
};
|
||||
use frontend_protocol::DekuBytes;
|
||||
use crate::dedicated_ai_server::talking::{SecretStreamSocket, wrap_connection_socket};
|
||||
use crate::dedicated_ai_server::talking::{ProtocolError, FrameCallback};
|
||||
|
||||
|
||||
const SERVER_HOST: &str = "127.0.0.1";
|
||||
const SERVER_PORT: u16 = 9000;
|
||||
const SHARED_SECRET: &str = "change-me";
|
||||
|
||||
/* client =========== aka F.py */
|
||||
|
||||
|
||||
async fn connect_to_server(host: &str, port: u16, shared_secret: &str) -> Result<SecretStreamSocket> {
|
||||
let socket = TcpStream::connect((host, port)).await?;
|
||||
wrap_connection_socket(socket, shared_secret).await
|
||||
}
|
||||
|
||||
|
||||
async fn send_message(connection: &mut SecretStreamSocket, request_id: u64, role: &str, content: &str) -> Result<()> {
|
||||
let batch = ChatCompletionRequest::new(vec![MessageInChat {
|
||||
role: role.to_string(),
|
||||
content: content.to_string(),
|
||||
}]);
|
||||
let request = Request {
|
||||
request_id,
|
||||
payload: RequestPayload::ChatCompletion(batch),
|
||||
};
|
||||
let payload = request.to_bytes()?;
|
||||
connection.send_frame(&payload).await
|
||||
}
|
||||
|
||||
|
||||
async fn close_connection(connection: &mut SecretStreamSocket) -> Result<()> {
|
||||
connection.close(true).await
|
||||
}
|
||||
|
||||
|
||||
pub async fn main_F() -> Result<()> {
|
||||
let mut connection = connect_to_server(SERVER_HOST, SERVER_PORT, SHARED_SECRET).await?;
|
||||
send_message(&mut connection, 1, "user", "hello from Rust client").await?;
|
||||
close_connection(&mut connection).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
/* server========= aka E.py */
|
||||
|
||||
fn print_batch(peer: &str, request_id: u64, batch: &ChatCompletionRequest) {
|
||||
println!(
|
||||
"[packet] {} sent request_id={} with {} message(s)",
|
||||
peer,
|
||||
request_id,
|
||||
batch.messages.len(),
|
||||
);
|
||||
for (index, message) in batch.messages.iter().enumerate() {
|
||||
println!(
|
||||
" [{}] role={:?} content={:?}",
|
||||
index, message.role, message.content
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
async fn handle_client(client_socket: TcpStream, peer: String, shared_secret: &str) -> Result<()> {
|
||||
let peer_for_callback = peer.clone();
|
||||
let on_frame: FrameCallback<'_> = Box::new(move |frame: Vec<u8>| {
|
||||
let request = Request::from_bytes(&frame)?;
|
||||
match request.payload {
|
||||
RequestPayload::ChatCompletion(batch) => {
|
||||
print_batch(&peer_for_callback, request.request_id, &batch);
|
||||
}
|
||||
RequestPayload::ChatCompletionCancellation(_) => {
|
||||
println!(
|
||||
"[packet] {} sent request_id={} cancel",
|
||||
peer_for_callback,
|
||||
request.request_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
});
|
||||
|
||||
let mut transport = wrap_connection_socket(client_socket, shared_secret).await?;
|
||||
println!("[connected] {}", peer);
|
||||
let result = transport.run_receiving_loop(on_frame).await;
|
||||
println!("[disconnected] {}", peer);
|
||||
let _ = transport.close(true).await;
|
||||
result
|
||||
}
|
||||
|
||||
|
||||
pub async fn main_E() -> Result<()> {
|
||||
let listener = TcpListener::bind((SERVER_HOST, SERVER_PORT)).await?;
|
||||
println!("[listening] {}:{}", SERVER_HOST, SERVER_PORT);
|
||||
|
||||
loop {
|
||||
let (client_socket, addr) = listener.accept().await?;
|
||||
let peer = addr.to_string();
|
||||
|
||||
if let Err(err) = handle_client(client_socket, peer.clone(), SHARED_SECRET).await {
|
||||
if err.downcast_ref::<ProtocolError>().is_some() {
|
||||
eprintln!("[protocol error] {}: {:#}", peer, err);
|
||||
} else {
|
||||
eprintln!("[error] {}: {:#}", peer, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -40,6 +40,9 @@ impl ChatCompletionRequest {
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct ChatCompletionCancellationRequest;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
#[deku(id_type = "u8")]
|
||||
#[repr(u8)]
|
||||
@ -47,7 +50,7 @@ pub enum RequestPayload {
|
||||
#[deku(id = "0")]
|
||||
ChatCompletion(ChatCompletionRequest),
|
||||
#[deku(id = "1")]
|
||||
ChatCompletionCancellation,
|
||||
ChatCompletionCancellation(ChatCompletionCancellationRequest),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
@ -66,6 +69,12 @@ pub struct ResponseChatCompletion {
|
||||
pub piece: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct ResponseChatCompletionCancellation;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
pub struct ResponseChatCompletionEnd;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
#[deku(id_type = "u8")]
|
||||
#[repr(u8)]
|
||||
@ -73,9 +82,9 @@ pub enum ResponsePayload {
|
||||
#[deku(id = "0")]
|
||||
ChatCompletion(ResponseChatCompletion),
|
||||
#[deku(id = "1")]
|
||||
ChatCompletionCancellation,
|
||||
ChatCompletionCancellation(ResponseChatCompletionCancellation),
|
||||
#[deku(id = "2")]
|
||||
ChatCompletionEnd
|
||||
ChatCompletionEnd(ResponseChatCompletionEnd)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, DekuRead, DekuWrite)]
|
||||
@ -89,9 +98,12 @@ impl DekuBytes for MessageInChat {}
|
||||
|
||||
impl DekuBytes for ChatCompletionRequest {}
|
||||
|
||||
impl DekuBytes for ChatCompletionCancellationRequest {}
|
||||
|
||||
impl DekuBytes for RequestPayload {}
|
||||
impl DekuBytes for Request {}
|
||||
|
||||
impl DekuBytes for ResponseChatCompletion {}
|
||||
impl DekuBytes for ResponseChatCompletionCancellation {}
|
||||
impl DekuBytes for ResponsePayload {}
|
||||
impl DekuBytes for Response {}
|
||||
|
||||
@ -32,14 +32,9 @@ struct PendingChatCompletionRecord {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum RequestWorkItem {
|
||||
ChatCompletion {
|
||||
request: Request,
|
||||
response_tx: mpsc::UnboundedSender<MessagePiece>,
|
||||
},
|
||||
ChatCompletionCancellation {
|
||||
request_id: u64,
|
||||
},
|
||||
struct RequestWorkItem {
|
||||
request: Request,
|
||||
response_tx: mpsc::UnboundedSender<MessagePiece>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -75,7 +70,7 @@ impl DedicatedAiServerConnection {
|
||||
pub fn send_chat_completion(
|
||||
&self,
|
||||
messages: Vec<MessageInChat>,
|
||||
) -> Result<(mpsc::UnboundedReceiver<MessagePiece>, u64)> {
|
||||
) -> Result<mpsc::UnboundedReceiver<MessagePiece>> {
|
||||
let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
|
||||
let request = Request {
|
||||
request_id,
|
||||
@ -83,16 +78,9 @@ impl DedicatedAiServerConnection {
|
||||
};
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
self.request_tx
|
||||
.send(RequestWorkItem::ChatCompletion { request, response_tx })
|
||||
.send(RequestWorkItem { request, response_tx })
|
||||
.map_err(|err| anyhow::anyhow!("failed to enqueue request: {err}"))?;
|
||||
Ok((response_rx, request_id))
|
||||
}
|
||||
|
||||
pub fn send_chat_completion_cancellation(&self, request_id: u64) -> Result<()> {
|
||||
self.request_tx
|
||||
.send(RequestWorkItem::ChatCompletionCancellation { request_id })
|
||||
.map_err(|err| anyhow::anyhow!("failed to enqueue cancellation: {err}"))?;
|
||||
Ok(())
|
||||
Ok(response_rx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -157,26 +145,14 @@ async fn run_connected_loop(
|
||||
Some(item) => item,
|
||||
None => return Err(anyhow::anyhow!("request channel closed")), // idk why this would happen
|
||||
};
|
||||
match item {
|
||||
RequestWorkItem::ChatCompletion { request, response_tx } => {
|
||||
let request_id = request.request_id;
|
||||
pending.insert(
|
||||
request_id,
|
||||
PendingChatCompletionRecord { response_tx },
|
||||
);
|
||||
let request_id = item.request.request_id;
|
||||
pending.insert(
|
||||
request_id,
|
||||
PendingChatCompletionRecord { response_tx: item.response_tx },
|
||||
);
|
||||
|
||||
let payload = request.to_bytes().context("failed to encode request")?;
|
||||
transport.send_frame(&payload).await.context("failed to send request")?;
|
||||
}
|
||||
RequestWorkItem::ChatCompletionCancellation { request_id } => {
|
||||
let request = Request {
|
||||
request_id,
|
||||
payload: RequestPayload::ChatCompletionCancellation,
|
||||
};
|
||||
let payload = request.to_bytes().context("failed to encode cancellation")?;
|
||||
transport.send_frame(&payload).await.context("failed to send cancellation")?;
|
||||
}
|
||||
}
|
||||
let payload = item.request.to_bytes().context("failed to encode request")?;
|
||||
transport.send_frame(&payload).await.context("failed to send request")?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -207,12 +183,12 @@ fn handle_response_frame(
|
||||
pending.remove(&request_id);
|
||||
}
|
||||
}
|
||||
ResponsePayload::ChatCompletionEnd => {
|
||||
ResponsePayload::ChatCompletionEnd(_) => {
|
||||
if let Some(record) = pending.remove(&request_id) {
|
||||
let _ = record.response_tx.send(MessagePiece::End);
|
||||
}
|
||||
}
|
||||
ResponsePayload::ChatCompletionCancellation => {
|
||||
ResponsePayload::ChatCompletionCancellation(_) => {
|
||||
if let Some(record) = pending.remove(&request_id) {
|
||||
let _ = record.response_tx.send(MessagePiece::Cancelled);
|
||||
}
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
pub mod api;
|
||||
pub mod connection;
|
||||
pub mod talking;
|
||||
pub mod TEST;
|
||||
|
||||
@ -31,8 +31,12 @@ use crate::dedicated_ai_server::connection::{MessagePiece, MessagePiecePayload};
|
||||
use frontend_protocol::{
|
||||
DekuBytes,
|
||||
UserRequest,
|
||||
UserRequestPayload,
|
||||
UserResponse,
|
||||
UserResponseChatCompletion,
|
||||
UserResponseChatCompletionCancellation,
|
||||
UserResponseChatCompletionEnd,
|
||||
UserResponsePayload,
|
||||
};
|
||||
|
||||
async fn init_app_state() -> Result<(AppState, tokio::task::JoinHandle<()>), Box<dyn std::error::Error>> {
|
||||
@ -260,11 +264,10 @@ async fn chat(
|
||||
}
|
||||
|
||||
async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
|
||||
'outer: loop {
|
||||
let msg = match socket.recv().await {
|
||||
Some(Ok(msg)) => msg,
|
||||
Some(Err(_)) => break,
|
||||
None => break,
|
||||
'outer: while let Some(msg) = socket.recv().await {
|
||||
let msg = match msg {
|
||||
Ok(msg) => msg,
|
||||
Err(_) => break,
|
||||
};
|
||||
|
||||
match msg {
|
||||
@ -277,62 +280,60 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
|
||||
}
|
||||
};
|
||||
|
||||
let payload = match request {
|
||||
UserRequest::ChatCompletion(payload) => payload,
|
||||
UserRequest::ChatCompletionCancellation => {
|
||||
eprintln!("[chat] protocol error: unexpected cancellation without active request");
|
||||
break;
|
||||
}
|
||||
};
|
||||
match request.payload {
|
||||
UserRequestPayload::ChatCompletion(payload) => {
|
||||
let messages = payload
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| MessageInChat {
|
||||
role: msg.role,
|
||||
content: msg.content,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let messages = payload
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(|msg| MessageInChat {
|
||||
role: msg.role,
|
||||
content: msg.content,
|
||||
})
|
||||
.collect();
|
||||
let mut response_rx = match state.dedicated_ai.send_chat_completion(messages) {
|
||||
Ok(rx) => rx,
|
||||
Err(err) => {
|
||||
eprintln!("[chat] failed to send request: {err:#}");
|
||||
let response = UserResponse {
|
||||
request_id: request.request_id,
|
||||
payload: UserResponsePayload::ChatCompletionCancellation(
|
||||
UserResponseChatCompletionCancellation,
|
||||
),
|
||||
};
|
||||
if let Ok(bytes) = response.to_bytes() {
|
||||
let _ = socket.send(Message::Binary(bytes)).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (mut response_rx, request_id) = match state.dedicated_ai.
|
||||
send_chat_completion(messages)
|
||||
{
|
||||
Ok(rx) => rx,
|
||||
Err(err) => {
|
||||
eprintln!("[chat] failed to send request: {err:#}");
|
||||
let response = UserResponse::ChatCompletionCancellation;
|
||||
// todo: make to_bytes nofail
|
||||
if let Ok(bytes) = response.to_bytes() {
|
||||
let _ = socket.send(Message::Binary(bytes)).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
piece = response_rx.recv() => {
|
||||
let piece = match piece {
|
||||
Some(piece) => piece,
|
||||
None => break,
|
||||
};
|
||||
let (response, should_break) = match piece {
|
||||
while let Some(piece) = response_rx.recv().await {
|
||||
let (payload, should_break) = match piece {
|
||||
MessagePiece::Piece(MessagePiecePayload(text)) => (
|
||||
UserResponse::ChatCompletion(UserResponseChatCompletion {
|
||||
UserResponsePayload::ChatCompletion(UserResponseChatCompletion {
|
||||
piece: text,
|
||||
}),
|
||||
false,
|
||||
),
|
||||
MessagePiece::End => (
|
||||
UserResponse::ChatCompletionEnd,
|
||||
UserResponsePayload::ChatCompletionEnd(
|
||||
UserResponseChatCompletionEnd,
|
||||
),
|
||||
true,
|
||||
),
|
||||
MessagePiece::Cancelled | MessagePiece::DedicatedServerDisconnected => (
|
||||
UserResponse::ChatCompletionCancellation,
|
||||
UserResponsePayload::ChatCompletionCancellation(
|
||||
UserResponseChatCompletionCancellation,
|
||||
),
|
||||
true,
|
||||
),
|
||||
};
|
||||
|
||||
let response = UserResponse {
|
||||
request_id: request.request_id,
|
||||
payload,
|
||||
};
|
||||
let bytes = match response.to_bytes() {
|
||||
Ok(bytes) => bytes,
|
||||
Err(err) => {
|
||||
@ -347,57 +348,24 @@ async fn handle_chat_socket(mut socket: WebSocket, state: AppState) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg = socket.recv() => {
|
||||
let msg = match msg {
|
||||
Some(Ok(msg)) => msg,
|
||||
Some(Err(_)) => break 'outer,
|
||||
None => break 'outer,
|
||||
};
|
||||
|
||||
match msg {
|
||||
Message::Binary(bytes) => {
|
||||
let request = match UserRequest::from_bytes(&bytes) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
eprintln!("[chat] failed to decode request: {err:#}");
|
||||
break 'outer;
|
||||
}
|
||||
};
|
||||
|
||||
match request {
|
||||
UserRequest::ChatCompletionCancellation => {
|
||||
if let Err(err) = state
|
||||
.dedicated_ai
|
||||
.send_chat_completion_cancellation(request_id)
|
||||
{
|
||||
eprintln!("[chat] failed to send cancellation: {err:#}");
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
UserRequest::ChatCompletion(_) => {
|
||||
eprintln!("[chat] protocol error: chat completion while receiving");
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Close(_) => break 'outer,
|
||||
_ => {
|
||||
eprintln!("[chat] protocol error: unexpected non-binary message");
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
UserRequestPayload::ChatCompletionCancellation(_) => {
|
||||
let response = UserResponse {
|
||||
request_id: request.request_id,
|
||||
payload: UserResponsePayload::ChatCompletionCancellation(
|
||||
UserResponseChatCompletionCancellation,
|
||||
),
|
||||
};
|
||||
if let Ok(bytes) = response.to_bytes() {
|
||||
if socket.send(Message::Binary(bytes)).await.is_err() {
|
||||
break 'outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Message::Close(_) => {
|
||||
println!(" [debug] websocket closed");
|
||||
break;
|
||||
}
|
||||
_ => {
|
||||
eprintln!("[chat] protocol error: unexpected non-binary message");
|
||||
break;
|
||||
}
|
||||
Message::Close(_) => break,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user