ready to test
This commit is contained in:
parent
2ac57a6434
commit
f7f6e32acb
@ -1,69 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import socket
|
|
||||||
from typing import Final
|
|
||||||
from construct import ConstructError
|
|
||||||
from api import Request
|
|
||||||
from secret_stream_socket import ProtocolError, wrap_connection_socket
|
|
||||||
|
|
||||||
SERVER_HOST: Final = "127.0.0.1"
|
|
||||||
SERVER_PORT: Final = 9000
|
|
||||||
SHARED_SECRET: Final = "change-me"
|
|
||||||
|
|
||||||
|
|
||||||
def print_batch(peer: str, request_id: int, batch: object) -> None:
|
|
||||||
print(
|
|
||||||
f"[packet] {peer} sent request_id={request_id} with {len(batch.messages)} message(s)",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
for index, message in enumerate(batch.messages):
|
|
||||||
print(
|
|
||||||
f" [{index}] role={message.role!r} content={message.content!r}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def handle_client(
|
|
||||||
client_sock: socket.socket,
|
|
||||||
address: tuple[str, int],
|
|
||||||
shared_secret: str,
|
|
||||||
) -> None:
|
|
||||||
peer = f"{address[0]}:{address[1]}"
|
|
||||||
|
|
||||||
def on_frame(frame: bytes) -> None:
|
|
||||||
req = Request.parse(frame)
|
|
||||||
if req.kind == "chat":
|
|
||||||
print_batch(peer, req.request_id, req.payload)
|
|
||||||
elif req.kind == "cancel":
|
|
||||||
print(f"[packet] {peer} sent request_id={req.request_id} cancel", flush=True)
|
|
||||||
else:
|
|
||||||
raise ConstructError(f"Unknown request kind {req.kind!r}")
|
|
||||||
|
|
||||||
with wrap_connection_socket(client_sock, shared_secret) as transport:
|
|
||||||
print(f"[connected] {peer}", flush=True)
|
|
||||||
transport.run_receiving_loop(on_frame)
|
|
||||||
print(f"[disconnected] {peer}", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def run_server(host: str, port: int, shared_secret: str) -> None:
|
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock:
|
|
||||||
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
||||||
server_sock.bind((host, port))
|
|
||||||
server_sock.listen()
|
|
||||||
|
|
||||||
print(f"[listening] {host}:{port}", flush=True)
|
|
||||||
while True:
|
|
||||||
client_sock, address = server_sock.accept()
|
|
||||||
try:
|
|
||||||
handle_client(client_sock, address, shared_secret)
|
|
||||||
except (ConstructError, ProtocolError) as exc:
|
|
||||||
print(f"[protocol error] {address[0]}:{address[1]}: {exc}", flush=True)
|
|
||||||
except Exception as exc:
|
|
||||||
print(f"[error] {address[0]}:{address[1]}: {exc}", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
try:
|
|
||||||
run_server(SERVER_HOST, SERVER_PORT, SHARED_SECRET)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n[stopped]", flush=True)
|
|
||||||
@ -1,50 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from typing import Final
|
|
||||||
import socket
|
|
||||||
from api import Request
|
|
||||||
|
|
||||||
SERVER_HOST: Final = "127.0.0.1"
|
|
||||||
SERVER_PORT: Final = 9000
|
|
||||||
SHARED_SECRET: Final = "change-me"
|
|
||||||
|
|
||||||
from secret_stream_socket import SecretStreamSocket, wrap_connection_socket
|
|
||||||
|
|
||||||
|
|
||||||
def connect_to_server(
|
|
||||||
host: str = SERVER_HOST,
|
|
||||||
port: int = SERVER_PORT,
|
|
||||||
shared_secret: str = SHARED_SECRET,
|
|
||||||
) -> SecretStreamSocket:
|
|
||||||
sock = socket.create_connection((host, port))
|
|
||||||
return wrap_connection_socket(sock, shared_secret)
|
|
||||||
|
|
||||||
|
|
||||||
def send_message(connection: SecretStreamSocket, request_id: int, role: str, content: str) -> None:
|
|
||||||
payload = Request.build(
|
|
||||||
{
|
|
||||||
"request_id": request_id,
|
|
||||||
"kind": "chat",
|
|
||||||
"payload": {
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": role,
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
connection.send_frame(payload)
|
|
||||||
|
|
||||||
|
|
||||||
def close_connection(connection: SecretStreamSocket) -> None:
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
connection = connect_to_server()
|
|
||||||
try:
|
|
||||||
msg = "\n".join([f"hello {i} from F.py" for i in range(1000)])
|
|
||||||
send_message(connection, 1, "user", msg)
|
|
||||||
finally:
|
|
||||||
close_connection(connection)
|
|
||||||
@ -12,6 +12,7 @@ class Config:
|
|||||||
listening_addr: str
|
listening_addr: str
|
||||||
listening_port: int
|
listening_port: int
|
||||||
secret: str
|
secret: str
|
||||||
|
model_id: str
|
||||||
|
|
||||||
|
|
||||||
def _read_toml(path: Path) -> Mapping[str, Any]:
|
def _read_toml(path: Path) -> Mapping[str, Any]:
|
||||||
@ -41,4 +42,5 @@ def read_config(
|
|||||||
listening_addr=addr,
|
listening_addr=addr,
|
||||||
listening_port=port,
|
listening_port=port,
|
||||||
secret=str(secret),
|
secret=str(secret),
|
||||||
|
model_id=str(config_data.get("model_id", "zai-org/GLM-4.7-Flash")),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from config import Config, read_config
|
|||||||
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
|
from secret_stream_socket import ProtocolError, SecretStreamSocket, wrap_connection_socket
|
||||||
|
|
||||||
|
|
||||||
MODEL_ID = "zai-org/GLM-4.7-Flash"
|
|
||||||
MAX_NEW_TOKENS = 256
|
MAX_NEW_TOKENS = 256
|
||||||
|
|
||||||
|
|
||||||
@ -72,21 +71,21 @@ def build_context(messages: list) -> list[dict[str, str]]:
|
|||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def load_local_model() -> ModelBundle:
|
def load_local_model(model_id: str) -> ModelBundle:
|
||||||
try:
|
try:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
MODEL_ID,
|
model_id,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
)
|
)
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
MODEL_ID,
|
model_id,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
)
|
)
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
fail(
|
fail(
|
||||||
f"Model {MODEL_ID!r} is not fully available in the Hugging Face cache. "
|
f"Model {model_id!r} is not fully available in the Hugging Face cache. "
|
||||||
f"Run B2.py first. Original error: {exc}"
|
f"Run B2.py first. Original error: {exc}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -134,6 +133,8 @@ def generate_llm_pieces(bundle: ModelBundle, messages: list) -> Iterable[str]:
|
|||||||
next_token_id = torch.argmax(logits, dim=-1)
|
next_token_id = torch.argmax(logits, dim=-1)
|
||||||
token_id = int(next_token_id.item())
|
token_id = int(next_token_id.item())
|
||||||
|
|
||||||
|
print("[debug] get argmaxed")
|
||||||
|
|
||||||
if token_id in eos_token_ids:
|
if token_id in eos_token_ids:
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -168,6 +169,7 @@ def worker_loop(
|
|||||||
if record.is_cancelled():
|
if record.is_cancelled():
|
||||||
cancelled = True
|
cancelled = True
|
||||||
else:
|
else:
|
||||||
|
print("[debug] starting llm inference for a chat completion request")
|
||||||
for piece in generate_llm_pieces(model_bundle, item.messages):
|
for piece in generate_llm_pieces(model_bundle, item.messages):
|
||||||
loop.call_soon_threadsafe(
|
loop.call_soon_threadsafe(
|
||||||
record.response_queue.put_nowait,
|
record.response_queue.put_nowait,
|
||||||
@ -250,6 +252,7 @@ async def handle_chat_request(
|
|||||||
transport: SecretStreamSocket,
|
transport: SecretStreamSocket,
|
||||||
send_lock: asyncio.Lock,
|
send_lock: asyncio.Lock,
|
||||||
response_tasks: set[asyncio.Task[None]],
|
response_tasks: set[asyncio.Task[None]],
|
||||||
|
owned_request_ids: set[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
request_id = request.request_id
|
request_id = request.request_id
|
||||||
print(f"[request] chat request_id={request_id}", flush=True)
|
print(f"[request] chat request_id={request_id}", flush=True)
|
||||||
@ -264,6 +267,7 @@ async def handle_chat_request(
|
|||||||
f"Duplicate request_id {request_id} received on this connection"
|
f"Duplicate request_id {request_id} received on this connection"
|
||||||
)
|
)
|
||||||
pending[request_id] = record
|
pending[request_id] = record
|
||||||
|
owned_request_ids.add(request_id)
|
||||||
|
|
||||||
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
|
work_queue.put(WorkItem(request_id, list(request.payload.messages), record))
|
||||||
|
|
||||||
@ -369,8 +373,8 @@ async def run_server(model_bundle: ModelBundle, config: Config) -> None:
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
config = read_config()
|
config = read_config()
|
||||||
model_bundle = load_local_model()
|
model_bundle = load_local_model(config.model_id)
|
||||||
print("[model] loaded", flush=True)
|
print(f"[model] loaded {config.model_id}", flush=True)
|
||||||
asyncio.run(run_server(model_bundle, config))
|
asyncio.run(run_server(model_bundle, config))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user