254 lines
9.1 KiB
Python
254 lines
9.1 KiB
Python
import json
|
|
import requests
|
|
import numpy as np
|
|
from typing import List, Dict, Optional
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import PointStruct
|
|
|
|
# --- Model Client ---
|
|
class ModelClient:
|
|
def chat(self, messages: List[Dict], character_id: Optional[str] = None) -> str:
|
|
payload = {"model": "deepseek-main", "messages": messages, "character_id": character_id}
|
|
try:
|
|
resp = requests.post("http://vllm-chat:5001/v1/chat/completions", json=payload)
|
|
resp.raise_for_status()
|
|
return resp.json()["choices"][0]["message"]["content"]
|
|
except requests.RequestException as e:
|
|
return f"[Error calling chat model: {e}]"
|
|
|
|
def embed(self, text: str) -> List[float]:
|
|
payload = {"model": "bge-m3", "input": text}
|
|
try:
|
|
resp = requests.post("http://vllm-embed:5002/v1/embeddings", json=payload)
|
|
resp.raise_for_status()
|
|
return resp.json()["data"]
|
|
except requests.RequestException as e:
|
|
return np.random.rand(768).tolist() # fallback stub
|
|
|
|
# --- Memory Store with Qdrant ---
|
|
class MemoryStore:
|
|
def __init__(self, host="qdrant", port=6333, collection_name="memory", top_k=5):
|
|
self.client = QdrantClient(host=host, port=port)
|
|
self.collection_name = collection_name
|
|
self.top_k = top_k
|
|
|
|
if collection_name not in [c.name for c in self.client.get_collections().collections]:
|
|
self.client.recreate_collection(
|
|
collection_name=collection_name,
|
|
vectors={"size": 768, "distance": "Cosine"}
|
|
)
|
|
|
|
def store(self, namespace: str, content: str, embedding: Optional[List[float]] = None):
|
|
if embedding is None:
|
|
embedding = np.random.rand(768).tolist()
|
|
self.client.upsert(
|
|
collection_name=self.collection_name,
|
|
points=[PointStruct(
|
|
id=f"{namespace}_{np.random.randint(1e6)}",
|
|
vector=embedding,
|
|
payload={"content": content, "namespace": namespace}
|
|
)]
|
|
)
|
|
|
|
def retrieve(self, namespace: str, query: str, top_k: int = None) -> List[str]:
|
|
if top_k is None:
|
|
top_k = self.top_k
|
|
query_embedding = np.array(model_client.embed(query))
|
|
try:
|
|
hits = self.client.search(
|
|
collection_name=self.collection_name,
|
|
query_vector=query_embedding.tolist(),
|
|
query_filter={"must": [{"key": "namespace", "match": {"value": namespace}}]},
|
|
limit=top_k
|
|
)
|
|
return [hit.payload["content"] for hit in hits]
|
|
except Exception:
|
|
return []
|
|
|
|
# --- Tool Registry ---
|
|
class ToolRegistry:
|
|
def __init__(self):
|
|
self.tools = {
|
|
"generate_image": self.generate_image,
|
|
"search_documents": self.search_documents,
|
|
"run_code": self.run_code,
|
|
"tts": self.tts_speak,
|
|
"stt": self.stt_transcribe,
|
|
}
|
|
|
|
def execute(self, tool_name: str, arguments: dict) -> str:
|
|
if tool_name not in self.tools:
|
|
return f"Error: tool '{tool_name}' not registered"
|
|
return self.tools[tool_name](arguments)
|
|
|
|
def generate_image(self, args: dict) -> str:
|
|
prompt = args.get("prompt", "")
|
|
style = args.get("style", "default")
|
|
try:
|
|
resp = requests.post("http://image_service:5100/generate", json={"prompt": prompt, "style": style})
|
|
resp.raise_for_status()
|
|
return resp.json()["image_url"]
|
|
except requests.RequestException as e:
|
|
return f"[Error generating image: {e}]"
|
|
|
|
def search_documents(self, args: dict) -> str:
|
|
query = args.get("query", "")
|
|
return f"[Search results for: {query}]"
|
|
|
|
def run_code(self, args: dict) -> str:
|
|
code = args.get("code", "")
|
|
return f"[Executed code: {code}]"
|
|
|
|
def tts_speak(self, args: dict) -> str:
|
|
text = args.get("text", "")
|
|
try:
|
|
resp = requests.post("http://tts_service:5200/speak", json={"text": text})
|
|
resp.raise_for_status()
|
|
return resp.json()["audio_url"]
|
|
except requests.RequestException as e:
|
|
return f"[Error in TTS: {e}]"
|
|
|
|
def stt_transcribe(self, args: dict) -> str:
|
|
audio_url = args.get("audio_url", "")
|
|
try:
|
|
resp = requests.post("http://stt_service:5300/transcribe", json={"audio_url": audio_url})
|
|
resp.raise_for_status()
|
|
return resp.json()["text"]
|
|
except requests.RequestException as e:
|
|
return f"[Error in STT: {e}]"
|
|
|
|
# --- Prompt Assembler ---
|
|
class PromptAssembler:
|
|
def __init__(self, model_max_tokens: int = 8192):
|
|
self.model_max_tokens = model_max_tokens
|
|
|
|
def assemble_prompt(self,
|
|
character_config: Dict,
|
|
memory_chunks: List[str],
|
|
rag_context: List[str],
|
|
conversation_history: List[Dict],
|
|
tool_schemas: List[Dict]) -> str:
|
|
system = character_config.get("system_prompt", "")
|
|
memory_text = "\n".join(memory_chunks)
|
|
context_text = "\n".join(rag_context)
|
|
history_text = "\n".join(f'{m["role"].capitalize()}: {m["content"]}' for m in conversation_history)
|
|
tools_text = "TOOLS:\n" + "\n".join([t.get("description", "") for t in tool_schemas])
|
|
prompt = "\n\n".join([system, memory_text, context_text, tools_text, history_text])
|
|
return self._trim_to_token_limit(prompt)
|
|
|
|
def _trim_to_token_limit(self, prompt: str) -> str:
|
|
# Optional: truncate if exceeding model_max_tokens
|
|
return prompt
|
|
|
|
# --- Utility functions ---
|
|
def parse_tool_call(model_response: str) -> dict | None:
|
|
try:
|
|
data = json.loads(model_response)
|
|
if "tool" in data and "arguments" in data:
|
|
return data
|
|
except json.JSONDecodeError:
|
|
return None
|
|
return None
|
|
|
|
def handle_model_response(model_client: ModelClient,
|
|
tool_registry: ToolRegistry,
|
|
model_response: str) -> str:
|
|
tool_call = parse_tool_call(model_response)
|
|
if tool_call:
|
|
tool_name = tool_call["tool"]
|
|
arguments = tool_call["arguments"]
|
|
result = tool_registry.execute(tool_name, arguments)
|
|
return f"Tool result:\n{result}\nPlease continue the response."
|
|
return model_response
|
|
|
|
# --- FastAPI App & Models ---
|
|
app = FastAPI(title="AI Gateway")
|
|
|
|
model_client = ModelClient()
|
|
memory_store = MemoryStore()
|
|
tool_registry = ToolRegistry()
|
|
prompt_assembler = PromptAssembler()
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class ChatRequest(BaseModel):
|
|
model: str
|
|
messages: List[Message]
|
|
character_id: Optional[str] = None
|
|
conversation_id: Optional[str] = None
|
|
tools: Optional[bool] = False
|
|
temperature: Optional[float] = 0.7
|
|
|
|
class EmbeddingRequest(BaseModel):
|
|
model: str
|
|
input: str
|
|
|
|
class ToolExecutionRequest(BaseModel):
|
|
tool_name: str
|
|
arguments: Dict
|
|
|
|
# --- Endpoints ---
|
|
@app.post("/v1/chat/completions")
|
|
def chat_endpoint(req: ChatRequest):
|
|
namespace = req.character_id or "user"
|
|
|
|
character_config = {"system_prompt": f"You are {namespace}."}
|
|
|
|
# Retrieve memory for RAG
|
|
memory_chunks = memory_store.retrieve(namespace, req.messages[-1].content, top_k=5)
|
|
rag_context = memory_chunks
|
|
|
|
# Assemble prompt
|
|
prompt = prompt_assembler.assemble_prompt(
|
|
character_config=character_config,
|
|
memory_chunks=memory_chunks,
|
|
rag_context=rag_context,
|
|
conversation_history=[m.dict() for m in req.messages],
|
|
tool_schemas=[{"description": "generate_image"}, {"description": "search_documents"}]
|
|
)
|
|
|
|
# Inject prompt into chat messages so memory affects response
|
|
messages_with_prompt = [{"role": "system", "content": prompt}] + [m.dict() for m in req.messages]
|
|
|
|
# Call model
|
|
model_response = model_client.chat(messages=messages_with_prompt, character_id=req.character_id)
|
|
|
|
# Handle tool calls
|
|
final_response = handle_model_response(model_client, tool_registry, model_response)
|
|
|
|
# Persist memory
|
|
embedding_vector = model_client.embed(final_response)
|
|
memory_store.store(namespace, final_response, embedding=embedding_vector)
|
|
|
|
return {"id": "chatcmpl-001", "choices": [{"message": {"role": "assistant", "content": final_response}}]}
|
|
|
|
@app.post("/v1/embeddings")
|
|
def embeddings_endpoint(req: EmbeddingRequest):
|
|
embedding_vector = model_client.embed(req.input)
|
|
return {"object": "embedding", "data": embedding_vector}
|
|
|
|
@app.get("/v1/models")
|
|
def models_endpoint():
|
|
return {"data": [
|
|
{"id": "deepseek-main", "type": "chat"},
|
|
{"id": "bge-m3", "type": "embedding"}
|
|
]}
|
|
|
|
@app.post("/v1/tools/execute")
|
|
def tools_execute_endpoint(req: ToolExecutionRequest):
|
|
result = tool_registry.execute(req.tool_name, req.arguments)
|
|
if "Error" in result:
|
|
raise HTTPException(status_code=400, detail=result)
|
|
return {"result": result}
|
|
|
|
# --- main() ---
|
|
def main():
|
|
print("AI Gateway is ready!")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|