diff --git a/dedicated_ai_server/E.py b/dedicated_ai_server/E.py deleted file mode 100644 index bf6df2b..0000000 --- a/dedicated_ai_server/E.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -import socket -from typing import Final -from construct import ConstructError -from api import Request -from secret_stream_socket import ProtocolError, wrap_connection_socket - -SERVER_HOST: Final = "127.0.0.1" -SERVER_PORT: Final = 9000 -SHARED_SECRET: Final = "change-me" - - -def print_batch(peer: str, request_id: int, batch: object) -> None: - print( - f"[packet] {peer} sent request_id={request_id} with {len(batch.messages)} message(s)", - flush=True, - ) - for index, message in enumerate(batch.messages): - print( - f" [{index}] role={message.role!r} content={message.content!r}", - flush=True, - ) - - -def handle_client( - client_sock: socket.socket, - address: tuple[str, int], - shared_secret: str, -) -> None: - peer = f"{address[0]}:{address[1]}" - - def on_frame(frame: bytes) -> None: - req = Request.parse(frame) - if req.kind == "chat": - print_batch(peer, req.request_id, req.payload) - elif req.kind == "cancel": - print(f"[packet] {peer} sent request_id={req.request_id} cancel", flush=True) - else: - raise ConstructError(f"Unknown request kind {req.kind!r}") - - with wrap_connection_socket(client_sock, shared_secret) as transport: - print(f"[connected] {peer}", flush=True) - transport.run_receiving_loop(on_frame) - print(f"[disconnected] {peer}", flush=True) - - -def run_server(host: str, port: int, shared_secret: str) -> None: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock: - server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_sock.bind((host, port)) - server_sock.listen() - - print(f"[listening] {host}:{port}", flush=True) - while True: - client_sock, address = server_sock.accept() - try: - handle_client(client_sock, address, shared_secret) - except (ConstructError, ProtocolError) as exc: - print(f"[protocol error] {address[0]}:{address[1]}: {exc}", flush=True) - except Exception as exc: - print(f"[error] {address[0]}:{address[1]}: {exc}", flush=True) - - -if __name__ == "__main__": - try: - run_server(SERVER_HOST, SERVER_PORT, SHARED_SECRET) - except KeyboardInterrupt: - print("\n[stopped]", flush=True) diff --git a/dedicated_ai_server/F.py b/dedicated_ai_server/F.py deleted file mode 100644 index c39a4a0..0000000 --- a/dedicated_ai_server/F.py +++ /dev/null @@ -1,50 +0,0 @@ -from __future__ import annotations -from typing import Final -import socket -from api import Request - -SERVER_HOST: Final = "127.0.0.1" -SERVER_PORT: Final = 9000 -SHARED_SECRET: Final = "change-me" - -from secret_stream_socket import SecretStreamSocket, wrap_connection_socket - - -def connect_to_server( - host: str = SERVER_HOST, - port: int = SERVER_PORT, - shared_secret: str = SHARED_SECRET, -) -> SecretStreamSocket: - sock = socket.create_connection((host, port)) - return wrap_connection_socket(sock, shared_secret) - - -def send_message(connection: SecretStreamSocket, request_id: int, role: str, content: str) -> None: - payload = Request.build( - { - "request_id": request_id, - "kind": "chat", - "payload": { - "messages": [ - { - "role": role, - "content": content, - } - ] - }, - } - ) - connection.send_frame(payload) - - -def close_connection(connection: SecretStreamSocket) -> None: - connection.close() - - -if __name__ == "__main__": - connection = connect_to_server() - try: - msg = "\n".join([f"hello {i} from F.py" for i in range(1000)]) - send_message(connection, 1, "user", msg) - finally: - close_connection(connection) diff --git a/dedicated_ai_server/config.py b/dedicated_ai_server/config.py index 5ae7b19..a18b41f 100644 --- a/dedicated_ai_server/config.py +++ b/dedicated_ai_server/config.py @@ -12,6 +12,7 @@ class Config: listening_addr: str listening_port: int secret: str + model_id: str def _read_toml(path: Path) -> Mapping[str, Any]: @@ -41,4 +42,5 @@ def read_config( listening_addr=addr, listening_port=port, secret=str(secret), + model_id=str(config_data.get("model_id", "zai-org/GLM-4.7-Flash")), ) diff --git a/dedicated_ai_server/server.py b/dedicated_ai_server/server.py index 4cff69d..84a9a9b 100644 --- a/dedicated_ai_server/server.py +++ b/dedicated_ai_server/server.py @@ -14,7 +14,6 @@ from config import Config, read_config from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket -MODEL_ID = "zai-org/GLM-4.7-Flash" MAX_NEW_TOKENS = 256 @@ -72,21 +71,21 @@ def build_context(messages: list) -> list[dict[str, str]]: return context -def load_local_model() -> ModelBundle: +def load_local_model(model_id: str) -> ModelBundle: try: tokenizer = AutoTokenizer.from_pretrained( - MODEL_ID, + model_id, local_files_only=True, ) model = AutoModelForCausalLM.from_pretrained( - MODEL_ID, + model_id, local_files_only=True, dtype=torch.bfloat16, device_map="auto", ) except OSError as exc: fail( - f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. " + f"Model {model_id!r} is not fully available in the Hugging Face cache. " f"Run B2.py first. Original error: {exc}" ) @@ -134,6 +133,8 @@ def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]: next_token_id = torch.argmax(logits, dim=-1) token_id = int(next_token_id.item()) + print("[debug] get argmaxed") + if token_id in eos_token_ids: break @@ -168,6 +169,7 @@ def worker_loop( if record.is_cancelled(): cancelled = True else: + print("[debug] starting llm inference for a chat completion request") for piece in generate_llm_pieces(model_bundle, item.messages): loop.call_soon_threadsafe( record.response_queue.put_nowait, @@ -250,6 +252,7 @@ async def handle_chat_request( transport: SecretStreamSocket, send_lock: asyncio.Lock, response_tasks: set[asyncio.Task[None]], + owned_request_ids: set[int], ) -> None: request_id = request.request_id print(f"[request] chat request_id={request_id}", flush=True) @@ -264,6 +267,7 @@ async def handle_chat_request( f"Duplicate request_id {request_id} received on this connection" ) pending[request_id] = record + owned_request_ids.add(request_id) work_queue.put(WorkItem(request_id, list(request.payload.messages), record)) @@ -369,8 +373,8 @@ async def run_server(model_bundle: ModelBundle, config: Config) -> None: def main() -> None: config = read_config() - model_bundle = load_local_model() - print("[model] loaded", flush=True) + model_bundle = load_local_model(config.model_id) + print(f"[model] loaded {config.model_id}", flush=True) asyncio.run(run_server(model_bundle, config))