115 lines
3.6 KiB
Python
115 lines
3.6 KiB
Python
import uuid
|
|
import json
|
|
from typing import Dict, List, Tuple
|
|
from starlette.responses import JSONResponse
|
|
from starlette.requests import Request
|
|
from sse_starlette.sse import EventSourceResponse
|
|
from chatgraph import get_messages, get_llm
|
|
from models.Chat import Chat
|
|
|
|
|
|
PENDING: Dict[str, Tuple[str, str]] = {} # message_id -> (chat_id, provider)
|
|
|
|
MODELS = {
|
|
"qwen/qwen3-235b-a22b-2507",
|
|
"deepseek/deepseek-r1-0528",
|
|
"moonshotai/kimi-k2",
|
|
"x-ai/grok-4",
|
|
"openai/gpt-4.1",
|
|
"anthropic/claude-sonnet-4",
|
|
"meta-llama/llama-4-maverick",
|
|
"mistralai/devstral-medium",
|
|
"qwen/qwen3-coder",
|
|
"google/gemini-2.5-pro",
|
|
}
|
|
|
|
async def get_models(request: Request):
|
|
"""GET /models -> {models: [...]}"""
|
|
return JSONResponse({"models": list(MODELS)})
|
|
|
|
async def create_chat(request: Request):
|
|
"""POST /chats -> {chat_id, model}"""
|
|
body = await request.json()
|
|
provider = body.get("model","")
|
|
if provider not in MODELS:
|
|
return JSONResponse({"error": "Unknown model"}, status_code=400)
|
|
chat = Chat()
|
|
chat_id = str(uuid.uuid4())
|
|
chat.id = chat_id
|
|
chat.title = "New Chat"
|
|
chat.messages = json.dumps([])
|
|
chat.save()
|
|
return JSONResponse({"id": chat_id, "model": provider})
|
|
|
|
async def history(request : Request):
|
|
"""GET /chats/{chat_id} -> previous messages"""
|
|
chat_id = request.path_params["chat_id"]
|
|
chat = Chat.find(chat_id)
|
|
if not chat:
|
|
return JSONResponse({"error": "Not found"}, status_code=404)
|
|
messages = json.loads(chat.messages) if chat.messages else []
|
|
return JSONResponse({"messages": messages})
|
|
|
|
async def post_message(request: Request):
|
|
"""POST /chats/{chat_id}/messages
|
|
Body: {"message": "...", "model": "model_name"}
|
|
Returns: {"message_id": "<chat_id>"}
|
|
"""
|
|
chat_id = request.path_params["chat_id"]
|
|
chat = Chat.find(chat_id)
|
|
if not chat:
|
|
return JSONResponse({"error": "Chat not found"}, status_code=404)
|
|
|
|
body = await request.json()
|
|
user_text = body.get("message", "")
|
|
provider = body.get("model", "")
|
|
if provider not in MODELS:
|
|
return JSONResponse({"error": "Unknown model"}, status_code=400)
|
|
|
|
# Load existing messages and add the new user message
|
|
messages = json.loads(chat.messages) if chat.messages else []
|
|
messages.append({"role": "human", "content": user_text})
|
|
chat.messages = json.dumps(messages)
|
|
chat.save()
|
|
|
|
message_id = str(uuid.uuid4())
|
|
PENDING[message_id] = (chat_id, provider)
|
|
|
|
return JSONResponse({
|
|
"status": "queued",
|
|
"message_id": message_id
|
|
})
|
|
|
|
async def chat_stream(request):
|
|
"""GET /chats/{chat_id}/stream?message_id=<chat_id>"""
|
|
chat_id = request.path_params["chat_id"]
|
|
message_id = request.query_params.get("message_id")
|
|
|
|
if message_id not in PENDING:
|
|
return JSONResponse({"error": "Not found"}, status_code=404)
|
|
|
|
chat_id_from_map, provider = PENDING.pop(message_id)
|
|
assert chat_id == chat_id_from_map
|
|
|
|
chat = Chat.find(chat_id)
|
|
if not chat:
|
|
return JSONResponse({"error": "Chat not found"}, status_code=404)
|
|
|
|
messages = json.loads(chat.messages) if chat.messages else []
|
|
msgs = get_messages( messages , chat_id)
|
|
llm = get_llm(provider)
|
|
|
|
async def event_generator():
|
|
buffer = ""
|
|
async for chunk in llm.astream(msgs):
|
|
token = chunk.content
|
|
buffer += token
|
|
yield {"data": token}
|
|
# Finished: store assistant reply
|
|
messages.append({"role": "assistant", "content": buffer})
|
|
chat.messages = json.dumps(messages)
|
|
chat.save()
|
|
yield {"event": "done", "data": ""}
|
|
|
|
return EventSourceResponse(event_generator())
|
|
|