main components from chatgpt
This commit is contained in:
18
dockerfile
Normal file
18
dockerfile
Normal file
@@ -0,0 +1,18 @@
|
||||
# Use official Python image
|
||||
FROM python:3.11-slim
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy the FastAPI app
|
||||
COPY main.py .
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Start FastAPI with uvicorn
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
@@ -1,5 +0,0 @@
|
||||
def main():
|
||||
print("Hello, world!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
253
main.py
Normal file
253
main.py
Normal file
@@ -0,0 +1,253 @@
|
||||
import json
|
||||
import requests
|
||||
import numpy as np
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
# --- Model Client ---
|
||||
class ModelClient:
|
||||
def chat(self, messages: List[Dict], character_id: Optional[str] = None) -> str:
|
||||
payload = {"model": "deepseek-main", "messages": messages, "character_id": character_id}
|
||||
try:
|
||||
resp = requests.post("http://vllm-chat:5001/v1/chat/completions", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["choices"][0]["message"]["content"]
|
||||
except requests.RequestException as e:
|
||||
return f"[Error calling chat model: {e}]"
|
||||
|
||||
def embed(self, text: str) -> List[float]:
|
||||
payload = {"model": "bge-m3", "input": text}
|
||||
try:
|
||||
resp = requests.post("http://vllm-embed:5002/v1/embeddings", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json()["data"]
|
||||
except requests.RequestException as e:
|
||||
return np.random.rand(768).tolist() # fallback stub
|
||||
|
||||
# --- Memory Store with Qdrant ---
|
||||
class MemoryStore:
|
||||
def __init__(self, host="qdrant", port=6333, collection_name="memory", top_k=5):
|
||||
self.client = QdrantClient(host=host, port=port)
|
||||
self.collection_name = collection_name
|
||||
self.top_k = top_k
|
||||
|
||||
if collection_name not in [c.name for c in self.client.get_collections().collections]:
|
||||
self.client.recreate_collection(
|
||||
collection_name=collection_name,
|
||||
vectors={"size": 768, "distance": "Cosine"}
|
||||
)
|
||||
|
||||
def store(self, namespace: str, content: str, embedding: Optional[List[float]] = None):
|
||||
if embedding is None:
|
||||
embedding = np.random.rand(768).tolist()
|
||||
self.client.upsert(
|
||||
collection_name=self.collection_name,
|
||||
points=[PointStruct(
|
||||
id=f"{namespace}_{np.random.randint(1e6)}",
|
||||
vector=embedding,
|
||||
payload={"content": content, "namespace": namespace}
|
||||
)]
|
||||
)
|
||||
|
||||
def retrieve(self, namespace: str, query: str, top_k: int = None) -> List[str]:
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
query_embedding = np.array(model_client.embed(query))
|
||||
try:
|
||||
hits = self.client.search(
|
||||
collection_name=self.collection_name,
|
||||
query_vector=query_embedding.tolist(),
|
||||
query_filter={"must": [{"key": "namespace", "match": {"value": namespace}}]},
|
||||
limit=top_k
|
||||
)
|
||||
return [hit.payload["content"] for hit in hits]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
# --- Tool Registry ---
|
||||
class ToolRegistry:
|
||||
def __init__(self):
|
||||
self.tools = {
|
||||
"generate_image": self.generate_image,
|
||||
"search_documents": self.search_documents,
|
||||
"run_code": self.run_code,
|
||||
"tts": self.tts_speak,
|
||||
"stt": self.stt_transcribe,
|
||||
}
|
||||
|
||||
def execute(self, tool_name: str, arguments: dict) -> str:
|
||||
if tool_name not in self.tools:
|
||||
return f"Error: tool '{tool_name}' not registered"
|
||||
return self.tools[tool_name](arguments)
|
||||
|
||||
def generate_image(self, args: dict) -> str:
|
||||
prompt = args.get("prompt", "")
|
||||
style = args.get("style", "default")
|
||||
try:
|
||||
resp = requests.post("http://image_service:5100/generate", json={"prompt": prompt, "style": style})
|
||||
resp.raise_for_status()
|
||||
return resp.json()["image_url"]
|
||||
except requests.RequestException as e:
|
||||
return f"[Error generating image: {e}]"
|
||||
|
||||
def search_documents(self, args: dict) -> str:
|
||||
query = args.get("query", "")
|
||||
return f"[Search results for: {query}]"
|
||||
|
||||
def run_code(self, args: dict) -> str:
|
||||
code = args.get("code", "")
|
||||
return f"[Executed code: {code}]"
|
||||
|
||||
def tts_speak(self, args: dict) -> str:
|
||||
text = args.get("text", "")
|
||||
try:
|
||||
resp = requests.post("http://tts_service:5200/speak", json={"text": text})
|
||||
resp.raise_for_status()
|
||||
return resp.json()["audio_url"]
|
||||
except requests.RequestException as e:
|
||||
return f"[Error in TTS: {e}]"
|
||||
|
||||
def stt_transcribe(self, args: dict) -> str:
|
||||
audio_url = args.get("audio_url", "")
|
||||
try:
|
||||
resp = requests.post("http://stt_service:5300/transcribe", json={"audio_url": audio_url})
|
||||
resp.raise_for_status()
|
||||
return resp.json()["text"]
|
||||
except requests.RequestException as e:
|
||||
return f"[Error in STT: {e}]"
|
||||
|
||||
# --- Prompt Assembler ---
|
||||
class PromptAssembler:
|
||||
def __init__(self, model_max_tokens: int = 8192):
|
||||
self.model_max_tokens = model_max_tokens
|
||||
|
||||
def assemble_prompt(self,
|
||||
character_config: Dict,
|
||||
memory_chunks: List[str],
|
||||
rag_context: List[str],
|
||||
conversation_history: List[Dict],
|
||||
tool_schemas: List[Dict]) -> str:
|
||||
system = character_config.get("system_prompt", "")
|
||||
memory_text = "\n".join(memory_chunks)
|
||||
context_text = "\n".join(rag_context)
|
||||
history_text = "\n".join(f'{m["role"].capitalize()}: {m["content"]}' for m in conversation_history)
|
||||
tools_text = "TOOLS:\n" + "\n".join([t.get("description", "") for t in tool_schemas])
|
||||
prompt = "\n\n".join([system, memory_text, context_text, tools_text, history_text])
|
||||
return self._trim_to_token_limit(prompt)
|
||||
|
||||
def _trim_to_token_limit(self, prompt: str) -> str:
|
||||
# Optional: truncate if exceeding model_max_tokens
|
||||
return prompt
|
||||
|
||||
# --- Utility functions ---
|
||||
def parse_tool_call(model_response: str) -> dict | None:
|
||||
try:
|
||||
data = json.loads(model_response)
|
||||
if "tool" in data and "arguments" in data:
|
||||
return data
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def handle_model_response(model_client: ModelClient,
|
||||
tool_registry: ToolRegistry,
|
||||
model_response: str) -> str:
|
||||
tool_call = parse_tool_call(model_response)
|
||||
if tool_call:
|
||||
tool_name = tool_call["tool"]
|
||||
arguments = tool_call["arguments"]
|
||||
result = tool_registry.execute(tool_name, arguments)
|
||||
return f"Tool result:\n{result}\nPlease continue the response."
|
||||
return model_response
|
||||
|
||||
# --- FastAPI App & Models ---
|
||||
app = FastAPI(title="AI Gateway")
|
||||
|
||||
model_client = ModelClient()
|
||||
memory_store = MemoryStore()
|
||||
tool_registry = ToolRegistry()
|
||||
prompt_assembler = PromptAssembler()
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
character_id: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
tools: Optional[bool] = False
|
||||
temperature: Optional[float] = 0.7
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
model: str
|
||||
input: str
|
||||
|
||||
class ToolExecutionRequest(BaseModel):
|
||||
tool_name: str
|
||||
arguments: Dict
|
||||
|
||||
# --- Endpoints ---
|
||||
@app.post("/v1/chat/completions")
|
||||
def chat_endpoint(req: ChatRequest):
|
||||
namespace = req.character_id or "user"
|
||||
|
||||
character_config = {"system_prompt": f"You are {namespace}."}
|
||||
|
||||
# Retrieve memory for RAG
|
||||
memory_chunks = memory_store.retrieve(namespace, req.messages[-1].content, top_k=5)
|
||||
rag_context = memory_chunks
|
||||
|
||||
# Assemble prompt
|
||||
prompt = prompt_assembler.assemble_prompt(
|
||||
character_config=character_config,
|
||||
memory_chunks=memory_chunks,
|
||||
rag_context=rag_context,
|
||||
conversation_history=[m.dict() for m in req.messages],
|
||||
tool_schemas=[{"description": "generate_image"}, {"description": "search_documents"}]
|
||||
)
|
||||
|
||||
# Inject prompt into chat messages so memory affects response
|
||||
messages_with_prompt = [{"role": "system", "content": prompt}] + [m.dict() for m in req.messages]
|
||||
|
||||
# Call model
|
||||
model_response = model_client.chat(messages=messages_with_prompt, character_id=req.character_id)
|
||||
|
||||
# Handle tool calls
|
||||
final_response = handle_model_response(model_client, tool_registry, model_response)
|
||||
|
||||
# Persist memory
|
||||
embedding_vector = model_client.embed(final_response)
|
||||
memory_store.store(namespace, final_response, embedding=embedding_vector)
|
||||
|
||||
return {"id": "chatcmpl-001", "choices": [{"message": {"role": "assistant", "content": final_response}}]}
|
||||
|
||||
@app.post("/v1/embeddings")
|
||||
def embeddings_endpoint(req: EmbeddingRequest):
|
||||
embedding_vector = model_client.embed(req.input)
|
||||
return {"object": "embedding", "data": embedding_vector}
|
||||
|
||||
@app.get("/v1/models")
|
||||
def models_endpoint():
|
||||
return {"data": [
|
||||
{"id": "deepseek-main", "type": "chat"},
|
||||
{"id": "bge-m3", "type": "embedding"}
|
||||
]}
|
||||
|
||||
@app.post("/v1/tools/execute")
|
||||
def tools_execute_endpoint(req: ToolExecutionRequest):
|
||||
result = tool_registry.execute(req.tool_name, req.arguments)
|
||||
if "Error" in result:
|
||||
raise HTTPException(status_code=400, detail=result)
|
||||
return {"result": result}
|
||||
|
||||
# --- main() ---
|
||||
def main():
|
||||
print("AI Gateway is ready!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1 +1,7 @@
|
||||
# Add your project dependencies here
|
||||
fastapi==0.102.0
|
||||
uvicorn[standard]==0.23.0
|
||||
requests==2.31.0
|
||||
numpy==1.27.0
|
||||
pydantic==1.10.12
|
||||
qdrant-client==1.9.1
|
||||
|
||||
15
tools/image_service.py
Normal file
15
tools/image_service.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI(title="Image Generation Stub")
|
||||
|
||||
class ImageRequest(BaseModel):
|
||||
prompt: str
|
||||
style: str = "default"
|
||||
|
||||
@app.post("/generate")
|
||||
def generate_image(req: ImageRequest):
|
||||
# Just a stub, returns a placeholder
|
||||
return {"image_url": f"https://dummyimage.com/512x512/000/fff&text={req.prompt.replace(' ', '+')}"}
|
||||
|
||||
# Run: uvicorn image_service:app --host 0.0.0.0 --port 5100
|
||||
14
tools/tts_service.py
Normal file
14
tools/tts_service.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
|
||||
app = FastAPI(title="TTS Stub")
|
||||
|
||||
class TTSRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
@app.post("/speak")
|
||||
def speak(req: TTSRequest):
|
||||
# Return a fake URL to audio file
|
||||
return {"audio_url": f"https://dummy-audio.com/{req.text[:10].replace(' ', '_')}.mp3"}
|
||||
|
||||
# Run: uvicorn tts_service:app --host 0.0.0.0 --port 5200
|
||||
Reference in New Issue
Block a user