diff --git a/dedicated_ai_server/B2.py b/dedicated_ai_server/B2.py index 1dffbc9..b319970 100644 --- a/dedicated_ai_server/B2.py +++ b/dedicated_ai_server/B2.py @@ -4,6 +4,7 @@ from huggingface_hub import snapshot_download MODEL_ID = "zai-org/GLM-4.7-Flash" +# MODEL_ID = "mlabonne/Daredevil-8B-abliterated-GGUF" def main() -> None: diff --git a/dedicated_ai_server/G.py b/dedicated_ai_server/G.py index 28021ca..cc34e0c 100644 --- a/dedicated_ai_server/G.py +++ b/dedicated_ai_server/G.py @@ -11,6 +11,7 @@ from transformers import ( ) MODEL_ID = "zai-org/GLM-4.7-Flash" +# MODEL_ID = "/home/gregory/programming/testWithPython/Daredevil-GGUF/daredevil-8b-abliterated.Q8_0.gguf" MAX_NEW_TOKENS = 256 SYSTEM_PROMPT = ( diff --git a/dedicated_ai_server/pyproject.toml b/dedicated_ai_server/pyproject.toml index 87f1cf9..2b10624 100644 --- a/dedicated_ai_server/pyproject.toml +++ b/dedicated_ai_server/pyproject.toml @@ -4,6 +4,7 @@ name = "dedicated_ai_server" version = "0.1.0" requires-python = ">=3.13" dependencies = [ + "accelerate>=1.13.0", "asyncio>=4.0.0", "construct>=2.10.70", "huggingface-hub>=1.0.0", diff --git a/dedicated_ai_server/quick-setup-1.sh b/dedicated_ai_server/quick-setup-1.sh new file mode 100644 index 0000000..2c1afb6 --- /dev/null +++ b/dedicated_ai_server/quick-setup-1.sh @@ -0,0 +1,54 @@ +#!/usr/bin/env bash +set -euo pipefail + +FISH_CONFIG_DIR="${HOME}/.config/fish" +FISH_FUNCTIONS_DIR="${FISH_CONFIG_DIR}/functions" +FISH_CONFIG_FILE="${FISH_CONFIG_DIR}/config.fish" +FISH_PROMPT_FILE="${FISH_FUNCTIONS_DIR}/fish_prompt.fish" + + +if ! command -v dnf >/dev/null 2>&1; then + echo "This script is for Fedora systems with dnf." + exit 1 +fi + + +sudo dnf install -y htop tree git gcc gcc-c++ make cmake fish vim + +curl -LsSf https://astral.sh/uv/install.sh | sh + +if ! sudo dnf install -y uv; then + echo "uv package is not available in the currently enabled Fedora repositories." +fi + +mkdir -p "${FISH_FUNCTIONS_DIR}" +touch "${FISH_CONFIG_FILE}" + +if ! grep -qxF 'set -g fish_greeting' "${FISH_CONFIG_FILE}"; then + printf '\nset -g fish_greeting\n' >> "${FISH_CONFIG_FILE}" +fi + +cat > "${FISH_PROMPT_FILE}" <<'EOF' +function fish_prompt + set -l last_status $status + + set_color brblue + printf '%s' (prompt_pwd) + set_color normal + + if test $last_status -ne 0 + printf ' ' + set_color red + printf '[%s]' $last_status + set_color normal + end + + printf '\n' + set_color brcyan + printf '> ' + set_color normal +end +EOF + +echo "Installed: htop tree git gcc gcc-c++ make cmake fish" +echo "Fish prompt configured. Start it with: fish" diff --git a/dedicated_ai_server/server.py b/dedicated_ai_server/server.py index 898b7ce..57072d4 100644 --- a/dedicated_ai_server/server.py +++ b/dedicated_ai_server/server.py @@ -3,9 +3,11 @@ from __future__ import annotations import asyncio import queue import threading -import time from dataclasses import dataclass, field -from typing import Dict, Iterable +from typing import Any, Dict, Iterable + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer from api import Request, Response from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket @@ -15,8 +17,9 @@ SERVER_HOST = "127.0.0.1" SERVER_PORT = 9000 SHARED_SECRET = "change-me" - -PROCESS_DELAY_SECONDS = 0.4 +MODEL_ID = "zai-org/GLM-4.7-Flash" +MAX_NEW_TOKENS = 256 +PIECE_CHUNK_SIZE = 64 @dataclass @@ -51,20 +54,81 @@ class WorkItem: record: PendingChatCompletionRecord -def extract_last_message(messages: list) -> str: - if not messages: - return "" - last = messages[-1] - if hasattr(last, "content"): - return last.content - return last.get("content", "") +@dataclass +class ModelBundle: + tokenizer: Any + model: Any -def generate_uppercase_pieces(text: str) -> Iterable[str]: - words = text.split() - for word in words: - time.sleep(PROCESS_DELAY_SECONDS) - yield word.upper() + " " +def fail(message: str) -> None: + print(message, flush=True) + raise SystemExit(1) + + +def build_context(messages: list) -> list[dict[str, str]]: + context: list[dict[str, str]] = [] + for message in messages: + if hasattr(message, "role") and hasattr(message, "content"): + role = message.role + content = message.content + else: + role = message.get("role", "") + content = message.get("content", "") + context.append({"role": role, "content": content}) + return context + + +def load_local_model() -> ModelBundle: + try: + tokenizer = AutoTokenizer.from_pretrained( + MODEL_ID, + local_files_only=True, + ) + model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + local_files_only=True, + dtype=torch.bfloat16, + device_map="auto", + ) + except OSError as exc: + fail( + f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. " + f"Run B2.py first. Original error: {exc}" + ) + + return ModelBundle( + tokenizer=tokenizer, + model=model, + ) + + +def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]: + context = build_context(messages) + inputs = bundle.tokenizer.apply_chat_template( + context, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + ) + inputs = {name: tensor.to(bundle.model.device) for name, tensor in inputs.items()} + + input_ids = inputs.get("input_ids") + input_len = int(input_ids.shape[1]) if input_ids is not None else 0 + + with torch.inference_mode(): + output_ids = bundle.model.generate( + **inputs, + max_new_tokens=MAX_NEW_TOKENS, + do_sample=False, + ) + + generated_ids = output_ids[0][input_len:] + text = bundle.tokenizer.decode(generated_ids, skip_special_tokens=True) + if not text: + return [] + + return [text[i:i + PIECE_CHUNK_SIZE] for i in range(0, len(text), PIECE_CHUNK_SIZE)] def worker_loop( @@ -72,6 +136,7 @@ def worker_loop( pending: Dict[int, PendingChatCompletionRecord], pending_lock: threading.Lock, loop: asyncio.AbstractEventLoop, + model_bundle: ModelBundle, ) -> None: while True: item = work_queue.get() @@ -84,15 +149,12 @@ def worker_loop( if record.is_cancelled(): cancelled = True else: - text = extract_last_message(item.messages) - for piece in generate_uppercase_pieces(text): + for piece in generate_llm_pieces(model_bundle, item.messages): loop.call_soon_threadsafe( record.response_queue.put_nowait, MessagePiece(piece=piece, is_end=False, is_cancel=False), ) - print("[debug] got a new piece") if record.is_cancelled(): - print("[debug] record was cancelled") cancelled = True break @@ -106,6 +168,12 @@ def worker_loop( record.response_queue.put_nowait, MessagePiece(is_end=True, is_cancel=False), ) + except BaseException as exc: + print(f"[worker] generation failed: {exc}", flush=True) + loop.call_soon_threadsafe( + record.response_queue.put_nowait, + MessagePiece(is_end=True, is_cancel=True), + ) finally: with pending_lock: pending.pop(item.request_id, None) @@ -155,6 +223,38 @@ async def forward_pieces( return +async def handle_chat_request( + request: Request, + work_queue: queue.Queue[WorkItem | None], + pending: Dict[int, PendingChatCompletionRecord], + pending_lock: threading.Lock, + transport: SecretStreamSocket, + send_lock: asyncio.Lock, + response_tasks: set[asyncio.Task[None]], +) -> None: + request_id = request.request_id + print(f"[request] chat request_id={request_id}", flush=True) + response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue() + record = PendingChatCompletionRecord( + response_queue=response_queue, + ) + + with pending_lock: + if request_id in pending: + raise ProtocolError( + f"Duplicate request_id {request_id} received on this connection" + ) + pending[request_id] = record + + work_queue.put(WorkItem(request_id, list(request.payload.messages), record)) + + task = asyncio.create_task( + forward_pieces(request_id, response_queue, transport, send_lock) + ) + response_tasks.add(task) + task.add_done_callback(response_tasks.discard) + + async def handle_client( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, @@ -178,27 +278,16 @@ async def handle_client( request_id = request.request_id if request.kind == "chat": - print(f"[request] chat request_id={request_id}", flush=True) - response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue() - record = PendingChatCompletionRecord( - response_queue=response_queue, + await handle_chat_request( + request, + work_queue, + pending, + pending_lock, + transport, + send_lock, + response_tasks, ) - - with pending_lock: - if request_id in pending: - raise ProtocolError( - f"Duplicate request_id {request_id} received on this connection" - ) - pending[request_id] = record - owned_request_ids.add(request_id) - work_queue.put(WorkItem(request_id, list(request.payload.messages), record)) - - task = asyncio.create_task( - forward_pieces(request_id, response_queue, transport, send_lock) - ) - response_tasks.add(task) - task.add_done_callback(response_tasks.discard) elif request.kind == "cancel": print(f"[request] cancel request_id={request_id}", flush=True) with pending_lock: @@ -224,7 +313,7 @@ async def handle_client( await transport.close() -async def run_server() -> None: +async def run_server(model_bundle: ModelBundle) -> None: pending: Dict[int, PendingChatCompletionRecord] = {} pending_lock = threading.Lock() work_queue: queue.Queue[WorkItem | None] = queue.Queue() @@ -232,7 +321,7 @@ async def run_server() -> None: loop = asyncio.get_running_loop() worker = threading.Thread( target=worker_loop, - args=(work_queue, pending, pending_lock, loop), + args=(work_queue, pending, pending_lock, loop, model_bundle), daemon=True, ) worker.start() @@ -256,7 +345,9 @@ async def run_server() -> None: def main() -> None: - asyncio.run(run_server()) + model_bundle = load_local_model() + print("[model] loaded", flush=True) + asyncio.run(run_server(model_bundle)) if __name__ == "__main__": diff --git a/dedicated_ai_server/uv.lock b/dedicated_ai_server/uv.lock index e5aa936..ecb5c87 100644 --- a/dedicated_ai_server/uv.lock +++ b/dedicated_ai_server/uv.lock @@ -2,6 +2,24 @@ version = 1 revision = 3 requires-python = ">=3.13" +[[package]] +name = "accelerate" +version = "1.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyyaml" }, + { name = "safetensors" }, + { name = "torch" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -188,6 +206,7 @@ name = "dedicated-ai-server" version = "0.1.0" source = { virtual = "." } dependencies = [ + { name = "accelerate" }, { name = "asyncio" }, { name = "construct" }, { name = "huggingface-hub" }, @@ -198,6 +217,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "accelerate", specifier = ">=1.13.0" }, { name = "asyncio", specifier = ">=4.0.0" }, { name = "construct", specifier = ">=2.10.70" }, { name = "huggingface-hub", specifier = ">=1.0.0" }, @@ -633,6 +653,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, ] +[[package]] +name = "psutil" +version = "7.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/08/510cbdb69c25a96f4ae523f733cdc963ae654904e8db864c07585ef99875/psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b", size = 130595, upload-time = "2026-01-28T18:14:57.293Z" }, + { url = "https://files.pythonhosted.org/packages/d6/f5/97baea3fe7a5a9af7436301f85490905379b1c6f2dd51fe3ecf24b4c5fbf/psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea", size = 131082, upload-time = "2026-01-28T18:14:59.732Z" }, + { url = "https://files.pythonhosted.org/packages/37/d6/246513fbf9fa174af531f28412297dd05241d97a75911ac8febefa1a53c6/psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63", size = 181476, upload-time = "2026-01-28T18:15:01.884Z" }, + { url = "https://files.pythonhosted.org/packages/b8/b5/9182c9af3836cca61696dabe4fd1304e17bc56cb62f17439e1154f225dd3/psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312", size = 184062, upload-time = "2026-01-28T18:15:04.436Z" }, + { url = "https://files.pythonhosted.org/packages/16/ba/0756dca669f5a9300d0cbcbfae9a4c30e446dfc7440ffe43ded5724bfd93/psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b", size = 139893, upload-time = "2026-01-28T18:15:06.378Z" }, + { url = "https://files.pythonhosted.org/packages/1c/61/8fa0e26f33623b49949346de05ec1ddaad02ed8ba64af45f40a147dbfa97/psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9", size = 135589, upload-time = "2026-01-28T18:15:08.03Z" }, + { url = "https://files.pythonhosted.org/packages/81/69/ef179ab5ca24f32acc1dac0c247fd6a13b501fd5534dbae0e05a1c48b66d/psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00", size = 130664, upload-time = "2026-01-28T18:15:09.469Z" }, + { url = "https://files.pythonhosted.org/packages/7b/64/665248b557a236d3fa9efc378d60d95ef56dd0a490c2cd37dafc7660d4a9/psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9", size = 131087, upload-time = "2026-01-28T18:15:11.724Z" }, + { url = "https://files.pythonhosted.org/packages/d5/2e/e6782744700d6759ebce3043dcfa661fb61e2fb752b91cdeae9af12c2178/psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a", size = 182383, upload-time = "2026-01-28T18:15:13.445Z" }, + { url = "https://files.pythonhosted.org/packages/57/49/0a41cefd10cb7505cdc04dab3eacf24c0c2cb158a998b8c7b1d27ee2c1f5/psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf", size = 185210, upload-time = "2026-01-28T18:15:16.002Z" }, + { url = "https://files.pythonhosted.org/packages/dd/2c/ff9bfb544f283ba5f83ba725a3c5fec6d6b10b8f27ac1dc641c473dc390d/psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1", size = 141228, upload-time = "2026-01-28T18:15:18.385Z" }, + { url = "https://files.pythonhosted.org/packages/f2/fc/f8d9c31db14fcec13748d373e668bc3bed94d9077dbc17fb0eebc073233c/psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841", size = 136284, upload-time = "2026-01-28T18:15:19.912Z" }, + { url = "https://files.pythonhosted.org/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486", size = 129090, upload-time = "2026-01-28T18:15:22.168Z" }, + { url = "https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979", size = 129859, upload-time = "2026-01-28T18:15:23.795Z" }, + { url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" }, + { url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" }, + { url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" }, + { url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" }, + { url = "https://files.pythonhosted.org/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988", size = 137737, upload-time = "2026-01-28T18:15:33.849Z" }, + { url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" }, +] + [[package]] name = "pycparser" version = "3.0" diff --git a/frontend/pages/chat.html b/frontend/pages/chat.html index 56bf4f7..e5cbfcf 100644 --- a/frontend/pages/chat.html +++ b/frontend/pages/chat.html @@ -29,10 +29,10 @@ -
Enter to send. Shift+Enter for a new line.
+
diff --git a/frontend/src/chat.rs b/frontend/src/chat.rs index 17a54de..de7eb0d 100644 --- a/frontend/src/chat.rs +++ b/frontend/src/chat.rs @@ -36,9 +36,12 @@ struct ChatState { message_nodes: Vec, is_receiving: bool, active_assistant_index: Option, + compose_role: ComposeRole, + generation_enabled: bool, } struct MessageNode { + wrapper: Element, content: Element, status: Option, } @@ -50,11 +53,37 @@ enum ChatStatus { Hidden, } +#[derive(Copy, Clone, PartialEq)] +enum ComposeRole { + System, + Assistant, + User, +} + +impl ComposeRole { + fn as_str(self) -> &'static str { + match self { + ComposeRole::System => "system", + ComposeRole::Assistant => "assistant", + ComposeRole::User => "user", + } + } + + fn next(self) -> Self { + match self { + ComposeRole::System => ComposeRole::Assistant, + ComposeRole::Assistant => ComposeRole::User, + ComposeRole::User => ComposeRole::System, + } + } +} + struct AppState { ws: WebSocket, document: Document, messages_container: Element, input: HtmlTextAreaElement, + status_label: Element, state: RefCell, } @@ -91,6 +120,7 @@ fn append_message( container.append_child(&wrapper)?; Ok(MessageNode { + wrapper, content: content_el, status: status_el, }) @@ -117,6 +147,12 @@ fn apply_status(node: &Element, status: ChatStatus) { } } +fn update_status_label(label: &Element, role: ComposeRole, generation_enabled: bool) { + let generation = if generation_enabled { "on" } else { "off" }; + let text = format!("Role: {} | Generation: {}", role.as_str(), generation); + label.set_text_content(Some(&text)); +} + fn scroll_to_bottom(container: &Element) { if let Some(element) = container.dyn_ref::() { let height = element.scroll_height(); @@ -141,6 +177,9 @@ pub fn init_chat() -> Result<(), JsValue> { .get_element_by_id("chat-input") .ok_or_else(|| JsValue::from_str("Missing chat-input element"))? .dyn_into()?; + let status_label = document + .get_element_by_id("chat-status") + .ok_or_else(|| JsValue::from_str("Missing chat-status element"))?; let ws = WebSocket::new(&ws_url)?; ws.set_binary_type(BinaryType::Arraybuffer); @@ -151,14 +190,22 @@ pub fn init_chat() -> Result<(), JsValue> { document, messages_container, input: input_el, + status_label, state: RefCell::new(ChatState { messages: Vec::new(), message_nodes: Vec::new(), is_receiving: false, active_assistant_index: None, + compose_role: ComposeRole::User, + generation_enabled: true, }), }); + { + let state = app.state.borrow(); + update_status_label(&app.status_label, state.compose_role, state.generation_enabled); + } + let onopen = Closure::::wrap(Box::new(move |_| { console::log_1(&"[ws] connected".into()); })); @@ -167,28 +214,75 @@ pub fn init_chat() -> Result<(), JsValue> { let app_for_keydown = app.clone(); let onkeydown = Closure::::wrap(Box::new(move |event: KeyboardEvent| { - if event.ctrl_key() && (event.key() == "c" || event.key() == "C") { - let state = app_for_keydown.state.borrow(); - if state.is_receiving { - event.prevent_default(); - if app_for_keydown.ws.ready_state() != WebSocket::OPEN { - console::error_1(&"[ws] socket is not open".into()); - return; - } + if event.ctrl_key() { + let key = event.key(); + if key == "c" || key == "C" { + let state = app_for_keydown.state.borrow(); + if state.is_receiving { + event.prevent_default(); + if app_for_keydown.ws.ready_state() != WebSocket::OPEN { + console::error_1(&"[ws] socket is not open".into()); + return; + } - let request = UserRequest::ChatCompletionCancellation; - match request.to_bytes() { - Ok(bytes) => { - if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) { - console::error_1(&format!("[ws] cancel send error: {:?}", err).into()); + let request = UserRequest::ChatCompletionCancellation; + match request.to_bytes() { + Ok(bytes) => { + if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) { + console::error_1(&format!("[ws] cancel send error: {:?}", err).into()); + } + } + Err(err) => { + console::error_1(&format!("[ws] cancel encode error: {err:#}").into()); } } - Err(err) => { - console::error_1(&format!("[ws] cancel encode error: {err:#}").into()); + } + return; + } + + if key == "1" { + event.prevent_default(); + let mut state = app_for_keydown.state.borrow_mut(); + state.compose_role = state.compose_role.next(); + update_status_label( + &app_for_keydown.status_label, + state.compose_role, + state.generation_enabled, + ); + return; + } + + if key == "4" { + event.prevent_default(); + let mut state = app_for_keydown.state.borrow_mut(); + state.generation_enabled = !state.generation_enabled; + update_status_label( + &app_for_keydown.status_label, + state.compose_role, + state.generation_enabled, + ); + return; + } + + if key == "d" || key == "D" { + event.prevent_default(); + let mut state = app_for_keydown.state.borrow_mut(); + if state.is_receiving { + return; + } + let last_message = match state.messages.pop() { + Some(message) => message, + None => return, + }; + if let Some(node) = state.message_nodes.pop() { + if let Some(parent) = node.wrapper.parent_node() { + let _ = parent.remove_child(&node.wrapper); } } + app_for_keydown.input.set_value(&last_message.content); + scroll_to_bottom(&app_for_keydown.messages_container); + return; } - return; } if event.key() != "Enter" || event.shift_key() { @@ -207,18 +301,14 @@ pub fn init_chat() -> Result<(), JsValue> { return; } - if app_for_keydown.ws.ready_state() != WebSocket::OPEN { - console::error_1(&"[ws] socket is not open".into()); - return; - } - + let role = state.compose_role; let user_content = trimmed.to_string(); app_for_keydown.input.set_value(""); let user_node = match append_message( &app_for_keydown.document, &app_for_keydown.messages_container, - "user", + role.as_str(), &user_content, None, ) { @@ -230,12 +320,21 @@ pub fn init_chat() -> Result<(), JsValue> { }; state.messages.push(UserChatMessage { - role: "user".to_string(), + role: role.as_str().to_string(), content: user_content, }); state.message_nodes.push(user_node); scroll_to_bottom(&app_for_keydown.messages_container); + if role != ComposeRole::User || !state.generation_enabled { + return; + } + + if app_for_keydown.ws.ready_state() != WebSocket::OPEN { + console::error_1(&"[ws] socket is not open".into()); + return; + } + let history = state.messages.clone(); let request = UserRequest::ChatCompletion(UserChatCompletionRequest::new(history)); diff --git a/frontend/static/css/site.css b/frontend/static/css/site.css index 28e6240..a15224b 100644 --- a/frontend/static/css/site.css +++ b/frontend/static/css/site.css @@ -154,9 +154,9 @@ body { color: #8a8a8a; } -.chat-input-hint { - font-size: 9.5pt; - color: #6b6b6b; +.chat-status-label { + font-size: 15pt; + color: #4b4b4b; } body {