mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-07 16:02:45 +00:00
Compare commits
6 Commits
cli/v0.2.1
...
edwin/qdra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b1e5642ef2 | ||
|
|
34c59d0540 | ||
|
|
e02dbed56a | ||
|
|
8d09761762 | ||
|
|
1058308afa | ||
|
|
c9653729de |
3
backend/.gitignore
vendored
3
backend/.gitignore
vendored
@@ -12,3 +12,6 @@ celerybeat-schedule*
|
||||
onyx/connectors/salesforce/data/
|
||||
.test.env
|
||||
/generated
|
||||
backend/scratch/qdrant/accuracy_testing/*.jsonl
|
||||
scratch/qdrant/accuracy_testing/*.jsonl
|
||||
scratch/qdrant/accuracy_testing/evaluation_results.json
|
||||
|
||||
@@ -103,6 +103,7 @@ from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
)
|
||||
from onyx.server.qdrant_search.api import router as qdrant_search_router
|
||||
from onyx.server.query_and_chat.chat_backend import router as chat_router
|
||||
from onyx.server.query_and_chat.chat_backend_v0 import router as chat_v0_router
|
||||
from onyx.server.query_and_chat.query_backend import (
|
||||
@@ -340,6 +341,7 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, qdrant_search_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_v0_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
|
||||
0
backend/onyx/server/qdrant_search/__init__.py
Normal file
0
backend/onyx/server/qdrant_search/__init__.py
Normal file
59
backend/onyx/server/qdrant_search/api.py
Normal file
59
backend/onyx/server/qdrant_search/api.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
FastAPI router for Qdrant document search endpoints.
|
||||
Provides real-time search-as-you-type functionality.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.models import User
|
||||
from onyx.server.qdrant_search.models import QdrantSearchResponse
|
||||
from onyx.server.qdrant_search.service import search_documents
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/qdrant")
|
||||
|
||||
|
||||
@router.get("/search")
|
||||
async def search_qdrant_documents(
|
||||
query: str = Query(..., min_length=1, description="Search query text"),
|
||||
limit: int = Query(10, ge=1, le=50, description="Maximum number of results"),
|
||||
_user: User | None = Depends(current_user),
|
||||
) -> QdrantSearchResponse:
|
||||
"""
|
||||
Search for documents in Qdrant using hybrid search (dense + sparse vectors).
|
||||
|
||||
This endpoint is optimized for search-as-you-type functionality with:
|
||||
- Fast hybrid search using pre-computed embeddings
|
||||
- Sub-second response times
|
||||
- Relevance scoring using Distribution-Based Score Fusion
|
||||
|
||||
Args:
|
||||
query: The search query text (minimum 1 character)
|
||||
limit: Maximum number of results to return (1-50, default 10)
|
||||
|
||||
Returns:
|
||||
QdrantSearchResponse containing matching documents with scores
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
raise HTTPException(status_code=400, detail="Query cannot be empty")
|
||||
|
||||
logger.info(f"Search request: query='{query[:50]}...', limit={limit}")
|
||||
|
||||
try:
|
||||
response = search_documents(query=query.strip(), limit=limit)
|
||||
logger.info(
|
||||
f"Search completed: {response.total_results} results for '{query[:50]}'"
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in search endpoint: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Internal server error during search: {str(e)}"
|
||||
)
|
||||
21
backend/onyx/server/qdrant_search/models.py
Normal file
21
backend/onyx/server/qdrant_search/models.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class QdrantSearchRequest(BaseModel):
|
||||
query: str
|
||||
limit: int = 10
|
||||
|
||||
|
||||
class QdrantSearchResult(BaseModel):
|
||||
document_id: str
|
||||
content: str
|
||||
filename: str | None
|
||||
source_type: str | None
|
||||
score: float
|
||||
metadata: dict | None
|
||||
|
||||
|
||||
class QdrantSearchResponse(BaseModel):
|
||||
results: list[QdrantSearchResult]
|
||||
query: str
|
||||
total_results: int
|
||||
220
backend/onyx/server/qdrant_search/service.py
Normal file
220
backend/onyx/server/qdrant_search/service.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""
|
||||
Service for performing real-time search against Qdrant vector database.
|
||||
Uses hybrid search (dense + sparse) for optimal results.
|
||||
|
||||
Implements prefix caching to accelerate search-as-you-type:
|
||||
- Pre-computed embeddings for common query prefixes
|
||||
- Cache hit: ~5-10ms embedding lookup
|
||||
- Cache miss: ~100-200ms embedding generation
|
||||
"""
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import cohere
|
||||
from fastembed import SparseTextEmbedding
|
||||
from qdrant_client.models import Fusion
|
||||
|
||||
from onyx.server.qdrant_search.models import QdrantSearchResponse
|
||||
from onyx.server.qdrant_search.models import QdrantSearchResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from scratch.qdrant.client import QdrantClient
|
||||
from scratch.qdrant.prefix_cache.prefix_to_id import prefix_to_id
|
||||
from scratch.qdrant.schemas.collection_name import CollectionName
|
||||
from scratch.qdrant.service import QdrantService
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_cohere_client() -> cohere.Client:
|
||||
"""Get cached Cohere client instance."""
|
||||
cohere_api_key = os.getenv("COHERE_API_KEY")
|
||||
if not cohere_api_key:
|
||||
raise ValueError("COHERE_API_KEY environment variable not set")
|
||||
return cohere.Client(cohere_api_key)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_sparse_embedding_model() -> SparseTextEmbedding:
|
||||
"""Get cached sparse embedding model instance."""
|
||||
# Use BM25 for sparse embeddings
|
||||
sparse_model_name = "Qdrant/bm25"
|
||||
return SparseTextEmbedding(model_name=sparse_model_name, threads=2)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_qdrant_service() -> QdrantService:
|
||||
"""Get cached Qdrant service instance."""
|
||||
client = QdrantClient()
|
||||
return QdrantService(client=client)
|
||||
|
||||
|
||||
def search_with_prefix_cache_recommend(
|
||||
query: str, qdrant_client: QdrantClient, limit: int = 10
|
||||
) -> tuple[list, bool]:
|
||||
"""
|
||||
Search using recommend endpoint with prefix cache lookup.
|
||||
|
||||
This uses Qdrant's recommend endpoint with lookup_from parameter to:
|
||||
1. Look up the prefix point ID from prefix_cache collection
|
||||
2. Use that point's vector to search the main collection
|
||||
3. All in a SINGLE API call (no separate retrieve needed!)
|
||||
|
||||
Args:
|
||||
query: The search query text
|
||||
qdrant_client: Qdrant client instance
|
||||
limit: Number of results to return
|
||||
|
||||
Returns:
|
||||
Tuple of (results, cache_hit) where cache_hit indicates if cache was used
|
||||
"""
|
||||
try:
|
||||
# Normalize query for lookup (lowercase)
|
||||
normalized_query = query.lower().strip()
|
||||
|
||||
# Convert prefix to u64 integer point ID
|
||||
point_id = prefix_to_id(normalized_query)
|
||||
|
||||
# Use recommend endpoint with lookup_from to search via prefix cache
|
||||
# This is THE key optimization from the article!
|
||||
from qdrant_client.models import LookupLocation
|
||||
|
||||
results = qdrant_client.client.recommend(
|
||||
collection_name=CollectionName.ACCURACY_TESTING,
|
||||
positive=[point_id], # u64 integer point ID from prefix
|
||||
limit=limit,
|
||||
lookup_from=LookupLocation(
|
||||
collection=CollectionName.PREFIX_CACHE
|
||||
), # Look up vector from cache!
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
if results:
|
||||
logger.info(
|
||||
f"✓ Prefix cache HIT for '{query}' (recommend with lookup_from)"
|
||||
)
|
||||
return (results, True)
|
||||
else:
|
||||
logger.info(f"✗ Prefix cache MISS for '{query}' (point not found in cache)")
|
||||
return ([], False)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error using prefix cache recommend: {e}")
|
||||
return ([], False)
|
||||
|
||||
|
||||
def embed_query_with_cohere(
|
||||
query_text: str,
|
||||
cohere_client: cohere.Client,
|
||||
model: str = "embed-english-v3.0",
|
||||
) -> list[float]:
|
||||
"""
|
||||
Embed query text using Cohere API.
|
||||
|
||||
Args:
|
||||
query_text: The search query
|
||||
cohere_client: Initialized Cohere client
|
||||
model: Cohere model name
|
||||
|
||||
Returns:
|
||||
Dense embedding vector
|
||||
"""
|
||||
response = cohere_client.embed(
|
||||
texts=[query_text],
|
||||
model=model,
|
||||
input_type="search_query", # Important: use search_query for queries
|
||||
)
|
||||
return response.embeddings[0]
|
||||
|
||||
|
||||
def search_documents(query: str, limit: int = 10) -> QdrantSearchResponse:
|
||||
"""
|
||||
Perform hybrid search on Qdrant collection with prefix caching.
|
||||
|
||||
Strategy (from Qdrant article):
|
||||
1. Try recommend with lookup_from prefix_cache (cache hit: SINGLE API call!)
|
||||
2. If cache miss, generate embeddings and do hybrid search (~100-200ms)
|
||||
|
||||
Args:
|
||||
query: Search query text
|
||||
limit: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
QdrantSearchResponse with search results
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Searching for query: '{query[:50]}'")
|
||||
|
||||
# Get client instances
|
||||
qdrant_client = QdrantClient()
|
||||
|
||||
# Try prefix cache with recommend endpoint (article's approach!)
|
||||
cache_results, cache_hit = search_with_prefix_cache_recommend(
|
||||
query, qdrant_client, limit
|
||||
)
|
||||
|
||||
if cache_hit:
|
||||
# Cache HIT - use results from recommend (single API call!)
|
||||
search_points = cache_results
|
||||
else:
|
||||
# Cache MISS - fall back to on-the-fly embedding + hybrid search
|
||||
logger.info(f"Generating embeddings for query: {query[:50]}...")
|
||||
|
||||
qdrant_service = get_qdrant_service()
|
||||
cohere_client = get_cohere_client()
|
||||
sparse_model = get_sparse_embedding_model()
|
||||
|
||||
# Generate dense embedding with Cohere
|
||||
dense_vector = embed_query_with_cohere(query, cohere_client)
|
||||
|
||||
# Generate sparse embedding
|
||||
sparse_embedding = next(sparse_model.query_embed(query))
|
||||
from qdrant_client.models import SparseVector
|
||||
|
||||
sparse_vector = SparseVector(
|
||||
indices=sparse_embedding.indices.tolist(),
|
||||
values=sparse_embedding.values.tolist(),
|
||||
)
|
||||
|
||||
# Perform hybrid search
|
||||
logger.info("Performing hybrid search...")
|
||||
search_results = qdrant_service.hybrid_search(
|
||||
dense_query_vector=dense_vector,
|
||||
sparse_query_vector=sparse_vector,
|
||||
collection_name=CollectionName.ACCURACY_TESTING,
|
||||
limit=limit,
|
||||
fusion=Fusion.DBSF, # Distribution-Based Score Fusion
|
||||
)
|
||||
search_points = search_results.points
|
||||
|
||||
# Convert results to response format (works for both recommend and search)
|
||||
results = []
|
||||
for point in search_points:
|
||||
payload = point.payload or {}
|
||||
result = QdrantSearchResult(
|
||||
document_id=payload.get("document_id", ""),
|
||||
content=payload.get("content", "")[:500], # Limit content preview
|
||||
filename=payload.get("filename"),
|
||||
source_type=payload.get("source_type"),
|
||||
score=point.score if point.score else 0.0,
|
||||
metadata=payload.get("metadata"),
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
logger.info(f"Found {len(results)} results for query: {query[:50]}")
|
||||
|
||||
return QdrantSearchResponse(
|
||||
results=results,
|
||||
query=query,
|
||||
total_results=len(results),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching documents: {e}")
|
||||
# Return empty results on error
|
||||
return QdrantSearchResponse(
|
||||
results=[],
|
||||
query=query,
|
||||
total_results=0,
|
||||
)
|
||||
452
backend/scratch/qdrant/SEARCH_AS_YOU_TYPE_IMPLEMENTATION.md
Normal file
452
backend/scratch/qdrant/SEARCH_AS_YOU_TYPE_IMPLEMENTATION.md
Normal file
@@ -0,0 +1,452 @@
|
||||
# Search-as-You-Type MVP Implementation Summary
|
||||
|
||||
## Overview
|
||||
|
||||
Successfully implemented a production-ready search-as-you-type system using Qdrant vector database with prefix caching optimization, **following the exact architecture from https://qdrant.tech/articles/search-as-you-type/**.
|
||||
|
||||
## Key Components Implemented
|
||||
|
||||
### 1. Backend Infrastructure
|
||||
|
||||
#### **Prefix Cache System** (Main Optimization from Article)
|
||||
|
||||
**The Core Idea:**
|
||||
- Pre-compute embeddings for common query prefixes
|
||||
- Store them with **prefix encoded as u64 point ID**
|
||||
- Use Qdrant's `recommend()` endpoint with `lookup_from` parameter
|
||||
- **Result**: Search without ANY embedding computation!
|
||||
|
||||
**Implementation Details:**
|
||||
- **Collection**: `prefix_cache` with ~10,000 pre-computed query prefixes
|
||||
- **Point ID Encoding**: Prefix string → u64 integer (e.g., `"docker"` → `125779918942052`)
|
||||
- Uses up to 8 ASCII bytes encoded as little-endian integer
|
||||
- See: `backend/scratch/qdrant/prefix_cache/prefix_to_id.py`
|
||||
- **Schema**: Dense (Cohere 1024-dim) + Sparse (BM25) embeddings
|
||||
- **Coverage**:
|
||||
- All 26 single-char prefixes (a-z)
|
||||
- All 708 two-char prefixes
|
||||
- Top 2,316 three-char prefixes (by frequency)
|
||||
- Top 3,243 four-char prefixes
|
||||
- Top 3,706 five-char prefixes
|
||||
- **Source**: Extracted from actual corpus (`target_docs.jsonl`)
|
||||
- Filenames: `mattermost`, `workflow`, `gitlab`
|
||||
- Content words: `docker`, `issue`, `customer`, `support`, `team`, `code`, `data`
|
||||
- **Location**: `backend/scratch/qdrant/prefix_cache/`
|
||||
|
||||
#### **Search Service** (`backend/onyx/server/qdrant_search/service.py`)
|
||||
|
||||
**Optimized Two-Tier Strategy (from article):**
|
||||
|
||||
1. **Cache HIT Path** (SINGLE API call!):
|
||||
```python
|
||||
point_id = prefix_to_id("docker") # Convert to u64: 125779918942052
|
||||
results = client.recommend(
|
||||
collection_name="accuracy_testing",
|
||||
positive=[point_id],
|
||||
lookup_from="prefix_cache", # Qdrant retrieves vector from cache!
|
||||
limit=10
|
||||
)
|
||||
# ~5-50ms total latency
|
||||
```
|
||||
|
||||
2. **Cache MISS Path** (fallback):
|
||||
```python
|
||||
# Generate embeddings on-the-fly
|
||||
dense_vector = embed_with_cohere(query) # ~100ms
|
||||
sparse_vector = embed_with_bm25(query) # ~10ms
|
||||
# Hybrid search with DBSF fusion # ~50ms
|
||||
# ~200ms total latency
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- **Recommend with lookup_from**: Single API call for cache hits
|
||||
- **u64 point ID encoding**: O(1) lookup by integer ID
|
||||
- **BM25 sparse embeddings**: Fast and effective
|
||||
- **LRU caching**: Model instances cached for performance
|
||||
- **Hybrid search fallback**: DBSF fusion for cache misses
|
||||
|
||||
#### **API Endpoint** (`/api/qdrant/search`)
|
||||
- **FastAPI router**: `backend/onyx/server/qdrant_search/api.py`
|
||||
- **Params**: `query` (min 1 char), `limit` (1-50, default 10)
|
||||
- **Response**: Documents with relevance scores + metadata
|
||||
- **Registered**: `/api/qdrant/search` in `backend/onyx/main.py`
|
||||
|
||||
### 2. Frontend Enhancements
|
||||
|
||||
#### **Text Highlighting** (`web/src/app/chat/chat_search/utils/highlightText.tsx`)
|
||||
- Highlights matching query terms in results
|
||||
- Supports multi-word queries
|
||||
- Styled with yellow highlight (light/dark mode)
|
||||
- Applied to both filename and content
|
||||
|
||||
#### **Keyboard Navigation**
|
||||
- **Arrow Up/Down**: Navigate through results
|
||||
- **Enter**: Select highlighted result
|
||||
- **Visual feedback**: Blue ring + background for selected item
|
||||
- **Auto-scroll**: Selected item scrolls into view
|
||||
- **Hint**: Shows "Use ↑↓ arrow keys to navigate, Enter to select"
|
||||
|
||||
#### **Enhanced DocumentSearchResults Component**
|
||||
- Integrated highlighting for filename and content
|
||||
- Keyboard navigation support
|
||||
- Improved accessibility (ARIA attributes)
|
||||
- Loading states and empty states
|
||||
- Visual selection with blue ring
|
||||
|
||||
### 3. Data
|
||||
|
||||
**Collection**: `accuracy_testing`
|
||||
- **Documents**: ~14,353 chunks from `target_docs.jsonl`
|
||||
- **Source**: GitLab Slack workspace (Docker/Kubernetes/DevOps discussions)
|
||||
- **Embeddings**: Cohere embed-english-v3.0 (dense) + BM25 (sparse)
|
||||
|
||||
**Collection**: `prefix_cache`
|
||||
- **Prefixes**: 9,999 ASCII-only prefixes (1-5 chars)
|
||||
- **Point IDs**: u64 integer encoding of prefix strings
|
||||
- **Embeddings**: Same as accuracy_testing (Cohere + BM25)
|
||||
|
||||
## Performance Results
|
||||
|
||||
### Prefix Cache Performance
|
||||
```
|
||||
Cache HITs: ~5-50ms (single recommend API call!)
|
||||
Cache MISSes: ~200ms (embedding generation + search)
|
||||
Speedup: 4-10x faster for cached queries
|
||||
```
|
||||
|
||||
### First Query Notes
|
||||
- First query: ~700-800ms (model loading overhead)
|
||||
- Subsequent cache hits: ~5-50ms
|
||||
- Subsequent cache misses: ~200ms
|
||||
|
||||
## Architecture Highlights
|
||||
|
||||
### Search Flow (Optimized per Article)
|
||||
|
||||
```
|
||||
User types → Frontend (500ms debounce) → API endpoint
|
||||
↓
|
||||
Try: recommend(lookup_from=prefix_cache)
|
||||
↙ ↘
|
||||
Cache HIT Cache MISS
|
||||
(point ID exists) (point not found)
|
||||
(~5-50ms) (~200ms)
|
||||
↓ ↓
|
||||
Qdrant retrieves vector Generate embeddings
|
||||
from prefix_cache and (Cohere + BM25)
|
||||
searches accuracy_testing ↓
|
||||
↓ Hybrid Search
|
||||
↓ (DBSF fusion)
|
||||
↘ ↙
|
||||
Results
|
||||
```
|
||||
|
||||
### Key Optimizations from Qdrant Article
|
||||
|
||||
✅ **u64 Point ID Encoding**: Prefix string → integer for O(1) lookup
|
||||
✅ **Recommend with lookup_from**: Single API call (no retrieve + search)
|
||||
✅ **Prefix Caching**: Pre-compute embeddings for 10k common prefixes
|
||||
✅ **BM25 Sparse Embeddings**: Fast and effective
|
||||
✅ **Hybrid Search**: Dense + sparse vectors with DBSF fusion
|
||||
✅ **Debouncing**: 500ms delay to reduce API calls
|
||||
✅ **Request Cancellation**: AbortController for in-flight requests
|
||||
✅ **Model Caching**: LRU cache for embedding models
|
||||
|
||||
## Files Created/Modified
|
||||
|
||||
### New Files
|
||||
```
|
||||
backend/scratch/qdrant/prefix_cache/
|
||||
├── __init__.py
|
||||
├── create_prefix_cache_collection.py
|
||||
├── populate_prefix_cache.py
|
||||
├── prefix_to_id.py # u64 encoding/decoding
|
||||
└── extract_all_prefixes.py # Corpus analysis (10k scale)
|
||||
|
||||
backend/scratch/qdrant/schemas/
|
||||
└── prefix_cache.py
|
||||
|
||||
backend/scratch/qdrant/
|
||||
├── test_prefix_cache_performance.py
|
||||
└── test_search_endpoint.py
|
||||
|
||||
web/src/app/chat/chat_search/utils/
|
||||
└── highlightText.tsx
|
||||
```
|
||||
|
||||
### Modified Files
|
||||
```
|
||||
backend/onyx/server/qdrant_search/service.py # recommend() with lookup_from
|
||||
backend/scratch/qdrant/schemas/collection_name.py # added PREFIX_CACHE
|
||||
backend/scratch/qdrant/accuracy_testing/upload_chunks.py # switched to BM25
|
||||
web/src/app/chat/chat_search/components/DocumentSearchResults.tsx # keyboard nav + highlighting
|
||||
web/src/app/chat/chat_search/ChatSearchModal.tsx # pass searchQuery prop
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Extracting Prefixes from Corpus
|
||||
```bash
|
||||
# Extract ~10k most popular prefixes from your documents
|
||||
python -m scratch.qdrant.prefix_cache.extract_all_prefixes
|
||||
|
||||
# This analyzes:
|
||||
# - Filenames (e.g., "mattermost.docx" → "matte", "matter", etc.)
|
||||
# - URLs (domain names, paths)
|
||||
# - Document content (all words, stop words filtered)
|
||||
# - Generates 1-5 character prefixes
|
||||
# - Selects top 10k by frequency
|
||||
```
|
||||
|
||||
### Creating/Populating Prefix Cache
|
||||
```bash
|
||||
# 1. Create collection
|
||||
python -m scratch.qdrant.prefix_cache.create_prefix_cache_collection
|
||||
|
||||
# 2. Populate with corpus-derived prefixes (uses u64 point IDs)
|
||||
python -m scratch.qdrant.prefix_cache.populate_prefix_cache
|
||||
|
||||
# 3. Verify
|
||||
python -m scratch.qdrant.get_collection_status
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
```bash
|
||||
# Test search functionality
|
||||
python -m dotenv -f .vscode/.env run -- \
|
||||
python -m scratch.qdrant.test_search_endpoint
|
||||
|
||||
# Test prefix cache performance
|
||||
python -m dotenv -f .vscode/.env run -- \
|
||||
python -m scratch.qdrant.test_prefix_cache_performance
|
||||
|
||||
# Test prefix encoding
|
||||
python -m scratch.qdrant.prefix_cache.prefix_to_id
|
||||
```
|
||||
|
||||
### API Usage
|
||||
```bash
|
||||
# Search via API (goes through frontend)
|
||||
curl "http://localhost:3000/api/qdrant/search?query=docker&limit=5"
|
||||
|
||||
# Check collection status
|
||||
python -m scratch.qdrant.get_collection_status
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Prefix to u64 Encoding
|
||||
|
||||
```python
|
||||
def prefix_to_id(prefix: str) -> int:
|
||||
"""
|
||||
Convert prefix string to u64 integer.
|
||||
|
||||
Examples:
|
||||
"a" → 97
|
||||
"docker" → 125779918942052
|
||||
"gitlab" → 108170570918247
|
||||
"""
|
||||
# Encode as ASCII bytes (up to 8 chars)
|
||||
prefix_bytes = prefix.encode('ascii')
|
||||
|
||||
# Pad to 8 bytes, convert to integer
|
||||
padded = prefix_bytes.ljust(8, b'\x00')
|
||||
return int.from_bytes(padded, byteorder='little')
|
||||
```
|
||||
|
||||
### Search Service Logic
|
||||
|
||||
```python
|
||||
def search_documents(query: str, limit: int = 10):
|
||||
# Normalize and convert to u64 ID
|
||||
normalized = query.lower().strip()
|
||||
point_id = prefix_to_id(normalized)
|
||||
|
||||
# Try cache with recommend (single API call!)
|
||||
try:
|
||||
results = client.recommend(
|
||||
collection_name="accuracy_testing",
|
||||
positive=[point_id],
|
||||
lookup_from="prefix_cache",
|
||||
limit=limit
|
||||
)
|
||||
# ✓ Cache HIT - return results
|
||||
return format_results(results)
|
||||
except:
|
||||
# ✗ Cache MISS - generate embeddings
|
||||
dense = embed_with_cohere(query)
|
||||
sparse = embed_with_bm25(query)
|
||||
results = hybrid_search(dense, sparse, limit)
|
||||
return format_results(results)
|
||||
```
|
||||
|
||||
## Corpus Analysis Results
|
||||
|
||||
**Analyzed**: 11,886 documents from `target_docs.jsonl`
|
||||
**Unique Words**: 68,050 (ASCII-only, stop words filtered)
|
||||
**Total Prefixes Generated**: 37,175 (1-5 chars)
|
||||
**Selected for Cache**: 9,999 most popular prefixes
|
||||
|
||||
**Top Words in Corpus:**
|
||||
1. gitlab (24,788 occurrences)
|
||||
2. team (24,398)
|
||||
3. use (16,687)
|
||||
4. data (15,839)
|
||||
5. issue (11,765)
|
||||
6. code (10,885)
|
||||
7. customer (10,673)
|
||||
8. support (10,504)
|
||||
9. user (9,935)
|
||||
10. product (9,793)
|
||||
|
||||
**Prefix Distribution:**
|
||||
- 1-char: 26 prefixes (all)
|
||||
- 2-char: 708 prefixes (all)
|
||||
- 3-char: 2,316 prefixes (most popular)
|
||||
- 4-char: 3,243 prefixes (most popular)
|
||||
- 5-char: 3,706 prefixes (most popular)
|
||||
|
||||
## Next Steps / Future Enhancements
|
||||
|
||||
### Immediate
|
||||
1. ✅ **Browser testing**: Test end-to-end functionality in the web UI
|
||||
2. **Document preview**: Implement click handler to view full documents
|
||||
3. **Analytics**: Track prefix cache hit rates for optimization
|
||||
4. **Error handling**: Better UX for non-ASCII queries
|
||||
|
||||
### Future
|
||||
1. **Dynamic cache warming**: Populate cache based on real query patterns
|
||||
2. **Query suggestions**: Show autocomplete suggestions from cache
|
||||
3. **Multi-language support**: Extend beyond ASCII (use UUID encoding)
|
||||
4. **Result ranking improvements**: Prioritize title matches
|
||||
5. **Streaming results**: Show results as they arrive for long queries
|
||||
6. **Cache analytics**: Monitor hit rates, update popular prefixes
|
||||
|
||||
## Configuration
|
||||
|
||||
### Environment Variables Required
|
||||
```bash
|
||||
COHERE_API_KEY=your_cohere_api_key
|
||||
QDRANT_URL=http://localhost:6333 # Default Qdrant URL
|
||||
```
|
||||
|
||||
### Frontend Config
|
||||
- **Debounce**: 500ms (configurable in `useQdrantSearch.ts:21`)
|
||||
- **Result limit**: 10 documents (configurable via API param)
|
||||
- **Enabled**: Only when modal is open and query is non-empty
|
||||
|
||||
### Backend Config
|
||||
- **Sparse Model**: `Qdrant/bm25` (BM25 embeddings)
|
||||
- **Dense Model**: `embed-english-v3.0` (Cohere, 1024 dimensions)
|
||||
- **Fusion**: DBSF (Distribution-Based Score Fusion)
|
||||
- **Collection**: `accuracy_testing` (main documents)
|
||||
- **Prefix Cache**: `prefix_cache` (10k prefixes with u64 IDs)
|
||||
|
||||
## Performance Benchmarks
|
||||
|
||||
### Target Metrics (from Qdrant article)
|
||||
- Search latency: <100ms for cache hits ✅
|
||||
- Search latency: <200ms for cache misses ✅
|
||||
- User experience: Feels instant for common queries ✅
|
||||
|
||||
### Actual Results
|
||||
```
|
||||
Cache HITs: 5-50ms (recommend with lookup_from)
|
||||
Cache MISSes: ~200ms (embedding + hybrid search)
|
||||
First query: ~800ms (model loading)
|
||||
Speedup: 4-10x for cached queries
|
||||
```
|
||||
|
||||
### Prefix Cache Coverage
|
||||
- **9,999 prefixes** covering most common search patterns
|
||||
- **~37k total available** prefixes in corpus
|
||||
- **27% coverage** optimized for frequency
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Prefix ID Encoding
|
||||
|
||||
The article recommends using the prefix itself as the point ID. Since Qdrant requires integer or UUID IDs, we encode the prefix string as a u64 integer:
|
||||
|
||||
```python
|
||||
# Encoding: ASCII bytes → u64 integer
|
||||
"a" → [0x61, 0, 0, 0, 0, 0, 0, 0] → 97
|
||||
"docker" → [0x64, 0x6f, 0x63, 0x6b, 0x65, 0x72, 0, 0] → 125779918942052
|
||||
"gitlab" → [0x67, 0x69, 0x74, 0x6c, 0x61, 0x62, 0, 0] → 108170570918247
|
||||
```
|
||||
|
||||
### Recommend Endpoint Usage
|
||||
|
||||
From the Qdrant article, the key optimization is using `recommend()` with `lookup_from`:
|
||||
|
||||
```python
|
||||
# Traditional approach (2 API calls):
|
||||
# 1. Retrieve point from cache: GET /collections/prefix_cache/points/{id}
|
||||
# 2. Search with vector: POST /collections/site/points/search
|
||||
|
||||
# Optimized approach (1 API call):
|
||||
POST /collections/accuracy_testing/points/recommend
|
||||
{
|
||||
"positive": [125779918942052], // u64 ID for "docker"
|
||||
"limit": 10,
|
||||
"lookup_from": {
|
||||
"collection": "prefix_cache"
|
||||
}
|
||||
}
|
||||
# Qdrant automatically:
|
||||
# 1. Looks up point 125779918942052 from prefix_cache
|
||||
# 2. Takes its vector
|
||||
# 3. Searches in accuracy_testing
|
||||
# All in ONE API call!
|
||||
```
|
||||
|
||||
### BM25 vs Splade
|
||||
|
||||
We use **BM25** for sparse embeddings (not Splade) because:
|
||||
- Faster inference (~1ms vs ~10ms)
|
||||
- Simpler model
|
||||
- Good enough for search-as-you-type
|
||||
- Recommended by Qdrant for this use case
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Non-ASCII Characters
|
||||
The prefix cache only supports ASCII characters (a-z, 0-9). Non-ASCII queries will:
|
||||
- Fail on `prefix_to_id()` encoding
|
||||
- Fall back to cache MISS path (on-the-fly embedding)
|
||||
- Still work, just slower (~200ms vs ~50ms)
|
||||
|
||||
Future: Could use UUID encoding for non-ASCII support.
|
||||
|
||||
### Empty Results
|
||||
If no results are returned:
|
||||
- Check that `accuracy_testing` collection has data
|
||||
- Verify embeddings match (both use Cohere + BM25)
|
||||
- Check Qdrant logs for errors
|
||||
|
||||
### Slow First Query
|
||||
First query is always slow (~800ms) due to:
|
||||
- Cohere client initialization
|
||||
- BM25 model loading into memory
|
||||
- Subsequent queries are much faster
|
||||
|
||||
## References
|
||||
|
||||
- **Qdrant Article**: https://qdrant.tech/articles/search-as-you-type/
|
||||
- **Cohere Embeddings**: https://docs.cohere.com/reference/embed
|
||||
- **FastEmbed (BM25)**: https://github.com/qdrant/fastembed
|
||||
- **Qdrant Recommend**: https://qdrant.tech/documentation/concepts/search/#recommendation-api
|
||||
|
||||
## Summary
|
||||
|
||||
This implementation follows the **exact architecture from the Qdrant article**:
|
||||
|
||||
1. ✅ **Prefix cache collection** with u64 point IDs
|
||||
2. ✅ **recommend() endpoint** with lookup_from parameter
|
||||
3. ✅ **Single API call** for cache hits (no embedding needed)
|
||||
4. ✅ **~10k corpus-derived prefixes** (not generic guesses)
|
||||
5. ✅ **BM25 sparse embeddings** for speed
|
||||
6. ✅ **Frontend enhancements** (highlighting + keyboard nav)
|
||||
|
||||
**Result**: Production-ready search-as-you-type with <50ms latency for cached queries!
|
||||
0
backend/scratch/qdrant/__init__.py
Normal file
0
backend/scratch/qdrant/__init__.py
Normal file
1
backend/scratch/qdrant/accuracy_testing/__init__.py
Normal file
1
backend/scratch/qdrant/accuracy_testing/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Accuracy testing package for Qdrant experiments
|
||||
529
backend/scratch/qdrant/accuracy_testing/evaluate_retrieval.py
Normal file
529
backend/scratch/qdrant/accuracy_testing/evaluate_retrieval.py
Normal file
@@ -0,0 +1,529 @@
|
||||
"""
|
||||
Evaluation script for testing retrieval accuracy on the ACCURACY_TESTING collection.
|
||||
|
||||
Loads questions from target_questions.jsonl, performs searches using Cohere embeddings,
|
||||
and evaluates if the correct documents are retrieved.
|
||||
|
||||
Metrics:
|
||||
- Top-1 accuracy: Correct document is in position 1
|
||||
- Top-5 accuracy: Correct document is in top 5
|
||||
- Top-10 accuracy: Correct document is in top 10
|
||||
- MRR (Mean Reciprocal Rank): Average of 1/rank for correct documents
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
|
||||
import cohere
|
||||
from dotenv import load_dotenv
|
||||
from fastembed import SparseTextEmbedding
|
||||
from qdrant_client.models import Filter
|
||||
from qdrant_client.models import Fusion
|
||||
from qdrant_client.models import FusionQuery
|
||||
from qdrant_client.models import Prefetch
|
||||
from qdrant_client.models import SparseVector
|
||||
|
||||
from scratch.qdrant.accuracy_testing.target_document_schema import TargetQuestion
|
||||
from scratch.qdrant.client import QdrantClient
|
||||
from scratch.qdrant.schemas.collection_name import CollectionName
|
||||
|
||||
|
||||
def load_questions(jsonl_path: Path) -> list[TargetQuestion]:
|
||||
"""Load questions from JSONL file."""
|
||||
questions = []
|
||||
with open(jsonl_path, "r") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
questions.append(TargetQuestion(**data))
|
||||
return questions
|
||||
|
||||
|
||||
def embed_query_with_cohere(
|
||||
query: str,
|
||||
cohere_client: cohere.Client,
|
||||
model: str = "embed-english-v3.0",
|
||||
) -> list[float]:
|
||||
"""Embed a single query using Cohere."""
|
||||
response = cohere_client.embed(
|
||||
texts=[query],
|
||||
model=model,
|
||||
input_type="search_query", # Use search_query for queries
|
||||
)
|
||||
return response.embeddings[0]
|
||||
|
||||
|
||||
def embed_query_with_bm25(
|
||||
query: str,
|
||||
sparse_embedding_model: SparseTextEmbedding,
|
||||
) -> SparseVector:
|
||||
"""Embed a single query using BM25."""
|
||||
sparse_embedding = next(sparse_embedding_model.query_embed(query))
|
||||
return SparseVector(
|
||||
indices=sparse_embedding.indices.tolist(),
|
||||
values=sparse_embedding.values.tolist(),
|
||||
)
|
||||
|
||||
|
||||
def hybrid_search_qdrant(
|
||||
dense_query_vector: list[float],
|
||||
sparse_query_vector: SparseVector,
|
||||
qdrant_client: QdrantClient,
|
||||
collection_name: CollectionName,
|
||||
limit: int = 10,
|
||||
prefetch_limit: int | None = None,
|
||||
query_filter: Filter | None = None,
|
||||
):
|
||||
"""Perform hybrid search using both dense and sparse vectors."""
|
||||
# If prefetch_limit not specified, use limit * 2
|
||||
effective_prefetch_limit = (
|
||||
prefetch_limit if prefetch_limit is not None else limit * 2
|
||||
)
|
||||
|
||||
return qdrant_client.query_points(
|
||||
collection_name=collection_name,
|
||||
prefetch=[
|
||||
Prefetch(
|
||||
query=sparse_query_vector,
|
||||
using="sparse",
|
||||
limit=effective_prefetch_limit,
|
||||
filter=query_filter,
|
||||
),
|
||||
Prefetch(
|
||||
query=dense_query_vector,
|
||||
using="dense",
|
||||
limit=effective_prefetch_limit,
|
||||
filter=query_filter,
|
||||
),
|
||||
],
|
||||
fusion_query=FusionQuery(fusion=Fusion.DBSF),
|
||||
with_payload=True,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
def extract_ground_truth_doc_ids(question: TargetQuestion) -> set[str]:
|
||||
"""
|
||||
Extract ground truth document IDs from question metadata.
|
||||
|
||||
For file-based sources: uses the 'source' field
|
||||
Example: "company_policies/Succession-planning-policy.docx"
|
||||
|
||||
For Slack messages: uses the 'thread_ts' field
|
||||
Example: "1706889457.275089"
|
||||
"""
|
||||
doc_ids = set()
|
||||
for doc_source in question.metadata.doc_source:
|
||||
if doc_source.source:
|
||||
# File-based source
|
||||
doc_ids.add(doc_source.source)
|
||||
elif doc_source.thread_ts:
|
||||
# Slack message - use thread_ts as document_id
|
||||
doc_ids.add(doc_source.thread_ts)
|
||||
return doc_ids
|
||||
|
||||
|
||||
def evaluate_search_results(
|
||||
search_results,
|
||||
ground_truth_doc_ids: set[str],
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate search results against ground truth with deduplication.
|
||||
|
||||
Matches on either document_id OR filename to handle both hash IDs and filenames.
|
||||
|
||||
For multi-document ground truth, deduplicates retrieved docs to get unique
|
||||
documents before checking recall.
|
||||
|
||||
Example:
|
||||
Ground truth: {A, B}
|
||||
Retrieved: [A, A, A, B, C] -> Deduplicated: [A, B, C]
|
||||
Recall@3: 2/2 = 1.0 (100%)
|
||||
|
||||
Returns:
|
||||
Dict with recall metrics at different k values
|
||||
"""
|
||||
# Extract both document_id and filename for matching
|
||||
retrieved_docs = []
|
||||
for point in search_results.points:
|
||||
doc_id = point.payload.get("document_id")
|
||||
filename = point.payload.get("filename")
|
||||
retrieved_docs.append((doc_id, filename))
|
||||
|
||||
# Helper to normalize paths (convert ~ to / for comparison)
|
||||
def normalize_path(path: str) -> str:
|
||||
return path.replace("~", "/")
|
||||
|
||||
# Helper to check if a doc matches ground truth (by document_id OR filename)
|
||||
def matches_ground_truth(doc_id: str, filename: str | None) -> bool:
|
||||
# Check document_id directly
|
||||
if doc_id in ground_truth_doc_ids:
|
||||
return True
|
||||
# Check filename with normalization and suffix matching
|
||||
if filename:
|
||||
normalized_filename = normalize_path(filename)
|
||||
# Check if any ground truth matches or ends with the normalized filename
|
||||
for gt_id in ground_truth_doc_ids:
|
||||
normalized_gt = normalize_path(gt_id)
|
||||
# Exact match
|
||||
if normalized_gt == normalized_filename:
|
||||
return True
|
||||
# Ground truth ends with filename (handles missing prefixes like "company_wiki/")
|
||||
if normalized_gt.endswith(normalized_filename):
|
||||
return True
|
||||
# Filename ends with ground truth (opposite case)
|
||||
if normalized_filename.endswith(normalized_gt):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Deduplicate while preserving order (by document_id)
|
||||
seen = set()
|
||||
deduplicated_doc_ids = []
|
||||
for doc_id, filename in retrieved_docs:
|
||||
if doc_id not in seen:
|
||||
seen.add(doc_id)
|
||||
deduplicated_doc_ids.append((doc_id, filename))
|
||||
|
||||
# Find rank of first correct document (for MRR)
|
||||
first_correct_rank = None
|
||||
for rank, (doc_id, filename) in enumerate(deduplicated_doc_ids, start=1):
|
||||
if matches_ground_truth(doc_id, filename):
|
||||
first_correct_rank = rank
|
||||
break
|
||||
|
||||
# Calculate recall at different k values (using deduplicated results)
|
||||
def recall_at_k(k: int) -> float:
|
||||
"""Calculate recall@k: fraction of ground truth docs found in top k"""
|
||||
top_k_tuples = deduplicated_doc_ids[:k]
|
||||
# Check each doc against ground truth (match on document_id OR filename)
|
||||
found_count = sum(
|
||||
1
|
||||
for doc_id, filename in top_k_tuples
|
||||
if matches_ground_truth(doc_id, filename)
|
||||
)
|
||||
return found_count / len(ground_truth_doc_ids) if ground_truth_doc_ids else 0.0
|
||||
|
||||
# Calculate recall at multiple k values
|
||||
recall_metrics = {
|
||||
"recall_at_1": recall_at_k(1),
|
||||
"recall_at_3": recall_at_k(3),
|
||||
"recall_at_5": recall_at_k(5),
|
||||
"recall_at_10": recall_at_k(10),
|
||||
"recall_at_25": recall_at_k(25),
|
||||
"recall_at_50": recall_at_k(50),
|
||||
}
|
||||
|
||||
# Perfect recall (all ground truth docs found) at different k
|
||||
# For display, prefer filename over hash document_id when available
|
||||
deduplicated_display = [
|
||||
filename if filename else doc_id
|
||||
for doc_id, filename in deduplicated_doc_ids[:50]
|
||||
]
|
||||
|
||||
return {
|
||||
"top_1_hit": recall_metrics["recall_at_1"] == 1.0,
|
||||
"top_3_hit": recall_metrics["recall_at_3"] == 1.0,
|
||||
"top_5_hit": recall_metrics["recall_at_5"] == 1.0,
|
||||
"top_10_hit": recall_metrics["recall_at_10"] == 1.0,
|
||||
**recall_metrics,
|
||||
"reciprocal_rank": 1.0 / first_correct_rank if first_correct_rank else 0.0,
|
||||
"first_correct_rank": first_correct_rank,
|
||||
"num_ground_truth": len(ground_truth_doc_ids),
|
||||
"retrieved_doc_ids": [
|
||||
filename if filename else doc_id for doc_id, filename in retrieved_docs[:50]
|
||||
],
|
||||
"deduplicated_doc_ids": deduplicated_display, # Keep deduplicated top 50
|
||||
}
|
||||
|
||||
|
||||
def evaluate_single_question(
|
||||
question: TargetQuestion,
|
||||
cohere_client: cohere.Client,
|
||||
sparse_embedding_model: SparseTextEmbedding,
|
||||
qdrant_client: QdrantClient,
|
||||
collection_name: CollectionName,
|
||||
cohere_model: str,
|
||||
) -> dict:
|
||||
"""Evaluate a single question using hybrid search. This will be run in parallel."""
|
||||
# Embed query with Cohere (dense)
|
||||
dense_query_vector = embed_query_with_cohere(
|
||||
question.question, cohere_client, cohere_model
|
||||
)
|
||||
|
||||
# Embed query with BM25 (sparse)
|
||||
sparse_query_vector = embed_query_with_bm25(
|
||||
question.question, sparse_embedding_model
|
||||
)
|
||||
|
||||
# Hybrid search (retrieve 50 to calculate recall@25 and recall@50)
|
||||
search_results = hybrid_search_qdrant(
|
||||
dense_query_vector,
|
||||
sparse_query_vector,
|
||||
qdrant_client,
|
||||
collection_name,
|
||||
limit=50,
|
||||
)
|
||||
|
||||
# Extract ground truth
|
||||
ground_truth_doc_ids = extract_ground_truth_doc_ids(question)
|
||||
|
||||
# Evaluate
|
||||
eval_result = evaluate_search_results(search_results, ground_truth_doc_ids)
|
||||
|
||||
return {
|
||||
"question_uid": question.uid,
|
||||
"question": question.question,
|
||||
"ground_truth_doc_ids": list(ground_truth_doc_ids),
|
||||
**eval_result,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
cohere_model = "embed-english-v3.0"
|
||||
|
||||
# Sparse model options: "Qdrant/bm25" or "prithivida/Splade_PP_en_v1" or "Qdrant/bm42-all-minilm-l6-v2-attentions"
|
||||
sparse_model_name = "prithivida/Splade_PP_en_v1" # Default to BM25
|
||||
|
||||
# Collection name (should match what was used in upload_chunks.py)
|
||||
collection_name = CollectionName.ACCURACY_TESTING
|
||||
max_workers = 10 # Number of parallel workers
|
||||
|
||||
# Load environment variables from .env file
|
||||
env_path = Path(__file__).parent.parent.parent.parent.parent / ".vscode" / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
print(f"Loaded environment variables from {env_path}")
|
||||
else:
|
||||
print(f"Warning: .env file not found at {env_path}")
|
||||
|
||||
# Initialize clients
|
||||
print("Initializing clients and embedding models...")
|
||||
print(f"Sparse model: {sparse_model_name}")
|
||||
|
||||
cohere_api_key = os.getenv("COHERE_API_KEY")
|
||||
if not cohere_api_key:
|
||||
raise ValueError("COHERE_API_KEY environment variable not set")
|
||||
|
||||
cohere_client = cohere.Client(cohere_api_key)
|
||||
qdrant_client = QdrantClient()
|
||||
sparse_embedding_model = SparseTextEmbedding(
|
||||
model_name=sparse_model_name, threads=2
|
||||
)
|
||||
print("Clients and models initialized\n")
|
||||
|
||||
# Load questions
|
||||
jsonl_path = Path(__file__).parent / "target_questions.jsonl"
|
||||
print(f"Loading questions from {jsonl_path}...")
|
||||
questions = load_questions(jsonl_path)
|
||||
print(f"Loaded {len(questions):,} questions\n")
|
||||
|
||||
# Run evaluation
|
||||
print("=" * 80)
|
||||
print(f"EVALUATION STARTED (using {max_workers} parallel workers)")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
results = []
|
||||
start_time = time.time()
|
||||
completed = 0
|
||||
|
||||
# Process questions in parallel
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit all tasks
|
||||
future_to_question = {
|
||||
executor.submit(
|
||||
evaluate_single_question,
|
||||
question,
|
||||
cohere_client,
|
||||
sparse_embedding_model,
|
||||
qdrant_client,
|
||||
collection_name,
|
||||
cohere_model,
|
||||
): question
|
||||
for question in questions
|
||||
}
|
||||
|
||||
# Process completed tasks as they finish
|
||||
for future in as_completed(future_to_question):
|
||||
completed += 1
|
||||
result = future.result()
|
||||
results.append(result)
|
||||
|
||||
# Print progress every 10 questions
|
||||
if completed % 10 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
avg_time = elapsed / completed
|
||||
remaining = (len(questions) - completed) * avg_time
|
||||
print(
|
||||
f"Progress: {completed}/{len(questions)} ({completed / len(questions) * 100:.1f}%) | "
|
||||
f"Elapsed: {elapsed:.1f}s | ETA: {remaining:.1f}s"
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# Calculate aggregate metrics
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("EVALUATION RESULTS")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
top_1_accuracy = sum(r["top_1_hit"] for r in results) / len(results) * 100
|
||||
top_3_accuracy = sum(r["top_3_hit"] for r in results) / len(results) * 100
|
||||
top_5_accuracy = sum(r["top_5_hit"] for r in results) / len(results) * 100
|
||||
top_10_accuracy = sum(r["top_10_hit"] for r in results) / len(results) * 100
|
||||
mrr = sum(r["reciprocal_rank"] for r in results) / len(results)
|
||||
|
||||
# Calculate average recall at different k values
|
||||
avg_recall_at_1 = sum(r["recall_at_1"] for r in results) / len(results) * 100
|
||||
avg_recall_at_3 = sum(r["recall_at_3"] for r in results) / len(results) * 100
|
||||
avg_recall_at_5 = sum(r["recall_at_5"] for r in results) / len(results) * 100
|
||||
avg_recall_at_10 = sum(r["recall_at_10"] for r in results) / len(results) * 100
|
||||
avg_recall_at_25 = sum(r["recall_at_25"] for r in results) / len(results) * 100
|
||||
avg_recall_at_50 = sum(r["recall_at_50"] for r in results) / len(results) * 100
|
||||
|
||||
print(f"Total questions evaluated: {len(results):,}")
|
||||
print(f"Total time: {total_time:.2f}s ({total_time / 60:.1f} minutes)")
|
||||
print(f"Average time per query: {total_time / len(results):.2f}s")
|
||||
print()
|
||||
|
||||
print("Perfect Recall Accuracy (all ground truth docs found):")
|
||||
print(f" Top-1 Accuracy: {top_1_accuracy:.2f}%")
|
||||
print(f" Top-3 Accuracy: {top_3_accuracy:.2f}%")
|
||||
print(f" Top-5 Accuracy: {top_5_accuracy:.2f}%")
|
||||
print(f" Top-10 Accuracy: {top_10_accuracy:.2f}%")
|
||||
print()
|
||||
|
||||
print("Average Found Ratio (recall metrics):")
|
||||
print(f" Average found ratio in first 1 context docs: {avg_recall_at_1:.2f}%")
|
||||
print(f" Average found ratio in first 3 context docs: {avg_recall_at_3:.2f}%")
|
||||
print(f" Average found ratio in first 5 context docs: {avg_recall_at_5:.2f}%")
|
||||
print(f" Average found ratio in first 10 context docs: {avg_recall_at_10:.2f}%")
|
||||
print(f" Average found ratio in first 25 context docs: {avg_recall_at_25:.2f}%")
|
||||
print(f" Average found ratio in first 50 context docs: {avg_recall_at_50:.2f}%")
|
||||
print()
|
||||
|
||||
print(f"MRR (Mean Reciprocal Rank): {mrr:.4f}")
|
||||
print()
|
||||
|
||||
# Show some examples
|
||||
print("=" * 80)
|
||||
print("SAMPLE RESULTS (First 2)")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
for idx, result in enumerate(results[:2], start=1):
|
||||
print(f"{idx}. Question: {result['question']}")
|
||||
print(
|
||||
f" Ground truth: {result['ground_truth_doc_ids']} ({result['num_ground_truth']} docs)"
|
||||
)
|
||||
print(
|
||||
f" Recall@3: {result['recall_at_3'] * 100:.0f}% | Recall@5: {result['recall_at_5'] * 100:.0f}% |"
|
||||
f" Recall@10: {result['recall_at_10'] * 100:.0f}%"
|
||||
)
|
||||
print(
|
||||
f" First correct rank: {result['first_correct_rank'] if result['first_correct_rank'] else 'Not found'}"
|
||||
)
|
||||
print(f" Deduplicated top 5: {result['deduplicated_doc_ids'][:5]}")
|
||||
print()
|
||||
|
||||
# Show multi-document ground truth examples
|
||||
print("=" * 80)
|
||||
print("MULTI-DOCUMENT GROUND TRUTH SAMPLES")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
multi_doc_results = [r for r in results if len(r["ground_truth_doc_ids"]) > 1]
|
||||
if multi_doc_results:
|
||||
print(
|
||||
f"Found {len(multi_doc_results)} questions with multiple ground truth documents\n"
|
||||
)
|
||||
|
||||
for idx, result in enumerate(multi_doc_results[:2], start=1):
|
||||
print(f"{idx}. Question: {result['question']}")
|
||||
print(f" # of ground truth docs: {result['num_ground_truth']}")
|
||||
print(f" Ground truth doc IDs: {result['ground_truth_doc_ids']}")
|
||||
print(
|
||||
f" Recall@1: {result['recall_at_1'] * 100:.0f}% | Recall@3: {result['recall_at_3'] * 100:.0f}% | Recall@5: "
|
||||
f"{result['recall_at_5'] * 100:.0f}% | Recall@10: {result['recall_at_10'] * 100:.0f}%"
|
||||
)
|
||||
print(
|
||||
f" Perfect recall in top-3: {'✓' if result['top_3_hit'] else '✗'} | "
|
||||
f"top-5: {'✓' if result['top_5_hit'] else '✗'} | top-10: {'✓' if result['top_10_hit'] else '✗'}"
|
||||
)
|
||||
print(
|
||||
f" First correct rank: {result['first_correct_rank'] if result['first_correct_rank'] else 'Not found'}"
|
||||
)
|
||||
print(f" Deduplicated (top 10): {result['deduplicated_doc_ids'][:10]}")
|
||||
print()
|
||||
else:
|
||||
print("No questions with multiple ground truth documents found.\n")
|
||||
|
||||
# Show failed retrieval examples (no ground truth found in top 10)
|
||||
print("=" * 80)
|
||||
print("FAILED RETRIEVAL SAMPLES (Ground Truth Not Found in Top 10)")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
failed_results = [r for r in results if not r["top_10_hit"]]
|
||||
if failed_results:
|
||||
print(
|
||||
f"Found {len(failed_results)} questions where ground truth was not in top 10\n"
|
||||
)
|
||||
|
||||
for idx, result in enumerate(failed_results[:2], start=1):
|
||||
print(f"{idx}. Question: {result['question']}")
|
||||
print(f" Ground truth: {result['ground_truth_doc_ids']}")
|
||||
print(
|
||||
f" First correct rank: "
|
||||
f"{result['first_correct_rank'] if result['first_correct_rank'] else 'Not found in top 50'}"
|
||||
)
|
||||
print(
|
||||
f" Recall@10: {result['recall_at_10'] * 100:.0f}% | Recall@50: {result['recall_at_50'] * 100:.0f}%"
|
||||
)
|
||||
print(f" Retrieved (top 10): {result['deduplicated_doc_ids'][:10]}")
|
||||
print()
|
||||
else:
|
||||
print("All questions found at least one ground truth document in top 10! 🎉\n")
|
||||
|
||||
# Save detailed results to file
|
||||
output_path = Path(__file__).parent / "evaluation_results.json"
|
||||
with open(output_path, "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"summary": {
|
||||
"total_questions": len(results),
|
||||
"total_multi_doc_questions": len(multi_doc_results),
|
||||
"perfect_recall_accuracy": {
|
||||
"top_1": top_1_accuracy,
|
||||
"top_3": top_3_accuracy,
|
||||
"top_5": top_5_accuracy,
|
||||
"top_10": top_10_accuracy,
|
||||
},
|
||||
"average_recall": {
|
||||
"recall_at_1": avg_recall_at_1,
|
||||
"recall_at_3": avg_recall_at_3,
|
||||
"recall_at_5": avg_recall_at_5,
|
||||
"recall_at_10": avg_recall_at_10,
|
||||
"recall_at_25": avg_recall_at_25,
|
||||
"recall_at_50": avg_recall_at_50,
|
||||
},
|
||||
"mrr": mrr,
|
||||
"total_time_seconds": total_time,
|
||||
"avg_time_per_query_seconds": total_time / len(results),
|
||||
},
|
||||
"detailed_results": results,
|
||||
},
|
||||
f,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
print(f"Detailed results saved to: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,46 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TargetDocument(BaseModel):
|
||||
"""Schema for documents in target_docs.jsonl."""
|
||||
|
||||
document_id: str # Hash ID
|
||||
semantic_identifier: str | None = None
|
||||
title: str | None
|
||||
content: str
|
||||
source_type: str | None = None
|
||||
filename: str | None = None # Human-readable filename
|
||||
url: str | None = None
|
||||
metadata: dict | None = None
|
||||
|
||||
|
||||
class TargetQuestionDocSource(BaseModel):
|
||||
"""Document source reference in question metadata."""
|
||||
|
||||
# For file-based sources
|
||||
source: str | None = None
|
||||
source_hash: str | None = None
|
||||
|
||||
# For Slack messages
|
||||
channel: str | None = None
|
||||
message_count: int | None = None
|
||||
source_type: str | None = None
|
||||
thread_ts: str | None = None
|
||||
workspace: str | None = None
|
||||
|
||||
|
||||
class TargetQuestionMetadata(BaseModel):
|
||||
"""Metadata for target questions."""
|
||||
|
||||
question_type: str
|
||||
doc_source: list[TargetQuestionDocSource]
|
||||
|
||||
|
||||
class TargetQuestion(BaseModel):
|
||||
"""Schema for questions in target_questions.jsonl."""
|
||||
|
||||
uid: str
|
||||
question: str
|
||||
ground_truth_answers: list[str]
|
||||
ground_truth_context: list[str]
|
||||
metadata: TargetQuestionMetadata
|
||||
534
backend/scratch/qdrant/accuracy_testing/upload_chunks.py
Normal file
534
backend/scratch/qdrant/accuracy_testing/upload_chunks.py
Normal file
@@ -0,0 +1,534 @@
|
||||
"""
|
||||
Script to upload documents from target_docs.jsonl to Qdrant for accuracy testing.
|
||||
Converts target documents to QdrantChunk format and embeds them using real embedding models.
|
||||
|
||||
Performance optimizations:
|
||||
- Medium batch sizes (50 documents per batch) to balance speed and memory
|
||||
- Controlled threading (2 threads) for stability
|
||||
- Parallel embedding (dense and sparse run concurrently)
|
||||
- Aggressive garbage collection to prevent OOM
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import cohere
|
||||
from dotenv import load_dotenv
|
||||
from fastembed import SparseTextEmbedding
|
||||
from qdrant_client.models import Distance
|
||||
from qdrant_client.models import OptimizersConfigDiff
|
||||
from qdrant_client.models import SparseVector
|
||||
from qdrant_client.models import SparseVectorParams
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from scratch.qdrant.accuracy_testing.target_document_schema import TargetDocument
|
||||
from scratch.qdrant.client import QdrantClient
|
||||
from scratch.qdrant.schemas.chunk import QdrantChunk
|
||||
from scratch.qdrant.schemas.collection_name import CollectionName
|
||||
from scratch.qdrant.schemas.embeddings import ChunkDenseEmbedding
|
||||
from scratch.qdrant.schemas.embeddings import ChunkSparseEmbedding
|
||||
|
||||
|
||||
def load_target_documents(jsonl_path: Path) -> list[TargetDocument]:
|
||||
"""Load target documents from JSONL file."""
|
||||
documents = []
|
||||
with open(jsonl_path, "r") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
documents.append(TargetDocument(**data))
|
||||
return documents
|
||||
|
||||
|
||||
def embed_with_cohere(
|
||||
texts: list[str],
|
||||
cohere_client: cohere.Client,
|
||||
model: str = "embed-english-v3.0",
|
||||
input_type: str = "search_document",
|
||||
batch_size: int = 96, # Cohere's max batch size is 96
|
||||
max_retries: int = 5,
|
||||
) -> list[list[float]]:
|
||||
"""
|
||||
Embed texts using Cohere API with batching support and retry logic.
|
||||
|
||||
Cohere API processes batches internally in parallel, so we just need to
|
||||
send the full batch and it will be handled efficiently.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
cohere_client: Initialized Cohere client
|
||||
model: Cohere model name
|
||||
input_type: Type of input - "search_document" for documents, "search_query" for queries
|
||||
batch_size: Maximum batch size for Cohere API (default 96)
|
||||
max_retries: Maximum number of retries for transient errors
|
||||
|
||||
Returns:
|
||||
List of embedding vectors
|
||||
"""
|
||||
|
||||
def embed_batch_with_retry(batch_texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a single batch with exponential backoff retry."""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = cohere_client.embed(
|
||||
texts=batch_texts,
|
||||
model=model,
|
||||
input_type=input_type,
|
||||
)
|
||||
return response.embeddings
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
# Last attempt failed, re-raise
|
||||
raise
|
||||
|
||||
# Calculate backoff time: 2^attempt seconds (1s, 2s, 4s, 8s, 16s)
|
||||
backoff_time = 2**attempt
|
||||
print(
|
||||
f" ⚠️ Error embedding batch (attempt {attempt + 1}/{max_retries}): {str(e)[:100]}"
|
||||
)
|
||||
print(f" ⏳ Retrying in {backoff_time}s...")
|
||||
time.sleep(backoff_time)
|
||||
|
||||
# Should never reach here, but just in case
|
||||
raise Exception("Max retries exceeded")
|
||||
|
||||
# If texts fit in one batch, send directly
|
||||
if len(texts) <= batch_size:
|
||||
return embed_batch_with_retry(texts)
|
||||
|
||||
# Otherwise, split into multiple batches and process
|
||||
all_embeddings = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i : i + batch_size]
|
||||
batch_embeddings = embed_batch_with_retry(batch_texts)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
|
||||
def chunks_to_cohere_embeddings(
|
||||
chunks: list[QdrantChunk],
|
||||
cohere_client: cohere.Client,
|
||||
model: str = "embed-english-v3.0",
|
||||
) -> list[ChunkDenseEmbedding]:
|
||||
"""
|
||||
Convert QdrantChunks to embeddings using Cohere.
|
||||
|
||||
Args:
|
||||
chunks: List of chunks to embed
|
||||
cohere_client: Initialized Cohere client
|
||||
model: Cohere model name
|
||||
|
||||
Returns:
|
||||
List of ChunkDenseEmbedding objects
|
||||
"""
|
||||
texts = [chunk.content for chunk in chunks]
|
||||
embeddings = embed_with_cohere(
|
||||
texts, cohere_client, model, input_type="search_document"
|
||||
)
|
||||
|
||||
return [
|
||||
ChunkDenseEmbedding(chunk_id=chunk.id, vector=embedding)
|
||||
for chunk, embedding in zip(chunks, embeddings)
|
||||
]
|
||||
|
||||
|
||||
def chunks_to_bm25_embeddings(
|
||||
chunks: list[QdrantChunk],
|
||||
sparse_embedding_model: SparseTextEmbedding,
|
||||
) -> list[ChunkSparseEmbedding]:
|
||||
"""
|
||||
Convert QdrantChunks to BM25 sparse embeddings.
|
||||
|
||||
Args:
|
||||
chunks: List of chunks to embed
|
||||
sparse_embedding_model: Initialized BM25 model
|
||||
|
||||
Returns:
|
||||
List of ChunkSparseEmbedding objects
|
||||
"""
|
||||
sparse_vectors = sparse_embedding_model.passage_embed(
|
||||
[chunk.content for chunk in chunks]
|
||||
)
|
||||
return [
|
||||
ChunkSparseEmbedding(
|
||||
chunk_id=chunk.id,
|
||||
vector=SparseVector(
|
||||
indices=vector.indices.tolist(), values=vector.values.tolist()
|
||||
),
|
||||
)
|
||||
for chunk, vector in zip(chunks, sparse_vectors)
|
||||
]
|
||||
|
||||
|
||||
def convert_target_doc_to_chunks(
|
||||
target_doc: TargetDocument, max_chunk_length: int = 8000
|
||||
) -> list[QdrantChunk]:
|
||||
"""
|
||||
Convert a TargetDocument to one or more QdrantChunks.
|
||||
|
||||
- Splits long documents into multiple chunks if needed
|
||||
- Each chunk gets a unique UUID
|
||||
- All chunks share the same document_id for traceability
|
||||
- Uses filename if available, otherwise falls back to document_id
|
||||
- Uses empty ACL (public access)
|
||||
- Uses current time for created_at
|
||||
|
||||
Args:
|
||||
target_doc: The document to convert
|
||||
max_chunk_length: Maximum characters per chunk (default 8000 = ~2000 tokens)
|
||||
|
||||
Returns:
|
||||
List of QdrantChunk objects (1 or more)
|
||||
"""
|
||||
created_at = datetime.datetime.now()
|
||||
content = target_doc.content
|
||||
title = target_doc.title
|
||||
|
||||
# Prepend title to content if title is not None
|
||||
if title is not None:
|
||||
content = f"{title}\n{content}"
|
||||
|
||||
# If content fits in one chunk, return single chunk
|
||||
if len(content) <= max_chunk_length:
|
||||
return [
|
||||
QdrantChunk(
|
||||
id=uuid4(),
|
||||
document_id=target_doc.document_id,
|
||||
filename=target_doc.filename,
|
||||
source_type=None,
|
||||
access_control_list=None,
|
||||
created_at=created_at,
|
||||
content=content,
|
||||
)
|
||||
]
|
||||
|
||||
# Split content into multiple chunks
|
||||
chunks = []
|
||||
num_chunks = (len(content) + max_chunk_length - 1) // max_chunk_length
|
||||
|
||||
for i in range(num_chunks):
|
||||
start_idx = i * max_chunk_length
|
||||
end_idx = min(start_idx + max_chunk_length, len(content))
|
||||
chunk_content = content[start_idx:end_idx]
|
||||
|
||||
chunks.append(
|
||||
QdrantChunk(
|
||||
id=uuid4(),
|
||||
document_id=target_doc.document_id,
|
||||
filename=target_doc.filename,
|
||||
source_type=None,
|
||||
access_control_list=None,
|
||||
created_at=created_at,
|
||||
content=chunk_content,
|
||||
)
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def main():
|
||||
# Embedding model configuration
|
||||
cohere_model = "embed-english-v3.0"
|
||||
vector_size = 1024 # embed-english-v3.0 dimension
|
||||
|
||||
# Sparse model options: "Qdrant/bm25" or "prithivida/Splade_PP_en_v1" "Qdrant/bm42-all-minilm-l6-v2-attentions"
|
||||
sparse_model_name = "Qdrant/bm25"
|
||||
# Collection name includes sparse model type for easy comparison
|
||||
if "bm25" in sparse_model_name.lower():
|
||||
collection_name = CollectionName.ACCURACY_TESTING
|
||||
collection_suffix = "_bm25"
|
||||
elif "splade" in sparse_model_name.lower():
|
||||
collection_name = CollectionName.ACCURACY_TESTING
|
||||
collection_suffix = "_splade"
|
||||
else:
|
||||
collection_name = CollectionName.ACCURACY_TESTING
|
||||
collection_suffix = "_custom"
|
||||
|
||||
# Control whether to index while uploading
|
||||
index_while_uploading = False
|
||||
|
||||
# Batch processing configuration
|
||||
# Cohere API can handle up to 96 texts per request and processes them in parallel
|
||||
batch_size = 96 # Match Cohere's max batch size for optimal performance
|
||||
|
||||
# Load environment variables from .env file
|
||||
env_path = Path(__file__).parent.parent.parent.parent.parent / ".vscode" / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
print(f"Loaded environment variables from {env_path}")
|
||||
else:
|
||||
print(f"Warning: .env file not found at {env_path}")
|
||||
|
||||
# Initialize Cohere client
|
||||
print("Initializing Cohere client...")
|
||||
cohere_api_key = os.getenv("COHERE_API_KEY")
|
||||
if not cohere_api_key:
|
||||
raise ValueError("COHERE_API_KEY environment variable not set")
|
||||
|
||||
cohere_client = cohere.Client(cohere_api_key)
|
||||
print(f"Cohere client initialized with model: {cohere_model}")
|
||||
|
||||
# Initialize sparse embedding model
|
||||
print(f"Initializing sparse embedding model: {sparse_model_name}...")
|
||||
sparse_embedding_model = SparseTextEmbedding(
|
||||
model_name=sparse_model_name, threads=2
|
||||
)
|
||||
print(f"Sparse model initialized: {sparse_model_name}{collection_suffix}\n")
|
||||
|
||||
# Initialize Qdrant client
|
||||
qdrant_client = QdrantClient()
|
||||
|
||||
# Delete and recreate collection
|
||||
print(f"Setting up collection: {collection_name} (sparse: {sparse_model_name})")
|
||||
print(f"Index while uploading: {index_while_uploading}")
|
||||
qdrant_client.delete_collection(collection_name=collection_name)
|
||||
|
||||
# Set indexing threshold based on mode
|
||||
optimizer_config = (
|
||||
None if index_while_uploading else OptimizersConfigDiff(indexing_threshold=0)
|
||||
)
|
||||
|
||||
# Create collection with both dense (Cohere) and sparse (BM25) vectors
|
||||
qdrant_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
dense_vectors_config={
|
||||
"dense": VectorParams(size=vector_size, distance=Distance.COSINE),
|
||||
},
|
||||
sparse_vectors_config={
|
||||
"sparse": SparseVectorParams(),
|
||||
},
|
||||
optimizers_config=optimizer_config,
|
||||
shard_number=4,
|
||||
)
|
||||
print(f"Collection {collection_name} created")
|
||||
print(f"Optimizer config: {optimizer_config}\n")
|
||||
|
||||
# Load target documents - stream them to count total first
|
||||
jsonl_path = Path(__file__).parent / "target_docs.jsonl"
|
||||
print(f"Counting documents in {jsonl_path}...")
|
||||
with open(jsonl_path, "r") as f:
|
||||
total_docs = sum(1 for line in f if line.strip())
|
||||
print(f"Found {total_docs:,} documents\n")
|
||||
|
||||
# Process in batches - stream documents and process in chunks
|
||||
num_batches = (total_docs + batch_size - 1) // batch_size
|
||||
print(
|
||||
f"Processing {total_docs:,} chunks in {num_batches} batches of {batch_size:,}..."
|
||||
)
|
||||
print()
|
||||
|
||||
overall_start = time.time()
|
||||
chunks_processed = 0
|
||||
docs_processed = 0
|
||||
batch_num = 0
|
||||
batch_chunks = []
|
||||
|
||||
# Stream documents and process in batches
|
||||
with open(jsonl_path, "r") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
# Parse and convert document to chunk(s)
|
||||
data = json.loads(line)
|
||||
target_doc = TargetDocument(**data)
|
||||
doc_chunks = convert_target_doc_to_chunks(target_doc)
|
||||
batch_chunks.extend(doc_chunks) # Add all chunks from this document
|
||||
docs_processed += 1
|
||||
|
||||
# Process batch when full
|
||||
if len(batch_chunks) >= batch_size:
|
||||
batch_num += 1
|
||||
print(f"=== Batch {batch_num}/{num_batches} ===")
|
||||
|
||||
# Embed with Cohere (dense)
|
||||
embed_start = time.time()
|
||||
dense_embeddings = chunks_to_cohere_embeddings(
|
||||
batch_chunks, cohere_client, cohere_model
|
||||
)
|
||||
dense_time = time.time() - embed_start
|
||||
|
||||
# Embed with BM25 (sparse)
|
||||
sparse_start = time.time()
|
||||
sparse_embeddings = chunks_to_bm25_embeddings(
|
||||
batch_chunks, sparse_embedding_model
|
||||
)
|
||||
sparse_time = time.time() - sparse_start
|
||||
|
||||
# Calculate embedding sizes
|
||||
dense_dim = len(dense_embeddings[0].vector) if dense_embeddings else 0
|
||||
avg_sparse_dims = (
|
||||
sum(len(e.vector.indices) for e in sparse_embeddings)
|
||||
/ len(sparse_embeddings)
|
||||
if sparse_embeddings
|
||||
else 0
|
||||
)
|
||||
|
||||
embed_time = dense_time + sparse_time
|
||||
print(
|
||||
f"1. Embeddings: {embed_time:.2f}s (dense: {dense_time:.2f}s, sparse: {sparse_time:.2f}s)"
|
||||
)
|
||||
print(
|
||||
f" Dense dim: {dense_dim}, Avg sparse dims: {avg_sparse_dims:.0f}"
|
||||
)
|
||||
|
||||
# Step 2: Build points with both dense and sparse embeddings
|
||||
build_start = time.time()
|
||||
points = []
|
||||
for chunk, dense_emb, sparse_emb in zip(
|
||||
batch_chunks, dense_embeddings, sparse_embeddings
|
||||
):
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=str(chunk.id),
|
||||
vector={
|
||||
"dense": dense_emb.vector,
|
||||
"sparse": sparse_emb.vector,
|
||||
},
|
||||
payload=chunk.model_dump(exclude={"id"}),
|
||||
)
|
||||
)
|
||||
build_time = time.time() - build_start
|
||||
print(f"2. Build points: {build_time:.2f}s")
|
||||
|
||||
# Step 3: Insert to Qdrant
|
||||
insert_start = time.time()
|
||||
result = qdrant_client.override_points(points, collection_name)
|
||||
insert_time = time.time() - insert_start
|
||||
print(f"3. Insert to Qdrant: {insert_time:.2f}s")
|
||||
|
||||
batch_total = time.time() - embed_start
|
||||
chunks_processed += len(batch_chunks)
|
||||
|
||||
print(f"Batch total: {batch_total:.2f}s")
|
||||
print(f"Status: {result.status}")
|
||||
print(
|
||||
f"Docs: {docs_processed:,} / {total_docs:,} | Chunks: {chunks_processed:,}"
|
||||
)
|
||||
print()
|
||||
|
||||
# Clear batch for next iteration and free memory aggressively
|
||||
del dense_embeddings, sparse_embeddings, points, result
|
||||
batch_chunks = []
|
||||
gc.collect()
|
||||
|
||||
# Small delay to allow memory cleanup
|
||||
time.sleep(0.1)
|
||||
|
||||
# Process remaining chunks if any
|
||||
if batch_chunks:
|
||||
batch_num += 1
|
||||
print(f"=== Batch {batch_num}/{num_batches} (final) ===")
|
||||
|
||||
# Embed with Cohere (dense)
|
||||
embed_start = time.time()
|
||||
dense_embeddings = chunks_to_cohere_embeddings(
|
||||
batch_chunks, cohere_client, cohere_model
|
||||
)
|
||||
dense_time = time.time() - embed_start
|
||||
|
||||
# Embed with BM25 (sparse)
|
||||
sparse_start = time.time()
|
||||
sparse_embeddings = chunks_to_bm25_embeddings(
|
||||
batch_chunks, sparse_embedding_model
|
||||
)
|
||||
sparse_time = time.time() - sparse_start
|
||||
|
||||
# Calculate embedding sizes
|
||||
dense_dim = len(dense_embeddings[0].vector) if dense_embeddings else 0
|
||||
avg_sparse_dims = (
|
||||
sum(len(e.vector.indices) for e in sparse_embeddings)
|
||||
/ len(sparse_embeddings)
|
||||
if sparse_embeddings
|
||||
else 0
|
||||
)
|
||||
|
||||
embed_time = dense_time + sparse_time
|
||||
print(
|
||||
f"1. Embeddings: {embed_time:.2f}s (dense: {dense_time:.2f}s, sparse: {sparse_time:.2f}s)"
|
||||
)
|
||||
print(f" Dense dim: {dense_dim}, Avg sparse dims: {avg_sparse_dims:.0f}")
|
||||
|
||||
# Build points with both dense and sparse embeddings
|
||||
build_start = time.time()
|
||||
points = []
|
||||
for chunk, dense_emb, sparse_emb in zip(
|
||||
batch_chunks, dense_embeddings, sparse_embeddings
|
||||
):
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=str(chunk.id),
|
||||
vector={"dense": dense_emb.vector, "sparse": sparse_emb.vector},
|
||||
payload=chunk.model_dump(exclude={"id"}),
|
||||
)
|
||||
)
|
||||
build_time = time.time() - build_start
|
||||
print(f"2. Build points: {build_time:.2f}s")
|
||||
|
||||
insert_start = time.time()
|
||||
result = qdrant_client.override_points(points, collection_name)
|
||||
insert_time = time.time() - insert_start
|
||||
print(f"3. Insert to Qdrant: {insert_time:.2f}s")
|
||||
|
||||
batch_total = time.time() - embed_start
|
||||
chunks_processed += len(batch_chunks)
|
||||
|
||||
print(f"Batch total: {batch_total:.2f}s")
|
||||
print(f"Status: {result.status}")
|
||||
print(
|
||||
f"Docs: {docs_processed:,} / {total_docs:,} | Chunks: {chunks_processed:,}"
|
||||
)
|
||||
print()
|
||||
|
||||
total_elapsed = time.time() - overall_start
|
||||
|
||||
print("=" * 60)
|
||||
print("COMPLETE")
|
||||
print("=" * 60)
|
||||
print(f"Total documents processed: {total_docs:,}")
|
||||
print(f"Total chunks inserted: {chunks_processed:,}")
|
||||
print(f"Average chunks per document: {chunks_processed / total_docs:.1f}")
|
||||
print(f"Total time: {total_elapsed:.2f} seconds ({total_elapsed / 60:.1f} minutes)")
|
||||
print(f"Average rate: {chunks_processed / total_elapsed:.1f} chunks/sec")
|
||||
print()
|
||||
|
||||
print("Collection info:")
|
||||
collection_info = qdrant_client.get_collection(collection_name)
|
||||
print(f" Points count: {collection_info.points_count:,}")
|
||||
print(f" Indexed vectors count: {collection_info.indexed_vectors_count:,}")
|
||||
print(f" Optimizer status: {collection_info.optimizer_status}")
|
||||
print(f" Status: {collection_info.status}")
|
||||
|
||||
# Only need to trigger indexing if we disabled it during upload
|
||||
if not index_while_uploading:
|
||||
print("\nTriggering indexing (was disabled during upload)...")
|
||||
|
||||
qdrant_client.update_collection(
|
||||
collection_name=collection_name,
|
||||
optimizers_config=OptimizersConfigDiff(
|
||||
indexing_threshold=20000,
|
||||
),
|
||||
)
|
||||
print("Collection optimizers config updated - indexing will now proceed")
|
||||
else:
|
||||
print("\nIndexing was enabled during upload - no manual trigger needed")
|
||||
|
||||
fresh_collection_info = qdrant_client.get_collection(collection_name)
|
||||
print(f" Points count: {fresh_collection_info.points_count:,}")
|
||||
print(f" Indexed vectors count: {fresh_collection_info.indexed_vectors_count:,}")
|
||||
print(f" Optimizer status: {fresh_collection_info.optimizer_status}")
|
||||
print(f" Status: {fresh_collection_info.status}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
125
backend/scratch/qdrant/client.py
Normal file
125
backend/scratch/qdrant/client.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from qdrant_client import QdrantClient as ThirdPartyQdrantClient
|
||||
from qdrant_client.models import CollectionInfo
|
||||
from qdrant_client.models import Filter
|
||||
from qdrant_client.models import FusionQuery
|
||||
from qdrant_client.models import OptimizersConfigDiff
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Prefetch
|
||||
from qdrant_client.models import QueryResponse
|
||||
from qdrant_client.models import SparseVectorParams
|
||||
from qdrant_client.models import UpdateResult
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from scratch.qdrant.config import QdrantConfig
|
||||
from scratch.qdrant.schemas.collection_name import CollectionName
|
||||
from scratch.qdrant.schemas.collection_operations import CreateCollectionResult
|
||||
from scratch.qdrant.schemas.collection_operations import DeleteCollectionResult
|
||||
from scratch.qdrant.schemas.collection_operations import UpdateCollectionResult
|
||||
|
||||
|
||||
class QdrantClient:
|
||||
def __init__(self):
|
||||
self.client = ThirdPartyQdrantClient(
|
||||
url=QdrantConfig.url,
|
||||
timeout=300,
|
||||
)
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
collection_name: CollectionName,
|
||||
dense_vectors_config: VectorParams | dict[str, VectorParams] | None,
|
||||
sparse_vectors_config: dict[str, SparseVectorParams] | None,
|
||||
optimizers_config: OptimizersConfigDiff | None = None,
|
||||
shard_number: int | None = None,
|
||||
) -> CreateCollectionResult:
|
||||
is_successful = self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=dense_vectors_config,
|
||||
sparse_vectors_config=sparse_vectors_config,
|
||||
optimizers_config=optimizers_config,
|
||||
shard_number=shard_number,
|
||||
)
|
||||
|
||||
return CreateCollectionResult(success=is_successful)
|
||||
|
||||
def update_collection(
|
||||
self,
|
||||
collection_name: CollectionName,
|
||||
optimizers_config: OptimizersConfigDiff | None = None,
|
||||
) -> UpdateCollectionResult:
|
||||
is_successful = self.client.update_collection(
|
||||
collection_name=collection_name,
|
||||
optimizers_config=optimizers_config,
|
||||
)
|
||||
|
||||
return UpdateCollectionResult(success=is_successful)
|
||||
|
||||
def delete_collection(
|
||||
self, collection_name: CollectionName
|
||||
) -> DeleteCollectionResult:
|
||||
is_successful = self.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
return DeleteCollectionResult(success=is_successful)
|
||||
|
||||
def get_collection(self, collection_name: CollectionName) -> CollectionInfo:
|
||||
result = self.client.get_collection(collection_name=collection_name)
|
||||
return result
|
||||
|
||||
def override_points(
|
||||
self, points: list[PointStruct], collection_name: CollectionName
|
||||
) -> UpdateResult:
|
||||
import sys
|
||||
import json
|
||||
|
||||
# Calculate approximate request size
|
||||
try:
|
||||
# Convert points to dicts for size calculation
|
||||
points_dicts = [
|
||||
{"id": p.id, "vector": p.vector, "payload": p.payload} for p in points
|
||||
]
|
||||
json_str = json.dumps(points_dicts)
|
||||
request_size_bytes = len(json_str.encode("utf-8"))
|
||||
except Exception:
|
||||
# Fallback to sys.getsizeof if serialization fails
|
||||
request_size_bytes = sys.getsizeof(points)
|
||||
|
||||
request_size_mb = request_size_bytes / (1024 * 1024)
|
||||
max_size_mb = 32
|
||||
|
||||
print(
|
||||
f" Request size: {request_size_mb:.2f} MB / {max_size_mb} MB ({request_size_mb / max_size_mb * 100:.1f}%)"
|
||||
)
|
||||
|
||||
if request_size_mb > max_size_mb * 0.9:
|
||||
print(f" WARNING: Request size is close to {max_size_mb}MB limit!")
|
||||
|
||||
result = self.client.upsert(points=points, collection_name=collection_name)
|
||||
return result
|
||||
|
||||
def get_embedding_size(self, model_name: str) -> int:
|
||||
"""Get the embedding size for a dense model."""
|
||||
return self.client.get_embedding_size(model_name)
|
||||
|
||||
def query_points(
|
||||
self,
|
||||
collection_name: CollectionName,
|
||||
query: list[float] | None = None,
|
||||
prefetch: list[Prefetch] | None = None,
|
||||
query_filter: Filter | None = None,
|
||||
fusion_query: FusionQuery | None = None,
|
||||
using: str | None = None,
|
||||
with_payload: bool = True,
|
||||
with_vectors: bool = False,
|
||||
limit: int = 10,
|
||||
) -> QueryResponse:
|
||||
"""Query points from a collection with optional prefetch and fusion."""
|
||||
return self.client.query_points(
|
||||
collection_name=collection_name,
|
||||
query=query if fusion_query is None else fusion_query,
|
||||
prefetch=prefetch,
|
||||
query_filter=query_filter,
|
||||
using=using,
|
||||
with_payload=with_payload,
|
||||
with_vectors=with_vectors,
|
||||
limit=limit,
|
||||
)
|
||||
7
backend/scratch/qdrant/config.py
Normal file
7
backend/scratch/qdrant/config.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from typing import ClassVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class QdrantConfig(BaseModel):
|
||||
url: ClassVar[str] = "http://localhost:6333"
|
||||
121
backend/scratch/qdrant/create_and_search_script.py
Normal file
121
backend/scratch/qdrant/create_and_search_script.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from fastembed import SparseTextEmbedding
|
||||
from fastembed import TextEmbedding
|
||||
from qdrant_client.models import Distance
|
||||
from qdrant_client.models import SparseVectorParams
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from .client import QdrantClient
|
||||
from .performance_testing.fake_chunk_helpers import fake_acl
|
||||
from .performance_testing.fake_chunk_helpers import fake_source_type
|
||||
from .performance_testing.fake_chunk_helpers import generate_fake_qdrant_chunks
|
||||
from .schemas.chunk import QdrantChunk
|
||||
from .schemas.collection_name import CollectionName
|
||||
from .service import QdrantService
|
||||
|
||||
|
||||
def main():
|
||||
collection_name = CollectionName.TEST_COLLECTION
|
||||
dense_model_name = "nomic-ai/nomic-embed-text-v1"
|
||||
sparse_model_name = "prithivida/Splade_PP_en_v1"
|
||||
|
||||
# Initialize client and service
|
||||
dense_embedding_model = TextEmbedding(model_name=dense_model_name)
|
||||
sparse_embedding_model = SparseTextEmbedding(model_name=sparse_model_name)
|
||||
service = QdrantService(client=QdrantClient())
|
||||
|
||||
# Get embedding size for dense vectors
|
||||
dense_embedding_size = service.client.get_embedding_size(dense_model_name)
|
||||
|
||||
# Delete and recreate collection
|
||||
service.client.delete_collection(collection_name=collection_name)
|
||||
service.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
dense_vectors_config={
|
||||
"dense": VectorParams(size=dense_embedding_size, distance=Distance.COSINE),
|
||||
},
|
||||
sparse_vectors_config={
|
||||
"sparse": SparseVectorParams(),
|
||||
},
|
||||
)
|
||||
print(f"Collection {collection_name} created")
|
||||
|
||||
# Generate fake chunks
|
||||
num_fake_chunks = 100
|
||||
print(f"Generating {num_fake_chunks} fake chunks...")
|
||||
fake_chunks = list(generate_fake_qdrant_chunks(num_fake_chunks))
|
||||
print(f"Generated {len(fake_chunks)} chunks")
|
||||
|
||||
# Write to Qdrant using service
|
||||
print("Writing chunks to Qdrant...")
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
result = service.embed_and_upsert_chunks(
|
||||
fake_chunks, dense_embedding_model, sparse_embedding_model, collection_name
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"Upsert result: {result}")
|
||||
print(f"Time taken: {elapsed_time:.2f} seconds")
|
||||
|
||||
# Manually insert a specific chunk for testing
|
||||
manual_doc_id = uuid4()
|
||||
manual_content = "China is a very large nation"
|
||||
manual_acl = fake_acl()
|
||||
manual_source_type = fake_source_type()
|
||||
chunk_id = uuid4()
|
||||
print("\nManually inserting a chunk with the following details:")
|
||||
print(f"Chunk ID: {chunk_id}")
|
||||
print(f"Document ID: {manual_doc_id}")
|
||||
print(f"Content: {manual_content}")
|
||||
print(f"ACL: {manual_acl}")
|
||||
print(f"Source Type: {manual_source_type}")
|
||||
|
||||
manual_result = service.embed_and_upsert_chunks(
|
||||
[
|
||||
QdrantChunk(
|
||||
id=chunk_id,
|
||||
document_id=manual_doc_id,
|
||||
source_type=manual_source_type,
|
||||
access_control_list=manual_acl,
|
||||
content=manual_content,
|
||||
)
|
||||
],
|
||||
dense_embedding_model,
|
||||
sparse_embedding_model,
|
||||
collection_name,
|
||||
)
|
||||
print(f"Upsert result: {manual_result}")
|
||||
|
||||
# Test hybrid search using service
|
||||
query = "What is the biggest nation?"
|
||||
print(f"\nTesting hybrid search (RRF fusion)... with query: '{query}'")
|
||||
|
||||
# Generate query embeddings
|
||||
dense_query_vector, sparse_query_vector = service.generate_query_embeddings(
|
||||
query, dense_embedding_model, sparse_embedding_model
|
||||
)
|
||||
|
||||
# Perform search with pre-computed embeddings
|
||||
search_result = service.hybrid_search(
|
||||
dense_query_vector=dense_query_vector,
|
||||
sparse_query_vector=sparse_query_vector,
|
||||
collection_name=collection_name,
|
||||
limit=3,
|
||||
)
|
||||
|
||||
print(f"\nSearch Results for '{query}':")
|
||||
print(f"Found {len(search_result.points)} results\n")
|
||||
for idx, point in enumerate(search_result.points, 1):
|
||||
print(f"{idx}. Score: {point.score:.4f}")
|
||||
print(f" ID: {point.id}")
|
||||
print(f" Document ID: {point.payload.get('document_id')}")
|
||||
print(f" Source: {point.payload.get('source_type')}")
|
||||
print(f" ACL: {point.payload.get('access_control_list')}")
|
||||
print(f" Content: {point.payload.get('content')}")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
106
backend/scratch/qdrant/get_collection_status.py
Normal file
106
backend/scratch/qdrant/get_collection_status.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Script to get and display the status of a Qdrant collection.
|
||||
|
||||
Usage:
|
||||
python -m scratch.qdrant.get_collection_status [collection_name]
|
||||
|
||||
If no collection name is provided, shows ACCURACY_TESTING by default.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
from scratch.qdrant.client import QdrantClient
|
||||
from scratch.qdrant.schemas.collection_name import CollectionName
|
||||
|
||||
|
||||
def main():
|
||||
# Get collection name from command line or use default
|
||||
if len(sys.argv) > 1:
|
||||
collection_name_str = sys.argv[1]
|
||||
# Try to match to CollectionName enum
|
||||
try:
|
||||
collection_name = CollectionName[collection_name_str.upper()]
|
||||
except KeyError:
|
||||
# If not in enum, use the string directly
|
||||
collection_name = collection_name_str
|
||||
else:
|
||||
collection_name = CollectionName.ACCURACY_TESTING
|
||||
|
||||
# Initialize client
|
||||
client = QdrantClient()
|
||||
|
||||
print("=" * 80)
|
||||
print(f"COLLECTION STATUS: {collection_name}")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
try:
|
||||
collection_info = client.get_collection(collection_name)
|
||||
|
||||
print("General Info:")
|
||||
print(f" Status: {collection_info.status}")
|
||||
print(f" Points count: {collection_info.points_count:,}")
|
||||
print(f" Indexed vectors count: {collection_info.indexed_vectors_count:,}")
|
||||
print(f" Optimizer status: {collection_info.optimizer_status}")
|
||||
print()
|
||||
|
||||
print("Configuration:")
|
||||
print(f" Vectors config: {collection_info.config.params.vectors}")
|
||||
print(
|
||||
f" Sparse vectors config: {collection_info.config.params.sparse_vectors}"
|
||||
)
|
||||
print(f" Shard number: {collection_info.config.params.shard_number}")
|
||||
print(
|
||||
f" Replication factor: {collection_info.config.params.replication_factor}"
|
||||
)
|
||||
print()
|
||||
|
||||
if collection_info.config.optimizer_config:
|
||||
print("Optimizer Config:")
|
||||
print(
|
||||
f" Deleted threshold: {collection_info.config.optimizer_config.deleted_threshold}"
|
||||
)
|
||||
print(
|
||||
f" Vacuum min vector number: {collection_info.config.optimizer_config.vacuum_min_vector_number}"
|
||||
)
|
||||
print(
|
||||
f" Default segment number: {collection_info.config.optimizer_config.default_segment_number}"
|
||||
)
|
||||
print(
|
||||
f" Max segment size: {collection_info.config.optimizer_config.max_segment_size}"
|
||||
)
|
||||
print(
|
||||
f" Memmap threshold: {collection_info.config.optimizer_config.memmap_threshold}"
|
||||
)
|
||||
print(
|
||||
f" Indexing threshold: {collection_info.config.optimizer_config.indexing_threshold}"
|
||||
)
|
||||
print(
|
||||
f" Flush interval sec: {collection_info.config.optimizer_config.flush_interval_sec}"
|
||||
)
|
||||
print(
|
||||
f" Max optimization threads: {collection_info.config.optimizer_config.max_optimization_threads}"
|
||||
)
|
||||
print()
|
||||
|
||||
if collection_info.payload_schema:
|
||||
print("Payload Schema:")
|
||||
for field, schema in collection_info.payload_schema.items():
|
||||
print(f" {field}: {schema}")
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error retrieving collection: {e}")
|
||||
print()
|
||||
print("Available collections:")
|
||||
# Try to list collections
|
||||
try:
|
||||
collections = client.client.get_collections()
|
||||
for col in collections.collections:
|
||||
print(f" - {col.name}")
|
||||
except Exception as list_error:
|
||||
print(f" Could not list collections: {list_error}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
backend/scratch/qdrant/list_supported_models.py
Normal file
43
backend/scratch/qdrant/list_supported_models.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Script to list all supported embedding models from fastembed.
|
||||
|
||||
Shows both dense (TextEmbedding) and sparse (SparseTextEmbedding) models.
|
||||
"""
|
||||
|
||||
from fastembed import SparseTextEmbedding
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 80)
|
||||
print("DENSE EMBEDDING MODELS (TextEmbedding)")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
dense_models = TextEmbedding.list_supported_models()
|
||||
print(f"Total models: {len(dense_models)}\n")
|
||||
|
||||
for idx, model in enumerate(dense_models, start=1):
|
||||
print(f"{idx}. {model['model']}")
|
||||
print(f" Dimension: {model['dim']}")
|
||||
print(f" Description: {model.get('description', 'N/A')}")
|
||||
print(f" Size: {model.get('size_in_GB', 'N/A')} GB")
|
||||
print()
|
||||
|
||||
print("=" * 80)
|
||||
print("SPARSE EMBEDDING MODELS (SparseTextEmbedding)")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
sparse_models = SparseTextEmbedding.list_supported_models()
|
||||
print(f"Total models: {len(sparse_models)}\n")
|
||||
|
||||
for idx, model in enumerate(sparse_models, start=1):
|
||||
print(f"{idx}. {model['model']}")
|
||||
print(f" Description: {model.get('description', 'N/A')}")
|
||||
print(f" Size: {model.get('size_in_GB', 'N/A')} GB")
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
backend/scratch/qdrant/performance_testing/__init__.py
Normal file
1
backend/scratch/qdrant/performance_testing/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Performance testing package for Qdrant experiments
|
||||
147
backend/scratch/qdrant/performance_testing/fake_chunk_helpers.py
Normal file
147
backend/scratch/qdrant/performance_testing/fake_chunk_helpers.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import datetime
|
||||
import random
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client.models import SparseVector
|
||||
|
||||
from ..schemas.chunk import QdrantChunk
|
||||
from ..schemas.embeddings import ChunkDenseEmbedding
|
||||
from ..schemas.embeddings import ChunkSparseEmbedding
|
||||
from ..schemas.source_type import SourceType
|
||||
|
||||
|
||||
# Pre-generated pool of 100 emails for consistent ACL testing
|
||||
_EMAIL_POOL = [f"user_{i:03d}@example.com" for i in range(100)]
|
||||
|
||||
|
||||
def get_email_pool() -> list[str]:
|
||||
"""Get the pool of 100 pre-generated emails."""
|
||||
return _EMAIL_POOL
|
||||
|
||||
|
||||
def fake_email() -> str:
|
||||
"""Get a random email from the pool."""
|
||||
return random.choice(_EMAIL_POOL)
|
||||
|
||||
|
||||
def fake_content() -> str:
|
||||
words = [
|
||||
"lorem",
|
||||
"ipsum",
|
||||
"dolor",
|
||||
"sit",
|
||||
"amet",
|
||||
"consectetur",
|
||||
"adipiscing",
|
||||
"elit",
|
||||
"sed",
|
||||
"do",
|
||||
"eiusmod",
|
||||
"tempor",
|
||||
]
|
||||
return " ".join(random.choices(words, k=random.randint(10, 30)))
|
||||
|
||||
|
||||
def fake_source_type() -> SourceType:
|
||||
return random.choice(list(SourceType))
|
||||
|
||||
|
||||
def fake_created_at() -> datetime.datetime:
|
||||
"""
|
||||
Generate a fake creation datetime within the last year.
|
||||
|
||||
Returns:
|
||||
A random datetime between 365 days ago and now
|
||||
"""
|
||||
now = datetime.datetime.now()
|
||||
days_ago = random.randint(0, 365)
|
||||
hours_ago = random.randint(0, 23)
|
||||
minutes_ago = random.randint(0, 59)
|
||||
|
||||
fake_time = now - datetime.timedelta(
|
||||
days=days_ago, hours=hours_ago, minutes=minutes_ago
|
||||
)
|
||||
return fake_time
|
||||
|
||||
|
||||
def fake_acl() -> list[str]:
|
||||
"""
|
||||
Generate a fake ACL with 1-3 random emails from the pool.
|
||||
This ensures overlap between chunks for realistic filtering tests.
|
||||
"""
|
||||
num_emails = random.randint(1, 3)
|
||||
return random.sample(_EMAIL_POOL, num_emails)
|
||||
|
||||
|
||||
def generate_fake_qdrant_chunks(n: int, content: str | None = None):
|
||||
"""
|
||||
Generator that yields n fake QdrantChunk objects.
|
||||
|
||||
Args:
|
||||
n: Number of chunks to generate
|
||||
content: Optional fixed content for all chunks. If None, generates random content.
|
||||
"""
|
||||
for _ in range(n):
|
||||
yield QdrantChunk(
|
||||
id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
source_type=fake_source_type(),
|
||||
access_control_list=fake_acl(),
|
||||
created_at=fake_created_at(),
|
||||
content=content if content is not None else fake_content(),
|
||||
)
|
||||
|
||||
|
||||
def fake_dense_embedding(chunk_id: UUID, vector_size: int = 768) -> ChunkDenseEmbedding:
|
||||
"""
|
||||
Generate a fake dense embedding with random values.
|
||||
|
||||
Args:
|
||||
chunk_id: The chunk UUID
|
||||
vector_size: Dimension of the dense vector (default 768 for nomic-embed)
|
||||
"""
|
||||
# Generate random normalized vector
|
||||
vector = [random.uniform(-1, 1) for _ in range(vector_size)]
|
||||
return ChunkDenseEmbedding(chunk_id=chunk_id, vector=vector)
|
||||
|
||||
|
||||
def fake_sparse_embedding(chunk_id: UUID, num_dims: int = 100) -> ChunkSparseEmbedding:
|
||||
"""
|
||||
Generate a fake sparse embedding with random indices and values.
|
||||
|
||||
Args:
|
||||
chunk_id: The chunk UUID
|
||||
num_dims: Number of non-zero dimensions in the sparse vector
|
||||
"""
|
||||
# Generate random indices (sorted, no duplicates)
|
||||
indices = sorted(random.sample(range(30000), num_dims))
|
||||
# Generate random values (typical range for sparse embeddings)
|
||||
values = [random.uniform(0, 2) for _ in range(num_dims)]
|
||||
|
||||
sparse_vector = SparseVector(indices=indices, values=values)
|
||||
return ChunkSparseEmbedding(chunk_id=chunk_id, vector=sparse_vector)
|
||||
|
||||
|
||||
def generate_fake_embeddings_for_chunks(
|
||||
chunks: list[QdrantChunk],
|
||||
vector_size: int = 768,
|
||||
sparse_dims: int = 100,
|
||||
) -> tuple[list[ChunkDenseEmbedding], list[ChunkSparseEmbedding]]:
|
||||
"""
|
||||
Generate fake embeddings for a batch of chunks.
|
||||
Much faster than real embedding models for load testing.
|
||||
|
||||
Args:
|
||||
chunks: List of chunks to generate embeddings for
|
||||
vector_size: Dimension of dense vectors
|
||||
sparse_dims: Number of non-zero dimensions in sparse vectors
|
||||
|
||||
Returns:
|
||||
Tuple of (dense_embeddings, sparse_embeddings)
|
||||
"""
|
||||
dense_embeddings = [fake_dense_embedding(chunk.id, vector_size) for chunk in chunks]
|
||||
sparse_embeddings = [
|
||||
fake_sparse_embedding(chunk.id, sparse_dims) for chunk in chunks
|
||||
]
|
||||
return dense_embeddings, sparse_embeddings
|
||||
@@ -0,0 +1,296 @@
|
||||
"""
|
||||
Benchmark script for testing ACL filtering performance in Qdrant.
|
||||
Inserts many chunks with the same ACL, then measures latency of filtering by that ACL.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from fastembed import SparseTextEmbedding
|
||||
from fastembed import TextEmbedding
|
||||
from qdrant_client.models import DatetimeRange
|
||||
from qdrant_client.models import FieldCondition
|
||||
from qdrant_client.models import Filter
|
||||
from qdrant_client.models import HasIdCondition
|
||||
from qdrant_client.models import MatchValue
|
||||
|
||||
from ..client import QdrantClient
|
||||
from ..schemas.collection_name import CollectionName
|
||||
from ..schemas.source_type import SourceType
|
||||
from ..service import QdrantService
|
||||
from .fake_chunk_helpers import get_email_pool
|
||||
|
||||
|
||||
def main():
|
||||
collection_name = CollectionName.TEST_COLLECTION
|
||||
dense_model_name = "nomic-ai/nomic-embed-text-v1"
|
||||
sparse_model_name = "Qdrant/bm25"
|
||||
# sparse_model_name = "prithivida/Splade_PP_en_v1"
|
||||
|
||||
# Initialize client and service
|
||||
dense_embedding_model = TextEmbedding(model_name=dense_model_name)
|
||||
sparse_embedding_model = SparseTextEmbedding(model_name=sparse_model_name)
|
||||
service = QdrantService(client=QdrantClient())
|
||||
|
||||
collection_info = service.client.get_collection(collection_name=collection_name)
|
||||
print(f"\nCollection {collection_name} info:")
|
||||
print(f" Status: {collection_info.status}")
|
||||
print(f" Points count: {collection_info.points_count:,}")
|
||||
print(f" Indexed vectors count: {collection_info.indexed_vectors_count:,}")
|
||||
print(f" Optimizer status: {collection_info.optimizer_status}")
|
||||
print(f" Payload schema: {collection_info.payload_schema}")
|
||||
print()
|
||||
|
||||
source_type_filter = Filter(
|
||||
should=[
|
||||
FieldCondition(
|
||||
key="source_type", match=MatchValue(value=SourceType.GITHUB)
|
||||
),
|
||||
FieldCondition(key="source_type", match=MatchValue(value=SourceType.ASANA)),
|
||||
FieldCondition(key="source_type", match=MatchValue(value=SourceType.BOX)),
|
||||
]
|
||||
)
|
||||
|
||||
query_text = (
|
||||
"this is the same content for all chunks to test filtering performance."
|
||||
)
|
||||
|
||||
# Pre-compute query embeddings (exclude from search timing)
|
||||
print("\nGenerating query embeddings...")
|
||||
emb_start = time.time()
|
||||
dense_query_vector, sparse_query_vector = service.generate_query_embeddings(
|
||||
query_text, dense_embedding_model, sparse_embedding_model
|
||||
)
|
||||
emb_time = time.time() - emb_start
|
||||
print(f"Embedding time: {emb_time * 1000:.2f}ms")
|
||||
|
||||
# Test different limits
|
||||
test_limits = [1, 10, 100, 1000, 10000]
|
||||
|
||||
# Baseline: No filters
|
||||
print("\n" + "=" * 60)
|
||||
print("BENCHMARK: No Filters (Baseline)")
|
||||
print("=" * 60)
|
||||
print("Filter: None")
|
||||
print(f"{'Limit':<10} {'Results':<10} {'Latency (ms)':<15}")
|
||||
print("-" * 60)
|
||||
|
||||
for test_limit in test_limits:
|
||||
start = time.time()
|
||||
|
||||
search_result = service.hybrid_search(
|
||||
dense_query_vector=dense_query_vector,
|
||||
sparse_query_vector=sparse_query_vector,
|
||||
collection_name=collection_name,
|
||||
limit=test_limit,
|
||||
query_filter=None, # No filter
|
||||
)
|
||||
|
||||
latency = time.time() - start
|
||||
|
||||
print(
|
||||
f"{test_limit:<10} {len(search_result.points):<10} {latency * 1000:<15.2f}"
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("BENCHMARK: Source Type Filter (One-to-Many)")
|
||||
print("=" * 60)
|
||||
print("Filter: source_type in (github, asana, box)")
|
||||
print(f"{'Limit':<10} {'Results':<10} {'Latency (ms)':<15}")
|
||||
print("-" * 60)
|
||||
|
||||
for test_limit in test_limits:
|
||||
# Measure ONLY search latency (no embedding overhead)
|
||||
start = time.time()
|
||||
|
||||
search_result = service.hybrid_search(
|
||||
dense_query_vector=dense_query_vector,
|
||||
sparse_query_vector=sparse_query_vector,
|
||||
collection_name=collection_name,
|
||||
limit=test_limit,
|
||||
query_filter=source_type_filter,
|
||||
)
|
||||
|
||||
latency = time.time() - start
|
||||
|
||||
print(
|
||||
f"{test_limit:<10} {len(search_result.points):<10} {latency * 1000:<15.2f}"
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Test ACL filtering (many-to-many relationship)
|
||||
print("\n" + "=" * 60)
|
||||
print("BENCHMARK: ACL Filtering (Many-to-Many)")
|
||||
print("=" * 60)
|
||||
|
||||
# Pick a known email from the pool
|
||||
test_email = get_email_pool()[0] # "user_000@example.com"
|
||||
|
||||
# Create ACL filter
|
||||
acl_filter = Filter(
|
||||
must=[
|
||||
FieldCondition(
|
||||
key="access_control_list", match=MatchValue(value=test_email)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
print(f"Filter: access_control_list contains '{test_email}'")
|
||||
print(f"{'Limit':<10} {'Results':<10} {'Latency (ms)':<15}")
|
||||
print("-" * 60)
|
||||
|
||||
for test_limit in test_limits:
|
||||
start = time.time()
|
||||
|
||||
search_result = service.hybrid_search(
|
||||
dense_query_vector=dense_query_vector,
|
||||
sparse_query_vector=sparse_query_vector,
|
||||
collection_name=collection_name,
|
||||
limit=test_limit,
|
||||
query_filter=acl_filter,
|
||||
)
|
||||
|
||||
latency = time.time() - start
|
||||
|
||||
print(
|
||||
f"{test_limit:<10} {len(search_result.points):<10} {latency * 1000:<15.2f}"
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Test composite filtering (source_type AND created_at)
|
||||
print("\n" + "=" * 60)
|
||||
print("BENCHMARK: Composite Filter (Source + Time Range)")
|
||||
print("=" * 60)
|
||||
|
||||
import datetime
|
||||
|
||||
# Filter for source_type=ASANA AND created_at within last 30 days
|
||||
thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=30)
|
||||
|
||||
composite_filter = Filter(
|
||||
must=[
|
||||
FieldCondition(key="source_type", match=MatchValue(value=SourceType.ASANA)),
|
||||
FieldCondition(
|
||||
key="created_at", range=DatetimeRange(gte=thirty_days_ago.isoformat())
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
print(f"Filter: source_type = ASANA AND created_at >= {thirty_days_ago.date()}")
|
||||
print(f"{'Limit':<10} {'Results':<10} {'Latency (ms)':<15}")
|
||||
print("-" * 60)
|
||||
|
||||
for test_limit in test_limits:
|
||||
start = time.time()
|
||||
|
||||
search_result = service.hybrid_search(
|
||||
dense_query_vector=dense_query_vector,
|
||||
sparse_query_vector=sparse_query_vector,
|
||||
collection_name=collection_name,
|
||||
limit=test_limit,
|
||||
query_filter=composite_filter,
|
||||
)
|
||||
|
||||
latency = time.time() - start
|
||||
|
||||
print(
|
||||
f"{test_limit:<10} {len(search_result.points):<10} {latency * 1000:<15.2f}"
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Test filtering by 200 chunk IDs
|
||||
print("\n" + "=" * 60)
|
||||
print("BENCHMARK: Filter by 200 Chunk IDs")
|
||||
print("=" * 60)
|
||||
|
||||
# First, get 200 chunk IDs from the collection
|
||||
print("Fetching 200 chunk IDs from collection...")
|
||||
sample_search = service.hybrid_search(
|
||||
dense_query_vector=dense_query_vector,
|
||||
sparse_query_vector=sparse_query_vector,
|
||||
collection_name=collection_name,
|
||||
limit=200,
|
||||
query_filter=None,
|
||||
)
|
||||
|
||||
chunk_ids = [point.id for point in sample_search.points]
|
||||
print(f"Retrieved {len(chunk_ids)} chunk IDs")
|
||||
|
||||
# Create filter with all 200 IDs
|
||||
id_filter = Filter(must=[HasIdCondition(has_id=chunk_ids)])
|
||||
|
||||
print(f"Filter: chunk_id IN ({len(chunk_ids)} IDs)")
|
||||
print(f"{'Limit':<10} {'Results':<10} {'Latency (ms)':<15}")
|
||||
print("-" * 60)
|
||||
|
||||
for test_limit in test_limits:
|
||||
start = time.time()
|
||||
|
||||
search_result = service.hybrid_search(
|
||||
dense_query_vector=dense_query_vector,
|
||||
sparse_query_vector=sparse_query_vector,
|
||||
collection_name=collection_name,
|
||||
limit=test_limit,
|
||||
query_filter=id_filter,
|
||||
)
|
||||
|
||||
latency = time.time() - start
|
||||
|
||||
print(
|
||||
f"{test_limit:<10} {len(search_result.points):<10} {latency * 1000:<15.2f}"
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Test concurrent queries
|
||||
print("\n" + "=" * 60)
|
||||
print("BENCHMARK: Concurrent Queries (50 parallel)")
|
||||
print("=" * 60)
|
||||
|
||||
import concurrent.futures
|
||||
|
||||
# Use the source_type filter from the first test
|
||||
num_concurrent = 50
|
||||
query_limit = 100
|
||||
|
||||
print(f"Running {num_concurrent} concurrent queries with limit={query_limit}")
|
||||
print("Filter: source_type in (github, asana, box)")
|
||||
|
||||
def run_single_search():
|
||||
"""Helper function to run a single search."""
|
||||
return service.hybrid_search(
|
||||
dense_query_vector=dense_query_vector,
|
||||
sparse_query_vector=sparse_query_vector,
|
||||
collection_name=collection_name,
|
||||
limit=query_limit,
|
||||
query_filter=source_type_filter,
|
||||
)
|
||||
|
||||
# Execute concurrent searches
|
||||
start = time.time()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_concurrent) as executor:
|
||||
futures = [executor.submit(run_single_search) for _ in range(num_concurrent)]
|
||||
results = [
|
||||
future.result() for future in concurrent.futures.as_completed(futures)
|
||||
]
|
||||
|
||||
total_latency = time.time() - start
|
||||
|
||||
print("\nResults:")
|
||||
print(f" Total queries: {num_concurrent}")
|
||||
print(f" Total time: {total_latency * 1000:.2f}ms ({total_latency:.2f}s)")
|
||||
print(
|
||||
f" Average latency per query: {(total_latency / num_concurrent) * 1000:.2f}ms"
|
||||
)
|
||||
print(f" Queries per second: {num_concurrent / total_latency:.2f}")
|
||||
print(f" All queries returned results: {all(len(r.points) > 0 for r in results)}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
148
backend/scratch/qdrant/performance_testing/populate_chunks.py
Normal file
148
backend/scratch/qdrant/performance_testing/populate_chunks.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Script to generate and insert 1 million chunks into Qdrant for load testing.
|
||||
Uses fake embeddings for speed.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from qdrant_client.models import Distance
|
||||
from qdrant_client.models import OptimizersConfigDiff
|
||||
from qdrant_client.models import SparseVectorParams
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from ..client import QdrantClient
|
||||
from ..schemas.collection_name import CollectionName
|
||||
from ..service import QdrantService
|
||||
from .fake_chunk_helpers import generate_fake_embeddings_for_chunks
|
||||
from .fake_chunk_helpers import generate_fake_qdrant_chunks
|
||||
|
||||
|
||||
def main():
|
||||
collection_name = CollectionName.TEST_COLLECTION
|
||||
vector_size = 768 # nomic-embed-text-v1 dimension
|
||||
sparse_dims = 100 # typical sparse vector size
|
||||
|
||||
# Control whether to index while uploading
|
||||
index_while_uploading = False
|
||||
|
||||
# Initialize client and service
|
||||
service = QdrantService(client=QdrantClient())
|
||||
|
||||
# Use the vector size directly (no need to query embedding model)
|
||||
dense_embedding_size = vector_size
|
||||
|
||||
# Delete and recreate collection
|
||||
print(f"Setting up collection: {collection_name}")
|
||||
print(f"Index while uploading: {index_while_uploading}")
|
||||
service.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
# Set indexing threshold based on mode
|
||||
optimizer_config = (
|
||||
None if index_while_uploading else OptimizersConfigDiff(indexing_threshold=0)
|
||||
)
|
||||
|
||||
service.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
dense_vectors_config={
|
||||
"dense": VectorParams(size=dense_embedding_size, distance=Distance.COSINE),
|
||||
},
|
||||
sparse_vectors_config={
|
||||
"sparse": SparseVectorParams(),
|
||||
},
|
||||
optimizers_config=optimizer_config,
|
||||
shard_number=4,
|
||||
)
|
||||
print(f"Collection {collection_name} created")
|
||||
print(f"Optimizer config: {optimizer_config}\n")
|
||||
|
||||
# Generate and insert chunks in batches
|
||||
total_chunks = 1_000_000
|
||||
batch_size = 1_854
|
||||
num_batches = total_chunks // batch_size
|
||||
|
||||
print(
|
||||
f"Generating and inserting {total_chunks:,} chunks in {num_batches} batches of {batch_size:,}..."
|
||||
)
|
||||
print()
|
||||
|
||||
overall_start = time.time()
|
||||
|
||||
for batch_num in range(num_batches):
|
||||
print(f"=== Batch {batch_num + 1}/{num_batches} ===")
|
||||
|
||||
# Step 1: Generate chunks
|
||||
gen_start = time.time()
|
||||
fake_chunks = list(generate_fake_qdrant_chunks(batch_size))
|
||||
gen_time = time.time() - gen_start
|
||||
print(f"1. Generate chunks: {gen_time:.2f}s")
|
||||
|
||||
# Step 2: Generate fake embeddings
|
||||
emb_start = time.time()
|
||||
dense_embeddings, sparse_embeddings = generate_fake_embeddings_for_chunks(
|
||||
fake_chunks, vector_size, sparse_dims
|
||||
)
|
||||
emb_time = time.time() - emb_start
|
||||
print(f"2. Generate embeddings: {emb_time:.2f}s")
|
||||
|
||||
# Step 3: Build points
|
||||
build_start = time.time()
|
||||
points = service.build_points_from_chunks_and_embeddings(
|
||||
fake_chunks, dense_embeddings, sparse_embeddings
|
||||
)
|
||||
build_time = time.time() - build_start
|
||||
print(f"3. Build points: {build_time:.2f}s")
|
||||
|
||||
# Step 4: Insert to Qdrant (likely bottleneck)
|
||||
insert_start = time.time()
|
||||
result = service.client.override_points(points, collection_name)
|
||||
insert_time = time.time() - insert_start
|
||||
print(f"4. Insert to Qdrant: {insert_time:.2f}s")
|
||||
|
||||
batch_total = time.time() - gen_start
|
||||
chunks_processed = (batch_num + 1) * batch_size
|
||||
|
||||
print(f"Batch total: {batch_total:.2f}s")
|
||||
print(f"Status: {result.status}")
|
||||
print(f"Chunks processed: {chunks_processed:,} / {total_chunks:,}")
|
||||
print()
|
||||
|
||||
total_elapsed = time.time() - overall_start
|
||||
|
||||
print("=" * 60)
|
||||
print("COMPLETE")
|
||||
print("=" * 60)
|
||||
print(f"Total chunks inserted: {total_chunks:,}")
|
||||
print(f"Total time: {total_elapsed:.2f} seconds ({total_elapsed / 60:.1f} minutes)")
|
||||
print(f"Average rate: {total_chunks / total_elapsed:.1f} chunks/sec")
|
||||
print()
|
||||
|
||||
print("Collection info:")
|
||||
collection_info = service.client.get_collection(collection_name)
|
||||
print(f" Points count: {collection_info.points_count:,}")
|
||||
print(f" Indexed vectors count: {collection_info.indexed_vectors_count:,}")
|
||||
print(f" Optimizer status: {collection_info.optimizer_status}")
|
||||
print(f" Status: {collection_info.status}")
|
||||
|
||||
# Only need to trigger indexing if we disabled it during upload
|
||||
if not index_while_uploading:
|
||||
print("\nTriggering indexing (was disabled during upload)...")
|
||||
|
||||
service.client.update_collection(
|
||||
collection_name=collection_name,
|
||||
optimizers_config=OptimizersConfigDiff(
|
||||
indexing_threshold=20000,
|
||||
),
|
||||
)
|
||||
print("Collection optimizers config updated - indexing will now proceed")
|
||||
else:
|
||||
print("\nIndexing was enabled during upload - no manual trigger needed")
|
||||
|
||||
fresh_collection_info = service.client.get_collection(collection_name)
|
||||
print(f" Points count: {fresh_collection_info.points_count:,}")
|
||||
print(f" Indexed vectors count: {fresh_collection_info.indexed_vectors_count:,}")
|
||||
print(f" Optimizer status: {fresh_collection_info.optimizer_status}")
|
||||
print(f" Status: {fresh_collection_info.status}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
backend/scratch/qdrant/prefix_cache/__init__.py
Normal file
0
backend/scratch/qdrant/prefix_cache/__init__.py
Normal file
9999
backend/scratch/qdrant/prefix_cache/corpus_prefixes_100k.txt
Normal file
9999
backend/scratch/qdrant/prefix_cache/corpus_prefixes_100k.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Script to create the prefix_cache collection in Qdrant.
|
||||
|
||||
This collection stores pre-computed embeddings for common query prefixes
|
||||
to accelerate search-as-you-type functionality.
|
||||
"""
|
||||
|
||||
from qdrant_client.models import Distance
|
||||
from qdrant_client.models import SparseVectorParams
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from scratch.qdrant.client import QdrantClient
|
||||
from scratch.qdrant.schemas.collection_name import CollectionName
|
||||
|
||||
|
||||
def create_prefix_cache_collection() -> None:
|
||||
"""
|
||||
Create the prefix_cache collection with appropriate vector configurations.
|
||||
|
||||
Collection schema:
|
||||
- Dense vectors: 1024 dimensions (Cohere embed-english-v3.0)
|
||||
- Sparse vectors: Splade_PP_en_v1
|
||||
- Payload: prefix text, hit_count for analytics
|
||||
"""
|
||||
client = QdrantClient()
|
||||
|
||||
collection_name = CollectionName.PREFIX_CACHE
|
||||
|
||||
print(f"Creating collection: {collection_name}")
|
||||
|
||||
# Dense vector config (same as accuracy_testing collection)
|
||||
dense_config = VectorParams(
|
||||
size=1024, # Cohere embed-english-v3.0 dimension
|
||||
distance=Distance.COSINE,
|
||||
)
|
||||
|
||||
# Sparse vector config (same as accuracy_testing collection)
|
||||
sparse_config = {"sparse": SparseVectorParams()}
|
||||
|
||||
result = client.create_collection(
|
||||
collection_name=collection_name,
|
||||
dense_vectors_config={"dense": dense_config},
|
||||
sparse_vectors_config=sparse_config,
|
||||
shard_number=2, # Single shard for small cache collection
|
||||
)
|
||||
|
||||
if result.success:
|
||||
print(f"✓ Collection '{collection_name}' created successfully")
|
||||
|
||||
# Verify collection was created
|
||||
collection_info = client.get_collection(collection_name)
|
||||
print("\nCollection info:")
|
||||
print(f" Status: {collection_info.status}")
|
||||
print(f" Points: {collection_info.points_count}")
|
||||
print(f" Vectors config: {collection_info.config.params.vectors}")
|
||||
print(
|
||||
f" Sparse vectors config: {collection_info.config.params.sparse_vectors}"
|
||||
)
|
||||
else:
|
||||
print(f"✗ Failed to create collection '{collection_name}'")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
create_prefix_cache_collection()
|
||||
237
backend/scratch/qdrant/prefix_cache/extract_all_prefixes.py
Normal file
237
backend/scratch/qdrant/prefix_cache/extract_all_prefixes.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""
|
||||
Extract ALL prefixes from the target_docs.jsonl corpus at scale (100k+).
|
||||
|
||||
Strategy:
|
||||
1. Extract tokens from filenames, URLs, document IDs
|
||||
2. Extract ALL words from content (not just top N)
|
||||
3. Generate prefixes for everything (1-10 characters)
|
||||
4. Target: 100k+ unique prefixes
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def extract_tokens_from_filename(filename: str | None) -> list[str]:
|
||||
"""Extract searchable tokens from filename."""
|
||||
if not filename:
|
||||
return []
|
||||
|
||||
# Remove extension
|
||||
name_without_ext = re.sub(r"\.[^.]+$", "", filename)
|
||||
|
||||
# Split on common separators: _, -, ~, space, .
|
||||
tokens = re.split(r"[_\-~\s.]+", name_without_ext.lower())
|
||||
|
||||
# Filter out empty, very short tokens, and non-ASCII
|
||||
tokens = [t for t in tokens if len(t) >= 3 and t.isascii()]
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def extract_tokens_from_url(url: str | None) -> list[str]:
|
||||
"""Extract searchable tokens from URL."""
|
||||
if not url:
|
||||
return []
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Get domain parts
|
||||
domain_parts = parsed.netloc.lower().split(".")
|
||||
|
||||
# Get path parts
|
||||
path_parts = [p for p in parsed.path.lower().split("/") if p]
|
||||
|
||||
all_parts = domain_parts + path_parts
|
||||
|
||||
# Further split on common separators
|
||||
tokens = []
|
||||
for part in all_parts:
|
||||
tokens.extend(re.split(r"[_\-~\s.]+", part))
|
||||
|
||||
# Filter
|
||||
tokens = [t for t in tokens if len(t) >= 3]
|
||||
|
||||
return tokens
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
def extract_words_from_text(text: str) -> list[str]:
|
||||
"""Extract ALL words from text (no stop word filtering for max coverage)."""
|
||||
# Remove URLs, email addresses
|
||||
text = re.sub(r"http[s]?://\S+", "", text)
|
||||
text = re.sub(r"\S+@\S+", "", text)
|
||||
|
||||
# Extract words (3+ characters, ASCII letters only - NO NUMBERS, NO UNICODE)
|
||||
words = re.findall(r"\b[a-z]{3,}\b", text.lower())
|
||||
|
||||
# Filter out any non-ASCII words
|
||||
ascii_words = [w for w in words if w.isascii()]
|
||||
|
||||
return ascii_words
|
||||
|
||||
|
||||
def generate_prefixes(word: str, min_len: int = 1, max_len: int = 10) -> list[str]:
|
||||
"""Generate all prefixes of a word from min_len to max_len."""
|
||||
prefixes = []
|
||||
for length in range(min_len, min(len(word), max_len) + 1):
|
||||
prefix = word[:length]
|
||||
# Only include if it starts with a letter (no numbers or special chars)
|
||||
if prefix and prefix[0].isalpha():
|
||||
prefixes.append(prefix)
|
||||
return prefixes
|
||||
|
||||
|
||||
def main():
|
||||
jsonl_path = Path(__file__).parent.parent / "accuracy_testing" / "target_docs.jsonl"
|
||||
|
||||
print("=" * 80)
|
||||
print("EXTRACTING ALL PREFIXES AT SCALE")
|
||||
print("=" * 80)
|
||||
print(f"\nAnalyzing corpus: {jsonl_path}")
|
||||
print("Target: 100k+ unique prefixes\n")
|
||||
|
||||
# Collect ALL unique tokens
|
||||
all_tokens = set()
|
||||
doc_count = 0
|
||||
|
||||
print("Phase 1: Extracting tokens...")
|
||||
with open(jsonl_path, "r") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
doc = json.loads(line)
|
||||
|
||||
# Extract from filename
|
||||
filename = doc.get("filename")
|
||||
if filename:
|
||||
all_tokens.update(extract_tokens_from_filename(filename))
|
||||
|
||||
# Extract from URL
|
||||
url = doc.get("url")
|
||||
if url:
|
||||
all_tokens.update(extract_tokens_from_url(url))
|
||||
|
||||
# Extract from document_id
|
||||
doc_id = doc.get("document_id", "")
|
||||
if doc_id:
|
||||
# Split on common separators
|
||||
id_tokens = re.split(r"[_\-~\s.]+", doc_id.lower())
|
||||
all_tokens.update([t for t in id_tokens if len(t) >= 3])
|
||||
|
||||
# Extract from ALL documents (not sampling) to get comprehensive coverage
|
||||
content = doc.get("content", "")
|
||||
title = doc.get("title", "")
|
||||
full_text = f"{title} {content}"
|
||||
all_tokens.update(extract_words_from_text(full_text))
|
||||
|
||||
doc_count += 1
|
||||
|
||||
if doc_count % 1000 == 0:
|
||||
print(
|
||||
f" Processed {doc_count:,} documents, {len(all_tokens):,} unique tokens..."
|
||||
)
|
||||
|
||||
print(f"\n✓ Processed {doc_count:,} documents")
|
||||
print(f"✓ Found {len(all_tokens):,} unique tokens")
|
||||
|
||||
# Generate prefixes from all tokens WITH FREQUENCY TRACKING
|
||||
print("\nPhase 2: Generating prefixes with frequency...")
|
||||
prefix_frequency = Counter()
|
||||
|
||||
for token in all_tokens:
|
||||
# Generate prefixes 1-5 chars (only alphabetic)
|
||||
prefixes = generate_prefixes(token, min_len=1, max_len=5)
|
||||
for prefix in prefixes:
|
||||
prefix_frequency[prefix] += 1
|
||||
|
||||
print(
|
||||
f"✓ Generated {len(prefix_frequency):,} unique prefixes (1-5 chars, letters only)"
|
||||
)
|
||||
|
||||
# Group by length
|
||||
by_length = {}
|
||||
for prefix, freq in prefix_frequency.items():
|
||||
length = len(prefix)
|
||||
if length not in by_length:
|
||||
by_length[length] = []
|
||||
by_length[length].append((prefix, freq))
|
||||
|
||||
print("\nPrefix distribution (total available):")
|
||||
for length in sorted(by_length.keys()):
|
||||
print(f" {length}-char: {len(by_length[length]):,}")
|
||||
|
||||
# Select most POPULAR prefixes from each length category
|
||||
# Strategy: Take ALL 1-2 char, then most popular 3-5 char to reach ~10k
|
||||
target_count = 10000
|
||||
selected = []
|
||||
|
||||
# All 1-char (always useful)
|
||||
if 1 in by_length:
|
||||
selected.extend([p for p, f in by_length[1]])
|
||||
print(f"\nTaking all {len(by_length[1])} 1-char prefixes")
|
||||
|
||||
# All 2-char (very useful)
|
||||
if 2 in by_length:
|
||||
selected.extend([p for p, f in by_length[2]])
|
||||
print(f"Taking all {len(by_length[2])} 2-char prefixes")
|
||||
|
||||
# Calculate remaining budget
|
||||
remaining = target_count - len(selected)
|
||||
print(f"Remaining budget: {remaining:,} prefixes for 3-5 char")
|
||||
|
||||
# Distribute remaining across 3-5 char based on frequency
|
||||
# Weight: 25% for 3-char, 35% for 4-char, 40% for 5-char
|
||||
weights = {3: 0.25, 4: 0.35, 5: 0.40}
|
||||
|
||||
for length in [3, 4, 5]:
|
||||
if length in by_length:
|
||||
take_count = int(remaining * weights[length])
|
||||
# Sort by frequency (most popular first)
|
||||
sorted_by_freq = sorted(by_length[length], key=lambda x: x[1], reverse=True)
|
||||
top_prefixes = [p for p, f in sorted_by_freq[:take_count]]
|
||||
selected.extend(top_prefixes)
|
||||
print(
|
||||
f"Taking top {len(top_prefixes):,} most popular {length}-char prefixes"
|
||||
)
|
||||
|
||||
# Sort alphabetically
|
||||
sorted_prefixes = sorted(selected)
|
||||
|
||||
print(f"\n✓ Selected {len(sorted_prefixes):,} prefixes for cache")
|
||||
for length in range(1, 6):
|
||||
count = sum(1 for p in sorted_prefixes if len(p) == length)
|
||||
print(f" {length}-char: {count:,}")
|
||||
|
||||
# Save to file
|
||||
output_path = Path(__file__).parent / "corpus_prefixes_100k.txt"
|
||||
with open(output_path, "w") as f:
|
||||
for prefix in sorted_prefixes:
|
||||
f.write(f"{prefix}\n")
|
||||
|
||||
print(f"\n✓ Saved {len(sorted_prefixes):,} prefixes to: {output_path}")
|
||||
|
||||
# Statistics by length
|
||||
print("\nPrefix statistics by length:")
|
||||
for length in range(1, 11):
|
||||
count = sum(1 for p in sorted_prefixes if len(p) == length)
|
||||
print(f" {length}-char: {count:,}")
|
||||
|
||||
# Sample prefixes
|
||||
print("\nSample prefixes (first 100):")
|
||||
for i, prefix in enumerate(sorted_prefixes[:100], 1):
|
||||
print(f" {i:3d}. '{prefix}'")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print(f"COMPLETE - {len(sorted_prefixes):,} prefixes extracted")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
143
backend/scratch/qdrant/prefix_cache/populate_prefix_cache.py
Normal file
143
backend/scratch/qdrant/prefix_cache/populate_prefix_cache.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Script to populate the prefix_cache collection with common query prefixes.
|
||||
|
||||
This pre-computes embeddings for frequently typed prefixes to accelerate
|
||||
search-as-you-type functionality.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import cohere
|
||||
from dotenv import load_dotenv
|
||||
from fastembed import SparseTextEmbedding
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import SparseVector
|
||||
|
||||
from scratch.qdrant.client import QdrantClient
|
||||
from scratch.qdrant.prefix_cache.prefix_to_id import prefix_to_id
|
||||
from scratch.qdrant.schemas.collection_name import CollectionName
|
||||
|
||||
|
||||
def get_common_prefixes() -> list[str]:
|
||||
"""
|
||||
Get common query prefixes extracted from the actual corpus.
|
||||
|
||||
Loads ~10k most popular prefixes (1-5 chars) from corpus analysis.
|
||||
Generated by extract_all_prefixes.py from target_docs.jsonl.
|
||||
|
||||
Returns:
|
||||
List of prefix strings to cache
|
||||
"""
|
||||
# Load prefixes from file (too large to hardcode)
|
||||
prefix_file = Path(__file__).parent / "corpus_prefixes_100k.txt"
|
||||
|
||||
if not prefix_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Prefix file not found: {prefix_file}\n"
|
||||
"Run: python -m scratch.qdrant.prefix_cache.extract_all_prefixes"
|
||||
)
|
||||
|
||||
with open(prefix_file, "r") as f:
|
||||
prefixes = [line.strip() for line in f if line.strip()]
|
||||
|
||||
print(f"Loaded {len(prefixes):,} prefixes from {prefix_file.name}")
|
||||
|
||||
return prefixes
|
||||
|
||||
|
||||
def populate_prefix_cache() -> None:
|
||||
"""
|
||||
Populate the prefix_cache collection with pre-computed embeddings.
|
||||
"""
|
||||
# Load environment variables
|
||||
env_path = Path(__file__).parent.parent.parent.parent.parent / ".vscode" / ".env"
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
print(f"Loaded environment variables from {env_path}")
|
||||
else:
|
||||
print(f"Warning: .env file not found at {env_path}")
|
||||
|
||||
print("=" * 80)
|
||||
print("POPULATING PREFIX CACHE")
|
||||
print("=" * 80)
|
||||
|
||||
# Initialize clients
|
||||
client = QdrantClient()
|
||||
cohere_api_key = os.getenv("COHERE_API_KEY")
|
||||
if not cohere_api_key:
|
||||
raise ValueError("COHERE_API_KEY environment variable not set")
|
||||
|
||||
cohere_client = cohere.Client(cohere_api_key)
|
||||
# Use BM25 for sparse embeddings
|
||||
sparse_model = SparseTextEmbedding(model_name="Qdrant/bm25", threads=2)
|
||||
|
||||
# Get prefixes to cache
|
||||
prefixes = get_common_prefixes()
|
||||
print(f"\nGenerating embeddings for {len(prefixes)} prefixes...")
|
||||
|
||||
# Generate embeddings in batches
|
||||
batch_size = 96 # Cohere API batch limit
|
||||
points = []
|
||||
|
||||
for i in range(0, len(prefixes), batch_size):
|
||||
batch = prefixes[i : i + batch_size]
|
||||
print(f"\nProcessing batch {i // batch_size + 1} ({len(batch)} prefixes)...")
|
||||
|
||||
# Generate dense embeddings with Cohere
|
||||
print(" Generating dense embeddings...")
|
||||
dense_response = cohere_client.embed(
|
||||
texts=batch,
|
||||
model="embed-english-v3.0",
|
||||
input_type="search_query",
|
||||
)
|
||||
|
||||
# Generate sparse embeddings
|
||||
print(" Generating sparse embeddings...")
|
||||
sparse_embeddings = list(sparse_model.query_embed(batch))
|
||||
|
||||
# Create points - Convert prefix to u64 integer point ID
|
||||
for prefix, dense_emb, sparse_emb in zip(
|
||||
batch, dense_response.embeddings, sparse_embeddings
|
||||
):
|
||||
point_id = prefix_to_id(prefix) # Convert prefix to u64 integer!
|
||||
|
||||
point = PointStruct(
|
||||
id=point_id, # u64 integer encoding of the prefix
|
||||
vector={
|
||||
"dense": dense_emb,
|
||||
"sparse": SparseVector(
|
||||
indices=sparse_emb.indices.tolist(),
|
||||
values=sparse_emb.values.tolist(),
|
||||
),
|
||||
},
|
||||
payload={
|
||||
"prefix": prefix, # Store original prefix for reference
|
||||
"hit_count": 0,
|
||||
},
|
||||
)
|
||||
points.append(point)
|
||||
|
||||
# Upload this batch immediately (don't accumulate all points)
|
||||
print(" Uploading batch to Qdrant...")
|
||||
client.override_points(
|
||||
points=points[i : i + len(batch)], # Upload just this batch
|
||||
collection_name=CollectionName.PREFIX_CACHE,
|
||||
)
|
||||
print(f" ✓ Batch uploaded ({i + len(batch)} / {len(prefixes)} total points)")
|
||||
|
||||
print("\n✓ Successfully uploaded all prefix cache entries")
|
||||
|
||||
# Verify
|
||||
collection_info = client.get_collection(CollectionName.PREFIX_CACHE)
|
||||
print("\nCollection status:")
|
||||
print(f" Points count: {collection_info.points_count}")
|
||||
print(f" Status: {collection_info.status}")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("PREFIX CACHE POPULATION COMPLETE")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
populate_prefix_cache()
|
||||
94
backend/scratch/qdrant/prefix_cache/prefix_to_id.py
Normal file
94
backend/scratch/qdrant/prefix_cache/prefix_to_id.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Convert prefix strings to u64 integer point IDs.
|
||||
|
||||
Following the approach from the Qdrant article:
|
||||
- Use up to 8 bytes (u64) to encode the prefix
|
||||
- Encode prefix as ASCII bytes, pad with zeros
|
||||
"""
|
||||
|
||||
|
||||
def prefix_to_id(prefix: str) -> int:
|
||||
"""
|
||||
Convert a prefix string to a u64 integer point ID.
|
||||
|
||||
Encodes the prefix as ASCII bytes (up to 8 chars) and converts to integer.
|
||||
Examples:
|
||||
"a" -> 97 (ASCII value of 'a')
|
||||
"ab" -> 24930 (0x6162 = 'a' + 'b' << 8)
|
||||
"docker" -> 29273878479972 (encodes 'd','o','c','k','e','r')
|
||||
|
||||
Args:
|
||||
prefix: The prefix string (max 8 characters, ASCII only)
|
||||
|
||||
Returns:
|
||||
u64 integer point ID
|
||||
|
||||
Raises:
|
||||
ValueError: If prefix is longer than 8 characters or contains non-ASCII
|
||||
"""
|
||||
if len(prefix) > 8:
|
||||
raise ValueError(f"Prefix too long: '{prefix}' (max 8 chars)")
|
||||
|
||||
# Check if prefix is ASCII-only
|
||||
if not prefix.isascii():
|
||||
raise ValueError(f"Prefix contains non-ASCII characters: '{prefix}'")
|
||||
|
||||
# Convert prefix to bytes (ASCII)
|
||||
prefix_bytes = prefix.encode("ascii")
|
||||
|
||||
# Pad to 8 bytes with zeros (right-pad)
|
||||
padded = prefix_bytes.ljust(8, b"\x00")
|
||||
|
||||
# Convert to u64 integer (little-endian)
|
||||
point_id = int.from_bytes(padded, byteorder="little")
|
||||
|
||||
return point_id
|
||||
|
||||
|
||||
def id_to_prefix(point_id: int) -> str:
|
||||
"""
|
||||
Convert a u64 integer point ID back to prefix string.
|
||||
|
||||
Inverse of prefix_to_id() - useful for debugging.
|
||||
|
||||
Args:
|
||||
point_id: The u64 integer point ID
|
||||
|
||||
Returns:
|
||||
The original prefix string
|
||||
"""
|
||||
# Convert to 8 bytes (little-endian)
|
||||
id_bytes = point_id.to_bytes(8, byteorder="little")
|
||||
|
||||
# Remove padding zeros and decode
|
||||
prefix = id_bytes.rstrip(b"\x00").decode("ascii")
|
||||
|
||||
return prefix
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test the conversion
|
||||
test_prefixes = [
|
||||
"a",
|
||||
"ab",
|
||||
"abc",
|
||||
"docker",
|
||||
"gitlab",
|
||||
"issue",
|
||||
"customer",
|
||||
"workflow",
|
||||
]
|
||||
|
||||
print("Testing prefix_to_id conversion:")
|
||||
print("=" * 60)
|
||||
|
||||
for prefix in test_prefixes:
|
||||
point_id = prefix_to_id(prefix)
|
||||
recovered = id_to_prefix(point_id)
|
||||
|
||||
print(f"'{prefix}' -> {point_id} -> '{recovered}'")
|
||||
|
||||
# Verify round-trip
|
||||
assert recovered == prefix, f"Round-trip failed: {prefix} != {recovered}"
|
||||
|
||||
print("\n✓ All tests passed!")
|
||||
0
backend/scratch/qdrant/schemas/__init__.py
Normal file
0
backend/scratch/qdrant/schemas/__init__.py
Normal file
16
backend/scratch/qdrant/schemas/chunk.py
Normal file
16
backend/scratch/qdrant/schemas/chunk.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from scratch.qdrant.schemas.source_type import SourceType
|
||||
|
||||
|
||||
class QdrantChunk(BaseModel):
|
||||
id: UUID
|
||||
created_at: datetime.datetime
|
||||
document_id: str
|
||||
filename: str | None = None # Optional filename for matching
|
||||
source_type: SourceType | None
|
||||
access_control_list: list[str] | None # lets just say its a list of user emails
|
||||
content: str
|
||||
8
backend/scratch/qdrant/schemas/collection_name.py
Normal file
8
backend/scratch/qdrant/schemas/collection_name.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from enum import auto
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class CollectionName(StrEnum):
|
||||
TEST_COLLECTION = auto()
|
||||
ACCURACY_TESTING = auto()
|
||||
PREFIX_CACHE = auto()
|
||||
13
backend/scratch/qdrant/schemas/collection_operations.py
Normal file
13
backend/scratch/qdrant/schemas/collection_operations.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DeleteCollectionResult(BaseModel):
|
||||
success: bool
|
||||
|
||||
|
||||
class CreateCollectionResult(BaseModel):
|
||||
success: bool
|
||||
|
||||
|
||||
class UpdateCollectionResult(BaseModel):
|
||||
success: bool
|
||||
19
backend/scratch/qdrant/schemas/embeddings.py
Normal file
19
backend/scratch/qdrant/schemas/embeddings.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic.types import StrictFloat
|
||||
from qdrant_client.models import SparseVector
|
||||
|
||||
|
||||
class ChunkDenseEmbedding(BaseModel):
|
||||
"""A chunk ID paired with its dense vector embedding."""
|
||||
|
||||
chunk_id: UUID
|
||||
vector: list[StrictFloat]
|
||||
|
||||
|
||||
class ChunkSparseEmbedding(BaseModel):
|
||||
"""A chunk ID paired with its sparse vector embedding."""
|
||||
|
||||
chunk_id: UUID
|
||||
vector: SparseVector
|
||||
27
backend/scratch/qdrant/schemas/prefix_cache.py
Normal file
27
backend/scratch/qdrant/schemas/prefix_cache.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Schema for prefix cache entries.
|
||||
|
||||
The prefix cache stores pre-computed embeddings for common query prefixes
|
||||
to accelerate search-as-you-type functionality.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class PrefixCacheEntry(BaseModel):
|
||||
"""
|
||||
Represents a cached query prefix with its pre-computed embeddings.
|
||||
|
||||
Attributes:
|
||||
prefix: The query prefix text (e.g., "doc", "docker", "kubernetes")
|
||||
dense_embedding: Pre-computed dense vector embedding
|
||||
sparse_indices: Pre-computed sparse vector indices
|
||||
sparse_values: Pre-computed sparse vector values
|
||||
hit_count: Number of times this prefix has been used (for analytics)
|
||||
"""
|
||||
|
||||
prefix: str
|
||||
dense_embedding: list[float]
|
||||
sparse_indices: list[int]
|
||||
sparse_values: list[float]
|
||||
hit_count: int = 0
|
||||
25
backend/scratch/qdrant/schemas/source_type.py
Normal file
25
backend/scratch/qdrant/schemas/source_type.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from enum import auto
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class SourceType(StrEnum):
|
||||
GITHUB = auto()
|
||||
BITBUCKET = auto()
|
||||
CONFLUENCE = auto()
|
||||
GOOGLE_DRIVE = auto()
|
||||
SLACK = auto()
|
||||
DROPBOX = auto()
|
||||
JIRA = auto()
|
||||
ASANA = auto()
|
||||
TRELLO = auto()
|
||||
ZENDESK = auto()
|
||||
SALESFORCE = auto()
|
||||
NOTION = auto()
|
||||
AIRTABLE = auto()
|
||||
MONDAY = auto()
|
||||
FIGMA = auto()
|
||||
INTERCOM = auto()
|
||||
HUBSPOT = auto()
|
||||
BOX = auto()
|
||||
SHAREPOINT = auto()
|
||||
SERVICENOW = auto()
|
||||
195
backend/scratch/qdrant/service.py
Normal file
195
backend/scratch/qdrant/service.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from fastembed import SparseTextEmbedding
|
||||
from fastembed import TextEmbedding
|
||||
from qdrant_client.models import Filter
|
||||
from qdrant_client.models import Fusion
|
||||
from qdrant_client.models import FusionQuery
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import Prefetch
|
||||
from qdrant_client.models import SparseVector
|
||||
from qdrant_client.models import UpdateResult
|
||||
|
||||
from scratch.qdrant.client import QdrantClient
|
||||
from scratch.qdrant.schemas.chunk import QdrantChunk
|
||||
from scratch.qdrant.schemas.collection_name import CollectionName
|
||||
from scratch.qdrant.schemas.embeddings import ChunkDenseEmbedding
|
||||
from scratch.qdrant.schemas.embeddings import ChunkSparseEmbedding
|
||||
|
||||
|
||||
class QdrantService:
|
||||
def __init__(self, client: QdrantClient):
|
||||
self.client = client
|
||||
|
||||
def embed_chunks_to_dense_embeddings(
|
||||
self,
|
||||
chunks: list[QdrantChunk],
|
||||
dense_embedding_model: TextEmbedding,
|
||||
) -> list[ChunkDenseEmbedding]:
|
||||
dense_vectors = dense_embedding_model.passage_embed(
|
||||
[chunk.content for chunk in chunks]
|
||||
)
|
||||
return [
|
||||
ChunkDenseEmbedding(chunk_id=chunk.id, vector=vector.tolist())
|
||||
for chunk, vector in zip(chunks, dense_vectors)
|
||||
]
|
||||
|
||||
def embed_chunks_to_sparse_embeddings(
|
||||
self,
|
||||
chunks: list[QdrantChunk],
|
||||
sparse_embedding_model: SparseTextEmbedding,
|
||||
) -> list[ChunkSparseEmbedding]:
|
||||
sparse_vectors = sparse_embedding_model.passage_embed(
|
||||
[chunk.content for chunk in chunks]
|
||||
)
|
||||
return [
|
||||
ChunkSparseEmbedding(
|
||||
chunk_id=chunk.id,
|
||||
vector=SparseVector(
|
||||
indices=vector.indices.tolist(), values=vector.values.tolist()
|
||||
),
|
||||
)
|
||||
for chunk, vector in zip(chunks, sparse_vectors)
|
||||
]
|
||||
|
||||
def build_points_from_chunks_and_embeddings(
|
||||
self,
|
||||
chunks: list[QdrantChunk],
|
||||
dense_embeddings: list[ChunkDenseEmbedding],
|
||||
sparse_embeddings: list[ChunkSparseEmbedding],
|
||||
) -> list[PointStruct]:
|
||||
"""Build PointStruct objects from chunks and their embeddings."""
|
||||
# Create lookup maps by chunk_id for explicit matching
|
||||
dense_emb_map = {emb.chunk_id: emb for emb in dense_embeddings}
|
||||
sparse_emb_map = {emb.chunk_id: emb for emb in sparse_embeddings}
|
||||
|
||||
# Build points from chunks and embeddings matched by chunk_id
|
||||
points = []
|
||||
for chunk in chunks:
|
||||
dense_emb = dense_emb_map[chunk.id]
|
||||
sparse_emb = sparse_emb_map[chunk.id]
|
||||
|
||||
points.append(
|
||||
PointStruct(
|
||||
id=str(chunk.id),
|
||||
vector={"dense": dense_emb.vector, "sparse": sparse_emb.vector},
|
||||
payload=chunk.model_dump(exclude={"id"}),
|
||||
)
|
||||
)
|
||||
return points
|
||||
|
||||
def embed_and_upsert_chunks(
|
||||
self,
|
||||
chunks: list[QdrantChunk],
|
||||
dense_embedding_model: TextEmbedding,
|
||||
sparse_embedding_model: SparseTextEmbedding,
|
||||
collection_name: CollectionName,
|
||||
) -> UpdateResult:
|
||||
# Use the embedding methods to get structured embeddings
|
||||
dense_embeddings = self.embed_chunks_to_dense_embeddings(
|
||||
chunks, dense_embedding_model
|
||||
)
|
||||
sparse_embeddings = self.embed_chunks_to_sparse_embeddings(
|
||||
chunks, sparse_embedding_model
|
||||
)
|
||||
|
||||
# Build points using the helper method
|
||||
points = self.build_points_from_chunks_and_embeddings(
|
||||
chunks, dense_embeddings, sparse_embeddings
|
||||
)
|
||||
|
||||
update_result = self.client.override_points(
|
||||
points=points,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
|
||||
return update_result
|
||||
|
||||
def dense_search(
|
||||
self,
|
||||
query_text: str,
|
||||
dense_embedding_model: TextEmbedding,
|
||||
collection_name: CollectionName,
|
||||
limit: int = 10,
|
||||
query_filter: Filter | None = None,
|
||||
):
|
||||
"""Perform dense vector search only."""
|
||||
# Generate query embedding
|
||||
dense_query_vector = next(
|
||||
dense_embedding_model.query_embed(query_text)
|
||||
).tolist()
|
||||
|
||||
# Query with dense vector only
|
||||
return self.client.query_points(
|
||||
collection_name=collection_name,
|
||||
query=dense_query_vector,
|
||||
using="dense",
|
||||
query_filter=query_filter,
|
||||
with_payload=True,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def generate_query_embeddings( # should live in different service, here for convenience purposes for testing
|
||||
self,
|
||||
query_text: str,
|
||||
dense_embedding_model: TextEmbedding,
|
||||
sparse_embedding_model: SparseTextEmbedding,
|
||||
) -> tuple[list[float], SparseVector]:
|
||||
"""
|
||||
Generate dense and sparse embeddings for a query.
|
||||
Separated out so you can time just the search without embedding overhead.
|
||||
|
||||
Returns:
|
||||
Tuple of (dense_vector, sparse_vector)
|
||||
"""
|
||||
dense_query_vector = next(
|
||||
dense_embedding_model.query_embed(query_text)
|
||||
).tolist()
|
||||
sparse_embedding = next(sparse_embedding_model.query_embed(query_text))
|
||||
sparse_query_vector = SparseVector(
|
||||
indices=sparse_embedding.indices.tolist(),
|
||||
values=sparse_embedding.values.tolist(),
|
||||
)
|
||||
return dense_query_vector, sparse_query_vector
|
||||
|
||||
def hybrid_search(
|
||||
self,
|
||||
dense_query_vector: list[float],
|
||||
sparse_query_vector: SparseVector,
|
||||
collection_name: CollectionName,
|
||||
limit: int = 10,
|
||||
prefetch_limit: int | None = None,
|
||||
query_filter: Filter | None = None,
|
||||
fusion: Fusion = Fusion.DBSF,
|
||||
):
|
||||
"""
|
||||
Perform hybrid search using fusion of dense and sparse vectors.
|
||||
|
||||
Use generate_query_embeddings() first to get vectors from text.
|
||||
This keeps embedding time separate from search time for benchmarking.
|
||||
"""
|
||||
# If prefetch_limit not specified, use limit * 2 to ensure we get enough results
|
||||
effective_prefetch_limit = (
|
||||
prefetch_limit if prefetch_limit is not None else limit * 2
|
||||
)
|
||||
|
||||
# Query with fusion
|
||||
# Note: With prefetch + fusion, filters must be applied to prefetch queries
|
||||
return self.client.query_points(
|
||||
collection_name=collection_name,
|
||||
prefetch=[
|
||||
Prefetch(
|
||||
query=sparse_query_vector,
|
||||
using="sparse",
|
||||
limit=effective_prefetch_limit,
|
||||
filter=query_filter, # Apply filter to prefetch
|
||||
),
|
||||
Prefetch(
|
||||
query=dense_query_vector,
|
||||
using="dense",
|
||||
limit=effective_prefetch_limit,
|
||||
filter=query_filter, # Apply filter to prefetch
|
||||
),
|
||||
],
|
||||
fusion_query=FusionQuery(fusion=fusion),
|
||||
with_payload=True,
|
||||
limit=limit,
|
||||
)
|
||||
95
backend/scratch/qdrant/test_prefix_cache_performance.py
Normal file
95
backend/scratch/qdrant/test_prefix_cache_performance.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
Test script to verify prefix cache performance improvements.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from onyx.server.qdrant_search.service import search_documents
|
||||
|
||||
|
||||
def test_prefix_cache_performance():
|
||||
"""Test the search functionality with and without prefix caching."""
|
||||
|
||||
# Test queries - mix of cached and uncached
|
||||
test_cases = [
|
||||
("docker", "Should HIT cache"),
|
||||
("kubernetes", "Should HIT cache"),
|
||||
("python", "Should HIT cache"),
|
||||
("how to install docker containers", "Should MISS cache"),
|
||||
("docker", "Should HIT cache (2nd time)"),
|
||||
("d", "Should HIT cache (single char)"),
|
||||
("do", "Should HIT cache (2-char)"),
|
||||
("doc", "Should HIT cache (3-char)"),
|
||||
("dock", "Should HIT cache (4-char)"),
|
||||
("container orchestration performance", "Should MISS cache"),
|
||||
]
|
||||
|
||||
print("=" * 80)
|
||||
print("PREFIX CACHE PERFORMANCE TEST")
|
||||
print("=" * 80)
|
||||
|
||||
results = []
|
||||
|
||||
for query, expected in test_cases:
|
||||
print(f"\n\nQuery: '{query}' ({expected})")
|
||||
print("-" * 80)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
response = search_documents(query=query, limit=3)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
|
||||
print(f"✓ Search completed in {elapsed_ms:.2f}ms")
|
||||
print(f" Total results: {response.total_results}")
|
||||
|
||||
if response.results:
|
||||
print(f" Top result score: {response.results[0].score:.4f}")
|
||||
|
||||
results.append(
|
||||
{
|
||||
"query": query,
|
||||
"time_ms": elapsed_ms,
|
||||
"expected": expected,
|
||||
"num_results": response.total_results,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
print(f"✗ Search failed after {elapsed_ms:.2f}ms")
|
||||
print(f" Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Summary
|
||||
print("\n\n" + "=" * 80)
|
||||
print("PERFORMANCE SUMMARY")
|
||||
print("=" * 80)
|
||||
|
||||
cache_hits = [r for r in results if "HIT" in r["expected"]]
|
||||
cache_misses = [r for r in results if "MISS" in r["expected"]]
|
||||
|
||||
if cache_hits:
|
||||
avg_hit_time = sum(r["time_ms"] for r in cache_hits) / len(cache_hits)
|
||||
print(f"\nCache HITs (n={len(cache_hits)}): avg {avg_hit_time:.2f}ms")
|
||||
for r in cache_hits:
|
||||
print(f" '{r['query']}': {r['time_ms']:.2f}ms")
|
||||
|
||||
if cache_misses:
|
||||
avg_miss_time = sum(r["time_ms"] for r in cache_misses) / len(cache_misses)
|
||||
print(f"\nCache MISSes (n={len(cache_misses)}): avg {avg_miss_time:.2f}ms")
|
||||
for r in cache_misses:
|
||||
print(f" '{r['query']}': {r['time_ms']:.2f}ms")
|
||||
|
||||
if cache_hits and cache_misses:
|
||||
speedup = avg_miss_time / avg_hit_time
|
||||
print(f"\n📈 Cache speedup: {speedup:.1f}x faster")
|
||||
print(f" Cache hit latency: {avg_hit_time:.2f}ms")
|
||||
print(f" Cache miss latency: {avg_miss_time:.2f}ms")
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_prefix_cache_performance()
|
||||
56
backend/scratch/qdrant/test_search_endpoint.py
Normal file
56
backend/scratch/qdrant/test_search_endpoint.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
Test script to verify Qdrant search functionality directly.
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from onyx.server.qdrant_search.service import search_documents
|
||||
|
||||
|
||||
def test_search():
|
||||
"""Test the search functionality with timing."""
|
||||
test_queries = [
|
||||
"docker",
|
||||
"kubernetes",
|
||||
"container orchestration",
|
||||
"how to install",
|
||||
]
|
||||
|
||||
print("=" * 80)
|
||||
print("TESTING QDRANT SEARCH FUNCTIONALITY")
|
||||
print("=" * 80)
|
||||
|
||||
for query in test_queries:
|
||||
print(f"\n\nQuery: '{query}'")
|
||||
print("-" * 80)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
response = search_documents(query=query, limit=3)
|
||||
elapsed_time = (time.time() - start_time) * 1000 # Convert to milliseconds
|
||||
|
||||
print(f"✓ Search completed in {elapsed_time:.2f}ms")
|
||||
print(f"Total results: {response.total_results}")
|
||||
|
||||
for i, result in enumerate(response.results, 1):
|
||||
print(f"\n Result {i}:")
|
||||
print(f" Score: {result.score:.4f}")
|
||||
print(f" Source: {result.source_type or 'N/A'}")
|
||||
print(f" Filename: {result.filename or 'N/A'}")
|
||||
print(f" Content preview: {result.content[:100]}...")
|
||||
|
||||
except Exception as e:
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
print(f"✗ Search failed after {elapsed_time:.2f}ms")
|
||||
print(f"Error: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST COMPLETE")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_search()
|
||||
@@ -4,10 +4,12 @@ import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { ChatSearchGroup } from "./ChatSearchGroup";
|
||||
import { NewChatButton } from "./NewChatButton";
|
||||
import { useChatSearch } from "./hooks/useChatSearch";
|
||||
import { useQdrantSearch } from "./hooks/useQdrantSearch";
|
||||
import { LoadingSpinner } from "./LoadingSpinner";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { SearchInput } from "./components/SearchInput";
|
||||
import { ChatSearchSkeletonList } from "./components/ChatSearchSkeleton";
|
||||
import { DocumentSearchResults } from "./components/DocumentSearchResults";
|
||||
import { useIntersectionObserver } from "./hooks/useIntersectionObserver";
|
||||
|
||||
interface ChatSearchModalProps {
|
||||
@@ -26,6 +28,15 @@ export function ChatSearchModal({ open, onCloseModal }: ChatSearchModalProps) {
|
||||
fetchMoreChats,
|
||||
} = useChatSearch();
|
||||
|
||||
// Qdrant document search
|
||||
const { results: documentResults, isLoading: isLoadingDocuments } =
|
||||
useQdrantSearch({
|
||||
searchQuery,
|
||||
enabled: open && searchQuery.length > 0,
|
||||
debounceMs: 500,
|
||||
limit: 10,
|
||||
});
|
||||
|
||||
const onClose = () => {
|
||||
setSearchQuery("");
|
||||
onCloseModal();
|
||||
@@ -78,6 +89,15 @@ export function ChatSearchModal({ open, onCloseModal }: ChatSearchModalProps) {
|
||||
<div className="px-4 py-2">
|
||||
<NewChatButton onClick={handleNewChat} />
|
||||
|
||||
{/* Document Search Results */}
|
||||
{searchQuery && (
|
||||
<DocumentSearchResults
|
||||
results={documentResults}
|
||||
isLoading={isLoadingDocuments}
|
||||
searchQuery={searchQuery}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isSearching ? (
|
||||
<ChatSearchSkeletonList />
|
||||
) : isLoading && chatGroups.length === 0 ? (
|
||||
|
||||
@@ -0,0 +1,148 @@
|
||||
import React, { useState, useEffect, useRef } from "react";
|
||||
import { QdrantSearchResult } from "../qdrantInterfaces";
|
||||
import { FileText } from "lucide-react";
|
||||
import { highlightText } from "../utils/highlightText";
|
||||
|
||||
interface DocumentSearchResultsProps {
|
||||
results: QdrantSearchResult[];
|
||||
isLoading: boolean;
|
||||
searchQuery: string;
|
||||
}
|
||||
|
||||
export function DocumentSearchResults({
|
||||
results,
|
||||
isLoading,
|
||||
searchQuery,
|
||||
}: DocumentSearchResultsProps) {
|
||||
const [selectedIndex, setSelectedIndex] = useState<number>(-1);
|
||||
const resultRefs = useRef<(HTMLDivElement | null)[]>([]);
|
||||
|
||||
// Reset selection when results change
|
||||
useEffect(() => {
|
||||
setSelectedIndex(-1);
|
||||
resultRefs.current = resultRefs.current.slice(0, results.length);
|
||||
}, [results]);
|
||||
|
||||
// Keyboard navigation
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (e: KeyboardEvent) => {
|
||||
if (results.length === 0) return;
|
||||
|
||||
if (e.key === "ArrowDown") {
|
||||
e.preventDefault();
|
||||
setSelectedIndex((prev) =>
|
||||
prev < results.length - 1 ? prev + 1 : prev
|
||||
);
|
||||
} else if (e.key === "ArrowUp") {
|
||||
e.preventDefault();
|
||||
setSelectedIndex((prev) => (prev > 0 ? prev - 1 : -1));
|
||||
} else if (e.key === "Enter" && selectedIndex >= 0) {
|
||||
e.preventDefault();
|
||||
// Trigger click on selected result
|
||||
resultRefs.current[selectedIndex]?.click();
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener("keydown", handleKeyDown);
|
||||
return () => window.removeEventListener("keydown", handleKeyDown);
|
||||
}, [results.length, selectedIndex]);
|
||||
|
||||
// Scroll selected item into view
|
||||
useEffect(() => {
|
||||
if (selectedIndex >= 0 && resultRefs.current[selectedIndex]) {
|
||||
resultRefs.current[selectedIndex]?.scrollIntoView({
|
||||
behavior: "smooth",
|
||||
block: "nearest",
|
||||
});
|
||||
}
|
||||
}, [selectedIndex]);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="px-4 py-2">
|
||||
<div className="text-sm text-neutral-500 dark:text-neutral-400">
|
||||
Searching documents...
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (results.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handleResultClick = (result: QdrantSearchResult) => {
|
||||
// TODO: Implement document preview/open functionality
|
||||
console.log("Document clicked:", result.document_id);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="px-4 py-2">
|
||||
<div className="text-xs font-semibold text-neutral-600 dark:text-neutral-400 mb-2 uppercase tracking-wide">
|
||||
Documents ({results.length})
|
||||
</div>
|
||||
|
||||
<div className="space-y-1">
|
||||
{results.map((result, index) => {
|
||||
const isSelected = index === selectedIndex;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={result.document_id}
|
||||
ref={(el) => (resultRefs.current[index] = el)}
|
||||
onClick={() => handleResultClick(result)}
|
||||
className={`group flex items-start gap-3 px-3 py-2 rounded-lg cursor-pointer transition-colors ${
|
||||
isSelected
|
||||
? "bg-blue-100 dark:bg-blue-900/30 ring-2 ring-blue-500 dark:ring-blue-400"
|
||||
: "hover:bg-neutral-100 dark:hover:bg-neutral-700"
|
||||
}`}
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
aria-selected={isSelected}
|
||||
>
|
||||
<div className="flex-shrink-0 mt-0.5">
|
||||
<FileText
|
||||
size={16}
|
||||
className={`${
|
||||
isSelected
|
||||
? "text-blue-600 dark:text-blue-400"
|
||||
: "text-neutral-500 dark:text-neutral-400"
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 min-w-0">
|
||||
{result.filename && (
|
||||
<div className="text-sm font-medium text-neutral-900 dark:text-neutral-100 truncate">
|
||||
{highlightText(result.filename, searchQuery)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="text-sm text-neutral-600 dark:text-neutral-400 line-clamp-2 mt-1">
|
||||
{highlightText(result.content, searchQuery)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2 mt-1">
|
||||
{result.source_type && (
|
||||
<span className="text-xs text-neutral-500 dark:text-neutral-500">
|
||||
{result.source_type}
|
||||
</span>
|
||||
)}
|
||||
<span className="text-xs text-neutral-400 dark:text-neutral-600">
|
||||
Score: {result.score.toFixed(3)}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{results.length > 0 && (
|
||||
<div className="text-xs text-neutral-400 dark:text-neutral-600 mt-2 px-3">
|
||||
Use ↑↓ arrow keys to navigate, Enter to select
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
118
web/src/app/chat/chat_search/hooks/useQdrantSearch.ts
Normal file
118
web/src/app/chat/chat_search/hooks/useQdrantSearch.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
import { useState, useEffect, useCallback, useRef } from "react";
|
||||
import { searchQdrantDocuments } from "../qdrantUtils";
|
||||
import { QdrantSearchResult } from "../qdrantInterfaces";
|
||||
|
||||
interface UseQdrantSearchOptions {
|
||||
searchQuery: string;
|
||||
enabled?: boolean;
|
||||
debounceMs?: number;
|
||||
limit?: number;
|
||||
}
|
||||
|
||||
interface UseQdrantSearchResult {
|
||||
results: QdrantSearchResult[];
|
||||
isLoading: boolean;
|
||||
error: Error | null;
|
||||
}
|
||||
|
||||
export function useQdrantSearch({
|
||||
searchQuery,
|
||||
enabled = true,
|
||||
debounceMs = 500,
|
||||
limit = 10,
|
||||
}: UseQdrantSearchOptions): UseQdrantSearchResult {
|
||||
const [results, setResults] = useState<QdrantSearchResult[]>([]);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<Error | null>(null);
|
||||
|
||||
const searchTimeoutRef = useRef<NodeJS.Timeout | null>(null);
|
||||
const currentAbortController = useRef<AbortController | null>(null);
|
||||
const activeSearchIdRef = useRef<number>(0);
|
||||
|
||||
const performSearch = useCallback(
|
||||
async (query: string, searchId: number, signal: AbortSignal) => {
|
||||
try {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
const response = await searchQdrantDocuments({
|
||||
query,
|
||||
limit,
|
||||
signal,
|
||||
});
|
||||
|
||||
// Only update state if this is still the active search
|
||||
if (activeSearchIdRef.current === searchId && !signal.aborted) {
|
||||
setResults(response.results);
|
||||
}
|
||||
} catch (err: any) {
|
||||
if (
|
||||
err?.name !== "AbortError" &&
|
||||
activeSearchIdRef.current === searchId
|
||||
) {
|
||||
console.error("Error searching Qdrant:", err);
|
||||
setError(err);
|
||||
setResults([]);
|
||||
}
|
||||
} finally {
|
||||
if (activeSearchIdRef.current === searchId) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
},
|
||||
[limit]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
// Clear any pending timeouts
|
||||
if (searchTimeoutRef.current) {
|
||||
clearTimeout(searchTimeoutRef.current);
|
||||
searchTimeoutRef.current = null;
|
||||
}
|
||||
|
||||
// Abort any in-flight requests
|
||||
if (currentAbortController.current) {
|
||||
currentAbortController.current.abort();
|
||||
currentAbortController.current = null;
|
||||
}
|
||||
|
||||
// If search is disabled or query is empty, clear results
|
||||
if (!enabled || !searchQuery.trim()) {
|
||||
setResults([]);
|
||||
setIsLoading(false);
|
||||
setError(null);
|
||||
return;
|
||||
}
|
||||
|
||||
// Clear old results immediately when query changes for better UX
|
||||
setResults([]);
|
||||
setIsLoading(true);
|
||||
|
||||
// Create a new search ID
|
||||
const newSearchId = activeSearchIdRef.current + 1;
|
||||
activeSearchIdRef.current = newSearchId;
|
||||
|
||||
// Create abort controller
|
||||
const controller = new AbortController();
|
||||
currentAbortController.current = controller;
|
||||
|
||||
// Debounce the search
|
||||
searchTimeoutRef.current = setTimeout(() => {
|
||||
performSearch(searchQuery.trim(), newSearchId, controller.signal);
|
||||
}, debounceMs);
|
||||
|
||||
// Cleanup function
|
||||
return () => {
|
||||
if (searchTimeoutRef.current) {
|
||||
clearTimeout(searchTimeoutRef.current);
|
||||
}
|
||||
controller.abort();
|
||||
};
|
||||
}, [searchQuery, enabled, debounceMs, performSearch]);
|
||||
|
||||
return {
|
||||
results,
|
||||
isLoading,
|
||||
error,
|
||||
};
|
||||
}
|
||||
20
web/src/app/chat/chat_search/qdrantInterfaces.ts
Normal file
20
web/src/app/chat/chat_search/qdrantInterfaces.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
export interface QdrantSearchResult {
|
||||
document_id: string;
|
||||
content: string;
|
||||
filename: string | null;
|
||||
source_type: string | null;
|
||||
score: number;
|
||||
metadata: Record<string, any> | null;
|
||||
}
|
||||
|
||||
export interface QdrantSearchResponse {
|
||||
results: QdrantSearchResult[];
|
||||
query: string;
|
||||
total_results: number;
|
||||
}
|
||||
|
||||
export interface QdrantSearchRequest {
|
||||
query: string;
|
||||
limit?: number;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
32
web/src/app/chat/chat_search/qdrantUtils.ts
Normal file
32
web/src/app/chat/chat_search/qdrantUtils.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import { QdrantSearchRequest, QdrantSearchResponse } from "./qdrantInterfaces";
|
||||
|
||||
const API_BASE_URL = "/api";
|
||||
|
||||
export async function searchQdrantDocuments(
|
||||
params: QdrantSearchRequest
|
||||
): Promise<QdrantSearchResponse> {
|
||||
const queryParams = new URLSearchParams();
|
||||
queryParams.append("query", params.query);
|
||||
|
||||
if (params.limit) {
|
||||
queryParams.append("limit", params.limit.toString());
|
||||
}
|
||||
|
||||
const queryString = `?${queryParams.toString()}`;
|
||||
|
||||
const response = await fetch(`${API_BASE_URL}/qdrant/search${queryString}`, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
signal: params.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to search Qdrant documents: ${response.statusText}`
|
||||
);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
47
web/src/app/chat/chat_search/utils/highlightText.tsx
Normal file
47
web/src/app/chat/chat_search/utils/highlightText.tsx
Normal file
@@ -0,0 +1,47 @@
|
||||
import React from "react";
|
||||
|
||||
/**
|
||||
* Highlights matching query terms in text.
|
||||
* Returns JSX with highlighted spans.
|
||||
*/
|
||||
export function highlightText(text: string, query: string): React.ReactNode {
|
||||
if (!query || !text) {
|
||||
return text;
|
||||
}
|
||||
|
||||
// Split query into individual terms
|
||||
const terms = query.toLowerCase().trim().split(/\s+/);
|
||||
|
||||
// Escape special regex characters
|
||||
const escapeRegex = (str: string) =>
|
||||
str.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
|
||||
|
||||
// Create regex pattern that matches any of the terms
|
||||
const pattern = terms.map(escapeRegex).join("|");
|
||||
const regex = new RegExp(`(${pattern})`, "gi");
|
||||
|
||||
// Split text by matches
|
||||
const parts = text.split(regex);
|
||||
|
||||
return (
|
||||
<>
|
||||
{parts.map((part, index) => {
|
||||
// Check if this part matches any search term
|
||||
const isMatch = terms.some(
|
||||
(term) => part.toLowerCase() === term.toLowerCase()
|
||||
);
|
||||
|
||||
return isMatch ? (
|
||||
<mark
|
||||
key={index}
|
||||
className="bg-yellow-200 dark:bg-yellow-900/50 text-inherit font-medium rounded px-0.5"
|
||||
>
|
||||
{part}
|
||||
</mark>
|
||||
) : (
|
||||
<span key={index}>{part}</span>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -38,6 +38,9 @@ import {
|
||||
getIconForAction,
|
||||
hasSearchToolsAvailable,
|
||||
} from "../../services/actionUtils";
|
||||
import { useQdrantSearch } from "../../chat_search/hooks/useQdrantSearch";
|
||||
import { InlineSearchResults } from "./InlineSearchResults";
|
||||
import { QdrantSearchResult } from "../../chat_search/qdrantInterfaces";
|
||||
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
|
||||
@@ -214,6 +217,38 @@ function ChatInputBarInner({
|
||||
const { data: federatedConnectorsData } = useFederatedConnectors();
|
||||
const [showPrompts, setShowPrompts] = useState(false);
|
||||
|
||||
// Search-as-you-type for documents
|
||||
const [showDocumentSearch, setShowDocumentSearch] = useState(false);
|
||||
const [searchResultIndex, setSearchResultIndex] = useState(0);
|
||||
|
||||
const { results: documentSearchResults, isLoading: isSearchingDocuments } =
|
||||
useQdrantSearch({
|
||||
searchQuery: message,
|
||||
enabled: message.length > 0 && !showPrompts, // Don't search when showing prompts
|
||||
debounceMs: 300, // Faster for inline search
|
||||
limit: 5, // Fewer results for inline display
|
||||
});
|
||||
|
||||
// Show search results when we have query and results
|
||||
useEffect(() => {
|
||||
const shouldShow =
|
||||
message.length > 0 && !showPrompts && documentSearchResults.length > 0;
|
||||
console.log("[Search Debug]", {
|
||||
message: message.substring(0, 20),
|
||||
messageLength: message.length,
|
||||
showPrompts,
|
||||
resultsLength: documentSearchResults.length,
|
||||
shouldShow,
|
||||
});
|
||||
|
||||
if (shouldShow) {
|
||||
setShowDocumentSearch(true);
|
||||
} else {
|
||||
setShowDocumentSearch(false);
|
||||
setSearchResultIndex(0);
|
||||
}
|
||||
}, [message, showPrompts, documentSearchResults.length]);
|
||||
|
||||
// Memoize availableSources to prevent unnecessary re-renders
|
||||
const memoizedAvailableSources = useMemo(
|
||||
() => [
|
||||
@@ -230,11 +265,27 @@ function ChatInputBarInner({
|
||||
setTabbingIconIndex(0);
|
||||
};
|
||||
|
||||
const hideDocumentSearch = useCallback(() => {
|
||||
setShowDocumentSearch(false);
|
||||
setSearchResultIndex(0);
|
||||
}, []);
|
||||
|
||||
const updateInputPrompt = (prompt: InputPrompt) => {
|
||||
hidePrompts();
|
||||
setMessage(`${prompt.content}`);
|
||||
};
|
||||
|
||||
const handleSelectDocument = useCallback(
|
||||
(result: QdrantSearchResult) => {
|
||||
// Insert document reference into message
|
||||
const docReference = result.filename || result.document_id;
|
||||
const newMessage = `${message}\n\nRef: ${docReference}`;
|
||||
setMessage(newMessage);
|
||||
hideDocumentSearch();
|
||||
},
|
||||
[message, setMessage, hideDocumentSearch]
|
||||
);
|
||||
|
||||
const handlePromptInput = useCallback(
|
||||
(text: string) => {
|
||||
if (!text.startsWith("/")) {
|
||||
@@ -310,6 +361,33 @@ function ChatInputBarInner({
|
||||
}, [selectedAssistant.tools]);
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
// Handle document search navigation
|
||||
if (showDocumentSearch) {
|
||||
if (e.key === "ArrowDown") {
|
||||
e.preventDefault();
|
||||
setSearchResultIndex((prev) =>
|
||||
Math.min(prev + 1, documentSearchResults.length - 1)
|
||||
);
|
||||
return;
|
||||
} else if (e.key === "ArrowUp") {
|
||||
e.preventDefault();
|
||||
setSearchResultIndex((prev) => Math.max(prev - 1, 0));
|
||||
return;
|
||||
} else if (e.key === "Enter" && documentSearchResults.length > 0) {
|
||||
e.preventDefault();
|
||||
const selected = documentSearchResults[searchResultIndex];
|
||||
if (selected) {
|
||||
handleSelectDocument(selected);
|
||||
}
|
||||
return;
|
||||
} else if (e.key === "Escape") {
|
||||
e.preventDefault();
|
||||
hideDocumentSearch();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Handle prompt navigation
|
||||
if (showPrompts && (e.key === "Tab" || e.key == "Enter")) {
|
||||
e.preventDefault();
|
||||
|
||||
@@ -345,7 +423,18 @@ function ChatInputBarInner({
|
||||
};
|
||||
|
||||
return (
|
||||
<div id="onyx-chat-input" className="max-w-full w-[50rem]">
|
||||
<div id="onyx-chat-input" className="max-w-full w-[50rem] relative">
|
||||
{/* Document search results dropdown */}
|
||||
{showDocumentSearch && !showPrompts && (
|
||||
<InlineSearchResults
|
||||
results={documentSearchResults}
|
||||
searchQuery={message}
|
||||
selectedIndex={searchResultIndex}
|
||||
onSelectResult={handleSelectDocument}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Prompt suggestions dropdown */}
|
||||
{showPrompts && user?.preferences?.shortcut_enabled && (
|
||||
<div className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full">
|
||||
<div className="rounded-lg overflow-y-auto max-h-[200px] py-1.5 bg-background-neutral-01 border border-border-01 shadow-lg mx-2 px-1.5 mt-2 rounded z-10">
|
||||
@@ -432,6 +521,7 @@ function ChatInputBarInner({
|
||||
if (
|
||||
event.key === "Enter" &&
|
||||
!showPrompts &&
|
||||
!showDocumentSearch && // Don't submit when search results are showing
|
||||
!event.shiftKey &&
|
||||
!(event.nativeEvent as any).isComposing
|
||||
) {
|
||||
|
||||
91
web/src/app/chat/components/input/InlineSearchResults.tsx
Normal file
91
web/src/app/chat/components/input/InlineSearchResults.tsx
Normal file
@@ -0,0 +1,91 @@
|
||||
import React from "react";
|
||||
import { FileText } from "lucide-react";
|
||||
import { QdrantSearchResult } from "../../chat_search/qdrantInterfaces";
|
||||
import { highlightText } from "../../chat_search/utils/highlightText";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface InlineSearchResultsProps {
|
||||
results: QdrantSearchResult[];
|
||||
searchQuery: string;
|
||||
selectedIndex: number;
|
||||
onSelectResult: (result: QdrantSearchResult) => void;
|
||||
}
|
||||
|
||||
export function InlineSearchResults({
|
||||
results,
|
||||
searchQuery,
|
||||
selectedIndex,
|
||||
onSelectResult,
|
||||
}: InlineSearchResultsProps) {
|
||||
if (results.length === 0) {
|
||||
return (
|
||||
<div className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full">
|
||||
<div className="rounded-lg py-2 px-3 bg-background-neutral-01 border border-border-01 shadow-lg mx-2 mt-2">
|
||||
<p className="text-text-03 text-sm">
|
||||
No documents found for "{searchQuery}"
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full">
|
||||
<div className="rounded-lg overflow-y-auto max-h-[300px] py-1.5 bg-background-neutral-01 border border-border-01 shadow-lg mx-2 mt-2 z-10">
|
||||
<div className="px-2 py-1 text-xs font-semibold text-text-03 uppercase tracking-wide">
|
||||
Documents ({results.length})
|
||||
</div>
|
||||
|
||||
{results.map((result, index) => {
|
||||
const isSelected = index === selectedIndex;
|
||||
|
||||
return (
|
||||
<button
|
||||
key={result.document_id}
|
||||
className={cn(
|
||||
"w-full px-2 py-1.5 flex items-start gap-2 cursor-pointer rounded",
|
||||
isSelected && "bg-background-neutral-02",
|
||||
"hover:bg-background-neutral-02"
|
||||
)}
|
||||
onClick={() => onSelectResult(result)}
|
||||
>
|
||||
<div className="flex-shrink-0 mt-0.5">
|
||||
<FileText
|
||||
size={14}
|
||||
className={cn(isSelected ? "text-blue-600" : "text-text-03")}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="flex-1 min-w-0 text-left">
|
||||
{result.filename && (
|
||||
<div className="text-xs font-medium text-text-01 truncate">
|
||||
{highlightText(result.filename, searchQuery)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="text-xs text-text-03 line-clamp-2 mt-0.5">
|
||||
{highlightText(result.content, searchQuery)}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2 mt-0.5">
|
||||
{result.source_type && (
|
||||
<span className="text-[10px] text-text-04">
|
||||
{result.source_type}
|
||||
</span>
|
||||
)}
|
||||
<span className="text-[10px] text-text-04">
|
||||
Score: {result.score.toFixed(3)}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
|
||||
<div className="px-2 py-1 text-[10px] text-text-04 border-t border-border-01 mt-1">
|
||||
Use ↑↓ to navigate, Enter to select, Esc to close
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user