mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
80 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d960a477f | ||
|
|
8837b8ea79 | ||
|
|
3dfb214f73 | ||
|
|
18d7262608 | ||
|
|
09b879ee73 | ||
|
|
aaa668c963 | ||
|
|
edb877f4bc | ||
|
|
eb369caefb | ||
|
|
b9567eabd7 | ||
|
|
13bbf67091 | ||
|
|
457a4c73f0 | ||
|
|
ce37688b5b | ||
|
|
4e2c90f4af | ||
|
|
513dd8a319 | ||
|
|
71c5043832 | ||
|
|
64b6f15e95 | ||
|
|
35022f5f09 | ||
|
|
0d44014c16 | ||
|
|
1b9e9f48fa | ||
|
|
05fb5aa27b | ||
|
|
3b645b72a3 | ||
|
|
fe770b5c3a | ||
|
|
1eaf885f50 | ||
|
|
a187aa508c | ||
|
|
aa4bfa2a78 | ||
|
|
9011b8a139 | ||
|
|
59c774353a | ||
|
|
b458d504af | ||
|
|
f83e7bfcd9 | ||
|
|
4d2e26ce4b | ||
|
|
817fdc1f36 | ||
|
|
e9b10e8b41 | ||
|
|
a0fa4adb60 | ||
|
|
ca9ba925bd | ||
|
|
833cc5c97c | ||
|
|
23ecf654ed | ||
|
|
ddc6a6d2b3 | ||
|
|
571c8ece32 | ||
|
|
884bdb4b01 | ||
|
|
b3ecf0d59f | ||
|
|
f56fda27c9 | ||
|
|
b1e4d4ea8d | ||
|
|
8db6d49fe5 | ||
|
|
28598694b1 | ||
|
|
b5d0df90b9 | ||
|
|
48be6338ec | ||
|
|
ed9014f03d | ||
|
|
2dd51230ed | ||
|
|
8b249cbe63 | ||
|
|
6b50f86cd2 | ||
|
|
bd2805b6df | ||
|
|
2847ab003e | ||
|
|
1df6a506ec | ||
|
|
f1541d1fbe | ||
|
|
dd0c4b64df | ||
|
|
788b3015bc | ||
|
|
cbbf10f450 | ||
|
|
d954914a0a | ||
|
|
bee74ac360 | ||
|
|
29ef64272a | ||
|
|
01bf6ee4b7 | ||
|
|
0502417cbe | ||
|
|
d0483dd269 | ||
|
|
eefa872d60 | ||
|
|
3f3d4da611 | ||
|
|
469068052e | ||
|
|
9032b05606 | ||
|
|
334bc6be8c | ||
|
|
814f97c2c7 | ||
|
|
4f5a2b47c4 | ||
|
|
f545508268 | ||
|
|
590986ec65 | ||
|
|
531bab5409 | ||
|
|
29c44007c4 | ||
|
|
d388643a04 | ||
|
|
8a422683e3 | ||
|
|
ddc0230d68 | ||
|
|
6711e91dbf | ||
|
|
cff2346db5 | ||
|
|
8d3fad1f12 |
18
.github/pull_request_template.md
vendored
18
.github/pull_request_template.md
vendored
@@ -6,24 +6,6 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
|
||||
@@ -66,6 +66,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -8,18 +8,29 @@ on:
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
DOCKER_BUILDKIT: 1
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
build-amd64:
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: System Info
|
||||
run: |
|
||||
df -h
|
||||
free -h
|
||||
docker system prune -af --volumes
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
@@ -27,24 +38,80 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
- name: Build and Push AMD64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
platforms: linux/amd64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
build-arm64:
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: System Info
|
||||
run: |
|
||||
df -h
|
||||
free -h
|
||||
docker system prune -af --volumes
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: |
|
||||
image=moby/buildkit:latest
|
||||
network=host
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push ARM64
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
push: true
|
||||
tags: ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
outputs: type=registry
|
||||
provenance: false
|
||||
|
||||
merge-and-scan:
|
||||
needs: [build-amd64, build-arm64]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and Push Multi-arch Manifest
|
||||
run: |
|
||||
docker buildx create --use
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }} \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
if [[ "${{ env.LATEST_TAG }}" == "true" ]]; then
|
||||
docker buildx imagetools create -t ${{ env.REGISTRY_IMAGE }}:latest \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-amd64 \
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}-arm64
|
||||
fi
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
@@ -53,3 +120,4 @@ jobs:
|
||||
with:
|
||||
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
timeout: "10m"
|
||||
|
||||
14
.github/workflows/pr-chromatic-tests.yml
vendored
14
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -15,7 +15,12 @@ jobs:
|
||||
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -196,7 +201,12 @@ jobs:
|
||||
|
||||
needs: playwright-tests
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
|
||||
[
|
||||
runs-on,
|
||||
runner=32cpu-linux-x64,
|
||||
disk=large,
|
||||
"run-id=${{ github.run_id }}",
|
||||
]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
3
.github/workflows/pr-integration-tests.yml
vendored
3
.github/workflows/pr-integration-tests.yml
vendored
@@ -20,8 +20,7 @@ env:
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, ram=16, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
16
README.md
16
README.md
@@ -3,7 +3,7 @@
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
@@ -24,7 +24,7 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
@@ -133,15 +133,3 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -1,39 +1,49 @@
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Literal
|
||||
import os
|
||||
import ssl
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
if config.config_file_name is not None and config.attributes.get(
|
||||
"configure_logger", True
|
||||
):
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Add your model's MetaData object here for 'autogenerate' support
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
if USE_IAM_AUTH:
|
||||
if not os.path.exists(SSL_CERT_FILE):
|
||||
raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.")
|
||||
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
@@ -49,20 +59,12 @@ def include_object(
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
Excludes specified tables from migrations.
|
||||
"""
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_schema_options() -> tuple[str, bool, bool]:
|
||||
"""
|
||||
Parses command-line options passed via '-x' in Alembic commands.
|
||||
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
|
||||
"""
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
@@ -90,16 +92,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
|
||||
def do_run_migrations(
|
||||
connection: Connection, schema_name: str, create_schema: bool
|
||||
) -> None:
|
||||
"""
|
||||
Executes migrations in the specified schema.
|
||||
"""
|
||||
logger.info(f"About to migrate schema: {schema_name}")
|
||||
|
||||
if create_schema:
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
# Set search_path to the target schema
|
||||
connection.execute(text(f'SET search_path TO "{schema_name}"'))
|
||||
|
||||
context.configure(
|
||||
@@ -117,11 +115,25 @@ def do_run_migrations(
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
# Get IAM authentication token
|
||||
token = get_iam_auth_token(host, port, user, region)
|
||||
|
||||
# For Alembic / SQLAlchemy in this context, set SSL and password
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""
|
||||
Determines whether to run migrations for a single schema or all schemas,
|
||||
and executes migrations accordingly.
|
||||
"""
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
@@ -129,10 +141,16 @@ async def run_async_migrations() -> None:
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run migrations for all tenant schemas sequentially
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
for schema in tenant_schemas:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
@@ -162,15 +180,20 @@ async def run_async_migrations() -> None:
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""
|
||||
Run migrations in 'offline' mode.
|
||||
"""
|
||||
schema_name, _, upgrade_all_tenants = get_schema_options()
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
# Run offline migrations for all tenant schemas
|
||||
engine = create_async_engine(url)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic_offline(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
|
||||
@@ -207,9 +230,6 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""
|
||||
Runs migrations in 'online' mode using an asynchronous engine.
|
||||
"""
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
|
||||
121
backend/alembic/versions/35e518e0ddf4_properly_cascade.py
Normal file
121
backend/alembic/versions/35e518e0ddf4_properly_cascade.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""properly_cascade
|
||||
|
||||
Revision ID: 35e518e0ddf4
|
||||
Revises: 91a0a4d62b14
|
||||
Create Date: 2024-09-20 21:24:04.891018
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35e518e0ddf4"
|
||||
down_revision = "91a0a4d62b14"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Update chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Update chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
# Add CASCADE delete for tool_call foreign key
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert chat_message foreign key constraint
|
||||
op.drop_constraint(
|
||||
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message_chat_session_id_fkey",
|
||||
"chat_message",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert chat_message__search_doc foreign key constraints
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_chat_message_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__search_doc_search_doc_id_fkey",
|
||||
"chat_message__search_doc",
|
||||
"search_doc",
|
||||
["search_doc_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Revert tool_call foreign key constraint
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey",
|
||||
"tool_call",
|
||||
"chat_message",
|
||||
["message_id"],
|
||||
["id"],
|
||||
)
|
||||
45
backend/alembic/versions/91a0a4d62b14_milestone.py
Normal file
45
backend/alembic/versions/91a0a4d62b14_milestone.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Milestone
|
||||
|
||||
Revision ID: 91a0a4d62b14
|
||||
Revises: dab04867cd88
|
||||
Create Date: 2024-12-13 19:03:30.947551
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91a0a4d62b14"
|
||||
down_revision = "dab04867cd88"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"milestone",
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.Column("tenant_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("event_type", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("event_tracker", postgresql.JSONB(), nullable=True),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("event_type", name="uq_milestone_event_type"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("milestone")
|
||||
87
backend/alembic/versions/c0aab6edb6dd_delete_workspace.py
Normal file
87
backend/alembic/versions/c0aab6edb6dd_delete_workspace.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""delete workspace
|
||||
|
||||
Revision ID: c0aab6edb6dd
|
||||
Revises: 35e518e0ddf4
|
||||
Create Date: 2024-12-17 14:37:07.660631
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0aab6edb6dd"
|
||||
down_revision = "35e518e0ddf4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = connector_specific_config - 'workspace'
|
||||
WHERE source = 'SLACK'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
import json
|
||||
from sqlalchemy import text
|
||||
from slack_sdk import WebClient
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Fetch all Slack credentials
|
||||
creds_result = conn.execute(
|
||||
text("SELECT id, credential_json FROM credential WHERE source = 'SLACK'")
|
||||
)
|
||||
all_slack_creds = creds_result.fetchall()
|
||||
if not all_slack_creds:
|
||||
return
|
||||
|
||||
for cred_row in all_slack_creds:
|
||||
credential_id, credential_json = cred_row
|
||||
|
||||
credential_json = (
|
||||
credential_json.tobytes().decode("utf-8")
|
||||
if isinstance(credential_json, memoryview)
|
||||
else credential_json.decode("utf-8")
|
||||
)
|
||||
credential_data = json.loads(credential_json)
|
||||
slack_bot_token = credential_data.get("slack_bot_token")
|
||||
if not slack_bot_token:
|
||||
print(
|
||||
f"No slack_bot_token found for credential {credential_id}. "
|
||||
"Your Slack connector will not function until you upgrade and provide a valid token."
|
||||
)
|
||||
continue
|
||||
|
||||
client = WebClient(token=slack_bot_token)
|
||||
try:
|
||||
auth_response = client.auth_test()
|
||||
workspace = auth_response["url"].split("//")[1].split(".")[0]
|
||||
|
||||
# Update only the connectors linked to this credential
|
||||
# (and which are Slack connectors).
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE connector AS c
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{{workspace}}',
|
||||
to_jsonb('{workspace}'::text)
|
||||
)
|
||||
FROM connector_credential_pair AS ccp
|
||||
WHERE ccp.connector_id = c.id
|
||||
AND c.source = 'SLACK'
|
||||
AND ccp.credential_id = {credential_id}
|
||||
"""
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
f"We were unable to get the workspace url for your Slack Connector with id {credential_id}."
|
||||
)
|
||||
print("This connector will no longer work until you upgrade.")
|
||||
continue
|
||||
@@ -47,3 +47,11 @@ OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", ""
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
|
||||
# The posthog client does not accept empty API keys or hosts however it fails silently
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def validate_user_creation_permissions(
|
||||
def validate_object_creation_for_user(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None = None,
|
||||
@@ -440,32 +440,108 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
def _validate_curator_relationship_update_requester(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
"""
|
||||
This function validates that the user making the change has the necessary permissions
|
||||
to update the curator relationship for the target user in the given user group.
|
||||
"""
|
||||
|
||||
if user.role == UserRole.ADMIN:
|
||||
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# check if the user making the change is a curator in the group they are changing the curator relationship for
|
||||
user_making_change_curator_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user_making_change.id,
|
||||
# only check if the user making the change is a curator if they are a curator
|
||||
# otherwise, they are a global_curator and can update the curator relationship
|
||||
# for any group they are a member of
|
||||
only_curator_groups=user_making_change.role == UserRole.CURATOR,
|
||||
)
|
||||
requestor_curator_group_ids = [
|
||||
group.id for group in user_making_change_curator_groups
|
||||
]
|
||||
if user_group_id not in requestor_curator_group_ids:
|
||||
raise ValueError(
|
||||
f"User '{user.email}' is an admin and therefore has all permissions "
|
||||
f"user making change {user_making_change.email} is not a curator,"
|
||||
f" admin, or global_curator for group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_request(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
target_user: User,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the curator_relationship_update request itself is valid.
|
||||
"""
|
||||
if target_user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
elif target_user.role == UserRole.GLOBAL_CURATOR:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is a global_curator and therefore has all "
|
||||
"permissions of a curator for all groups. If you'd like this user to only "
|
||||
"have curator permissions for a specific group, you must update their role "
|
||||
"to BASIC then assign them to be CURATOR in the appropriate groups."
|
||||
)
|
||||
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
|
||||
raise ValueError(
|
||||
f"This endpoint can only be used to update the curator relationship for "
|
||||
"users with the CURATOR or BASIC role. \n"
|
||||
f"Target user: {target_user.email} \n"
|
||||
f"Target user role: {target_user.role} \n"
|
||||
)
|
||||
|
||||
# check if the target user is in the group they are changing the curator relationship for
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=set_curator_request.user_id,
|
||||
user_id=target_user.id,
|
||||
only_curator_groups=False,
|
||||
)
|
||||
|
||||
group_ids = [group.id for group in requested_user_groups]
|
||||
if user_group_id not in group_ids:
|
||||
raise ValueError(f"user is not in group '{user_group_id}'")
|
||||
raise ValueError(
|
||||
f"target user {target_user.email} is not in group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not target_user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
|
||||
_validate_curator_relationship_update_request(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
target_user=target_user,
|
||||
)
|
||||
|
||||
_validate_curator_relationship_update_requester(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
user_making_change=user_making_change,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
|
||||
f"updating the curator relationship for user={target_user.email} "
|
||||
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
|
||||
)
|
||||
|
||||
relationship_to_update = (
|
||||
db_session.query(User__UserGroup)
|
||||
@@ -486,7 +562,7 @@ def update_user_curator_relationship(
|
||||
)
|
||||
db_session.add(relationship_to_update)
|
||||
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
_validate_curator_status__no_commit(db_session, [target_user])
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
@@ -12,15 +13,23 @@ from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_jwt_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -114,3 +123,48 @@ async def impersonate_user(
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
|
||||
@@ -39,3 +39,8 @@ class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
referral_source: str | None = None
|
||||
|
||||
|
||||
class TenantDeletionPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
|
||||
@@ -3,15 +3,19 @@ import logging
|
||||
import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
import httpx
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import drop_schema
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
@@ -20,6 +24,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -35,22 +40,27 @@ from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_or_create_tenant_id(
|
||||
email: str, referral_source: str | None = None
|
||||
async def get_or_provision_tenant(
|
||||
email: str, referral_source: str | None = None, request: Request | None = None
|
||||
) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
except exceptions.UserNotExists:
|
||||
@@ -122,6 +132,17 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
@@ -165,6 +186,7 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
@@ -267,3 +289,59 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
logger.info(
|
||||
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
|
||||
)
|
||||
|
||||
|
||||
async def submit_to_hubspot(
|
||||
email: str, referral_source: str | None, request: Request
|
||||
) -> None:
|
||||
if not HUBSPOT_TRACKING_URL:
|
||||
logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission")
|
||||
return
|
||||
|
||||
# HubSpot tracking cookie
|
||||
hubspot_cookie = request.cookies.get("hubspotutk")
|
||||
|
||||
# IP address
|
||||
ip_address = request.client.host if request.client else None
|
||||
|
||||
data = {
|
||||
"fields": [
|
||||
{"name": "email", "value": email},
|
||||
{"name": "referral_source", "value": referral_source or ""},
|
||||
],
|
||||
"context": {
|
||||
"hutk": hubspot_cookie,
|
||||
"ipAddress": ip_address,
|
||||
"pageUri": str(request.url),
|
||||
"pageName": "User Registration",
|
||||
},
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(HUBSPOT_TRACKING_URL, json=data)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to submit to HubSpot: {response.text}")
|
||||
|
||||
|
||||
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
print(response)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
@@ -68,3 +68,11 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
@@ -83,7 +83,7 @@ def patch_user_group(
|
||||
def set_user_curator(
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
@@ -91,6 +91,7 @@ def set_user_curator(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
set_curator_request=set_curator_request,
|
||||
user_making_change=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error setting user curator: {e}")
|
||||
|
||||
34
backend/ee/onyx/utils/telemetry.py
Normal file
34
backend/ee/onyx/utils/telemetry.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Any
|
||||
|
||||
from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
) -> None:
|
||||
"""Capture and send an event to PostHog, flushing immediately."""
|
||||
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
|
||||
try:
|
||||
posthog.capture(distinct_id, event, properties)
|
||||
posthog.flush()
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing PostHog event: {e}")
|
||||
@@ -27,8 +27,8 @@ from shared_configs.configs import SENTRY_DSN
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
HF_CACHE_PATH = Path("/root/.cache/huggingface/")
|
||||
TEMP_HF_CACHE_PATH = Path("/root/.cache/temp_huggingface/")
|
||||
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
|
||||
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
|
||||
80
backend/onyx/auth/email_utils.py
Normal file
80
backend/onyx/auth/email_utils.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Workspace"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join a workspace on Onyx.
|
||||
|
||||
To join the workspace, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/login
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
)
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
@@ -4,6 +4,8 @@ from typing import cast
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
|
||||
from onyx.configs.constants import NO_AUTH_USER_EMAIL
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.key_value_store.store import KeyValueStore
|
||||
from onyx.key_value_store.store import KvKeyNotFoundError
|
||||
from onyx.server.manage.models import UserInfo
|
||||
@@ -30,8 +32,8 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
return UserInfo(
|
||||
id="__no_auth_user__",
|
||||
email="anonymous@onyx.app",
|
||||
id=NO_AUTH_USER_ID,
|
||||
email=NO_AUTH_USER_EMAIL,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
@@ -52,19 +50,17 @@ from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
@@ -72,6 +68,8 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
@@ -88,6 +86,7 @@ from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -189,30 +188,6 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Onyx Email Verification"
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
|
||||
body = MIMEText(f"Click the following link to verify your email address: {link}")
|
||||
msg.attach(body)
|
||||
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
|
||||
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -225,17 +200,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request is not None:
|
||||
referral_source = request.cookies.get("referral_source", None)
|
||||
# We verify the password here to make sure it's valid before we proceed
|
||||
await self.validate_password(
|
||||
user_create.password, cast(schemas.UC, user_create)
|
||||
)
|
||||
|
||||
user_count: int | None = None
|
||||
referral_source = (
|
||||
request.cookies.get("referral_source", None)
|
||||
if request is not None
|
||||
else None
|
||||
)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -278,7 +262,37 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
return user
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to basic security guidelines
|
||||
if len(password) < 12:
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must be at least 12 characters long."
|
||||
)
|
||||
if len(password) > 64:
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must not exceed 64 characters."
|
||||
)
|
||||
if not any(char.isupper() for char in password):
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must contain at least one uppercase letter."
|
||||
)
|
||||
if not any(char.islower() for char in password):
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must contain at least one lowercase letter."
|
||||
)
|
||||
if not any(char.isdigit() for char in password):
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must contain at least one number."
|
||||
)
|
||||
if not any(char in PASSWORD_SPECIAL_CHARS for char in password):
|
||||
raise exceptions.InvalidPasswordException(
|
||||
reason="Password must contain at least one special character from the following set: "
|
||||
f"{PASSWORD_SPECIAL_CHARS}."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
async def oauth_callback(
|
||||
self,
|
||||
@@ -293,17 +307,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> User:
|
||||
referral_source = None
|
||||
if request:
|
||||
referral_source = getattr(request.state, "referral_source", None)
|
||||
referral_source = (
|
||||
getattr(request.state, "referral_source", None) if request else None
|
||||
)
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
|
||||
if not tenant_id:
|
||||
@@ -365,6 +380,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -418,6 +434,39 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
request=request,
|
||||
)
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
user_count = await get_user_count()
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if user_count == 1:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=MilestoneRecordType.USER_SIGNED_UP,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email,
|
||||
event_type=MilestoneRecordType.MULTIPLE_USERS,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
logger.notice(f"User {user.id} has registered.")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.SIGN_UP,
|
||||
@@ -428,7 +477,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
if not EMAIL_CONFIGURED:
|
||||
logger.error(
|
||||
"Email is not configured. Please configure email in the admin panel"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Your admin has not enbaled this feature.",
|
||||
)
|
||||
send_forgot_password_email(user.email, token)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -449,7 +506,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
@@ -510,7 +567,7 @@ class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
@@ -546,9 +603,7 @@ def get_database_strategy(
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt" if MULTI_TENANT else "database",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
|
||||
) # type: ignore
|
||||
|
||||
|
||||
|
||||
@@ -3,11 +3,12 @@ import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
@@ -21,6 +22,7 @@ from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
@@ -34,8 +36,11 @@ from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -56,8 +61,8 @@ def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
@@ -257,7 +262,8 @@ def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
client = get_vespa_http_client()
|
||||
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
@@ -346,26 +352,36 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
loglevel: int,
|
||||
logfile: str | None,
|
||||
format: str,
|
||||
colorize: bool,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# TODO: could unhardcode format and colorize and accept these as options from
|
||||
# celery's config
|
||||
|
||||
# reformats the root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers = []
|
||||
|
||||
root_handler = logging.StreamHandler() # Set up a handler for the root logger
|
||||
# Define the log format
|
||||
log_format = (
|
||||
"%(levelname)-8s %(asctime)s %(filename)15s:%(lineno)-4d: %(name)s %(message)s"
|
||||
)
|
||||
|
||||
# Set up the root handler
|
||||
root_handler = logging.StreamHandler()
|
||||
root_formatter = ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
log_format,
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_handler.setFormatter(root_formatter)
|
||||
root_logger.addHandler(root_handler) # Apply the handler to the root logger
|
||||
root_logger.addHandler(root_handler)
|
||||
|
||||
if logfile:
|
||||
root_file_handler = logging.FileHandler(logfile)
|
||||
root_file_formatter = PlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
log_format,
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
root_file_handler.setFormatter(root_file_formatter)
|
||||
@@ -373,19 +389,23 @@ def on_setup_logging(
|
||||
|
||||
root_logger.setLevel(loglevel)
|
||||
|
||||
# reformats celery's task logger
|
||||
# Configure the task logger
|
||||
task_logger.handlers = []
|
||||
|
||||
task_handler = logging.StreamHandler()
|
||||
task_handler.addFilter(TenantContextFilter())
|
||||
task_formatter = CeleryTaskColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
log_format,
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_handler = logging.StreamHandler() # Set up a handler for the task logger
|
||||
task_handler.setFormatter(task_formatter)
|
||||
task_logger.addHandler(task_handler) # Apply the handler to the task logger
|
||||
task_logger.addHandler(task_handler)
|
||||
|
||||
if logfile:
|
||||
task_file_handler = logging.FileHandler(logfile)
|
||||
task_file_handler.addFilter(TenantContextFilter())
|
||||
task_file_formatter = CeleryTaskPlainFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: %(message)s",
|
||||
log_format,
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
task_file_handler.setFormatter(task_file_formatter)
|
||||
@@ -394,10 +414,55 @@ def on_setup_logging(
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
# Hide celery task received and succeeded/failed messages
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class TenantContextFilter(logging.Filter):
|
||||
|
||||
"""Logging filter to inject tenant ID into the logger's name."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if not MULTI_TENANT:
|
||||
record.name = ""
|
||||
return True
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id:
|
||||
tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:5]
|
||||
record.name = f"[t:{tenant_id}]"
|
||||
else:
|
||||
record.name = ""
|
||||
return True
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def set_tenant_id(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Signal handler to set tenant ID in context var before task starts."""
|
||||
tenant_id = (
|
||||
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
if kwargs
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def reset_tenant_id(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Signal handler to reset tenant ID in context var after task ends."""
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
@@ -44,18 +43,18 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating tenant task update")
|
||||
logger.info("Reload interval reached, initiating task update")
|
||||
self._update_tenant_tasks()
|
||||
self._last_reload = now
|
||||
logger.info("Tenant task update completed, reset reload timer")
|
||||
logger.info("Task update completed, reset reload timer")
|
||||
return retval
|
||||
|
||||
def _update_tenant_tasks(self) -> None:
|
||||
logger.info("Starting tenant task update process")
|
||||
logger.info("Starting task update process")
|
||||
try:
|
||||
logger.info("Fetching all tenant IDs")
|
||||
logger.info("Fetching all IDs")
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} tenants")
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
|
||||
logger.info("Fetching tasks to schedule")
|
||||
tasks_to_schedule = fetch_versioned_implementation(
|
||||
@@ -70,7 +69,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
for task_name, _ in current_schedule:
|
||||
if "-" in task_name:
|
||||
existing_tenants.add(task_name.split("-")[-1])
|
||||
logger.info(f"Found {len(existing_tenants)} existing tenants in schedule")
|
||||
logger.info(f"Found {len(existing_tenants)} existing items in schedule")
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
@@ -83,7 +82,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
continue
|
||||
|
||||
if tenant_id not in existing_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
logger.info(f"Processing new item: {tenant_id}")
|
||||
|
||||
for task in tasks_to_schedule():
|
||||
task_name = f"{task['name']}-{tenant_id}"
|
||||
@@ -129,11 +128,10 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
logger.info("Schedule update completed successfully")
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
@@ -155,10 +153,6 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -61,13 +61,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -62,13 +62,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -60,13 +60,15 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -84,14 +84,14 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
|
||||
@@ -1,12 +1,56 @@
|
||||
# These are helper objects for tracking the keys we need to write in redis
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
"""Checking the unacked queue is useful because a non-zero length tells us there
|
||||
may be prefetched tasks.
|
||||
|
||||
There can be other tasks in here besides indexing tasks, so this is mostly useful
|
||||
just to see if the task count is non zero.
|
||||
|
||||
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
|
||||
"""
|
||||
length = cast(int, r.hlen("unacked"))
|
||||
return length
|
||||
|
||||
|
||||
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""Gets the set of task id's matching the given queue in the unacked hash.
|
||||
|
||||
Unacked entries belonging to the indexing queue are "prefetched", so this gives
|
||||
us crucial visibility as to what tasks are in that state.
|
||||
"""
|
||||
tasks: set[str] = set()
|
||||
|
||||
for _, v in r.hscan_iter("unacked"):
|
||||
v_bytes = cast(bytes, v)
|
||||
v_str = v_bytes.decode("utf-8")
|
||||
task = json.loads(v_str)
|
||||
|
||||
task_description = task[0]
|
||||
task_queue = task[2]
|
||||
|
||||
if task_queue != queue:
|
||||
continue
|
||||
|
||||
task_id = task_description.get("headers", {}).get("id")
|
||||
if not task_id:
|
||||
continue
|
||||
|
||||
# if the queue matches and we see the task_id, add it
|
||||
tasks.add(task_id)
|
||||
return tasks
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
@@ -23,3 +67,96 @@ def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
total_length += cast(int, length)
|
||||
|
||||
return total_length
|
||||
|
||||
|
||||
def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to find a task for a particular queue in redis.
|
||||
It is priority aware and knows how to look through the multiple redis lists
|
||||
used to implement task prioritization.
|
||||
This operation is not atomic.
|
||||
|
||||
This is a linear search O(n) ... so be careful using it when the task queues can be larger.
|
||||
|
||||
Returns true if the id is in the queue, False if not.
|
||||
"""
|
||||
for priority in range(len(OnyxCeleryPriority)):
|
||||
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
||||
|
||||
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
||||
for task in tasks:
|
||||
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
||||
if task_dict.get("headers", {}).get("id") == task_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if workers:
|
||||
for worker_name in list(workers.keys()):
|
||||
# if the name filter not set, return all worker names
|
||||
if not name_filter:
|
||||
worker_names.append(worker_name)
|
||||
continue
|
||||
|
||||
# if the name filter is set, return only worker names that contain the name filter
|
||||
if name_filter not in worker_name:
|
||||
continue
|
||||
|
||||
worker_names.append(worker_name)
|
||||
|
||||
return worker_names
|
||||
|
||||
|
||||
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of reserved tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
reserved_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
|
||||
if reserved_tasks:
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_task_ids.add(task["id"])
|
||||
|
||||
return reserved_task_ids
|
||||
|
||||
|
||||
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of active tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
active_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
|
||||
if active_tasks:
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_task_ids.add(task["id"])
|
||||
|
||||
return active_task_ids
|
||||
|
||||
@@ -16,6 +16,11 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Indexing worker specific ... this lets us track the transition to STARTED in redis
|
||||
# We don't currently rely on this but it has the potential to be useful and
|
||||
# indexing tasks are not high volume
|
||||
task_track_started = True
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -4,55 +4,80 @@ from typing import Any
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": OnyxCeleryPriority.LOWEST},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": OnyxCeleryPriority.HIGH},
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception during connector deletion check")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@@ -131,14 +131,14 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
if redis_connector_index.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (indexing in progress): "
|
||||
"Connector deletion - Delayed (indexing in progress): "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
|
||||
if redis_connector.prune.fenced:
|
||||
raise TaskDependencyError(
|
||||
f"Connector deletion - Delayed (pruning in progress): "
|
||||
"Connector deletion - Delayed (pruning in progress): "
|
||||
f"cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
@@ -175,7 +175,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
@@ -15,6 +17,8 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
@@ -26,6 +30,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -96,13 +101,37 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned. "
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
# diagnostic logging for lock errors
|
||||
name = self.redis_lock.name
|
||||
ttl = self.redis_client.ttl(name)
|
||||
locked = self.redis_lock.locked()
|
||||
owned = self.redis_lock.owned()
|
||||
local_token: str | None = self.redis_lock.local.token # type: ignore
|
||||
|
||||
remote_token_raw = self.redis_client.get(self.redis_lock.name)
|
||||
if remote_token_raw:
|
||||
remote_token_bytes = cast(bytes, remote_token_raw)
|
||||
remote_token = remote_token_bytes.decode("utf-8")
|
||||
else:
|
||||
remote_token = None
|
||||
|
||||
logger.warning(
|
||||
f"IndexingCallback - lock diagnostics: "
|
||||
f"name={name} "
|
||||
f"locked={locked} "
|
||||
f"owned={owned} "
|
||||
f"local_token={local_token} "
|
||||
f"remote_token={remote_token} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
@@ -162,11 +191,19 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -271,7 +308,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
search_settings_instance,
|
||||
reindex,
|
||||
db_session,
|
||||
r,
|
||||
redis_client,
|
||||
tenant_id,
|
||||
)
|
||||
if attempt_id:
|
||||
@@ -286,7 +323,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
|
||||
db_session, redis_client
|
||||
)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -304,12 +343,27 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
mark_attempt_failed(
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# clear any indexing fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_indexing_fences(
|
||||
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception during indexing check")
|
||||
finally:
|
||||
if locked:
|
||||
if lock_beat.owned():
|
||||
@@ -320,9 +374,157 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
mark_attempt_failed(
|
||||
payload.index_attempt_id,
|
||||
db_session,
|
||||
"validate_indexing_fence - Canceling index attempt due to missing celery tasks",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"validate_indexing_fence - Exception while marking index attempt as failed."
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
|
||||
def _should_index(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
@@ -469,6 +671,7 @@ def try_creating_indexing_task(
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
redis_connector_index.set_active()
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
# create the index attempt for tracking purposes
|
||||
@@ -502,13 +705,14 @@ def try_creating_indexing_task(
|
||||
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
|
||||
|
||||
# now fill out the fence with the rest of the data
|
||||
redis_connector_index.set_active()
|
||||
|
||||
payload.index_attempt_id = index_attempt_id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector_index.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"try_creating_indexing_task - Unexpected exception: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"search_settings={search_settings.id}"
|
||||
)
|
||||
@@ -540,7 +744,6 @@ def connector_indexing_proxy_task(
|
||||
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - starting: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
@@ -563,15 +766,14 @@ def connector_indexing_proxy_task(
|
||||
if not job:
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
@@ -582,11 +784,56 @@ def connector_indexing_proxy_task(
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
job.release()
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
@@ -609,79 +856,36 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
job.release()
|
||||
break
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
@@ -877,6 +1081,7 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# This is where the heavy/real work happens
|
||||
run_indexing_entrypoint(
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
@@ -906,7 +1111,6 @@ def connector_indexing_task(
|
||||
|
||||
logger.info(
|
||||
f"Indexing spawned task finished: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
@@ -122,7 +122,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception during pruning check")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@@ -308,7 +308,7 @@ def connector_pruning_generator_task(
|
||||
doc_ids_to_remove = list(all_indexed_document_ids - all_connector_doc_ids)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning set collected: "
|
||||
"Pruning set collected: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source} "
|
||||
f"docs_to_remove={len(doc_ids_to_remove)}"
|
||||
@@ -324,7 +324,7 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.prune.generate_tasks finished. "
|
||||
"RedisConnector.prune.generate_tasks finished. "
|
||||
f"cc_pair={cc_pair_id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
task_logger.debug(f"Task start: tenant={tenant_id} doc={document_id}")
|
||||
task_logger.debug(f"Task start: doc={document_id}")
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -129,16 +129,13 @@ def document_by_cc_pair_cleanup_task(
|
||||
db_session.commit()
|
||||
|
||||
task_logger.info(
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
@@ -157,15 +154,12 @@ def document_by_cc_pair_cleanup_task(
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
task_logger.exception(f"Unexpected exception: doc={document_id}")
|
||||
|
||||
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
|
||||
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
@@ -176,7 +170,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"tenant={tenant_id} doc={document_id}"
|
||||
f"doc={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# delete the cc pair relationship now and let reconciliation clean it up
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
@@ -19,6 +20,7 @@ from tenacity import RetryError
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
@@ -89,10 +91,11 @@ logger = setup_logger()
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -156,11 +159,15 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
task_logger.exception("Unexpected exception during vespa metadata sync")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
celery_app: Celery,
|
||||
@@ -630,15 +637,23 @@ def monitor_ccpair_indexing_taskset(
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: cc_pair={cc_pair_id} "
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
@@ -709,11 +724,14 @@ def monitor_ccpair_indexing_taskset(
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
@@ -730,6 +748,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
time_start = time.monotonic()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -759,31 +778,34 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
)
|
||||
|
||||
# scan and monitor activity to completion
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
@@ -794,28 +816,21 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -824,6 +839,8 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -873,13 +890,9 @@ def vespa_metadata_sync_task(
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
task_logger.info(
|
||||
f"tenant={tenant_id} doc={document_id} action=sync chunks={chunks_affected}"
|
||||
)
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
f"SoftTimeLimitExceeded exception. tenant={tenant_id} doc={document_id}"
|
||||
)
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
@@ -897,14 +910,13 @@ def vespa_metadata_sync_task(
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"tenant={tenant_id} "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception: tenant={tenant_id} doc={document_id}"
|
||||
f"Unexpected exception during vespa metadata sync: doc={document_id}"
|
||||
)
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.background.indexing.tracer import OnyxTracer
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
@@ -34,6 +35,7 @@ from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -396,6 +398,15 @@ def _run_indexing(
|
||||
|
||||
if index_attempt_md.num_exceptions == 0:
|
||||
mark_attempt_succeeded(index_attempt, db_session)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Connector succeeded: "
|
||||
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
|
||||
@@ -31,6 +31,8 @@ from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
@@ -53,6 +55,9 @@ from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.milestone import update_user_assistant_milestone
|
||||
from onyx.db.models import SearchDoc as DbSearchDoc
|
||||
from onyx.db.models import ToolCall
|
||||
from onyx.db.models import User
|
||||
@@ -117,6 +122,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
from onyx.tools.tool_runner import ToolCallFinalResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -356,6 +362,31 @@ def stream_chat_message_objects(
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
|
||||
multi_assistant_milestone, _is_new = create_milestone_if_not_exists(
|
||||
user=user,
|
||||
event_type=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_user_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
user_id=str(user.id) if user else NO_AUTH_USER_ID,
|
||||
assistant_id=persona.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
_, just_hit_multi_assistant_milestone = check_multi_assistant_milestone(
|
||||
milestone=multi_assistant_milestone,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if just_hit_multi_assistant_milestone:
|
||||
mt_cloud_telemetry(
|
||||
distinct_id=tenant_id,
|
||||
event=MilestoneRecordType.MULTIPLE_ASSISTANTS,
|
||||
properties=None,
|
||||
)
|
||||
|
||||
# If a prompt override is specified via the API, use that with highest priority
|
||||
# but for saving it, we are just mapping it to an existing prompt
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
|
||||
@@ -65,7 +65,7 @@ class CitationProcessor:
|
||||
# Handle code blocks without language tags
|
||||
if "`" in self.curr_segment:
|
||||
if self.curr_segment.endswith("`"):
|
||||
return
|
||||
pass
|
||||
elif "```" in self.curr_segment:
|
||||
piece_that_comes_after = self.curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
@@ -91,6 +92,7 @@ SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
# If set, Onyx will listen to the `expires_at` returned by the identity
|
||||
@@ -144,6 +146,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
@@ -174,6 +177,9 @@ try:
|
||||
except ValueError:
|
||||
POSTGRES_IDLE_SESSIONS_TIMEOUT = POSTGRES_IDLE_SESSIONS_TIMEOUT_DEFAULT
|
||||
|
||||
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||
|
||||
|
||||
REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true"
|
||||
REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
@@ -483,6 +489,21 @@ SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
|
||||
|
||||
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
|
||||
|
||||
# allow for custom error messages for different errors returned by litellm
|
||||
# for example, can specify: {"Violated content safety policy": "EVIL REQUEST!!!"}
|
||||
# to make it so that if an LLM call returns an error containing "Violated content safety policy"
|
||||
# the end user will see "EVIL REQUEST!!!" instead of the default error message.
|
||||
_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = os.environ.get(
|
||||
"LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS", ""
|
||||
)
|
||||
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS: dict[str, str] | None = None
|
||||
try:
|
||||
LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = cast(
|
||||
dict[str, str], json.loads(_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
|
||||
@@ -63,6 +63,10 @@ LANGUAGE_CHAT_NAMING_HINT = (
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
# Number of prompts each persona should have
|
||||
NUM_PERSONA_PROMPTS = 4
|
||||
NUM_PERSONA_PROMPT_GENERATION_CHUNKS = 5
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
|
||||
@@ -15,6 +15,9 @@ ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
|
||||
NO_AUTH_USER_ID = "__no_auth_user__"
|
||||
NO_AUTH_USER_EMAIL = "anonymous@onyx.app"
|
||||
|
||||
# For chunking/processing chunks
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
@@ -33,6 +36,8 @@ DISABLED_GEN_AI_MSG = (
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
DEFAULT_CC_PAIR_ID = 1
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -46,6 +51,7 @@ POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
SSL_CERT_FILE = "bundle.pem"
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai"
|
||||
@@ -170,6 +176,10 @@ class AuthType(str, Enum):
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
# Special characters for password validation
|
||||
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
CHAT = "Chat"
|
||||
SEARCH = "Search"
|
||||
@@ -210,6 +220,19 @@ class FileOrigin(str, Enum):
|
||||
OTHER = "other"
|
||||
|
||||
|
||||
class MilestoneRecordType(str, Enum):
|
||||
TENANT_CREATED = "tenant_created"
|
||||
USER_SIGNED_UP = "user_signed_up"
|
||||
MULTIPLE_USERS = "multiple_users"
|
||||
VISITED_ADMIN_PAGE = "visited_admin_page"
|
||||
CREATED_CONNECTOR = "created_connector"
|
||||
CONNECTOR_SUCCEEDED = "connector_succeeded"
|
||||
RAN_QUERY = "ran_query"
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
@@ -254,6 +277,10 @@ class OnyxRedisLocks:
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
||||
|
||||
|
||||
class OnyxCeleryPriority(int, Enum):
|
||||
HIGHEST = 0
|
||||
HIGH = auto()
|
||||
|
||||
@@ -56,6 +56,23 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
|
||||
"png",
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"gif",
|
||||
"mp4",
|
||||
"mov",
|
||||
"mp3",
|
||||
"wav",
|
||||
]
|
||||
_FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
[
|
||||
f" and title!~'*.{extension}'"
|
||||
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
@@ -64,7 +81,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
is_cloud: bool,
|
||||
space: str = "",
|
||||
page_id: str = "",
|
||||
index_recursively: bool = True,
|
||||
index_recursively: bool = False,
|
||||
cql_query: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
@@ -82,23 +99,25 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
|
||||
# if nothing is provided, we will fetch all pages
|
||||
cql_page_query = "type=page"
|
||||
"""
|
||||
If nothing is provided, we default to fetching all pages
|
||||
Only one or none of the following options should be specified so
|
||||
the order shouldn't matter
|
||||
However, we use elif to ensure that only of the following is enforced
|
||||
"""
|
||||
base_cql_page_query = "type=page"
|
||||
if cql_query:
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
cql_page_query = cql_query
|
||||
base_cql_page_query = cql_query
|
||||
elif page_id:
|
||||
# if a cql_query is not provided, we will use the page_id to fetch the page
|
||||
if index_recursively:
|
||||
cql_page_query += f" and ancestor='{page_id}'"
|
||||
base_cql_page_query += f" and (ancestor='{page_id}' or id='{page_id}')"
|
||||
else:
|
||||
cql_page_query += f" and id='{page_id}'"
|
||||
base_cql_page_query += f" and id='{page_id}'"
|
||||
elif space:
|
||||
# if no cql_query or page_id is provided, we will use the space to fetch the pages
|
||||
cql_page_query += f" and space='{quote(space)}'"
|
||||
uri_safe_space = quote(space)
|
||||
base_cql_page_query += f" and space='{uri_safe_space}'"
|
||||
|
||||
self.cql_page_query = cql_page_query
|
||||
self.cql_time_filter = ""
|
||||
self.base_cql_page_query = base_cql_page_query
|
||||
|
||||
self.cql_label_filter = ""
|
||||
if labels_to_skip:
|
||||
@@ -126,6 +145,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def _construct_page_query(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
|
||||
# Add time filters
|
||||
if start:
|
||||
formatted_start_time = datetime.fromtimestamp(
|
||||
start, tz=self.timezone
|
||||
).strftime("%Y-%m-%d %H:%M")
|
||||
page_query += f" and lastmodified >= '{formatted_start_time}'"
|
||||
if end:
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
page_query += f" and lastmodified <= '{formatted_end_time}'"
|
||||
|
||||
return page_query
|
||||
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
attachment_query += _FULL_EXTENSION_FILTER_STRING
|
||||
return attachment_query
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
comment_string = ""
|
||||
|
||||
@@ -205,11 +251,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
|
||||
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
|
||||
def _fetch_document_batches(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
page_query = self._construct_page_query(start, end)
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
@@ -228,11 +278,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
attachment_query = self._construct_attachment_query(confluence_page_id)
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_cql,
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
@@ -248,17 +297,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
|
||||
return self._fetch_document_batches()
|
||||
def poll_source(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches(start, end)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -269,7 +313,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -294,10 +338,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_cql,
|
||||
cql=attachment_query,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
|
||||
@@ -190,7 +190,7 @@ class DiscourseConnector(PollConnector):
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
page = 1
|
||||
page = 0
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -249,17 +250,36 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return new_creds_dict
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
"""
|
||||
List all user emails if we are on a Google Workspace domain.
|
||||
If the domain is gmail.com, or if we attempt to call the Admin SDK and
|
||||
get a 404, fall back to using the single user.
|
||||
"""
|
||||
|
||||
try:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404:
|
||||
logger.warning(
|
||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _fetch_threads(
|
||||
self,
|
||||
|
||||
@@ -54,9 +54,11 @@ def get_total_users_count(db_session: Session) -> int:
|
||||
return user_count + invited_users
|
||||
|
||||
|
||||
async def get_user_count() -> int:
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async with get_async_session_with_tenant() as session:
|
||||
stmt = select(func.count(User.id))
|
||||
if only_admin_users:
|
||||
stmt = stmt.where(User.role == UserRole.ADMIN)
|
||||
result = await session.execute(stmt)
|
||||
user_count = result.scalar()
|
||||
if user_count is None:
|
||||
|
||||
@@ -141,14 +141,20 @@ def get_valid_messages_from_query_sessions(
|
||||
return {row.chat_session_id: row.message for row in first_messages}
|
||||
|
||||
|
||||
# Retrieves chat sessions by user
|
||||
# Chat sessions do not include onyxbot flows
|
||||
def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
include_onyxbot_flows: bool = False,
|
||||
limit: int = 50,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
@@ -310,6 +316,23 @@ def update_chat_session(
|
||||
return chat_session
|
||||
|
||||
|
||||
def delete_all_chat_sessions_for_user(
|
||||
user: User | None, db_session: Session, hard_delete: bool = HARD_DELETE_CHATS
|
||||
) -> None:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
query = db_session.query(ChatSession).filter(
|
||||
ChatSession.user_id == user_id, ChatSession.onyxbot_flow.is_(False)
|
||||
)
|
||||
|
||||
if hard_delete:
|
||||
query.delete(synchronize_session=False)
|
||||
else:
|
||||
query.update({ChatSession.deleted: True}, synchronize_session=False)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_chat_session(
|
||||
user_id: UUID | None,
|
||||
chat_session_id: UUID,
|
||||
|
||||
@@ -310,6 +310,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
if existing_association is not None:
|
||||
return
|
||||
|
||||
# DefaultCCPair has id 1 since it is the first CC pair created
|
||||
# It is DEFAULT_CC_PAIR_ID, but can't set it explicitly because it messed with the
|
||||
# auto-incrementing id
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
@@ -350,7 +353,12 @@ def add_credential_to_connector(
|
||||
last_successful_index_time: datetime | None = None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
@@ -427,7 +435,12 @@ def remove_credential_from_connector(
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
@@ -86,7 +86,7 @@ def _add_user_filters(
|
||||
"""
|
||||
Filter Credentials by:
|
||||
- if the user is in the user_group that owns the Credential
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
- if the user is a curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out Credentials that are owned by groups
|
||||
that the user isn't a curator for
|
||||
@@ -97,6 +97,7 @@ def _add_user_filters(
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
|
||||
if get_editable:
|
||||
user_groups = select(User__UserGroup.user_group_id).where(
|
||||
User__UserGroup.user_id == user.id
|
||||
@@ -152,10 +153,16 @@ def fetch_credential_by_id(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
assume_admin: bool = False,
|
||||
get_editable: bool = True,
|
||||
) -> Credential | None:
|
||||
stmt = select(Credential).distinct()
|
||||
stmt = stmt.where(Credential.id == credential_id)
|
||||
stmt = _add_user_filters(stmt, user, assume_admin=assume_admin)
|
||||
stmt = _add_user_filters(
|
||||
stmt=stmt,
|
||||
user=user,
|
||||
assume_admin=assume_admin,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
return credential
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -10,6 +12,8 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import ContextManager
|
||||
|
||||
import asyncpg # type: ignore
|
||||
import boto3
|
||||
import jwt
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
@@ -23,6 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
@@ -37,6 +42,7 @@ from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -49,28 +55,87 @@ logger = setup_logger()
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
# global so we don't create more than one engine per process
|
||||
# outside of being best practice, this is needed so we can properly pool
|
||||
# connections and not create a new pool on every request
|
||||
USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true"
|
||||
|
||||
# Global so we don't create more than one engine per process
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
SessionFactory: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def create_ssl_context_if_iam() -> ssl.SSLContext | None:
|
||||
"""Create an SSL context if IAM authentication is enabled, else return None."""
|
||||
if USE_IAM_AUTH:
|
||||
return ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
return None
|
||||
|
||||
|
||||
ssl_context = create_ssl_context_if_iam()
|
||||
|
||||
|
||||
def get_iam_auth_token(
|
||||
host: str, port: str, user: str, region: str = "us-east-2"
|
||||
) -> str:
|
||||
"""
|
||||
Generate an IAM authentication token using boto3.
|
||||
"""
|
||||
client = boto3.client("rds", region_name=region)
|
||||
token = client.generate_db_auth_token(
|
||||
DBHostname=host, Port=int(port), DBUsername=user
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
def configure_psycopg2_iam_auth(
|
||||
cparams: dict[str, Any], host: str, port: str, user: str, region: str
|
||||
) -> None:
|
||||
"""
|
||||
Configure cparams for psycopg2 with IAM token and SSL.
|
||||
"""
|
||||
token = get_iam_auth_token(host, port, user, region)
|
||||
cparams["password"] = token
|
||||
cparams["sslmode"] = "require"
|
||||
cparams["sslrootcert"] = SSL_CERT_FILE
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
user: str = POSTGRES_USER,
|
||||
password: str = POSTGRES_PASSWORD,
|
||||
host: str = POSTGRES_HOST,
|
||||
port: str = POSTGRES_PORT,
|
||||
db: str = POSTGRES_DB,
|
||||
app_name: str | None = None,
|
||||
use_iam: bool = USE_IAM_AUTH,
|
||||
region: str = "us-west-2",
|
||||
) -> str:
|
||||
if use_iam:
|
||||
base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}"
|
||||
else:
|
||||
base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
# For asyncpg, do not include application_name in the connection string
|
||||
if app_name and db_api != "asyncpg":
|
||||
if "?" in base_conn_str:
|
||||
return f"{base_conn_str}&application_name={app_name}"
|
||||
else:
|
||||
return f"{base_conn_str}?application_name={app_name}"
|
||||
return base_conn_str
|
||||
|
||||
|
||||
if LOG_POSTGRES_LATENCY:
|
||||
# Function to log before query execution
|
||||
|
||||
@event.listens_for(Engine, "before_cursor_execute")
|
||||
def before_cursor_execute( # type: ignore
|
||||
conn, cursor, statement, parameters, context, executemany
|
||||
):
|
||||
conn.info["query_start_time"] = time.time()
|
||||
|
||||
# Function to log after query execution
|
||||
@event.listens_for(Engine, "after_cursor_execute")
|
||||
def after_cursor_execute( # type: ignore
|
||||
conn, cursor, statement, parameters, context, executemany
|
||||
):
|
||||
total_time = time.time() - conn.info["query_start_time"]
|
||||
# don't spam TOO hard
|
||||
if total_time > 0.1:
|
||||
logger.debug(
|
||||
f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds"
|
||||
@@ -78,7 +143,6 @@ if LOG_POSTGRES_LATENCY:
|
||||
|
||||
|
||||
if LOG_POSTGRES_CONN_COUNTS:
|
||||
# Global counter for connection checkouts and checkins
|
||||
checkout_count = 0
|
||||
checkin_count = 0
|
||||
|
||||
@@ -105,21 +169,13 @@ if LOG_POSTGRES_CONN_COUNTS:
|
||||
logger.debug(f"Total connection checkins: {checkin_count}")
|
||||
|
||||
|
||||
"""END DEBUGGING LOGGING"""
|
||||
|
||||
|
||||
def get_db_current_time(db_session: Session) -> datetime:
|
||||
"""Get the current time from Postgres representing the start of the transaction
|
||||
Within the same transaction this value will not update
|
||||
This datetime object returned should be timezone aware, default Postgres timezone is UTC
|
||||
"""
|
||||
result = db_session.execute(text("SELECT NOW()")).scalar()
|
||||
if result is None:
|
||||
raise ValueError("Database did not return a time")
|
||||
return result
|
||||
|
||||
|
||||
# Regular expression to validate schema names to prevent SQL injection
|
||||
SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$")
|
||||
|
||||
|
||||
@@ -128,16 +184,9 @@ def is_valid_schema_name(name: str) -> bool:
|
||||
|
||||
|
||||
class SqlEngine:
|
||||
"""Class to manage a global SQLAlchemy engine (needed for proper resource control).
|
||||
Will eventually subsume most of the standalone functions in this file.
|
||||
Sync only for now.
|
||||
"""
|
||||
|
||||
_engine: Engine | None = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
|
||||
# Default parameters for engine creation
|
||||
DEFAULT_ENGINE_KWARGS = {
|
||||
"pool_size": 20,
|
||||
"max_overflow": 5,
|
||||
@@ -145,33 +194,27 @@ class SqlEngine:
|
||||
"pool_recycle": POSTGRES_POOL_RECYCLE,
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
"""Private helper method to create and return an Engine."""
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync"
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||
)
|
||||
merged_kwargs = {**cls.DEFAULT_ENGINE_KWARGS, **engine_kwargs}
|
||||
return create_engine(connection_string, **merged_kwargs)
|
||||
engine = create_engine(connection_string, **merged_kwargs)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
event.listen(engine, "do_connect", provide_iam_token)
|
||||
|
||||
return engine
|
||||
|
||||
@classmethod
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
"""Allow the caller to init the engine with extra params. Different clients
|
||||
such as the API server and different Celery workers and tasks
|
||||
need different settings.
|
||||
"""
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
"""Gets the SQLAlchemy engine. Will init a default engine if init hasn't
|
||||
already been called. You probably want to init first!
|
||||
"""
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
@@ -180,12 +223,10 @@ class SqlEngine:
|
||||
|
||||
@classmethod
|
||||
def set_app_name(cls, app_name: str) -> None:
|
||||
"""Class method to set the app name."""
|
||||
cls._app_name = app_name
|
||||
|
||||
@classmethod
|
||||
def get_app_name(cls) -> str:
|
||||
"""Class method to get current app name."""
|
||||
if not cls._app_name:
|
||||
return ""
|
||||
return cls._app_name
|
||||
@@ -217,56 +258,71 @@ def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
for tenant in tenant_ids
|
||||
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
|
||||
]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
|
||||
def build_connection_string(
|
||||
*,
|
||||
db_api: str = ASYNC_DB_API,
|
||||
user: str = POSTGRES_USER,
|
||||
password: str = POSTGRES_PASSWORD,
|
||||
host: str = POSTGRES_HOST,
|
||||
port: str = POSTGRES_PORT,
|
||||
db: str = POSTGRES_DB,
|
||||
app_name: str | None = None,
|
||||
) -> str:
|
||||
if app_name:
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
def get_sqlalchemy_engine() -> Engine:
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
async def get_async_connection() -> Any:
|
||||
"""
|
||||
Custom connection function for async engine when using IAM auth.
|
||||
"""
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
db = POSTGRES_DB
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
|
||||
# asyncpg requires 'ssl="require"' if SSL needed
|
||||
return await asyncpg.connect(
|
||||
user=user, password=token, host=host, port=int(port), database=db, ssl="require"
|
||||
)
|
||||
|
||||
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
global _ASYNC_ENGINE
|
||||
if _ASYNC_ENGINE is None:
|
||||
# Underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# https://github.com/MagicStack/asyncpg/issues/798
|
||||
connection_string = build_connection_string()
|
||||
app_name = SqlEngine.get_app_name() + "_async"
|
||||
connection_string = build_connection_string(
|
||||
db_api=ASYNC_DB_API,
|
||||
use_iam=USE_IAM_AUTH,
|
||||
)
|
||||
|
||||
connect_args: dict[str, Any] = {}
|
||||
if app_name:
|
||||
connect_args["server_settings"] = {"application_name": app_name}
|
||||
|
||||
connect_args["ssl"] = ssl_context
|
||||
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string,
|
||||
connect_args={
|
||||
"server_settings": {
|
||||
"application_name": SqlEngine.get_app_name() + "_async"
|
||||
}
|
||||
},
|
||||
# async engine is only used by API server, so we can use those values
|
||||
# here as well
|
||||
connect_args=connect_args,
|
||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect")
|
||||
def provide_iam_token_async(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
# For async engine using asyncpg, we still need to set the IAM token here.
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
# Dependency to get the current tenant ID
|
||||
# If no token is present, uses the default schema for this use case
|
||||
def get_current_tenant_id(request: Request) -> str:
|
||||
"""Dependency that extracts the tenant ID from the JWT token in the request and sets the context variable."""
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -275,7 +331,6 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
# If no token is present, use the default schema or handle accordingly
|
||||
return current_value
|
||||
|
||||
try:
|
||||
@@ -289,7 +344,6 @@ def get_current_tenant_id(request: Request) -> str:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
return tenant_id
|
||||
except jwt.InvalidTokenError:
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
@@ -316,7 +370,6 @@ async def get_async_session_with_tenant(
|
||||
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
# Set the search_path to the tenant's schema
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
@@ -326,8 +379,6 @@ async def get_async_session_with_tenant(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error setting search_path.")
|
||||
# You can choose to re-raise the exception or handle it
|
||||
# Here, we'll re-raise to prevent proceeding with an incorrect session
|
||||
raise
|
||||
else:
|
||||
yield session
|
||||
@@ -335,9 +386,6 @@ async def get_async_session_with_tenant(
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Get a database session using the current tenant ID from the context variable.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
@@ -349,7 +397,6 @@ def get_session_with_tenant(
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
|
||||
This function:
|
||||
1. Sets the database schema to the specified tenant's schema.
|
||||
2. Preserves the tenant ID across the session.
|
||||
@@ -357,27 +404,20 @@ def get_session_with_tenant(
|
||||
4. Uses the default schema if no tenant ID is provided.
|
||||
"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Store the previous tenant ID
|
||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
try:
|
||||
# Establish a raw connection
|
||||
with engine.connect() as connection:
|
||||
# Access the raw DBAPI connection and set the search_path
|
||||
dbapi_connection = connection.connection
|
||||
|
||||
# Set the search_path outside of any transaction
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
@@ -390,21 +430,17 @@ def get_session_with_tenant(
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# Bind the session to the connection
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# Reset search_path to default after the session is used
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
finally:
|
||||
# Restore the previous tenant ID
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
|
||||
|
||||
|
||||
@@ -424,12 +460,9 @@ def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(
|
||||
detail="User must authenticate",
|
||||
)
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
@@ -437,20 +470,17 @@ def get_session() -> Generator[Session, None, None]:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Generate an async database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
await async_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield async_session
|
||||
|
||||
@@ -461,7 +491,6 @@ def get_session_context_manager() -> ContextManager[Session]:
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
"""Get a session factory."""
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
@@ -489,3 +518,13 @@ async def warm_up_connections(
|
||||
await async_conn.execute(text("SELECT 1"))
|
||||
for async_conn in async_connections:
|
||||
await async_conn.close()
|
||||
|
||||
|
||||
def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
region = os.getenv("AWS_REGION", "us-east-2")
|
||||
# Configure for psycopg2 with IAM token
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
99
backend/onyx/db/milestone.py
Normal file
99
backend/onyx/db/milestone.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.attributes import flag_modified
|
||||
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.models import Milestone
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
USER_ASSISTANT_PREFIX = "user_assistants_used_"
|
||||
MULTI_ASSISTANT_USED = "multi_assistant_used"
|
||||
|
||||
|
||||
def create_milestone(
|
||||
user: User | None,
|
||||
event_type: MilestoneRecordType,
|
||||
db_session: Session,
|
||||
) -> Milestone:
|
||||
milestone = Milestone(
|
||||
event_type=event_type,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(milestone)
|
||||
db_session.commit()
|
||||
|
||||
return milestone
|
||||
|
||||
|
||||
def create_milestone_if_not_exists(
|
||||
user: User | None, event_type: MilestoneRecordType, db_session: Session
|
||||
) -> tuple[Milestone, bool]:
|
||||
# Check if it exists
|
||||
milestone = db_session.execute(
|
||||
select(Milestone).where(Milestone.event_type == event_type)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if milestone is not None:
|
||||
return milestone, False
|
||||
|
||||
# If it doesn't exist, try to create it.
|
||||
try:
|
||||
milestone = create_milestone(user, event_type, db_session)
|
||||
return milestone, True
|
||||
except IntegrityError:
|
||||
# Another thread or process inserted it in the meantime
|
||||
db_session.rollback()
|
||||
# Fetch again to return the existing record
|
||||
milestone = db_session.execute(
|
||||
select(Milestone).where(Milestone.event_type == event_type)
|
||||
).scalar_one() # Now should exist
|
||||
return milestone, False
|
||||
|
||||
|
||||
def update_user_assistant_milestone(
|
||||
milestone: Milestone,
|
||||
user_id: str | None,
|
||||
assistant_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
event_tracker = milestone.event_tracker
|
||||
if event_tracker is None:
|
||||
milestone.event_tracker = event_tracker = {}
|
||||
|
||||
if event_tracker.get(MULTI_ASSISTANT_USED):
|
||||
# No need to keep tracking and populating if the milestone has already been hit
|
||||
return
|
||||
|
||||
user_key = f"{USER_ASSISTANT_PREFIX}{user_id}"
|
||||
|
||||
if event_tracker.get(user_key) is None:
|
||||
event_tracker[user_key] = [assistant_id]
|
||||
elif assistant_id not in event_tracker[user_key]:
|
||||
event_tracker[user_key].append(assistant_id)
|
||||
|
||||
flag_modified(milestone, "event_tracker")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def check_multi_assistant_milestone(
|
||||
milestone: Milestone,
|
||||
db_session: Session,
|
||||
) -> tuple[bool, bool]:
|
||||
"""Returns if the milestone was hit and if it was just hit for the first time"""
|
||||
event_tracker = milestone.event_tracker
|
||||
if event_tracker is None:
|
||||
return False, False
|
||||
|
||||
if event_tracker.get(MULTI_ASSISTANT_USED):
|
||||
return True, False
|
||||
|
||||
for key, value in event_tracker.items():
|
||||
if key.startswith(USER_ASSISTANT_PREFIX) and len(value) > 1:
|
||||
event_tracker[MULTI_ASSISTANT_USED] = True
|
||||
flag_modified(milestone, "event_tracker")
|
||||
db_session.commit()
|
||||
return True, True
|
||||
|
||||
return False, False
|
||||
@@ -5,6 +5,8 @@ from typing import Literal
|
||||
from typing import NotRequired
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
|
||||
@@ -37,7 +39,7 @@ from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -1008,7 +1010,7 @@ class ChatSession(Base):
|
||||
"ChatFolder", back_populates="chat_sessions"
|
||||
)
|
||||
messages: Mapped[list["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="chat_session"
|
||||
"ChatMessage", back_populates="chat_session", cascade="all, delete-orphan"
|
||||
)
|
||||
persona: Mapped["Persona"] = relationship("Persona")
|
||||
|
||||
@@ -1076,6 +1078,8 @@ class ChatMessage(Base):
|
||||
"SearchDoc",
|
||||
secondary=ChatMessage__SearchDoc.__table__,
|
||||
back_populates="chat_messages",
|
||||
cascade="all, delete-orphan",
|
||||
single_parent=True,
|
||||
)
|
||||
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
@@ -1344,6 +1348,11 @@ class StarterMessage(TypedDict):
|
||||
message: str
|
||||
|
||||
|
||||
class StarterMessageModel(BaseModel):
|
||||
name: str
|
||||
message: str
|
||||
|
||||
|
||||
class Persona(Base):
|
||||
__tablename__ = "persona"
|
||||
|
||||
@@ -1534,6 +1543,32 @@ class SlackBot(Base):
|
||||
)
|
||||
|
||||
|
||||
class Milestone(Base):
|
||||
# This table is used to track significant events for a deployment towards finding value
|
||||
# The table is currently not used for features but it may be used in the future to inform
|
||||
# users about the product features and encourage usage/exploration.
|
||||
__tablename__ = "milestone"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
event_type: Mapped[MilestoneRecordType] = mapped_column(String)
|
||||
# Need to track counts and specific ids of certain events to know if the Milestone has been reached
|
||||
event_tracker: Mapped[dict | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User")
|
||||
|
||||
__table_args__ = (UniqueConstraint("event_type", name="uq_milestone_event_type"),)
|
||||
|
||||
|
||||
class TaskQueueState(Base):
|
||||
# Currently refers to Celery Tasks
|
||||
__tablename__ = "task_queue_jobs"
|
||||
|
||||
@@ -543,6 +543,10 @@ def upsert_persona(
|
||||
if tools is not None:
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
# We should only update display priority if it is not already set
|
||||
if existing_persona.display_priority is None:
|
||||
existing_persona.display_priority = display_priority
|
||||
|
||||
persona = existing_persona
|
||||
|
||||
else:
|
||||
|
||||
@@ -7,8 +7,15 @@ from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
|
||||
@@ -185,3 +192,43 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
db_session.commit()
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
for oauth_account in user_to_delete.oauth_accounts:
|
||||
db_session.delete(oauth_account)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.external_perm",
|
||||
"delete_user__ext_group_for_user__no_commit",
|
||||
)(
|
||||
db_session=db_session,
|
||||
user_id=user_to_delete.id,
|
||||
)
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(User__UserGroup).filter(
|
||||
User__UserGroup.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.delete(user_to_delete)
|
||||
db_session.commit()
|
||||
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [
|
||||
remaining_user_email
|
||||
for remaining_user_email in user_emails
|
||||
if remaining_user_email != user_to_delete.email
|
||||
]
|
||||
write_invited_users(remaining_users)
|
||||
|
||||
@@ -369,6 +369,19 @@ class AdminCapable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RandomCapable(abc.ABC):
|
||||
"""Class must implement random document retrieval capability"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 10,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""Retrieve random chunks matching the filters"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseIndex(
|
||||
Verifiable,
|
||||
Indexable,
|
||||
@@ -376,6 +389,7 @@ class BaseIndex(
|
||||
Deletable,
|
||||
AdminCapable,
|
||||
IdRetrievalCapable,
|
||||
RandomCapable,
|
||||
abc.ABC,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -218,4 +218,10 @@ schema DANSWER_CHUNK_NAME {
|
||||
expression: bm25(content) + (5 * bm25(title))
|
||||
}
|
||||
}
|
||||
|
||||
rank-profile random_ {
|
||||
first-phase {
|
||||
expression: random.match
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
<resource-limits>
|
||||
<!-- Default is 75% but this can be increased for Dockerized deployments -->
|
||||
<!-- https://docs.vespa.ai/en/operations/feed-block.html -->
|
||||
<disk>0.75</disk>
|
||||
<disk>0.85</disk>
|
||||
</resource-limits>
|
||||
</tuning>
|
||||
<engine>
|
||||
|
||||
@@ -2,6 +2,7 @@ import concurrent.futures
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import urllib
|
||||
@@ -534,7 +535,7 @@ class VespaIndex(DocumentIndex):
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
for index_name in index_names:
|
||||
params = httpx.QueryParams(
|
||||
{
|
||||
@@ -545,8 +546,12 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
while True:
|
||||
try:
|
||||
vespa_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}"
|
||||
)
|
||||
logger.debug(f'update_single PUT on URL "{vespa_url}"')
|
||||
resp = http_client.put(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}",
|
||||
vespa_url,
|
||||
params=params,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update_dict,
|
||||
@@ -618,7 +623,7 @@ class VespaIndex(DocumentIndex):
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
for index_name in index_names:
|
||||
params = httpx.QueryParams(
|
||||
{
|
||||
@@ -629,8 +634,12 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
while True:
|
||||
try:
|
||||
vespa_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}"
|
||||
)
|
||||
logger.debug(f'delete_single DELETE on URL "{vespa_url}"')
|
||||
resp = http_client.delete(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}",
|
||||
vespa_url,
|
||||
params=params,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
@@ -903,6 +912,32 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
logger.info("Batch deletion completed")
|
||||
|
||||
def random_retrieval(
|
||||
self,
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int = 10,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""Retrieve random chunks matching the filters using Vespa's random ranking
|
||||
|
||||
This method is currently used for random chunk retrieval in the context of
|
||||
assistant starter message creation (passed as sample context for usage by the assistant).
|
||||
"""
|
||||
vespa_where_clauses = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
|
||||
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
|
||||
|
||||
random_seed = random.randint(0, 1000000)
|
||||
|
||||
params: dict[str, str | int | float] = {
|
||||
"yql": yql,
|
||||
"hits": num_to_retrieve,
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
"ranking.profile": "random_",
|
||||
"ranking.properties.random.seed": random_seed,
|
||||
}
|
||||
|
||||
return query_vespa(params)
|
||||
|
||||
|
||||
class _VespaDeleteRequest:
|
||||
def __init__(self, document_id: str, index_name: str) -> None:
|
||||
|
||||
@@ -55,7 +55,7 @@ def remove_invalid_unicode_chars(text: str) -> str:
|
||||
return _illegal_xml_chars_RE.sub("", text)
|
||||
|
||||
|
||||
def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
"""
|
||||
Configure and return an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
@@ -67,5 +67,5 @@ def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
|
||||
else None,
|
||||
verify=False if not MANAGED_VESPA else True,
|
||||
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
|
||||
http2=True,
|
||||
http2=http2,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,12 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str:
|
||||
def build_vespa_filters(
|
||||
filters: IndexFilters,
|
||||
*,
|
||||
include_hidden: bool = False,
|
||||
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
|
||||
) -> str:
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
if vals is None:
|
||||
return ""
|
||||
@@ -78,6 +83,9 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
|
||||
|
||||
filter_str += _build_time_filter(filters.time_cutoff)
|
||||
|
||||
if remove_trailing_and and filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5] # We remove the trailing " and "
|
||||
|
||||
return filter_str
|
||||
|
||||
|
||||
|
||||
@@ -260,6 +260,21 @@ def index_doc_batch_prepare(
|
||||
def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
# Remove any NUL characters from title/semantic_id
|
||||
# This is a known issue with the Zendesk connector
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if document.title:
|
||||
document.title = document.title.replace("\x00", "")
|
||||
if document.semantic_identifier:
|
||||
document.semantic_identifier = document.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
# Remove NUL characters from all sections
|
||||
for section in document.sections:
|
||||
if section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
|
||||
@@ -266,18 +266,27 @@ class DefaultMultiLLM(LLM):
|
||||
# )
|
||||
self._custom_config = custom_config
|
||||
|
||||
# Create a dictionary for model-specific arguments if it's None
|
||||
model_kwargs = model_kwargs or {}
|
||||
|
||||
# NOTE: have to set these as environment variables for Litellm since
|
||||
# not all are able to passed in but they always support them set as env
|
||||
# variables. We'll also try passing them in, since litellm just ignores
|
||||
# addtional kwargs (and some kwargs MUST be passed in rather than set as
|
||||
# env variables)
|
||||
if custom_config:
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
# Specifically pass in "vertex_credentials" as a model_kwarg to the
|
||||
# completion call for vertex AI. More details here:
|
||||
# https://docs.litellm.ai/docs/providers/vertex
|
||||
vertex_credentials_key = "vertex_credentials"
|
||||
vertex_credentials = custom_config.get(vertex_credentials_key)
|
||||
if vertex_credentials and model_provider == "vertex_ai":
|
||||
model_kwargs[vertex_credentials_key] = vertex_credentials
|
||||
else:
|
||||
# standard case
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
if custom_config:
|
||||
model_kwargs.update(custom_config)
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
@@ -453,7 +462,9 @@ class DefaultMultiLLM(LLM):
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
if (
|
||||
DISABLE_LITELLM_STREAMING or self.config.model_name == "o1-2024-12-17"
|
||||
): # TODO: remove once litellm supports streaming
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
return
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o1-2024-12-17",
|
||||
"gpt-4",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
|
||||
@@ -28,6 +28,7 @@ from litellm.exceptions import RateLimitError # type: ignore
|
||||
from litellm.exceptions import Timeout # type: ignore
|
||||
from litellm.exceptions import UnprocessableEntityError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -45,10 +46,19 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def litellm_exception_to_error_msg(
|
||||
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
|
||||
e: Exception,
|
||||
llm: LLM,
|
||||
fallback_to_error_msg: bool = False,
|
||||
custom_error_msg_mappings: dict[str, str]
|
||||
| None = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS,
|
||||
) -> str:
|
||||
error_msg = str(e)
|
||||
|
||||
if custom_error_msg_mappings:
|
||||
for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items():
|
||||
if error_msg_pattern in error_msg:
|
||||
return custom_error_msg
|
||||
|
||||
if isinstance(e, BadRequestError):
|
||||
error_msg = "Bad request: The server couldn't process your request. Please check your input."
|
||||
elif isinstance(e, AuthenticationError):
|
||||
|
||||
@@ -243,6 +243,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, admin_query_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_router)
|
||||
include_router_with_global_prefix_prepended(application, connector_router)
|
||||
include_router_with_global_prefix_prepended(application, user_router)
|
||||
include_router_with_global_prefix_prepended(application, credential_router)
|
||||
include_router_with_global_prefix_prepended(application, cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, folder_router)
|
||||
|
||||
46
backend/onyx/prompts/starter_messages.py
Normal file
46
backend/onyx/prompts/starter_messages.py
Normal file
@@ -0,0 +1,46 @@
|
||||
PERSONA_CATEGORY_GENERATION_PROMPT = """
|
||||
Based on the assistant's name, description, and instructions, generate a list of {num_categories}
|
||||
**unique and diverse** categories that represent different types of starter messages a user
|
||||
might send to initiate a conversation with this chatbot assistant.
|
||||
|
||||
**Ensure that the categories are varied and cover a wide range of topics related to the assistant's capabilities.**
|
||||
|
||||
Provide the categories as a JSON array of strings **without any code fences or additional text**.
|
||||
|
||||
**Context about the assistant:**
|
||||
- **Name**: {name}
|
||||
- **Description**: {description}
|
||||
- **Instructions**: {instructions}
|
||||
""".strip()
|
||||
|
||||
PERSONA_STARTER_MESSAGE_CREATION_PROMPT = """
|
||||
Create a starter message that a **user** might send to initiate a conversation with a chatbot assistant.
|
||||
|
||||
**Category**: {category}
|
||||
|
||||
Your response should include two parts:
|
||||
|
||||
1. **Title**: A short, engaging title that reflects the user's intent
|
||||
(e.g., 'Need Travel Advice', 'Question About Coding', 'Looking for Book Recommendations').
|
||||
|
||||
2. **Message**: The actual message that the user would send to the assistant.
|
||||
This should be natural, engaging, and encourage a helpful response from the assistant.
|
||||
**Avoid overly specific details; keep the message general and broadly applicable.**
|
||||
|
||||
For example:
|
||||
- Instead of "I've just adopted a 6-month-old Labrador puppy who's pulling on the leash,"
|
||||
write "I'm having trouble training my new puppy to walk nicely on a leash."
|
||||
|
||||
Ensure each part is clearly labeled and separated as shown above.
|
||||
Do not provide any additional text or explanation and be extremely concise
|
||||
|
||||
**Context about the assistant:**
|
||||
- **Name**: {name}
|
||||
- **Description**: {description}
|
||||
- **Instructions**: {instructions}
|
||||
""".strip()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(PERSONA_CATEGORY_GENERATION_PROMPT)
|
||||
print(PERSONA_STARTER_MESSAGE_CREATION_PROMPT)
|
||||
@@ -31,6 +31,10 @@ class RedisConnectorIndex:
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's difficult to prevent
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
@@ -54,6 +58,7 @@ class RedisConnectorIndex:
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
@@ -107,6 +112,26 @@ class RedisConnectorIndex:
|
||||
# 10 minute TTL is good.
|
||||
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=3600)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def generator_locked(self) -> bool:
|
||||
if self.redis.exists(self.generator_lock_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
@@ -138,6 +163,7 @@ class RedisConnectorIndex:
|
||||
return status
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_lock_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
@@ -146,6 +172,9 @@ class RedisConnectorIndex:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
271
backend/onyx/secondary_llm_flows/starter_message_creation.py
Normal file
271
backend/onyx/secondary_llm_flows/starter_message_creation.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
|
||||
from litellm import get_supported_openai_params
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import NUM_PERSONA_PROMPT_GENERATION_CHUNKS
|
||||
from onyx.configs.chat_configs import NUM_PERSONA_PROMPTS
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.db.document_set import get_document_sets_by_ids
|
||||
from onyx.db.models import StarterMessageModel as StarterMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.prompts.starter_messages import PERSONA_CATEGORY_GENERATION_PROMPT
|
||||
from onyx.prompts.starter_messages import PERSONA_STARTER_MESSAGE_CREATION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_random_chunks_from_doc_sets(
|
||||
doc_sets: List[str], db_session: Session, user: User | None = None
|
||||
) -> List[InferenceChunk]:
|
||||
"""
|
||||
Retrieves random chunks from the specified document sets.
|
||||
"""
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(curr_ind_name, sec_ind_name)
|
||||
|
||||
acl_filters = build_access_filters_for_user(user, db_session)
|
||||
filters = IndexFilters(document_set=doc_sets, access_control_list=acl_filters)
|
||||
|
||||
chunks = document_index.random_retrieval(
|
||||
filters=filters, num_to_retrieve=NUM_PERSONA_PROMPT_GENERATION_CHUNKS
|
||||
)
|
||||
return cleanup_chunks(chunks)
|
||||
|
||||
|
||||
def parse_categories(content: str) -> List[str]:
|
||||
"""
|
||||
Parses the JSON array of categories from the LLM response.
|
||||
"""
|
||||
# Clean the response to remove code fences and extra whitespace
|
||||
content = content.strip().strip("```").strip()
|
||||
if content.startswith("json"):
|
||||
content = content[4:].strip()
|
||||
|
||||
try:
|
||||
categories = json.loads(content)
|
||||
if not isinstance(categories, list):
|
||||
logger.error("Categories are not a list.")
|
||||
return []
|
||||
return categories
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse categories: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def generate_start_message_prompts(
|
||||
name: str,
|
||||
description: str,
|
||||
instructions: str,
|
||||
categories: List[str],
|
||||
chunk_contents: str,
|
||||
supports_structured_output: bool,
|
||||
fast_llm: Any,
|
||||
) -> List[FunctionCall]:
|
||||
"""
|
||||
Generates the list of FunctionCall objects for starter message generation.
|
||||
"""
|
||||
functions = []
|
||||
for category in categories:
|
||||
# Create a prompt specific to the category
|
||||
start_message_generation_prompt = (
|
||||
PERSONA_STARTER_MESSAGE_CREATION_PROMPT.format(
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
category=category,
|
||||
)
|
||||
)
|
||||
|
||||
if chunk_contents:
|
||||
start_message_generation_prompt += (
|
||||
"\n\nExample content this assistant has access to:\n"
|
||||
"'''\n"
|
||||
f"{chunk_contents}"
|
||||
"\n'''"
|
||||
)
|
||||
|
||||
if supports_structured_output:
|
||||
functions.append(
|
||||
FunctionCall(
|
||||
fast_llm.invoke,
|
||||
(start_message_generation_prompt, None, None, StarterMessage),
|
||||
)
|
||||
)
|
||||
else:
|
||||
functions.append(
|
||||
FunctionCall(
|
||||
fast_llm.invoke,
|
||||
(start_message_generation_prompt,),
|
||||
)
|
||||
)
|
||||
return functions
|
||||
|
||||
|
||||
def parse_unstructured_output(output: str) -> Dict[str, str]:
|
||||
"""
|
||||
Parses the assistant's unstructured output into a dictionary with keys:
|
||||
- 'name' (Title)
|
||||
- 'message' (Message)
|
||||
"""
|
||||
|
||||
# Debug output
|
||||
logger.debug(f"LLM Output for starter message creation: {output}")
|
||||
|
||||
# Patterns to match
|
||||
title_pattern = r"(?i)^\**Title\**\s*:\s*(.+)"
|
||||
message_pattern = r"(?i)^\**Message\**\s*:\s*(.+)"
|
||||
|
||||
# Initialize the response dictionary
|
||||
response_dict = {}
|
||||
|
||||
# Split the output into lines
|
||||
lines = output.strip().split("\n")
|
||||
|
||||
# Variables to keep track of the current key being processed
|
||||
current_key = None
|
||||
current_value_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Check for title
|
||||
title_match = re.match(title_pattern, line.strip())
|
||||
if title_match:
|
||||
# Save previous key-value pair if any
|
||||
if current_key and current_value_lines:
|
||||
response_dict[current_key] = " ".join(current_value_lines).strip()
|
||||
current_value_lines = []
|
||||
current_key = "name"
|
||||
current_value_lines.append(title_match.group(1).strip())
|
||||
continue
|
||||
|
||||
# Check for message
|
||||
message_match = re.match(message_pattern, line.strip())
|
||||
if message_match:
|
||||
if current_key and current_value_lines:
|
||||
response_dict[current_key] = " ".join(current_value_lines).strip()
|
||||
current_value_lines = []
|
||||
current_key = "message"
|
||||
current_value_lines.append(message_match.group(1).strip())
|
||||
continue
|
||||
|
||||
# If the line doesn't match a new key, append it to the current value
|
||||
if current_key:
|
||||
current_value_lines.append(line.strip())
|
||||
|
||||
# Add the last key-value pair
|
||||
if current_key and current_value_lines:
|
||||
response_dict[current_key] = " ".join(current_value_lines).strip()
|
||||
|
||||
# Validate that the necessary keys are present
|
||||
if not all(k in response_dict for k in ["name", "message"]):
|
||||
raise ValueError("Failed to parse the assistant's response.")
|
||||
|
||||
return response_dict
|
||||
|
||||
|
||||
def generate_starter_messages(
|
||||
name: str,
|
||||
description: str,
|
||||
instructions: str,
|
||||
document_set_ids: List[int],
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
) -> List[StarterMessage]:
|
||||
"""
|
||||
Generates starter messages by first obtaining categories and then generating messages for each category.
|
||||
On failure, returns an empty list (or list with processed starter messages if some messages are processed successfully).
|
||||
"""
|
||||
_, fast_llm = get_default_llms(temperature=0.5)
|
||||
|
||||
provider = fast_llm.config.model_provider
|
||||
model = fast_llm.config.model_name
|
||||
|
||||
params = get_supported_openai_params(model=model, custom_llm_provider=provider)
|
||||
supports_structured_output = (
|
||||
isinstance(params, list) and "response_format" in params
|
||||
)
|
||||
|
||||
# Generate categories
|
||||
category_generation_prompt = PERSONA_CATEGORY_GENERATION_PROMPT.format(
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
num_categories=NUM_PERSONA_PROMPTS,
|
||||
)
|
||||
|
||||
category_response = fast_llm.invoke(category_generation_prompt)
|
||||
categories = parse_categories(cast(str, category_response.content))
|
||||
|
||||
if not categories:
|
||||
logger.error("No categories were generated.")
|
||||
return []
|
||||
|
||||
# Fetch example content if document sets are provided
|
||||
if document_set_ids:
|
||||
document_sets = get_document_sets_by_ids(
|
||||
document_set_ids=document_set_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
chunks = get_random_chunks_from_doc_sets(
|
||||
doc_sets=[doc_set.name for doc_set in document_sets],
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# Add example content context
|
||||
chunk_contents = "\n".join(chunk.content.strip() for chunk in chunks)
|
||||
else:
|
||||
chunk_contents = ""
|
||||
|
||||
# Generate prompts for starter messages
|
||||
functions = generate_start_message_prompts(
|
||||
name,
|
||||
description,
|
||||
instructions,
|
||||
categories,
|
||||
chunk_contents,
|
||||
supports_structured_output,
|
||||
fast_llm,
|
||||
)
|
||||
|
||||
# Run LLM calls in parallel
|
||||
if not functions:
|
||||
logger.error("No functions to execute for starter message generation.")
|
||||
return []
|
||||
|
||||
results = run_functions_in_parallel(function_calls=functions)
|
||||
prompts = []
|
||||
|
||||
for response in results.values():
|
||||
try:
|
||||
if supports_structured_output:
|
||||
response_dict = json.loads(response.content)
|
||||
else:
|
||||
response_dict = parse_unstructured_output(response.content)
|
||||
starter_message = StarterMessage(
|
||||
name=response_dict["name"],
|
||||
message=response_dict["message"],
|
||||
)
|
||||
prompts.append(starter_message)
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.error(f"Failed to parse starter message: {e}")
|
||||
continue
|
||||
|
||||
return prompts
|
||||
@@ -9,6 +9,7 @@ from onyx.access.models import default_public_access
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import KV_DOCUMENTS_SEEDED_KEY
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
@@ -71,7 +72,7 @@ def _create_indexable_chunks(
|
||||
source_links={0: preprocessed_doc["url"]},
|
||||
section_continuation=False,
|
||||
source_document=document,
|
||||
title_prefix=preprocessed_doc["title"],
|
||||
title_prefix=preprocessed_doc["title"] + RETURN_SEPARATOR,
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
@@ -216,7 +217,7 @@ def seed_initial_documents(
|
||||
# as we just sent over the Vespa schema and there is a slight delay
|
||||
|
||||
index_with_retries = retry_builder()(document_index.index)
|
||||
index_with_retries(chunks=chunks, fresh_index=True)
|
||||
index_with_retries(chunks=chunks, fresh_index=cohere_enabled)
|
||||
|
||||
# Mock a run for the UI even though it did not actually call out to anything
|
||||
mock_successful_index_attempt(
|
||||
|
||||
@@ -48,6 +48,7 @@ def load_personas_from_yaml(
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
@@ -127,6 +128,7 @@ def load_personas_from_yaml(
|
||||
display_priority=(
|
||||
existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
and persona.get("display_priority") is None
|
||||
else persona.get("display_priority")
|
||||
),
|
||||
is_visible=(
|
||||
|
||||
@@ -7,7 +7,7 @@ personas:
|
||||
- id: 0
|
||||
name: "Search"
|
||||
description: >
|
||||
Assistant with access to documents from your Connected Sources.
|
||||
Assistant with access to documents and knowledge from Connected Sources.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
prompts:
|
||||
- "Answer-Question"
|
||||
@@ -39,7 +39,7 @@ personas:
|
||||
document_sets: []
|
||||
icon_shape: 23013
|
||||
icon_color: "#6FB1FF"
|
||||
display_priority: 1
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Give me an overview of what's here"
|
||||
@@ -54,7 +54,7 @@ personas:
|
||||
- id: 1
|
||||
name: "General"
|
||||
description: >
|
||||
Assistant with no access to documents. Chat with just the Large Language Model.
|
||||
Assistant with no search functionalities. Chat directly with the Large Language Model.
|
||||
prompts:
|
||||
- "OnlyLLM"
|
||||
num_chunks: 0
|
||||
@@ -64,7 +64,7 @@ personas:
|
||||
document_sets: []
|
||||
icon_shape: 50910
|
||||
icon_color: "#FF6F6F"
|
||||
display_priority: 0
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
starter_messages:
|
||||
- name: "Summarize a document"
|
||||
|
||||
@@ -510,7 +510,7 @@ def associate_credential_to_connector(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -532,7 +532,8 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
return response
|
||||
except IntegrityError:
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError: {e}")
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.google_utils.google_auth import (
|
||||
@@ -110,6 +111,7 @@ from onyx.server.documents.models import ObjectCreationIdResponse
|
||||
from onyx.server.documents.models import RunConnectorRequest
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -639,6 +641,15 @@ def get_connector_indexing_status(
|
||||
)
|
||||
)
|
||||
|
||||
# Visiting admin page brings the user to the current connectors page which calls this endpoint
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.VISITED_ADMIN_PAGE,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return indexing_statuses
|
||||
|
||||
|
||||
@@ -663,12 +674,13 @@ def create_connector_from_model(
|
||||
connector_data: ConnectorUpdateRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> ObjectCreationIdResponse:
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -677,10 +689,20 @@ def create_connector_from_model(
|
||||
object_is_perm_sync=connector_data.access_type == AccessType.SYNC,
|
||||
)
|
||||
connector_base = connector_data.to_connector_base()
|
||||
return create_connector(
|
||||
connector_response = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CREATED_CONNECTOR,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return connector_response
|
||||
except ValueError as e:
|
||||
logger.error(f"Error creating connector: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -691,9 +713,10 @@ def create_connector_with_mock_credential(
|
||||
connector_data: ConnectorUpdateRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -728,6 +751,15 @@ def create_connector_with_mock_credential(
|
||||
cc_pair_name=connector_data.name,
|
||||
groups=connector_data.groups,
|
||||
)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CREATED_CONNECTOR,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
@@ -744,7 +776,7 @@ def update_connector_from_model(
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
|
||||
@@ -122,7 +122,7 @@ def create_credential_from_model(
|
||||
) -> ObjectCreationIdResponse:
|
||||
if not _ignore_credential_permissions(credential_info.source):
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -164,7 +164,12 @@ def get_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CredentialSnapshot | StatusResponse[int]:
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
|
||||
@@ -31,7 +31,7 @@ def create_document_set(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> int:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -56,7 +56,7 @@ def patch_document_set(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
|
||||
@@ -15,8 +15,11 @@ from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.prompt_builder.utils import build_dummy_prompt
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import StarterMessageModel as StarterMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.persona import create_assistant_category
|
||||
@@ -34,7 +37,11 @@ from onyx.db.persona import update_persona_shared_users
|
||||
from onyx.db.persona import update_persona_visibility
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.secondary_llm_flows.starter_message_creation import (
|
||||
generate_starter_messages,
|
||||
)
|
||||
from onyx.server.features.persona.models import CreatePersonaRequest
|
||||
from onyx.server.features.persona.models import GenerateStarterMessageRequest
|
||||
from onyx.server.features.persona.models import ImageGenerationToolStatus
|
||||
from onyx.server.features.persona.models import PersonaCategoryCreate
|
||||
from onyx.server.features.persona.models import PersonaCategoryResponse
|
||||
@@ -44,6 +51,7 @@ from onyx.server.features.persona.models import PromptTemplateResponse
|
||||
from onyx.server.models import DisplayPriorityRequest
|
||||
from onyx.tools.utils import is_image_generation_available
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -167,14 +175,25 @@ def create_persona(
|
||||
create_persona_request: CreatePersonaRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> PersonaSnapshot:
|
||||
return create_update_persona(
|
||||
persona_snapshot = create_update_persona(
|
||||
persona_id=None,
|
||||
create_persona_request=create_persona_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CREATED_ASSISTANT,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return persona_snapshot
|
||||
|
||||
|
||||
# NOTE: This endpoint cannot update persona configuration options that
|
||||
# are core to the persona, such as its display priority and
|
||||
@@ -363,3 +382,26 @@ def build_final_template_prompt(
|
||||
retrieval_disabled=retrieval_disabled,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@basic_router.post("/assistant-prompt-refresh")
|
||||
def build_assistant_prompts(
|
||||
generate_persona_prompt_request: GenerateStarterMessageRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> list[StarterMessage]:
|
||||
try:
|
||||
logger.info(
|
||||
"Generating starter messages for user: %s", user.id if user else "Anonymous"
|
||||
)
|
||||
return generate_starter_messages(
|
||||
name=generate_persona_prompt_request.name,
|
||||
description=generate_persona_prompt_request.description,
|
||||
instructions=generate_persona_prompt_request.instructions,
|
||||
document_set_ids=generate_persona_prompt_request.document_set_ids,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate starter messages")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -17,6 +17,14 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# More minimal request for generating a persona prompt
|
||||
class GenerateStarterMessageRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
instructions: str
|
||||
document_set_ids: list[int]
|
||||
|
||||
|
||||
class CreatePersonaRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
@@ -57,7 +57,6 @@ def test_llm_configuration(
|
||||
)
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
|
||||
|
||||
if (
|
||||
test_llm_request.fast_default_model_name
|
||||
and test_llm_request.fast_default_model_name
|
||||
|
||||
@@ -4,7 +4,9 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import ChannelConfig
|
||||
from onyx.db.models import User
|
||||
@@ -25,6 +27,7 @@ from onyx.server.manage.models import SlackBot
|
||||
from onyx.server.manage.models import SlackBotCreationRequest
|
||||
from onyx.server.manage.models import SlackChannelConfig
|
||||
from onyx.server.manage.models import SlackChannelConfigCreationRequest
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
@@ -217,6 +220,7 @@ def create_bot(
|
||||
slack_bot_creation_request: SlackBotCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> SlackBot:
|
||||
slack_bot_model = insert_slack_bot(
|
||||
db_session=db_session,
|
||||
@@ -225,6 +229,15 @@ def create_bot(
|
||||
bot_token=slack_bot_creation_request.bot_token,
|
||||
app_token=slack_bot_creation_request.app_token,
|
||||
)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.CREATED_ONYX_BOT,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return SlackBot.from_model(slack_bot_model)
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import SUPER_USERS
|
||||
from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.noauth_user import fetch_no_auth_user
|
||||
@@ -41,11 +42,8 @@ from onyx.db.auth import get_total_users_count
|
||||
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import list_users
|
||||
from onyx.db.users import validate_user_role_update
|
||||
@@ -61,7 +59,6 @@ from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.server.utils import send_user_email_invite
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -370,45 +367,10 @@ async def delete_user(
|
||||
db_session.expunge(user_to_delete)
|
||||
|
||||
try:
|
||||
for oauth_account in user_to_delete.oauth_accounts:
|
||||
db_session.delete(oauth_account)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.external_perm",
|
||||
"delete_user__ext_group_for_user__no_commit",
|
||||
)(
|
||||
db_session=db_session,
|
||||
user_id=user_to_delete.id,
|
||||
)
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(User__UserGroup).filter(
|
||||
User__UserGroup.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.delete(user_to_delete)
|
||||
db_session.commit()
|
||||
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [
|
||||
user for user in user_emails if user != user_email.user_email
|
||||
]
|
||||
write_invited_users(remaining_users)
|
||||
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
logger.info(f"Deleted user {user_to_delete.email}")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
full_traceback = traceback.format_exc()
|
||||
logger.error(f"Full stack trace:\n{full_traceback}")
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Error deleting user")
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import api_key_dep
|
||||
from onyx.configs.constants import DEFAULT_CC_PAIR_ID
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
@@ -79,7 +80,7 @@ def upsert_ingestion_doc(
|
||||
document.source = DocumentSource.FILE
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=doc_info.cc_pair_id or 0, db_session=db_session
|
||||
cc_pair_id=doc_info.cc_pair_id or DEFAULT_CC_PAIR_ID, db_session=db_session
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -30,10 +30,12 @@ from onyx.chat.prompt_builder.citations_prompt import (
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
from onyx.db.chat import add_chats_to_session_from_slack_thread
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import delete_all_chat_sessions_for_user
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import duplicate_chat_session_for_user_from_slack
|
||||
from onyx.db.chat import get_chat_message
|
||||
@@ -44,7 +46,9 @@ from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import create_doc_retrieval_feedback
|
||||
from onyx.db.models import User
|
||||
@@ -81,6 +85,7 @@ from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from onyx.utils.headers import get_custom_tool_additional_request_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -177,12 +182,15 @@ def get_chat_session(
|
||||
description=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||
persona_icon_color=chat_session.persona.icon_color
|
||||
if chat_session.persona
|
||||
else None,
|
||||
persona_icon_shape=chat_session.persona.icon_shape
|
||||
if chat_session.persona
|
||||
else None,
|
||||
current_alternate_model=chat_session.current_alternate_model,
|
||||
messages=[
|
||||
translate_db_message_to_chat_message_detail(
|
||||
msg, remove_doc_content=is_shared # if shared, don't leak doc content
|
||||
)
|
||||
for msg in session_messages
|
||||
translate_db_message_to_chat_message_detail(msg) for msg in session_messages
|
||||
],
|
||||
time_created=chat_session.time_created,
|
||||
shared_status=chat_session.shared_status,
|
||||
@@ -192,7 +200,7 @@ def get_chat_session(
|
||||
@router.post("/create-chat-session")
|
||||
def create_new_chat_session(
|
||||
chat_session_creation_request: ChatSessionCreationRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_limited_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateChatSessionID:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -276,6 +284,17 @@ def patch_chat_session(
|
||||
return None
|
||||
|
||||
|
||||
@router.delete("/delete-all-chat-sessions")
|
||||
def delete_all_chat_sessions(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
delete_all_chat_sessions_for_user(user=user, db_session=db_session)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/delete-chat-session/{session_id}")
|
||||
def delete_chat_session_by_id(
|
||||
session_id: UUID,
|
||||
@@ -315,8 +334,9 @@ def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_limited_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
_rate_limit_check: None = Depends(check_token_rate_limits),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
This endpoint is both used for all the following purposes:
|
||||
@@ -347,6 +367,15 @@ def handle_new_chat_message(
|
||||
):
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
event_type=MilestoneRecordType.RAN_QUERY,
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in stream_chat_message(
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.chat.models import RetrievalDocs
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import ChunkContext
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
@@ -151,6 +152,10 @@ class ChatSessionUpdateRequest(BaseModel):
|
||||
sharing_status: ChatSessionSharedStatus
|
||||
|
||||
|
||||
class DeleteAllSessionsRequest(BaseModel):
|
||||
session_type: SessionType
|
||||
|
||||
|
||||
class RenameChatSessionResponse(BaseModel):
|
||||
new_name: str # This is only really useful if the name is generated
|
||||
|
||||
@@ -220,6 +225,8 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
description: str | None
|
||||
persona_id: int | None = None
|
||||
persona_name: str | None
|
||||
persona_icon_color: str | None
|
||||
persona_icon_shape: int | None
|
||||
messages: list[ChatMessageDetail]
|
||||
time_created: datetime
|
||||
shared_status: ChatSessionSharedStatus
|
||||
|
||||
@@ -1,21 +1,10 @@
|
||||
import json
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
@@ -62,31 +51,3 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
|
||||
masked_creds[key] = mask_string(val)
|
||||
return masked_creds
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Invitation to Join Onyx Workspace"
|
||||
msg["From"] = current_user.email
|
||||
msg["To"] = user_email
|
||||
|
||||
email_body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join a workspace on Onyx.
|
||||
|
||||
To join the workspace, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/login
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
)
|
||||
|
||||
msg.attach(MIMEText(email_body, "plain"))
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp_server:
|
||||
smtp_server.starttls()
|
||||
smtp_server.login(SMTP_USER, SMTP_PASS)
|
||||
smtp_server.send_message(msg)
|
||||
|
||||
@@ -10,10 +10,17 @@ from onyx.configs.app_configs import DISABLE_TELEMETRY
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.constants import KV_CUSTOMER_UUID_KEY
|
||||
from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
|
||||
_CACHED_UUID: str | None = None
|
||||
@@ -103,3 +110,37 @@ def optional_telemetry(
|
||||
except Exception:
|
||||
# Should never interfere with normal functions of Onyx
|
||||
pass
|
||||
|
||||
|
||||
def mt_cloud_telemetry(
|
||||
distinct_id: str,
|
||||
event: MilestoneRecordType,
|
||||
properties: dict | None = None,
|
||||
) -> None:
|
||||
if not MULTI_TENANT:
|
||||
return
|
||||
|
||||
# MIT version should not need to include any Posthog code
|
||||
# This is only for Onyx MT Cloud, this code should also never be hit, no reason for any orgs to
|
||||
# be running the Multi Tenant version of Onyx.
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.utils.telemetry",
|
||||
attribute="event_telemetry",
|
||||
fallback=noop_fallback,
|
||||
)(distinct_id, event, properties)
|
||||
|
||||
|
||||
def create_milestone_and_report(
|
||||
user: User | None,
|
||||
distinct_id: str,
|
||||
event_type: MilestoneRecordType,
|
||||
properties: dict | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
_, is_new = create_milestone_if_not_exists(user, event_type, db_session)
|
||||
if is_new:
|
||||
mt_cloud_telemetry(
|
||||
distinct_id=distinct_id,
|
||||
event=event_type,
|
||||
properties=properties,
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.54.1
|
||||
litellm==1.55.4
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
|
||||
@@ -9,6 +9,7 @@ mypy-extensions==1.0.0
|
||||
mypy==1.8.0
|
||||
pandas-stubs==2.2.3.241009
|
||||
pandas==2.2.3
|
||||
posthog==3.7.4
|
||||
pre-commit==3.2.2
|
||||
pytest-asyncio==0.22.0
|
||||
pytest==7.4.4
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
cohere==5.6.1
|
||||
posthog==3.7.4
|
||||
python3-saml==1.15.0
|
||||
cohere==5.6.1
|
||||
@@ -12,5 +12,5 @@ torch==2.2.0
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.54.1
|
||||
litellm==1.55.4
|
||||
sentry-sdk[fastapi,celery,starlette]==2.14.0
|
||||
@@ -48,4 +48,7 @@ sleep 1
|
||||
echo "Running Alembic migration..."
|
||||
alembic upgrade head
|
||||
|
||||
# Run the following instead of the above if using MT cloud
|
||||
# alembic -n schema_private upgrade head
|
||||
|
||||
echo "Containers restarted and migration completed."
|
||||
|
||||
@@ -14,7 +14,7 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
DOMAIN = "test.com"
|
||||
DEFAULT_PASSWORD = "test"
|
||||
DEFAULT_PASSWORD = "TestPassword123!"
|
||||
|
||||
|
||||
def build_email(name: str) -> str:
|
||||
|
||||
@@ -219,6 +219,7 @@ def test_slack_permission_sync(
|
||||
assert private_message not in onyx_doc_message_strings
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="flaky", strict=False)
|
||||
def test_slack_group_permission_sync(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
|
||||
@@ -376,6 +376,26 @@ def process_text(
|
||||
"The code demonstrates variable assignment.",
|
||||
[],
|
||||
),
|
||||
(
|
||||
"Long JSON string in code block",
|
||||
[
|
||||
"```json\n{",
|
||||
'"name": "John Doe",',
|
||||
'"age": 30,',
|
||||
'"city": "New York",',
|
||||
'"hobbies": ["reading", "swimming", "cycling"],',
|
||||
'"education": {',
|
||||
' "degree": "Bachelor\'s",',
|
||||
' "major": "Computer Science",',
|
||||
' "university": "Example University"',
|
||||
"}",
|
||||
"}\n```",
|
||||
],
|
||||
'```json\n{"name": "John Doe","age": 30,"city": "New York","hobbies": '
|
||||
'["reading", "swimming", "cycling"],"education": { '
|
||||
'"degree": "Bachelor\'s", "major": "Computer Science", "university": "Example University"}}\n```',
|
||||
[],
|
||||
),
|
||||
(
|
||||
"Citation as a single token",
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-beat
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-light
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-primary
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.20
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -92,6 +92,7 @@ services:
|
||||
- LOG_POSTGRES_LATENCY=${LOG_POSTGRES_LATENCY:-}
|
||||
- LOG_POSTGRES_CONN_COUNTS=${LOG_POSTGRES_CONN_COUNTS:-}
|
||||
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
|
||||
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
@@ -103,6 +104,13 @@ services:
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
# Seeding configuration
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@@ -223,6 +231,13 @@ services:
|
||||
|
||||
# Enterprise Edition stuff
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@@ -252,7 +267,7 @@ services:
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
|
||||
|
||||
- NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED:-}
|
||||
# Enterprise Edition only
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
# DO NOT TURN ON unless you have EXPLICIT PERMISSION from Onyx.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user