mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 16:25:45 +00:00
Compare commits
4 Commits
fix_auth
...
fix_textse
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a475f1b328 | ||
|
|
e37bb2209e | ||
|
|
5aba739aee | ||
|
|
906b29bd9b |
209
.github/workflows/pr-mit-integration-tests.yml
vendored
209
.github/workflows/pr-mit-integration-tests.yml
vendored
@@ -1,209 +0,0 @@
|
||||
name: Run MIT Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
@@ -1,50 +0,0 @@
|
||||
"""enable contextual retrieval
|
||||
|
||||
Revision ID: 8e1ac4f39a9f
|
||||
Revises: 3781a5eb12cb
|
||||
Create Date: 2024-12-20 13:29:09.918661
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8e1ac4f39a9f"
|
||||
down_revision = "3781a5eb12cb"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"enable_contextual_rag",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"contextual_rag_llm_name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"contextual_rag_llm_provider",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("search_settings", "enable_contextual_rag")
|
||||
op.drop_column("search_settings", "contextual_rag_llm_name")
|
||||
op.drop_column("search_settings", "contextual_rag_llm_provider")
|
||||
@@ -44,7 +44,7 @@ async def _get_tenant_id_from_request(
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
3) The anonymous user cookie
|
||||
3) Reset token cookie
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
@@ -52,55 +52,41 @@ async def _get_tenant_id_from_request(
|
||||
if tenant_id is not None:
|
||||
return tenant_id
|
||||
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
|
||||
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if token_data:
|
||||
tenant_id_from_payload = token_data.get(
|
||||
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
)
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
else None
|
||||
)
|
||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
if tenant_id and not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
# Check for anonymous user cookie
|
||||
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
|
||||
if anonymous_user_cookie:
|
||||
try:
|
||||
anonymous_user_data = decode_anonymous_user_jwt_token(
|
||||
anonymous_user_cookie
|
||||
)
|
||||
tenant_id = anonymous_user_data.get(
|
||||
"tenant_id", POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
if not tenant_id or not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
|
||||
# Continue and attempt to authenticate
|
||||
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
# Since token_data.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
|
||||
@@ -94,7 +94,6 @@ async def get_or_provision_tenant(
|
||||
# Notify control plane if we have created / assigned a new tenant
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -360,6 +360,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reason="Password must contain at least one special character from the following set: "
|
||||
f"{PASSWORD_SPECIAL_CHARS}."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
async def oauth_callback(
|
||||
|
||||
@@ -153,8 +153,6 @@ def _apply_pruning(
|
||||
# remove docs that are explicitly marked as not for QA
|
||||
sections = _remove_sections_to_ignore(sections=sections)
|
||||
|
||||
section_idx_token_count: dict[int, int] = {}
|
||||
|
||||
final_section_ind = None
|
||||
total_tokens = 0
|
||||
for ind, section in enumerate(sections):
|
||||
@@ -204,20 +202,10 @@ def _apply_pruning(
|
||||
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
total_tokens += section_token_count
|
||||
section_idx_token_count[ind] = section_token_count
|
||||
|
||||
if total_tokens > token_limit:
|
||||
final_section_ind = ind
|
||||
break
|
||||
|
||||
try:
|
||||
logger.debug(f"Number of documents after pruning: {ind + 1}")
|
||||
logger.debug("Number of tokens per document (pruned):")
|
||||
for x, y in section_idx_token_count.items():
|
||||
logger.debug(f"{x + 1}: {y}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging prune statistics: {e}")
|
||||
|
||||
if final_section_ind is not None:
|
||||
if is_manually_selected_docs or use_sections:
|
||||
if final_section_ind != len(sections) - 1:
|
||||
@@ -374,26 +362,6 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
try:
|
||||
num_original_sections = len(sections)
|
||||
num_original_document_ids = len(
|
||||
set([section.center_chunk.document_id for section in sections])
|
||||
)
|
||||
num_merged_sections = len(new_sections)
|
||||
num_merged_document_ids = len(
|
||||
set([section.center_chunk.document_id for section in new_sections])
|
||||
)
|
||||
logger.debug(
|
||||
f"Merged {num_original_sections} sections from {num_original_document_ids} documents "
|
||||
f"into {num_merged_sections} new sections in {num_merged_document_ids} documents"
|
||||
)
|
||||
|
||||
logger.debug("Number of chunks per document (new ranking):")
|
||||
for x, y in enumerate(new_sections):
|
||||
logger.debug(f"{x + 1}: {len(y.chunks)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging merge statistics: {e}")
|
||||
|
||||
return new_sections
|
||||
|
||||
|
||||
|
||||
@@ -495,11 +495,6 @@ NUM_SECONDARY_INDEXING_WORKERS = int(
|
||||
ENABLE_MULTIPASS_INDEXING = (
|
||||
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
|
||||
)
|
||||
# Enable contextual retrieval
|
||||
ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true"
|
||||
|
||||
DEFAULT_CONTEXTUAL_RAG_LLM_NAME = "gpt-4o-mini"
|
||||
DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER = "DevEnvPresetOpenAI"
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
@@ -541,17 +536,6 @@ MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
# Use document summary for contextual rag
|
||||
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
|
||||
# Use chunk summary for contextual rag
|
||||
USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true"
|
||||
# Average summary embeddings for contextual rag (not yet implemented)
|
||||
AVERAGE_SUMMARY_EMBEDDINGS = (
|
||||
os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.lazy import lazy_eval
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -77,26 +76,6 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
return is_valid_image_type(mime_type)
|
||||
|
||||
|
||||
def download_request(service: GoogleDriveService, file_id: str) -> bytes:
|
||||
"""
|
||||
Download the file from Google Drive.
|
||||
"""
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_id}")
|
||||
return bytes()
|
||||
return response
|
||||
|
||||
|
||||
def _download_and_extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
@@ -135,31 +114,41 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
response_call = lazy_eval(lambda: download_request(service, file_id))
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
return []
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response_call().decode("utf-8")
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response_call()))
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response_call()))
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
@@ -169,7 +158,7 @@ def _download_and_extract_sections_basic(
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response_call(),
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
@@ -182,7 +171,7 @@ def _download_and_extract_sections_basic(
|
||||
return sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
@@ -205,15 +194,8 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
if mime_type in [
|
||||
"application/vnd.google-apps.video",
|
||||
"application/vnd.google-apps.audio",
|
||||
"application/zip",
|
||||
]:
|
||||
return []
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response_call()), file_name)
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
|
||||
@@ -75,7 +75,7 @@ class HighspotClient:
|
||||
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
self.base_url = base_url.rstrip("/") + "/"
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
# Set up session with retry logic
|
||||
|
||||
@@ -163,9 +163,6 @@ class DocumentBase(BaseModel):
|
||||
attributes.append(k + INDEX_SEPARATOR + v)
|
||||
return attributes
|
||||
|
||||
def get_text_content(self) -> str:
|
||||
return " ".join([section.text for section in self.sections if section.text])
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
"""Used for Onyx ingestion api, the ID is required"""
|
||||
|
||||
@@ -60,7 +60,7 @@ class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
|
||||
inference_settings = InferenceSettings.from_db_model(search_settings)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
return cls(**inference_settings.model_dump(), **indexing_setting.model_dump())
|
||||
return cls(**inference_settings.dict(), **indexing_setting.dict())
|
||||
|
||||
|
||||
class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
@@ -80,9 +80,6 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
# Whether switching to this model requires re-indexing
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
enable_contextual_rag=search_settings.enable_contextual_rag,
|
||||
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
|
||||
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
|
||||
# Reranking Details
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
@@ -221,8 +218,6 @@ class InferenceChunk(BaseChunk):
|
||||
# to specify that a set of words should be highlighted. For example:
|
||||
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
|
||||
match_highlights: list[str]
|
||||
doc_summary: str
|
||||
chunk_context: str
|
||||
|
||||
# when the doc was last updated
|
||||
updated_at: datetime | None
|
||||
|
||||
@@ -196,21 +196,9 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
|
||||
RETURN_SEPARATOR
|
||||
)
|
||||
|
||||
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
|
||||
# remove document summary
|
||||
if chunk.content.startswith(chunk.doc_summary):
|
||||
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
|
||||
# remove chunk context
|
||||
if chunk.content.endswith(chunk.chunk_context):
|
||||
chunk.content = chunk.content[
|
||||
: len(chunk.content) - len(chunk.chunk_context)
|
||||
].rstrip()
|
||||
return chunk.content
|
||||
|
||||
for chunk in chunks:
|
||||
chunk.content = _remove_title(chunk)
|
||||
chunk.content = _remove_metadata_suffix(chunk)
|
||||
chunk.content = _remove_contextual_rag(chunk)
|
||||
|
||||
return [chunk.to_inference_chunk() for chunk in chunks]
|
||||
|
||||
|
||||
@@ -791,15 +791,6 @@ class SearchSettings(Base):
|
||||
# Mini and Large Chunks (large chunk also checks for model max context)
|
||||
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
# Contextual RAG
|
||||
enable_contextual_rag: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Contextual RAG LLM
|
||||
contextual_rag_llm_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
contextual_rag_llm_provider: Mapped[str | None] = mapped_column(
|
||||
String, nullable=True
|
||||
)
|
||||
|
||||
multilingual_expansion: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), default=[]
|
||||
)
|
||||
|
||||
@@ -62,9 +62,6 @@ def create_search_settings(
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
enable_contextual_rag=search_settings.enable_contextual_rag,
|
||||
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
|
||||
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
@@ -322,7 +319,6 @@ def get_old_default_embedding_model() -> IndexingSetting:
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
index_name="danswer_chunk",
|
||||
multipass_indexing=False,
|
||||
enable_contextual_rag=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
@@ -337,6 +333,5 @@ def get_new_default_embedding_model() -> IndexingSetting:
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
multipass_indexing=False,
|
||||
enable_contextual_rag=False,
|
||||
api_url=None,
|
||||
)
|
||||
|
||||
@@ -98,12 +98,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
field metadata type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field chunk_context type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field doc_summary type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field metadata_suffix type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
|
||||
@@ -24,11 +24,9 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import BLURB
|
||||
from onyx.document_index.vespa_constants import BOOST
|
||||
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
from onyx.document_index.vespa_constants import CONTENT
|
||||
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOC_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
@@ -128,8 +126,7 @@ def _vespa_hit_to_inference_chunk(
|
||||
return InferenceChunkUncleaned(
|
||||
chunk_id=fields[CHUNK_ID],
|
||||
blurb=fields.get(BLURB, ""), # Unused
|
||||
content=fields[CONTENT], # Includes extra title prefix and metadata suffix;
|
||||
# also sometimes context for contextual rag
|
||||
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
|
||||
source_links=source_links_dict or {0: ""},
|
||||
section_continuation=fields[SECTION_CONTINUATION],
|
||||
document_id=fields[DOCUMENT_ID],
|
||||
@@ -146,8 +143,6 @@ def _vespa_hit_to_inference_chunk(
|
||||
large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []),
|
||||
metadata=metadata,
|
||||
metadata_suffix=fields.get(METADATA_SUFFIX),
|
||||
doc_summary=fields.get(DOC_SUMMARY, ""),
|
||||
chunk_context=fields.get(CHUNK_CONTEXT, ""),
|
||||
match_highlights=match_highlights,
|
||||
updated_at=updated_at,
|
||||
)
|
||||
@@ -350,19 +345,6 @@ def query_vespa(
|
||||
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
|
||||
|
||||
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
|
||||
|
||||
try:
|
||||
num_retrieved_inference_chunks = len(inference_chunks)
|
||||
num_retrieved_document_ids = len(
|
||||
set([chunk.document_id for chunk in inference_chunks])
|
||||
)
|
||||
logger.debug(
|
||||
f"Retrieved {num_retrieved_inference_chunks} inference chunks for {num_retrieved_document_ids} documents"
|
||||
)
|
||||
except Exception as e:
|
||||
# Debug logging only, should not fail the retrieval
|
||||
logger.error(f"Error logging retrieval statistics: {e}")
|
||||
|
||||
# Good Debugging Spot
|
||||
return inference_chunks
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class VespaIndex(DocumentIndex):
|
||||
) -> None:
|
||||
if MULTI_TENANT:
|
||||
logger.info(
|
||||
"Skipping Vespa index setup for multitenant (would wipe all indices)"
|
||||
"Skipping Vespa index seup for multitenant (would wipe all indices)"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -25,11 +25,9 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
|
||||
from onyx.document_index.vespa_constants import BLURB
|
||||
from onyx.document_index.vespa_constants import BOOST
|
||||
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
from onyx.document_index.vespa_constants import CONTENT
|
||||
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOC_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
@@ -176,7 +174,7 @@ def _index_vespa_chunk(
|
||||
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
|
||||
# natural language representation of the metadata section
|
||||
CONTENT: remove_invalid_unicode_chars(
|
||||
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
|
||||
),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
@@ -191,8 +189,6 @@ def _index_vespa_chunk(
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: metadata_list,
|
||||
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
|
||||
CHUNK_CONTEXT: chunk.chunk_context,
|
||||
DOC_SUMMARY: chunk.doc_summary,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),
|
||||
|
||||
@@ -71,8 +71,6 @@ LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
|
||||
METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
METADATA_SUFFIX = "metadata_suffix"
|
||||
DOC_SUMMARY = "doc_summary"
|
||||
CHUNK_CONTEXT = "chunk_context"
|
||||
BOOST = "boost"
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
|
||||
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
|
||||
@@ -108,8 +106,6 @@ YQL_BASE = (
|
||||
f"{LARGE_CHUNK_REFERENCE_IDS}, "
|
||||
f"{METADATA}, "
|
||||
f"{METADATA_SUFFIX}, "
|
||||
f"{DOC_SUMMARY}, "
|
||||
f"{CHUNK_CONTEXT}, "
|
||||
f"{CONTENT_SUMMARY} "
|
||||
f"from {{index_name}} where "
|
||||
)
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
from onyx.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from onyx.configs.app_configs import SKIP_METADATA_IN_CHUNK
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.configs.constants import SECTION_SEPARATOR
|
||||
@@ -16,7 +13,6 @@ from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_text
|
||||
@@ -86,9 +82,6 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
|
||||
large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks],
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=large_chunk_id,
|
||||
chunk_context="",
|
||||
doc_summary="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
offset = 0
|
||||
@@ -127,7 +120,6 @@ class Chunker:
|
||||
tokenizer: BaseTokenizer,
|
||||
enable_multipass: bool = False,
|
||||
enable_large_chunks: bool = False,
|
||||
enable_contextual_rag: bool = False,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
|
||||
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
@@ -141,20 +133,9 @@ class Chunker:
|
||||
self.chunk_token_limit = chunk_token_limit
|
||||
self.enable_multipass = enable_multipass
|
||||
self.enable_large_chunks = enable_large_chunks
|
||||
self.enable_contextual_rag = enable_contextual_rag
|
||||
if enable_contextual_rag:
|
||||
assert (
|
||||
USE_CHUNK_SUMMARY or USE_DOCUMENT_SUMMARY
|
||||
), "Contextual RAG requires at least one of chunk summary and document summary enabled"
|
||||
self.default_contextual_rag_reserved_tokens = MAX_CONTEXT_TOKENS * (
|
||||
int(USE_CHUNK_SUMMARY) + int(USE_DOCUMENT_SUMMARY)
|
||||
)
|
||||
self.tokenizer = tokenizer
|
||||
self.callback = callback
|
||||
|
||||
self.max_context = 0
|
||||
self.prompt_tokens = 0
|
||||
|
||||
self.blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
chunk_size=blurb_size,
|
||||
@@ -240,9 +221,6 @@ class Chunker:
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
|
||||
)
|
||||
chunks_list.append(new_chunk)
|
||||
|
||||
@@ -310,7 +288,7 @@ class Chunker:
|
||||
continue
|
||||
|
||||
# CASE 2: Normal text section
|
||||
section_token_count = len(self.tokenizer.encode(section_text))
|
||||
section_token_count = len(self.tokenizer.tokenize(section_text))
|
||||
|
||||
# If the section is large on its own, split it separately
|
||||
if section_token_count > content_token_limit:
|
||||
@@ -333,7 +311,8 @@ class Chunker:
|
||||
# If even the split_text is bigger than strict limit, further split
|
||||
if (
|
||||
STRICT_CHUNK_TOKEN_LIMIT
|
||||
and len(self.tokenizer.encode(split_text)) > content_token_limit
|
||||
and len(self.tokenizer.tokenize(split_text))
|
||||
> content_token_limit
|
||||
):
|
||||
smaller_chunks = self._split_oversized_chunk(
|
||||
split_text, content_token_limit
|
||||
@@ -363,10 +342,10 @@ class Chunker:
|
||||
continue
|
||||
|
||||
# If we can still fit this section into the current chunk, do so
|
||||
current_token_count = len(self.tokenizer.encode(chunk_text))
|
||||
current_token_count = len(self.tokenizer.tokenize(chunk_text))
|
||||
current_offset = len(shared_precompare_cleanup(chunk_text))
|
||||
next_section_tokens = (
|
||||
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
|
||||
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
|
||||
)
|
||||
|
||||
if next_section_tokens + current_token_count <= content_token_limit:
|
||||
@@ -414,7 +393,7 @@ class Chunker:
|
||||
# Title prep
|
||||
title = self._extract_blurb(document.get_title_for_document_index() or "")
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(self.tokenizer.encode(title_prefix))
|
||||
title_tokens = len(self.tokenizer.tokenize(title_prefix))
|
||||
|
||||
# Metadata prep
|
||||
metadata_suffix_semantic = ""
|
||||
@@ -427,50 +406,15 @@ class Chunker:
|
||||
) = _get_metadata_suffix_for_document_index(
|
||||
document.metadata, include_separator=True
|
||||
)
|
||||
metadata_tokens = len(self.tokenizer.encode(metadata_suffix_semantic))
|
||||
metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic))
|
||||
|
||||
# If metadata is too large, skip it in the semantic content
|
||||
if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE:
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
single_chunk_fits = True
|
||||
doc_token_count = 0
|
||||
if self.enable_contextual_rag:
|
||||
doc_content = document.get_text_content()
|
||||
tokenized_doc = self.tokenizer.tokenize(doc_content)
|
||||
doc_token_count = len(tokenized_doc)
|
||||
|
||||
# check if doc + title + metadata fits in a single chunk. If so, no need for contextual RAG
|
||||
single_chunk_fits = (
|
||||
doc_token_count + title_tokens + metadata_tokens
|
||||
<= self.chunk_token_limit
|
||||
)
|
||||
|
||||
# expand the size of the context used for contextual rag based on whether chunk context and doc summary are used
|
||||
context_size = 0
|
||||
if (
|
||||
self.enable_contextual_rag
|
||||
and not single_chunk_fits
|
||||
and not AVERAGE_SUMMARY_EMBEDDINGS
|
||||
):
|
||||
context_size += self.default_contextual_rag_reserved_tokens
|
||||
|
||||
# Adjust content token limit to accommodate title + metadata
|
||||
content_token_limit = (
|
||||
self.chunk_token_limit - title_tokens - metadata_tokens - context_size
|
||||
)
|
||||
|
||||
# first check: if there is not enough actual chunk content when including contextual rag,
|
||||
# then don't do contextual rag
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
context_size = 0 # Don't do contextual RAG
|
||||
# revert to previous content token limit
|
||||
content_token_limit = (
|
||||
self.chunk_token_limit - title_tokens - metadata_tokens
|
||||
)
|
||||
|
||||
# If there is not enough context remaining then just index the chunk with no prefix/suffix
|
||||
content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
# Not enough space left, so revert to full chunk without the prefix
|
||||
content_token_limit = self.chunk_token_limit
|
||||
@@ -494,9 +438,6 @@ class Chunker:
|
||||
large_chunks = generate_large_chunks(normal_chunks)
|
||||
normal_chunks.extend(large_chunks)
|
||||
|
||||
for chunk in normal_chunks:
|
||||
chunk.contextual_rag_reserved_tokens = context_size
|
||||
|
||||
return normal_chunks
|
||||
|
||||
def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]:
|
||||
|
||||
@@ -121,7 +121,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
if chunk.large_chunk_reference_ids:
|
||||
large_chunks_present = True
|
||||
chunk_text = (
|
||||
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
|
||||
) or chunk.source_document.get_title_for_document_index()
|
||||
|
||||
if not chunk_text:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import Protocol
|
||||
@@ -9,13 +8,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.access import get_access_for_documents
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
|
||||
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
|
||||
@@ -43,10 +36,9 @@ from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
|
||||
from onyx.db.pg_file_store import read_lobj
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tag import create_or_add_document_tag
|
||||
from onyx.db.tag import create_or_add_document_tag_list
|
||||
from onyx.document_index.document_index_utils import (
|
||||
@@ -65,24 +57,11 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.factory import get_llm_for_contextual_rag
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_middle
|
||||
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
@@ -270,8 +249,6 @@ def index_doc_batch_with_handler(
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
enable_contextual_rag: bool = False,
|
||||
llm: LLM | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
@@ -284,8 +261,6 @@ def index_doc_batch_with_handler(
|
||||
db_session=db_session,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
tenant_id=tenant_id,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
llm=llm,
|
||||
)
|
||||
except Exception as e:
|
||||
# don't log the batch directly, it's too much text
|
||||
@@ -556,145 +531,6 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
return indexed_documents
|
||||
|
||||
|
||||
def add_document_summaries(
|
||||
chunks_by_doc: list[DocAwareChunk],
|
||||
llm: LLM,
|
||||
tokenizer: BaseTokenizer,
|
||||
trunc_doc_tokens: int,
|
||||
) -> list[int] | None:
|
||||
"""
|
||||
Adds a document summary to a list of chunks from the same document.
|
||||
Returns the number of tokens in the document.
|
||||
"""
|
||||
|
||||
doc_tokens = []
|
||||
# this is value is the same for each chunk in the document; 0 indicates
|
||||
# There is not enough space for contextual RAG (the chunk content
|
||||
# and possibly metadata took up too much space)
|
||||
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
|
||||
return None
|
||||
|
||||
doc_tokens = tokenizer.encode(chunks_by_doc[0].source_document.get_text_content())
|
||||
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_tokens, tokenizer)
|
||||
summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
|
||||
doc_summary = message_to_string(
|
||||
llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS)
|
||||
)
|
||||
|
||||
for chunk in chunks_by_doc:
|
||||
chunk.doc_summary = doc_summary
|
||||
|
||||
return doc_tokens
|
||||
|
||||
|
||||
def add_chunk_summaries(
|
||||
chunks_by_doc: list[DocAwareChunk],
|
||||
llm: LLM,
|
||||
tokenizer: BaseTokenizer,
|
||||
trunc_doc_chunk_tokens: int,
|
||||
doc_tokens: list[int] | None,
|
||||
) -> None:
|
||||
"""
|
||||
Adds chunk summaries to the chunks grouped by document id.
|
||||
Chunk summaries look at the chunk as well as the entire document (or a summary,
|
||||
if the document is too long) and describe how the chunk relates to the document.
|
||||
"""
|
||||
# all chunks within a document have the same contextual_rag_reserved_tokens
|
||||
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
|
||||
return
|
||||
|
||||
# use values computed in above doc summary section if available
|
||||
doc_tokens = doc_tokens or tokenizer.encode(
|
||||
chunks_by_doc[0].source_document.get_text_content()
|
||||
)
|
||||
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_chunk_tokens, tokenizer)
|
||||
|
||||
# only compute doc summary if needed
|
||||
doc_info = (
|
||||
doc_content
|
||||
if len(doc_tokens) <= MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
else chunks_by_doc[0].doc_summary
|
||||
)
|
||||
if not doc_info:
|
||||
# This happens if the document is too long AND document summaries are turned off
|
||||
# In this case we compute a doc summary using the LLM
|
||||
doc_info = message_to_string(
|
||||
llm.invoke(
|
||||
DOCUMENT_SUMMARY_PROMPT.format(document=doc_content),
|
||||
max_tokens=MAX_CONTEXT_TOKENS,
|
||||
)
|
||||
)
|
||||
|
||||
context_prompt1 = CONTEXTUAL_RAG_PROMPT1.format(document=doc_info)
|
||||
|
||||
def assign_context(chunk: DocAwareChunk) -> None:
|
||||
context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content)
|
||||
try:
|
||||
chunk.chunk_context = message_to_string(
|
||||
llm.invoke(
|
||||
context_prompt1 + context_prompt2,
|
||||
max_tokens=MAX_CONTEXT_TOKENS,
|
||||
)
|
||||
)
|
||||
except LLMRateLimitError as e:
|
||||
# Erroring during chunker is undesirable, so we log the error and continue
|
||||
# TODO: for v2, add robust retry logic
|
||||
logger.exception(f"Rate limit adding chunk summary: {e}", exc_info=e)
|
||||
chunk.chunk_context = ""
|
||||
except Exception as e:
|
||||
logger.exception(f"Error adding chunk summary: {e}", exc_info=e)
|
||||
chunk.chunk_context = ""
|
||||
|
||||
run_functions_tuples_in_parallel(
|
||||
[(assign_context, (chunk,)) for chunk in chunks_by_doc]
|
||||
)
|
||||
|
||||
|
||||
def add_contextual_summaries(
|
||||
chunks: list[DocAwareChunk],
|
||||
llm: LLM,
|
||||
tokenizer: BaseTokenizer,
|
||||
chunk_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Adds Document summary and chunk-within-document context to the chunks
|
||||
based on which environment variables are set.
|
||||
"""
|
||||
max_context = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
output_tokens=MAX_CONTEXT_TOKENS,
|
||||
)
|
||||
doc2chunks = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
doc2chunks[chunk.source_document.id].append(chunk)
|
||||
|
||||
# The number of tokens allowed for the document when computing a document summary
|
||||
trunc_doc_summary_tokens = max_context - len(
|
||||
tokenizer.encode(DOCUMENT_SUMMARY_PROMPT)
|
||||
)
|
||||
|
||||
prompt_tokens = len(
|
||||
tokenizer.encode(CONTEXTUAL_RAG_PROMPT1 + CONTEXTUAL_RAG_PROMPT2)
|
||||
)
|
||||
# The number of tokens allowed for the document when computing a
|
||||
# "chunk in context of document" summary
|
||||
trunc_doc_chunk_tokens = max_context - prompt_tokens - chunk_token_limit
|
||||
for chunks_by_doc in doc2chunks.values():
|
||||
doc_tokens = None
|
||||
if USE_DOCUMENT_SUMMARY:
|
||||
doc_tokens = add_document_summaries(
|
||||
chunks_by_doc, llm, tokenizer, trunc_doc_summary_tokens
|
||||
)
|
||||
|
||||
if USE_CHUNK_SUMMARY:
|
||||
add_chunk_summaries(
|
||||
chunks_by_doc, llm, tokenizer, trunc_doc_chunk_tokens, doc_tokens
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
@@ -706,8 +542,6 @@ def index_doc_batch(
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
enable_contextual_rag: bool = False,
|
||||
llm: LLM | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> IndexingPipelineResult:
|
||||
@@ -770,20 +604,6 @@ def index_doc_batch(
|
||||
# a common source of failure for the indexing pipeline
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.indexable_docs)
|
||||
|
||||
# contextual RAG
|
||||
if enable_contextual_rag:
|
||||
assert llm is not None, "must provide an LLM for contextual RAG"
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Because the chunker's tokens are different from the LLM's tokens,
|
||||
# We add a fudge factor to ensure we truncate prompts to the LLM's token limit
|
||||
chunks = add_contextual_summaries(
|
||||
chunks, llm, llm_tokenizer, chunker.chunk_token_limit * 2
|
||||
)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
embed_chunks_with_failure_handling(
|
||||
@@ -979,33 +799,13 @@ def build_indexing_pipeline(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
all_search_settings = get_active_search_settings(db_session)
|
||||
if (
|
||||
all_search_settings.secondary
|
||||
and all_search_settings.secondary.status == IndexModelStatus.FUTURE
|
||||
):
|
||||
search_settings = all_search_settings.secondary
|
||||
else:
|
||||
search_settings = all_search_settings.primary
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
|
||||
enable_contextual_rag = (
|
||||
search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG
|
||||
)
|
||||
llm = None
|
||||
if enable_contextual_rag:
|
||||
llm = get_llm_for_contextual_rag(
|
||||
search_settings.contextual_rag_llm_name or DEFAULT_CONTEXTUAL_RAG_LLM_NAME,
|
||||
search_settings.contextual_rag_llm_provider
|
||||
or DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER,
|
||||
)
|
||||
|
||||
chunker = chunker or Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=multipass_config.multipass_indexing,
|
||||
enable_large_chunks=multipass_config.enable_large_chunks,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
# after every doc, update status in case there are a bunch of really long docs
|
||||
callback=callback,
|
||||
)
|
||||
@@ -1019,6 +819,4 @@ def build_indexing_pipeline(
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
llm=llm,
|
||||
)
|
||||
|
||||
@@ -49,15 +49,6 @@ class DocAwareChunk(BaseChunk):
|
||||
metadata_suffix_semantic: str
|
||||
metadata_suffix_keyword: str
|
||||
|
||||
# This is the number of tokens reserved for contextual RAG
|
||||
# in the chunk. doc_summary and chunk_context conbined should
|
||||
# contain at most this many tokens.
|
||||
contextual_rag_reserved_tokens: int
|
||||
# This is the summary for the document generated for contextual RAG
|
||||
doc_summary: str
|
||||
# This is the context for this chunk generated for contextual RAG
|
||||
chunk_context: str
|
||||
|
||||
mini_chunk_texts: list[str] | None
|
||||
|
||||
large_chunk_id: int | None
|
||||
@@ -163,9 +154,6 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
reduced_dimension: int | None = None
|
||||
|
||||
background_reindex_enabled: bool = True
|
||||
enable_contextual_rag: bool
|
||||
contextual_rag_llm_name: str | None = None
|
||||
contextual_rag_llm_provider: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
@@ -190,7 +178,6 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
embedding_precision=search_settings.embedding_precision,
|
||||
reduced_dimension=search_settings.reduced_dimension,
|
||||
background_reindex_enabled=search_settings.background_reindex_enabled,
|
||||
enable_contextual_rag=search_settings.enable_contextual_rag,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -425,12 +425,12 @@ class DefaultMultiLLM(LLM):
|
||||
messages=processed_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice if tools else None,
|
||||
max_tokens=max_tokens,
|
||||
# streaming choice
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=0,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
@@ -531,7 +531,6 @@ class DefaultMultiLLM(LLM):
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
max_tokens,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import LLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -155,40 +154,6 @@ def get_default_llm_with_vision(
|
||||
return None
|
||||
|
||||
|
||||
def llm_from_provider(
|
||||
model_name: str,
|
||||
llm_provider: LLMProvider,
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model_name,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
|
||||
def get_llm_for_contextual_rag(model_name: str, model_provider: str) -> LLM:
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_llm_provider_view(db_session, model_provider)
|
||||
if not llm_provider:
|
||||
raise ValueError("No LLM provider with name {} found".format(model_provider))
|
||||
return llm_from_provider(
|
||||
model_name=model_name,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
|
||||
def get_default_llms(
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
@@ -214,9 +179,14 @@ def get_default_llms(
|
||||
raise ValueError("No fast default model name found")
|
||||
|
||||
def _create_llm(model: str) -> LLM:
|
||||
return llm_from_provider(
|
||||
model_name=model,
|
||||
llm_provider=llm_provider,
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
|
||||
@@ -29,19 +29,13 @@ from litellm.exceptions import Timeout # type: ignore
|
||||
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_TOKEN_ESTIMATE
|
||||
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_TOKEN_ESTIMATE
|
||||
from onyx.prompts.constants import CODE_BLOCK_PAT
|
||||
from onyx.utils.b64 import get_image_type
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
@@ -50,10 +44,6 @@ from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
MAX_CONTEXT_TOKENS = 100
|
||||
ONE_MILLION = 1_000_000
|
||||
CHUNKS_PER_DOC_ESTIMATE = 5
|
||||
|
||||
|
||||
def litellm_exception_to_error_msg(
|
||||
e: Exception,
|
||||
@@ -426,72 +416,6 @@ def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | N
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_contextual_cost(
|
||||
llm: LLM,
|
||||
) -> float:
|
||||
"""
|
||||
Approximate the cost of using the given LLM for indexing with Contextual RAG.
|
||||
|
||||
We use a precomputed estimate for the number of tokens in the contextualizing prompts,
|
||||
and we assume that every chunk is maximized in terms of content and context.
|
||||
We also assume that every document is maximized in terms of content, as currently if
|
||||
a document is longer than a certain length, its summary is used instead of the full content.
|
||||
|
||||
We expect that the first assumption will overestimate more than the second one
|
||||
underestimates, so this should be a fairly conservative price estimate. Also,
|
||||
this does not account for the cost of documents that fit within a single chunk
|
||||
which do not get contextualized.
|
||||
"""
|
||||
|
||||
# calculate input costs
|
||||
num_tokens = ONE_MILLION
|
||||
num_input_chunks = num_tokens // DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
# We assume that the documents are MAX_TOKENS_FOR_FULL_INCLUSION tokens long
|
||||
# on average.
|
||||
num_docs = num_tokens // MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
|
||||
num_input_tokens = 0
|
||||
num_output_tokens = 0
|
||||
|
||||
if not USE_CHUNK_SUMMARY and not USE_DOCUMENT_SUMMARY:
|
||||
return 0
|
||||
|
||||
if USE_CHUNK_SUMMARY:
|
||||
# Each per-chunk prompt includes:
|
||||
# - The prompt tokens
|
||||
# - the document tokens
|
||||
# - the chunk tokens
|
||||
|
||||
# for each chunk, we prompt the LLM with the contextual RAG prompt
|
||||
# and the full document content (or the doc summary, so this is an overestimate)
|
||||
num_input_tokens += num_input_chunks * (
|
||||
CONTEXTUAL_RAG_TOKEN_ESTIMATE + MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
)
|
||||
|
||||
# in aggregate, each chunk content is used as a prompt input once
|
||||
# so the full input size is covered
|
||||
num_input_tokens += num_tokens
|
||||
|
||||
# A single MAX_CONTEXT_TOKENS worth of output is generated per chunk
|
||||
num_output_tokens += num_input_chunks * MAX_CONTEXT_TOKENS
|
||||
|
||||
# going over each doc once means all the tokens, plus the prompt tokens for
|
||||
# the summary prompt. This CAN happen even when USE_DOCUMENT_SUMMARY is false,
|
||||
# since doc summaries are used for longer documents when USE_CHUNK_SUMMARY is true.
|
||||
# So, we include this unconditionally to overestimate.
|
||||
num_input_tokens += num_tokens + num_docs * DOCUMENT_SUMMARY_TOKEN_ESTIMATE
|
||||
num_output_tokens += num_docs * MAX_CONTEXT_TOKENS
|
||||
|
||||
usd_per_prompt, usd_per_completion = litellm.cost_per_token(
|
||||
model=llm.config.model_name,
|
||||
prompt_tokens=num_input_tokens,
|
||||
completion_tokens=num_output_tokens,
|
||||
)
|
||||
# Costs are in USD dollars per million tokens
|
||||
return usd_per_prompt + usd_per_completion
|
||||
|
||||
|
||||
def get_llm_max_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
|
||||
@@ -391,11 +391,6 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
if (
|
||||
AUTH_TYPE == AuthType.CLOUD
|
||||
or AUTH_TYPE == AuthType.BASIC
|
||||
or AUTH_TYPE == AuthType.GOOGLE_OAUTH
|
||||
):
|
||||
# Add refresh token endpoint for OAuth as well
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
|
||||
@@ -3,8 +3,6 @@ from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from copy import copy
|
||||
|
||||
from tokenizers import Encoding # type: ignore
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
@@ -13,8 +11,6 @@ from onyx.context.search.models import InferenceChunk
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
TRIM_SEP_PAT = "\n... {n} tokens removed...\n"
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -71,27 +67,16 @@ class TiktokenTokenizer(BaseTokenizer):
|
||||
|
||||
class HuggingFaceTokenizer(BaseTokenizer):
|
||||
def __init__(self, model_name: str):
|
||||
self.encoder: Tokenizer = Tokenizer.from_pretrained(model_name)
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
|
||||
def _safer_encode(self, string: str) -> Encoding:
|
||||
"""
|
||||
Encode a string using the HuggingFaceTokenizer, but if it fails,
|
||||
encode the string as ASCII and decode it back to a string. This helps
|
||||
in cases where the string has weird characters like \udeb4.
|
||||
"""
|
||||
try:
|
||||
return self.encoder.encode(string, add_special_tokens=False)
|
||||
except Exception:
|
||||
return self.encoder.encode(
|
||||
string.encode("ascii", "ignore").decode(), add_special_tokens=False
|
||||
)
|
||||
self.encoder = Tokenizer.from_pretrained(model_name)
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
# this returns no special tokens
|
||||
return self._safer_encode(string).ids
|
||||
return self.encoder.encode(string, add_special_tokens=False).ids
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
return self._safer_encode(string).tokens
|
||||
return self.encoder.encode(string, add_special_tokens=False).tokens
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return self.encoder.decode(tokens)
|
||||
@@ -174,26 +159,9 @@ def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) <= desired_length:
|
||||
return content
|
||||
|
||||
return tokenizer.decode(tokens[:desired_length])
|
||||
|
||||
|
||||
def tokenizer_trim_middle(
|
||||
tokens: list[int], desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
if len(tokens) <= desired_length:
|
||||
return tokenizer.decode(tokens)
|
||||
sep_str = TRIM_SEP_PAT.format(n=len(tokens) - desired_length)
|
||||
sep_tokens = tokenizer.encode(sep_str)
|
||||
slice_size = (desired_length - len(sep_tokens)) // 2
|
||||
assert slice_size > 0, "Slice size is not positive, desired length is too short"
|
||||
return (
|
||||
tokenizer.decode(tokens[:slice_size])
|
||||
+ sep_str
|
||||
+ tokenizer.decode(tokens[-slice_size:])
|
||||
)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
|
||||
@@ -220,29 +220,3 @@ Chat History:
|
||||
|
||||
Based on the above, what is a short name to convey the topic of the conversation?
|
||||
""".strip()
|
||||
|
||||
# NOTE: the prompt separation is partially done for efficiency; previously I tried
|
||||
# to do it all in one prompt with sequential format() calls but this will cause a backend
|
||||
# error when the document contains any {} as python will expect the {} to be filled by
|
||||
# format() arguments
|
||||
CONTEXTUAL_RAG_PROMPT1 = """<document>
|
||||
{document}
|
||||
</document>
|
||||
Here is the chunk we want to situate within the whole document"""
|
||||
|
||||
CONTEXTUAL_RAG_PROMPT2 = """<chunk>
|
||||
{chunk}
|
||||
</chunk>
|
||||
Please give a short succinct context to situate this chunk within the overall document
|
||||
for the purposes of improving search retrieval of the chunk. Answer only with the succinct
|
||||
context and nothing else. """
|
||||
|
||||
CONTEXTUAL_RAG_TOKEN_ESTIMATE = 64 # 19 + 45
|
||||
|
||||
DOCUMENT_SUMMARY_PROMPT = """<document>
|
||||
{document}
|
||||
</document>
|
||||
Please give a short succinct summary of the entire document. Answer only with the succinct
|
||||
summary and nothing else. """
|
||||
|
||||
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29
|
||||
|
||||
@@ -87,9 +87,6 @@ def _create_indexable_chunks(
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_reference_ids=[],
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=preprocessed_doc["content_embedding"],
|
||||
mini_chunk_embeddings=[],
|
||||
|
||||
@@ -21,11 +21,9 @@ from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.factory import get_llm
|
||||
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import get_llm_contextual_cost
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.manage.llm.models import LLMCost
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
@@ -288,38 +286,3 @@ def list_llm_provider_basics(
|
||||
db_session, user
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@admin_router.get("/provider-contextual-cost")
|
||||
def get_provider_contextual_cost(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMCost]:
|
||||
"""
|
||||
Get the cost of Re-indexing all documents for contextual retrieval.
|
||||
|
||||
See https://docs.litellm.ai/docs/completion/token_usage#5-cost_per_token
|
||||
This includes:
|
||||
- The cost of invoking the LLM on each chunk-document pair to get
|
||||
- the doc_summary
|
||||
- the chunk_context
|
||||
- The per-token cost of the LLM used to generate the doc_summary and chunk_context
|
||||
"""
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
costs = []
|
||||
for provider in providers:
|
||||
for model_name in provider.display_model_names or provider.model_names or []:
|
||||
llm = get_llm(
|
||||
provider=provider.provider,
|
||||
model=model_name,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
)
|
||||
cost = get_llm_contextual_cost(llm)
|
||||
costs.append(
|
||||
LLMCost(provider=provider.name, model_name=model_name, cost=cost)
|
||||
)
|
||||
return costs
|
||||
|
||||
@@ -119,9 +119,3 @@ class VisionProviderResponse(LLMProviderView):
|
||||
"""Response model for vision providers endpoint, including vision-specific fields."""
|
||||
|
||||
vision_models: list[str]
|
||||
|
||||
|
||||
class LLMCost(BaseModel):
|
||||
provider: str
|
||||
model_name: str
|
||||
cost: float
|
||||
|
||||
@@ -63,10 +63,7 @@ def generate_dummy_chunk(
|
||||
title_prefix=f"Title prefix for doc {doc_id}",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=generate_random_embedding(embedding_dim),
|
||||
mini_chunk_embeddings=[],
|
||||
|
||||
@@ -99,7 +99,6 @@ PRESERVED_SEARCH_FIELDS = [
|
||||
"api_url",
|
||||
"index_name",
|
||||
"multipass_indexing",
|
||||
"enable_contextual_rag",
|
||||
"model_dim",
|
||||
"normalize",
|
||||
"passage_prefix",
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -9,10 +6,6 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion(reset: None) -> None:
|
||||
"""
|
||||
Test that SAML login correctly converts users with non-authenticated roles
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -15,10 +12,6 @@ from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history is enterprise only",
|
||||
)
|
||||
def test_all_stream_chat_message_objects_outputs(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -18,10 +16,6 @@ from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
|
||||
# create connectors
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
@@ -68,10 +62,6 @@ def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -
|
||||
assert found_doc["metadata"]["document_id"] == doc.id
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
@@ -161,10 +151,6 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
|
||||
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_send_message_simple_with_history_strict_json(
|
||||
new_admin_user: DATestUser | None,
|
||||
) -> None:
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
This file takes the happy path to adding a curator to a user group and then tests
|
||||
the permissions of the curator manipulating connector-credential pairs.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -17,10 +15,6 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and User Group tests are enterprise only",
|
||||
)
|
||||
def test_cc_pair_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
This file takes the happy path to adding a curator to a user group and then tests
|
||||
the permissions of the curator manipulating connectors.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -15,10 +13,6 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_connector_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
This file takes the happy path to adding a curator to a user group and then tests
|
||||
the permissions of the curator manipulating credentials.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -14,10 +12,6 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_credential_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -12,10 +10,6 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_doc_set_permissions_setup(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -4,8 +4,6 @@ This file tests the permissions for creating and editing personas for different
|
||||
- Curators can edit personas that belong exclusively to groups they curate
|
||||
- Admins can edit all personas
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -15,10 +13,6 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_persona_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
"""
|
||||
This file tests the ability of different user types to set the role of other users.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -12,10 +10,6 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_user_role_setting_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
"""
|
||||
This test tests the happy path for curator permissions
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
@@ -16,10 +12,6 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator tests are enterprise only",
|
||||
)
|
||||
def test_whole_curator_flow(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
@@ -97,10 +89,6 @@ def test_whole_curator_flow(reset: None) -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator tests are enterprise only",
|
||||
)
|
||||
def test_global_curator_flow(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -64,10 +63,6 @@ def setup_chat_session(reset: None) -> tuple[DATestUser, str]:
|
||||
return admin_user, str(chat_session.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Chat history tests are enterprise only",
|
||||
)
|
||||
def test_chat_history_endpoints(
|
||||
reset: None, setup_chat_session: tuple[DATestUser, str]
|
||||
) -> None:
|
||||
@@ -121,10 +116,6 @@ def test_chat_history_endpoints(
|
||||
assert len(history_response.items) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Chat history tests are enterprise only",
|
||||
)
|
||||
def test_chat_history_csv_export(
|
||||
reset: None, setup_chat_session: tuple[DATestUser, str]
|
||||
) -> None:
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from tests.integration.common_utils.managers.query_history import QueryHistoryManager
|
||||
from tests.integration.common_utils.test_models import DAQueryHistoryEntry
|
||||
@@ -50,10 +47,6 @@ def _verify_query_history_pagination(
|
||||
assert all_expected_sessions == all_retrieved_sessions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Query history tests are enterprise only",
|
||||
)
|
||||
def test_query_history_pagination(reset: None) -> None:
|
||||
(
|
||||
admin_user,
|
||||
|
||||
@@ -8,10 +8,6 @@ This tests the deletion of a user group with the following foreign key constrain
|
||||
- token_rate_limit (Not Implemented)
|
||||
- persona
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
@@ -29,10 +25,6 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_user_group_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
@@ -15,10 +11,6 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -32,8 +32,6 @@ def create_test_chunk(
|
||||
match_highlights=[],
|
||||
updated_at=datetime.now(),
|
||||
image_file_name=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -78,8 +78,6 @@ def mock_inference_sections() -> list[InferenceSection]:
|
||||
source_links={0: "https://example.com/doc1"},
|
||||
match_highlights=[],
|
||||
image_file_name=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
),
|
||||
chunks=MagicMock(),
|
||||
),
|
||||
@@ -103,8 +101,6 @@ def mock_inference_sections() -> list[InferenceSection]:
|
||||
source_links={0: "https://example.com/doc2"},
|
||||
match_highlights=[],
|
||||
image_file_name=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
),
|
||||
chunks=MagicMock(),
|
||||
),
|
||||
|
||||
@@ -151,8 +151,6 @@ def test_fuzzy_match_quotes_to_docs() -> None:
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
image_file_name=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
)
|
||||
test_chunk_1 = InferenceChunk(
|
||||
document_id="test doc 1",
|
||||
@@ -172,8 +170,6 @@ def test_fuzzy_match_quotes_to_docs() -> None:
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
image_file_name=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
)
|
||||
|
||||
test_quotes = [
|
||||
|
||||
@@ -38,8 +38,6 @@ def create_inference_chunk(
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
image_file_name=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
|
||||
@@ -18,13 +17,3 @@ class MockHeartbeat(IndexingHeartbeatInterface):
|
||||
@pytest.fixture
|
||||
def mock_heartbeat() -> MockHeartbeat:
|
||||
return MockHeartbeat()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedder() -> DefaultIndexingEmbedder:
|
||||
return DefaultIndexingEmbedder(
|
||||
model_name="intfloat/e5-base-v2",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
)
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
from typing import Any
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import process_image_sections
|
||||
from onyx.llm.utils import MAX_CONTEXT_TOKENS
|
||||
from tests.unit.onyx.indexing.conftest import MockHeartbeat
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_contextual_rag", [True, False])
|
||||
def test_chunk_document(
|
||||
embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool
|
||||
) -> None:
|
||||
@pytest.fixture
|
||||
def embedder() -> DefaultIndexingEmbedder:
|
||||
return DefaultIndexingEmbedder(
|
||||
model_name="intfloat/e5-base-v2",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
)
|
||||
|
||||
|
||||
def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
|
||||
short_section_1 = "This is a short section."
|
||||
long_section = (
|
||||
"This is a long section that should be split into multiple chunks. " * 100
|
||||
@@ -44,22 +45,9 @@ def test_chunk_document(
|
||||
)
|
||||
indexing_documents = process_image_sections([document])
|
||||
|
||||
mock_llm_invoke_count = 0
|
||||
|
||||
def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock:
|
||||
nonlocal mock_llm_invoke_count
|
||||
mock_llm_invoke_count += 1
|
||||
m = Mock()
|
||||
m.content = f"Test{mock_llm_invoke_count}"
|
||||
return m
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.invoke = mock_llm_invoke
|
||||
|
||||
chunker = Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=False,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
)
|
||||
chunks = chunker.chunk(indexing_documents)
|
||||
|
||||
@@ -70,14 +58,6 @@ def test_chunk_document(
|
||||
assert "tag1" in chunks[0].metadata_suffix_keyword
|
||||
assert "tag2" in chunks[0].metadata_suffix_semantic
|
||||
|
||||
rag_tokens = MAX_CONTEXT_TOKENS * (
|
||||
int(USE_DOCUMENT_SUMMARY) + int(USE_CHUNK_SUMMARY)
|
||||
)
|
||||
for chunk in chunks:
|
||||
assert chunk.contextual_rag_reserved_tokens == (
|
||||
rag_tokens if enable_contextual_rag else 0
|
||||
)
|
||||
|
||||
|
||||
def test_chunker_heartbeat(
|
||||
embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat
|
||||
@@ -98,7 +78,6 @@ def test_chunker_heartbeat(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=False,
|
||||
callback=mock_heartbeat,
|
||||
enable_contextual_rag=False,
|
||||
)
|
||||
|
||||
chunks = chunker.chunk(indexing_documents)
|
||||
|
||||
@@ -21,13 +21,7 @@ def mock_embedding_model() -> Generator[Mock, None, None]:
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"chunk_context, doc_summary",
|
||||
[("Test chunk context", "Test document summary"), ("", "")],
|
||||
)
|
||||
def test_default_indexing_embedder_embed_chunks(
|
||||
mock_embedding_model: Mock, chunk_context: str, doc_summary: str
|
||||
) -> None:
|
||||
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
|
||||
# Setup
|
||||
embedder = DefaultIndexingEmbedder(
|
||||
model_name="test-model",
|
||||
@@ -69,9 +63,6 @@ def test_default_indexing_embedder_embed_chunks(
|
||||
large_chunk_reference_ids=[],
|
||||
large_chunk_id=None,
|
||||
image_file_name=None,
|
||||
chunk_context=chunk_context,
|
||||
doc_summary=doc_summary,
|
||||
contextual_rag_reserved_tokens=200,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -90,7 +81,7 @@ def test_default_indexing_embedder_embed_chunks(
|
||||
|
||||
# Verify the embedding model was called correctly
|
||||
mock_embedding_model.return_value.encode.assert_any_call(
|
||||
texts=[f"Title: {doc_summary}Test chunk{chunk_context}"],
|
||||
texts=["Title: Test chunk"],
|
||||
text_type=EmbedTextType.PASSAGE,
|
||||
large_chunks_present=False,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import List
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -11,12 +9,8 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import _get_aggregated_chunk_boost_factor
|
||||
from onyx.indexing.indexing_pipeline import add_contextual_summaries
|
||||
from onyx.indexing.indexing_pipeline import filter_documents
|
||||
from onyx.indexing.indexing_pipeline import process_image_sections
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
@@ -172,9 +166,6 @@ def create_test_chunk(
|
||||
embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]),
|
||||
title_embedding=None,
|
||||
image_file_name=None,
|
||||
chunk_context="",
|
||||
doc_summary="",
|
||||
contextual_rag_reserved_tokens=200,
|
||||
)
|
||||
|
||||
|
||||
@@ -258,76 +249,3 @@ def test_get_aggregated_boost_factor_individual_failure() -> None:
|
||||
)
|
||||
|
||||
assert "Failed to predict content classification for chunk" in str(exc_info.value)
|
||||
|
||||
|
||||
@patch("onyx.llm.utils.GEN_AI_MAX_TOKENS", 4096)
|
||||
@pytest.mark.parametrize("enable_contextual_rag", [True, False])
|
||||
def test_contextual_rag(
|
||||
embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool
|
||||
) -> None:
|
||||
short_section_1 = "This is a short section."
|
||||
long_section = (
|
||||
"This is a long section that should be split into multiple chunks. " * 100
|
||||
)
|
||||
short_section_2 = "This is another short section."
|
||||
short_section_3 = "This is another short section again."
|
||||
short_section_4 = "Final short section."
|
||||
semantic_identifier = "Test Document"
|
||||
|
||||
document = Document(
|
||||
id="test_doc",
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=semantic_identifier,
|
||||
metadata={"tags": ["tag1", "tag2"]},
|
||||
doc_updated_at=None,
|
||||
sections=[
|
||||
TextSection(text=short_section_1, link="link1"),
|
||||
TextSection(text=short_section_2, link="link2"),
|
||||
TextSection(text=long_section, link="link3"),
|
||||
TextSection(text=short_section_3, link="link4"),
|
||||
TextSection(text=short_section_4, link="link5"),
|
||||
],
|
||||
)
|
||||
indexing_documents = process_image_sections([document])
|
||||
|
||||
mock_llm_invoke_count = 0
|
||||
|
||||
def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock:
|
||||
nonlocal mock_llm_invoke_count
|
||||
mock_llm_invoke_count += 1
|
||||
m = Mock()
|
||||
m.content = f"Test{mock_llm_invoke_count}"
|
||||
return m
|
||||
|
||||
llm_tokenizer = embedder.embedding_model.tokenizer
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.invoke = mock_llm_invoke
|
||||
|
||||
chunker = Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
enable_multipass=False,
|
||||
enable_contextual_rag=enable_contextual_rag,
|
||||
)
|
||||
chunks = chunker.chunk(indexing_documents)
|
||||
|
||||
chunks = add_contextual_summaries(
|
||||
chunks, mock_llm, llm_tokenizer, chunker.chunk_token_limit * 2
|
||||
)
|
||||
|
||||
assert len(chunks) == 5
|
||||
assert short_section_1 in chunks[0].content
|
||||
assert short_section_3 in chunks[-1].content
|
||||
assert short_section_4 in chunks[-1].content
|
||||
assert "tag1" in chunks[0].metadata_suffix_keyword
|
||||
assert "tag2" in chunks[0].metadata_suffix_semantic
|
||||
|
||||
doc_summary = "Test1" if enable_contextual_rag else ""
|
||||
chunk_context = ""
|
||||
count = 2
|
||||
for chunk in chunks:
|
||||
if enable_contextual_rag:
|
||||
chunk_context = f"Test{count}"
|
||||
count += 1
|
||||
assert chunk.doc_summary == doc_summary
|
||||
assert chunk.chunk_context == chunk_context
|
||||
|
||||
@@ -140,12 +140,12 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice=None,
|
||||
max_tokens=None,
|
||||
stream=False,
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
parallel_tool_calls=False,
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -286,10 +286,10 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice=None,
|
||||
max_tokens=None,
|
||||
stream=True,
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
parallel_tool_calls=False,
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
|
||||
|
||||
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
|
||||
"/api/admin/llm/provider-contextual-cost";
|
||||
|
||||
export const EMBEDDING_PROVIDERS_ADMIN_URL =
|
||||
"/api/admin/embedding/embedding-provider";
|
||||
|
||||
|
||||
@@ -143,15 +143,6 @@ function Main() {
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Text className="font-semibold">Contextual RAG</Text>
|
||||
<Text className="text-text-700">
|
||||
{searchSettings.enable_contextual_rag
|
||||
? "Enabled"
|
||||
: "Disabled"}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<Text className="font-semibold">
|
||||
Disable Reranking for Streaming
|
||||
|
||||
@@ -26,18 +26,9 @@ export enum EmbeddingPrecision {
|
||||
BFLOAT16 = "bfloat16",
|
||||
}
|
||||
|
||||
export interface LLMContextualCost {
|
||||
provider: string;
|
||||
model_name: string;
|
||||
cost: number;
|
||||
}
|
||||
|
||||
export interface AdvancedSearchConfiguration {
|
||||
index_name: string | null;
|
||||
multipass_indexing: boolean;
|
||||
enable_contextual_rag: boolean;
|
||||
contextual_rag_llm_name: string | null;
|
||||
contextual_rag_llm_provider: string | null;
|
||||
multilingual_expansion: string[];
|
||||
disable_rerank_for_streaming: boolean;
|
||||
api_url: string | null;
|
||||
|
||||
@@ -3,11 +3,7 @@ import { Formik, Form, FormikProps, FieldArray, Field } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { FaPlus } from "react-icons/fa";
|
||||
import {
|
||||
AdvancedSearchConfiguration,
|
||||
EmbeddingPrecision,
|
||||
LLMContextualCost,
|
||||
} from "../interfaces";
|
||||
import { AdvancedSearchConfiguration, EmbeddingPrecision } from "../interfaces";
|
||||
import {
|
||||
BooleanFormField,
|
||||
Label,
|
||||
@@ -16,13 +12,6 @@ import {
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";
|
||||
import { StringOrNumberOption } from "@/components/Dropdown";
|
||||
import useSWR from "swr";
|
||||
import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "../../configuration/llm/constants";
|
||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
|
||||
// Number of tokens to show cost calculation for
|
||||
const COST_CALCULATION_TOKENS = 1_000_000;
|
||||
|
||||
interface AdvancedEmbeddingFormPageProps {
|
||||
updateAdvancedEmbeddingDetails: (
|
||||
@@ -56,66 +45,14 @@ const AdvancedEmbeddingFormPage = forwardRef<
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
// Fetch contextual costs
|
||||
const { data: contextualCosts, error: costError } = useSWR<
|
||||
LLMContextualCost[]
|
||||
>(LLM_CONTEXTUAL_COST_ADMIN_URL, errorHandlingFetcher);
|
||||
|
||||
const llmOptions: StringOrNumberOption[] = React.useMemo(
|
||||
() =>
|
||||
(contextualCosts || []).map((cost) => {
|
||||
return {
|
||||
name: getDisplayNameForModel(cost.model_name),
|
||||
value: cost.model_name,
|
||||
};
|
||||
}),
|
||||
[contextualCosts]
|
||||
);
|
||||
|
||||
// Helper function to format cost as USD
|
||||
const formatCost = (cost: number) => {
|
||||
return new Intl.NumberFormat("en-US", {
|
||||
style: "currency",
|
||||
currency: "USD",
|
||||
}).format(cost);
|
||||
};
|
||||
|
||||
// Get cost info for selected model
|
||||
const getSelectedModelCost = (modelName: string | null) => {
|
||||
if (!contextualCosts || !modelName) return null;
|
||||
return contextualCosts.find((cost) => cost.model_name === modelName);
|
||||
};
|
||||
|
||||
// Get the current value for the selector based on the parent state
|
||||
const getCurrentLLMValue = React.useMemo(() => {
|
||||
if (!advancedEmbeddingDetails.contextual_rag_llm_name) return null;
|
||||
return advancedEmbeddingDetails.contextual_rag_llm_name;
|
||||
}, [advancedEmbeddingDetails.contextual_rag_llm_name]);
|
||||
|
||||
return (
|
||||
<div className="py-4 rounded-lg max-w-4xl px-4 mx-auto">
|
||||
<Formik
|
||||
innerRef={ref}
|
||||
initialValues={{
|
||||
...advancedEmbeddingDetails,
|
||||
contextual_rag_llm: getCurrentLLMValue,
|
||||
}}
|
||||
initialValues={advancedEmbeddingDetails}
|
||||
validationSchema={Yup.object().shape({
|
||||
multilingual_expansion: Yup.array().of(Yup.string()),
|
||||
multipass_indexing: Yup.boolean(),
|
||||
enable_contextual_rag: Yup.boolean(),
|
||||
contextual_rag_llm: Yup.string()
|
||||
.nullable()
|
||||
.test(
|
||||
"required-if-contextual-rag",
|
||||
"LLM must be selected when Contextual RAG is enabled",
|
||||
function (value) {
|
||||
const enableContextualRag = this.parent.enable_contextual_rag;
|
||||
console.log("enableContextualRag", enableContextualRag);
|
||||
console.log("value", value);
|
||||
return !enableContextualRag || value !== null;
|
||||
}
|
||||
),
|
||||
disable_rerank_for_streaming: Yup.boolean(),
|
||||
num_rerank: Yup.number()
|
||||
.required("Number of results to rerank is required")
|
||||
@@ -142,26 +79,10 @@ const AdvancedEmbeddingFormPage = forwardRef<
|
||||
validate={(values) => {
|
||||
// Call updateAdvancedEmbeddingDetails for each changed field
|
||||
Object.entries(values).forEach(([key, value]) => {
|
||||
if (key === "contextual_rag_llm") {
|
||||
const selectedModel = (contextualCosts || []).find(
|
||||
(cost) => cost.model_name === value
|
||||
);
|
||||
if (selectedModel) {
|
||||
updateAdvancedEmbeddingDetails(
|
||||
"contextual_rag_llm_provider",
|
||||
selectedModel.provider
|
||||
);
|
||||
updateAdvancedEmbeddingDetails(
|
||||
"contextual_rag_llm_name",
|
||||
selectedModel.model_name
|
||||
);
|
||||
}
|
||||
} else {
|
||||
updateAdvancedEmbeddingDetails(
|
||||
key as keyof AdvancedSearchConfiguration,
|
||||
value
|
||||
);
|
||||
}
|
||||
updateAdvancedEmbeddingDetails(
|
||||
key as keyof AdvancedSearchConfiguration,
|
||||
value
|
||||
);
|
||||
});
|
||||
|
||||
// Run validation and report errors
|
||||
@@ -175,23 +96,6 @@ const AdvancedEmbeddingFormPage = forwardRef<
|
||||
.shape({
|
||||
multilingual_expansion: Yup.array().of(Yup.string()),
|
||||
multipass_indexing: Yup.boolean(),
|
||||
enable_contextual_rag: Yup.boolean(),
|
||||
contextual_rag_llm: Yup.string()
|
||||
.nullable()
|
||||
.test(
|
||||
"required-if-contextual-rag",
|
||||
"LLM must be selected when Contextual RAG is enabled",
|
||||
function (value) {
|
||||
const enableContextualRag =
|
||||
this.parent.enable_contextual_rag;
|
||||
console.log(
|
||||
"enableContextualRag2",
|
||||
enableContextualRag
|
||||
);
|
||||
console.log("value2", value);
|
||||
return !enableContextualRag || value !== null;
|
||||
}
|
||||
),
|
||||
disable_rerank_for_streaming: Yup.boolean(),
|
||||
num_rerank: Yup.number()
|
||||
.required("Number of results to rerank is required")
|
||||
@@ -286,56 +190,6 @@ const AdvancedEmbeddingFormPage = forwardRef<
|
||||
label="Disable Rerank for Streaming"
|
||||
name="disable_rerank_for_streaming"
|
||||
/>
|
||||
<BooleanFormField
|
||||
subtext="Enable contextual RAG for all chunk sizes."
|
||||
optional
|
||||
label="Contextual RAG"
|
||||
name="enable_contextual_rag"
|
||||
/>
|
||||
<div>
|
||||
<SelectorFormField
|
||||
name="contextual_rag_llm"
|
||||
label="Contextual RAG LLM"
|
||||
subtext={
|
||||
costError
|
||||
? "Error loading LLM models. Please try again later."
|
||||
: !contextualCosts
|
||||
? "Loading available LLM models..."
|
||||
: values.enable_contextual_rag
|
||||
? "Select the LLM model to use for contextual RAG processing."
|
||||
: "Enable Contextual RAG above to select an LLM model."
|
||||
}
|
||||
options={llmOptions}
|
||||
disabled={
|
||||
!values.enable_contextual_rag ||
|
||||
!contextualCosts ||
|
||||
!!costError
|
||||
}
|
||||
/>
|
||||
{values.enable_contextual_rag &&
|
||||
values.contextual_rag_llm &&
|
||||
!costError && (
|
||||
<div className="mt-2 text-sm text-text-600">
|
||||
{contextualCosts ? (
|
||||
<>
|
||||
Estimated cost for processing{" "}
|
||||
{COST_CALCULATION_TOKENS.toLocaleString()} tokens:{" "}
|
||||
<span className="font-medium">
|
||||
{getSelectedModelCost(values.contextual_rag_llm)
|
||||
? formatCost(
|
||||
getSelectedModelCost(
|
||||
values.contextual_rag_llm
|
||||
)!.cost
|
||||
)
|
||||
: "Cost information not available"}
|
||||
</span>
|
||||
</>
|
||||
) : (
|
||||
"Loading cost information..."
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<NumberInput
|
||||
description="Number of results to rerank"
|
||||
optional={false}
|
||||
|
||||
@@ -64,9 +64,6 @@ export default function EmbeddingForm() {
|
||||
useState<AdvancedSearchConfiguration>({
|
||||
index_name: "",
|
||||
multipass_indexing: true,
|
||||
enable_contextual_rag: false,
|
||||
contextual_rag_llm_name: null,
|
||||
contextual_rag_llm_provider: null,
|
||||
multilingual_expansion: [],
|
||||
disable_rerank_for_streaming: false,
|
||||
api_url: null,
|
||||
@@ -155,9 +152,6 @@ export default function EmbeddingForm() {
|
||||
setAdvancedEmbeddingDetails({
|
||||
index_name: searchSettings.index_name,
|
||||
multipass_indexing: searchSettings.multipass_indexing,
|
||||
enable_contextual_rag: searchSettings.enable_contextual_rag,
|
||||
contextual_rag_llm_name: searchSettings.contextual_rag_llm_name,
|
||||
contextual_rag_llm_provider: searchSettings.contextual_rag_llm_provider,
|
||||
multilingual_expansion: searchSettings.multilingual_expansion,
|
||||
disable_rerank_for_streaming:
|
||||
searchSettings.disable_rerank_for_streaming,
|
||||
@@ -203,9 +197,7 @@ export default function EmbeddingForm() {
|
||||
searchSettings?.embedding_precision !=
|
||||
advancedEmbeddingDetails.embedding_precision ||
|
||||
searchSettings?.reduced_dimension !=
|
||||
advancedEmbeddingDetails.reduced_dimension ||
|
||||
searchSettings?.enable_contextual_rag !=
|
||||
advancedEmbeddingDetails.enable_contextual_rag;
|
||||
advancedEmbeddingDetails.reduced_dimension;
|
||||
|
||||
const updateSearch = useCallback(async () => {
|
||||
if (!selectedProvider) {
|
||||
@@ -392,14 +384,6 @@ export default function EmbeddingForm() {
|
||||
advancedEmbeddingDetails.reduced_dimension && (
|
||||
<li>Reduced dimension modification</li>
|
||||
)}
|
||||
{(searchSettings?.enable_contextual_rag !=
|
||||
advancedEmbeddingDetails.enable_contextual_rag ||
|
||||
searchSettings?.contextual_rag_llm_name !=
|
||||
advancedEmbeddingDetails.contextual_rag_llm_name ||
|
||||
searchSettings?.contextual_rag_llm_provider !=
|
||||
advancedEmbeddingDetails.contextual_rag_llm_provider) && (
|
||||
<li>Contextual RAG modification</li>
|
||||
)}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
@@ -487,11 +471,6 @@ export default function EmbeddingForm() {
|
||||
};
|
||||
|
||||
const handleReIndex = async () => {
|
||||
console.log("handleReIndex");
|
||||
console.log(selectedProvider);
|
||||
console.log(advancedEmbeddingDetails);
|
||||
console.log(rerankingDetails);
|
||||
console.log(reindexType);
|
||||
if (!selectedProvider) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ export default function LoginPage({
|
||||
authTypeMetadata,
|
||||
nextUrl,
|
||||
searchParams,
|
||||
hidePageRedirect,
|
||||
showPageRedirect,
|
||||
}: {
|
||||
authUrl: string | null;
|
||||
authTypeMetadata: AuthTypeMetadata | null;
|
||||
@@ -24,7 +24,7 @@ export default function LoginPage({
|
||||
[key: string]: string | string[] | undefined;
|
||||
}
|
||||
| undefined;
|
||||
hidePageRedirect?: boolean;
|
||||
showPageRedirect?: boolean;
|
||||
}) {
|
||||
useSendAuthRequiredMessage();
|
||||
return (
|
||||
@@ -75,7 +75,7 @@ export default function LoginPage({
|
||||
<div className="flex flex-col gap-y-2 items-center"></div>
|
||||
</>
|
||||
)}
|
||||
{!hidePageRedirect && (
|
||||
{showPageRedirect && (
|
||||
<p className="text-center mt-4">
|
||||
Don't have an account?{" "}
|
||||
<span
|
||||
|
||||
@@ -72,7 +72,6 @@ const Page = async (props: {
|
||||
authTypeMetadata={authTypeMetadata}
|
||||
nextUrl={nextUrl!}
|
||||
searchParams={searchParams}
|
||||
hidePageRedirect={true}
|
||||
/>
|
||||
</AuthFlowContainer>
|
||||
</div>
|
||||
|
||||
@@ -91,7 +91,7 @@ export function AgenticToggle({
|
||||
>
|
||||
<div className="flex items-center space-x-2 mb-3">
|
||||
<h3 className="text-sm font-semibold text-neutral-900">
|
||||
Agent Search
|
||||
Agent Search (BETA)
|
||||
</h3>
|
||||
</div>
|
||||
<p className="text-xs text-neutral-600 dark:text-neutral-700 mb-2">
|
||||
|
||||
@@ -347,6 +347,7 @@ export default function NRFPage({
|
||||
<p className="p-4">Loading login info…</p>
|
||||
) : authType == "basic" ? (
|
||||
<LoginPage
|
||||
showPageRedirect
|
||||
authUrl={null}
|
||||
authTypeMetadata={{
|
||||
authType: authType as AuthType,
|
||||
|
||||
@@ -20,7 +20,7 @@ import {
|
||||
import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata";
|
||||
import { AppProvider } from "@/components/context/AppProvider";
|
||||
import { PHProvider } from "./providers";
|
||||
import { getAuthTypeMetadataSS, getCurrentUserSS } from "@/lib/userSS";
|
||||
import { getCurrentUserSS } from "@/lib/userSS";
|
||||
import { Suspense } from "react";
|
||||
import PostHogPageView from "./PostHogPageView";
|
||||
import Script from "next/script";
|
||||
@@ -55,7 +55,7 @@ export async function generateMetadata(): Promise<Metadata> {
|
||||
}
|
||||
|
||||
return {
|
||||
title: enterpriseSettings?.application_name || "Onyx",
|
||||
title: enterpriseSettings?.application_name ?? "Onyx",
|
||||
description: "Question answering for your documents",
|
||||
icons: {
|
||||
icon: logoLocation,
|
||||
@@ -70,13 +70,11 @@ export default async function RootLayout({
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const [combinedSettings, assistantsData, user, authTypeMetadata] =
|
||||
await Promise.all([
|
||||
fetchSettingsSS(),
|
||||
fetchAssistantData(),
|
||||
getCurrentUserSS(),
|
||||
getAuthTypeMetadataSS(),
|
||||
]);
|
||||
const [combinedSettings, assistantsData, user] = await Promise.all([
|
||||
fetchSettingsSS(),
|
||||
fetchAssistantData(),
|
||||
getCurrentUserSS(),
|
||||
]);
|
||||
|
||||
const productGating =
|
||||
combinedSettings?.settings.application_status ?? ApplicationStatus.ACTIVE;
|
||||
@@ -149,7 +147,6 @@ export default async function RootLayout({
|
||||
|
||||
return getPageContent(
|
||||
<AppProvider
|
||||
authTypeMetadata={authTypeMetadata}
|
||||
user={user}
|
||||
settings={combinedSettings}
|
||||
assistants={assistants}
|
||||
|
||||
@@ -676,7 +676,6 @@ interface SelectorFormFieldProps {
|
||||
includeReset?: boolean;
|
||||
fontSize?: "sm" | "md" | "lg";
|
||||
small?: boolean;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function SelectorFormField({
|
||||
@@ -692,7 +691,6 @@ export function SelectorFormField({
|
||||
includeReset = false,
|
||||
fontSize = "md",
|
||||
small = false,
|
||||
disabled = false,
|
||||
}: SelectorFormFieldProps) {
|
||||
const [field] = useField<string>(name);
|
||||
const { setFieldValue } = useFormikContext();
|
||||
@@ -744,9 +742,8 @@ export function SelectorFormField({
|
||||
: setFieldValue(name, selected))
|
||||
}
|
||||
defaultValue={defaultValue}
|
||||
disabled={disabled}
|
||||
>
|
||||
<SelectTrigger className={sizeClass.input} disabled={disabled}>
|
||||
<SelectTrigger className={sizeClass.input}>
|
||||
<SelectValue placeholder="Select...">
|
||||
{currentlySelected?.name || defaultValue || ""}
|
||||
</SelectValue>
|
||||
|
||||
@@ -7,7 +7,6 @@ import { AssistantsProvider } from "./AssistantsContext";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { User } from "@/lib/types";
|
||||
import { ModalProvider } from "./ModalContext";
|
||||
import { AuthTypeMetadata } from "@/lib/userSS";
|
||||
|
||||
interface AppProviderProps {
|
||||
children: React.ReactNode;
|
||||
@@ -16,7 +15,6 @@ interface AppProviderProps {
|
||||
assistants: Persona[];
|
||||
hasAnyConnectors: boolean;
|
||||
hasImageCompatibleModel: boolean;
|
||||
authTypeMetadata: AuthTypeMetadata;
|
||||
}
|
||||
|
||||
export const AppProvider = ({
|
||||
@@ -26,15 +24,10 @@ export const AppProvider = ({
|
||||
assistants,
|
||||
hasAnyConnectors,
|
||||
hasImageCompatibleModel,
|
||||
authTypeMetadata,
|
||||
}: AppProviderProps) => {
|
||||
return (
|
||||
<SettingsProvider settings={settings}>
|
||||
<UserProvider
|
||||
settings={settings}
|
||||
user={user}
|
||||
authTypeMetadata={authTypeMetadata}
|
||||
>
|
||||
<UserProvider settings={settings} user={user}>
|
||||
<ProviderContextProvider>
|
||||
<AssistantsProvider
|
||||
initialAssistants={assistants}
|
||||
|
||||
@@ -13,7 +13,6 @@ import { usePostHog } from "posthog-js/react";
|
||||
import { CombinedSettings } from "@/app/admin/settings/interfaces";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import { useTokenRefresh } from "@/hooks/useTokenRefresh";
|
||||
import { AuthTypeMetadata } from "@/lib/userSS";
|
||||
|
||||
interface UserContextType {
|
||||
user: User | null;
|
||||
@@ -34,12 +33,10 @@ interface UserContextType {
|
||||
const UserContext = createContext<UserContextType | undefined>(undefined);
|
||||
|
||||
export function UserProvider({
|
||||
authTypeMetadata,
|
||||
children,
|
||||
user,
|
||||
settings,
|
||||
}: {
|
||||
authTypeMetadata: AuthTypeMetadata;
|
||||
children: React.ReactNode;
|
||||
user: User | null;
|
||||
settings: CombinedSettings;
|
||||
@@ -105,7 +102,7 @@ export function UserProvider({
|
||||
};
|
||||
|
||||
// Use the custom token refresh hook
|
||||
useTokenRefresh(upToDateUser, authTypeMetadata, fetchUser);
|
||||
useTokenRefresh(upToDateUser, fetchUser);
|
||||
|
||||
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
|
||||
try {
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { User } from "@/lib/types";
|
||||
import { NO_AUTH_USER_ID } from "@/lib/extension/constants";
|
||||
import { AuthTypeMetadata } from "@/lib/userSS";
|
||||
|
||||
// Refresh token every 10 minutes (600000ms)
|
||||
// This is shorter than the session expiry time to ensure tokens stay valid
|
||||
@@ -10,7 +9,6 @@ const REFRESH_INTERVAL = 600000;
|
||||
// Custom hook for handling JWT token refresh for current user
|
||||
export function useTokenRefresh(
|
||||
user: User | null,
|
||||
authTypeMetadata: AuthTypeMetadata,
|
||||
onRefreshFail: () => Promise<void>
|
||||
) {
|
||||
// Track last refresh time to avoid unnecessary calls
|
||||
@@ -20,13 +18,7 @@ export function useTokenRefresh(
|
||||
const isFirstLoad = useRef(true);
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
!user ||
|
||||
user.id === NO_AUTH_USER_ID ||
|
||||
authTypeMetadata.authType === "oidc" ||
|
||||
authTypeMetadata.authType === "saml"
|
||||
)
|
||||
return;
|
||||
if (!user || user.id === NO_AUTH_USER_ID) return;
|
||||
|
||||
const refreshTokenPeriodically = async () => {
|
||||
try {
|
||||
|
||||
Reference in New Issue
Block a user