main components from chatgpt

This commit is contained in:
2026-01-01 04:07:59 +00:00
parent 1c0a7a92f5
commit 51f76ea9ad
7 changed files with 306 additions and 5 deletions

18
dockerfile Normal file
View File

@@ -0,0 +1,18 @@
# Use official Python image
FROM python:3.11-slim
# Set working directory
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy the FastAPI app
COPY main.py .
# Expose port
EXPOSE 8000
# Start FastAPI with uvicorn
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -1,5 +0,0 @@
def main():
print("Hello, world!")
if __name__ == "__main__":
main()

253
main.py Normal file
View File

@@ -0,0 +1,253 @@
import json
import requests
import numpy as np
from typing import List, Dict, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct
# --- Model Client ---
class ModelClient:
def chat(self, messages: List[Dict], character_id: Optional[str] = None) -> str:
payload = {"model": "deepseek-main", "messages": messages, "character_id": character_id}
try:
resp = requests.post("http://vllm-chat:5001/v1/chat/completions", json=payload)
resp.raise_for_status()
return resp.json()["choices"][0]["message"]["content"]
except requests.RequestException as e:
return f"[Error calling chat model: {e}]"
def embed(self, text: str) -> List[float]:
payload = {"model": "bge-m3", "input": text}
try:
resp = requests.post("http://vllm-embed:5002/v1/embeddings", json=payload)
resp.raise_for_status()
return resp.json()["data"]
except requests.RequestException as e:
return np.random.rand(768).tolist() # fallback stub
# --- Memory Store with Qdrant ---
class MemoryStore:
def __init__(self, host="qdrant", port=6333, collection_name="memory", top_k=5):
self.client = QdrantClient(host=host, port=port)
self.collection_name = collection_name
self.top_k = top_k
if collection_name not in [c.name for c in self.client.get_collections().collections]:
self.client.recreate_collection(
collection_name=collection_name,
vectors={"size": 768, "distance": "Cosine"}
)
def store(self, namespace: str, content: str, embedding: Optional[List[float]] = None):
if embedding is None:
embedding = np.random.rand(768).tolist()
self.client.upsert(
collection_name=self.collection_name,
points=[PointStruct(
id=f"{namespace}_{np.random.randint(1e6)}",
vector=embedding,
payload={"content": content, "namespace": namespace}
)]
)
def retrieve(self, namespace: str, query: str, top_k: int = None) -> List[str]:
if top_k is None:
top_k = self.top_k
query_embedding = np.array(model_client.embed(query))
try:
hits = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding.tolist(),
query_filter={"must": [{"key": "namespace", "match": {"value": namespace}}]},
limit=top_k
)
return [hit.payload["content"] for hit in hits]
except Exception:
return []
# --- Tool Registry ---
class ToolRegistry:
def __init__(self):
self.tools = {
"generate_image": self.generate_image,
"search_documents": self.search_documents,
"run_code": self.run_code,
"tts": self.tts_speak,
"stt": self.stt_transcribe,
}
def execute(self, tool_name: str, arguments: dict) -> str:
if tool_name not in self.tools:
return f"Error: tool '{tool_name}' not registered"
return self.tools[tool_name](arguments)
def generate_image(self, args: dict) -> str:
prompt = args.get("prompt", "")
style = args.get("style", "default")
try:
resp = requests.post("http://image_service:5100/generate", json={"prompt": prompt, "style": style})
resp.raise_for_status()
return resp.json()["image_url"]
except requests.RequestException as e:
return f"[Error generating image: {e}]"
def search_documents(self, args: dict) -> str:
query = args.get("query", "")
return f"[Search results for: {query}]"
def run_code(self, args: dict) -> str:
code = args.get("code", "")
return f"[Executed code: {code}]"
def tts_speak(self, args: dict) -> str:
text = args.get("text", "")
try:
resp = requests.post("http://tts_service:5200/speak", json={"text": text})
resp.raise_for_status()
return resp.json()["audio_url"]
except requests.RequestException as e:
return f"[Error in TTS: {e}]"
def stt_transcribe(self, args: dict) -> str:
audio_url = args.get("audio_url", "")
try:
resp = requests.post("http://stt_service:5300/transcribe", json={"audio_url": audio_url})
resp.raise_for_status()
return resp.json()["text"]
except requests.RequestException as e:
return f"[Error in STT: {e}]"
# --- Prompt Assembler ---
class PromptAssembler:
def __init__(self, model_max_tokens: int = 8192):
self.model_max_tokens = model_max_tokens
def assemble_prompt(self,
character_config: Dict,
memory_chunks: List[str],
rag_context: List[str],
conversation_history: List[Dict],
tool_schemas: List[Dict]) -> str:
system = character_config.get("system_prompt", "")
memory_text = "\n".join(memory_chunks)
context_text = "\n".join(rag_context)
history_text = "\n".join(f'{m["role"].capitalize()}: {m["content"]}' for m in conversation_history)
tools_text = "TOOLS:\n" + "\n".join([t.get("description", "") for t in tool_schemas])
prompt = "\n\n".join([system, memory_text, context_text, tools_text, history_text])
return self._trim_to_token_limit(prompt)
def _trim_to_token_limit(self, prompt: str) -> str:
# Optional: truncate if exceeding model_max_tokens
return prompt
# --- Utility functions ---
def parse_tool_call(model_response: str) -> dict | None:
try:
data = json.loads(model_response)
if "tool" in data and "arguments" in data:
return data
except json.JSONDecodeError:
return None
return None
def handle_model_response(model_client: ModelClient,
tool_registry: ToolRegistry,
model_response: str) -> str:
tool_call = parse_tool_call(model_response)
if tool_call:
tool_name = tool_call["tool"]
arguments = tool_call["arguments"]
result = tool_registry.execute(tool_name, arguments)
return f"Tool result:\n{result}\nPlease continue the response."
return model_response
# --- FastAPI App & Models ---
app = FastAPI(title="AI Gateway")
model_client = ModelClient()
memory_store = MemoryStore()
tool_registry = ToolRegistry()
prompt_assembler = PromptAssembler()
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
character_id: Optional[str] = None
conversation_id: Optional[str] = None
tools: Optional[bool] = False
temperature: Optional[float] = 0.7
class EmbeddingRequest(BaseModel):
model: str
input: str
class ToolExecutionRequest(BaseModel):
tool_name: str
arguments: Dict
# --- Endpoints ---
@app.post("/v1/chat/completions")
def chat_endpoint(req: ChatRequest):
namespace = req.character_id or "user"
character_config = {"system_prompt": f"You are {namespace}."}
# Retrieve memory for RAG
memory_chunks = memory_store.retrieve(namespace, req.messages[-1].content, top_k=5)
rag_context = memory_chunks
# Assemble prompt
prompt = prompt_assembler.assemble_prompt(
character_config=character_config,
memory_chunks=memory_chunks,
rag_context=rag_context,
conversation_history=[m.dict() for m in req.messages],
tool_schemas=[{"description": "generate_image"}, {"description": "search_documents"}]
)
# Inject prompt into chat messages so memory affects response
messages_with_prompt = [{"role": "system", "content": prompt}] + [m.dict() for m in req.messages]
# Call model
model_response = model_client.chat(messages=messages_with_prompt, character_id=req.character_id)
# Handle tool calls
final_response = handle_model_response(model_client, tool_registry, model_response)
# Persist memory
embedding_vector = model_client.embed(final_response)
memory_store.store(namespace, final_response, embedding=embedding_vector)
return {"id": "chatcmpl-001", "choices": [{"message": {"role": "assistant", "content": final_response}}]}
@app.post("/v1/embeddings")
def embeddings_endpoint(req: EmbeddingRequest):
embedding_vector = model_client.embed(req.input)
return {"object": "embedding", "data": embedding_vector}
@app.get("/v1/models")
def models_endpoint():
return {"data": [
{"id": "deepseek-main", "type": "chat"},
{"id": "bge-m3", "type": "embedding"}
]}
@app.post("/v1/tools/execute")
def tools_execute_endpoint(req: ToolExecutionRequest):
result = tool_registry.execute(req.tool_name, req.arguments)
if "Error" in result:
raise HTTPException(status_code=400, detail=result)
return {"result": result}
# --- main() ---
def main():
print("AI Gateway is ready!")
if __name__ == "__main__":
main()

View File

@@ -1 +1,7 @@
# Add your project dependencies here
fastapi==0.102.0
uvicorn[standard]==0.23.0
requests==2.31.0
numpy==1.27.0
pydantic==1.10.12
qdrant-client==1.9.1

15
tools/image_service.py Normal file
View File

@@ -0,0 +1,15 @@
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI(title="Image Generation Stub")
class ImageRequest(BaseModel):
prompt: str
style: str = "default"
@app.post("/generate")
def generate_image(req: ImageRequest):
# Just a stub, returns a placeholder
return {"image_url": f"https://dummyimage.com/512x512/000/fff&text={req.prompt.replace(' ', '+')}"}
# Run: uvicorn image_service:app --host 0.0.0.0 --port 5100

14
tools/tts_service.py Normal file
View File

@@ -0,0 +1,14 @@
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI(title="TTS Stub")
class TTSRequest(BaseModel):
text: str
@app.post("/speak")
def speak(req: TTSRequest):
# Return a fake URL to audio file
return {"audio_url": f"https://dummy-audio.com/{req.text[:10].replace(' ', '_')}.mp3"}
# Run: uvicorn tts_service:app --host 0.0.0.0 --port 5200