Chatgpt wrote me something, I can't even bother verifying, it is 4:21

This commit is contained in:
Андреев Григорий 2026-03-29 04:21:57 +03:00
parent 336177941f
commit 63340f2594
9 changed files with 365 additions and 70 deletions

View File

@ -4,6 +4,7 @@ from huggingface_hub import snapshot_download
MODEL_ID = "zai-org/GLM-4.7-Flash"
# MODEL_ID = "mlabonne/Daredevil-8B-abliterated-GGUF"
def main() -> None:

View File

@ -11,6 +11,7 @@ from transformers import (
)
MODEL_ID = "zai-org/GLM-4.7-Flash"
# MODEL_ID = "/home/gregory/programming/testWithPython/Daredevil-GGUF/daredevil-8b-abliterated.Q8_0.gguf"
MAX_NEW_TOKENS = 256
SYSTEM_PROMPT = (

View File

@ -4,6 +4,7 @@ name = "dedicated_ai_server"
version = "0.1.0"
requires-python = ">=3.13"
dependencies = [
"accelerate>=1.13.0",
"asyncio>=4.0.0",
"construct>=2.10.70",
"huggingface-hub>=1.0.0",

View File

@ -0,0 +1,54 @@
#!/usr/bin/env bash
set -euo pipefail
FISH_CONFIG_DIR="${HOME}/.config/fish"
FISH_FUNCTIONS_DIR="${FISH_CONFIG_DIR}/functions"
FISH_CONFIG_FILE="${FISH_CONFIG_DIR}/config.fish"
FISH_PROMPT_FILE="${FISH_FUNCTIONS_DIR}/fish_prompt.fish"
if ! command -v dnf >/dev/null 2>&1; then
echo "This script is for Fedora systems with dnf."
exit 1
fi
sudo dnf install -y htop tree git gcc gcc-c++ make cmake fish vim
curl -LsSf https://astral.sh/uv/install.sh | sh
if ! sudo dnf install -y uv; then
echo "uv package is not available in the currently enabled Fedora repositories."
fi
mkdir -p "${FISH_FUNCTIONS_DIR}"
touch "${FISH_CONFIG_FILE}"
if ! grep -qxF 'set -g fish_greeting' "${FISH_CONFIG_FILE}"; then
printf '\nset -g fish_greeting\n' >> "${FISH_CONFIG_FILE}"
fi
cat > "${FISH_PROMPT_FILE}" <<'EOF'
function fish_prompt
set -l last_status $status
set_color brblue
printf '%s' (prompt_pwd)
set_color normal
if test $last_status -ne 0
printf ' '
set_color red
printf '[%s]' $last_status
set_color normal
end
printf '\n'
set_color brcyan
printf '> '
set_color normal
end
EOF
echo "Installed: htop tree git gcc gcc-c++ make cmake fish"
echo "Fish prompt configured. Start it with: fish"

View File

@ -3,9 +3,11 @@ from __future__ import annotations
import asyncio
import queue
import threading
import time
from dataclasses import dataclass, field
from typing import Dict, Iterable
from typing import Any, Dict, Iterable
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from api import Request, Response
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
@ -15,8 +17,9 @@ SERVER_HOST = "127.0.0.1"
SERVER_PORT = 9000
SHARED_SECRET = "change-me"
PROCESS_DELAY_SECONDS = 0.4
MODEL_ID = "zai-org/GLM-4.7-Flash"
MAX_NEW_TOKENS = 256
PIECE_CHUNK_SIZE = 64
@dataclass
@ -51,20 +54,81 @@ class WorkItem:
record: PendingChatCompletionRecord
def extract_last_message(messages: list) -> str:
if not messages:
return ""
last = messages[-1]
if hasattr(last, "content"):
return last.content
return last.get("content", "")
@dataclass
class ModelBundle:
tokenizer: Any
model: Any
def generate_uppercase_pieces(text: str) -> Iterable[str]:
words = text.split()
for word in words:
time.sleep(PROCESS_DELAY_SECONDS)
yield word.upper() + " "
def fail(message: str) -> None:
print(message, flush=True)
raise SystemExit(1)
def build_context(messages: list) -> list[dict[str, str]]:
context: list[dict[str, str]] = []
for message in messages:
if hasattr(message, "role") and hasattr(message, "content"):
role = message.role
content = message.content
else:
role = message.get("role", "")
content = message.get("content", "")
context.append({"role": role, "content": content})
return context
def load_local_model() -> ModelBundle:
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
local_files_only=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
local_files_only=True,
dtype=torch.bfloat16,
device_map="auto",
)
except OSError as exc:
fail(
f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. "
f"Run B2.py first. Original error: {exc}"
)
return ModelBundle(
tokenizer=tokenizer,
model=model,
)
def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]:
context = build_context(messages)
inputs = bundle.tokenizer.apply_chat_template(
context,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs = {name: tensor.to(bundle.model.device) for name, tensor in inputs.items()}
input_ids = inputs.get("input_ids")
input_len = int(input_ids.shape[1]) if input_ids is not None else 0
with torch.inference_mode():
output_ids = bundle.model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False,
)
generated_ids = output_ids[0][input_len:]
text = bundle.tokenizer.decode(generated_ids, skip_special_tokens=True)
if not text:
return []
return [text[i:i + PIECE_CHUNK_SIZE] for i in range(0, len(text), PIECE_CHUNK_SIZE)]
def worker_loop(
@ -72,6 +136,7 @@ def worker_loop(
pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock,
loop: asyncio.AbstractEventLoop,
model_bundle: ModelBundle,
) -> None:
while True:
item = work_queue.get()
@ -84,15 +149,12 @@ def worker_loop(
if record.is_cancelled():
cancelled = True
else:
text = extract_last_message(item.messages)
for piece in generate_uppercase_pieces(text):
for piece in generate_llm_pieces(model_bundle, item.messages):
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(piece=piece, is_end=False, is_cancel=False),
)
print("[debug] got a new piece")
if record.is_cancelled():
print("[debug] record was cancelled")
cancelled = True
break
@ -106,6 +168,12 @@ def worker_loop(
record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=False),
)
except BaseException as exc:
print(f"[worker] generation failed: {exc}", flush=True)
loop.call_soon_threadsafe(
record.response_queue.put_nowait,
MessagePiece(is_end=True, is_cancel=True),
)
finally:
with pending_lock:
pending.pop(item.request_id, None)
@ -155,6 +223,38 @@ async def forward_pieces(
return
async def handle_chat_request(
request: Request,
work_queue: queue.Queue[WorkItem | None],
pending: Dict[int, PendingChatCompletionRecord],
pending_lock: threading.Lock,
transport: SecretStreamSocket,
send_lock: asyncio.Lock,
response_tasks: set[asyncio.Task[None]],
) -> None:
request_id = request.request_id
print(f"[request] chat request_id={request_id}", flush=True)
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
record = PendingChatCompletionRecord(
response_queue=response_queue,
)
with pending_lock:
if request_id in pending:
raise ProtocolError(
f"Duplicate request_id {request_id} received on this connection"
)
pending[request_id] = record
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
task = asyncio.create_task(
forward_pieces(request_id, response_queue, transport, send_lock)
)
response_tasks.add(task)
task.add_done_callback(response_tasks.discard)
async def handle_client(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
@ -178,27 +278,16 @@ async def handle_client(
request_id = request.request_id
if request.kind == "chat":
print(f"[request] chat request_id={request_id}", flush=True)
response_queue: asyncio.Queue[MessagePiece] = asyncio.Queue()
record = PendingChatCompletionRecord(
response_queue=response_queue,
await handle_chat_request(
request,
work_queue,
pending,
pending_lock,
transport,
send_lock,
response_tasks,
)
with pending_lock:
if request_id in pending:
raise ProtocolError(
f"Duplicate request_id {request_id} received on this connection"
)
pending[request_id] = record
owned_request_ids.add(request_id)
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
task = asyncio.create_task(
forward_pieces(request_id, response_queue, transport, send_lock)
)
response_tasks.add(task)
task.add_done_callback(response_tasks.discard)
elif request.kind == "cancel":
print(f"[request] cancel request_id={request_id}", flush=True)
with pending_lock:
@ -224,7 +313,7 @@ async def handle_client(
await transport.close()
async def run_server() -> None:
async def run_server(model_bundle: ModelBundle) -> None:
pending: Dict[int, PendingChatCompletionRecord] = {}
pending_lock = threading.Lock()
work_queue: queue.Queue[WorkItem | None] = queue.Queue()
@ -232,7 +321,7 @@ async def run_server() -> None:
loop = asyncio.get_running_loop()
worker = threading.Thread(
target=worker_loop,
args=(work_queue, pending, pending_lock, loop),
args=(work_queue, pending, pending_lock, loop, model_bundle),
daemon=True,
)
worker.start()
@ -256,7 +345,9 @@ async def run_server() -> None:
def main() -> None:
asyncio.run(run_server())
model_bundle = load_local_model()
print("[model] loaded", flush=True)
asyncio.run(run_server(model_bundle))
if __name__ == "__main__":

View File

@ -2,6 +2,24 @@ version = 1
revision = 3
requires-python = ">=3.13"
[[package]]
name = "accelerate"
version = "1.13.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "huggingface-hub" },
{ name = "numpy" },
{ name = "packaging" },
{ name = "psutil" },
{ name = "pyyaml" },
{ name = "safetensors" },
{ name = "torch" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ca/14/787e5498cd062640f0f3d92ef4ae4063174f76f9afd29d13fc52a319daae/accelerate-1.13.0.tar.gz", hash = "sha256:d631b4e0f5b3de4aff2d7e9e6857d164810dfc3237d54d017f075122d057b236", size = 402835, upload-time = "2026-03-04T19:34:12.359Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/7e/46/02ac5e262d4af18054b3e922b2baedbb2a03289ee792162de60a865defc5/accelerate-1.13.0-py3-none-any.whl", hash = "sha256:cf1a3efb96c18f7b152eb0fa7490f3710b19c3f395699358f08decca2b8b62e0", size = 383744, upload-time = "2026-03-04T19:34:10.313Z" },
]
[[package]]
name = "annotated-doc"
version = "0.0.4"
@ -188,6 +206,7 @@ name = "dedicated-ai-server"
version = "0.1.0"
source = { virtual = "." }
dependencies = [
{ name = "accelerate" },
{ name = "asyncio" },
{ name = "construct" },
{ name = "huggingface-hub" },
@ -198,6 +217,7 @@ dependencies = [
[package.metadata]
requires-dist = [
{ name = "accelerate", specifier = ">=1.13.0" },
{ name = "asyncio", specifier = ">=4.0.0" },
{ name = "construct", specifier = ">=2.10.70" },
{ name = "huggingface-hub", specifier = ">=1.0.0" },
@ -633,6 +653,34 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b7/b9/c538f279a4e237a006a2c98387d081e9eb060d203d8ed34467cc0f0b9b53/packaging-26.0-py3-none-any.whl", hash = "sha256:b36f1fef9334a5588b4166f8bcd26a14e521f2b55e6b9de3aaa80d3ff7a37529", size = 74366, upload-time = "2026-01-21T20:50:37.788Z" },
]
[[package]]
name = "psutil"
version = "7.2.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/51/08/510cbdb69c25a96f4ae523f733cdc963ae654904e8db864c07585ef99875/psutil-7.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:2edccc433cbfa046b980b0df0171cd25bcaeb3a68fe9022db0979e7aa74a826b", size = 130595, upload-time = "2026-01-28T18:14:57.293Z" },
{ url = "https://files.pythonhosted.org/packages/d6/f5/97baea3fe7a5a9af7436301f85490905379b1c6f2dd51fe3ecf24b4c5fbf/psutil-7.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:e78c8603dcd9a04c7364f1a3e670cea95d51ee865e4efb3556a3a63adef958ea", size = 131082, upload-time = "2026-01-28T18:14:59.732Z" },
{ url = "https://files.pythonhosted.org/packages/37/d6/246513fbf9fa174af531f28412297dd05241d97a75911ac8febefa1a53c6/psutil-7.2.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1a571f2330c966c62aeda00dd24620425d4b0cc86881c89861fbc04549e5dc63", size = 181476, upload-time = "2026-01-28T18:15:01.884Z" },
{ url = "https://files.pythonhosted.org/packages/b8/b5/9182c9af3836cca61696dabe4fd1304e17bc56cb62f17439e1154f225dd3/psutil-7.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:917e891983ca3c1887b4ef36447b1e0873e70c933afc831c6b6da078ba474312", size = 184062, upload-time = "2026-01-28T18:15:04.436Z" },
{ url = "https://files.pythonhosted.org/packages/16/ba/0756dca669f5a9300d0cbcbfae9a4c30e446dfc7440ffe43ded5724bfd93/psutil-7.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:ab486563df44c17f5173621c7b198955bd6b613fb87c71c161f827d3fb149a9b", size = 139893, upload-time = "2026-01-28T18:15:06.378Z" },
{ url = "https://files.pythonhosted.org/packages/1c/61/8fa0e26f33623b49949346de05ec1ddaad02ed8ba64af45f40a147dbfa97/psutil-7.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:ae0aefdd8796a7737eccea863f80f81e468a1e4cf14d926bd9b6f5f2d5f90ca9", size = 135589, upload-time = "2026-01-28T18:15:08.03Z" },
{ url = "https://files.pythonhosted.org/packages/81/69/ef179ab5ca24f32acc1dac0c247fd6a13b501fd5534dbae0e05a1c48b66d/psutil-7.2.2-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:eed63d3b4d62449571547b60578c5b2c4bcccc5387148db46e0c2313dad0ee00", size = 130664, upload-time = "2026-01-28T18:15:09.469Z" },
{ url = "https://files.pythonhosted.org/packages/7b/64/665248b557a236d3fa9efc378d60d95ef56dd0a490c2cd37dafc7660d4a9/psutil-7.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:7b6d09433a10592ce39b13d7be5a54fbac1d1228ed29abc880fb23df7cb694c9", size = 131087, upload-time = "2026-01-28T18:15:11.724Z" },
{ url = "https://files.pythonhosted.org/packages/d5/2e/e6782744700d6759ebce3043dcfa661fb61e2fb752b91cdeae9af12c2178/psutil-7.2.2-cp314-cp314t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1fa4ecf83bcdf6e6c8f4449aff98eefb5d0604bf88cb883d7da3d8d2d909546a", size = 182383, upload-time = "2026-01-28T18:15:13.445Z" },
{ url = "https://files.pythonhosted.org/packages/57/49/0a41cefd10cb7505cdc04dab3eacf24c0c2cb158a998b8c7b1d27ee2c1f5/psutil-7.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e452c464a02e7dc7822a05d25db4cde564444a67e58539a00f929c51eddda0cf", size = 185210, upload-time = "2026-01-28T18:15:16.002Z" },
{ url = "https://files.pythonhosted.org/packages/dd/2c/ff9bfb544f283ba5f83ba725a3c5fec6d6b10b8f27ac1dc641c473dc390d/psutil-7.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:c7663d4e37f13e884d13994247449e9f8f574bc4655d509c3b95e9ec9e2b9dc1", size = 141228, upload-time = "2026-01-28T18:15:18.385Z" },
{ url = "https://files.pythonhosted.org/packages/f2/fc/f8d9c31db14fcec13748d373e668bc3bed94d9077dbc17fb0eebc073233c/psutil-7.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:11fe5a4f613759764e79c65cf11ebdf26e33d6dd34336f8a337aa2996d71c841", size = 136284, upload-time = "2026-01-28T18:15:19.912Z" },
{ url = "https://files.pythonhosted.org/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486", size = 129090, upload-time = "2026-01-28T18:15:22.168Z" },
{ url = "https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979", size = 129859, upload-time = "2026-01-28T18:15:23.795Z" },
{ url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" },
{ url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" },
{ url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" },
{ url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" },
{ url = "https://files.pythonhosted.org/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988", size = 137737, upload-time = "2026-01-28T18:15:33.849Z" },
{ url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" },
]
[[package]]
name = "pycparser"
version = "3.0"

View File

@ -29,10 +29,10 @@
<textarea
class="chat-input"
id="chat-input"
placeholder="Write a message. Enter to send, Shift+Enter for a new line."
placeholder="Write a message. Shift+Enter for a new line."
rows="4"
></textarea>
<div class="chat-input-hint">Enter to send. Shift+Enter for a new line.</div>
<div class="chat-status-label" id="chat-status"></div>
</div>
</div>
</section>

View File

@ -36,9 +36,12 @@ struct ChatState {
message_nodes: Vec<MessageNode>,
is_receiving: bool,
active_assistant_index: Option<usize>,
compose_role: ComposeRole,
generation_enabled: bool,
}
struct MessageNode {
wrapper: Element,
content: Element,
status: Option<Element>,
}
@ -50,11 +53,37 @@ enum ChatStatus {
Hidden,
}
#[derive(Copy, Clone, PartialEq)]
enum ComposeRole {
System,
Assistant,
User,
}
impl ComposeRole {
fn as_str(self) -> &'static str {
match self {
ComposeRole::System => "system",
ComposeRole::Assistant => "assistant",
ComposeRole::User => "user",
}
}
fn next(self) -> Self {
match self {
ComposeRole::System => ComposeRole::Assistant,
ComposeRole::Assistant => ComposeRole::User,
ComposeRole::User => ComposeRole::System,
}
}
}
struct AppState {
ws: WebSocket,
document: Document,
messages_container: Element,
input: HtmlTextAreaElement,
status_label: Element,
state: RefCell<ChatState>,
}
@ -91,6 +120,7 @@ fn append_message(
container.append_child(&wrapper)?;
Ok(MessageNode {
wrapper,
content: content_el,
status: status_el,
})
@ -117,6 +147,12 @@ fn apply_status(node: &Element, status: ChatStatus) {
}
}
fn update_status_label(label: &Element, role: ComposeRole, generation_enabled: bool) {
let generation = if generation_enabled { "on" } else { "off" };
let text = format!("Role: {} | Generation: {}", role.as_str(), generation);
label.set_text_content(Some(&text));
}
fn scroll_to_bottom(container: &Element) {
if let Some(element) = container.dyn_ref::<HtmlElement>() {
let height = element.scroll_height();
@ -141,6 +177,9 @@ pub fn init_chat() -> Result<(), JsValue> {
.get_element_by_id("chat-input")
.ok_or_else(|| JsValue::from_str("Missing chat-input element"))?
.dyn_into()?;
let status_label = document
.get_element_by_id("chat-status")
.ok_or_else(|| JsValue::from_str("Missing chat-status element"))?;
let ws = WebSocket::new(&ws_url)?;
ws.set_binary_type(BinaryType::Arraybuffer);
@ -151,14 +190,22 @@ pub fn init_chat() -> Result<(), JsValue> {
document,
messages_container,
input: input_el,
status_label,
state: RefCell::new(ChatState {
messages: Vec::new(),
message_nodes: Vec::new(),
is_receiving: false,
active_assistant_index: None,
compose_role: ComposeRole::User,
generation_enabled: true,
}),
});
{
let state = app.state.borrow();
update_status_label(&app.status_label, state.compose_role, state.generation_enabled);
}
let onopen = Closure::<dyn FnMut(Event)>::wrap(Box::new(move |_| {
console::log_1(&"[ws] connected".into());
}));
@ -167,28 +214,75 @@ pub fn init_chat() -> Result<(), JsValue> {
let app_for_keydown = app.clone();
let onkeydown = Closure::<dyn FnMut(KeyboardEvent)>::wrap(Box::new(move |event: KeyboardEvent| {
if event.ctrl_key() && (event.key() == "c" || event.key() == "C") {
let state = app_for_keydown.state.borrow();
if state.is_receiving {
event.prevent_default();
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
console::error_1(&"[ws] socket is not open".into());
return;
}
if event.ctrl_key() {
let key = event.key();
if key == "c" || key == "C" {
let state = app_for_keydown.state.borrow();
if state.is_receiving {
event.prevent_default();
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
console::error_1(&"[ws] socket is not open".into());
return;
}
let request = UserRequest::ChatCompletionCancellation;
match request.to_bytes() {
Ok(bytes) => {
if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) {
console::error_1(&format!("[ws] cancel send error: {:?}", err).into());
let request = UserRequest::ChatCompletionCancellation;
match request.to_bytes() {
Ok(bytes) => {
if let Err(err) = app_for_keydown.ws.send_with_u8_array(&bytes) {
console::error_1(&format!("[ws] cancel send error: {:?}", err).into());
}
}
Err(err) => {
console::error_1(&format!("[ws] cancel encode error: {err:#}").into());
}
}
Err(err) => {
console::error_1(&format!("[ws] cancel encode error: {err:#}").into());
}
return;
}
if key == "1" {
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
state.compose_role = state.compose_role.next();
update_status_label(
&app_for_keydown.status_label,
state.compose_role,
state.generation_enabled,
);
return;
}
if key == "4" {
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
state.generation_enabled = !state.generation_enabled;
update_status_label(
&app_for_keydown.status_label,
state.compose_role,
state.generation_enabled,
);
return;
}
if key == "d" || key == "D" {
event.prevent_default();
let mut state = app_for_keydown.state.borrow_mut();
if state.is_receiving {
return;
}
let last_message = match state.messages.pop() {
Some(message) => message,
None => return,
};
if let Some(node) = state.message_nodes.pop() {
if let Some(parent) = node.wrapper.parent_node() {
let _ = parent.remove_child(&node.wrapper);
}
}
app_for_keydown.input.set_value(&last_message.content);
scroll_to_bottom(&app_for_keydown.messages_container);
return;
}
return;
}
if event.key() != "Enter" || event.shift_key() {
@ -207,18 +301,14 @@ pub fn init_chat() -> Result<(), JsValue> {
return;
}
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
console::error_1(&"[ws] socket is not open".into());
return;
}
let role = state.compose_role;
let user_content = trimmed.to_string();
app_for_keydown.input.set_value("");
let user_node = match append_message(
&app_for_keydown.document,
&app_for_keydown.messages_container,
"user",
role.as_str(),
&user_content,
None,
) {
@ -230,12 +320,21 @@ pub fn init_chat() -> Result<(), JsValue> {
};
state.messages.push(UserChatMessage {
role: "user".to_string(),
role: role.as_str().to_string(),
content: user_content,
});
state.message_nodes.push(user_node);
scroll_to_bottom(&app_for_keydown.messages_container);
if role != ComposeRole::User || !state.generation_enabled {
return;
}
if app_for_keydown.ws.ready_state() != WebSocket::OPEN {
console::error_1(&"[ws] socket is not open".into());
return;
}
let history = state.messages.clone();
let request = UserRequest::ChatCompletion(UserChatCompletionRequest::new(history));

View File

@ -154,9 +154,9 @@ body {
color: #8a8a8a;
}
.chat-input-hint {
font-size: 9.5pt;
color: #6b6b6b;
.chat-status-label {
font-size: 15pt;
color: #4b4b4b;
}
body {