chatsbt/controllers.py

46 lines
No EOL
1.4 KiB
Python

import asyncio
import uuid
from starlette.responses import JSONResponse
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel
from langchain_core.messages import HumanMessage, AIMessage
from chatgraph import app_graph
from collections import defaultdict
pending = defaultdict(asyncio.Queue)
class ChatIn(BaseModel):
message: str
async def create_chat(request):
"""POST /chats -> returns {id: <new_chat_id>}"""
chat_id = str(uuid.uuid4())[:8]
return JSONResponse({"id": chat_id})
async def post_message(request):
"""POST /chats/{chat_id}/messages"""
chat_id = request.path_params["chat_id"]
body = await request.json()
msg = ChatIn(**body).message
await pending[chat_id].put(msg)
return JSONResponse({"status": "queued"})
async def stream_response(request):
"""GET /chats/{chat_id}/stream (SSE)"""
chat_id = request.path_params["chat_id"]
user_msg = await pending[chat_id].get()
config = {"configurable": {"thread_id": chat_id}}
input_messages = [HumanMessage(content=user_msg)]
async def event_generator():
async for chunk, _ in app_graph.astream(
{"messages": input_messages},
config,
stream_mode="messages",
):
if isinstance(chunk, AIMessage):
yield dict(data=chunk.content)
return EventSourceResponse(event_generator())