chatsbt/controllers.py

95 lines
3.1 KiB
Python
Raw Normal View History

import uuid
2025-07-29 23:42:15 -04:00
from typing import Dict, List, Tuple
from starlette.responses import JSONResponse
2025-07-29 23:42:15 -04:00
from starlette.requests import Request
from sse_starlette.sse import EventSourceResponse
2025-07-29 23:42:15 -04:00
from chatgraph import get_messages, get_llm
2025-07-29 23:42:15 -04:00
CHATS: Dict[str, List[dict]] = {} # chat_id -> messages
PENDING: Dict[str, Tuple[str, str]] = {} # message_id -> (chat_id, provider)
2025-07-29 23:42:15 -04:00
MODELS = {
"qwen/qwen3-235b-a22b-2507",
"deepseek/deepseek-r1-0528",
"moonshotai/kimi-k2",
2025-07-31 15:59:19 -04:00
"x-ai/grok-4",
"openai/gpt-4.1",
"anthropic/claude-sonnet-4",
"meta-llama/llama-4-maverick",
"mistralai/devstral-medium",
"qwen/qwen3-coder",
"google/gemini-2.5-pro",
2025-07-29 23:42:15 -04:00
}
2025-07-31 15:59:19 -04:00
async def get_models(request: Request):
"""GET /models -> {models: [...]}"""
return JSONResponse({"models": list(MODELS)})
2025-07-29 23:42:15 -04:00
async def create_chat(request: Request):
"""POST /chats -> {chat_id, model}"""
body = await request.json()
provider = body.get("model","")
if provider not in MODELS:
return JSONResponse({"error": "Unknown model"}, status_code=400)
chat_id = str(uuid.uuid4())[:8]
2025-07-29 23:42:15 -04:00
CHATS[chat_id] = []
return JSONResponse({"id": chat_id, "model": provider})
2025-07-29 23:42:15 -04:00
async def history(request : Request):
"""GET /chats/{chat_id} -> previous messages"""
chat_id = request.path_params["chat_id"]
2025-07-29 23:42:15 -04:00
if chat_id not in CHATS:
return JSONResponse({"error": "Not found"}, status_code=404)
return JSONResponse({"messages": CHATS[chat_id]})
async def post_message(request: Request):
"""POST /chats/{chat_id}/messages
Body: {"message": "...", "model": "model_name"}
Returns: {"message_id": "<chat_id>"}
"""
chat_id = request.path_params["chat_id"]
if chat_id not in CHATS:
return JSONResponse({"error": "Chat not found"}, status_code=404)
body = await request.json()
2025-07-29 23:42:15 -04:00
user_text = body.get("message", "")
provider = body.get("model", "")
if provider not in MODELS:
return JSONResponse({"error": "Unknown model"}, status_code=400)
2025-07-29 23:42:15 -04:00
message_id = str(uuid.uuid4())
PENDING[message_id] = (chat_id, provider)
CHATS[chat_id].append({"role": "human", "content": user_text})
return JSONResponse({
"status": "queued",
"message_id": message_id
})
async def chat_stream(request):
"""GET /chats/{chat_id}/stream?message_id=<chat_id>"""
chat_id = request.path_params["chat_id"]
2025-07-29 23:42:15 -04:00
message_id = request.query_params.get("message_id")
2025-07-29 23:42:15 -04:00
if chat_id not in CHATS or message_id not in PENDING:
return JSONResponse({"error": "Not found"}, status_code=404)
2025-07-29 23:42:15 -04:00
chat_id_from_map, provider = PENDING.pop(message_id)
assert chat_id == chat_id_from_map
msgs = get_messages(CHATS, chat_id)
llm = get_llm(provider)
async def event_generator():
2025-07-29 23:42:15 -04:00
buffer = ""
async for chunk in llm.astream(msgs):
token = chunk.content
buffer += token
yield {"data": token}
# Finished: store assistant reply
CHATS[chat_id].append({"role": "assistant", "content": buffer})
yield {"event": "done", "data": ""}
return EventSourceResponse(event_generator())