78 lines
2.0 KiB
Python
78 lines
2.0 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import Iterable
|
|
|
|
from api import Request, Response
|
|
from secret_stream_socket import wrap_connection_socket
|
|
|
|
SERVER_HOST = "127.0.0.1"
|
|
SERVER_PORT = 9000
|
|
SHARED_SECRET = "change-me"
|
|
|
|
|
|
def build_chat_request(request_id: int, role: str, content: str) -> bytes:
|
|
return Request.build(
|
|
{
|
|
"request_id": request_id,
|
|
"kind": "chat",
|
|
"payload": {
|
|
"messages": [
|
|
{
|
|
"role": role,
|
|
"content": content,
|
|
}
|
|
]
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
def describe_response(response: object) -> str:
|
|
kind = response.kind
|
|
if kind == "chat":
|
|
return f"chat piece: {response.payload.piece!r}"
|
|
if kind == "cancel":
|
|
return "cancel"
|
|
if kind == "end":
|
|
return "end"
|
|
return f"unknown {kind!r}"
|
|
|
|
|
|
def is_terminal_response(response: object) -> bool:
|
|
return response.kind in ("cancel", "end")
|
|
|
|
|
|
async def read_responses(transport, pending: set[int]) -> None:
|
|
while pending:
|
|
frame = await transport.recv_frame()
|
|
if frame is None:
|
|
print("[disconnected]")
|
|
return
|
|
|
|
response = Response.parse(frame)
|
|
print(f"[recv] request_id={response.request_id} {describe_response(response)}")
|
|
|
|
if is_terminal_response(response):
|
|
pending.discard(response.request_id)
|
|
|
|
|
|
async def main() -> None:
|
|
reader, writer = await asyncio.open_connection(SERVER_HOST, SERVER_PORT)
|
|
transport = await wrap_connection_socket(reader, writer, SHARED_SECRET)
|
|
|
|
requests: Iterable[bytes] = [
|
|
build_chat_request(1, "user", "hello from client one"),
|
|
build_chat_request(2, "user", "second request with more words"),
|
|
]
|
|
|
|
for payload in requests:
|
|
await transport.send_frame(payload)
|
|
|
|
await read_responses(transport, {1, 2})
|
|
await transport.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|