Chat History and multi model
This commit is contained in:
parent
a9ffb48b4b
commit
44f391ef1e
13 changed files with 1072 additions and 839 deletions
|
|
@ -1,46 +1,83 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from typing import Dict, List, Tuple
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.requests import Request
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from chatgraph import app_graph
|
||||
from collections import defaultdict
|
||||
from chatgraph import get_messages, get_llm
|
||||
|
||||
pending = defaultdict(asyncio.Queue)
|
||||
|
||||
class ChatIn(BaseModel):
|
||||
message: str
|
||||
CHATS: Dict[str, List[dict]] = {} # chat_id -> messages
|
||||
PENDING: Dict[str, Tuple[str, str]] = {} # message_id -> (chat_id, provider)
|
||||
|
||||
async def create_chat(request):
|
||||
"""POST /chats -> returns {id: <new_chat_id>}"""
|
||||
chat_id = str(uuid.uuid4())[:8]
|
||||
return JSONResponse({"id": chat_id})
|
||||
MODELS = {
|
||||
"qwen/qwen3-235b-a22b-2507",
|
||||
"deepseek/deepseek-r1-0528",
|
||||
"moonshotai/kimi-k2",
|
||||
}
|
||||
|
||||
async def post_message(request):
|
||||
"""POST /chats/{chat_id}/messages"""
|
||||
chat_id = request.path_params["chat_id"]
|
||||
async def create_chat(request: Request):
|
||||
"""POST /chats -> {chat_id, model}"""
|
||||
body = await request.json()
|
||||
msg = ChatIn(**body).message
|
||||
await pending[chat_id].put(msg)
|
||||
return JSONResponse({"status": "queued"})
|
||||
provider = body.get("model","")
|
||||
if provider not in MODELS:
|
||||
return JSONResponse({"error": "Unknown model"}, status_code=400)
|
||||
chat_id = str(uuid.uuid4())[:8]
|
||||
CHATS[chat_id] = []
|
||||
return JSONResponse({"id": chat_id, "model": provider})
|
||||
|
||||
async def stream_response(request):
|
||||
"""GET /chats/{chat_id}/stream (SSE)"""
|
||||
async def history(request : Request):
|
||||
"""GET /chats/{chat_id} -> previous messages"""
|
||||
chat_id = request.path_params["chat_id"]
|
||||
if chat_id not in CHATS:
|
||||
return JSONResponse({"error": "Not found"}, status_code=404)
|
||||
return JSONResponse({"messages": CHATS[chat_id]})
|
||||
|
||||
user_msg = await pending[chat_id].get()
|
||||
async def post_message(request: Request):
|
||||
"""POST /chats/{chat_id}/messages
|
||||
Body: {"message": "...", "model": "model_name"}
|
||||
Returns: {"message_id": "<chat_id>"}
|
||||
"""
|
||||
chat_id = request.path_params["chat_id"]
|
||||
if chat_id not in CHATS:
|
||||
return JSONResponse({"error": "Chat not found"}, status_code=404)
|
||||
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
input_messages = [HumanMessage(content=user_msg)]
|
||||
body = await request.json()
|
||||
user_text = body.get("message", "")
|
||||
provider = body.get("model", "")
|
||||
if provider not in MODELS:
|
||||
return JSONResponse({"error": "Unknown model"}, status_code=400)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
PENDING[message_id] = (chat_id, provider)
|
||||
CHATS[chat_id].append({"role": "human", "content": user_text})
|
||||
|
||||
return JSONResponse({
|
||||
"status": "queued",
|
||||
"message_id": message_id
|
||||
})
|
||||
|
||||
async def chat_stream(request):
|
||||
"""GET /chats/{chat_id}/stream?message_id=<chat_id>"""
|
||||
chat_id = request.path_params["chat_id"]
|
||||
message_id = request.query_params.get("message_id")
|
||||
|
||||
if chat_id not in CHATS or message_id not in PENDING:
|
||||
return JSONResponse({"error": "Not found"}, status_code=404)
|
||||
|
||||
chat_id_from_map, provider = PENDING.pop(message_id)
|
||||
assert chat_id == chat_id_from_map
|
||||
|
||||
msgs = get_messages(CHATS, chat_id)
|
||||
llm = get_llm(provider)
|
||||
|
||||
async def event_generator():
|
||||
async for chunk, _ in app_graph.astream(
|
||||
{"messages": input_messages},
|
||||
config,
|
||||
stream_mode="messages",
|
||||
):
|
||||
if isinstance(chunk, AIMessage):
|
||||
yield dict(data=chunk.content)
|
||||
buffer = ""
|
||||
async for chunk in llm.astream(msgs):
|
||||
token = chunk.content
|
||||
buffer += token
|
||||
yield {"data": token}
|
||||
# Finished: store assistant reply
|
||||
CHATS[chat_id].append({"role": "assistant", "content": buffer})
|
||||
yield {"event": "done", "data": ""}
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
return EventSourceResponse(event_generator())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue