mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
2 Commits
fix_stream
...
nit_error
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d1b9f11750 | ||
|
|
9b067d8179 |
@@ -74,9 +74,7 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
123
README.md
123
README.md
@@ -24,94 +24,113 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
|
||||
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
|
||||
There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions the agents can take.
|
||||
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring AI Assistants.
|
||||
|
||||
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
|
||||
supported?" or "Where's the pull request for feature Y?"
|
||||
|
||||
<h3>Feature Showcase</h3>
|
||||
<h3>Usage</h3>
|
||||
|
||||
**Deep research over your team's knowledge:**
|
||||
Onyx Web App:
|
||||
|
||||
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
|
||||
https://github.com/onyx-dot-app/onyx/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
|
||||
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
**Use Onyx as a secure AI Chat with any LLM:**
|
||||
|
||||

|
||||
|
||||
|
||||
**Easily set up connectors to your apps:**
|
||||
|
||||

|
||||
|
||||
|
||||
**Access Onyx where your team already works:**
|
||||
|
||||

|
||||
https://github.com/onyx-dot-app/onyx/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
For more details on the Admin UI to manage connectors and users, check out our
|
||||
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
|
||||
|
||||
## Deployment
|
||||
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for high-availability/scalable deployment on Kubernetes.
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
|
||||
|
||||
## 💃 Main Features
|
||||
|
||||
- Chat UI with the ability to select documents to chat with.
|
||||
- Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
- Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
|
||||
- Document Search + AI Answers for natural language queries.
|
||||
- Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
- Slack integration to get answers and search results directly in Slack.
|
||||
|
||||
## 🚧 Roadmap
|
||||
- Extensions to the Chrome Plugin
|
||||
- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- Personalized Search
|
||||
|
||||
- Chat/Prompt sharing with specific teammates and user groups.
|
||||
- Multimodal model support, chat with images, video etc.
|
||||
- Choosing between LLMs and parameters during chat session.
|
||||
- Tool calling and agent configurations options.
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
- Code Search
|
||||
- SQL and Structured Query Language
|
||||
|
||||
## Other Notable Benefits of Onyx
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models only through Onyx + learn from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
- User Authentication with document level access management.
|
||||
- Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
- Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
- Custom deep learning models + learn from user feedback.
|
||||
- Easy deployment and ability to host Onyx anywhere of your choosing.
|
||||
|
||||
## 🔌 Connectors
|
||||
Keep knowledge and access up to sync across 40+ connectors:
|
||||
|
||||
Efficiently pulls the latest changes from:
|
||||
|
||||
- Slack
|
||||
- GitHub
|
||||
- Google Drive
|
||||
- Confluence
|
||||
- Slack
|
||||
- Gmail
|
||||
- Salesforce
|
||||
- Microsoft Sharepoint
|
||||
- Github
|
||||
- Jira
|
||||
- Zendesk
|
||||
- Gmail
|
||||
- Notion
|
||||
- Gong
|
||||
- Microsoft Teams
|
||||
- Dropbox
|
||||
- Slab
|
||||
- Linear
|
||||
- Productboard
|
||||
- Guru
|
||||
- Bookstack
|
||||
- Document360
|
||||
- Sharepoint
|
||||
- Hubspot
|
||||
- Local Files
|
||||
- Websites
|
||||
- And more ...
|
||||
|
||||
See the full list [here](https://docs.onyx.app/connectors).
|
||||
## 📚 Editions
|
||||
|
||||
|
||||
## 📚 Licensing
|
||||
There are two editions of Onyx:
|
||||
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
|
||||
For feature details, check out [our website](https://www.onyx.app/pricing).
|
||||
- Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
|
||||
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
- Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
- Role-based access control
|
||||
- Document permission inheritance from connected sources
|
||||
- Usage analytics and query history accessible to admins
|
||||
- Whitelabeling
|
||||
- API key authentication
|
||||
- Encryption of secrets
|
||||
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
|
||||
To try the Onyx Enterprise Edition:
|
||||
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
|
||||
|
||||
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Add indexes to document__tag
|
||||
|
||||
Revision ID: 1a03d2c2856b
|
||||
Revises: 9c00a2bccb83
|
||||
Create Date: 2025-02-18 10:45:13.957807
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1a03d2c2856b"
|
||||
down_revision = "9c00a2bccb83"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f("ix_document__tag_tag_id"),
|
||||
"document__tag",
|
||||
["tag_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag")
|
||||
@@ -1,43 +0,0 @@
|
||||
"""chat_message_agentic
|
||||
|
||||
Revision ID: 9c00a2bccb83
|
||||
Revises: b7a7eee5aa15
|
||||
Create Date: 2025-02-17 11:15:43.081150
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9c00a2bccb83"
|
||||
down_revision = "b7a7eee5aa15"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First add the column as nullable
|
||||
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
|
||||
|
||||
# Update existing rows based on presence of SubQuestions
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET is_agentic = EXISTS (
|
||||
SELECT 1
|
||||
FROM agent__sub_question
|
||||
WHERE agent__sub_question.primary_question_id = chat_message.id
|
||||
)
|
||||
WHERE is_agentic IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the column non-nullable with a default value of False
|
||||
op.alter_column(
|
||||
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "is_agentic")
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove inactive ccpair status on downgrade
|
||||
|
||||
Revision ID: acaab4ef4507
|
||||
Revises: b388730a2899
|
||||
Create Date: 2025-02-16 18:21:41.330212
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from sqlalchemy import update
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "acaab4ef4507"
|
||||
down_revision = "b388730a2899"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
update(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID)
|
||||
.values(status=ConnectorCredentialPairStatus.ACTIVE)
|
||||
)
|
||||
@@ -1,31 +0,0 @@
|
||||
"""nullable preferences
|
||||
|
||||
Revision ID: b388730a2899
|
||||
Revises: 1a03d2c2856b
|
||||
Create Date: 2025-02-17 18:49:22.643902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b388730a2899"
|
||||
down_revision = "1a03d2c2856b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=True)
|
||||
op.alter_column("user", "auto_scroll", nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Ensure no null values before making columns non-nullable
|
||||
op.execute(
|
||||
'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL'
|
||||
)
|
||||
op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL')
|
||||
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=False)
|
||||
op.alter_column("user", "auto_scroll", nullable=False)
|
||||
@@ -21,7 +21,7 @@ logger = setup_logger()
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
@@ -62,7 +62,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -14,24 +14,30 @@ def _build_group_member_email_map(
|
||||
confluence_client: OnyxConfluence, cc_pair_id: int
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user}")
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
email = user.email
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
msg = f"user result missing user field: {user_result}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
email = user.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
user_name = user.username
|
||||
user_name = user.get("username")
|
||||
# If it is present, try to get the email using a Server-specific method
|
||||
if user_name:
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
msg = f"user result missing email field: {user}"
|
||||
if user.type == "app":
|
||||
msg = f"user result missing email field: {user_result}"
|
||||
if user.get("type") == "app":
|
||||
logger.warning(msg)
|
||||
else:
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
@@ -39,7 +45,7 @@ def _build_group_member_email_map(
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id):
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
|
||||
@@ -33,7 +33,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
return await call_next(request)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in tenant ID middleware: {str(e)}")
|
||||
logger.error(f"Error in tenant ID middleware: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ async def _get_tenant_id_from_request(
|
||||
"""
|
||||
# Check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id is not None:
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Check for anonymous user cookie
|
||||
|
||||
@@ -36,12 +36,12 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -271,12 +271,12 @@ def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
@@ -329,6 +329,7 @@ def handle_slack_oauth_callback(
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
@@ -336,7 +337,7 @@ def handle_slack_oauth_callback(
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
@@ -522,6 +523,7 @@ def handle_google_drive_oauth_callback(
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
@@ -529,7 +531,7 @@ def handle_google_drive_oauth_callback(
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
|
||||
@@ -28,7 +28,7 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
@@ -52,8 +52,8 @@ User rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
@@ -94,7 +94,7 @@ User Group rate limits
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||
|
||||
if group_rate_limits:
|
||||
|
||||
@@ -41,15 +41,14 @@ from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
@@ -58,14 +57,13 @@ router = APIRouter(prefix="/tenants")
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
@@ -74,15 +72,15 @@ async def get_anonymous_user_path_api(
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
@@ -103,7 +101,7 @@ async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
@@ -152,17 +150,14 @@ async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
|
||||
try:
|
||||
# Fetch tenant_id and current tenant's information
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
@@ -186,8 +181,6 @@ async def create_subscription_session(
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
@@ -204,7 +197,7 @@ async def impersonate_user(
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
with get_session_with_tenant(tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
@@ -228,9 +221,8 @@ async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
|
||||
@@ -118,7 +118,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
@@ -134,7 +134,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
|
||||
@@ -42,5 +42,4 @@ def fetch_no_auth_user(
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
password_configured=False,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
@@ -88,6 +86,7 @@ from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
@@ -105,7 +104,6 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -141,30 +139,6 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
return email or ""
|
||||
|
||||
|
||||
def generate_password() -> str:
|
||||
lowercase_letters = string.ascii_lowercase
|
||||
uppercase_letters = string.ascii_uppercase
|
||||
digits = string.digits
|
||||
special_characters = string.punctuation
|
||||
|
||||
# Ensure at least one of each required character type
|
||||
password = [
|
||||
secrets.choice(uppercase_letters),
|
||||
secrets.choice(digits),
|
||||
secrets.choice(special_characters),
|
||||
]
|
||||
|
||||
# Fill the rest with a mix of characters
|
||||
remaining_length = 12 - len(password)
|
||||
all_characters = lowercase_letters + uppercase_letters + digits + special_characters
|
||||
password.extend(secrets.choice(all_characters) for _ in range(remaining_length))
|
||||
|
||||
# Shuffle the password to randomize the position of the required characters
|
||||
random.shuffle(password)
|
||||
|
||||
return "".join(password)
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
@@ -215,7 +189,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
@@ -617,39 +591,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
return user
|
||||
|
||||
async def reset_password_as_admin(self, user_id: uuid.UUID) -> str:
|
||||
"""Admin-only. Generate a random password for a user and return it."""
|
||||
user = await self.get(user_id)
|
||||
new_password = generate_password()
|
||||
await self._update(user, {"password": new_password})
|
||||
return new_password
|
||||
|
||||
async def change_password_if_old_matches(
|
||||
self, user: User, old_password: str, new_password: str
|
||||
) -> None:
|
||||
"""
|
||||
For normal users to change password if they know the old one.
|
||||
Raises 400 if old password doesn't match.
|
||||
"""
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
old_password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
# Raise some HTTPException (or your custom exception) if old password is invalid:
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid current password",
|
||||
)
|
||||
|
||||
# If the hash was upgraded behind the scenes, we can keep it before setting the new password:
|
||||
if updated_password_hash:
|
||||
user.hashed_password = updated_password_hash
|
||||
|
||||
# Now apply and validate the new password
|
||||
await self._update(user, {"password": new_password})
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
@@ -874,9 +815,8 @@ async def current_limited_user(
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> User | None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
|
||||
)
|
||||
|
||||
@@ -33,7 +33,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -59,35 +58,13 @@ else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
class TenantAwareTask(Task):
|
||||
"""A custom base Task that sets tenant_id in a contextvar before running."""
|
||||
|
||||
abstract = True # So Celery knows not to register this as a real task.
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# Grab tenant_id from the kwargs, or fallback to default if missing.
|
||||
tenant_id = kwargs.get("tenant_id", None) or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Set the context var
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Actually run the task now
|
||||
try:
|
||||
return super().__call__(*args, **kwargs)
|
||||
finally:
|
||||
# Clear or reset after the task runs
|
||||
# so it does not leak into any subsequent tasks on the same worker process
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -140,7 +117,7 @@ def on_task_postrun(
|
||||
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
|
||||
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
|
||||
@@ -224,7 +201,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
@@ -310,7 +287,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
@@ -462,6 +439,24 @@ class TenantContextFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def set_tenant_id(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Signal handler to set tenant ID in context var before task starts."""
|
||||
tenant_id = (
|
||||
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
if kwargs
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def reset_tenant_id(
|
||||
sender: Any | None = None,
|
||||
|
||||
@@ -132,7 +132,6 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
|
||||
new_schedule[tenant_task_name] = tenant_task
|
||||
|
||||
return new_schedule
|
||||
@@ -257,4 +256,3 @@ def on_setup_logging(
|
||||
|
||||
|
||||
celery_app.conf.beat_scheduler = DynamicTenantScheduler
|
||||
celery_app.conf.task_default_base = app_base.TenantAwareTask
|
||||
|
||||
@@ -20,7 +20,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.heavy")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -21,7 +21,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.indexing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -23,7 +23,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.light")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -20,7 +20,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_default_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
@@ -38,7 +38,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -47,7 +47,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.primary")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
@@ -102,7 +101,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
|
||||
info: dict[str, Any] = cast(dict, r.info("replication"))
|
||||
@@ -159,7 +158,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
# mark orphaned index attempts as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
@@ -235,7 +234,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
|
||||
@@ -27,7 +27,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
@@ -62,8 +62,8 @@ class TaskDependencyError(RuntimeError):
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -77,14 +77,14 @@ def check_for_connector_deletion_task(
|
||||
try:
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair.id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
@@ -277,7 +277,7 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -287,7 +287,7 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
|
||||
@@ -45,7 +45,7 @@ from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -119,13 +119,13 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# TODO(rkuo): merge into check function after lookup table for fences is added
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -140,7 +140,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
try:
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
@@ -189,7 +189,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
@@ -247,7 +247,7 @@ def try_creating_permissions_sync_task(
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -321,7 +321,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
@@ -378,7 +378,7 @@ def connector_permission_sync_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -480,8 +480,7 @@ def update_external_document_permissions_task(
|
||||
external_access = document_external_access.external_access
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
|
||||
@@ -39,7 +39,7 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -122,8 +122,8 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -140,7 +140,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
try:
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
@@ -230,7 +230,7 @@ def try_creating_external_group_sync_task(
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -296,7 +296,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
@@ -357,7 +357,7 @@ def connector_external_group_sync_generator_task(
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -384,7 +384,6 @@ def connector_external_group_sync_generator_task(
|
||||
logger.info(
|
||||
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
logger.debug(f"New external user groups: {external_user_groups}")
|
||||
|
||||
replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
@@ -409,7 +408,7 @@ def connector_external_group_sync_generator_task(
|
||||
task_logger.exception(msg)
|
||||
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -460,6 +459,7 @@ def validate_external_group_sync_fences(
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -41,18 +41,16 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -92,9 +90,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SPAWN_FAILED = "spawn_failed" # connector spawn failed
|
||||
SPAWN_NOT_ALIVE = (
|
||||
"spawn_not_alive" # spawn succeeded but process did not come alive
|
||||
)
|
||||
|
||||
BLOCKED_BY_DELETION = "blocked_by_deletion"
|
||||
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
|
||||
@@ -108,9 +103,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
"index_attempt_mismatch" # expected index attempt metadata not found in db
|
||||
)
|
||||
|
||||
CONNECTOR_VALIDATION_ERROR = (
|
||||
"connector_validation_error" # the connector validation failed
|
||||
)
|
||||
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
|
||||
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
|
||||
|
||||
@@ -120,8 +112,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
# the watchdog terminated the task due to no activity
|
||||
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
|
||||
|
||||
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
|
||||
# consolidate once we know more
|
||||
OUT_OF_MEMORY = "out_of_memory"
|
||||
|
||||
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
|
||||
@@ -131,7 +121,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
|
||||
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
|
||||
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
|
||||
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
|
||||
@@ -148,8 +137,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
|
||||
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
|
||||
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
|
||||
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
|
||||
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
|
||||
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
|
||||
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
|
||||
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
|
||||
@@ -361,13 +348,12 @@ def monitor_ccpair_indexing_taskset(
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client()
|
||||
redis_client_replica = get_redis_replica_client()
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
@@ -405,7 +391,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# 1/3: KICKOFF
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
@@ -426,7 +412,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# gather cc_pair_ids
|
||||
lock_beat.reacquire()
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
@@ -436,7 +422,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
for search_settings_instance in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
@@ -514,7 +500,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
|
||||
db_session, redis_client
|
||||
)
|
||||
@@ -566,7 +552,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(
|
||||
tenant_id, key_bytes, redis_client_replica, db_session
|
||||
)
|
||||
@@ -597,8 +583,8 @@ def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -649,7 +635,7 @@ def connector_indexing_task(
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
raise SimpleJobException(
|
||||
@@ -743,7 +729,7 @@ def connector_indexing_task(
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
@@ -778,9 +764,9 @@ def connector_indexing_task(
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
redis_connector_index,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -802,15 +788,6 @@ def connector_indexing_task(
|
||||
# get back the total number of indexed docs and return it
|
||||
n_final_progress = redis_connector_index.get_progress()
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
except ConnectorValidationError:
|
||||
raise SimpleJobException(
|
||||
f"Indexing task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
@@ -818,8 +795,8 @@ def connector_indexing_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
@@ -930,11 +907,12 @@ def connector_indexing_proxy_task(
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
global_version.is_ee_version(),
|
||||
tenant_id,
|
||||
global_version.is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
|
||||
if not job or not job.process:
|
||||
if not job:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
@@ -945,39 +923,13 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the process has moved out of the starting state
|
||||
num_waits = 0
|
||||
while True:
|
||||
if num_waits > 15:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
job.release()
|
||||
return
|
||||
|
||||
if job.process.is_alive() or job.process.exitcode is not None:
|
||||
break
|
||||
|
||||
sleep(1)
|
||||
num_waits += 1
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawn succeeded",
|
||||
pid=str(job.process.pid),
|
||||
)
|
||||
)
|
||||
task_logger.info(log_builder.build("Indexing watchdog - spawn succeeded"))
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
@@ -988,9 +940,6 @@ def connector_indexing_proxy_task(
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
|
||||
redis_connector_index.set_active() # renew active signal
|
||||
redis_connector_index.set_connector_active() # prime the connective active signal
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
@@ -1025,42 +974,10 @@ def connector_indexing_proxy_task(
|
||||
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
|
||||
break
|
||||
|
||||
if not redis_connector_index.connector_active():
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - activity timeout exceeded",
|
||||
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Indexing watchdog - activity timeout exceeded: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
except Exception:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
result.status = (
|
||||
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
|
||||
)
|
||||
break
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
@@ -1079,20 +996,16 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# No need to expose full stack trace for validation errors
|
||||
result.exception_str = str(e)
|
||||
else:
|
||||
result.exception_str = traceback.format_exc()
|
||||
result.exception_str = traceback.format_exc()
|
||||
|
||||
# handle exit and reporting
|
||||
elapsed = time.monotonic() - start
|
||||
if result.exception_str is not None:
|
||||
# print with exception
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
@@ -1132,7 +1045,7 @@ def connector_indexing_proxy_task(
|
||||
# print without exception
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
@@ -1182,7 +1095,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
|
||||
try:
|
||||
locked = True
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
|
||||
for attempt in old_attempts:
|
||||
task_logger.info(
|
||||
@@ -1218,5 +1131,5 @@ def cleanup_checkpoint_task(
|
||||
self: Task, *, index_attempt_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
cleanup_checkpoint(db_session, index_attempt_id)
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
@@ -93,25 +93,27 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
redis_connector: RedisConnector,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = f"{self.__class__.__name__}.__init__"
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
@@ -125,8 +127,8 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon=True. It seems likely some indexing tasks will need to spawn other processes
|
||||
# eventually, which daemon=True prevents, so leave this code in until we're ready to test it.
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
@@ -141,6 +143,8 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_connector.prune.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
@@ -152,7 +156,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
self.last_tag = tag
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"{self.__class__.__name__} - lock.reacquire exceptioned: "
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
@@ -163,24 +167,6 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
|
||||
class IndexingCallback(IndexingCallbackBase):
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
):
|
||||
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
|
||||
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
self.redis_connector_index.set_active()
|
||||
self.redis_connector_index.set_connector_active()
|
||||
super().progress(tag, amount)
|
||||
self.redis_client.incrby(
|
||||
self.redis_connector_index.generator_progress_key, amount
|
||||
)
|
||||
@@ -332,7 +318,7 @@ def validate_indexing_fences(
|
||||
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
|
||||
@@ -8,7 +8,7 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | N
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
|
||||
@@ -26,8 +26,7 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
@@ -43,6 +42,7 @@ from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
|
||||
|
||||
@@ -668,7 +668,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
task_logger.info("Starting background monitoring")
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_monitoring: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
|
||||
@@ -683,7 +683,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client()
|
||||
redis_std = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
@@ -693,7 +693,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
]
|
||||
|
||||
# Collect and log each metric
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
@@ -771,11 +771,12 @@ def cloud_check_alembic() -> bool | None:
|
||||
if tenant_id is None:
|
||||
continue
|
||||
|
||||
with get_session_with_shared_schema() as session:
|
||||
with get_session_with_tenant(tenant_id=None) as session:
|
||||
try:
|
||||
result = session.execute(
|
||||
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
|
||||
)
|
||||
|
||||
result_scalar: str | None = result.scalar_one_or_none()
|
||||
if result_scalar is None:
|
||||
raise ValueError("Alembic version should not be None.")
|
||||
|
||||
@@ -15,7 +15,7 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PostgresAdvisoryLocks
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -36,7 +36,7 @@ def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -41,7 +41,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
@@ -62,12 +62,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class PruneCallback(IndexingCallbackBase):
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
self.redis_connector.prune.set_active()
|
||||
super().progress(tag, amount)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
@@ -114,8 +108,8 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -133,14 +127,14 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# but pruning only kicks off once per hour
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -188,7 +182,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
@@ -343,7 +337,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
@@ -401,7 +395,7 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
@@ -431,7 +425,6 @@ def connector_pruning_generator_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
@@ -441,11 +434,12 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
redis_connector.new_index(search_settings.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
callback = PruneCallback(
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ from onyx.db.document import mark_document_as_modified
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
@@ -79,7 +79,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
@@ -205,7 +205,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"doc={document_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# delete the cc pair relationship now and let reconciliation clean it up
|
||||
# in vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
|
||||
@@ -34,7 +34,7 @@ from onyx.db.document_set import fetch_document_sets
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import mark_document_set_as_synced
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import DocumentSet
|
||||
@@ -84,8 +84,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
@@ -98,7 +98,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
try:
|
||||
# 1/3: KICKOFF
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -106,7 +106,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
# region document set scan
|
||||
lock_beat.reacquire()
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
@@ -117,7 +117,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -136,7 +136,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
pass
|
||||
else:
|
||||
usergroup_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
@@ -146,7 +146,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -167,7 +167,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
monitor_usergroup_taskset = (
|
||||
@@ -177,7 +177,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -523,12 +523,12 @@ def monitor_document_set_taskset(
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, *, tenant_id: str | None
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from onyx.db.background_error import create_background_error
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
def emit_background_error(
|
||||
@@ -9,5 +9,5 @@ def emit_background_error(
|
||||
"""Currently just saves a row in the background_errors table.
|
||||
|
||||
In the future, could create notifications based on the severity."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant() as db_session:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
@@ -29,7 +28,7 @@ from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -87,11 +86,6 @@ def _get_connector_runner(
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
|
||||
runnable_connector.validate_connector_settings()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
||||
@@ -250,7 +244,7 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
@@ -376,7 +370,7 @@ def _run_indexing(
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
@@ -436,7 +430,7 @@ def _run_indexing(
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp, ctx, index_attempt_id
|
||||
@@ -445,7 +439,7 @@ def _run_indexing(
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
@@ -509,7 +503,7 @@ def _run_indexing(
|
||||
if document.id not in failed_document_ids
|
||||
]
|
||||
for document_id in successful_document_ids:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
if document_id in doc_id_to_unresolved_errors:
|
||||
logger.info(
|
||||
f"Resolving IndexAttemptError for document '{document_id}'"
|
||||
@@ -522,7 +516,7 @@ def _run_indexing(
|
||||
# add brand new failures
|
||||
if index_pipeline_result.failures:
|
||||
total_failures += len(index_pipeline_result.failures)
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
for failure in index_pipeline_result.failures:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
@@ -539,7 +533,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
|
||||
# so we need either to commit() or to use a new session
|
||||
update_docs_indexed(
|
||||
@@ -561,7 +555,7 @@ def _run_indexing(
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# save latest checkpoint
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
@@ -573,29 +567,9 @@ def _run_indexing(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
f"{time.monotonic() - start_time} seconds"
|
||||
)
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# On validation errors during indexing, we want to cancel the indexing attempt
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
@@ -613,7 +587,7 @@ def _run_indexing(
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
@@ -635,7 +609,7 @@ def _run_indexing(
|
||||
memory_tracer.stop()
|
||||
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# resolve entity-based errors
|
||||
for error in entity_based_unresolved_errors:
|
||||
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
|
||||
@@ -695,7 +669,7 @@ def run_indexing_entrypoint(
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
@@ -716,7 +690,7 @@ def run_indexing_entrypoint(
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -143,10 +143,9 @@ from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
def _translate_citations(
|
||||
@@ -343,7 +342,7 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
@@ -632,7 +631,6 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
@@ -1017,7 +1015,7 @@ def stream_chat_message_objects(
|
||||
if info.message_specific_citations
|
||||
else None
|
||||
),
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
|
||||
@@ -1055,9 +1053,7 @@ def stream_chat_message_objects(
|
||||
citations=info.message_specific_citations.citation_map
|
||||
if info.message_specific_citations
|
||||
else None,
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
is_agentic=True,
|
||||
)
|
||||
next_level += 1
|
||||
prev_message = next_answer_message
|
||||
|
||||
@@ -178,13 +178,13 @@ AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
@@ -204,7 +204,7 @@ AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
|
||||
@@ -217,7 +217,7 @@ AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
@@ -256,7 +256,7 @@ AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
@@ -269,7 +269,7 @@ AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
@@ -282,7 +282,7 @@ AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 2 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 1 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
|
||||
@@ -295,7 +295,7 @@ AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
@@ -308,7 +308,7 @@ AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
@@ -321,7 +321,7 @@ AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
|
||||
|
||||
@@ -98,18 +98,9 @@ CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
|
||||
# hard timeout applied by the watchdog to the indexing connector run
|
||||
# to handle hung connectors
|
||||
CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
|
||||
|
||||
# soft timeout for the lock taken by the indexing connector run
|
||||
# allows the lock to eventually expire if the managing code around it dies
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
# CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 15 minutes
|
||||
# hard termination should always fire first if the connector is hung
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
|
||||
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
@@ -5,8 +5,6 @@ import requests
|
||||
|
||||
class BookStackClientRequestFailedError(ConnectionError):
|
||||
def __init__(self, status: int, error: str) -> None:
|
||||
self.status_code = status
|
||||
self.error = error
|
||||
super().__init__(
|
||||
"BookStack Client request failed with status {status}: {error}".format(
|
||||
status=status, error=error
|
||||
|
||||
@@ -7,12 +7,8 @@ from typing import Any
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.bookstack.client import BookStackApiClient
|
||||
from onyx.connectors.bookstack.client import BookStackClientRequestFailedError
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import CredentialExpiredError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -218,39 +214,3 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
break
|
||||
else:
|
||||
time.sleep(0.2)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Validate that the BookStack credentials and connector settings are correct.
|
||||
Specifically checks that we can make an authenticated request to BookStack.
|
||||
"""
|
||||
if not self.bookstack_client:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"BookStack credentials have not been loaded."
|
||||
)
|
||||
|
||||
try:
|
||||
# Attempt to fetch a small batch of books (arbitrary endpoint) to verify credentials
|
||||
_ = self.bookstack_client.get(
|
||||
"/books", params={"count": "1", "offset": "0"}
|
||||
)
|
||||
|
||||
except BookStackClientRequestFailedError as e:
|
||||
# Check for HTTP status codes
|
||||
if e.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Your BookStack credentials appear to be invalid or expired (HTTP 401)."
|
||||
) from e
|
||||
elif e.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"The configured BookStack token does not have sufficient permissions (HTTP 403)."
|
||||
) from e
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected BookStack error (status={e.status_code}): {e}"
|
||||
) from e
|
||||
|
||||
except Exception as exc:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected error while validating BookStack connector settings: {exc}"
|
||||
) from exc
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -30,16 +29,6 @@ class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConfluenceUser(BaseModel):
|
||||
user_id: str # accountId in Cloud, userKey in Server
|
||||
username: str | None # Confluence Cloud doesn't give usernames
|
||||
display_name: str
|
||||
# Confluence Data Center doesn't give email back by default,
|
||||
# have to fetch it with a different endpoint
|
||||
email: str | None
|
||||
type: str
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
@@ -286,95 +275,21 @@ class OnyxConfluence(Confluence):
|
||||
self,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[ConfluenceUser]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The search/user endpoint can be used to fetch users.
|
||||
It's a seperate endpoint from the content/search endpoint used only for users.
|
||||
Otherwise it's very similar to the content/search endpoint.
|
||||
"""
|
||||
if self.cloud:
|
||||
cql = "type=user"
|
||||
url = "rest/api/search/user"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
for user_result in self._paginate_url(url, limit):
|
||||
# Example response:
|
||||
# {
|
||||
# 'user': {
|
||||
# 'type': 'known',
|
||||
# 'accountId': '712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
|
||||
# 'accountType': 'atlassian',
|
||||
# 'email': 'chris@danswer.ai',
|
||||
# 'publicName': 'Chris Weaver',
|
||||
# 'profilePicture': {
|
||||
# 'path': '/wiki/aa-avatar/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
|
||||
# 'width': 48,
|
||||
# 'height': 48,
|
||||
# 'isDefault': False
|
||||
# },
|
||||
# 'displayName': 'Chris Weaver',
|
||||
# 'isExternalCollaborator': False,
|
||||
# '_expandable': {
|
||||
# 'operations': '',
|
||||
# 'personalSpace': ''
|
||||
# },
|
||||
# '_links': {
|
||||
# 'self': 'https://danswerai.atlassian.net/wiki/rest/api/user?accountId=712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d'
|
||||
# }
|
||||
# },
|
||||
# 'title': 'Chris Weaver',
|
||||
# 'excerpt': '',
|
||||
# 'url': '/people/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
|
||||
# 'breadcrumbs': [],
|
||||
# 'entityType': 'user',
|
||||
# 'iconCssClass': 'aui-icon content-type-profile',
|
||||
# 'lastModified': '2025-02-18T04:08:03.579Z',
|
||||
# 'score': 0.0
|
||||
# }
|
||||
user = user_result["user"]
|
||||
yield ConfluenceUser(
|
||||
user_id=user["accountId"],
|
||||
username=None,
|
||||
display_name=user["displayName"],
|
||||
email=user.get("email"),
|
||||
type=user["accountType"],
|
||||
)
|
||||
else:
|
||||
# https://developer.atlassian.com/server/confluence/rest/v900/api-group-user/#api-rest-api-user-list-get
|
||||
# ^ is only available on data center deployments
|
||||
# Example response:
|
||||
# [
|
||||
# {
|
||||
# 'type': 'known',
|
||||
# 'username': 'admin',
|
||||
# 'userKey': '40281082950c5fe901950c61c55d0000',
|
||||
# 'profilePicture': {
|
||||
# 'path': '/images/icons/profilepics/default.svg',
|
||||
# 'width': 48,
|
||||
# 'height': 48,
|
||||
# 'isDefault': True
|
||||
# },
|
||||
# 'displayName': 'Admin Test',
|
||||
# '_links': {
|
||||
# 'self': 'http://localhost:8090/rest/api/user?key=40281082950c5fe901950c61c55d0000'
|
||||
# },
|
||||
# '_expandable': {
|
||||
# 'status': ''
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
for user in self._paginate_url("rest/api/user/list", limit):
|
||||
yield ConfluenceUser(
|
||||
user_id=user["userKey"],
|
||||
username=user["username"],
|
||||
display_name=user["displayName"],
|
||||
email=None,
|
||||
type=user.get("type", "user"),
|
||||
)
|
||||
cql = "type=user"
|
||||
url = "rest/api/search/user" if self.cloud else "rest/api/search"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
yield from self._paginate_url(url, limit)
|
||||
|
||||
def paginated_groups_by_user_retrieval(
|
||||
self,
|
||||
user_id: str, # accountId in Cloud, userKey in Server
|
||||
user: dict[str, Any],
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
@@ -382,7 +297,7 @@ class OnyxConfluence(Confluence):
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
user_field = "accountId" if self.cloud else "key"
|
||||
user_value = user_id
|
||||
user_value = user["accountId"] if self.cloud else user["userKey"]
|
||||
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
|
||||
user_query = f"{user_field}={quote(user_value)}"
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
@@ -25,22 +24,16 @@ def datetime_to_utc(dt: datetime) -> datetime:
|
||||
|
||||
|
||||
def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
# Remove all timezone abbreviations in parentheses
|
||||
datetime_str = re.sub(r"\([A-Z]+\)", "", datetime_str).strip()
|
||||
|
||||
# Remove any remaining parentheses and their contents
|
||||
datetime_str = re.sub(r"\(.*?\)", "", datetime_str).strip()
|
||||
|
||||
try:
|
||||
dt = parse(datetime_str)
|
||||
except ValueError:
|
||||
# Fix common format issues (e.g. "0000" => "+0000")
|
||||
# Handle malformed timezone by attempting to fix common format issues
|
||||
if "0000" in datetime_str:
|
||||
datetime_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(datetime_str)
|
||||
# Convert "0000" to "+0000" for proper timezone parsing
|
||||
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(fixed_dt_str)
|
||||
else:
|
||||
raise
|
||||
|
||||
return datetime_to_utc(dt)
|
||||
|
||||
|
||||
|
||||
@@ -4,16 +4,12 @@ from typing import Any
|
||||
|
||||
from dropbox import Dropbox # type: ignore
|
||||
from dropbox.exceptions import ApiError # type:ignore
|
||||
from dropbox.exceptions import AuthError # type:ignore
|
||||
from dropbox.files import FileMetadata # type:ignore
|
||||
from dropbox.files import FolderMetadata # type:ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import CredentialInvalidError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -145,29 +141,6 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox credentials not loaded.")
|
||||
|
||||
try:
|
||||
self.dropbox_client.files_list_folder(path="", limit=1)
|
||||
except AuthError as e:
|
||||
logger.exception("Failed to validate Dropbox credentials")
|
||||
raise CredentialInvalidError(f"Dropbox credential is invalid: {e.error}")
|
||||
except ApiError as e:
|
||||
if (
|
||||
e.error is not None
|
||||
and "insufficient_permissions" in str(e.error).lower()
|
||||
):
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Dropbox token does not have sufficient permissions."
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected Dropbox error during validation: {e.user_message_text or e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Unexpected error during Dropbox settings validation: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -31,7 +31,6 @@ from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -53,11 +52,8 @@ from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
from onyx.connectors.xenforo.connector import XenforoConnector
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from onyx.connectors.zulip.connector import ZulipConnector
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
class ConnectorMissingException(Exception):
|
||||
@@ -178,38 +174,3 @@ def instantiate_connector(
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
def validate_ccpair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
# Validate the connector settings
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id_for_user(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
if not credential:
|
||||
raise ValueError("Credential not found")
|
||||
if not connector:
|
||||
raise ValueError("Connector not found")
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=connector.source,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
credential=credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(str(e))
|
||||
|
||||
runnable_connector.validate_connector_settings()
|
||||
|
||||
@@ -181,7 +181,7 @@ class LocalFileConnector(LoadConnector):
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
|
||||
with get_session_with_tenant(self.tenant_id) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
|
||||
@@ -187,12 +187,12 @@ class FirefliesConnector(PollConnector, LoadConnector):
|
||||
return self._process_transcripts()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
start_datetime = datetime.fromtimestamp(
|
||||
start_unixtime, tz=timezone.utc
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import cast
|
||||
from github import Github
|
||||
from github import RateLimitExceededException
|
||||
from github import Repository
|
||||
from github.GithubException import GithubException
|
||||
from github.Issue import Issue
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
@@ -17,20 +16,17 @@ from github.PullRequest import PullRequest
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import CredentialExpiredError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import UnexpectedError
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -230,48 +226,6 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
|
||||
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
|
||||
|
||||
if not self.repo_owner or not self.repo_name:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: 'repo_owner' and 'repo_name' must be provided."
|
||||
)
|
||||
|
||||
try:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
|
||||
except RateLimitExceededException:
|
||||
raise UnexpectedError(
|
||||
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
|
||||
except GithubException as e:
|
||||
if e.status == 401:
|
||||
raise CredentialExpiredError(
|
||||
"GitHub credential appears to be invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif e.status == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
|
||||
)
|
||||
elif e.status == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected GitHub error (status={e.status}): {e.data}"
|
||||
)
|
||||
except Exception as exc:
|
||||
raise Exception(
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -297,7 +297,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
|
||||
SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
@@ -44,14 +45,6 @@ class BaseConnector(abc.ABC):
|
||||
raise RuntimeError(custom_parser_req_msg)
|
||||
return metadata_lines
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Override this if your connector needs to validate credentials or settings.
|
||||
Raise an exception if invalid, otherwise do nothing.
|
||||
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
|
||||
|
||||
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
||||
class LoadConnector(BaseConnector):
|
||||
@@ -146,46 +139,3 @@ class CheckpointConnector(BaseConnector):
|
||||
```
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ConnectorValidationError(Exception):
|
||||
"""General exception for connector validation errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class UnexpectedError(Exception):
|
||||
"""Raised when an unexpected error occurs during connector validation.
|
||||
|
||||
Unexpected errors don't necessarily mean the credential is invalid,
|
||||
but rather that there was an error during the validation process
|
||||
or we encountered a currently unhandled error case.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Unexpected error during connector validation"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CredentialInvalidError(ConnectorValidationError):
|
||||
"""Raised when a connector's credential is invalid."""
|
||||
|
||||
def __init__(self, message: str = "Credential is invalid"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CredentialExpiredError(ConnectorValidationError):
|
||||
"""Raised when a connector's credential is expired."""
|
||||
|
||||
def __init__(self, message: str = "Credential has expired"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InsufficientPermissionsError(ConnectorValidationError):
|
||||
"""Raised when the credential does not have sufficient API permissions."""
|
||||
|
||||
def __init__(
|
||||
self, message: str = "Insufficient permissions for the requested operation"
|
||||
):
|
||||
super().__init__(message)
|
||||
|
||||
@@ -7,7 +7,6 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -16,14 +15,10 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import CredentialExpiredError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.batching import batch_generator
|
||||
@@ -621,64 +616,6 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
else:
|
||||
break
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if not self.headers.get("Authorization"):
|
||||
raise ConnectorMissingCredentialError("Notion credentials not loaded.")
|
||||
|
||||
try:
|
||||
# We'll do a minimal search call (page_size=1) to confirm accessibility
|
||||
if self.root_page_id:
|
||||
# If root_page_id is set, fetch the specific page
|
||||
res = rl_requests.get(
|
||||
f"https://api.notion.com/v1/pages/{self.root_page_id}",
|
||||
headers=self.headers,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
else:
|
||||
# If root_page_id is not set, perform a minimal search
|
||||
test_query = {
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
"page_size": 1,
|
||||
}
|
||||
res = rl_requests.post(
|
||||
"https://api.notion.com/v1/search",
|
||||
headers=self.headers,
|
||||
json=test_query,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
res.raise_for_status()
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else None
|
||||
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Notion credential appears to be invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Notion token does not have sufficient permissions (HTTP 403)."
|
||||
)
|
||||
elif status_code == 404:
|
||||
# Typically means resource not found or not shared. Could be root_page_id is invalid.
|
||||
raise ConnectorValidationError(
|
||||
"Notion resource not found or not shared with the integration (HTTP 404)."
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Notion rate-limits being exceeded (HTTP 429). "
|
||||
"Please try again later."
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Unexpected Notion HTTP error (status={status_code}): {http_err}"
|
||||
) from http_err
|
||||
|
||||
except Exception as exc:
|
||||
raise Exception(
|
||||
f"Unexpected error during Notion settings validation: {exc}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -12,11 +12,8 @@ from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import CredentialExpiredError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
@@ -275,40 +272,6 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
yield slim_doc_batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
if not self._jira_project:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: 'jira_project' must be provided."
|
||||
)
|
||||
|
||||
try:
|
||||
self.jira_client.project(self._jira_project)
|
||||
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", None)
|
||||
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Jira credential appears to be expired or invalid (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
|
||||
)
|
||||
elif status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Jira project not found with key: {self._jira_project}"
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"Unexpected Jira error during validation: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -25,12 +25,8 @@ from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_TOKEN_URL
|
||||
from onyx.configs.app_configs import WEB_CONNECTOR_VALIDATE_URLS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.connectors.interfaces import CredentialExpiredError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import UnexpectedError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
@@ -41,8 +37,6 @@ from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
|
||||
|
||||
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
|
||||
# Given a base site, index everything under that path
|
||||
@@ -176,35 +170,26 @@ def start_playwright() -> Tuple[Playwright, BrowserContext]:
|
||||
|
||||
|
||||
def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
|
||||
try:
|
||||
response = requests.get(sitemap_url)
|
||||
response.raise_for_status()
|
||||
response = requests.get(sitemap_url)
|
||||
response.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
urls = [
|
||||
_ensure_absolute_url(sitemap_url, loc_tag.text)
|
||||
for loc_tag in soup.find_all("loc")
|
||||
]
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
urls = [
|
||||
_ensure_absolute_url(sitemap_url, loc_tag.text)
|
||||
for loc_tag in soup.find_all("loc")
|
||||
]
|
||||
|
||||
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
|
||||
# the given url doesn't look like a sitemap, let's try to find one
|
||||
urls = list_pages_for_site(sitemap_url)
|
||||
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
|
||||
# the given url doesn't look like a sitemap, let's try to find one
|
||||
urls = list_pages_for_site(sitemap_url)
|
||||
|
||||
if len(urls) == 0:
|
||||
raise ValueError(
|
||||
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
|
||||
)
|
||||
|
||||
return urls
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"Failed to fetch sitemap from {sitemap_url}: {e}")
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error processing sitemap {sitemap_url}: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Unexpected error while processing sitemap {sitemap_url}: {e}"
|
||||
if len(urls) == 0:
|
||||
raise ValueError(
|
||||
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
|
||||
)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
|
||||
if not urlparse(maybe_relative_url).netloc:
|
||||
@@ -240,14 +225,10 @@ class WebConnector(LoadConnector):
|
||||
web_connector_type: str = WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value,
|
||||
mintlify_cleanup: bool = True, # Mostly ok to apply to other websites as well
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
scroll_before_scraping: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.mintlify_cleanup = mintlify_cleanup
|
||||
self.batch_size = batch_size
|
||||
self.recursive = False
|
||||
self.scroll_before_scraping = scroll_before_scraping
|
||||
self.web_connector_type = web_connector_type
|
||||
|
||||
if web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value:
|
||||
self.recursive = True
|
||||
@@ -363,18 +344,6 @@ class WebConnector(LoadConnector):
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
|
||||
if self.scroll_before_scraping:
|
||||
scroll_attempts = 0
|
||||
previous_height = page.evaluate("document.body.scrollHeight")
|
||||
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
|
||||
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||
page.wait_for_load_state("networkidle", timeout=30000)
|
||||
new_height = page.evaluate("document.body.scrollHeight")
|
||||
if new_height == previous_height:
|
||||
break # Stop scrolling when no more content is loaded
|
||||
previous_height = new_height
|
||||
scroll_attempts += 1
|
||||
|
||||
content = page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
@@ -433,53 +402,6 @@ class WebConnector(LoadConnector):
|
||||
raise RuntimeError(last_error)
|
||||
raise RuntimeError("No valid pages found.")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
# Make sure we have at least one valid URL to check
|
||||
if not self.to_visit_list:
|
||||
raise ConnectorValidationError(
|
||||
"No URL configured. Please provide at least one valid URL."
|
||||
)
|
||||
|
||||
if self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP.value:
|
||||
return None
|
||||
|
||||
# We'll just test the first URL for connectivity and correctness
|
||||
test_url = self.to_visit_list[0]
|
||||
|
||||
# Check that the URL is allowed and well-formed
|
||||
try:
|
||||
protected_url_check(test_url)
|
||||
except ValueError as e:
|
||||
raise ConnectorValidationError(
|
||||
f"Protected URL check failed for '{test_url}': {e}"
|
||||
)
|
||||
except ConnectionError as e:
|
||||
# Typically DNS or other network issues
|
||||
raise ConnectorValidationError(str(e))
|
||||
|
||||
# Make a quick request to see if we get a valid response
|
||||
try:
|
||||
check_internet_connection(test_url)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
if "401" in err_str:
|
||||
raise CredentialExpiredError(
|
||||
f"Unauthorized access to '{test_url}': {e}"
|
||||
)
|
||||
elif "403" in err_str:
|
||||
raise InsufficientPermissionsError(
|
||||
f"Forbidden access to '{test_url}': {e}"
|
||||
)
|
||||
elif "404" in err_str:
|
||||
raise ConnectorValidationError(f"Page not found for '{test_url}': {e}")
|
||||
elif "Max retries exceeded" in err_str and "NameResolutionError" in err_str:
|
||||
raise ConnectorValidationError(
|
||||
f"Unable to resolve hostname for '{test_url}'. Please check the URL and your internet connection."
|
||||
)
|
||||
else:
|
||||
# Could be a 5xx or another error, treat as unexpected
|
||||
raise UnexpectedError(f"Unexpected error validating '{test_url}': {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = WebConnector("https://docs.onyx.app/")
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.context.search.preprocessing.access_filters import (
|
||||
from onyx.context.search.retrieval.search_runner import (
|
||||
remove_stop_words_and_punctuation,
|
||||
)
|
||||
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -34,7 +35,6 @@ from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -166,7 +166,7 @@ def retrieval_preprocessing(
|
||||
time_cutoff=time_filter or predicted_time_cutoff,
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
|
||||
@@ -17,7 +17,7 @@ from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
@@ -71,7 +71,7 @@ def insert_api_key(
|
||||
std_password_helper = PasswordHelper()
|
||||
|
||||
# Get tenant_id from context var (will be default schema for single tenant)
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
|
||||
api_key_user_id = uuid.uuid4()
|
||||
|
||||
@@ -629,7 +629,6 @@ def create_new_chat_message(
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
refined_answer_improvement: bool | None = None,
|
||||
is_agentic: bool = False,
|
||||
) -> ChatMessage:
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
@@ -651,7 +650,7 @@ def create_new_chat_message(
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.overridden_model = overridden_model
|
||||
existing_message.refined_answer_improvement = refined_answer_improvement
|
||||
existing_message.is_agentic = is_agentic
|
||||
|
||||
new_chat_message = existing_message
|
||||
else:
|
||||
# Create new message
|
||||
@@ -671,7 +670,6 @@ def create_new_chat_message(
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
overridden_model=overridden_model,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
is_agentic=is_agentic,
|
||||
)
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
@@ -962,7 +960,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
chat_message.sub_questions
|
||||
),
|
||||
refined_answer_improvement=chat_message.refined_answer_improvement,
|
||||
error=chat_message.error,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -246,10 +245,6 @@ def swap_credentials_connector(
|
||||
existing_pair.credential_id = new_credential_id
|
||||
existing_pair.credential = new_credential
|
||||
|
||||
# Update ccpair status if it's in INVALID state
|
||||
if existing_pair.status == ConnectorCredentialPairStatus.INVALID:
|
||||
existing_pair.status = ConnectorCredentialPairStatus.ACTIVE
|
||||
|
||||
# Commit the changes
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
@@ -15,6 +16,7 @@ from typing import ContextManager
|
||||
import asyncpg # type: ignore
|
||||
import boto3
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import text
|
||||
@@ -42,13 +44,13 @@ from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -263,7 +265,7 @@ def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
|
||||
with get_session_with_shared_schema() as session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
|
||||
result = session.execute(
|
||||
text(
|
||||
f"""
|
||||
@@ -351,6 +353,38 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
async def get_current_tenant_id(request: Request) -> str:
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
logger.debug(
|
||||
f"Token data not found or expired in Redis, defaulting to {current_value}"
|
||||
)
|
||||
return current_value
|
||||
|
||||
tenant_id = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return tenant_id
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
# Listen for events on the synchronous Session class
|
||||
@event.listens_for(Session, "after_begin")
|
||||
def _set_search_path(
|
||||
@@ -376,7 +410,7 @@ async def get_async_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> AsyncGenerator[AsyncSession, None]:
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
logger.error(f"Invalid tenant ID: {tenant_id}")
|
||||
@@ -399,80 +433,82 @@ async def get_async_session_with_tenant(
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_current_tenant() -> Generator[Session, None, None]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
||||
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Used in multi tenant mode when need to refer to the shared `public` schema
|
||||
@contextmanager
|
||||
def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
|
||||
yield session
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
|
||||
def get_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
This function:
|
||||
1. Sets the database schema to the specified tenant's schema.
|
||||
2. Preserves the tenant ID across the session.
|
||||
3. Reverts to the previous tenant ID after the session is closed.
|
||||
4. Uses the default schema if no tenant ID is provided.
|
||||
"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
|
||||
with engine.connect() as connection:
|
||||
dbapi_connection = connection.connection
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
with engine.connect() as connection:
|
||||
dbapi_connection = connection.connection
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
yield session
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
cursor.close()
|
||||
|
||||
with Session(bind=connection, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
if MULTI_TENANT:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute('SET search_path TO "$user", public')
|
||||
finally:
|
||||
cursor.close()
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
|
||||
|
||||
|
||||
def set_search_path_on_checkout(
|
||||
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id and is_valid_schema_name(tenant_id):
|
||||
with dbapi_conn.cursor() as cursor:
|
||||
cursor.execute(f'SET search_path TO "{tenant_id}"')
|
||||
|
||||
|
||||
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(detail="User must authenticate")
|
||||
|
||||
@@ -487,7 +523,7 @@ def get_session() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
engine = get_sqlalchemy_async_engine()
|
||||
async with AsyncSession(engine, expire_on_commit=False) as async_session:
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -73,7 +73,6 @@ class ConnectorCredentialPairStatus(str, PyEnum):
|
||||
ACTIVE = "ACTIVE"
|
||||
PAUSED = "PAUSED"
|
||||
DELETING = "DELETING"
|
||||
INVALID = "INVALID"
|
||||
|
||||
def is_active(self) -> bool:
|
||||
return self == ConnectorCredentialPairStatus.ACTIVE
|
||||
|
||||
@@ -148,12 +148,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
putting here for simpicity
|
||||
"""
|
||||
|
||||
temperature_override_enabled: Mapped[bool | None] = mapped_column(
|
||||
Boolean, default=None
|
||||
)
|
||||
auto_scroll: Mapped[bool | None] = mapped_column(Boolean, default=None)
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
@@ -205,13 +204,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
|
||||
@property
|
||||
def password_configured(self) -> bool:
|
||||
"""
|
||||
Returns True if the user has at least one OAuth (or OIDC) account.
|
||||
"""
|
||||
return not bool(self.oauth_accounts)
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
pass
|
||||
@@ -350,9 +342,7 @@ class Document__Tag(Base):
|
||||
document_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("document.id"), primary_key=True
|
||||
)
|
||||
tag_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("tag.id"), primary_key=True, index=True
|
||||
)
|
||||
tag_id: Mapped[int] = mapped_column(ForeignKey("tag.id"), primary_key=True)
|
||||
|
||||
|
||||
class Persona__Tool(Base):
|
||||
@@ -1231,7 +1221,6 @@ class ChatMessage(Base):
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
is_agentic: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
refined_answer_improvement: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
||||
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
@@ -1753,7 +1742,6 @@ class ChannelConfig(TypedDict):
|
||||
# If empty list, follow up with no tags
|
||||
follow_up_tags: NotRequired[list[str]]
|
||||
show_continue_in_web_ui: NotRequired[bool] # defaults to False
|
||||
disabled: NotRequired[bool] # defaults to False
|
||||
|
||||
|
||||
class SlackChannelConfig(Base):
|
||||
@@ -1777,7 +1765,6 @@ class SlackChannelConfig(Base):
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
|
||||
slack_bot: Mapped["SlackBot"] = relationship(
|
||||
"SlackBot",
|
||||
back_populates="slack_channel_configs",
|
||||
|
||||
@@ -13,7 +13,7 @@ from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from onyx.context.search.models import SavedSearchSettings
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_default_tenant
|
||||
from onyx.db.llm import fetch_embedding_provider
|
||||
from onyx.db.models import CloudEmbeddingProvider
|
||||
from onyx.db.models import IndexAttempt
|
||||
@@ -189,7 +189,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
|
||||
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||
if db_session is None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
else:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
@@ -151,7 +151,6 @@ def update_slack_channel_config(
|
||||
channel_config: ChannelConfig,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
disabled: bool,
|
||||
) -> SlackChannelConfig:
|
||||
slack_channel_config = db_session.scalar(
|
||||
select(SlackChannelConfig).where(
|
||||
|
||||
@@ -73,7 +73,7 @@ from onyx.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
|
||||
from onyx.document_index.vespa_constants import YQL_BASE
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.key_value_store.factory import get_shared_kv_store
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -193,7 +193,7 @@ class VespaIndex(DocumentIndex):
|
||||
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
|
||||
)
|
||||
|
||||
kv_store = get_shared_kv_store()
|
||||
kv_store = get_kv_store()
|
||||
|
||||
needs_reindexing = False
|
||||
try:
|
||||
@@ -240,9 +240,6 @@ class VespaIndex(DocumentIndex):
|
||||
headers = {"Content-Type": "application/zip"}
|
||||
response = requests.post(deploy_url, headers=headers, data=zip_file)
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to prepare Vespa Onyx Index. Response: {response.text}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to prepare Vespa Onyx Index. Response: {response.text}"
|
||||
)
|
||||
@@ -280,7 +277,7 @@ class VespaIndex(DocumentIndex):
|
||||
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
|
||||
)
|
||||
|
||||
kv_store = get_shared_kv_store()
|
||||
kv_store = get_kv_store()
|
||||
|
||||
needs_reindexing = False
|
||||
try:
|
||||
|
||||
@@ -36,9 +36,7 @@ def build_vespa_filters(
|
||||
|
||||
eq_elems = [f'{key} contains "{elem}"' for elem in valid_vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
result = f"({or_clause}) and "
|
||||
|
||||
return result
|
||||
return f"({or_clause}) and "
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
|
||||
@@ -320,13 +320,7 @@ def eml_to_text(file: IO[Any]) -> str:
|
||||
text_content = []
|
||||
for part in message.walk():
|
||||
if part.get_content_type().startswith("text/plain"):
|
||||
payload = part.get_payload()
|
||||
if isinstance(payload, str):
|
||||
text_content.append(payload)
|
||||
elif isinstance(payload, list):
|
||||
text_content.extend(item for item in payload if isinstance(item, str))
|
||||
else:
|
||||
logger.warning(f"Unexpected payload type: {type(payload)}")
|
||||
text_content.append(part.get_payload())
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
weird errors."""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -76,7 +76,7 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
|
||||
|
||||
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unique_id = str(uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
|
||||
@@ -1,18 +1,8 @@
|
||||
from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
from shared_configs.configs import DEFAULT_REDIS_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def get_kv_store() -> KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
return PgRedisKVStore()
|
||||
|
||||
|
||||
def get_shared_kv_store() -> KeyValueStore:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(DEFAULT_REDIS_PREFIX)
|
||||
try:
|
||||
return get_kv_store()
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -434,17 +434,7 @@ class DefaultMultiLLM(LLM):
|
||||
# or else OpenAI throws an error
|
||||
**(
|
||||
{"parallel_tool_calls": False}
|
||||
if tools
|
||||
and self.config.model_name
|
||||
not in [
|
||||
"o3-mini",
|
||||
"o3-preview",
|
||||
"o1",
|
||||
"o1-preview",
|
||||
"o1-mini",
|
||||
"o1-mini-2024-09-12",
|
||||
"o3-mini-2025-01-31",
|
||||
]
|
||||
if tools and self.config.model_name != "o3-mini"
|
||||
else {}
|
||||
), # TODO: remove once LITELLM has patched
|
||||
**(
|
||||
|
||||
@@ -61,7 +61,6 @@ from onyx.server.features.input_prompt.api import (
|
||||
basic_router as input_prompt_router,
|
||||
)
|
||||
from onyx.server.features.notifications.api import router as notification_router
|
||||
from onyx.server.features.password.api import router as password_router
|
||||
from onyx.server.features.persona.api import admin_router as admin_persona_router
|
||||
from onyx.server.features.persona.api import basic_router as persona_router
|
||||
from onyx.server.features.tool.api import admin_router as admin_tool_router
|
||||
@@ -282,7 +281,6 @@ def get_application() -> FastAPI:
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, log_http_error
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, password_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, document_router)
|
||||
|
||||
@@ -415,7 +415,7 @@ def _build_continue_in_web_ui_block(
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
raise ValueError("No message id provided to build continue in web ui block")
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
chat_session = get_chat_session_by_message_id(
|
||||
db_session=db_session,
|
||||
message_id=message_id,
|
||||
|
||||
@@ -114,7 +114,7 @@ def handle_generate_answer_button(
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -155,7 +155,7 @@ def handle_slack_feedback(
|
||||
) -> None:
|
||||
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
|
||||
create_chat_message_feedback(
|
||||
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
|
||||
@@ -246,7 +246,7 @@ def handle_followup_button(
|
||||
|
||||
tag_ids: list[str] = []
|
||||
group_ids: list[str] = []
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
channel_name, is_dm = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel_id
|
||||
)
|
||||
|
||||
@@ -180,13 +180,6 @@ def handle_message(
|
||||
)
|
||||
return False
|
||||
|
||||
if slack_channel_config.channel_config.get("disabled") and not bypass_filters:
|
||||
logger.info(
|
||||
"Skipping message since the channel is configured such that "
|
||||
"OnyxBot is disabled"
|
||||
)
|
||||
return False
|
||||
|
||||
# List of user id to send message to, if None, send to everyone in channel
|
||||
send_to: list[str] | None = None
|
||||
missing_users: list[str] | None = None
|
||||
@@ -218,7 +211,7 @@ def handle_message(
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if message_info.email:
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.configs.onyxbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
@@ -87,7 +86,7 @@ def handle_regular_answer(
|
||||
user = None
|
||||
if message_info.is_bot_dm:
|
||||
if message_info.email:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
@@ -96,7 +95,7 @@ def handle_regular_answer(
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona
|
||||
if not persona:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
@@ -108,7 +107,7 @@ def handle_regular_answer(
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
expecting_search_result = persona_has_search_tool(persona.id, db_session)
|
||||
|
||||
# TODO: Add in support for Slack to truncate messages based on max LLM context
|
||||
@@ -157,7 +156,7 @@ def handle_regular_answer(
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, onyx_user: User | None
|
||||
) -> ChatOnyxBotResponse:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=onyx_user,
|
||||
@@ -197,7 +196,7 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=user_message.message,
|
||||
user=user,
|
||||
|
||||
@@ -251,7 +251,7 @@ class SlackbotHandler:
|
||||
"""
|
||||
all_tenants = get_all_tenant_ids()
|
||||
|
||||
token: Token[str | None]
|
||||
token: Token[str]
|
||||
|
||||
# 1) Try to acquire locks for new tenants
|
||||
for tenant_id in all_tenants:
|
||||
@@ -300,7 +300,7 @@ class SlackbotHandler:
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
bots: list[SlackBot] = []
|
||||
try:
|
||||
bots = list(fetch_slack_bots(db_session=db_session))
|
||||
@@ -340,7 +340,7 @@ class SlackbotHandler:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Attempt to fetch Slack bots
|
||||
try:
|
||||
bots = list(fetch_slack_bots(db_session=db_session))
|
||||
@@ -564,7 +564,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -788,13 +788,13 @@ def process_message(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
token: Token[str | None] | None = None
|
||||
token: Token[str] | None = None
|
||||
# Set the current tenant ID at the beginning for all DB calls within this thread
|
||||
if client.tenant_id:
|
||||
logger.info(f"Setting tenant ID to {client.tenant_id}")
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -806,7 +806,6 @@ def process_message(
|
||||
and slack_channel_config.channel_config.get("follow_up_tags")
|
||||
is not None
|
||||
)
|
||||
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
|
||||
@@ -583,7 +583,7 @@ def slack_usage_report(
|
||||
logger.warning("Unable to find sender email")
|
||||
|
||||
if sender_email is not None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
onyx_user = get_user_by_email(email=sender_email, db_session=db_session)
|
||||
|
||||
optional_telemetry(
|
||||
|
||||
@@ -6,7 +6,6 @@ from uuid import uuid4
|
||||
import redis
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
|
||||
|
||||
@@ -46,10 +45,6 @@ class RedisConnectorIndex:
|
||||
WATCHDOG_PREFIX = PREFIX + "_watchdog"
|
||||
WATCHDOG_TTL = 300
|
||||
|
||||
# used to signal that the connector itself is still running
|
||||
CONNECTOR_ACTIVE_PREFIX = PREFIX + "_connector_active"
|
||||
CONNECTOR_ACTIVE_TTL = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
@@ -73,12 +68,8 @@ class RedisConnectorIndex:
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.connector_active_key = (
|
||||
f"{self.CONNECTOR_ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
)
|
||||
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
@@ -165,20 +156,6 @@ class RedisConnectorIndex:
|
||||
|
||||
return False
|
||||
|
||||
def set_connector_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.connector_active_key, 0, ex=self.CONNECTOR_ACTIVE_TTL)
|
||||
|
||||
def connector_active(self) -> bool:
|
||||
if self.redis.exists(self.connector_active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def generator_locked(self) -> bool:
|
||||
if self.redis.exists(self.generator_lock_key):
|
||||
return True
|
||||
@@ -217,7 +194,6 @@ class RedisConnectorIndex:
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.connector_active_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_lock_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
@@ -227,9 +203,6 @@ class RedisConnectorIndex:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorIndex.CONNECTOR_ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -28,8 +28,6 @@ from onyx.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import DEFAULT_REDIS_PREFIX
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -154,10 +152,14 @@ class RedisPool:
|
||||
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
|
||||
)
|
||||
|
||||
def get_client(self, tenant_id: str) -> Redis:
|
||||
def get_client(self, tenant_id: str | None) -> Redis:
|
||||
if tenant_id is None:
|
||||
tenant_id = "public"
|
||||
return TenantRedis(tenant_id, connection_pool=self._pool)
|
||||
|
||||
def get_replica_client(self, tenant_id: str) -> Redis:
|
||||
def get_replica_client(self, tenant_id: str | None) -> Redis:
|
||||
if tenant_id is None:
|
||||
tenant_id = "public"
|
||||
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
|
||||
|
||||
@staticmethod
|
||||
@@ -219,36 +221,14 @@ redis_pool = RedisPool()
|
||||
# print(value.decode()) # Output: 'value'
|
||||
|
||||
|
||||
def get_redis_client(
|
||||
*,
|
||||
# This argument will be deprecated in the future
|
||||
tenant_id: str | None = None,
|
||||
) -> Redis:
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
return redis_pool.get_client(tenant_id)
|
||||
|
||||
|
||||
def get_redis_replica_client(
|
||||
*,
|
||||
# this argument will be deprecated in the future
|
||||
tenant_id: str | None = None,
|
||||
) -> Redis:
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
def get_redis_replica_client(*, tenant_id: str | None) -> Redis:
|
||||
return redis_pool.get_replica_client(tenant_id)
|
||||
|
||||
|
||||
def get_shared_redis_client() -> Redis:
|
||||
return redis_pool.get_client(DEFAULT_REDIS_PREFIX)
|
||||
|
||||
|
||||
def get_shared_redis_replica_client() -> Redis:
|
||||
return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX)
|
||||
|
||||
|
||||
SSL_CERT_REQS_MAP = {
|
||||
"none": ssl.CERT_NONE,
|
||||
"optional": ssl.CERT_OPTIONAL,
|
||||
|
||||
@@ -25,9 +25,6 @@ from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.db.connector import delete_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id_for_user,
|
||||
@@ -38,6 +35,8 @@ from onyx.db.connector_credential_pair import (
|
||||
)
|
||||
from onyx.db.document import get_document_counts_for_cc_pairs
|
||||
from onyx.db.document import get_documents_for_cc_pair
|
||||
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -63,7 +62,6 @@ from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/manage")
|
||||
@@ -108,9 +106,8 @@ def get_cc_pair_full_info(
|
||||
cc_pair_id: int,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> CCPairFullInfo:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id, db_session, user, get_editable=False
|
||||
)
|
||||
@@ -175,6 +172,7 @@ def update_cc_pair_status(
|
||||
status_update_request: CCStatusUpdateRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""This method returns nearly immediately. It simply sets some signals and
|
||||
optimistically assumes any running background processes will clean themselves up.
|
||||
@@ -182,8 +180,6 @@ def update_cc_pair_status(
|
||||
|
||||
Returns HTTPStatus.OK if everything finished.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
@@ -344,9 +340,9 @@ def prune_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[list[int]]:
|
||||
"""Triggers pruning on a particular cc_pair immediately"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -360,7 +356,7 @@ def prune_cc_pair(
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.prune.fenced:
|
||||
@@ -376,7 +372,7 @@ def prune_cc_pair(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
primary_app, cc_pair, db_session, r, tenant_id
|
||||
primary_app, cc_pair, db_session, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
@@ -418,9 +414,9 @@ def sync_cc_pair(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[list[int]]:
|
||||
"""Triggers permissions sync on a particular cc_pair immediately"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -434,7 +430,7 @@ def sync_cc_pair(
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.permissions.fenced:
|
||||
@@ -450,7 +446,7 @@ def sync_cc_pair(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
primary_app, cc_pair_id, r, tenant_id
|
||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
@@ -492,9 +488,9 @@ def sync_cc_pair_groups(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[list[int]]:
|
||||
"""Triggers group sync on a particular cc_pair immediately"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -508,7 +504,7 @@ def sync_cc_pair_groups(
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
@@ -524,7 +520,7 @@ def sync_cc_pair_groups(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
primary_app, cc_pair_id, r, tenant_id
|
||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
@@ -620,10 +616,6 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
try:
|
||||
validate_ccpair_for_user(
|
||||
connector_id, credential_id, db_session, user, tenant_id
|
||||
)
|
||||
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -648,27 +640,10 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ConnectorValidationError as e:
|
||||
# If validation fails, delete the connector and commit the changes
|
||||
# Ensures we don't leave invalid connectors in the database
|
||||
# NOTE: consensus is that it makes sense to unify connector and ccpair creation flows
|
||||
# which would rid us of needing to handle cases like these
|
||||
delete_connector(db_session, connector_id)
|
||||
db_session.commit()
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Connector validation error: " + str(e)
|
||||
)
|
||||
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError: {e}")
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error: {e}")
|
||||
raise HTTPException(status_code=500, detail="Unexpected error")
|
||||
|
||||
|
||||
@router.delete("/connector/{connector_id}/credential/{credential_id}")
|
||||
def dissociate_credential_from_connector(
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.connectors.google_utils.google_auth import (
|
||||
get_google_oauth_creds,
|
||||
)
|
||||
@@ -62,7 +61,6 @@ from onyx.connectors.google_utils.shared_constants import DB_CREDENTIALS_DICT_TO
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.connectors.interfaces import ConnectorValidationError
|
||||
from onyx.db.connector import create_connector
|
||||
from onyx.db.connector import delete_connector
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
@@ -81,6 +79,7 @@ from onyx.db.credentials import delete_service_account_credentials
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from onyx.db.document import get_document_counts_for_cc_pairs
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import IndexingMode
|
||||
@@ -120,7 +119,6 @@ from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -612,8 +610,8 @@ def get_connector_indexing_status(
|
||||
get_editable: bool = Query(
|
||||
False, description="If true, return editable document sets"
|
||||
),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> list[ConnectorIndexingStatus]:
|
||||
tenant_id = get_current_tenant_id()
|
||||
indexing_statuses: list[ConnectorIndexingStatus] = []
|
||||
|
||||
if MOCK_CONNECTOR_FILE_PATH:
|
||||
@@ -776,9 +774,8 @@ def create_connector_from_model(
|
||||
connector_data: ConnectorUpdateRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> ObjectCreationIdResponse:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
|
||||
@@ -816,8 +813,15 @@ def create_connector_with_mock_credential(
|
||||
connector_data: ConnectorUpdateRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
tenant_id = get_current_tenant_id()
|
||||
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
|
||||
and associate_credential_to_connector should be combined.
|
||||
|
||||
The intent of this endpoint is to handle connectors that don't need credentials,
|
||||
AKA web, file, etc ... but there isn't any reason a single endpoint couldn't
|
||||
server this purpose.
|
||||
"""
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
@@ -846,22 +850,11 @@ def create_connector_with_mock_credential(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Store the created connector and credential IDs
|
||||
connector_id = cast(int, connector_response.id)
|
||||
credential_id = credential.id
|
||||
|
||||
validate_ccpair_for_user(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
connector_id=cast(int, connector_response.id), # will aways be an int
|
||||
credential_id=credential.id,
|
||||
access_type=connector_data.access_type,
|
||||
cc_pair_name=connector_data.name,
|
||||
groups=connector_data.groups,
|
||||
@@ -886,12 +879,9 @@ def create_connector_with_mock_credential(
|
||||
properties=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ConnectorValidationError as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Connector validation error: " + str(e)
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@@ -962,10 +952,10 @@ def connector_run_once(
|
||||
run_info: RunConnectorRequest,
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[int]:
|
||||
"""Used to trigger indexing on a set of cc_pairs associated with a
|
||||
single connector."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
connector_id = run_info.connector_id
|
||||
specified_credential_ids = run_info.credential_ids
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.credentials import alter_credential
|
||||
from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
@@ -18,7 +17,6 @@ from onyx.db.credentials import fetch_credentials_by_source_for_user
|
||||
from onyx.db.credentials import fetch_credentials_for_user
|
||||
from onyx.db.credentials import swap_credentials_connector
|
||||
from onyx.db.credentials import update_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.db.models import User
|
||||
@@ -100,16 +98,7 @@ def swap_credentials_for_connector(
|
||||
credential_swap_req: CredentialSwapRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
validate_ccpair_for_user(
|
||||
credential_swap_req.connector_id,
|
||||
credential_swap_req.new_credential_id,
|
||||
db_session,
|
||||
user,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
connector_credential_pair = swap_credentials_connector(
|
||||
new_credential_id=credential_swap_req.new_credential_id,
|
||||
connector_id=credential_swap_req.connector_id,
|
||||
|
||||
@@ -17,13 +17,13 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.subclasses import find_all_subclasses_in_dir
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -89,10 +89,9 @@ def oauth_authorize(
|
||||
source: DocumentSource,
|
||||
desired_return_url: Annotated[str | None, Query()] = None,
|
||||
_: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""Initiates the OAuth flow by redirecting to the provider's auth page"""
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
@@ -141,6 +140,7 @@ def oauth_callback(
|
||||
state: Annotated[str, Query()],
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> CallbackResponse:
|
||||
"""Handles the OAuth callback and exchanges the code for tokens"""
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
@@ -151,7 +151,7 @@ def oauth_callback(
|
||||
connector_cls = oauth_connectors[source]
|
||||
|
||||
# get state from redis
|
||||
redis_client = get_redis_client()
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
oauth_state_bytes = cast(
|
||||
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
|
||||
)
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi_users.exceptions import InvalidPasswordException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import User
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.features.password.models import ChangePasswordRequest
|
||||
from onyx.server.features.password.models import UserResetRequest
|
||||
from onyx.server.features.password.models import UserResetResponse
|
||||
|
||||
router = APIRouter(prefix="/password")
|
||||
|
||||
|
||||
@router.post("/change-password")
|
||||
async def change_my_password(
|
||||
form_data: ChangePasswordRequest,
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
current_user: User = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Change the password for the current user.
|
||||
"""
|
||||
try:
|
||||
await user_manager.change_password_if_old_matches(
|
||||
user=current_user,
|
||||
old_password=form_data.old_password,
|
||||
new_password=form_data.new_password,
|
||||
)
|
||||
except InvalidPasswordException as e:
|
||||
raise HTTPException(status_code=400, detail=str(e.reason))
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reset_password")
|
||||
async def admin_reset_user_password(
|
||||
user_reset_request: UserResetRequest,
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> UserResetResponse:
|
||||
"""
|
||||
Reset the password for a user (admin only).
|
||||
"""
|
||||
user = get_user_by_email(user_reset_request.user_email, db_session)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
new_password = await user_manager.reset_password_as_admin(user.id)
|
||||
return UserResetResponse(
|
||||
user_id=str(user.id),
|
||||
new_password=new_password,
|
||||
)
|
||||
@@ -1,15 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class UserResetRequest(BaseModel):
|
||||
user_email: str
|
||||
|
||||
|
||||
class UserResetResponse(BaseModel):
|
||||
user_id: str
|
||||
new_password: str
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
old_password: str
|
||||
new_password: str
|
||||
@@ -18,6 +18,7 @@ from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import StarterMessageModel as StarterMessage
|
||||
from onyx.db.models import User
|
||||
@@ -55,7 +56,6 @@ from onyx.server.models import DisplayPriorityRequest
|
||||
from onyx.tools.utils import is_image_generation_available
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -201,9 +201,8 @@ def create_persona(
|
||||
persona_upsert_request: PersonaUpsertRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> PersonaSnapshot:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
prompt_id = (
|
||||
persona_upsert_request.prompt_ids[0]
|
||||
if persona_upsert_request.prompt_ids
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_for_
|
||||
from onyx.db.connector_credential_pair import (
|
||||
update_connector_credential_pair_from_id,
|
||||
)
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.feedback import fetch_docs_ranked_by_boost_for_user
|
||||
@@ -38,7 +39,6 @@ from onyx.server.manage.models import BoostUpdateRequest
|
||||
from onyx.server.manage.models import HiddenUpdateRequest
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
logger = setup_logger()
|
||||
@@ -139,9 +139,8 @@ def create_deletion_attempt_for_connector_id(
|
||||
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
connector_id = connector_credential_pair_identifier.connector_id
|
||||
credential_id = connector_credential_pair_identifier.credential_id
|
||||
|
||||
|
||||
@@ -45,11 +45,9 @@ class UserPreferences(BaseModel):
|
||||
hidden_assistants: list[int] = []
|
||||
visible_assistants: list[int] = []
|
||||
default_model: str | None = None
|
||||
auto_scroll: bool | None = None
|
||||
pinned_assistants: list[int] | None = None
|
||||
shortcut_enabled: bool | None = None
|
||||
|
||||
# These will default to workspace settings on the frontend if not set
|
||||
auto_scroll: bool | None = None
|
||||
temperature_override_enabled: bool | None = None
|
||||
|
||||
|
||||
@@ -67,7 +65,6 @@ class UserInfo(BaseModel):
|
||||
is_cloud_superuser: bool = False
|
||||
organization_name: str | None = None
|
||||
is_anonymous_user: bool | None = None
|
||||
password_configured: bool | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -86,16 +83,15 @@ class UserInfo(BaseModel):
|
||||
is_superuser=user.is_superuser,
|
||||
is_verified=user.is_verified,
|
||||
role=user.role,
|
||||
password_configured=user.password_configured,
|
||||
preferences=(
|
||||
UserPreferences(
|
||||
shortcut_enabled=user.shortcut_enabled,
|
||||
auto_scroll=user.auto_scroll,
|
||||
chosen_assistants=user.chosen_assistants,
|
||||
default_model=user.default_model,
|
||||
hidden_assistants=user.hidden_assistants,
|
||||
pinned_assistants=user.pinned_assistants,
|
||||
visible_assistants=user.visible_assistants,
|
||||
auto_scroll=user.auto_scroll,
|
||||
temperature_override_enabled=user.temperature_override_enabled,
|
||||
)
|
||||
),
|
||||
@@ -191,7 +187,6 @@ class SlackChannelConfigCreationRequest(BaseModel):
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[int] = Field(default_factory=list)
|
||||
disabled: bool = False
|
||||
|
||||
@field_validator("answer_filters", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import ChannelConfig
|
||||
from onyx.db.models import User
|
||||
@@ -35,7 +36,6 @@ from onyx.server.manage.models import SlackChannelConfigCreationRequest
|
||||
from onyx.server.manage.validate_tokens import validate_app_token
|
||||
from onyx.server.manage.validate_tokens import validate_bot_token
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
@@ -93,8 +93,6 @@ def _form_channel_config(
|
||||
"respond_to_bots"
|
||||
] = slack_channel_config_creation_request.respond_to_bots
|
||||
|
||||
channel_config["disabled"] = slack_channel_config_creation_request.disabled
|
||||
|
||||
return channel_config
|
||||
|
||||
|
||||
@@ -196,7 +194,6 @@ def patch_slack_channel_config(
|
||||
channel_config=channel_config,
|
||||
standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
disabled=slack_channel_config_creation_request.disabled,
|
||||
)
|
||||
return SlackChannelConfig.from_model(slack_channel_config_model)
|
||||
|
||||
@@ -231,9 +228,8 @@ def create_bot(
|
||||
slack_bot_creation_request: SlackBotCreationRequest,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> SlackBot:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
validate_app_token(slack_bot_creation_request.app_token)
|
||||
validate_bot_token(slack_bot_creation_request.bot_token)
|
||||
|
||||
|
||||
@@ -42,6 +42,8 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.auth import get_total_users_count
|
||||
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import User
|
||||
@@ -67,7 +69,6 @@ from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter()
|
||||
@@ -206,7 +207,6 @@ def list_all_users(
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in accepted_users
|
||||
],
|
||||
@@ -216,7 +216,6 @@ def list_all_users(
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in slack_users
|
||||
],
|
||||
@@ -234,7 +233,6 @@ def list_all_users(
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in accepted_users
|
||||
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
|
||||
@@ -244,7 +242,6 @@ def list_all_users(
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
for user in slack_users
|
||||
][
|
||||
@@ -269,13 +266,13 @@ def bulk_invite_users(
|
||||
) -> int:
|
||||
"""emails are string validated. If any email fails validation, no emails are
|
||||
invited and an exception is raised."""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Auth is disabled, cannot invite users"
|
||||
)
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
new_invited_emails = []
|
||||
email: str
|
||||
|
||||
@@ -317,7 +314,7 @@ def bulk_invite_users(
|
||||
logger.info("Registering tenant users")
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_total_users_count(db_session))
|
||||
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in new_invited_emails:
|
||||
@@ -344,10 +341,10 @@ def remove_invited_user(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> int:
|
||||
tenant_id = get_current_tenant_id()
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [user for user in user_emails if user != user_email.user_email]
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)([user_email.user_email], tenant_id)
|
||||
@@ -357,7 +354,7 @@ def remove_invited_user(
|
||||
if MULTI_TENANT:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_total_users_count(db_session))
|
||||
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Request to update number of seats taken in control plane failed. "
|
||||
@@ -534,9 +531,8 @@ def get_current_token_creation(
|
||||
def verify_user_logged_in(
|
||||
user: User | None = Depends(optional_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> UserInfo:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
|
||||
# to enforce user verification here - the frontend always wants to get the info about
|
||||
# the current user regardless of if they are currently verified
|
||||
|
||||
@@ -36,7 +36,6 @@ class FullUserSnapshot(BaseModel):
|
||||
email: str
|
||||
role: UserRole
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
|
||||
@classmethod
|
||||
def from_user_model(cls, user: User) -> "FullUserSnapshot":
|
||||
@@ -45,7 +44,6 @@ class FullUserSnapshot(BaseModel):
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_documents_by_cc_pair
|
||||
from onyx.db.document import get_ingestion_documents
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -23,7 +24,6 @@ from onyx.server.onyx_api.models import DocMinimalInfo
|
||||
from onyx.server.onyx_api.models import IngestionDocument
|
||||
from onyx.server.onyx_api.models import IngestionResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -69,9 +69,8 @@ def upsert_ingestion_doc(
|
||||
doc_info: IngestionDocument,
|
||||
_: User | None = Depends(api_key_dep),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> IngestionResult:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
doc_info.document.from_ingestion_api = True
|
||||
|
||||
document = Document.from_base(doc_info.document)
|
||||
|
||||
@@ -44,6 +44,7 @@ from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
@@ -82,7 +83,6 @@ from onyx.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from onyx.utils.headers import get_custom_tool_additional_request_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -376,6 +376,7 @@ def handle_new_chat_message(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
_rate_limit_check: None = Depends(check_token_rate_limits),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
This endpoint is both used for all the following purposes:
|
||||
@@ -397,7 +398,6 @@ def handle_new_chat_message(
|
||||
Returns:
|
||||
StreamingResponse: Streams the response to the new chat message.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
logger.debug(f"Received new chat message: {chat_message_req.message}")
|
||||
|
||||
if (
|
||||
@@ -407,7 +407,7 @@ def handle_new_chat_message(
|
||||
):
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
|
||||
@@ -240,7 +240,6 @@ class ChatMessageDetail(BaseModel):
|
||||
files: list[FileDescriptor]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
refined_answer_improvement: bool | None = None
|
||||
error: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.db.chat import get_search_docs_for_chat_message
|
||||
from onyx.db.chat import get_valid_messages_from_query_sessions
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -36,7 +37,6 @@ from onyx.server.query_and_chat.models import SearchSessionDetailResponse
|
||||
from onyx.server.query_and_chat.models import SourceTag
|
||||
from onyx.server.query_and_chat.models import TagResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -49,9 +49,8 @@ def admin_search(
|
||||
question: AdminSearchRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> AdminSearchResponse:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
query = question.query
|
||||
logger.notice(f"Received admin search query: {query}")
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.db.models import User
|
||||
from onyx.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -41,7 +41,7 @@ def check_token_rate_limits(
|
||||
versioned_rate_limit_strategy = fetch_versioned_implementation(
|
||||
"onyx.server.query_and_chat.token_limit", "_check_token_rate_limits"
|
||||
)
|
||||
return versioned_rate_limit_strategy(user, get_current_tenant_id())
|
||||
return versioned_rate_limit_strategy(user, CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
|
||||
|
||||
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
|
||||
@@ -54,7 +54,7 @@ Global rate limits
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
global_rate_limits = fetch_all_global_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
|
||||
@@ -46,9 +46,7 @@ class Settings(BaseModel):
|
||||
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
pro_search_disabled: bool | None = None
|
||||
|
||||
temperature_override_enabled: bool = False
|
||||
auto_scroll: bool = False
|
||||
auto_scroll: bool | None = None
|
||||
|
||||
|
||||
class UserSettings(Settings):
|
||||
|
||||
@@ -6,7 +6,7 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -26,7 +26,7 @@ def load_settings() -> Settings:
|
||||
logger.error(f"Error loading settings from KV store: {str(e)}")
|
||||
settings = Settings()
|
||||
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
@@ -49,7 +49,7 @@ def load_settings() -> Settings:
|
||||
|
||||
|
||||
def store_settings(settings: Settings) -> None:
|
||||
tenant_id = get_current_tenant_id() if MULTI_TENANT else None
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if settings.anonymous_user_enabled is not None:
|
||||
|
||||
@@ -252,7 +252,7 @@ def setup_vespa(
|
||||
logger.notice("Vespa setup complete.")
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
logger.notice(
|
||||
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
|
||||
)
|
||||
time.sleep(WAIT_SECONDS)
|
||||
|
||||
@@ -17,7 +17,7 @@ from requests import JSONDecodeError
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_default_tenant
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
@@ -205,7 +205,7 @@ class CustomTool(BaseTool):
|
||||
def _save_and_get_file_references(
|
||||
self, file_content: bytes | str, content_type: str
|
||||
) -> List[str]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
file_id = str(uuid.uuid4())
|
||||
@@ -328,7 +328,7 @@ class CustomTool(BaseTool):
|
||||
|
||||
# Load files from storage
|
||||
files = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
for file_id in response.tool_result.file_ids:
|
||||
|
||||
@@ -107,7 +107,7 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
|
||||
# This will always be the case for authenticated API requests
|
||||
if MULTI_TENANT:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id != POSTGRES_DEFAULT_SCHEMA and tenant_id is not None:
|
||||
if tenant_id != POSTGRES_DEFAULT_SCHEMA:
|
||||
# Strip tenant_ prefix and take first 8 chars for cleaner logs
|
||||
tenant_display = tenant_id.removeprefix(TENANT_ID_PREFIX)
|
||||
short_tenant = (
|
||||
|
||||
@@ -12,7 +12,7 @@ from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.constants import KV_CUSTOMER_UUID_KEY
|
||||
from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -22,7 +22,7 @@ from onyx.utils.variable_functionality import (
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
|
||||
_CACHED_UUID: str | None = None
|
||||
@@ -75,7 +75,7 @@ def _get_or_generate_instance_domain() -> str | None: #
|
||||
try:
|
||||
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant() as db_session:
|
||||
first_user = db_session.query(User).first()
|
||||
if first_user:
|
||||
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
|
||||
@@ -90,12 +90,12 @@ def optional_telemetry(
|
||||
record_type: RecordType,
|
||||
data: dict,
|
||||
user_id: str | None = None,
|
||||
tenant_id: str | None = None, # Allows for override of tenant_id
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
if DISABLE_TELEMETRY:
|
||||
return
|
||||
|
||||
tenant_id = tenant_id or get_current_tenant_id()
|
||||
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
try:
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user