Compare commits

..

38 Commits

Author SHA1 Message Date
pablonyx
25d9266da4 update 2025-02-26 08:48:35 -08:00
Weves
23073d91b9 reduce number of chars to index for search 2025-02-25 19:27:50 -08:00
Chris Weaver
f767b1f476 Fix confluence permission syncing at scale (#4129)
* Fix confluence permission syncing at scale

* Remove line

* Better log message

* Adjust log
2025-02-25 19:22:52 -08:00
pablonyx
9ffc8cb2c4 k 2025-02-25 18:15:49 -08:00
pablonyx
98bfb58147 Handle bad slack configurations– multi tenant (#4118)
* k

* quick nit

* k

* k
2025-02-25 22:22:54 +00:00
evan-danswer
6ce810e957 faster indexing status at scale plus minor cleanups (#4081)
* faster indexing status at scale plus minor cleanups

* mypy

* address chris comments

* remove extra prints
2025-02-25 21:22:26 +00:00
pablonyx
07b0b57b31 (nit) bump timeout 2025-02-25 14:10:30 -08:00
pablonyx
118cdd7701 Chat search (#4113)
* add chat search

* don't add the bible

* base functional

* k

* k

* functioning

* functioning well

* functioning well

* k

* delete bible

* quick cleanup

* quick cleanup

* k

* fixed frontend hooks

* delete bible

* nit

* nit

* nit

* fix build

* k

* improved debouncing

* address comments

* fix alembic

* k
2025-02-25 20:49:46 +00:00
rkuo-danswer
ac83b4c365 validate connector deletion (#4108)
* validate connector deletion

* fixes

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-25 20:35:21 +00:00
pablonyx
fa408ff447 add 3.7 (#4116) 2025-02-25 12:41:40 -08:00
rkuo-danswer
4aa8eb8b75 fix scrolling test (#4117)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-02-25 10:23:04 -08:00
rkuo-danswer
60bd9271f7 Bugfix/model tests (#4092)
* trying out a fix

* add ability to manually run model tests

* add log dump

* check status code, not text?

* just the model server

* add port mapping to host

* pass through more api keys

* add azure tests

* fix litellm env vars

* fix env vars in github workflow

* temp disable litellm test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-25 04:53:51 +00:00
Weves
5d58a5e3ea Add ability to index all of Github 2025-02-24 18:56:36 -08:00
Chris Weaver
a99dd05533 Add option to index all Jira projects (#4106)
* Add option to index all Jira projects

* Fix test

* Fix web build

* Address comment
2025-02-25 02:07:00 +00:00
pablonyx
0dce67094e Prettier formatting for bedrock (#4111)
* k

* k
2025-02-25 02:05:29 +00:00
pablonyx
ffd14435a4 Text overflow logic (#4051)
* proper components

* k

* k

* k
2025-02-25 01:05:22 +00:00
rkuo-danswer
c9a3b45ad4 more aggressive handling of tasks blocking deletion (#4093)
* more aggressive handling of tasks blocking deletion

* comment updated

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-24 22:41:13 +00:00
pablonyx
7d40676398 Heavy task improvements, logging, and validation (#4058) 2025-02-24 13:48:53 -08:00
rkuo-danswer
b9e79e5db3 tighten up logs (#4076)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-24 19:23:00 +00:00
rkuo-danswer
558bbe16e4 Bugfix/termination cleanup (#4077)
* move activity timeout cleanup to the function exit

* fix excessive logging

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-24 19:21:55 +00:00
evan-danswer
076619ce2c make Settings model match db (#4087) 2025-02-24 19:04:36 +00:00
pablonyx
1263e21eb5 k (#4102) 2025-02-24 17:44:18 +00:00
pablonyx
f0c13b6558 fix starter message editing (#4101) 2025-02-24 01:01:01 +00:00
evan-danswer
a7125662f1 Fix gpt o-series code block formatting (#4089)
* prompt addition for gpt o-series to encourage markdown formatting of code blocks

* fix to match https://simonwillison.net/tags/markdown/

* chris comment

* chris comment
2025-02-24 00:59:48 +00:00
evan-danswer
4a4e4a6c50 thread utils respect contextvars (#4074)
* thread utils respect contextvars now

* address pablo comments

* removed tenant id from places it was already being passed

* fix rate limit check and pablo comment
2025-02-24 00:43:21 +00:00
pablonyx
1f2af373e1 improve scroll (#4096) 2025-02-23 19:20:07 +00:00
Weves
bdaa293ae4 Fix nginx for prod compose file 2025-02-21 16:57:54 -08:00
pablonyx
5a131f4547 Fix integration tests (#4059) 2025-02-21 15:56:11 -08:00
rkuo-danswer
ffb7d5b85b enable manual testing for model server (#4003)
* trying out a fix

* add ability to manually run model tests

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-21 14:00:32 -08:00
rkuo-danswer
fe8a5d671a don't spam the logs with texts on auth errors (#4085)
* don't spam the logs with texts on auth errors

* refactor the logging a bit

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-21 13:40:07 -08:00
Yuhong Sun
6de53ebf60 README Touchup (#4088) 2025-02-21 13:31:07 -08:00
rkuo-danswer
61d536c782 tool fixes (#4075) 2025-02-21 12:30:33 -08:00
Chris Weaver
e1ff9086a4 Fix LLM selection (#4078) 2025-02-21 11:32:57 -08:00
evan-danswer
ba21bacbbf coerce useLanggraph to boolean (#4084)
* coerce useLanggraph to boolean
2025-02-21 09:43:46 -08:00
pablonyx
158bccc3fc Default on for non-ee (#4083) 2025-02-21 09:11:45 -08:00
Weves
599b7705c2 Fix gitbook connector issues 2025-02-20 15:29:11 -08:00
rkuo-danswer
4958a5355d try more efficient query (#4047) 2025-02-20 12:58:50 -08:00
Chris Weaver
c4b8519381 Add support for sending email invites for single tenant users (#4065) 2025-02-19 21:05:23 -08:00
130 changed files with 3945 additions and 1056 deletions

View File

@@ -145,7 +145,7 @@ jobs:
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
@@ -157,6 +157,7 @@ jobs:
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
@@ -199,7 +200,7 @@ jobs:
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |

View File

@@ -1,18 +1,29 @@
name: Connector Tests
name: Model Server Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
workflow_dispatch:
inputs:
branch:
description: 'Branch to run the workflow on'
required: false
default: 'main'
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# OpenAI
# API keys for testing
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
jobs:
model-check:
@@ -26,6 +37,23 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Model Server Docker image
run: |
docker pull onyxdotapp/onyx-model-server:latest
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
- name: Set up Python
uses: actions/setup-python@v5
with:
@@ -41,6 +69,49 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
@@ -56,3 +127,23 @@ jobs:
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v

View File

@@ -26,12 +26,12 @@
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date.
Create custom AI agents with unique prompts, knowledge, and actions the agents can take.
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
<h3>Feature Showcase</h3>
<h3>Feature Highlights</h3>
**Deep research over your team's knowledge:**
@@ -63,22 +63,21 @@ We also have built-in support for high-availability/scalable deployment on Kuber
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
## 🚧 Roadmap
- Extensions to the Chrome Plugin
- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- Personalized Search
- Organizational understanding and ability to locate and suggest experts from your team.
- Code Search
- SQL and Structured Query Language
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models only through Onyx + learn from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
## 🔌 Connectors
Keep knowledge and access up to sync across 40+ connectors:

View File

@@ -0,0 +1,31 @@
"""add index
Revision ID: 8f43500ee275
Revises: da42808081e3
Create Date: 2025-02-24 17:35:33.072714
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8f43500ee275"
down_revision = "da42808081e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a basic index on the lowercase message column for direct text matching
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
op.execute(
"""
CREATE INDEX idx_chat_message_message_lower
ON chat_message (LOWER(substring(message, 1, 1500)))
"""
)
def downgrade() -> None:
# Drop the index
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -0,0 +1,120 @@
"""migrate jira connectors to new format
Revision ID: da42808081e3
Revises: f13db29f3101
Create Date: 2025-02-24 11:24:54.396040
"""
from alembic import op
import sqlalchemy as sa
import json
from onyx.configs.constants import DocumentSource
from onyx.connectors.onyx_jira.utils import extract_jira_project
# revision identifiers, used by Alembic.
revision = "da42808081e3"
down_revision = "f13db29f3101"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config
for connector_id, old_config in jira_connectors:
if not old_config:
continue
# Extract project key from URL if it exists
new_config: dict[str, str | None] = {}
if project_url := old_config.get("jira_project_url"):
# Parse the URL to get base and project
try:
jira_base, project_key = extract_jira_project(project_url)
new_config = {"jira_base_url": jira_base, "project_key": project_key}
except ValueError:
# If URL parsing fails, just use the URL as the base
new_config = {
"jira_base_url": project_url.split("/projects/")[0],
"project_key": None,
}
else:
# For connectors without a project URL, we need admin intervention
# Mark these for review
print(
f"WARNING: Jira connector {connector_id} has no project URL configured"
)
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :id
"""
),
{"id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config back to the old format
for connector_id, new_config in jira_connectors:
if not new_config:
continue
old_config = {}
base_url = new_config.get("jira_base_url")
project_key = new_config.get("project_key")
if base_url and project_key:
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
elif base_url:
old_config = {"jira_project_url": base_url}
else:
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :old_config
WHERE id = :id
"""
),
{"id": connector_id, "old_config": old_config},
)

View File

@@ -0,0 +1,27 @@
"""Add composite index for last_modified and last_synced to document
Revision ID: f13db29f3101
Revises: b388730a2899
Create Date: 2025-02-18 22:48:11.511389
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f13db29f3101"
down_revision = "acaab4ef4507"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
op.create_index(
"ix_document_sync_status",
"document",
["last_modified", "last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index("ix_document_sync_status", table_name="document")

View File

@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import UserGroup__ConnectorCredentialPair
@@ -35,10 +36,11 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
def get_cc_pairs_by_source(
db_session: Session,
source_type: DocumentSource,
only_sync: bool,
access_type: AccessType | None = None,
status: ConnectorCredentialPairStatus | None = None,
) -> list[ConnectorCredentialPair]:
"""
Get all cc_pairs for a given source type (and optionally only sync)
Get all cc_pairs for a given source type with optional filtering by access_type and status
result is sorted by cc_pair id
"""
query = (
@@ -48,8 +50,11 @@ def get_cc_pairs_by_source(
.order_by(ConnectorCredentialPair.id)
)
if only_sync:
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
if access_type is not None:
query = query.filter(ConnectorCredentialPair.access_type == access_type)
if status is not None:
query = query.filter(ConnectorCredentialPair.status == status)
cc_pairs = query.all()
return cc_pairs

View File

@@ -62,12 +62,14 @@ def _fetch_permissions_for_permission_ids(
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain)",
supportsAllDrives=True,
continue_on_404_or_403=True,
)
permissions_for_doc_id = []
@@ -104,7 +106,13 @@ def _get_permissions_from_slim_doc(
user_emails: set[str] = set()
group_emails: set[str] = set()
public = False
skipped_permissions = 0
for permission in permissions_list:
if not permission:
skipped_permissions += 1
continue
permission_type = permission["type"]
if permission_type == "user":
user_emails.add(permission["emailAddress"])
@@ -121,6 +129,11 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
if skipped_permissions > 0:
logger.warning(
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
)
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())

View File

@@ -13,7 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.api_key import is_api_key_email_address
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit
@@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
def _check_token_rate_limits(user: User | None) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global(tenant_id)
_user_is_rate_limited_by_global()
elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings
_user_is_rate_limited_by_global(tenant_id)
_user_is_rate_limited_by_global()
else:
run_functions_tuples_in_parallel(
[
(_user_is_rate_limited, (user.id, tenant_id)),
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
(_user_is_rate_limited_by_global, (tenant_id,)),
(_user_is_rate_limited, (user.id,)),
(_user_is_rate_limited_by_group, (user.id,)),
(_user_is_rate_limited_by_global, ()),
]
)
@@ -52,8 +52,8 @@ User rate limits
"""
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def _user_is_rate_limited(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
@@ -93,8 +93,8 @@ User Group rate limits
"""
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits:

View File

@@ -224,7 +224,7 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20241022",
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
)

View File

@@ -98,12 +98,17 @@ class CloudEmbedding:
return final_embeddings
except Exception as e:
error_string = (
f"Error embedding text with OpenAI: {str(e)} \n"
f"Model: {model} \n"
f"Provider: {self.provider} \n"
f"Texts: {texts}"
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
)
logger.error(error_string)
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
async def _embed_cohere(

View File

@@ -10,6 +10,7 @@ from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
@@ -187,23 +188,51 @@ def send_subscription_cancellation_email(user_email: str) -> None:
send_email(user_email, subject, html_content, text_content)
def send_user_email_invite(user_email: str, current_user: User) -> None:
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
subject = "Invitation to Join Onyx Organization"
heading = "You've Been Invited!"
message = (
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
elif auth_type == AuthType.BASIC:
message += (
"<p>To join the organization, please click the button below to set a password "
"and complete your registration.</p>"
)
elif auth_type == AuthType.GOOGLE_OAUTH:
message += (
"<p>To join the organization, please click the button below to login with Google "
"and complete your registration.</p>"
)
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
message += (
"<p>To join the organization, please click the button below to"
" complete your registration.</p>"
)
else:
raise ValueError(f"Invalid auth type: {auth_type}")
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
"You'll be asked to set a password or login with Google to complete your registration."
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
send_email(user_email, subject, html_content, text_content)

View File

@@ -92,7 +92,8 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue.
"""This is a redis specific way to build a list of tasks in a queue and return them
as a set.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.

View File

@@ -8,16 +8,21 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
@@ -52,6 +57,51 @@ class TaskDependencyError(RuntimeError):
with connector deletion."""
def revoke_tasks_blocking_deletion(
redis_connector: RedisConnector, db_session: Session, app: Celery
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
task_logger.info(
f"Revoked indexing task {index_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
try:
permissions_sync_payload = redis_connector.permissions.payload
if permissions_sync_payload and permissions_sync_payload.celery_task_id:
app.control.revoke(permissions_sync_payload.celery_task_id)
task_logger.info(
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking pruning task")
try:
prune_payload = redis_connector.prune.payload
if prune_payload and prune_payload.celery_task_id:
app.control.revoke(prune_payload.celery_task_id)
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
except Exception:
task_logger.exception("Exception while revoking permissions sync task")
try:
external_group_sync_payload = redis_connector.external_group_sync.payload
if external_group_sync_payload and external_group_sync_payload.celery_task_id:
app.control.revoke(external_group_sync_payload.celery_task_id)
task_logger.info(
f"Revoked external group sync task {external_group_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking external group sync task")
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
ignore_result=True,
@@ -64,17 +114,33 @@ def check_for_connector_deletion_task(
) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
# Prevent this task from overlapping with itself
if not lock_beat.acquire(blocking=False):
return None
try:
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
except Exception:
task_logger.exception(
"Exception while validating connector deletion fences"
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
@@ -92,9 +158,38 @@ def check_for_connector_deletion_task(
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
# on the first error, we set a stop signal and revoke the dependent tasks
# on subsequent errors, we hard reset blocking fences after our specified timeout
# is exceeded
task_logger.info(str(e))
redis_connector.stop.set_fence(True)
if not redis_connector.stop.fenced:
# one time revoke of celery tasks
task_logger.info("Revoking any tasks blocking deletion.")
revoke_tasks_blocking_deletion(
redis_connector, db_session, self.app
)
redis_connector.stop.set_fence(True)
redis_connector.stop.set_timeout()
else:
# stop signal already set
if redis_connector.stop.timed_out:
# waiting too long, just reset blocking fences
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()
else:
# just wait
pass
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
@@ -169,6 +264,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
redis_connector.delete.set_active()
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
@@ -401,3 +497,171 @@ def monitor_connector_deletion_taskset(
)
redis_connector.delete.reset()
def validate_connector_deletion_fences(
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
# validate all existing connector deletion jobs
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
continue
validate_connector_deletion_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
r,
)
lock_beat.reacquire()
return
def validate_connector_deletion_fence(
tenant_id: str | None,
key_bytes: bytes,
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.delete.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.delete.payload
except ValidationError:
task_logger.exception(
"validate_connector_deletion_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return
if not payload:
return
# OK, there's actually something for us to validate
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.delete.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_connector_deletion_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.delete.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.delete.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_connector_deletion_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return

View File

@@ -30,6 +30,7 @@ from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -42,8 +43,10 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
@@ -63,6 +66,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
@@ -193,12 +197,19 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
task_logger.info(f"check_for_doc_permissions_sync finished: tenant={tenant_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id} {error_msg}"
)
task_logger.exception(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id}"
)
finally:
if lock_beat.owned():
lock_beat.release()
@@ -282,13 +293,19 @@ def try_creating_permissions_sync_task(
redis_connector.permissions.set_fence(payload)
payload_id = payload.id
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_permissions_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_permissions_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -388,6 +405,30 @@ def connector_permission_sync_generator_task(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
tenant_id,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -439,6 +480,10 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
)
task_logger.exception(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
)
@@ -473,6 +518,8 @@ def update_external_document_permissions_task(
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
@@ -512,13 +559,28 @@ def update_external_document_permissions_task(
f"elapsed={elapsed:.2f}"
)
except Exception:
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
)
task_logger.exception(
f"Exception in update_external_document_permissions_task: "
f"update_external_document_permissions_task exceptioned: "
f"connector_id={connector_id} doc_id={doc_id}"
)
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
finally:
task_logger.info(
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
task_logger.info(
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
)
return True

View File

@@ -37,8 +37,11 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -55,6 +58,7 @@ from onyx.redis.redis_connector_ext_group_sync import (
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -148,7 +152,10 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session, source, only_sync=True
db_session,
source,
access_type=AccessType.SYNC,
status=ConnectorCredentialPairStatus.ACTIVE,
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
@@ -195,12 +202,17 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_external_group_sync exception: tenant={tenant_id} {error_msg}"
)
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_external_group_sync finished: tenant={tenant_id}")
return True
@@ -267,12 +279,19 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.set_fence(payload)
payload_id = payload.id
except Exception:
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_external_group_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
task_logger.info(
f"try_creating_external_group_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -361,12 +380,37 @@ def connector_external_group_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
tenant_id,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -378,8 +422,18 @@ def connector_external_group_sync_generator_task(
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise e
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
@@ -405,6 +459,14 @@ def connector_external_group_sync_generator_task(
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id} {error_msg}"
)
task_logger.exception(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
)
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
task_logger.exception(msg)
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)

View File

@@ -48,7 +48,7 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -899,6 +899,9 @@ def connector_indexing_proxy_task(
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
start = time.monotonic()
@@ -924,6 +927,7 @@ def connector_indexing_proxy_task(
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting connector_indexing_task with tenant_id={tenant_id}")
job = client.submit(
connector_indexing_task,
@@ -1016,7 +1020,7 @@ def connector_indexing_proxy_task(
job.release()
break
# if a termination signal is detected, clean up and break
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
log_builder.build("Indexing watchdog - termination signal detected")
@@ -1025,6 +1029,7 @@ def connector_indexing_proxy_task(
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
break
# if activity timeout is detected, break (exit point will clean up)
if not redis_connector_index.connector_active():
task_logger.warning(
log_builder.build(
@@ -1033,25 +1038,6 @@ def connector_indexing_proxy_task(
)
)
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
result.status = (
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
)
@@ -1070,15 +1056,15 @@ def connector_indexing_proxy_task(
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
@@ -1139,8 +1125,6 @@ def connector_indexing_proxy_task(
"Connector termination signal detected",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
@@ -1148,6 +1132,25 @@ def connector_indexing_proxy_task(
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(

View File

@@ -55,6 +55,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
@@ -194,12 +195,14 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(f"Unexpected pruning check exception: {error_msg}")
task_logger.exception("Unexpected exception during pruning check")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_pruning finished: tenant={tenant_id}")
return True
@@ -301,13 +304,19 @@ def try_creating_prune_generator_task(
redis_connector.prune.set_fence(payload)
payload_id = payload.id
except Exception:
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_prune_generator_task exception: cc_pair={cc_pair.id} {error_msg}"
)
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_prune_generator_task finished: cc_pair={cc_pair.id} payload_id={payload_id}"
)
return payload_id

View File

@@ -1,4 +1,5 @@
import time
from enum import Enum
from http import HTTPStatus
import httpx
@@ -45,6 +46,24 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
class OnyxCeleryTaskCompletionStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SKIPPED = "skipped"
SOFT_TIME_LIMIT = "soft_time_limit"
NON_RETRYABLE_EXCEPTION = "non_retryable_exception"
RETRYABLE_EXCEPTION = "retryable_exception"
@shared_task(
name=OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
@@ -78,6 +97,8 @@ def document_by_cc_pair_cleanup_task(
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
action = "skip"
@@ -110,6 +131,9 @@ def document_by_cc_pair_cleanup_task(
db_session=db_session,
document_ids=[document_id],
)
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
elif count > 1:
action = "update"
@@ -153,10 +177,11 @@ def document_by_cc_pair_cleanup_task(
)
mark_document_as_synced(document_id, db_session)
else:
pass
db_session.commit()
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
else:
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
elapsed = time.monotonic() - start
task_logger.info(
@@ -168,57 +193,79 @@ def document_by_cc_pair_cleanup_task(
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"document_by_cc_pair_cleanup_task exceptioned: doc={document_id}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
return False
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(f"Unexpected exception: doc={document_id}")
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"document_by_cc_pair_cleanup_task completed: status={completion_status.value} doc={document_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
task_logger.info(f"document_by_cc_pair_cleanup_task finished: doc={document_id}")
return True

View File

@@ -19,6 +19,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -527,6 +528,8 @@ def vespa_metadata_sync_task(
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
@@ -540,75 +543,103 @@ def vespa_metadata_sync_task(
doc = get_document(document_id, db_session)
if not doc:
return False
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=no_operation "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
else:
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"vespa_metadata_sync_task exceptioned: doc={document_id}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
return False
task_logger.exception(
f"Unexpected exception during vespa metadata sync: doc={document_id}"
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"vespa_metadata_sync_task completed: status={completion_status.value} doc={document_id}"
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
return True

View File

@@ -1,3 +1,5 @@
from sqlalchemy.exc import IntegrityError
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_current_tenant
@@ -10,4 +12,9 @@ def emit_background_error(
In the future, could create notifications based on the severity."""
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)
try:
create_background_error(db_session, message, cc_pair_id)
except IntegrityError as e:
# Log an error if the cc_pair_id was deleted or any other exception occurs
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
create_background_error(db_session, error_message, None)

View File

@@ -17,6 +17,9 @@ from typing import Optional
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -54,6 +57,15 @@ def _initializer(
kwargs = {}
logger.info("Initializing spawned worker child process.")
# 1. Get tenant_id from args or fallback to default
tenant_id = POSTGRES_DEFAULT_SCHEMA
for arg in reversed(args):
if isinstance(arg, str) and arg.startswith(TENANT_ID_PREFIX):
tenant_id = arg
break
# 2. Set the tenant context before running anything
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Reset the engine in the child process
SqlEngine.reset_engine()
@@ -81,6 +93,8 @@ def _initializer(
queue.put(error_msg) # Send the exception to the parent process
sys.exit(255) # use 255 to indicate a generic exception
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _run_in_process(

View File

@@ -15,13 +15,14 @@ from onyx.background.indexing.memory_tracer import MemoryTracer
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
@@ -89,8 +90,8 @@ def _get_connector_runner(
)
# validate the connector settings
runnable_connector.validate_connector_settings()
if not INTEGRATION_TESTS_MODE:
runnable_connector.validate_connector_settings()
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")

View File

@@ -747,14 +747,13 @@ def stream_chat_message_objects(
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config),
system_message=default_build_system_message(prompt_config, llm.config),
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,
raw_user_uploaded_files=latest_query_files or [],
single_message_history=single_message_history,
)
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
# LLM prompt building, response capturing, etc.
answer = Answer(
@@ -870,7 +869,6 @@ def stream_chat_message_objects(
for img in img_generation_response
if img.image_data
],
tenant_id=tenant_id,
)
info.ai_message_files.extend(
[

View File

@@ -12,6 +12,7 @@ from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_toke
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLMConfig
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
@@ -19,6 +20,7 @@ from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
@@ -31,8 +33,16 @@ from onyx.tools.tool import Tool
def default_build_system_message(
prompt_config: PromptConfig,
llm_config: LLMConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
# for o-series markdown generation
if (
llm_config.model_provider == OPENAI_PROVIDER_NAME
and llm_config.model_name.startswith("o")
):
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
@@ -110,21 +120,8 @@ class AnswerPromptBuilder:
),
)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
(
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
)
if system_message
else None
)
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(
user_message,
self.llm_tokenizer_encode_func,
),
)
self.update_system_prompt(system_message)
self.update_user_prompt(user_message)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []

View File

@@ -626,6 +626,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"

View File

@@ -342,6 +342,9 @@ class OnyxRedisSignals:
BLOCK_PRUNING = "signal:block_pruning"
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
"signal:block_validate_connector_deletion_fences"
)
class OnyxRedisConstants:

View File

@@ -15,14 +15,14 @@ from mypy_boto3_s3 import S3Client # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import BlobType
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section

View File

@@ -9,10 +9,10 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.bookstack.client import BookStackApiClient
from onyx.connectors.bookstack.client import BookStackClientRequestFailedError
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch

View File

@@ -4,6 +4,8 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from requests.exceptions import HTTPError
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@@ -16,6 +18,10 @@ from onyx.connectors.confluence.utils import build_confluence_document_id
from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -397,3 +403,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list
def validate_connector_settings(self) -> None:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence credentials not loaded.")
try:
spaces = self._confluence_client.get_all_spaces(limit=1)
except HTTPError as e:
status_code = e.response.status_code if e.response else None
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Confluence credentials (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Insufficient permissions to access Confluence resources (HTTP 403)."
)
raise UnexpectedError(
f"Unexpected Confluence error (status={status_code}): {e}"
)
except Exception as e:
raise UnexpectedError(
f"Unexpected error while validating Confluence settings: {e}"
)
if not spaces or not spaces.get("results"):
raise ConnectorValidationError(
"No Confluence spaces found. Either your credentials lack permissions, or "
"there truly are no spaces in this Confluence instance."
)

View File

@@ -11,6 +11,9 @@ from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from requests import HTTPError
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -160,7 +163,7 @@ class OnyxConfluence(Confluence):
)
def _paginate_url(
self, url_suffix: str, limit: int | None = None
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -235,9 +238,41 @@ class OnyxConfluence(Confluence):
raise e
# yield the results individually
yield from next_response.get("results", [])
results = cast(list[dict[str, Any]], next_response.get("results", []))
yield from results
url_suffix = next_response.get("_links", {}).get("next")
old_url_suffix = url_suffix
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
# make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
if url_suffix and "start" in url_suffix:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.warning(
f"Start was updated by more than the amount of results "
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start
adjusted_start = previous_start + len(results)
url_suffix = update_param_in_path(
url_suffix, "start", str(adjusted_start)
)
# some APIs don't properly paginate, so we need to manually update the `start` param
if auto_paginate and len(results) > 0:
previous_start = get_start_param_from_url(old_url_suffix)
updated_start = previous_start + len(results)
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
def paginated_cql_retrieval(
self,
@@ -297,7 +332,9 @@ class OnyxConfluence(Confluence):
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
for user_result in self._paginate_url(url, limit):
# endpoint doesn't properly paginate, so we need to manually update the `start` param
# thus the auto_paginate flag
for user_result in self._paginate_url(url, limit, auto_paginate=True):
# Example response:
# {
# 'user': {
@@ -508,11 +545,15 @@ def build_confluence_client(
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
try:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
except Exception as e:
raise ConnectorValidationError(str(e))
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present

View File

@@ -2,7 +2,10 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TYPE_CHECKING
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
@@ -10,13 +13,13 @@ from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.onyx_confluence import (
OnyxConfluence,
)
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
logger = setup_logger()
@@ -24,7 +27,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
confluence_client: "OnyxConfluence", user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
@@ -47,7 +50,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
@@ -78,7 +81,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_client: "OnyxConfluence",
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
@@ -191,7 +194,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
def attachment_to_content(
confluence_client: OnyxConfluence,
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
@@ -279,3 +282,32 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def get_single_param_from_url(url: str, param: str) -> str | None:
"""Get a parameter from a url"""
parsed_url = urlparse(url)
return parse_qs(parsed_url.query).get(param, [None])[0]
def get_start_param_from_url(url: str) -> int:
"""Get the start parameter from a url"""
start_str = get_single_param_from_url(url, "start")
if start_str is None:
return 0
return int(start_str)
def update_param_in_path(path: str, param: str, value: str) -> str:
"""Update a parameter in a path. Path should look something like:
/api/rest/users?start=0&limit=10
"""
parsed_url = urlparse(path)
query_params = parse_qs(parsed_url.query)
query_params[param] = [value]
return (
path.split("?")[0]
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)

View File

@@ -10,10 +10,10 @@ from dropbox.files import FolderMetadata # type:ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialInvalidError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialInvalidError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch

View File

@@ -0,0 +1,49 @@
class ValidationError(Exception):
"""General exception for validation errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class ConnectorValidationError(ValidationError):
"""General exception for connector validation errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class UnexpectedError(ValidationError):
"""Raised when an unexpected error occurs during connector validation.
Unexpected errors don't necessarily mean the credential is invalid,
but rather that there was an error during the validation process
or we encountered a currently unhandled error case.
"""
def __init__(self, message: str = "Unexpected error during connector validation"):
super().__init__(message)
class CredentialInvalidError(ConnectorValidationError):
"""Raised when a connector's credential is invalid."""
def __init__(self, message: str = "Credential is invalid"):
super().__init__(message)
class CredentialExpiredError(ConnectorValidationError):
"""Raised when a connector's credential is expired."""
def __init__(self, message: str = "Credential has expired"):
super().__init__(message)
class InsufficientPermissionsError(ConnectorValidationError):
"""Raised when the credential does not have sufficient API permissions."""
def __init__(
self, message: str = "Insufficient permissions for the requested operation"
):
super().__init__(message)

View File

@@ -3,6 +3,7 @@ from typing import Type
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceRequiringTenantContext
from onyx.connectors.airtable.airtable_connector import AirtableConnector
@@ -17,6 +18,7 @@ from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
from onyx.connectors.dropbox.connector import DropboxConnector
from onyx.connectors.egnyte.connector import EgnyteConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.fireflies.connector import FirefliesConnector
from onyx.connectors.freshdesk.connector import FreshdeskConnector
@@ -31,7 +33,6 @@ from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -55,9 +56,8 @@ from onyx.connectors.zendesk.connector import ZendeskConnector
from onyx.connectors.zulip.connector import ZulipConnector
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.models import Credential
from onyx.db.models import User
class ConnectorMissingException(Exception):
@@ -184,23 +184,27 @@ def validate_ccpair_for_user(
connector_id: int,
credential_id: int,
db_session: Session,
user: User | None,
tenant_id: str | None,
) -> None:
enforce_creation: bool = True,
) -> bool:
if INTEGRATION_TESTS_MODE:
return True
# Validate the connector settings
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id_for_user(
credential = fetch_credential_by_id(
credential_id,
user,
db_session,
get_editable=False,
)
if not connector:
raise ValueError("Connector not found")
if connector.source == DocumentSource.INGESTION_API:
return
if (
connector.source == DocumentSource.INGESTION_API
or connector.source == DocumentSource.MOCK_CONNECTOR
):
return True
if not credential:
raise ValueError("Credential not found")
@@ -214,7 +218,13 @@ def validate_ccpair_for_user(
credential=credential,
tenant_id=tenant_id,
)
except ConnectorValidationError as e:
raise e
except Exception as e:
raise ConnectorValidationError(str(e))
if enforce_creation:
raise ConnectorValidationError(str(e))
else:
return False
runnable_connector.validate_connector_settings()
return True

View File

@@ -229,16 +229,20 @@ class GitbookConnector(LoadConnector, PollConnector):
try:
content = self.client.get(f"/spaces/{self.space_id}/content")
pages = content.get("pages", [])
pages: list[dict[str, Any]] = content.get("pages", [])
current_batch: list[Document] = []
for page in pages:
updated_at = datetime.fromisoformat(page["updatedAt"])
while pages:
page = pages.pop(0)
updated_at_raw = page.get("updatedAt")
if updated_at_raw is None:
# if updatedAt is not present, that means the page has never been edited
continue
updated_at = datetime.fromisoformat(updated_at_raw)
if start and updated_at < start:
if current_batch:
yield current_batch
return
continue
if end and updated_at > end:
continue
@@ -250,6 +254,8 @@ class GitbookConnector(LoadConnector, PollConnector):
yield current_batch
current_batch = []
pages.extend(page.get("pages", []))
if current_batch:
yield current_batch

View File

@@ -17,14 +17,14 @@ from github.PullRequest import PullRequest
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
@@ -124,7 +124,7 @@ class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repo_name: str,
repo_name: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
@@ -162,53 +162,81 @@ class GithubConnector(LoadConnector, PollConnector):
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
def _get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
# Try to get organization first
try:
org = github_client.get_organization(self.repo_owner)
return list(org.get_repos())
except GithubException:
# If not an org, try as a user
user = github_client.get_user(self.repo_owner)
return list(user.get_repos())
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_all_repos(github_client, attempt_num + 1)
def _fetch_from_github(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repo = self._get_github_repo(self.github_client)
repos = (
[self._get_github_repo(self.github_client)]
if self.repo_name
else self._get_all_repos(self.github_client)
)
if self.include_prs:
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
for repo in repos:
if self.include_prs:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
return
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
break
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
if self.include_issues:
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
if self.include_issues:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
return
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
break
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_github()
@@ -234,16 +262,26 @@ class GithubConnector(LoadConnector, PollConnector):
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
if not self.repo_owner or not self.repo_name:
if not self.repo_owner:
raise ConnectorValidationError(
"Invalid connector settings: 'repo_owner' and 'repo_name' must be provided."
"Invalid connector settings: 'repo_owner' must be provided."
)
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
if self.repo_name:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
org = self.github_client.get_organization(self.repo_owner)
org.get_repos().totalCount # Just check if we can access repos
except GithubException:
# If not an org, try as a user
user = self.github_client.get_user(self.repo_owner)
user.get_repos().totalCount # Just check if we can access repos
except RateLimitExceededException:
raise UnexpectedError(
@@ -260,9 +298,14 @@ class GithubConnector(LoadConnector, PollConnector):
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
if self.repo_name:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"

View File

@@ -305,6 +305,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
continue_on_404_or_403=True,
)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
@@ -336,6 +337,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
doc_batch.append(
SlimDocument(

View File

@@ -13,6 +13,9 @@ from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.google_drive.doc_conversion import build_slim_document
from onyx.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
@@ -42,6 +45,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -137,7 +141,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
raise ValueError(
raise ConnectorValidationError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
@@ -151,7 +155,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
and not my_drive_emails
and not shared_drive_urls
):
raise ValueError(
raise ConnectorValidationError(
"Nothing to index. Please specify at least one of the following: "
"include_shared_drives, include_my_drives, include_files_shared_with_me, "
"shared_folder_urls, or my_drive_emails"
@@ -220,7 +224,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
try:
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
except KeyError:
raise ValueError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
@@ -602,3 +613,50 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def validate_connector_settings(self) -> None:
if self._creds is None:
raise ConnectorMissingCredentialError(
"Google Drive credentials not loaded."
)
if self._primary_admin_email is None:
raise ConnectorValidationError(
"Primary admin email not found in credentials. "
"Ensure DB_CREDENTIALS_PRIMARY_ADMIN_KEY is set."
)
try:
drive_service = get_drive_service(self._creds, self._primary_admin_email)
drive_service.files().list(pageSize=1, fields="files(id)").execute()
if isinstance(self._creds, ServiceAccountCredentials):
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
status_code = e.resp.status if e.resp else None
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Google Drive credentials (401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Google Drive app lacks required permissions (403). "
"Please ensure the necessary scopes are granted and Drive "
"apps are enabled."
)
else:
raise ConnectorValidationError(
f"Unexpected Google Drive error (status={status_code}): {e}"
)
except Exception as e:
# Check for scope-related hints from the error message
if MISSING_SCOPES_ERROR_STR in str(e):
raise InsufficientPermissionsError(
"Google Drive credentials are missing required scopes. "
f"{ONYX_SCOPE_INSTRUCTIONS}"
)
raise ConnectorValidationError(
f"Unexpected error during Google Drive validation: {e}"
)

View File

@@ -146,46 +146,3 @@ class CheckpointConnector(BaseConnector):
```
"""
raise NotImplementedError
class ConnectorValidationError(Exception):
"""General exception for connector validation errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class UnexpectedError(Exception):
"""Raised when an unexpected error occurs during connector validation.
Unexpected errors don't necessarily mean the credential is invalid,
but rather that there was an error during the validation process
or we encountered a currently unhandled error case.
"""
def __init__(self, message: str = "Unexpected error during connector validation"):
super().__init__(message)
class CredentialInvalidError(ConnectorValidationError):
"""Raised when a connector's credential is invalid."""
def __init__(self, message: str = "Credential is invalid"):
super().__init__(message)
class CredentialExpiredError(ConnectorValidationError):
"""Raised when a connector's credential is expired."""
def __init__(self, message: str = "Credential has expired"):
super().__init__(message)
class InsufficientPermissionsError(ConnectorValidationError):
"""Raised when the credential does not have sufficient API permissions."""
def __init__(
self, message: str = "Insufficient permissions for the requested operation"
):
super().__init__(message)

View File

@@ -16,10 +16,11 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
)
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@@ -670,12 +671,12 @@ class NotionConnector(LoadConnector, PollConnector):
"Please try again later."
)
else:
raise Exception(
raise UnexpectedError(
f"Unexpected Notion HTTP error (status={status_code}): {http_err}"
) from http_err
except Exception as exc:
raise Exception(
raise UnexpectedError(
f"Unexpected error during Notion settings validation: {exc}"
)

View File

@@ -12,11 +12,11 @@ from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@@ -29,7 +29,6 @@ from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
from onyx.connectors.onyx_jira.utils import build_jira_client
from onyx.connectors.onyx_jira.utils import build_jira_url
from onyx.connectors.onyx_jira.utils import extract_jira_project
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
from onyx.connectors.onyx_jira.utils import get_comment_strs
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -160,7 +159,8 @@ def fetch_jira_issues_batch(
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
jira_project_url: str,
jira_base_url: str,
project_key: str | None = None,
comment_email_blacklist: list[str] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
# if a ticket has one of the labels specified in this list, we will just
@@ -169,12 +169,13 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
) -> None:
self.batch_size = batch_size
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
self._jira_client: JIRA | None = None
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
self.jira_project = project_key
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
self._jira_client: JIRA | None = None
@property
def comment_email_blacklist(self) -> tuple:
return tuple(email.strip() for email in self._comment_email_blacklist)
@@ -188,7 +189,9 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
@property
def quoted_jira_project(self) -> str:
# Quote the project name to handle reserved words
return f'"{self._jira_project}"'
if not self.jira_project:
return ""
return f'"{self.jira_project}"'
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
@@ -197,8 +200,14 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _get_jql_query(self) -> str:
"""Get the JQL query based on whether a specific project is set"""
if self.jira_project:
return f"project = {self.quoted_jira_project}"
return "" # Empty string means all accessible projects
def load_from_state(self) -> GenerateDocumentsOutput:
jql = f"project = {self.quoted_jira_project}"
jql = self._get_jql_query()
document_batch = []
for doc in fetch_jira_issues_batch(
@@ -225,11 +234,10 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
base_jql = self._get_jql_query()
jql = (
f"project = {self.quoted_jira_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)
f"{base_jql} AND " if base_jql else ""
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
document_batch = []
for doc in fetch_jira_issues_batch(
@@ -252,7 +260,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}"
jql = self._get_jql_query()
slim_doc_batch = []
for issue in _paginate_jql_search(
@@ -279,43 +287,63 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
if not self._jira_project:
raise ConnectorValidationError(
"Invalid connector settings: 'jira_project' must be provided."
)
# If a specific project is set, validate it exists
if self.jira_project:
try:
self.jira_client.project(self.jira_project)
except Exception as e:
status_code = getattr(e, "status_code", None)
try:
self.jira_client.project(self._jira_project)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
)
elif status_code == 404:
raise ConnectorValidationError(
f"Jira project not found with key: {self.jira_project}"
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
except Exception as e:
status_code = getattr(e, "status_code", None)
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
else:
# If no project specified, validate we can access the Jira API
try:
# Try to list projects to validate access
self.jira_client.projects()
except Exception as e:
status_code = getattr(e, "status_code", None)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions to list projects (HTTP 403)."
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
)
elif status_code == 404:
raise ConnectorValidationError(
f"Jira project not found with key: {self._jira_project}"
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
else:
raise Exception(f"Unexpected Jira error during validation: {e}")
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
if __name__ == "__main__":
import os
connector = JiraConnector(
os.environ["JIRA_PROJECT_URL"], comment_email_blacklist=[]
jira_base_url=os.environ["JIRA_BASE_URL"],
project_key=os.environ.get("JIRA_PROJECT_KEY"),
comment_email_blacklist=[],
)
connector.load_credentials(
{
"jira_user_email": os.environ["JIRA_USER_EMAIL"],

View File

@@ -18,15 +18,15 @@ from slack_sdk.errors import SlackApiError
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -86,14 +86,14 @@ def get_channels(
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace"""
"""Get all channels in the workspace."""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# try getting private channels as well at first
# Try fetching both public and private channels first:
try:
channels = _collect_paginated_channels(
client=client,
@@ -101,19 +101,19 @@ def get_channels(
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
logger.info(
f"Unable to fetch private channels due to: {e}. Trying again without private channels."
)
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch")
logger.warning("No channels to fetch.")
return []
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
return channels
@@ -671,19 +671,32 @@ class SlackConnector(SlimConnector, CheckpointConnector):
return checkpoint
def validate_connector_settings(self) -> None:
"""
1. Verify the bot token is valid for the workspace (via auth_test).
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists.
"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
try:
# Minimal API call to confirm we can list channels
# We set limit=1 for a lightweight check
response = self.client.conversations_list(limit=1, types=["public_channel"])
# Just ensure Slack responded "ok: True"
if not response.get("ok", False):
error_msg = response.get("error", "Unknown error from Slack")
# 1) Validate connection to workspace
auth_response = self.client.auth_test()
if not auth_response.get("ok", False):
error_msg = auth_response.get(
"error", "Unknown error from Slack auth_test"
)
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
# 2) Minimal test to confirm listing channels works
test_resp = self.client.conversations_list(
limit=1, types=["public_channel"]
)
if not test_resp.get("ok", False):
error_msg = test_resp.get("error", "Unknown error from Slack")
if error_msg == "invalid_auth":
raise ConnectorValidationError(
f"Invalid or expired Slack bot token ({error_msg})."
f"Invalid Slack bot token ({error_msg})."
)
elif error_msg == "not_authed":
raise CredentialExpiredError(
@@ -691,31 +704,48 @@ class SlackConnector(SlimConnector, CheckpointConnector):
)
raise UnexpectedError(f"Slack API returned a failure: {error_msg}")
# 3) If channels are specified, verify each is accessible
if self.channels:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,
get_public=True,
get_private=True,
)
# For quick lookups by name or ID, build a map:
accessible_channel_names = {ch["name"] for ch in accessible_channels}
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
for user_channel in self.channels:
if (
user_channel not in accessible_channel_names
and user_channel not in accessible_channel_ids
):
raise ConnectorValidationError(
f"Channel '{user_channel}' not found or inaccessible in this workspace."
)
except SlackApiError as e:
slack_error = e.response.get("error", "")
if slack_error == "missing_scope":
# The needed scope is typically "channels:read" or "groups:read"
# for viewing channels. The error message might also contain the
# specific scope needed vs. provided.
raise InsufficientPermissionsError(
"Slack bot token lacks the necessary scope to list channels. "
"Please ensure your Slack app has 'channels:read' (or 'groups:read' for private channels) enabled."
"Slack bot token lacks the necessary scope to list/access channels. "
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
)
elif slack_error == "invalid_auth":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
f"Invalid Slack bot token ({slack_error})."
)
elif slack_error == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
else:
# Generic Slack error
raise UnexpectedError(
f"Unexpected Slack error '{slack_error}' during settings validation."
)
raise UnexpectedError(
f"Unexpected Slack error '{slack_error}' during settings validation."
)
except ConnectorValidationError as e:
raise e
except Exception as e:
# Catch-all for unexpected exceptions
raise UnexpectedError(
f"Unexpected error during Slack settings validation: {e}"
)

View File

@@ -13,14 +13,14 @@ from office365.teams.team import Team # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
@@ -285,17 +285,6 @@ class TeamsConnector(LoadConnector, PollConnector):
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
def validate_connector_settings(self) -> None:
"""
Validate that we can connect to Microsoft Teams with the provided MSAL/Graph credentials
and that we can see at least one Team. If the user has specified a list of Teams by name,
confirm at least one of them is found.
Raises:
ConnectorMissingCredentialError: If the Graph client is not yet set (missing credentials).
CredentialExpiredError: If credentials appear invalid/expired (e.g. 401 Unauthorized).
InsufficientPermissionsError: If the app lacks required permissions to read Teams.
ConnectorValidationError: If no Teams are found, or if requested Teams are not found.
"""
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
@@ -303,7 +292,6 @@ class TeamsConnector(LoadConnector, PollConnector):
# Minimal call to confirm we can retrieve Teams
found_teams = self._get_all_teams()
# You may optionally catch the Graph/Office365 request exception if available:
except ClientRequestException as e:
status_code = e.response.status_code
if status_code == 401:
@@ -314,8 +302,7 @@ class TeamsConnector(LoadConnector, PollConnector):
raise InsufficientPermissionsError(
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
)
else:
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
except Exception as e:
error_str = str(e).lower()
@@ -335,7 +322,6 @@ class TeamsConnector(LoadConnector, PollConnector):
f"Unexpected error during Teams validation: {e}"
)
# If we get this far, the Graph call succeeded. Check for presence of Teams:
if not found_teams:
raise ConnectorValidationError(
"No Teams found for the given credentials. "

View File

@@ -25,12 +25,12 @@ from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_TOKEN_URL
from onyx.configs.app_configs import WEB_CONNECTOR_VALIDATE_URLS
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import read_pdf_file

View File

@@ -0,0 +1,152 @@
from typing import List
from typing import Optional
from typing import Tuple
from uuid import UUID
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import literal
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import union_all
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
def search_chat_sessions(
user_id: UUID | None,
db_session: Session,
query: Optional[str] = None,
page: int = 1,
page_size: int = 10,
include_deleted: bool = False,
include_onyxbot_flows: bool = False,
) -> Tuple[List[ChatSession], bool]:
"""
Search for chat sessions based on the provided query.
If no query is provided, returns recent chat sessions.
Returns a tuple of (chat_sessions, has_more)
"""
offset = (page - 1) * page_size
# If no search query, we use standard SQLAlchemy pagination
if not query or not query.strip():
stmt = select(ChatSession)
if user_id:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
# Apply pagination
stmt = stmt.offset(offset).limit(page_size + 1)
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
has_more = len(chat_sessions) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
return list(chat_sessions), has_more
words = query.lower().strip().split()
# Message mach subquery
message_matches = []
for word in words:
word_like = f"%{word}%"
message_match: Select = (
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
.where(func.lower(ChatMessage.message).like(word_like))
)
if user_id:
message_match = message_match.where(ChatSession.user_id == user_id)
message_matches.append(message_match)
if message_matches:
message_matches_query = union_all(*message_matches).alias("message_matches")
else:
return [], False
# Description matches
description_match: Select = select(
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
if user_id:
description_match = description_match.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
description_match = description_match.where(ChatSession.deleted.is_(False))
# Combine all match sources
combined_matches = union_all(
message_matches_query.select(), description_match
).alias("combined_matches")
# Use CTE to group and get max rank
session_ranks = (
select(
combined_matches.c.chat_session_id,
func.max(combined_matches.c.search_rank).label("rank"),
)
.group_by(combined_matches.c.chat_session_id)
.alias("session_ranks")
)
# Get ranked sessions with pagination
ranked_query = (
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
.offset(offset)
.limit(page_size + 1)
)
result = ranked_query.all()
# Extract session IDs and ranks
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
session_ids = list(session_ids_with_ranks.keys())
if not session_ids:
return [], False
# Now, let's query the actual ChatSession objects using the IDs
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
if user_id:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
# Full objects with eager loading
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
# Sort based on above ranking
chat_sessions = sorted(
chat_sessions,
key=lambda session: (
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
session.time_created.timestamp() * -1, # Then by time (newest first)
),
)
has_more = len(chat_sessions) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
return chat_sessions, has_more

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import TypeVarTuple
from fastapi import HTTPException
from sqlalchemy import delete
@@ -8,15 +9,18 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
@@ -31,10 +35,12 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
R = TypeVarTuple("R")
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
) -> Select[tuple[*R]]:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -98,17 +104,52 @@ def get_connector_credential_pairs_for_user(
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
if eager_load_user:
assert (
eager_load_credential
), "eager_load_credential must be True if eager_load_user is True"
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
if eager_load_user:
load_opts = load_opts.joinedload(Credential.user)
stmt = stmt.options(load_opts)
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
return list(db_session.scalars(stmt).unique().all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_connector_credential_pairs_for_user_parallel(
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_connector_credential_pairs_for_user(
db_session,
user,
get_editable,
ids,
eager_load_connector,
eager_load_credential,
eager_load_user,
)
def get_connector_credential_pairs(
@@ -151,6 +192,16 @@ def get_cc_pair_groups_for_ids(
return list(db_session.scalars(stmt).all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_cc_pair_groups_for_ids_parallel(
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
@@ -194,9 +245,14 @@ def get_connector_credential_pair_from_id_for_user(
def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
eager_load_credential: bool = False,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
if eager_load_credential:
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
result = db_session.execute(stmt)
return result.scalar_one_or_none()
@@ -396,8 +452,8 @@ def add_credential_to_connector(
# If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user
if seeding_flow:
credential = fetch_credential_by_id(
db_session=db_session,
credential_id=credential_id,
db_session=db_session,
)
else:
credential = fetch_credential_by_id_for_user(

View File

@@ -169,8 +169,8 @@ def fetch_credential_by_id_for_user(
def fetch_credential_by_id(
db_session: Session,
credential_id: int,
db_session: Session,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
@@ -422,8 +422,8 @@ def create_initial_public_credential(db_session: Session) -> None:
"There must exist an empty public credential for data connectors that do not require additional Auth."
)
first_credential = fetch_credential_by_id(
db_session=db_session,
credential_id=PUBLIC_CREDENTIAL_ID,
db_session=db_session,
)
if first_credential is not None:

View File

@@ -24,6 +24,7 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
@@ -60,9 +61,8 @@ def count_documents_by_needs_sync(session: Session) -> int:
This function executes the query and returns the count of
documents matching the criteria."""
count = (
session.query(func.count(DbDocument.id.distinct()))
.select_from(DbDocument)
return (
session.query(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
@@ -73,63 +73,53 @@ def count_documents_by_needs_sync(session: Session) -> int:
DbDocument.last_synced.is_(None),
)
)
.scalar()
.count()
)
return count
def construct_document_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
return (
select(DbDocument)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
),
)
)
.distinct()
)
return stmt
def construct_document_id_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
return (
select(DbDocument.id)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
),
)
)
.distinct()
)
return stmt
def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int
@@ -240,12 +230,12 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
# Prepare a list of (connector_id, credential_id) tuples
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
stmt = (
select(
@@ -271,6 +261,16 @@ def get_document_counts_for_cc_pairs(
return db_session.execute(stmt).all() # type: ignore
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_document_counts_for_cc_pairs_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
) -> Sequence[tuple[int, int, int]]:
with get_session_context_manager() as db_session:
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
def get_access_info_for_document(
db_session: Session,
document_id: str,

View File

@@ -218,6 +218,7 @@ class SqlEngine:
final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:

View File

@@ -2,6 +2,7 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import TypeVarTuple
from sqlalchemy import and_
from sqlalchemy import delete
@@ -9,9 +10,13 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
@@ -368,19 +373,33 @@ def get_latest_index_attempts_by_status(
return db_session.execute(stmt).scalars().all()
T = TypeVarTuple("T")
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
return stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
def get_latest_index_attempts(
secondary_index: bool,
db_session: Session,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
ids_stmt = select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
if secondary_index:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
else:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
ids_stmt = ids_stmt.where(SearchSettings.status == status)
if only_finished:
ids_stmt = _add_only_finished_clause(ids_stmt)
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
@@ -395,7 +414,53 @@ def get_latest_index_attempts(
.where(IndexAttempt.id == ids_subquery.c.max_id)
)
return db_session.execute(stmt).scalars().all()
if only_finished:
stmt = _add_only_finished_clause(stmt)
if eager_load_cc_pair:
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
return db_session.execute(stmt).scalars().unique().all()
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_latest_index_attempts_parallel(
secondary_index: bool,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,
eager_load_cc_pair,
only_finished,
)
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
)
if only_finished:
stmt = _add_only_finished_clause(stmt)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
def count_index_attempts_for_connector(
@@ -453,37 +518,12 @@ def get_paginated_index_attempts_for_cc_pair_id(
# Apply pagination
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.execute(stmt).scalars().all())
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
stmt = stmt.options(
contains_eager(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
if only_finished:
stmt = stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
return list(db_session.execute(stmt).scalars().unique().all())
def get_index_attempts_for_cc_pair(

View File

@@ -570,6 +570,14 @@ class Document(Base):
back_populates="documents",
)
__table_args__ = (
Index(
"ix_document_sync_status",
last_modified,
last_synced,
),
)
class Tag(Base):
__tablename__ = "tag"

View File

@@ -8,7 +8,7 @@ import requests
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import FileDescriptor
@@ -53,11 +53,11 @@ def load_all_chat_files(
return files
def save_file_from_url(url: str, tenant_id: str) -> str:
def save_file_from_url(url: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in
weird errors."""
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
response = requests.get(url)
response.raise_for_status()
@@ -75,8 +75,8 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
return unique_id
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def save_file_from_base64(base64_string: str) -> str:
with get_session_with_current_tenant() as db_session:
unique_id = str(uuid4())
file_store = get_default_file_store(db_session)
file_store.save_file(
@@ -90,14 +90,12 @@ def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
def save_file(
tenant_id: str,
url: str | None = None,
base64_data: str | None = None,
) -> str:
"""Save a file from either a URL or base64 encoded string.
Args:
tenant_id: The tenant ID to save the file under
url: URL to download file from
base64_data: Base64 encoded file data
@@ -111,22 +109,22 @@ def save_file(
raise ValueError("Cannot specify both url and base64_data")
if url is not None:
return save_file_from_url(url, tenant_id)
return save_file_from_url(url)
elif base64_data is not None:
return save_file_from_base64(base64_data, tenant_id)
return save_file_from_base64(base64_data)
else:
raise ValueError("Must specify either url or base64_data")
def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
def save_files(urls: list[str], base64_files: list[str]) -> list[str]:
# NOTE: be explicit about typing so that if we change things, we get notified
funcs: list[
tuple[
Callable[[str, str | None, str | None], str],
tuple[str, str | None, str | None],
Callable[[str | None, str | None], str],
tuple[str | None, str | None],
]
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
] = [(save_file, (url, None)) for url in urls] + [
(save_file, (None, base64_file)) for base64_file in base64_files
]
return run_functions_tuples_in_parallel(funcs)

View File

@@ -173,7 +173,10 @@ def index_doc_batch_with_handler(
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(f"Failed to index document batch: {document_batch}")
# don't log the batch directly, it's too much text
document_ids = [doc.id for doc in document_batch]
logger.exception(f"Failed to index document batch: {document_ids}")
index_pipeline_result = IndexingPipelineResult(
new_docs=0,
total_docs=len(document_batch),

View File

@@ -103,7 +103,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
api_version_required=False,
custom_config_keys=[],
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
default_model="claude-3-5-sonnet-20241022",
default_model="claude-3-7-sonnet-20250219",
default_fast_model="claude-3-5-sonnet-20241022",
),
WellKnownLLMProviderDescriptor(

View File

@@ -17,10 +17,12 @@ from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.chat.models import ThreadMessage
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import POD_NAME
@@ -249,7 +251,12 @@ class SlackbotHandler:
- If yes, store them in self.tenant_ids and manage the socket connections.
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
all_tenants = get_all_tenant_ids()
all_tenants = [
tenant_id
for tenant_id in get_all_tenant_ids()
if tenant_id not in get_gated_tenants()
]
token: Token[str | None]
@@ -416,6 +423,7 @@ class SlackbotHandler:
try:
bot_info = socket_client.web_client.auth_test()
if bot_info["ok"]:
bot_user_id = bot_info["user_id"]
user_info = socket_client.web_client.users_info(user=bot_user_id)
@@ -426,9 +434,23 @@ class SlackbotHandler:
logger.info(
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
)
except SlackApiError as e:
# Only error out if we get a not_authed error
if "not_authed" in str(e):
self.tenant_ids.add(tenant_id)
logger.error(
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
"Error: {e}"
)
return
# Log other Slack API errors but continue
logger.error(
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
)
except Exception as e:
logger.warning(
f"Could not fetch bot name: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
# Log other exceptions but continue
logger.error(
f"Error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
)
# Append the event handler

View File

@@ -18,6 +18,7 @@ Remember to provide inline citations in the format [1], [2], [3], etc.
ADDITIONAL_INFO = "\n\nAdditional Information:\n\t- {datetime_info}."
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "
CHAT_USER_PROMPT = f"""
Refer to the following context documents when responding to me.{{optional_ignore_statement}}

View File

@@ -33,6 +33,12 @@ class RedisConnectorDelete:
FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence"
TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset"
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -41,6 +47,8 @@ class RedisConnectorDelete:
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -77,6 +85,20 @@ class RedisConnectorDelete:
self.redis.set(self.fence_key, payload.model_dump_json())
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
def _generate_task_id(self) -> str:
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
@@ -141,6 +163,7 @@ class RedisConnectorDelete:
def reset(self) -> None:
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
self.redis.delete(self.active_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@@ -153,6 +176,9 @@ class RedisConnectorDelete:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorDelete.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@@ -93,10 +93,7 @@ class RedisConnectorIndex:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorIndexPayload | None:
@@ -106,9 +103,7 @@ class RedisConnectorIndex:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
return payload
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
def set_fence(
self,
@@ -123,10 +118,7 @@ class RedisConnectorIndex:
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def terminating(self, celery_task_id: str) -> bool:
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
return True
return False
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
@@ -146,10 +138,7 @@ class RedisConnectorIndex:
def watchdog_signaled(self) -> bool:
"""Check the state of the watchdog."""
if self.redis.exists(self.watchdog_key):
return True
return False
return bool(self.redis.exists(self.watchdog_key))
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -160,10 +149,7 @@ class RedisConnectorIndex:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
def set_connector_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -180,10 +166,7 @@ class RedisConnectorIndex:
return False
def generator_locked(self) -> bool:
if self.redis.exists(self.generator_lock_key):
return True
return False
return bool(self.redis.exists(self.generator_lock_key))
def set_generator_complete(self, payload: int | None) -> None:
if not payload:

View File

@@ -5,7 +5,13 @@ class RedisConnectorStop:
"""Manages interactions with redis for stop signaling. Should only be accessed
through RedisConnector."""
FENCE_PREFIX = "connectorstop_fence"
PREFIX = "connectorstop"
FENCE_PREFIX = f"{PREFIX}_fence"
# if this timeout is exceeded, the caller may decide to take more
# drastic measures
TIMEOUT_PREFIX = f"{PREFIX}_timeout"
TIMEOUT_TTL = 300
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
@@ -13,6 +19,7 @@ class RedisConnectorStop:
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.timeout_key: str = f"{self.TIMEOUT_PREFIX}_{id}"
@property
def fenced(self) -> bool:
@@ -28,7 +35,22 @@ class RedisConnectorStop:
self.redis.set(self.fence_key, 0)
@property
def timed_out(self) -> bool:
if self.redis.exists(self.timeout_key):
return False
return True
def set_timeout(self) -> None:
"""After calling this, call timed_out to determine if the timeout has been
exceeded."""
self.redis.set(f"{self.timeout_key}", 0, ex=self.TIMEOUT_TTL)
@staticmethod
def reset_all(r: redis.Redis) -> None:
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.TIMEOUT_PREFIX + "*"):
r.delete(key)

View File

@@ -25,8 +25,8 @@ from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.db.connector import delete_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.connector_credential_pair import (
@@ -123,15 +123,15 @@ def get_cc_pair_full_info(
)
is_editable_for_current_user = editable_cc_pair is not None
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
document_count_info_list = list(
get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=[cc_pair_identifier],
cc_pairs=[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
],
)
)
documents_indexed = (
@@ -620,9 +620,7 @@ def associate_credential_to_connector(
)
try:
validate_ccpair_for_user(
connector_id, credential_id, db_session, user, tenant_id
)
validate_ccpair_for_user(connector_id, credential_id, db_session, tenant_id)
response = add_credential_to_connector(
db_session=db_session,
@@ -649,7 +647,7 @@ def associate_credential_to_connector(
return response
except ConnectorValidationError as e:
except ValidationError as e:
# If validation fails, delete the connector and commit the changes
# Ensures we don't leave invalid connectors in the database
# NOTE: consensus is that it makes sense to unify connector and ccpair creation flows
@@ -660,7 +658,6 @@ def associate_credential_to_connector(
raise HTTPException(
status_code=400, detail="Connector validation error: " + str(e)
)
except IntegrityError as e:
logger.error(f"IntegrityError: {e}")
raise HTTPException(status_code=400, detail="Name must be unique")

View File

@@ -28,6 +28,7 @@ from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.connectors.google_utils.google_auth import (
get_google_oauth_creds,
@@ -62,7 +63,6 @@ from onyx.connectors.google_utils.shared_constants import DB_CREDENTIALS_DICT_TO
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.db.connector import create_connector
from onyx.db.connector import delete_connector
from onyx.db.connector import fetch_connector_by_id
@@ -72,25 +72,31 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector import update_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids_parallel
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_for_user_parallel,
)
from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import cleanup_google_drive_credentials
from onyx.db.credentials import create_credential
from onyx.db.credentials import delete_service_account_credentials
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
from onyx.db.document import get_document_counts_for_cc_pairs
from onyx.db.document import get_document_counts_for_cc_pairs_parallel
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import IndexingMode
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import get_latest_index_attempts
from onyx.db.index_attempt import get_latest_index_attempts_by_status
from onyx.db.index_attempt import get_latest_index_attempts_parallel
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.file_processing.extract_file_text import convert_docx_to_txt
@@ -119,8 +125,8 @@ from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -578,6 +584,8 @@ def get_connector_status(
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
eager_load_connector=True,
eager_load_credential=True,
)
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
@@ -632,23 +640,35 @@ def get_connector_indexing_status(
# Additional checks are done to make sure the connector and credential still exist.
# TODO: make this one query ... possibly eager load or wrap in a read transaction
# to avoid the complexity of trying to error check throughout the function
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
)
cc_pair_identifiers = [
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
for cc_pair in cc_pairs
]
latest_index_attempts = get_latest_index_attempts(
secondary_index=secondary_index,
db_session=db_session,
# see https://stackoverflow.com/questions/75758327/
# sqlalchemy-method-connection-for-bind-is-already-in-progress
# for why we can't pass in the current db_session to these functions
(
cc_pairs,
latest_index_attempts,
latest_finished_index_attempts,
) = run_functions_tuples_in_parallel(
[
(
# Gets the connector/credential pairs for the user
get_connector_credential_pairs_for_user_parallel,
(user, get_editable, None, True, True, True),
),
(
# Gets the most recent index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, False),
),
(
# Gets the most recent FINISHED index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, True),
),
]
)
cc_pairs = cast(list[ConnectorCredentialPair], cc_pairs)
latest_index_attempts = cast(list[IndexAttempt], latest_index_attempts)
cc_pair_to_latest_index_attempt = {
(
@@ -658,31 +678,60 @@ def get_connector_indexing_status(
for index_attempt in latest_index_attempts
}
document_count_info = get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=cc_pair_identifiers,
cc_pair_to_latest_finished_index_attempt = {
(
index_attempt.connector_credential_pair.connector_id,
index_attempt.connector_credential_pair.credential_id,
): index_attempt
for index_attempt in latest_finished_index_attempts
}
document_count_info, group_cc_pair_relationships = run_functions_tuples_in_parallel(
[
(
get_document_counts_for_cc_pairs_parallel,
(
[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
for cc_pair in cc_pairs
],
),
),
(
get_cc_pair_groups_for_ids_parallel,
([cc_pair.id for cc_pair in cc_pairs],),
),
]
)
document_count_info = cast(list[tuple[int, int, int]], document_count_info)
group_cc_pair_relationships = cast(
list[UserGroup__ConnectorCredentialPair], group_cc_pair_relationships
)
cc_pair_to_document_cnt = {
(connector_id, credential_id): cnt
for connector_id, credential_id, cnt in document_count_info
}
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
db_session=db_session,
cc_pair_ids=[cc_pair.id for cc_pair in cc_pairs],
)
group_cc_pair_relationships_dict: dict[int, list[int]] = {}
for relationship in group_cc_pair_relationships:
group_cc_pair_relationships_dict.setdefault(relationship.cc_pair_id, []).append(
relationship.user_group_id
)
search_settings: SearchSettings | None = None
if not secondary_index:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_secondary_search_settings(db_session)
connector_to_cc_pair_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_cc_pair_ids.setdefault(cc_pair.connector_id, []).append(cc_pair.id)
get_search_settings = (
get_secondary_search_settings
if secondary_index
else get_current_search_settings
)
search_settings = get_search_settings(db_session)
for cc_pair in cc_pairs:
# TODO remove this to enable ingestion API
if cc_pair.name == "DefaultCCPair":
@@ -705,11 +754,8 @@ def get_connector_indexing_status(
(connector.id, credential.id)
)
latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id(
db_session=db_session,
connector_credential_pair_id=cc_pair.id,
secondary_index=secondary_index,
only_finished=True,
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
(connector.id, credential.id)
)
indexing_statuses.append(
@@ -718,7 +764,9 @@ def get_connector_indexing_status(
name=cc_pair.name,
in_progress=in_progress,
cc_pair_status=cc_pair.status,
connector=ConnectorSnapshot.from_connector_db_model(connector),
connector=ConnectorSnapshot.from_connector_db_model(
connector, connector_to_cc_pair_ids.get(connector.id, [])
),
credential=CredentialSnapshot.from_credential_db_model(credential),
access_type=cc_pair.access_type,
owner=credential.user.email if credential.user else "",
@@ -854,7 +902,6 @@ def create_connector_with_mock_credential(
connector_id=connector_id,
credential_id=credential_id,
db_session=db_session,
user=user,
tenant_id=tenant_id,
)
response = add_credential_to_connector(

View File

@@ -106,7 +106,6 @@ def swap_credentials_for_connector(
credential_swap_req.connector_id,
credential_swap_req.new_credential_id,
db_session,
user,
tenant_id,
)

View File

@@ -83,7 +83,9 @@ class ConnectorSnapshot(ConnectorBase):
source: DocumentSource
@classmethod
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
def from_connector_db_model(
cls, connector: Connector, credential_ids: list[int] | None = None
) -> "ConnectorSnapshot":
return ConnectorSnapshot(
id=connector.id,
name=connector.name,
@@ -92,9 +94,10 @@ class ConnectorSnapshot(ConnectorBase):
connector_specific_config=connector.connector_specific_config,
refresh_freq=connector.refresh_freq,
prune_freq=connector.prune_freq,
credential_ids=[
association.credential.id for association in connector.credentials
],
credential_ids=(
credential_ids
or [association.credential.id for association in connector.credentials]
),
indexing_start=connector.indexing_start,
time_created=connector.time_created,
time_updated=connector.time_updated,

View File

@@ -1,5 +1,3 @@
from typing import Any
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
@@ -345,6 +343,9 @@ def list_bot_configs(
]
MAX_CHANNELS = 200
@router.get(
"/admin/slack-app/bots/{bot_id}/channels",
)
@@ -353,38 +354,40 @@ def get_all_channels_from_slack_api(
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> list[SlackChannel]:
"""
Fetches all channels from the Slack API.
If the workspace has 200 or more channels, we raise an error.
"""
tokens = fetch_slack_bot_tokens(db_session, bot_id)
if not tokens or "bot_token" not in tokens:
raise HTTPException(
status_code=404, detail="Bot token not found for the given bot ID"
)
bot_token = tokens["bot_token"]
client = WebClient(token=bot_token)
client = WebClient(token=tokens["bot_token"])
try:
channels = []
cursor = None
while True:
response = client.conversations_list(
types="public_channel,private_channel",
exclude_archived=True,
limit=1000,
cursor=cursor,
)
for channel in response["channels"]:
channels.append(SlackChannel(id=channel["id"], name=channel["name"]))
response = client.conversations_list(
types="public_channel,private_channel",
exclude_archived=True,
limit=MAX_CHANNELS,
)
response_metadata: dict[str, Any] = response.get("response_metadata", {})
if isinstance(response_metadata, dict):
cursor = response_metadata.get("next_cursor")
if not cursor:
break
else:
break
channels = [
SlackChannel(id=channel["id"], name=channel["name"])
for channel in response["channels"]
]
if len(channels) == MAX_CHANNELS:
raise HTTPException(
status_code=400,
detail=f"Workspace has {MAX_CHANNELS} or more channels.",
)
return channels
except SlackApiError as e:
raise HTTPException(
status_code=500, detail=f"Error fetching channels from Slack API: {str(e)}"
status_code=500,
detail=f"Error fetching channels from Slack API: {str(e)}",
)

View File

@@ -311,19 +311,23 @@ def bulk_invite_users(
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
number_of_invited_users = write_invited_users(all_emails)
# send out email invitations if enabled
if ENABLE_EMAIL_INVITES:
try:
for email in new_invited_emails:
send_user_email_invite(email, current_user, AUTH_TYPE)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
if not MULTI_TENANT:
return number_of_invited_users
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))
if ENABLE_EMAIL_INVITES:
try:
for email in new_invited_emails:
send_user_email_invite(email, current_user)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
return number_of_invited_users
except Exception as e:

View File

@@ -1,15 +1,18 @@
import asyncio
import datetime
import io
import json
import os
import uuid
from collections.abc import Callable
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
@@ -44,6 +47,7 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.feedback import create_chat_message_feedback
@@ -65,10 +69,13 @@ from onyx.secondary_llm_flows.chat_session_naming import (
from onyx.server.query_and_chat.models import ChatFeedbackRequest
from onyx.server.query_and_chat.models import ChatMessageIdentifier
from onyx.server.query_and_chat.models import ChatRenameRequest
from onyx.server.query_and_chat.models import ChatSearchResponse
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import ChatSessionDetailResponse
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionGroup
from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.server.query_and_chat.models import ChatSessionSummary
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import CreateChatSessionID
@@ -794,3 +801,84 @@ def fetch_chat_file(
file_io = file_store.read_file(file_id, mode="b")
return StreamingResponse(file_io, media_type=media_type)
@router.get("/search")
async def search_chats(
query: str | None = Query(None),
page: int = Query(1),
page_size: int = Query(10),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSearchResponse:
"""
Search for chat sessions based on the provided query.
If no query is provided, returns recent chat sessions.
"""
# Use the enhanced database function for chat search
chat_sessions, has_more = search_chat_sessions(
user_id=user.id if user else None,
db_session=db_session,
query=query,
page=page,
page_size=page_size,
include_deleted=False,
include_onyxbot_flows=False,
)
# Group chat sessions by time period
today = datetime.datetime.now().date()
yesterday = today - timedelta(days=1)
this_week = today - timedelta(days=7)
this_month = today - timedelta(days=30)
today_chats: list[ChatSessionSummary] = []
yesterday_chats: list[ChatSessionSummary] = []
this_week_chats: list[ChatSessionSummary] = []
this_month_chats: list[ChatSessionSummary] = []
older_chats: list[ChatSessionSummary] = []
for session in chat_sessions:
session_date = session.time_created.date()
chat_summary = ChatSessionSummary(
id=session.id,
name=session.description,
persona_id=session.persona_id,
time_created=session.time_created,
shared_status=session.shared_status,
folder_id=session.folder_id,
current_alternate_model=session.current_alternate_model,
current_temperature_override=session.temperature_override,
)
if session_date == today:
today_chats.append(chat_summary)
elif session_date == yesterday:
yesterday_chats.append(chat_summary)
elif session_date > this_week:
this_week_chats.append(chat_summary)
elif session_date > this_month:
this_month_chats.append(chat_summary)
else:
older_chats.append(chat_summary)
# Create groups
groups = []
if today_chats:
groups.append(ChatSessionGroup(title="Today", chats=today_chats))
if yesterday_chats:
groups.append(ChatSessionGroup(title="Yesterday", chats=yesterday_chats))
if this_week_chats:
groups.append(ChatSessionGroup(title="This Week", chats=this_week_chats))
if this_month_chats:
groups.append(ChatSessionGroup(title="This Month", chats=this_month_chats))
if older_chats:
groups.append(ChatSessionGroup(title="Older", chats=older_chats))
return ChatSearchResponse(
groups=groups,
has_more=has_more,
next_page=page + 1 if has_more else None,
)

View File

@@ -24,6 +24,7 @@ from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
pass
@@ -282,3 +283,35 @@ class AdminSearchRequest(BaseModel):
class AdminSearchResponse(BaseModel):
documents: list[SearchDoc]
class ChatSessionSummary(BaseModel):
id: UUID
name: str | None = None
persona_id: int | None = None
time_created: datetime
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
current_alternate_model: str | None = None
current_temperature_override: float | None = None
class ChatSessionGroup(BaseModel):
title: str
chats: list[ChatSessionSummary]
class ChatSearchResponse(BaseModel):
groups: list[ChatSessionGroup]
has_more: bool
next_page: int | None = None
class ChatSearchRequest(BaseModel):
query: str | None = None
page: int = 1
page_size: int = 10
class CreateChatResponse(BaseModel):
chat_session_id: str

View File

@@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit
@@ -21,7 +20,6 @@ from onyx.db.models import User
from onyx.db.token_limit import fetch_all_global_token_rate_limits
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -39,13 +37,13 @@ def check_token_rate_limits(
return
versioned_rate_limit_strategy = fetch_versioned_implementation(
"onyx.server.query_and_chat.token_limit", "_check_token_rate_limits"
"onyx.server.query_and_chat.token_limit", _check_token_rate_limits.__name__
)
return versioned_rate_limit_strategy(user, get_current_tenant_id())
return versioned_rate_limit_strategy(user)
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
_user_is_rate_limited_by_global(tenant_id)
def _check_token_rate_limits(_: User | None) -> None:
_user_is_rate_limited_by_global()
"""
@@ -53,8 +51,8 @@ Global rate limits
"""
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def _user_is_rate_limited_by_global() -> None:
with get_session_context_manager() as db_session:
global_rate_limits = fetch_all_global_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)

View File

@@ -47,8 +47,8 @@ class Settings(BaseModel):
anonymous_user_enabled: bool | None = None
pro_search_enabled: bool | None = None
temperature_override_enabled: bool = False
auto_scroll: bool = False
temperature_override_enabled: bool | None = False
auto_scroll: bool | None = False
class UserSettings(Settings):

View File

@@ -253,3 +253,8 @@ def print_loggers() -> None:
print(f" Propagate: {logger.propagate}")
print()
def format_error_for_logging(e: Exception) -> str:
"""Clean error message by removing newlines for better logging."""
return str(e).replace("\n", " ")

View File

@@ -1,3 +1,4 @@
import contextvars
import threading
import uuid
from collections.abc import Callable
@@ -14,10 +15,6 @@ logger = setup_logger()
R = TypeVar("R")
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
allow_failures: bool = False,
@@ -45,8 +42,11 @@ def run_functions_tuples_in_parallel(
results = []
with ThreadPoolExecutor(max_workers=workers) as executor:
# The primary reason for propagating contextvars is to allow acquiring a db session
# that respects tenant id. Context.run is expected to be low-overhead, but if we later
# find that it is increasing latency we can make using it optional.
future_to_index = {
executor.submit(func, *args): i
executor.submit(contextvars.copy_context().run, func, *args): i
for i, (func, args) in enumerate(functions_with_args)
}
@@ -83,10 +83,6 @@ class FunctionCall(Generic[R]):
return self.func(*self.args, **self.kwargs)
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_functions_in_parallel(
function_calls: list[FunctionCall],
allow_failures: bool = False,
@@ -102,7 +98,9 @@ def run_functions_in_parallel(
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {
executor.submit(func_call.execute): func_call.result_id
executor.submit(
contextvars.copy_context().run, func_call.execute
): func_call.result_id
for func_call in function_calls
}
@@ -143,10 +141,6 @@ class TimeoutThread(threading.Thread):
)
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_with_timeout(
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
) -> R:
@@ -154,7 +148,8 @@ def run_with_timeout(
Executes a function with a timeout. If the function doesn't complete within the specified
timeout, raises TimeoutError.
"""
task = TimeoutThread(timeout, func, *args, **kwargs)
context = contextvars.copy_context()
task = TimeoutThread(timeout, context.run, func, *args, **kwargs)
task.start()
task.join(timeout)

View File

@@ -37,7 +37,7 @@ langchainhub==0.1.21
langgraph==0.2.72
langgraph-checkpoint==2.0.13
langgraph-sdk==0.1.44
litellm==1.60.2
litellm==1.61.16
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45

View File

@@ -12,5 +12,5 @@ torch==2.2.0
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.60.2
litellm==1.61.16
sentry-sdk[fastapi,celery,starlette]==2.14.0

View File

@@ -3,6 +3,7 @@ import json
import logging
import sys
import time
from enum import Enum
from logging import getLogger
from typing import cast
from uuid import UUID
@@ -20,10 +21,13 @@ from onyx.configs.app_configs import REDIS_PORT
from onyx.configs.app_configs import REDIS_SSL
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_pool import RedisPool
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
# Tool to run helpful operations on Redis in production
# This is targeted for internal usage and may not have all the necessary parameters
@@ -42,6 +46,19 @@ SCAN_ITER_COUNT = 10000
BATCH_DEFAULT = 1000
class OnyxRedisCommand(Enum):
purge_connectorsync_taskset = "purge_connectorsync_taskset"
purge_documentset_taskset = "purge_documentset_taskset"
purge_usergroup_taskset = "purge_usergroup_taskset"
purge_locks_blocking_deletion = "purge_locks_blocking_deletion"
purge_vespa_syncing = "purge_vespa_syncing"
get_user_token = "get_user_token"
delete_user_token = "delete_user_token"
def __str__(self) -> str:
return self.value
def get_user_id(user_email: str) -> tuple[UUID, str]:
tenant_id = (
get_tenant_id_for_email(user_email) if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
@@ -55,50 +72,79 @@ def get_user_id(user_email: str) -> tuple[UUID, str]:
def onyx_redis(
command: str,
command: OnyxRedisCommand,
batch: int,
dry_run: bool,
ssl: bool,
host: str,
port: int,
db: int,
password: str | None,
user_email: str | None = None,
cc_pair_id: int | None = None,
) -> int:
# this is global and not tenant aware
pool = RedisPool.create_pool(
host=host,
port=port,
db=db,
password=password if password else "",
ssl=REDIS_SSL,
ssl=ssl,
ssl_cert_reqs="optional",
ssl_ca_certs=None,
)
r = Redis(connection_pool=pool)
logger.info("Redis ping starting. This may hang if your settings are incorrect.")
try:
r.ping()
except:
logger.exception("Redis ping exceptioned")
raise
if command == "purge_connectorsync_taskset":
logger.info("Redis ping succeeded.")
if command == OnyxRedisCommand.purge_connectorsync_taskset:
"""Purge connector tasksets. Used when the tasks represented in the tasksets
have been purged."""
return purge_by_match_and_type(
"*connectorsync_taskset*", "set", batch, dry_run, r
)
elif command == "purge_documentset_taskset":
elif command == OnyxRedisCommand.purge_documentset_taskset:
return purge_by_match_and_type(
"*documentset_taskset*", "set", batch, dry_run, r
)
elif command == "purge_usergroup_taskset":
elif command == OnyxRedisCommand.purge_usergroup_taskset:
return purge_by_match_and_type("*usergroup_taskset*", "set", batch, dry_run, r)
elif command == "purge_vespa_syncing":
elif command == OnyxRedisCommand.purge_locks_blocking_deletion:
if cc_pair_id is None:
logger.error("You must specify --cc-pair with purge_deletion_locks")
return 1
tenant_id = get_current_tenant_id()
logger.info(f"Purging locks associated with deleting cc_pair={cc_pair_id}.")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
match_pattern = f"{tenant_id}:{RedisConnectorIndex.FENCE_PREFIX}_{cc_pair_id}/*"
purge_by_match_and_type(match_pattern, "string", batch, dry_run, r)
redis_delete_if_exists_helper(
f"{tenant_id}:{redis_connector.prune.fence_key}", dry_run, r
)
redis_delete_if_exists_helper(
f"{tenant_id}:{redis_connector.permissions.fence_key}", dry_run, r
)
redis_delete_if_exists_helper(
f"{tenant_id}:{redis_connector.external_group_sync.fence_key}", dry_run, r
)
return 0
elif command == OnyxRedisCommand.purge_vespa_syncing:
return purge_by_match_and_type(
"*connectorsync:vespa_syncing*", "string", batch, dry_run, r
)
elif command == "get_user_token":
elif command == OnyxRedisCommand.get_user_token:
if not user_email:
logger.error("You must specify --user-email with get_user_token")
return 1
@@ -109,7 +155,7 @@ def onyx_redis(
else:
print(f"No token found for user {user_email}")
return 2
elif command == "delete_user_token":
elif command == OnyxRedisCommand.delete_user_token:
if not user_email:
logger.error("You must specify --user-email with delete_user_token")
return 1
@@ -131,6 +177,25 @@ def flush_batch_delete(batch_keys: list[bytes], r: Redis) -> None:
pipe.execute()
def redis_delete_if_exists_helper(key: str, dry_run: bool, r: Redis) -> bool:
"""Returns True if the key was found, False if not.
This function exists for logging purposes as the delete operation itself
doesn't really need to check the existence of the key.
"""
if not r.exists(key):
logger.info(f"Did not find {key}.")
return False
if dry_run:
logger.info(f"(DRY-RUN) Deleting {key}.")
else:
logger.info(f"Deleting {key}.")
r.delete(key)
return True
def purge_by_match_and_type(
match_pattern: str, match_type: str, batch_size: int, dry_run: bool, r: Redis
) -> int:
@@ -138,6 +203,12 @@ def purge_by_match_and_type(
match_type: https://redis.io/docs/latest/commands/type/
"""
logger.info(
f"purge_by_match_and_type start: "
f"match_pattern={match_pattern} "
f"match_type={match_type}"
)
# cursor = "0"
# while cursor != 0:
# cursor, data = self.scan(
@@ -164,13 +235,15 @@ def purge_by_match_and_type(
logger.info(f"Deleting item {count}: {key_str}")
batch_keys.append(key)
# flush if batch size has been reached
if len(batch_keys) >= batch_size:
flush_batch_delete(batch_keys, r)
batch_keys.clear()
if len(batch_keys) >= batch_size:
flush_batch_delete(batch_keys, r)
batch_keys.clear()
# final flush
flush_batch_delete(batch_keys, r)
batch_keys.clear()
logger.info(f"Deleted {count} matches.")
@@ -279,7 +352,21 @@ def delete_user_token_from_redis(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Onyx Redis Manager")
parser.add_argument("--command", type=str, help="Operation to run", required=True)
parser.add_argument(
"--command",
type=OnyxRedisCommand,
help="The command to run",
choices=list(OnyxRedisCommand),
required=True,
)
parser.add_argument(
"--ssl",
type=bool,
default=REDIS_SSL,
help="Use SSL when connecting to Redis. Usually True for prod and False for local testing",
required=False,
)
parser.add_argument(
"--host",
@@ -342,6 +429,13 @@ if __name__ == "__main__":
required=False,
)
parser.add_argument(
"--cc-pair",
type=int,
help="A connector credential pair id. Used with the purge_deletion_locks command.",
required=False,
)
args = parser.parse_args()
if args.tenant_id:
@@ -368,10 +462,12 @@ if __name__ == "__main__":
command=args.command,
batch=args.batch,
dry_run=args.dry_run,
ssl=args.ssl,
host=args.host,
port=args.port,
db=args.db,
password=args.password,
user_email=args.user_email,
cc_pair_id=args.cc_pair,
)
sys.exit(exitcode)

View File

@@ -508,6 +508,7 @@ def get_number_of_chunks_we_think_exist(
class VespaDebugging:
# Class for managing Vespa debugging actions.
def __init__(self, tenant_id: str | None = None):
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id
self.index_name = get_index_name(self.tenant_id)

View File

@@ -10,7 +10,8 @@ from onyx.connectors.onyx_jira.connector import JiraConnector
@pytest.fixture
def jira_connector() -> JiraConnector:
connector = JiraConnector(
"https://danswerai.atlassian.net/jira/software/c/projects/AS/boards/6",
jira_base_url="https://danswerai.atlassian.net",
project_key="AS",
comment_email_blacklist=[],
)
connector.load_credentials(

View File

@@ -4,6 +4,10 @@ from onyx.connectors.models import Document
from onyx.connectors.web.connector import WEB_CONNECTOR_VALID_SETTINGS
from onyx.connectors.web.connector import WebConnector
EXPECTED_QUOTE = (
"If you can't explain it to a six year old, you don't understand it yourself."
)
# NOTE(rkuo): we will probably need to adjust this test to point at our own test site
# to avoid depending on a third party site
@@ -11,7 +15,7 @@ from onyx.connectors.web.connector import WebConnector
def web_connector(request: pytest.FixtureRequest) -> WebConnector:
scroll_before_scraping = request.param
connector = WebConnector(
base_url="https://developer.onewelcome.com",
base_url="https://quotes.toscrape.com/scroll",
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
scroll_before_scraping=scroll_before_scraping,
)
@@ -28,7 +32,7 @@ def test_web_connector_scroll(web_connector: WebConnector) -> None:
assert len(all_docs) == 1
doc = all_docs[0]
assert "Onegini Identity Cloud" in doc.sections[0].text
assert EXPECTED_QUOTE in doc.sections[0].text
@pytest.mark.parametrize("web_connector", [False], indirect=True)
@@ -41,4 +45,4 @@ def test_web_connector_no_scroll(web_connector: WebConnector) -> None:
assert len(all_docs) == 1
doc = all_docs[0]
assert "Onegini Identity Cloud" not in doc.sections[0].text
assert EXPECTED_QUOTE not in doc.sections[0].text

View File

@@ -71,12 +71,13 @@ def litellm_embedding_model() -> EmbeddingModel:
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("LITE_LLM_API_KEY"),
api_key=os.getenv("LITELLM_API_KEY"),
provider_type=EmbeddingProvider.LITELLM,
api_url=os.getenv("LITE_LLM_API_URL"),
api_url=os.getenv("LITELLM_API_URL"),
)
@pytest.mark.skip(reason="re-enable when we can get the correct litellm key and url")
def test_litellm_embedding(litellm_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, litellm_embedding_model, 1536)
_run_embeddings(TOO_LONG_SAMPLE, litellm_embedding_model, 1536)
@@ -117,6 +118,11 @@ def azure_embedding_model() -> EmbeddingModel:
)
def test_azure_embedding(azure_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, azure_embedding_model, 1536)
_run_embeddings(TOO_LONG_SAMPLE, azure_embedding_model, 1536)
# NOTE (chris): this test doesn't work, and I do not know why
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
# """NOTE: this test relies on a very low rate limit for the Azure API +

View File

@@ -30,8 +30,10 @@ class ConnectorManager:
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config
or {"file_locations": []},
connector_specific_config=(
connector_specific_config
or ({"file_locations": []} if source == DocumentSource.FILE else {})
),
access_type=access_type,
groups=groups or [],
)

View File

@@ -88,8 +88,6 @@ class UserManager:
if not session_cookie:
raise Exception("Failed to login")
print(f"Logged in as {test_user.email}")
# Set cookies in the headers
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
test_user.cookies = {"fastapiusersauth": session_cookie}

View File

@@ -70,7 +70,7 @@ def _answer_fixture_impl(
files=[],
single_message_history=None,
),
system_message=default_build_system_message(prompt_config),
system_message=default_build_system_message(prompt_config, mock_llm.config),
message_history=[],
llm_config=mock_llm.config,
raw_user_query=QUERY,

View File

@@ -0,0 +1,131 @@
import contextvars
import time
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import run_with_timeout
# Create a test contextvar
test_var = contextvars.ContextVar("test_var", default="default")
def get_contextvar_value() -> str:
"""Helper function that runs in a thread and returns the contextvar value"""
# Add a small sleep to ensure we're actually running in a different thread
time.sleep(0.1)
return test_var.get()
def test_run_with_timeout_preserves_contextvar() -> None:
"""Test that run_with_timeout preserves contextvar values"""
# Set a value in the main thread
test_var.set("test_value")
# Run function with timeout and verify the value is preserved
result = run_with_timeout(1.0, get_contextvar_value)
assert result == "test_value"
def test_run_functions_in_parallel_preserves_contextvar() -> None:
"""Test that run_functions_in_parallel preserves contextvar values"""
# Set a value in the main thread
test_var.set("parallel_test")
# Create multiple function calls
function_calls = [
FunctionCall(get_contextvar_value),
FunctionCall(get_contextvar_value),
]
# Run in parallel and verify all results have the correct value
results = run_functions_in_parallel(function_calls)
for result_id, value in results.items():
assert value == "parallel_test"
def test_run_functions_tuples_preserves_contextvar() -> None:
"""Test that run_functions_tuples_in_parallel preserves contextvar values"""
# Set a value in the main thread
test_var.set("tuple_test")
# Create list of function tuples
functions_with_args = [
(get_contextvar_value, ()),
(get_contextvar_value, ()),
]
# Run in parallel and verify all results have the correct value
results = run_functions_tuples_in_parallel(functions_with_args)
for result in results:
assert result == "tuple_test"
def test_nested_contextvar_modifications() -> None:
"""Test that modifications to contextvars in threads don't affect other threads"""
def modify_and_return_contextvar(new_value: str) -> tuple[str, str]:
"""Helper that modifies the contextvar and returns both values"""
original = test_var.get()
test_var.set(new_value)
time.sleep(0.1) # Ensure threads overlap
return original, test_var.get()
# Set initial value
test_var.set("initial")
# Run multiple functions that modify the contextvar
functions_with_args = [
(modify_and_return_contextvar, ("thread1",)),
(modify_and_return_contextvar, ("thread2",)),
]
results = run_functions_tuples_in_parallel(functions_with_args)
# Verify each thread saw the initial value and its own modification
for original, modified in results:
assert original == "initial" # Each thread should see the initial value
assert modified in [
"thread1",
"thread2",
] # Each thread should see its own modification
# Verify the main thread's value wasn't affected
assert test_var.get() == "initial"
def test_contextvar_isolation_between_runs() -> None:
"""Test that contextvar changes don't leak between separate parallel runs"""
def set_and_return_contextvar(value: str) -> str:
test_var.set(value)
return test_var.get()
# First run
test_var.set("first_run")
first_results = run_functions_tuples_in_parallel(
[
(set_and_return_contextvar, ("thread1",)),
(set_and_return_contextvar, ("thread2",)),
]
)
# Verify first run results
assert all(result in ["thread1", "thread2"] for result in first_results)
# Second run should still see the main thread's value
assert test_var.get() == "first_run"
# Second run with different value
test_var.set("second_run")
second_results = run_functions_tuples_in_parallel(
[
(set_and_return_contextvar, ("thread3",)),
(set_and_return_contextvar, ("thread4",)),
]
)
# Verify second run results
assert all(result in ["thread3", "thread4"] for result in second_results)

View File

@@ -1,5 +1,5 @@
# fill in the template
envsubst '$SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf
envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf
# wait for the api_server to be ready
echo "Waiting for API server to boot up; this may take a minute or two..."

View File

@@ -36,6 +36,7 @@ services:
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
- INTEGRATION_TESTS_MODE=${INTEGRATION_TESTS_MODE:-}
# Gen AI Settings
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}

View File

@@ -0,0 +1,37 @@
services:
indexing_model_server:
image: onyxdotapp/onyx-model-server:${IMAGE_TAG:-latest}
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- INDEX_BATCH_SIZE=${INDEX_BATCH_SIZE:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_ONLY=True
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
- CLIENT_EMBEDDING_TIMEOUT=${CLIENT_EMBEDDING_TIMEOUT:-}
# Analytics Configs
- SENTRY_DSN=${SENTRY_DSN:-}
volumes:
# Not necessary, this is just to reduce download time during startup
- indexing_huggingface_model_cache:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
ports:
- "9000:9000" # <-- Add this line to expose the port to the host
volumes:
indexing_huggingface_model_cache:

View File

@@ -68,6 +68,28 @@ const nextConfig = {
},
];
},
async rewrites() {
return [
{
source: "/api/docs/:path*", // catch /api/docs and /api/docs/...
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/docs/:path*`,
},
{
source: "/api/docs", // if you also need the exact /api/docs
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/docs`,
},
{
source: "/openapi.json",
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/openapi.json`,
},
];
},
};
// Sentry configuration for error monitoring:

View File

@@ -113,7 +113,6 @@ export function AssistantEditor({
documentSets,
user,
defaultPublic,
redirectType,
llmProviders,
tools,
shouldAddAssistantToUserPreferences,
@@ -124,7 +123,6 @@ export function AssistantEditor({
documentSets: DocumentSet[];
user: User | null;
defaultPublic: boolean;
redirectType: SuccessfulPersonaUpdateRedirectType;
llmProviders: FullLLMProvider[];
tools: ToolSnapshot[];
shouldAddAssistantToUserPreferences?: boolean;
@@ -502,7 +500,7 @@ export function AssistantEditor({
)
.map((message: { message: string; name?: string }) => ({
message: message.message,
name: message.name || message.message,
name: message.message,
}));
// don't set groups if marked as public

View File

@@ -1,7 +1,7 @@
"use client";
import React, { useMemo, useState, useEffect } from "react";
import { Formik, Form, Field } from "formik";
import React, { useMemo } from "react";
import { Formik, Form } from "formik";
import * as Yup from "yup";
import { usePopup } from "@/components/admin/connectors/Popup";
import {
@@ -13,17 +13,13 @@ import {
createSlackChannelConfig,
isPersonaASlackBotPersona,
updateSlackChannelConfig,
fetchSlackChannels,
} from "../lib";
import CardSection from "@/components/admin/CardSection";
import { useRouter } from "next/navigation";
import { Persona } from "@/app/admin/assistants/interfaces";
import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
import { SEARCH_TOOL_ID, SEARCH_TOOL_NAME } from "@/app/chat/tools/constants";
import {
SlackChannelConfigFormFields,
SlackChannelConfigFormFieldsProps,
} from "./SlackChannelConfigFormFields";
import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
import { SlackChannelConfigFormFields } from "./SlackChannelConfigFormFields";
export const SlackChannelConfigCreationForm = ({
slack_bot_id,

View File

@@ -1,13 +1,7 @@
"use client";
import React, { useState, useEffect, useMemo } from "react";
import {
FieldArray,
Form,
useFormikContext,
ErrorMessage,
Field,
} from "formik";
import { FieldArray, useFormikContext, ErrorMessage, Field } from "formik";
import { CCPairDescriptor, DocumentSet } from "@/lib/types";
import {
Label,
@@ -18,14 +12,13 @@ import {
} from "@/components/admin/connectors/Field";
import { Button } from "@/components/ui/button";
import { Persona } from "@/app/admin/assistants/interfaces";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable";
import CollapsibleSection from "@/app/admin/assistants/CollapsibleSection";
import { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
import { StandardAnswerCategoryDropdownField } from "@/components/standardAnswers/StandardAnswerCategoryDropdown";
import { RadioGroup } from "@/components/ui/radio-group";
import { RadioGroupItemField } from "@/components/ui/RadioGroupItemField";
import { AlertCircle, View } from "lucide-react";
import { AlertCircle } from "lucide-react";
import { useRouter } from "next/navigation";
import {
Tooltip,
@@ -50,6 +43,7 @@ import {
import { Separator } from "@/components/ui/separator";
import { CheckFormField } from "@/components/ui/CheckField";
import { Input } from "@/components/ui/input";
export interface SlackChannelConfigFormFieldsProps {
isUpdate: boolean;
@@ -178,9 +172,13 @@ export function SlackChannelConfigFormFields({
);
}, [documentSets]);
const { data: channelOptions, isLoading } = useSWR(
const {
data: channelOptions,
error,
isLoading,
} = useSWR(
`/api/manage/admin/slack-app/bots/${slack_bot_id}/channels`,
async (url: string) => {
async () => {
const channels = await fetchSlackChannels(slack_bot_id);
return channels.map((channel: any) => ({
name: channel.name,
@@ -227,20 +225,34 @@ export function SlackChannelConfigFormFields({
>
Select A Slack Channel:
</label>{" "}
<Field name="channel_name">
{({ field, form }: { field: any; form: any }) => (
<SearchMultiSelectDropdown
options={channelOptions || []}
onSelect={(selected) => {
form.setFieldValue("channel_name", selected.name);
}}
initialSearchTerm={field.value}
onSearchTermChange={(term) => {
form.setFieldValue("channel_name", term);
}}
{error ? (
<div>
<div className="text-red-600 text-sm mb-4">
{error.message || "Unable to fetch Slack channels."}
{" Please enter the channel name manually."}
</div>
<TextFormField
name="channel_name"
label="Channel Name"
placeholder="Enter channel name"
/>
)}
</Field>
</div>
) : (
<Field name="channel_name">
{({ field, form }: { field: any; form: any }) => (
<SearchMultiSelectDropdown
options={channelOptions || []}
onSelect={(selected) => {
form.setFieldValue("channel_name", selected.name);
}}
initialSearchTerm={field.value}
onSearchTermChange={(term) => {
form.setFieldValue("channel_name", term);
}}
/>
)}
</Field>
)}
</>
)}
<div className="space-y-2 mt-4">

View File

@@ -1,13 +1,18 @@
"use client";
import { Button } from "@/components/ui/button";
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
import {
CCPairFullInfo,
ConnectorCredentialPairStatus,
statusIsNotCurrentlyActive,
} from "./types";
import { usePopup } from "@/components/admin/connectors/Popup";
import { mutate } from "swr";
import { buildCCPairInfoUrl } from "./lib";
import { setCCPairStatus } from "@/lib/ccPair";
import { useState } from "react";
import { LoadingAnimation } from "@/components/Loading";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
export function ModifyStatusButtonCluster({
ccPair,
@@ -16,11 +21,24 @@ export function ModifyStatusButtonCluster({
}) {
const { popup, setPopup } = usePopup();
const [isUpdating, setIsUpdating] = useState(false);
const [showConfirmModal, setShowConfirmModal] = useState(false);
const handleStatusChange = async (
newStatus: ConnectorCredentialPairStatus
) => {
if (isUpdating) return; // Prevent double-clicks or multiple requests
if (
ccPair.status === ConnectorCredentialPairStatus.INVALID &&
newStatus === ConnectorCredentialPairStatus.ACTIVE
) {
setShowConfirmModal(true);
} else {
await updateStatus(newStatus);
}
};
const updateStatus = async (newStatus: ConnectorCredentialPairStatus) => {
setIsUpdating(true);
try {
@@ -38,30 +56,23 @@ export function ModifyStatusButtonCluster({
};
// Compute the button text based on current state and backend status
const buttonText =
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Re-Enable"
: "Pause";
const isNotActive = statusIsNotCurrentlyActive(ccPair.status);
const buttonText = isNotActive ? "Re-Enable" : "Pause";
const tooltip =
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Click to start indexing again!"
: "When paused, the connector's documents will still be visible. However, no new documents will be indexed.";
const tooltip = isNotActive
? "Click to start indexing again!"
: "When paused, the connector's documents will still be visible. However, no new documents will be indexed.";
return (
<>
{popup}
<Button
className="flex items-center justify-center w-auto min-w-[100px] px-4 py-2"
variant={
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "success-reverse"
: "default"
}
variant={isNotActive ? "success-reverse" : "default"}
disabled={isUpdating}
onClick={() =>
handleStatusChange(
ccPair.status === ConnectorCredentialPairStatus.PAUSED
isNotActive
? ConnectorCredentialPairStatus.ACTIVE
: ConnectorCredentialPairStatus.PAUSED
)
@@ -70,17 +81,27 @@ export function ModifyStatusButtonCluster({
>
{isUpdating ? (
<LoadingAnimation
text={
ccPair.status === ConnectorCredentialPairStatus.PAUSED
? "Resuming"
: "Pausing"
}
text={isNotActive ? "Resuming" : "Pausing"}
size="text-md"
/>
) : (
buttonText
)}
</Button>
{showConfirmModal && (
<ConfirmEntityModal
entityType="Invalid Connector"
entityName={ccPair.name}
onClose={() => setShowConfirmModal(false)}
onSubmit={() => {
setShowConfirmModal(false);
updateStatus(ConnectorCredentialPairStatus.ACTIVE);
}}
additionalDetails="This connector was previously marked as invalid. Please verify that your configuration is correct before re-enabling. Are you sure you want to proceed?"
actionButtonText="Re-Enable"
variant="action"
/>
)}
</>
);
}

View File

@@ -123,7 +123,8 @@ export function ReIndexButton({
disabled={
isDisabled ||
ccPairStatus == ConnectorCredentialPairStatus.DELETING ||
ccPairStatus == ConnectorCredentialPairStatus.PAUSED
ccPairStatus == ConnectorCredentialPairStatus.PAUSED ||
ccPairStatus == ConnectorCredentialPairStatus.INVALID
}
tooltip={getCCPairStatusMessage(isDisabled, isIndexing, ccPairStatus)}
>

View File

@@ -15,6 +15,18 @@ export enum ConnectorCredentialPairStatus {
INVALID = "INVALID",
}
/**
* Returns true if the status is not currently active (i.e. paused or invalid), but not deleting
*/
export function statusIsNotCurrentlyActive(
status: ConnectorCredentialPairStatus
): boolean {
return (
status === ConnectorCredentialPairStatus.PAUSED ||
status === ConnectorCredentialPairStatus.INVALID
);
}
export interface CCPairFullInfo {
id: number;
name: string;

View File

@@ -21,11 +21,7 @@ export default async function Page(props: { params: Promise<{ id: string }> }) {
<div className="px-32">
<div className="mx-auto container">
<CardSection className="!border-none !bg-transparent !ring-none">
<AssistantEditor
{...values}
defaultPublic={false}
redirectType={SuccessfulPersonaUpdateRedirectType.CHAT}
/>
<AssistantEditor {...values} defaultPublic={false} />
</CardSection>
</div>
</div>

View File

@@ -26,7 +26,6 @@ export default async function Page() {
<AssistantEditor
{...values}
defaultPublic={false}
redirectType={SuccessfulPersonaUpdateRedirectType.CHAT}
shouldAddAssistantToUserPreferences={true}
/>
</CardSection>

View File

@@ -47,6 +47,7 @@ import {
removeMessage,
sendMessage,
setMessageAsLatest,
updateLlmOverrideForChatSession,
updateParentChildren,
uploadFilesForChat,
useScrollonStream,
@@ -65,7 +66,7 @@ import {
import { usePopup } from "@/components/admin/connectors/Popup";
import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
import { useDocumentSelection } from "./useDocumentSelection";
import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks";
import { LlmDescriptor, useFilters, useLlmManager } from "@/lib/hooks";
import { ChatState, FeedbackType, RegenerationState } from "./types";
import { DocumentResults } from "./documentSidebar/DocumentResults";
import { OnyxInitializingLoader } from "@/components/OnyxInitializingLoader";
@@ -89,7 +90,11 @@ import {
import { buildFilters } from "@/lib/search/utils";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
import {
checkLLMSupportsImageInput,
getFinalLLM,
structureValue,
} from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar";
import { useChatContext } from "@/components/context/ChatContext";
import { v4 as uuidv4 } from "uuid";
@@ -137,6 +142,7 @@ import {
import { Button } from "@/components/ui/button";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { MessageChannel } from "node:worker_threads";
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -194,16 +200,6 @@ export function ChatPage({
return screenSize;
}
const { height: screenHeight } = useScreenSize();
const getContainerHeight = () => {
if (autoScrollEnabled) return undefined;
if (screenHeight < 600) return "20vh";
if (screenHeight < 1200) return "30vh";
return "40vh";
};
// handle redirect if chat page is disabled
// NOTE: this must be done here, in a client component since
// settings are passed in via Context and therefore aren't
@@ -222,6 +218,7 @@ export function ChatPage({
setProSearchEnabled(!proSearchEnabled);
};
const isInitialLoad = useRef(true);
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
const {
@@ -356,7 +353,7 @@ export function ChatPage({
]
);
const llmOverrideManager = useLlmOverride(
const llmManager = useLlmManager(
llmProviders,
selectedChatSession,
liveAssistant
@@ -520,8 +517,17 @@ export function ChatPage({
scrollInitialized.current = false;
if (!hasPerformedInitialScroll) {
if (isInitialLoad.current) {
setHasPerformedInitialScroll(true);
isInitialLoad.current = false;
}
clientScrollToBottom();
setTimeout(() => {
setHasPerformedInitialScroll(true);
}, 100);
} else if (isChatSessionSwitch) {
setHasPerformedInitialScroll(true);
clientScrollToBottom(true);
}
@@ -865,6 +871,7 @@ export function ChatPage({
}, [liveAssistant]);
const filterManager = useFilters();
const [isChatSearchModalOpen, setIsChatSearchModalOpen] = useState(false);
const [currentFeedback, setCurrentFeedback] = useState<
[FeedbackType, number] | null
@@ -1130,6 +1137,56 @@ export function ChatPage({
});
};
const [uncaughtError, setUncaughtError] = useState<string | null>(null);
const [agenticGenerating, setAgenticGenerating] = useState(false);
const autoScrollEnabled =
(user?.preferences?.auto_scroll && !agenticGenerating) ?? false;
useScrollonStream({
chatState: currentSessionChatState,
scrollableDivRef,
scrollDist,
endDivRef,
debounceNumber,
mobile: settings?.isMobile,
enableAutoScroll: autoScrollEnabled,
});
// Track whether a message has been sent during this page load, keyed by chat session id
const [sessionHasSentLocalUserMessage, setSessionHasSentLocalUserMessage] =
useState<Map<string | null, boolean>>(new Map());
// Update the local state for a session once the user sends a message
const markSessionMessageSent = (sessionId: string | null) => {
setSessionHasSentLocalUserMessage((prev) => {
const newMap = new Map(prev);
newMap.set(sessionId, true);
return newMap;
});
};
const currentSessionHasSentLocalUserMessage = useMemo(
() => (sessionId: string | null) => {
return sessionHasSentLocalUserMessage.size === 0
? undefined
: sessionHasSentLocalUserMessage.get(sessionId) || false;
},
[sessionHasSentLocalUserMessage]
);
const { height: screenHeight } = useScreenSize();
const getContainerHeight = useMemo(() => {
return () => {
if (!currentSessionHasSentLocalUserMessage(chatSessionIdRef.current)) {
return undefined;
}
if (autoScrollEnabled) return undefined;
if (screenHeight < 600) return "40vh";
if (screenHeight < 1200) return "50vh";
return "60vh";
};
}, [autoScrollEnabled, screenHeight, currentSessionHasSentLocalUserMessage]);
const onSubmit = async ({
messageIdToResend,
@@ -1138,7 +1195,7 @@ export function ChatPage({
forceSearch,
isSeededChat,
alternativeAssistantOverride = null,
modelOverRide,
modelOverride,
regenerationRequest,
overrideFileDescriptors,
}: {
@@ -1148,7 +1205,7 @@ export function ChatPage({
forceSearch?: boolean;
isSeededChat?: boolean;
alternativeAssistantOverride?: Persona | null;
modelOverRide?: LlmOverride;
modelOverride?: LlmDescriptor;
regenerationRequest?: RegenerationRequest | null;
overrideFileDescriptors?: FileDescriptor[];
} = {}) => {
@@ -1156,6 +1213,9 @@ export function ChatPage({
let frozenSessionId = currentSessionId();
updateCanContinue(false, frozenSessionId);
// Mark that we've sent a message for this session in the current page load
markSessionMessageSent(frozenSessionId);
if (currentChatState() != "input") {
if (currentChatState() == "uploading") {
setPopup({
@@ -1191,6 +1251,22 @@ export function ChatPage({
currChatSessionId = chatSessionIdRef.current as string;
}
frozenSessionId = currChatSessionId;
// update the selected model for the chat session if one is specified so that
// it persists across page reloads. Do not `await` here so that the message
// request can continue and this will just happen in the background.
// NOTE: only set the model override for the chat session once we send a
// message with it. If the user switches models and then starts a new
// chat session, it is unexpected for that model to be used when they
// return to this session the next day.
let finalLLM = modelOverride || llmManager.currentLlm;
updateLlmOverrideForChatSession(
currChatSessionId,
structureValue(
finalLLM.name || "",
finalLLM.provider || "",
finalLLM.modelName || ""
)
);
updateStatesWithNewSessionId(currChatSessionId);
@@ -1250,11 +1326,14 @@ export function ChatPage({
: null) ||
(messageMap.size === 1 ? Array.from(messageMap.values())[0] : null);
const currentAssistantId = alternativeAssistantOverride
? alternativeAssistantOverride.id
: alternativeAssistant
? alternativeAssistant.id
: liveAssistant.id;
let currentAssistantId;
if (alternativeAssistantOverride) {
currentAssistantId = alternativeAssistantOverride.id;
} else if (alternativeAssistant) {
currentAssistantId = alternativeAssistant.id;
} else {
currentAssistantId = liveAssistant.id;
}
resetInputBar();
let messageUpdates: Message[] | null = null;
@@ -1326,15 +1405,13 @@ export function ChatPage({
forceSearch,
regenerate: regenerationRequest !== undefined,
modelProvider:
modelOverRide?.name ||
llmOverrideManager.llmOverride.name ||
undefined,
modelOverride?.name || llmManager.currentLlm.name || undefined,
modelVersion:
modelOverRide?.modelName ||
llmOverrideManager.llmOverride.modelName ||
modelOverride?.modelName ||
llmManager.currentLlm.modelName ||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
undefined,
temperature: llmOverrideManager.temperature || undefined,
temperature: llmManager.temperature || undefined,
systemPromptOverride:
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
useExistingUserMessage: isSeededChat,
@@ -1802,7 +1879,7 @@ export function ChatPage({
const [_, llmModel] = getFinalLLM(
llmProviders,
liveAssistant,
llmOverrideManager.llmOverride
llmManager.currentLlm
);
const llmAcceptsImages = checkLLMSupportsImageInput(llmModel);
@@ -1857,7 +1934,6 @@ export function ChatPage({
// Used to maintain a "time out" for history sidebar so our existing refs can have time to process change
const [untoggled, setUntoggled] = useState(false);
const [loadingError, setLoadingError] = useState<string | null>(null);
const [agenticGenerating, setAgenticGenerating] = useState(false);
const explicitlyUntoggle = () => {
setShowHistorySidebar(false);
@@ -1899,19 +1975,6 @@ export function ChatPage({
isAnonymousUser: user?.is_anonymous_user,
});
const autoScrollEnabled =
(user?.preferences?.auto_scroll && !agenticGenerating) ?? false;
useScrollonStream({
chatState: currentSessionChatState,
scrollableDivRef,
scrollDist,
endDivRef,
debounceNumber,
mobile: settings?.isMobile,
enableAutoScroll: autoScrollEnabled,
});
// Virtualization + Scrolling related effects and functions
const scrollInitialized = useRef(false);
interface VisibleRange {
@@ -2121,7 +2184,7 @@ export function ChatPage({
}, [searchParams, router]);
useEffect(() => {
llmOverrideManager.updateImageFilesPresent(imageFileInMessageHistory);
llmManager.updateImageFilesPresent(imageFileInMessageHistory);
}, [imageFileInMessageHistory]);
const pathname = usePathname();
@@ -2175,9 +2238,9 @@ export function ChatPage({
function createRegenerator(regenerationRequest: RegenerationRequest) {
// Returns new function that only needs `modelOverRide` to be specified when called
return async function (modelOverRide: LlmOverride) {
return async function (modelOverride: LlmDescriptor) {
return await onSubmit({
modelOverRide,
modelOverride,
messageIdToResend: regenerationRequest.parentMessage.messageId,
regenerationRequest,
forceSearch: regenerationRequest.forceSearch,
@@ -2258,9 +2321,7 @@ export function ChatPage({
{(settingsToggled || userSettingsToggled) && (
<UserSettingsModal
setPopup={setPopup}
setLlmOverride={(newOverride) =>
llmOverrideManager.updateLLMOverride(newOverride)
}
setCurrentLlm={(newLlm) => llmManager.updateCurrentLlm(newLlm)}
defaultModel={user?.preferences.default_model!}
llmProviders={llmProviders}
onClose={() => {
@@ -2270,6 +2331,11 @@ export function ChatPage({
/>
)}
<ChatSearchModal
open={isChatSearchModalOpen}
onCloseModal={() => setIsChatSearchModalOpen(false)}
/>
{retrievalEnabled && documentSidebarVisible && settings?.isMobile && (
<div className="md:hidden">
<Modal
@@ -2324,7 +2390,7 @@ export function ChatPage({
<ShareChatSessionModal
assistantId={liveAssistant?.id}
message={message}
modelOverride={llmOverrideManager.llmOverride}
modelOverride={llmManager.currentLlm}
chatSessionId={sharedChatSession.id}
existingSharedStatus={sharedChatSession.shared_status}
onClose={() => setSharedChatSession(null)}
@@ -2342,7 +2408,7 @@ export function ChatPage({
<ShareChatSessionModal
message={message}
assistantId={liveAssistant?.id}
modelOverride={llmOverrideManager.llmOverride}
modelOverride={llmManager.currentLlm}
chatSessionId={chatSessionIdRef.current}
existingSharedStatus={chatSessionSharedStatus}
onClose={() => setSharingModalVisible(false)}
@@ -2377,6 +2443,9 @@ export function ChatPage({
>
<div className="w-full relative">
<HistorySidebar
toggleChatSessionSearchModal={() =>
setIsChatSearchModalOpen((open) => !open)
}
liveAssistant={liveAssistant}
setShowAssistantsModal={setShowAssistantsModal}
explicitlyUntoggle={explicitlyUntoggle}
@@ -2393,6 +2462,7 @@ export function ChatPage({
showDeleteAllModal={() => setShowDeleteAllModal(true)}
/>
</div>
<div
className={`
flex-none
@@ -2572,6 +2642,7 @@ export function ChatPage({
style={{ overflowAnchor: "none" }}
key={currentSessionId()}
className={
(hasPerformedInitialScroll ? "" : " hidden ") +
"desktop:-ml-4 w-full mx-auto " +
"absolute mobile:top-0 desktop:top-0 left-0 " +
(settings?.enterpriseSettings
@@ -3058,7 +3129,7 @@ export function ChatPage({
messageId: message.messageId,
parentMessage: parentMessage!,
forceSearch: true,
})(llmOverrideManager.llmOverride);
})(llmManager.currentLlm);
} else {
setPopup({
type: "error",
@@ -3203,7 +3274,7 @@ export function ChatPage({
availableDocumentSets={documentSets}
availableTags={tags}
filterManager={filterManager}
llmOverrideManager={llmOverrideManager}
llmManager={llmManager}
removeDocs={() => {
clearSelectedDocuments();
}}

View File

@@ -1,8 +1,8 @@
import { useChatContext } from "@/components/context/ChatContext";
import {
getDisplayNameForModel,
LlmOverride,
useLlmOverride,
LlmDescriptor,
useLlmManager,
} from "@/lib/hooks";
import { StringOrNumberOption } from "@/components/Dropdown";
@@ -106,13 +106,13 @@ export default function RegenerateOption({
onDropdownVisibleChange,
}: {
selectedAssistant: Persona;
regenerate: (modelOverRide: LlmOverride) => Promise<void>;
regenerate: (modelOverRide: LlmDescriptor) => Promise<void>;
overriddenModel?: string;
onHoverChange: (isHovered: boolean) => void;
onDropdownVisibleChange: (isVisible: boolean) => void;
}) {
const { llmProviders } = useChatContext();
const llmOverrideManager = useLlmOverride(llmProviders);
const llmManager = useLlmManager(llmProviders);
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
@@ -148,7 +148,7 @@ export default function RegenerateOption({
);
const currentModelName =
llmOverrideManager?.llmOverride.modelName ||
llmManager?.currentLlm.modelName ||
(selectedAssistant
? selectedAssistant.llm_model_version_override || llmName
: llmName);

View File

@@ -0,0 +1,31 @@
import React from "react";
import { ChatSearchItem } from "./ChatSearchItem";
import { ChatSessionSummary } from "./interfaces";
interface ChatSearchGroupProps {
title: string;
chats: ChatSessionSummary[];
onSelectChat: (id: string) => void;
}
export function ChatSearchGroup({
title,
chats,
onSelectChat,
}: ChatSearchGroupProps) {
return (
<div className="mb-4">
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-gray-800/90 py-2 px-4 px-4">
<div className="text-xs font-medium leading-4 text-gray-600 dark:text-gray-400">
{title}
</div>
</div>
<ol>
{chats.map((chat) => (
<ChatSearchItem key={chat.id} chat={chat} onSelect={onSelectChat} />
))}
</ol>
</div>
);
}

View File

@@ -0,0 +1,30 @@
import React from "react";
import { MessageSquare } from "lucide-react";
import { ChatSessionSummary } from "./interfaces";
interface ChatSearchItemProps {
chat: ChatSessionSummary;
onSelect: (id: string) => void;
}
export function ChatSearchItem({ chat, onSelect }: ChatSearchItemProps) {
return (
<li>
<div className="cursor-pointer" onClick={() => onSelect(chat.id)}>
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-800">
<div className="flex items-center">
<MessageSquare className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
<div className="relative grow overflow-hidden whitespace-nowrap pl-4">
<div className="text-sm dark:text-neutral-200">
{chat.name || "Untitled Chat"}
</div>
</div>
<div className="opacity-0 group-hover:opacity-100 transition-opacity text-xs text-neutral-500 dark:text-neutral-400">
{new Date(chat.time_created).toLocaleDateString()}
</div>
</div>
</div>
</div>
</li>
);
}

View File

@@ -0,0 +1,122 @@
import React, { useRef } from "react";
import { Dialog, DialogContent } from "@/components/ui/dialog";
import { ScrollArea } from "@/components/ui/scroll-area";
import { ChatSearchGroup } from "./ChatSearchGroup";
import { NewChatButton } from "./NewChatButton";
import { useChatSearch } from "./hooks/useChatSearch";
import { LoadingSpinner } from "./LoadingSpinner";
import { useRouter } from "next/navigation";
import { SearchInput } from "./components/SearchInput";
import { ChatSearchSkeletonList } from "./components/ChatSearchSkeleton";
import { useIntersectionObserver } from "./hooks/useIntersectionObserver";
interface ChatSearchModalProps {
open: boolean;
onCloseModal: () => void;
}
export function ChatSearchModal({ open, onCloseModal }: ChatSearchModalProps) {
const {
searchQuery,
setSearchQuery,
chatGroups,
isLoading,
isSearching,
hasMore,
fetchMoreChats,
} = useChatSearch();
const onClose = () => {
setSearchQuery("");
onCloseModal();
};
const router = useRouter();
const scrollAreaRef = useRef<HTMLDivElement>(null);
const { targetRef } = useIntersectionObserver({
root: scrollAreaRef.current,
onIntersect: fetchMoreChats,
enabled: open && hasMore && !isLoading,
});
const handleChatSelect = (chatId: string) => {
router.push(`/chat?chatId=${chatId}`);
onClose();
};
const handleNewChat = async () => {
try {
onClose();
router.push(`/chat`);
} catch (error) {
console.error("Error creating new chat:", error);
}
};
return (
<Dialog open={open} onOpenChange={(open) => !open && onClose()}>
<DialogContent
hideCloseIcon
className="!rounded-xl overflow-hidden p-0 w-full max-w-2xl"
backgroundColor="bg-neutral-950/20 shadow-xl"
>
<div className="w-full flex flex-col bg-white dark:bg-neutral-800 h-[80vh] max-h-[600px]">
<div className="sticky top-0 z-20 px-6 py-3 w-full flex items-center justify-between bg-white dark:bg-neutral-800 border-b border-neutral-200 dark:border-neutral-700">
<SearchInput
searchQuery={searchQuery}
setSearchQuery={setSearchQuery}
isSearching={isSearching}
/>
</div>
<ScrollArea
className="flex-grow bg-white relative dark:bg-neutral-800"
ref={scrollAreaRef}
type="auto"
>
<div className="px-4 py-2">
<NewChatButton onClick={handleNewChat} />
{isSearching ? (
<ChatSearchSkeletonList />
) : isLoading && chatGroups.length === 0 ? (
<div className="py-8">
<LoadingSpinner size="large" className="mx-auto" />
</div>
) : chatGroups.length > 0 ? (
<>
{chatGroups.map((group, groupIndex) => (
<ChatSearchGroup
key={groupIndex}
title={group.title}
chats={group.chats}
onSelectChat={handleChatSelect}
/>
))}
<div ref={targetRef} className="py-4">
{isLoading && hasMore && (
<LoadingSpinner className="mx-auto" />
)}
{!hasMore && chatGroups.length > 0 && (
<div className="text-center text-xs text-neutral-500 dark:text-neutral-400 py-2">
No more chats to load
</div>
)}
</div>
</>
) : (
!isLoading && (
<div className="px-4 py-3 text-sm text-neutral-500 dark:text-neutral-400">
No chats found
</div>
)
)}
</div>
</ScrollArea>
</div>
</DialogContent>
</Dialog>
);
}

Some files were not shown because too many files have changed in this diff Show More