Chatgpt wrote me something, I can't even bother verifying, it is 4:21

This commit is contained in:
Андреев Григорий 2026-03-29 04:21:57 +03:00
parent 336177941f
commit 63340f2594
9 changed files with 365 additions and 70 deletions

View File

@ -4,6 +4,7 @@ from huggingface_hub import snapshot_download
MODEL_ID = "zai-org/GLM-4.7-Flash" MODEL_ID = "zai-org/GLM-4.7-Flash"
# MODEL_ID = "mlabonne/Daredevil-8B-abliterated-GGUF"
def main() -> None: def main() -> None:

View File

@ -11,6 +11,7 @@ from transformers import (
) )
MODEL_ID = "zai-org/GLM-4.7-Flash" MODEL_ID = "zai-org/GLM-4.7-Flash"
# MODEL_ID = "/home/gregory/programming/testWithPython/Daredevil-GGUF/daredevil-8b-abliterated.Q8_0.gguf"
MAX_NEW_TOKENS = 256 MAX_NEW_TOKENS = 256
SYSTEM_PROMPT = ( SYSTEM_PROMPT = (

View File

@ -4,6 +4,7 @@ name = "dedicated_ai_server"
version = "0.1.0" version = "0.1.0"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"accelerate>=1.13.0",
"asyncio>=4.0.0", "asyncio>=4.0.0",
"construct>=2.10.70", "construct>=2.10.70",
"huggingface-hub>=1.0.0", "huggingface-hub>=1.0.0",

View File

@ -0,0 +1,54 @@
#!/usr/bin/env bash
set -euo pipefail
FISH_CONFIG_DIR="${HOME}/.config/fish"
FISH_FUNCTIONS_DIR="${FISH_CONFIG_DIR}/functions"
FISH_CONFIG_FILE="${FISH_CONFIG_DIR}/config.fish"
FISH_PROMPT_FILE="${FISH_FUNCTIONS_DIR}/fish_prompt.fish"
if ! command -v dnf >/dev/null 2>&1; then
echo "This script is for Fedora systems with dnf."
exit 1
fi
sudo dnf install -y htop tree git gcc gcc-c++ make cmake fish vim
curl -LsSf https://astral.sh/uv/install.sh | sh
if ! sudo dnf install -y uv; then
echo "uv package is not available in the currently enabled Fedora repositories."
fi
mkdir -p "${FISH_FUNCTIONS_DIR}"
touch "${FISH_CONFIG_FILE}"
if ! grep -qxF 'set -g fish_greeting' "${FISH_CONFIG_FILE}"; then
printf '\nset -g fish_greeting\n' >> "${FISH_CONFIG_FILE}"
fi
cat > "${FISH_PROMPT_FILE}" <<'EOF'
function fish_prompt
set -l last_status $status
set_color brblue
printf '%s' (prompt_pwd)
set_color normal
if test $last_status -ne 0
printf ' '
set_color red
printf '[%s]' $last_status
set_color normal
end
printf '\n'
set_color brcyan
printf '> '
set_color normal
end
EOF
echo "Installed: htop tree git gcc gcc-c++ make cmake fish"
echo "Fish prompt configured. Start it with: fish"

View File

@ -3,9 +3,11 @@ from __future__ import annotations
import asyncio import asyncio
import queue import queue
import threading import threading
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Iterable from typing import Any, Dict, Iterable
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from api import Request, Response from api import Request, Response
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
@ -15,8 +17,9 @@ SERVER_HOST = "127.0.0.1"
SERVER_PORT = 9000 SERVER_PORT = 9000
SHARED_SECRET = "change-me" SHARED_SECRET = "change-me"
MODEL_ID = "zai-org/GLM-4.7-Flash"
PROCESS_DELAY_SECONDS = 0.4 MAX_NEW_TOKENS = 256
PIECE_CHUNK_SIZE = 64
@dataclass @dataclass
@ -51,20 +54,81 @@ class WorkItem:
record: PendingChatCompletionRecord record: PendingChatCompletionRecord
def extract_last_message(messages: list) -> str: @dataclass
if not messages: class ModelBundle:
return "" tokenizer: Any
last = messages[-1] model: Any
if hasattr(last, "content"):
return last.content
return last.get("content", "")
def generate_uppercase_pieces(text: str) -> Iterable[str]: def fail(message: str) -> None:
words = text.split() print(message, flush=True)
for word in words: raise SystemExit(1)
time.sleep(PROCESS_DELAY_SECONDS)
yield word.upper() + " "
def build_context(messages: list) -> list[dict[str, str]]:
context: list[dict[str, str]] = []
for message in messages:
if hasattr(message, "role") and hasattr(message, "content"):
role = message.role
content = message.content
else:
role = message.get("role", "")
content = message.get("content", "")
context.append({"role": role, "content": content})
return context
def load_local_model() -> ModelBundle:
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
local_files_only=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
local_files_only=True,
dtype=torch.bfloat16,
device_map="auto",
)
except OSError as exc:
fail(
f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. "
f"Run B2.py first. Original error: {exc}"
)
return ModelBundle(
tokenizer=tokenizer,
model=model,
)
def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]:
context = build_context(messages)
inputs = bundle.tokenizer.apply_chat_template(
context,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = {name: tensor.to(bundle.model.device) for name, tensor in inputs.items()}
input_ids = inputs.get("input_ids")
input_len = int(input_ids.shape[1]) if input_ids is not None else 0
with torch.inference_mode():
output_ids = bundle.model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
generated_ids = output_ids[0][input_len:]
text = bundle.tokenizer.decode(generated_ids, skip_special_tokens=True)
if not text:
return []
return [text[i:i + PIECE_CHUNK_SIZE] for i in range(0, len(text), PIECE_CHUNK_SIZE)]
def worker_loop( def worker_loop(
@ -72,6 +136,7 @@ def worker_loop(
pending: Dict[int, PendingChatCompletionRecord], pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock, pending_lock: threading.Lock,
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
model_bundle: ModelBundle,
) -> None: ) -> None:
while True: while True:
item = work_queue.get() item = work_queue.get()
@ -84,15 +149,12 @@ def worker_loop(
if record.is_cancelled(): if record.is_cancelled():
cancelled = True cancelled = True
else: else:
text = extract_last_message(item.messages) for piece in generate_llm_pieces(model_bundle, item.messages):
for piece in generate_uppercase_pieces(text):
loop.call_soon_threadsafe( loop.call_soon_threadsafe(
record.response_queue.put_nowait, record.response_queue.put_nowait,
MessagePiece(piece=piece, is_end=False, is_cancel=False), MessagePiece(piece=piece, is_end=False, is_cancel=False),
) )
print("[debug] got a new piece")
if record.is_cancelled(): if record.is_cancelled():
print("[debug] record was cancelled")
cancelled = True cancelled = True
break break
@ -106,6 +168,12 @@ def worker_loop(
record.response_queue.put_nowait, record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=False), MessagePiece(is_end=True, is_cancel=False),
) )
except BaseException as exc:
print(f"[worker] generation failed: {exc}", flush=True)
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=True),
)
finally: finally:
with pending_lock: with pending_lock:
pending.pop(item.request_id, None) pending.pop(item.request_id, None)
@ -155,6 +223,38 @@ async def forward_pieces(
return return
async def handle_chat_request(
request: Request,
work_queue: queue.Queue[WorkItem | None],
pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock,
transport: SecretStreamSocket,
send_lock: asyncio.Lock,
response_tasks: set[asyncio.Task[None]],
) -> None:
request_id = request.request_id
print(f"[request] chat request_id={request_id}", flush=True)
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
record = PendingChatCompletionRecord(
response_queue=response_queue,
)
with pending_lock:
if request_id in pending:
raise ProtocolError(
f"Duplicate request_id {request_id} received on this connection"
)
pending[request_id] = record
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
task = asyncio.create_task(
forward_pieces(request_id, response_queue, transport, send_lock)
)
response_tasks.add(task)
task.add_done_callback(response_tasks.discard)
async def handle_client( async def handle_client(
reader: asyncio.StreamReader, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter, writer: asyncio.StreamWriter,
@ -178,27 +278,16 @@ async def handle_client(
request_id = request.request_id request_id = request.request_id
if request.kind == "chat": if request.kind == "chat":
print(f"[request] chat request_id={request_id}", flush=True) await handle_chat_request(
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue() request,
record = PendingChatCompletionRecord( work_queue,
response_queue=response_queue, pending,
pending_lock,
transport,
send_lock,
response_tasks,
) )
with pending_lock:
if request_id in pending:
raise ProtocolError(
f"Duplicate request_id {request_id} received on this connection"
)
pending[request_id] = record
owned_request_ids.add(request_id) owned_request_ids.add(request_id)
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
task = asyncio.create_task(
forward_pieces(request_id, response_queue, transport, send_lock)
)
response_tasks.add(task)
task.add_done_callback(response_tasks.discard)
elif request.kind == "cancel": elif request.kind == "cancel":
print(f"[request] cancel request_id={request_id}", flush=True) print(f"[request] cancel request_id={request_id}", flush=True)
with pending_lock: with pending_lock:
@ -224,7 +313,7 @@ async def handle_client(
await transport.close() await transport.close()
async def run_server() -> None: async def run_server(model_bundle: ModelBundle) -> None:
pending: Dict[int, PendingChatCompletionRecord] = {} pending: Dict[int, PendingChatCompletionRecord] = {}
pending_lock = threading.Lock() pending_lock = threading.Lock()
work_queue: queue.Queue[WorkItem | None] = queue.Queue() work_queue: queue.Queue[WorkItem | None] = queue.Queue()
@ -232,7 +321,7 @@ async def run_server() -> None:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
worker = threading.Thread( worker = threading.Thread(
target=worker_loop, target=worker_loop,
args=(work_queue, pending, pending_lock, loop), args=(work_queue, pending, pending_lock, loop, model_bundle),
daemon=True, daemon=True,
) )
worker.start() worker.start()
@ -256,7 +345,9 @@ async def run_server() -> None:
def main() -> None: def main() -> None:
asyncio.run(run_server()) model_bundle = load_local_model()
print("[model] loaded", flush=True)
asyncio.run(run_server(model_bundle))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -2,6 +2,24 @@ version = 1
revision = 3 revision = 3
requires-python = ">=3.13" requires-python = ">=3.13"
[[package]]
name = "accelerate"
version = "1.13.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "psutil" },
{ name = "pyyaml" },
{ name = "safetensors" },
{ name = "torch" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" },
]
[[package]] [[package]]
name = "annotated-doc" name = "annotated-doc"
version = "0.0.4" version = "0.0.4"
@ -188,6 +206,7 @@ name = "dedicated-ai-server"
version = "0.1.0" version = "0.1.0"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "accelerate" },
{ name = "asyncio" }, { name = "asyncio" },
{ name = "construct" }, { name = "construct" },
{ name = "huggingface-hub" }, { name = "huggingface-hub" },
@ -198,6 +217,7 @@ dependencies = [
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "accelerate", specifier = ">=1.13.0" },
{ name = "asyncio", specifier = ">=4.0.0" }, { name = "asyncio", specifier = ">=4.0.0" },
{ name = "construct", specifier = ">=2.10.70" }, { name = "construct", specifier = ">=2.10.70" },
{ name = "huggingface-hub", specifier = ">=1.0.0" }, { name = "huggingface-hub", specifier = ">=1.0.0" },
@ -633,6 +653,34 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" }, { url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
] ]
[[package]]
name = "psutil"
version = "7.2.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/51/08/510cbdb69c25a96f4ae523f733cdc963ae654904e8db864c07585ef99875/psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b", size = 130595, upload-time = "2026-01-28T18:14:57.293Z" },
{ url = "https://files.pythonhosted.org/packages/d6/f5/97baea3fe7a5a9af7436301f85490905379b1c6f2dd51fe3ecf24b4c5fbf/psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea", size = 131082, upload-time = "2026-01-28T18:14:59.732Z" },
{ url = "https://files.pythonhosted.org/packages/37/d6/246513fbf9fa174af531f28412297dd05241d97a75911ac8febefa1a53c6/psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63", size = 181476, upload-time = "2026-01-28T18:15:01.884Z" },
{ url = "https://files.pythonhosted.org/packages/b8/b5/9182c9af3836cca61696dabe4fd1304e17bc56cb62f17439e1154f225dd3/psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312", size = 184062, upload-time = "2026-01-28T18:15:04.436Z" },
{ url = "https://files.pythonhosted.org/packages/16/ba/0756dca669f5a9300d0cbcbfae9a4c30e446dfc7440ffe43ded5724bfd93/psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b", size = 139893, upload-time = "2026-01-28T18:15:06.378Z" },
{ url = "https://files.pythonhosted.org/packages/1c/61/8fa0e26f33623b49949346de05ec1ddaad02ed8ba64af45f40a147dbfa97/psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9", size = 135589, upload-time = "2026-01-28T18:15:08.03Z" },
{ url = "https://files.pythonhosted.org/packages/81/69/ef179ab5ca24f32acc1dac0c247fd6a13b501fd5534dbae0e05a1c48b66d/psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00", size = 130664, upload-time = "2026-01-28T18:15:09.469Z" },
{ url = "https://files.pythonhosted.org/packages/7b/64/665248b557a236d3fa9efc378d60d95ef56dd0a490c2cd37dafc7660d4a9/psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9", size = 131087, upload-time = "2026-01-28T18:15:11.724Z" },
{ url = "https://files.pythonhosted.org/packages/d5/2e/e6782744700d6759ebce3043dcfa661fb61e2fb752b91cdeae9af12c2178/psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a", size = 182383, upload-time = "2026-01-28T18:15:13.445Z" },
{ url = "https://files.pythonhosted.org/packages/57/49/0a41cefd10cb7505cdc04dab3eacf24c0c2cb158a998b8c7b1d27ee2c1f5/psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf", size = 185210, upload-time = "2026-01-28T18:15:16.002Z" },
{ url = "https://files.pythonhosted.org/packages/dd/2c/ff9bfb544f283ba5f83ba725a3c5fec6d6b10b8f27ac1dc641c473dc390d/psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1", size = 141228, upload-time = "2026-01-28T18:15:18.385Z" },
{ url = "https://files.pythonhosted.org/packages/f2/fc/f8d9c31db14fcec13748d373e668bc3bed94d9077dbc17fb0eebc073233c/psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841", size = 136284, upload-time = "2026-01-28T18:15:19.912Z" },
{ url = "https://files.pythonhosted.org/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486", size = 129090, upload-time = "2026-01-28T18:15:22.168Z" },
{ url = "https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979", size = 129859, upload-time = "2026-01-28T18:15:23.795Z" },
{ url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" },
{ url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" },
{ url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" },
{ url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" },
{ url = "https://files.pythonhosted.org/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988", size = 137737, upload-time = "2026-01-28T18:15:33.849Z" },
{ url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" },
]
[[package]] [[package]]
name = "pycparser" name = "pycparser"
version = "3.0" version = "3.0"

View File

@ -29,10 +29,10 @@
<textarea <textarea
class="chat-input" class="chat-input"
id="chat-input" id="chat-input"
placeholder="Write a message. Enter to send, Shift+Enter for a new line." placeholder="Write a message. Shift+Enter for a new line."
rows="4" rows="4"
></textarea> ></textarea>
<div class="chat-input-hint">Enter to send. Shift+Enter for a new line.</div> <div class="chat-status-label" id="chat-status"></div>
</div> </div>
</div> </div>
</section> </section>

View File

@ -36,9 +36,12 @@ struct ChatState {
message_nodes: Vec<MessageNode>, message_nodes: Vec<MessageNode>,
is_receiving: bool, is_receiving: bool,
active_assistant_index: Option<usize>, active_assistant_index: Option<usize>,
compose_role: ComposeRole,
generation_enabled: bool,
} }
struct MessageNode { struct MessageNode {
wrapper: Element,
content: Element, content: Element,
status: Option<Element>, status: Option<Element>,
} }
@ -50,11 +53,37 @@ enum ChatStatus {
Hidden, Hidden,
} }
#[derive(Copy, Clone, PartialEq)]
enum ComposeRole {
System,
Assistant,
User,
}
impl ComposeRole {
fn as_str(self) -> &'static str {
match self {
ComposeRole::System => "system",
ComposeRole::Assistant => "assistant",
ComposeRole::User => "user",
}
}
fn next(self) -> Self {
match self {
ComposeRole::System => ComposeRole::Assistant,
ComposeRole::Assistant => ComposeRole::User,
ComposeRole::User => ComposeRole::System,
}
}
}
struct AppState { struct AppState {
ws: WebSocket, ws: WebSocket,
document: Document, document: Document,
messages_container: Element, messages_container: Element,
input: HtmlTextAreaElement, input: HtmlTextAreaElement,
status_label: Element,
state: RefCell<ChatState>, state: RefCell<ChatState>,
} }
@ -91,6 +120,7 @@ fn append_message(
container.append_child(&wrapper)?; container.append_child(&wrapper)?;
Ok(MessageNode { Ok(MessageNode {
wrapper,
content: content_el, content: content_el,
status: status_el, status: status_el,
}) })
@ -117,6 +147,12 @@ fn apply_status(node: &Element, status: ChatStatus) {
} }
} }
fn update_status_label(label: &Element, role: ComposeRole, generation_enabled: bool) {
let generation = if generation_enabled { "on" } else { "off" };
let text = format!("Role: {} | Generation: {}", role.as_str(), generation);
label.set_text_content(Some(&text));
}
fn scroll_to_bottom(container: &Element) { fn scroll_to_bottom(container: &Element) {
if let Some(element) = container.dyn_ref::<HtmlElement>() { if let Some(element) = container.dyn_ref::<HtmlElement>() {
let height = element.scroll_height(); let height = element.scroll_height();
@ -141,6 +177,9 @@ pub fn init_chat() -> Result<(), JsValue> {
.get_element_by_id("chat-input") .get_element_by_id("chat-input")
.ok_or_else(|| JsValue::from_str("Missing chat-input element"))? .ok_or_else(|| JsValue::from_str("Missing chat-input element"))?
.dyn_into()?; .dyn_into()?;
let status_label = document
.get_element_by_id("chat-status")
.ok_or_else(|| JsValue::from_str("Missing chat-status element"))?;
let ws = WebSocket::new(&ws_url)?; let ws = WebSocket::new(&ws_url)?;
ws.set_binary_type(BinaryType::Arraybuffer); ws.set_binary_type(BinaryType::Arraybuffer);
@ -151,14 +190,22 @@ pub fn init_chat() -> Result<(), JsValue> {
document, document,
messages_container, messages_container,
input: input_el, input: input_el,
status_label,
state: RefCell::new(ChatState { state: RefCell::new(ChatState {
messages: Vec::new(), messages: Vec::new(),
message_nodes: Vec::new(), message_nodes: Vec::new(),
is_receiving: false, is_receiving: false,
active_assistant_index: None, active_assistant_index: None,
compose_role: ComposeRole::User,
generation_enabled: true,
}), }),
}); });
{
let state = app.state.borrow();
update_status_label(&app.status_label, state.compose_role, state.generation_enabled);
}
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| { let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
console::log_1(&"[ws] connected".into()); console::log_1(&"[ws] connected".into());
})); }));
@ -167,7 +214,9 @@ pub fn init_chat() -> Result<(), JsValue> {
let app_for_keydown = app.clone(); let app_for_keydown = app.clone();
let onkeydown = Closure::<dyn FnMut(KeyboardEvent)>::wrap(Box::new(move |event: KeyboardEvent| { let onkeydown = Closure::<dyn FnMut(KeyboardEvent)>::wrap(Box::new(move |event: KeyboardEvent| {
if event.ctrl_key() && (event.key() == "c" || event.key() == "C") { if event.ctrl_key() {
let key = event.key();
if key == "c" || key == "C" {
let state = app_for_keydown.state.borrow(); let state = app_for_keydown.state.borrow();
if state.is_receiving { if state.is_receiving {
event.prevent_default(); event.prevent_default();
@ -191,6 +240,51 @@ pub fn init_chat() -> Result<(), JsValue> {
return; return;
} }
if key == "1" {
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
state.compose_role = state.compose_role.next();
update_status_label(
&app_for_keydown.status_label,
state.compose_role,
state.generation_enabled,
);
return;
}
if key == "4" {
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
state.generation_enabled = !state.generation_enabled;
update_status_label(
&app_for_keydown.status_label,
state.compose_role,
state.generation_enabled,
);
return;
}
if key == "d" || key == "D" {
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
if state.is_receiving {
return;
}
let last_message = match state.messages.pop() {
Some(message) => message,
None => return,
};
if let Some(node) = state.message_nodes.pop() {
if let Some(parent) = node.wrapper.parent_node() {
let _ = parent.remove_child(&node.wrapper);
}
}
app_for_keydown.input.set_value(&last_message.content);
scroll_to_bottom(&app_for_keydown.messages_container);
return;
}
}
if event.key() != "Enter" || event.shift_key() { if event.key() != "Enter" || event.shift_key() {
return; return;
} }
@ -207,18 +301,14 @@ pub fn init_chat() -> Result<(), JsValue> {
return; return;
} }
if app_for_keydown.ws.ready_state() != WebSocket::OPEN { let role = state.compose_role;
console::error_1(&"[ws] socket is not open".into());
return;
}
let user_content = trimmed.to_string(); let user_content = trimmed.to_string();
app_for_keydown.input.set_value(""); app_for_keydown.input.set_value("");
let user_node = match append_message( let user_node = match append_message(
&app_for_keydown.document, &app_for_keydown.document,
&app_for_keydown.messages_container, &app_for_keydown.messages_container,
"user", role.as_str(),
&user_content, &user_content,
None, None,
) { ) {
@ -230,12 +320,21 @@ pub fn init_chat() -> Result<(), JsValue> {
}; };
state.messages.push(UserChatMessage { state.messages.push(UserChatMessage {
role: "user".to_string(), role: role.as_str().to_string(),
content: user_content, content: user_content,
}); });
state.message_nodes.push(user_node); state.message_nodes.push(user_node);
scroll_to_bottom(&app_for_keydown.messages_container); scroll_to_bottom(&app_for_keydown.messages_container);
if role != ComposeRole::User || !state.generation_enabled {
return;
}
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
console::error_1(&"[ws] socket is not open".into());
return;
}
let history = state.messages.clone(); let history = state.messages.clone();
let request = UserRequest::ChatCompletion(UserChatCompletionRequest::new(history)); let request = UserRequest::ChatCompletion(UserChatCompletionRequest::new(history));

View File

@ -154,9 +154,9 @@ body {
color: #8a8a8a; color: #8a8a8a;
} }
.chat-input-hint { .chat-status-label {
font-size: 9.5pt; font-size: 15pt;
color: #6b6b6b; color: #4b4b4b;
} }
body { body {