import asyncio import uuid from starlette.responses import JSONResponse from sse_starlette.sse import EventSourceResponse from pydantic import BaseModel from langchain_core.messages import HumanMessage, AIMessage from chatgraph import app_graph from collections import defaultdict pending = defaultdict(asyncio.Queue) class ChatIn(BaseModel): message: str async def create_chat(request): """POST /chats -> returns {id: }""" chat_id = str(uuid.uuid4())[:8] return JSONResponse({"id": chat_id}) async def post_message(request): """POST /chats/{chat_id}/messages""" chat_id = request.path_params["chat_id"] body = await request.json() msg = ChatIn(**body).message await pending[chat_id].put(msg) return JSONResponse({"status": "queued"}) async def stream_response(request): """GET /chats/{chat_id}/stream (SSE)""" chat_id = request.path_params["chat_id"] user_msg = await pending[chat_id].get() config = {"configurable": {"thread_id": chat_id}} input_messages = [HumanMessage(content=user_msg)] async def event_generator(): async for chunk, _ in app_graph.astream( {"messages": input_messages}, config, stream_mode="messages", ): if isinstance(chunk, AIMessage): yield dict(data=chunk.content) return EventSourceResponse(event_generator())