46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
|
|
import asyncio
|
||
|
|
import uuid
|
||
|
|
from starlette.responses import JSONResponse
|
||
|
|
from sse_starlette.sse import EventSourceResponse
|
||
|
|
from pydantic import BaseModel
|
||
|
|
from langchain_core.messages import HumanMessage, AIMessage
|
||
|
|
from chatgraph import app_graph
|
||
|
|
from collections import defaultdict
|
||
|
|
|
||
|
|
pending = defaultdict(asyncio.Queue)
|
||
|
|
|
||
|
|
class ChatIn(BaseModel):
|
||
|
|
message: str
|
||
|
|
|
||
|
|
async def create_chat(request):
|
||
|
|
"""POST /chats -> returns {id: <new_chat_id>}"""
|
||
|
|
chat_id = str(uuid.uuid4())[:8]
|
||
|
|
return JSONResponse({"id": chat_id})
|
||
|
|
|
||
|
|
async def post_message(request):
|
||
|
|
"""POST /chats/{chat_id}/messages"""
|
||
|
|
chat_id = request.path_params["chat_id"]
|
||
|
|
body = await request.json()
|
||
|
|
msg = ChatIn(**body).message
|
||
|
|
await pending[chat_id].put(msg)
|
||
|
|
return JSONResponse({"status": "queued"})
|
||
|
|
|
||
|
|
async def stream_response(request):
|
||
|
|
"""GET /chats/{chat_id}/stream (SSE)"""
|
||
|
|
chat_id = request.path_params["chat_id"]
|
||
|
|
|
||
|
|
user_msg = await pending[chat_id].get()
|
||
|
|
|
||
|
|
config = {"configurable": {"thread_id": chat_id}}
|
||
|
|
input_messages = [HumanMessage(content=user_msg)]
|
||
|
|
|
||
|
|
async def event_generator():
|
||
|
|
async for chunk, _ in app_graph.astream(
|
||
|
|
{"messages": input_messages},
|
||
|
|
config,
|
||
|
|
stream_mode="messages",
|
||
|
|
):
|
||
|
|
if isinstance(chunk, AIMessage):
|
||
|
|
yield dict(data=chunk.content)
|
||
|
|
|
||
|
|
return EventSourceResponse(event_generator())
|