refactor starlette as API, langgraph from openrouter
This commit is contained in:
parent
1dec5d49ec
commit
4527ed19cc
7 changed files with 450 additions and 170 deletions
46
controllers.py
Normal file
46
controllers.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import asyncio
|
||||
import uuid
|
||||
from starlette.responses import JSONResponse
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from pydantic import BaseModel
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
from chatgraph import app_graph
|
||||
from collections import defaultdict
|
||||
|
||||
pending = defaultdict(asyncio.Queue)
|
||||
|
||||
class ChatIn(BaseModel):
|
||||
message: str
|
||||
|
||||
async def create_chat(request):
|
||||
"""POST /chats -> returns {id: <new_chat_id>}"""
|
||||
chat_id = str(uuid.uuid4())[:8]
|
||||
return JSONResponse({"id": chat_id})
|
||||
|
||||
async def post_message(request):
|
||||
"""POST /chats/{chat_id}/messages"""
|
||||
chat_id = request.path_params["chat_id"]
|
||||
body = await request.json()
|
||||
msg = ChatIn(**body).message
|
||||
await pending[chat_id].put(msg)
|
||||
return JSONResponse({"status": "queued"})
|
||||
|
||||
async def stream_response(request):
|
||||
"""GET /chats/{chat_id}/stream (SSE)"""
|
||||
chat_id = request.path_params["chat_id"]
|
||||
|
||||
user_msg = await pending[chat_id].get()
|
||||
|
||||
config = {"configurable": {"thread_id": chat_id}}
|
||||
input_messages = [HumanMessage(content=user_msg)]
|
||||
|
||||
async def event_generator():
|
||||
async for chunk, _ in app_graph.astream(
|
||||
{"messages": input_messages},
|
||||
config,
|
||||
stream_mode="messages",
|
||||
):
|
||||
if isinstance(chunk, AIMessage):
|
||||
yield dict(data=chunk.content)
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
Loading…
Add table
Add a link
Reference in a new issue