Compare commits

..

6 Commits
pdf_fix ... max

Author SHA1 Message Date
pablonyx
e823919892 fix 2025-04-01 11:27:58 -07:00
pablonyx
2f3020a4d3 Update migration (#4410) 2025-04-01 09:10:24 -07:00
SubashMohan
4bae1318bb refactor tests for Highspot connector to use mocking for API key retrieval (#4346) 2025-04-01 02:39:05 +00:00
Weves
11c3f44c76 Init engine in slackbot 2025-03-31 17:04:20 -07:00
rkuo-danswer
cb38ac8a97 also set permission upsert to medium priority (#4405)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-31 14:59:31 -07:00
pablonyx
b2120b9f39 add user files (#4152) 2025-03-31 21:06:59 +00:00
171 changed files with 4641 additions and 8712 deletions

View File

@@ -23,10 +23,6 @@ env:
# Jira
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
GONG_ACCESS_KEY: ${{ secrets.GONG_ACCESS_KEY }}
GONG_ACCESS_KEY_SECRET: ${{ secrets.GONG_ACCESS_KEY_SECRET }}
# Google
GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR }}
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1 }}

View File

@@ -30,26 +30,30 @@ Keep knowledge and access controls sync-ed across over 40 connectors like Google
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
<h3>Feature Highlights</h3>
**Deep research over your team's knowledge:**
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
**Use Onyx as a secure AI Chat with any LLM:**
![Onyx Chat Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxChatSilentDemo.gif)
**Easily set up connectors to your apps:**
![Onyx Connector Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxConnectorSilentDemo.gif)
**Access Onyx where your team already works:**
![Onyx Bot Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxBot.png)
## Deployment
## Deployment
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
@@ -58,23 +62,23 @@ Onyx can also be run locally (even on a laptop) or deployed on a virtual machine
We also have built-in support for high-availability/scalable deployment on Kubernetes.
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
## 🔍 Other Notable Benefits of Onyx
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
## 🚧 Roadmap
## 🚧 Roadmap
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- Personalized Search
- Organizational understanding and ability to locate and suggest experts from your team.
- Code Search
- SQL and Structured Query Language
## 🔌 Connectors
## 🔌 Connectors
Keep knowledge and access up to sync across 40+ connectors:
- Google Drive
@@ -95,65 +99,19 @@ Keep knowledge and access up to sync across 40+ connectors:
See the full list [here](https://docs.onyx.app/connectors).
## 📚 Licensing
## 📚 Licensing
There are two editions of Onyx:
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
For feature details, check out [our website](https://www.onyx.app/pricing).
For feature details, check out [our website](https://www.onyx.app/pricing).
To try the Onyx Enterprise Edition:
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
## 💡 Contributing
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
# YC Company Twitter Scraper
A script that scrapes YC company pages and extracts Twitter/X.com links.
## Requirements
- Python 3.7+
- Playwright
## Installation
1. Install the required packages:
```
pip install -r requirements.txt
```
2. Install Playwright browsers:
```
playwright install
```
## Usage
Run the script with default settings:
```
python scrape_yc_twitter.py
```
This will scrape the YC companies from recent batches (W23, S23, S24, F24, S22, W22) and save the Twitter links to `twitter_links.txt`.
### Custom URL and Output
```
python scrape_yc_twitter.py --url "https://www.ycombinator.com/companies?batch=W24" --output "w24_twitter.txt"
```
## How it works
1. Navigates to the specified YC companies page
2. Scrolls down to load all company cards
3. Extracts links to individual company pages
4. Visits each company page and extracts Twitter/X.com links
5. Saves the results to a text file

View File

@@ -1,45 +0,0 @@
# YC Company Twitter Scraper
A script that scrapes YC company pages and extracts Twitter/X.com links.
## Requirements
- Python 3.7+
- Playwright
## Installation
1. Install the required packages:
```
pip install -r requirements.txt
```
2. Install Playwright browsers:
```
playwright install
```
## Usage
Run the script with default settings:
```
python scrape_yc_twitter.py
```
This will scrape the YC companies from recent batches (W23, S23, S24, F24, S22, W22) and save the Twitter links to `twitter_links.txt`.
### Custom URL and Output
```
python scrape_yc_twitter.py --url "https://www.ycombinator.com/companies?batch=W24" --output "w24_twitter.txt"
```
## How it works
1. Navigates to the specified YC companies page
2. Scrolls down to load all company cards
3. Extracts links to individual company pages
4. Visits each company page and extracts Twitter/X.com links
5. Saves the results to a text file

View File

@@ -46,7 +46,6 @@ WORKDIR /app
# Utils used by model server
COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
# Place to fetch version information
COPY ./onyx/__init__.py /app/onyx/__init__.py

View File

@@ -1,50 +0,0 @@
"""update prompt length
Revision ID: 4794bc13e484
Revises: f7505c5b0284
Create Date: 2025-04-02 11:26:36.180328
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4794bc13e484"
down_revision = "f7505c5b0284"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
def downgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)

View File

@@ -0,0 +1,52 @@
"""max_length_for_instruction_system_prompt
Revision ID: e995bdf0d6f7
Revises: 8e1ac4f39a9f
Create Date: 2025-04-01 18:32:45.123456
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e995bdf0d6f7"
down_revision = "8e1ac4f39a9f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Alter system_prompt and task_prompt columns to have a maximum length of 8000 characters
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.Text(),
type_=sa.String(8000),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.Text(),
type_=sa.String(8000),
existing_nullable=False,
)
def downgrade() -> None:
# Revert system_prompt and task_prompt columns back to Text type
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(8000),
type_=sa.Text(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(8000),
type_=sa.Text(),
existing_nullable=False,
)

View File

@@ -1,50 +0,0 @@
"""add prompt length limit
Revision ID: f71470ba9274
Revises: 6a804aeb4830
Create Date: 2025-04-01 15:07:14.977435
"""
# revision identifiers, used by Alembic.
revision = "f71470ba9274"
down_revision = "6a804aeb4830"
branch_labels = None
depends_on = None
def upgrade() -> None:
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
pass
def downgrade() -> None:
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
pass

View File

@@ -1,77 +0,0 @@
"""updated constraints for ccpairs
Revision ID: f7505c5b0284
Revises: f71470ba9274
Create Date: 2025-04-01 17:50:42.504818
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f7505c5b0284"
down_revision = "f71470ba9274"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Drop the old foreign-key constraints
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# 2) Re-add them with ondelete='CASCADE'
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="connector",
local_cols=["connector_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="credential",
local_cols=["credential_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Reverse the changes for rollback
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# Recreate without CASCADE
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
"connector",
["connector_id"],
["id"],
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
"credential",
["credential_id"],
["id"],
)

View File

@@ -159,9 +159,6 @@ def _get_space_permissions(
# Stores the permissions for each space
space_permissions_by_space_key[space_key] = space_permissions
logger.info(
f"Found space permissions for space '{space_key}': {space_permissions}"
)
return space_permissions_by_space_key

View File

@@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
# if user is None, permissions are not enforced
return chunks
final_chunk_dict: dict[str, InferenceChunk] = {}
chunks_to_keep = []
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
@@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
final_chunk_dict[chunk.unique_id] = chunk
chunks_to_keep.append(chunk)
# For each source, filter out the chunks using the permission
# check function for that source
@@ -79,16 +79,6 @@ def _post_query_chunk_censoring(
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
for censored_chunk in censored_chunks:
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
final_chunk_list: list[InferenceChunk] = []
for chunk in chunks:
# only if the chunk is in the final censored chunks, add it to the final list
# if it is missing, that means it was intentionally left out
if chunk.unique_id in final_chunk_dict:
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
return final_chunk_list
return chunks_to_keep

View File

@@ -51,14 +51,13 @@ def _get_objects_access_for_user_email_from_salesforce(
# This is cached in the function so the first query takes an extra 0.1-0.3 seconds
# but subsequent queries by the same user are essentially instant
start_time = time.monotonic()
start_time = time.time()
user_id = get_salesforce_user_id_from_email(salesforce_client, user_email)
end_time = time.monotonic()
end_time = time.time()
logger.info(
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
)
if user_id is None:
logger.warning(f"User '{user_email}' not found in Salesforce")
return None
# This is the only query that is not cached in the function
@@ -66,7 +65,6 @@ def _get_objects_access_for_user_email_from_salesforce(
object_id_to_access = get_objects_access_for_user_id(
salesforce_client, user_id, list(object_ids)
)
logger.debug(f"Object ID to access: {object_id_to_access}")
return object_id_to_access

View File

@@ -1,6 +1,10 @@
from simple_salesforce import Salesforce
from sqlalchemy.orm import Session
from onyx.connectors.salesforce.sqlite_functions import get_user_id_by_email
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import NULL_ID_STRING
from onyx.connectors.salesforce.sqlite_functions import update_email_to_id_table
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_cc_pairs_for_document
from onyx.utils.logger import setup_logger
@@ -24,8 +28,6 @@ def get_any_salesforce_client_for_doc_id(
E.g. there are 2 different credential sets for 2 different salesforce cc_pairs
but only one has the permissions to access the permissions needed for the query.
"""
# NOTE: this global seems very very bad
global _ANY_SALESFORCE_CLIENT
if _ANY_SALESFORCE_CLIENT is None:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
@@ -40,18 +42,11 @@ def get_any_salesforce_client_for_doc_id(
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true"
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
result = sf_client.query(query)
if len(result["records"]) > 0:
return result["records"][0]["Id"]
# try emails
query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true"
result = sf_client.query(query)
if len(result["records"]) > 0:
return result["records"][0]["Id"]
return None
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
# This contains only the user_ids that we have found in Salesforce.
@@ -82,21 +77,35 @@ def get_salesforce_user_id_from_email(
salesforce database. (Around 0.1-0.3 seconds)
If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds).
"""
# NOTE: this global seems bad
global _CACHED_SF_EMAIL_TO_ID_MAP
if user_email in _CACHED_SF_EMAIL_TO_ID_MAP:
if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None:
return _CACHED_SF_EMAIL_TO_ID_MAP[user_email]
# some caching via sqlite existed here before ... check history if interested
# ...query Salesforce and store the result in the database
user_id = _query_salesforce_user_id(sf_client, user_email)
db_exists = True
try:
# Check if the user is already in the database
user_id = get_user_id_by_email(user_email)
except Exception:
init_db()
try:
user_id = get_user_id_by_email(user_email)
except Exception as e:
logger.error(f"Error checking if user is in database: {e}")
user_id = None
db_exists = False
# If no entry is found in the database (indicated by user_id being None)...
if user_id is None:
# ...query Salesforce and store the result in the database
user_id = _query_salesforce_user_id(sf_client, user_email)
if db_exists:
update_email_to_id_table(user_email, user_id)
return user_id
elif user_id is None:
return None
elif user_id == NULL_ID_STRING:
return None
# If the found user_id is real, cache it
_CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id
return user_id

View File

@@ -5,14 +5,12 @@ from slack_sdk import WebClient
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -103,12 +101,7 @@ def _get_slack_document_access(
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
# Use credentials provider instead of directly loading credentials
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), "slack", cc_pair.credential.id
)
slack_connector.set_credentials_provider(provider)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)

View File

@@ -51,7 +51,6 @@ def _get_slack_group_members_email(
def slack_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
slack_client = WebClient(

View File

@@ -15,7 +15,6 @@ from ee.onyx.external_permissions.post_query_censoring import (
DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION,
)
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from ee.onyx.external_permissions.slack.group_sync import slack_group_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
@@ -57,7 +56,6 @@ DOC_PERMISSIONS_FUNC_MAP: dict[DocumentSource, DocSyncFuncType] = {
GROUP_PERMISSIONS_FUNC_MAP: dict[DocumentSource, GroupSyncFuncType] = {
DocumentSource.GOOGLE_DRIVE: gdrive_group_sync,
DocumentSource.CONFLUENCE: confluence_group_sync,
DocumentSource.SLACK: slack_group_sync,
}

View File

@@ -36,6 +36,9 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/auth/saml")
# Define non-authenticated user roles that should be re-created during SAML login
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
async def upsert_saml_user(email: str) -> User:
logger.debug(f"Attempting to upsert SAML user with email: {email}")
@@ -51,7 +54,7 @@ async def upsert_saml_user(email: str) -> User:
try:
user = await user_manager.get_by_email(email)
# If user has a non-authenticated role, treat as non-existent
if not user.role.is_web_login():
if user.role in NON_AUTHENTICATED_ROLES:
raise exceptions.UserNotExists()
return user
except exceptions.UserNotExists:

View File

@@ -1,4 +1,3 @@
import logging
import os
import shutil
from collections.abc import AsyncGenerator
@@ -9,7 +8,6 @@ import sentry_sdk
import torch
import uvicorn
from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
@@ -22,8 +20,6 @@ from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
from onyx import __version__
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
@@ -40,12 +36,6 @@ transformer_logging.set_verbosity_error()
logger = setup_logger()
file_handlers = [
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
]
setup_uvicorn_logger(shared_file_handlers=file_handlers)
def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -> None:
"""
@@ -122,15 +112,6 @@ def get_model_app() -> FastAPI:
application.include_router(encoders_router)
application.include_router(custom_models_router)
request_id_prefix = "INF"
if INDEXING_ONLY:
request_id_prefix = "IDX"
add_onyx_request_id_middleware(application, request_id_prefix, logger)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
return application

View File

@@ -15,22 +15,6 @@ class ExternalAccess:
# Whether the document is public in the external system or Onyx
is_public: bool
def __str__(self) -> str:
"""Prevent extremely long logs"""
def truncate_set(s: set[str], max_len: int = 100) -> str:
s_str = str(s)
if len(s_str) > max_len:
return f"{s_str[:max_len]}... ({len(s)} items)"
return s_str
return (
f"ExternalAccess("
f"external_user_emails={truncate_set(self.external_user_emails)}, "
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
f"is_public={self.is_public})"
)
@dataclass(frozen=True)
class DocExternalAccess:

View File

@@ -1,62 +0,0 @@
from collections.abc import Hashable
from typing import cast
from langchain_core.runnables.config import RunnableConfig
from langgraph.types import Send
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
def parallel_object_source_research_edge(
state: SearchSourcesObjectsUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
search_objects = state.analysis_objects
search_sources = state.analysis_sources
object_source_combinations = [
(object, source) for object in search_objects for source in search_sources
]
return [
Send(
"research_object_source",
ObjectSourceInput(
object_source_combination=object_source_combination,
log_messages=[],
),
)
for object_source_combination in object_source_combinations
]
def parallel_object_research_consolidation_edge(
state: ObjectResearchInformationUpdate, config: RunnableConfig
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the research for an individual object and source
"""
cast(GraphConfig, config["metadata"]["config"])
object_research_information_results = state.object_research_information_results
return [
Send(
"consolidate_object_research",
ObjectInformationInput(
object_information=object_information,
log_messages=[],
),
)
for object_information in object_research_information_results
]

View File

@@ -1,103 +0,0 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_research_consolidation_edge,
)
from onyx.agents.agent_search.dc_search_analysis.edges import (
parallel_object_source_research_edge,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a1_search_objects import (
search_objects,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a2_research_object_source import (
research_object_source,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a3_structure_research_by_object import (
structure_research_by_object,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a4_consolidate_object_research import (
consolidate_object_research,
)
from onyx.agents.agent_search.dc_search_analysis.nodes.a5_consolidate_research import (
consolidate_research,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.utils.logger import setup_logger
logger = setup_logger()
test_mode = False
def divide_and_conquer_graph_builder(test_mode: bool = False) -> StateGraph:
"""
LangGraph graph builder for the knowledge graph search process.
"""
graph = StateGraph(
state_schema=MainState,
input=MainInput,
)
### Add nodes ###
graph.add_node(
"search_objects",
search_objects,
)
graph.add_node(
"structure_research_by_source",
structure_research_by_object,
)
graph.add_node(
"research_object_source",
research_object_source,
)
graph.add_node(
"consolidate_object_research",
consolidate_object_research,
)
graph.add_node(
"consolidate_research",
consolidate_research,
)
### Add edges ###
graph.add_edge(start_key=START, end_key="search_objects")
graph.add_conditional_edges(
source="search_objects",
path=parallel_object_source_research_edge,
path_map=["research_object_source"],
)
graph.add_edge(
start_key="research_object_source",
end_key="structure_research_by_source",
)
graph.add_conditional_edges(
source="structure_research_by_source",
path=parallel_object_research_consolidation_edge,
path_map=["consolidate_object_research"],
)
graph.add_edge(
start_key="consolidate_object_research",
end_key="consolidate_research",
)
graph.add_edge(
start_key="consolidate_research",
end_key=END,
)
return graph

View File

@@ -1,159 +0,0 @@
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
SearchSourcesObjectsUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.configs.constants import DocumentSource
from onyx.prompts.agents.dc_prompts import DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SEPARATOR
from onyx.prompts.agents.dc_prompts import DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def search_objects(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> SearchSourcesObjectsUpdate:
"""
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
question = graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
if search_tool is None or graph_config.inputs.search_request.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.search_request.persona.prompts[
0
].system_prompt
agent_1_instructions = extract_section(
instructions, "Agent Step 1:", "Agent Step 2:"
)
if agent_1_instructions is None:
raise ValueError("Agent 1 instructions not found")
agent_1_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_1_task = extract_section(
agent_1_instructions, "Task:", "Independent Research Sources:"
)
if agent_1_task is None:
raise ValueError("Agent 1 task not found")
agent_1_independent_sources_str = extract_section(
agent_1_instructions, "Independent Research Sources:", "Output Objective:"
)
if agent_1_independent_sources_str is None:
raise ValueError("Agent 1 Independent Research Sources not found")
document_sources = [
DocumentSource(x.strip().lower())
for x in agent_1_independent_sources_str.split(DC_OBJECT_SEPARATOR)
]
agent_1_output_objective = extract_section(
agent_1_instructions, "Output Objective:"
)
if agent_1_output_objective is None:
raise ValueError("Agent 1 output objective not found")
except Exception as e:
raise ValueError(
f"Agent 1 instructions not found or not formatted correctly: {e}"
)
# Extract objects
if agent_1_base_data is None:
# Retrieve chunks for objects
retrieved_docs = research(question, search_tool)[:10]
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
dc_object_extraction_prompt = DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
document_text=document_texts,
objects_of_interest=agent_1_output_objective,
)
else:
dc_object_extraction_prompt = DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT.format(
question=question,
task=agent_1_task,
base_data=agent_1_base_data,
objects_of_interest=agent_1_output_objective,
)
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_extraction_prompt,
reserved_str="",
),
)
]
primary_llm = graph_config.tooling.primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
primary_llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
cleaned_response = (
str(llm_response.content)
.replace("```json\n", "")
.replace("\n```", "")
.replace("\n", "")
)
cleaned_response = cleaned_response.split("OBJECTS:")[1]
object_list = [x.strip() for x in cleaned_response.split(";")]
except Exception as e:
raise ValueError(f"Error in search_objects: {e}")
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" Researching the individual objects for each source type... ",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
return SearchSourcesObjectsUpdate(
analysis_objects=object_list,
analysis_sources=document_sources,
log_messages=["Agent 1 Task done"],
)

View File

@@ -1,185 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.ops import research
from onyx.agents.agent_search.dc_search_analysis.states import ObjectSourceInput
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectSourceResearchUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_SOURCE_RESEARCH_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def research_object_source(
state: ObjectSourceInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectSourceResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.search_request.query
object, document_source = state.object_source_combination
if search_tool is None or graph_config.inputs.search_request.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
try:
instructions = graph_config.inputs.search_request.persona.prompts[
0
].system_prompt
agent_2_instructions = extract_section(
instructions, "Agent Step 2:", "Agent Step 3:"
)
if agent_2_instructions is None:
raise ValueError("Agent 2 instructions not found")
agent_2_task = extract_section(
agent_2_instructions, "Task:", "Independent Research Sources:"
)
if agent_2_task is None:
raise ValueError("Agent 2 task not found")
agent_2_time_cutoff = extract_section(
agent_2_instructions, "Time Cutoff:", "Research Topics:"
)
agent_2_research_topics = extract_section(
agent_2_instructions, "Research Topics:", "Output Objective"
)
agent_2_output_objective = extract_section(
agent_2_instructions, "Output Objective:"
)
if agent_2_output_objective is None:
raise ValueError("Agent 2 output objective not found")
except Exception:
raise ValueError(
"Agent 1 instructions not found or not formatted correctly: {e}"
)
# Populate prompt
# Retrieve chunks for objects
if agent_2_time_cutoff is not None and agent_2_time_cutoff.strip() != "":
if agent_2_time_cutoff.strip().endswith("d"):
try:
days = int(agent_2_time_cutoff.strip()[:-1])
agent_2_source_start_time = datetime.now(timezone.utc) - timedelta(
days=days
)
except ValueError:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
raise ValueError(
f"Invalid time cutoff format: {agent_2_time_cutoff}. Expected format: '<number>d'"
)
else:
agent_2_source_start_time = None
document_sources = [document_source] if document_source else None
if len(question.strip()) > 0:
research_area = f"{question} for {object}"
elif agent_2_research_topics and len(agent_2_research_topics.strip()) > 0:
research_area = f"{agent_2_research_topics} for {object}"
else:
research_area = object
retrieved_docs = research(
question=research_area,
search_tool=search_tool,
document_sources=document_sources,
time_cutoff=agent_2_source_start_time,
)
# Generate document text
document_texts_list = []
for doc_num, doc in enumerate(retrieved_docs):
chunk_text = "Document " + str(doc_num) + ":\n" + doc.content
document_texts_list.append(chunk_text)
document_texts = "\n\n".join(document_texts_list)
# Built prompt
today = datetime.now().strftime("%A, %Y-%m-%d")
dc_object_source_research_prompt = (
DC_OBJECT_SOURCE_RESEARCH_PROMPT.format(
today=today,
question=question,
task=agent_2_task,
document_text=document_texts,
format=agent_2_output_objective,
)
.replace("---object---", object)
.replace("---source---", document_source.value)
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_source_research_prompt,
reserved_str="",
),
)
]
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
cleaned_response = str(llm_response.content).replace("```json\n", "")
cleaned_response = cleaned_response.split("RESEARCH RESULTS:")[1]
object_research_results = {
"object": object,
"source": document_source.value,
"research_result": cleaned_response,
}
except Exception as e:
raise ValueError(f"Error in research_object_source: {e}")
logger.debug("DivCon Step A2 - Object Source Research - completed for an object")
return ObjectSourceResearchUpdate(
object_source_research_results=[object_research_results],
log_messages=["Agent Step 2 done for one object"],
)

View File

@@ -1,68 +0,0 @@
from collections import defaultdict
from datetime import datetime
from typing import cast
from typing import Dict
from typing import List
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import (
ObjectResearchInformationUpdate,
)
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.utils.logger import setup_logger
logger = setup_logger()
def structure_research_by_object(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ObjectResearchInformationUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" consolidating the information across source types for each object...",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
object_source_research_results = state.object_source_research_results
object_research_information_results: List[Dict[str, str]] = []
object_research_information_results_list: Dict[str, List[str]] = defaultdict(list)
for object_source_research in object_source_research_results:
object = object_source_research["object"]
source = object_source_research["source"]
research_result = object_source_research["research_result"]
object_research_information_results_list[object].append(
f"Source: {source}\n{research_result}"
)
for object, information in object_research_information_results_list.items():
object_research_information_results.append(
{"object": object, "information": "\n".join(information)}
)
logger.debug("DivCon Step A3 - Object Research Information Structuring - completed")
return ObjectResearchInformationUpdate(
object_research_information_results=object_research_information_results,
log_messages=["A3 - Object Research Information structured"],
)

View File

@@ -1,107 +0,0 @@
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import ObjectInformationInput
from onyx.agents.agent_search.dc_search_analysis.states import ObjectResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.prompts.agents.dc_prompts import DC_OBJECT_CONSOLIDATION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def consolidate_object_research(
state: ObjectInformationInput,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,
) -> ObjectResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
question = graph_config.inputs.search_request.query
if search_tool is None or graph_config.inputs.search_request.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
instructions = graph_config.inputs.search_request.persona.prompts[0].system_prompt
agent_4_instructions = extract_section(
instructions, "Agent Step 4:", "Agent Step 5:"
)
if agent_4_instructions is None:
raise ValueError("Agent 4 instructions not found")
agent_4_output_objective = extract_section(
agent_4_instructions, "Output Objective:"
)
if agent_4_output_objective is None:
raise ValueError("Agent 4 output objective not found")
object_information = state.object_information
object = object_information["object"]
information = object_information["information"]
# Create a prompt for the object consolidation
dc_object_consolidation_prompt = DC_OBJECT_CONSOLIDATION_PROMPT.format(
question=question,
object=object,
information=information,
format=agent_4_output_objective,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_object_consolidation_prompt,
reserved_str="",
),
)
]
graph_config.tooling.primary_llm
# fast_llm = graph_config.tooling.fast_llm
primary_llm = graph_config.tooling.primary_llm
llm = primary_llm
# Grader
try:
llm_response = run_with_timeout(
30,
llm.invoke,
prompt=msg,
timeout_override=30,
max_tokens=300,
)
cleaned_response = str(llm_response.content).replace("```json\n", "")
consolidated_information = cleaned_response.split("INFORMATION:")[1]
except Exception as e:
raise ValueError(f"Error in consolidate_object_research: {e}")
object_research_results = {
"object": object,
"research_result": consolidated_information,
}
logger.debug(
"DivCon Step A4 - Object Research Consolidation - completed for an object"
)
return ObjectResearchUpdate(
object_research_results=[object_research_results],
log_messages=["Agent Source Consilidation done"],
)

View File

@@ -1,164 +0,0 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableConfig
from langgraph.types import StreamWriter
from onyx.agents.agent_search.dc_search_analysis.ops import extract_section
from onyx.agents.agent_search.dc_search_analysis.states import MainState
from onyx.agents.agent_search.dc_search_analysis.states import ResearchUpdate
from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_WITH_BASE_DATA_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
def consolidate_research(
state: MainState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> ResearchUpdate:
"""
LangGraph node to start the agentic search process.
"""
datetime.now()
graph_config = cast(GraphConfig, config["metadata"]["config"])
graph_config.inputs.search_request.query
search_tool = graph_config.tooling.search_tool
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=" generating the answer\n\n\n",
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
if search_tool is None or graph_config.inputs.search_request.persona is None:
raise ValueError("Search tool and persona must be provided for DivCon search")
# Populate prompt
instructions = graph_config.inputs.search_request.persona.prompts[0].system_prompt
try:
agent_5_instructions = extract_section(
instructions, "Agent Step 5:", "Agent End"
)
if agent_5_instructions is None:
raise ValueError("Agent 5 instructions not found")
agent_5_base_data = extract_section(instructions, "|Start Data|", "|End Data|")
agent_5_task = extract_section(
agent_5_instructions, "Task:", "Independent Research Sources:"
)
if agent_5_task is None:
raise ValueError("Agent 5 task not found")
agent_5_output_objective = extract_section(
agent_5_instructions, "Output Objective:"
)
if agent_5_output_objective is None:
raise ValueError("Agent 5 output objective not found")
except ValueError as e:
raise ValueError(
f"Instructions for Agent Step 5 were not properly formatted: {e}"
)
research_result_list = []
if agent_5_task.strip() == "*concatenate*":
object_research_results = state.object_research_results
for object_research_result in object_research_results:
object = object_research_result["object"]
research_result = object_research_result["research_result"]
research_result_list.append(f"Object: {object}\n\n{research_result}")
research_results = "\n\n".join(research_result_list)
else:
raise NotImplementedError("Only '*concatenate*' is currently supported")
# Create a prompt for the object consolidation
if agent_5_base_data is None:
dc_formatting_prompt = DC_FORMATTING_NO_BASE_DATA_PROMPT.format(
text=research_results,
format=agent_5_output_objective,
)
else:
dc_formatting_prompt = DC_FORMATTING_WITH_BASE_DATA_PROMPT.format(
base_data=agent_5_base_data,
text=research_results,
format=agent_5_output_objective,
)
# Run LLM
msg = [
HumanMessage(
content=trim_prompt_piece(
config=graph_config.tooling.primary_llm.config,
prompt_piece=dc_formatting_prompt,
reserved_str="",
),
)
]
dispatch_timings: list[float] = []
primary_model = graph_config.tooling.primary_llm
def stream_initial_answer() -> list[str]:
response: list[str] = []
for message in primary_model.stream(msg, timeout_override=30, max_tokens=None):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
_ = run_with_timeout(
60,
stream_initial_answer,
)
except Exception as e:
raise ValueError(f"Error in consolidate_research: {e}")
logger.debug("DivCon Step A5 - Final Generation - completed")
return ResearchUpdate(
research_results=research_results,
log_messages=["Agent Source Consilidation done"],
)

View File

@@ -1,61 +0,0 @@
from datetime import datetime
from typing import cast
from onyx.chat.models import LlmDoc
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
def research(
question: str,
search_tool: SearchTool,
document_sources: list[DocumentSource] | None = None,
time_cutoff: datetime | None = None,
) -> list[LlmDoc]:
# new db session to avoid concurrency issues
callback_container: list[list[InferenceSection]] = []
retrieved_docs: list[LlmDoc] = []
with get_session_with_current_tenant() as db_session:
for tool_response in search_tool.run(
query=question,
override_kwargs=SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=db_session,
retrieved_sections_callback=callback_container.append,
skip_query_analysis=True,
document_sources=document_sources,
time_cutoff=time_cutoff,
),
):
# get retrieved docs to send to the rest of the graph
if tool_response.id == FINAL_CONTEXT_DOCUMENTS_ID:
retrieved_docs = cast(list[LlmDoc], tool_response.response)[:10]
break
return retrieved_docs
def extract_section(
text: str, start_marker: str, end_marker: str | None = None
) -> str | None:
"""Extract text between markers, returning None if markers not found"""
parts = text.split(start_marker)
if len(parts) == 1:
return None
after_start = parts[1].strip()
if not end_marker:
return after_start
extract = after_start.split(end_marker)[0]
return extract.strip()

View File

@@ -1,72 +0,0 @@
from operator import add
from typing import Annotated
from typing import Dict
from typing import TypedDict
from pydantic import BaseModel
from onyx.agents.agent_search.core_state import CoreState
from onyx.agents.agent_search.orchestration.states import ToolCallUpdate
from onyx.agents.agent_search.orchestration.states import ToolChoiceInput
from onyx.agents.agent_search.orchestration.states import ToolChoiceUpdate
from onyx.configs.constants import DocumentSource
### States ###
class LoggerUpdate(BaseModel):
log_messages: Annotated[list[str], add] = []
class SearchSourcesObjectsUpdate(LoggerUpdate):
analysis_objects: list[str] = []
analysis_sources: list[DocumentSource] = []
class ObjectSourceInput(LoggerUpdate):
object_source_combination: tuple[str, DocumentSource]
class ObjectSourceResearchUpdate(LoggerUpdate):
object_source_research_results: Annotated[list[Dict[str, str]], add] = []
class ObjectInformationInput(LoggerUpdate):
object_information: Dict[str, str]
class ObjectResearchInformationUpdate(LoggerUpdate):
object_research_information_results: Annotated[list[Dict[str, str]], add] = []
class ObjectResearchUpdate(LoggerUpdate):
object_research_results: Annotated[list[Dict[str, str]], add] = []
class ResearchUpdate(LoggerUpdate):
research_results: str | None = None
## Graph Input State
class MainInput(CoreState):
pass
## Graph State
class MainState(
# This includes the core state
MainInput,
ToolChoiceInput,
ToolCallUpdate,
ToolChoiceUpdate,
SearchSourcesObjectsUpdate,
ObjectSourceResearchUpdate,
ObjectResearchInformationUpdate,
ObjectResearchUpdate,
ResearchUpdate,
):
pass
## Graph Output State - presently not used
class MainOutput(TypedDict):
log_messages: list[str]

View File

@@ -8,10 +8,6 @@ from langgraph.graph.state import CompiledStateGraph
from onyx.agents.agent_search.basic.graph_builder import basic_graph_builder
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.dc_search_analysis.graph_builder import (
divide_and_conquer_graph_builder,
)
from onyx.agents.agent_search.dc_search_analysis.states import MainInput as DCMainInput
from onyx.agents.agent_search.deep_search.main.graph_builder import (
main_graph_builder as main_graph_builder_a,
)
@@ -86,7 +82,7 @@ def _parse_agent_event(
def manage_sync_streaming(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
graph_input: BasicInput | MainInput | DCMainInput,
graph_input: BasicInput | MainInput,
) -> Iterable[StreamEvent]:
message_id = config.persistence.message_id if config.persistence else None
for event in compiled_graph.stream(
@@ -100,7 +96,7 @@ def manage_sync_streaming(
def run_graph(
compiled_graph: CompiledStateGraph,
config: GraphConfig,
input: BasicInput | MainInput | DCMainInput,
input: BasicInput | MainInput,
) -> AnswerStream:
config.behavior.perform_initial_search_decomposition = (
INITIAL_SEARCH_DECOMPOSITION_ENABLED
@@ -150,16 +146,6 @@ def run_basic_graph(
return run_graph(compiled_graph, config, input)
def run_dc_graph(
config: GraphConfig,
) -> AnswerStream:
graph = divide_and_conquer_graph_builder()
compiled_graph = graph.compile()
input = DCMainInput(log_messages=[])
config.inputs.search_request.query = config.inputs.search_request.query.strip()
return run_graph(compiled_graph, config, input)
if __name__ == "__main__":
for _ in range(1):
query_start_time = datetime.now()

View File

@@ -180,35 +180,3 @@ def binary_string_test_after_answer_separator(
relevant_text = text.split(f"{separator}")[-1]
return binary_string_test(relevant_text, positive_value)
def build_dc_search_prompt(
question: str,
original_question: str,
docs: list[InferenceSection],
persona_specification: str,
config: LLMConfig,
) -> list[SystemMessage | HumanMessage | AIMessage | ToolMessage]:
system_message = SystemMessage(
content=persona_specification,
)
date_str = build_date_time_string()
docs_str = format_docs(docs)
docs_str = trim_prompt_piece(
config,
docs_str,
SUB_QUESTION_RAG_PROMPT + question + original_question + date_str,
)
human_message = HumanMessage(
content=SUB_QUESTION_RAG_PROMPT.format(
question=question,
original_question=original_question,
context=docs_str,
date_prompt=date_str,
)
)
return [system_message, human_message]

View File

@@ -23,7 +23,6 @@ from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
<html lang="en">

View File

@@ -56,7 +56,6 @@ from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
@@ -514,25 +513,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return user
async def on_after_login(
self,
user: User,
request: Optional[Request] = None,
response: Optional[Response] = None,
) -> None:
try:
if response and request and ANONYMOUS_USER_COOKIE_NAME in request.cookies:
response.delete_cookie(
ANONYMOUS_USER_COOKIE_NAME,
# Ensure cookie deletion doesn't override other cookies by setting the same path/domain
path="/",
domain=None,
secure=WEB_DOMAIN.startswith("https"),
)
logger.debug(f"Deleted anonymous user cookie for user {user.email}")
except Exception:
logger.exception("Error deleting anonymous user cookie")
async def on_after_register(
self, user: User, request: Optional[Request] = None
) -> None:
@@ -1322,7 +1302,6 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
if tenant_id is None:
# Use URL utility to add parameters
@@ -1332,14 +1311,9 @@ def get_oauth_router(
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
# FastAPI can have multiple Set-Cookie headers as a list
if header_name.lower() == "set-cookie" and isinstance(header_value, list):
for cookie_value in header_value:
redirect_response.headers.append(header_name, cookie_value)
else:
redirect_response.headers[header_name] = header_value
redirect_response.headers[header_name] = header_value
if hasattr(response, "body"):
redirect_response.body = response.body

View File

@@ -1,6 +1,5 @@
import logging
import multiprocessing
import os
import time
from typing import Any
from typing import cast
@@ -306,7 +305,7 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Running as a secondary celery worker: pid={os.getpid()}")
logger.info("Running as a secondary celery worker.")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5

View File

@@ -1,7 +0,0 @@
from celery import Celery
import onyx.background.celery.apps.app_base as app_base
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.client")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]

View File

@@ -1,5 +1,4 @@
import logging
import os
from typing import Any
from typing import cast
@@ -96,7 +95,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
logger.info("Running as the primary celery worker.")
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -1,16 +0,0 @@
import onyx.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late

View File

@@ -886,8 +886,11 @@ def monitor_ccpair_permissions_taskset(
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
data={
"cc_pair_id": cc_pair_id,
"total_docs_synced": initial if initial is not None else 0,
"remaining_docs_to_sync": remaining,
"id": payload.id if payload else None,
"total_docs": initial if initial is not None else 0,
"remaining_docs": remaining,
"synced_docs": (initial - remaining) if initial is not None else 0,
"is_complete": remaining == 0,
},
tenant_id=tenant_id,
)
@@ -903,13 +906,6 @@ def monitor_ccpair_permissions_taskset(
f"num_synced={initial}"
)
# Add telemetry for permission syncing complete
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
data={"cc_pair_id": cc_pair_id},
tenant_id=tenant_id,
)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,

View File

@@ -1,20 +0,0 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker.
This is an app stub purely for sending tasks as a client.
"""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.client import celery_app
return celery_app
app = get_app()

View File

@@ -56,6 +56,7 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.redis.redis_connector import RedisConnector
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
@@ -577,8 +578,11 @@ def _run_indexing(
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"current_docs_indexed": document_count,
"current_chunks_indexed": chunk_count,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_num": batch_num,
"source": ctx.source.value,
},
tenant_id=tenant_id,
@@ -599,15 +603,26 @@ def _run_indexing(
checkpoint=checkpoint,
)
# Add telemetry for completed indexing
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
redis_connector_index = redis_connector.new_index(
index_attempt_start.search_settings_id
)
final_progress = redis_connector_index.get_progress() or 0
optional_telemetry(
record_type=RecordType.INDEXING_COMPLETE,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_count": batch_num,
"time_elapsed_seconds": time.monotonic() - start_time,
"source": ctx.source.value,
"redis_progress": final_progress,
},
tenant_id=tenant_id,
)

View File

@@ -10,7 +10,6 @@ from onyx.agents.agent_search.models import GraphPersistence
from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_main_graph
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AnswerPacket
@@ -143,18 +142,11 @@ class Answer:
yield from self._processed_stream
return
if self.graph_config.behavior.use_agentic_search:
run_langgraph = run_main_graph
elif (
self.graph_config.inputs.search_request.persona
and self.graph_config.inputs.search_request.persona.description.startswith(
"DivCon Beta Agent"
)
):
run_langgraph = run_dc_graph
else:
run_langgraph = run_basic_graph
run_langgraph = (
run_main_graph
if self.graph_config.behavior.use_agentic_search
else run_basic_graph
)
stream = run_langgraph(
self.graph_config,
)

View File

@@ -43,7 +43,6 @@ from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_me
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import MessageType
@@ -693,13 +692,8 @@ def stream_chat_message_objects(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
# Add a maximum context size in the case of user-selected docs to prevent
# slight inaccuracies in context window size pruning from causing
# the entire query to fail
document_pruning_config = DocumentPruningConfig(
is_manually_selected_docs=True,
max_window_percentage=SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE,
is_manually_selected_docs=True
)
# In case the search doc is deleted, just don't include it

View File

@@ -312,14 +312,11 @@ def prune_sections(
)
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, int]:
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
ADJACENT_CHUNK_SEP = "\n"
DISTANT_CHUNK_SEP = "\n\n...\n\n"
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
@@ -327,48 +324,33 @@ def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, i
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
)
added_chars = 0
merged_content = []
for i, chunk in enumerate(sorted_chunks):
if i > 0:
prev_chunk_id = sorted_chunks[i - 1].chunk_id
sep = (
ADJACENT_CHUNK_SEP
if chunk.chunk_id == prev_chunk_id + 1
else DISTANT_CHUNK_SEP
)
merged_content.append(sep)
added_chars += len(sep)
if chunk.chunk_id == prev_chunk_id + 1:
merged_content.append("\n")
else:
merged_content.append("\n\n...\n\n")
merged_content.append(chunk.content)
combined_content = "".join(merged_content)
return (
InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
),
added_chars,
return InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
)
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
doc_order: dict[str, int] = {}
combined_section_lengths: dict[str, int] = defaultdict(lambda: 0)
# chunk de-duping and doc ordering
for index, section in enumerate(sections):
if section.center_chunk.document_id not in doc_order:
doc_order[section.center_chunk.document_id] = index
combined_section_lengths[section.center_chunk.document_id] += len(
section.combined_content
)
chunks_map = docs_map[section.center_chunk.document_id]
for chunk in [section.center_chunk] + section.chunks:
chunks_map = docs_map[section.center_chunk.document_id]
existing_chunk = chunks_map.get(chunk.chunk_id)
if (
existing_chunk is None
@@ -379,22 +361,8 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
chunks_map[chunk.chunk_id] = chunk
new_sections = []
for doc_id, section_chunks in docs_map.items():
section_chunks_list = list(section_chunks.values())
merged_section, added_chars = _merge_doc_chunks(chunks=section_chunks_list)
previous_length = combined_section_lengths[doc_id] + added_chars
# After merging, ensure the content respects the pruning done earlier. Each
# combined section is restricted to the sum of the lengths of the sections
# from the pruning step. Technically the correct approach would be to prune based
# on tokens AGAIN, but this is a good approximation and worth not adding the
# tokenization overhead. This could also be fixed if we added a way of removing
# chunks from sections in the pruning step; at the moment this issue largely
# exists because we only trim the final section's combined_content.
merged_section.combined_content = merged_section.combined_content[
:previous_length
]
new_sections.append(merged_section)
for section_chunks in docs_map.values():
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
# Sort by highest score, then by original document order
# It is now 1 large section per doc, the center chunk being the one with the highest score

View File

@@ -16,9 +16,6 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# Maximum percentage of the context window to fill with selected sections
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(

View File

@@ -13,7 +13,6 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urljoin
from urllib.parse import urlparse
import requests
@@ -343,14 +342,9 @@ def build_confluence_document_id(
Returns:
str: The document id
"""
# NOTE: urljoin is tricky and will drop the last segment of the base if it doesn't
# end with "/" because it believes that makes it a file.
final_url = base_url.rstrip("/") + "/"
if is_cloud and not final_url.endswith("/wiki/"):
final_url = urljoin(final_url, "wiki") + "/"
final_url = urljoin(final_url, content_url.lstrip("/"))
return final_url
if is_cloud and not base_url.endswith("/wiki"):
base_url += "/wiki"
return f"{base_url}{content_url}"
def datetime_from_string(datetime_string: str) -> datetime:
@@ -460,19 +454,6 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
logger.warning("HTTPError with `None` as response or as headers")
raise e
# Confluence Server returns 403 when rate limited
if e.response.status_code == 403:
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
FORBIDDEN_RETRY_DELAY = 10
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
logger.warning(
"403 error. This sometimes happens when we hit "
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
)
return FORBIDDEN_RETRY_DELAY
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()

View File

@@ -1,5 +1,4 @@
import base64
import time
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
@@ -8,8 +7,6 @@ from typing import Any
from typing import cast
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.app_configs import GONG_CONNECTOR_START_TIME
@@ -24,14 +21,13 @@ from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
GONG_BASE_URL = "https://us-34014.api.gong.io"
class GongConnector(LoadConnector, PollConnector):
BASE_URL = "https://api.gong.io"
MAX_CALL_DETAILS_ATTEMPTS = 6
CALL_DETAILS_DELAY = 30 # in seconds
def __init__(
self,
workspaces: list[str] | None = None,
@@ -45,23 +41,15 @@ class GongConnector(LoadConnector, PollConnector):
self.auth_token_basic: str | None = None
self.hide_user_info = hide_user_info
retry_strategy = Retry(
total=5,
backoff_factor=2,
status_forcelist=[429, 500, 502, 503, 504],
)
def _get_auth_header(self) -> dict[str, str]:
if self.auth_token_basic is None:
raise ConnectorMissingCredentialError("Gong")
session = requests.Session()
session.mount(GongConnector.BASE_URL, HTTPAdapter(max_retries=retry_strategy))
self._session = session
@staticmethod
def make_url(endpoint: str) -> str:
url = f"{GongConnector.BASE_URL}{endpoint}"
return url
return {"Authorization": f"Basic {self.auth_token_basic}"}
def _get_workspace_id_map(self) -> dict[str, str]:
response = self._session.get(GongConnector.make_url("/v2/workspaces"))
url = f"{GONG_BASE_URL}/v2/workspaces"
response = requests.get(url, headers=self._get_auth_header())
response.raise_for_status()
workspaces_details = response.json().get("workspaces")
@@ -78,6 +66,7 @@ class GongConnector(LoadConnector, PollConnector):
def _get_transcript_batches(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> Generator[list[dict[str, Any]], None, None]:
url = f"{GONG_BASE_URL}/v2/calls/transcript"
body: dict[str, dict] = {"filter": {}}
if start_datetime:
body["filter"]["fromDateTime"] = start_datetime
@@ -105,8 +94,8 @@ class GongConnector(LoadConnector, PollConnector):
del body["filter"]["workspaceId"]
while True:
response = self._session.post(
GongConnector.make_url("/v2/calls/transcript"), json=body
response = requests.post(
url, headers=self._get_auth_header(), json=body
)
# If no calls in the range, just break out
if response.status_code == 404:
@@ -136,14 +125,14 @@ class GongConnector(LoadConnector, PollConnector):
yield transcripts
def _get_call_details_by_ids(self, call_ids: list[str]) -> dict:
url = f"{GONG_BASE_URL}/v2/calls/extensive"
body = {
"filter": {"callIds": call_ids},
"contentSelector": {"exposedFields": {"parties": True}},
}
response = self._session.post(
GongConnector.make_url("/v2/calls/extensive"), json=body
)
response = requests.post(url, headers=self._get_auth_header(), json=body)
response.raise_for_status()
calls = response.json().get("calls")
@@ -176,74 +165,24 @@ class GongConnector(LoadConnector, PollConnector):
def _fetch_calls(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> GenerateDocumentsOutput:
num_calls = 0
for transcript_batch in self._get_transcript_batches(
start_datetime, end_datetime
):
doc_batch: list[Document] = []
transcript_call_ids = cast(
call_ids = cast(
list[str],
[t.get("callId") for t in transcript_batch if t.get("callId")],
)
call_details_map = self._get_call_details_by_ids(call_ids)
call_details_map: dict[str, Any] = {}
# There's a likely race condition in the API where a transcript will have a
# call id but the call to v2/calls/extensive will not return all of the id's
# retry with exponential backoff has been observed to mitigate this
# in ~2 minutes
current_attempt = 0
while True:
current_attempt += 1
call_details_map = self._get_call_details_by_ids(transcript_call_ids)
if set(transcript_call_ids) == set(call_details_map.keys()):
# we got all the id's we were expecting ... break and continue
break
# we are missing some id's. Log and retry with exponential backoff
missing_call_ids = set(transcript_call_ids) - set(
call_details_map.keys()
)
logger.warning(
f"_get_call_details_by_ids is missing call id's: "
f"current_attempt={current_attempt} "
f"missing_call_ids={missing_call_ids}"
)
if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS:
raise RuntimeError(
f"Attempt count exceeded for _get_call_details_by_ids: "
f"missing_call_ids={missing_call_ids} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
)
wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1)
logger.warning(
f"_get_call_details_by_ids waiting to retry: "
f"wait={wait_seconds}s "
f"current_attempt={current_attempt} "
f"next_attempt={current_attempt+1} "
f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}"
)
time.sleep(wait_seconds)
# now we can iterate per call/transcript
for transcript in transcript_batch:
call_id = transcript.get("callId")
if not call_id or call_id not in call_details_map:
# NOTE(rkuo): seeing odd behavior where call_ids from the transcript
# don't have call details. adding error debugging logs to trace.
logger.error(
f"Couldn't get call information for Call ID: {call_id}"
)
if call_id:
logger.error(
f"Call debug info: call_id={call_id} "
f"call_ids={transcript_call_ids} "
f"call_details_map={call_details_map.keys()}"
)
if not self.continue_on_fail:
raise RuntimeError(
f"Couldn't get call information for Call ID: {call_id}"
@@ -256,8 +195,7 @@ class GongConnector(LoadConnector, PollConnector):
call_time_str = call_metadata["started"]
call_title = call_metadata["title"]
logger.info(
f"{num_calls+1}: Indexing Gong call id {call_id} "
f"from {call_time_str.split('T', 1)[0]}: {call_title}"
f"Indexing Gong call from {call_time_str.split('T', 1)[0]}: {call_title}"
)
call_parties = cast(list[dict] | None, call_details.get("parties"))
@@ -316,13 +254,8 @@ class GongConnector(LoadConnector, PollConnector):
metadata={"client": call_metadata.get("system")},
)
)
num_calls += 1
yield doc_batch
logger.info(f"_fetch_calls finished: num_calls={num_calls}")
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
combined = (
f'{credentials["gong_access_key"]}:{credentials["gong_access_key_secret"]}'
@@ -330,13 +263,6 @@ class GongConnector(LoadConnector, PollConnector):
self.auth_token_basic = base64.b64encode(combined.encode("utf-8")).decode(
"utf-8"
)
if self.auth_token_basic is None:
raise ConnectorMissingCredentialError("Gong")
self._session.headers.update(
{"Authorization": f"Basic {self.auth_token_basic}"}
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:

View File

@@ -445,9 +445,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.warning(
f"User '{user_email}' does not have access to the drive APIs."
)
# mark this user as done so we don't try to retrieve anything for them
# again
curr_stage.stage = DriveRetrievalStage.DONE
return
raise
@@ -584,25 +581,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
drive_ids_to_retrieve, checkpoint
)
# only process emails that we haven't already completed retrieval for
non_completed_org_emails = [
user_email
for user_email, stage in checkpoint.completion_map.items()
if stage != DriveRetrievalStage.DONE
]
# don't process too many emails before returning a checkpoint. This is
# to resolve the case where there are a ton of emails that don't have access
# to the drive APIs. Without this, we could loop through these emails for
# more than 3 hours, causing a timeout and stalling progress.
email_batch_takes_us_to_completion = True
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
non_completed_org_emails = non_completed_org_emails[
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
]
email_batch_takes_us_to_completion = False
user_retrieval_gens = [
self._impersonate_user_for_retrieval(
email,
@@ -613,14 +591,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
start,
end,
)
for email in non_completed_org_emails
for email in all_org_emails
]
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
# if there are more emails to process, don't mark as complete
if not email_batch_takes_us_to_completion:
return
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids

View File

@@ -20,8 +20,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -85,21 +84,14 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
Populate the spot ID map with all available spots.
Keys are stored as lowercase for case-insensitive lookups.
"""
try:
spots = self.client.get_spots()
for spot in spots:
if "title" in spot and "id" in spot:
spot_name = spot["title"]
self._spot_id_map[spot_name.lower()] = spot["id"]
spots = self.client.get_spots()
for spot in spots:
if "title" in spot and "id" in spot:
spot_name = spot["title"]
self._spot_id_map[spot_name.lower()] = spot["id"]
self._all_spots_fetched = True
logger.info(f"Retrieved {len(self._spot_id_map)} spots from Highspot")
except HighspotClientError as e:
logger.error(f"Error retrieving spots from Highspot: {str(e)}")
raise
except Exception as e:
logger.error(f"Unexpected error retrieving spots from Highspot: {str(e)}")
raise
self._all_spots_fetched = True
logger.info(f"Retrieved {len(self._spot_id_map)} spots from Highspot")
def _get_all_spot_names(self) -> List[str]:
"""
@@ -159,142 +151,116 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
Batches of Document objects
"""
doc_batch: list[Document] = []
try:
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
if not spot_names_to_process:
logger.warning("No spots found in Highspot")
raise ValueError("No spots found in Highspot")
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots"
)
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
if spot_id is None:
logger.warning(f"Spot ID not found for spot {spot_name}")
continue
offset = 0
has_more = True
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots"
)
while has_more:
logger.info(
f"Retrieving items from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
items = response.get("collection", [])
logger.info(f"Received Items: {items}")
if not items:
has_more = False
continue
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
if spot_id is None:
logger.warning(f"Spot ID not found for spot {spot_name}")
continue
offset = 0
has_more = True
for item in items:
try:
item_id = item.get("id")
if not item_id:
logger.warning("Item without ID found, skipping")
continue
item_details = self.client.get_item(item_id)
if not item_details:
logger.warning(
f"Item {item_id} details not found, skipping"
)
continue
# Apply time filter if specified
if start or end:
updated_at = item_details.get("date_updated")
if updated_at:
# Convert to datetime for comparison
try:
updated_time = datetime.fromisoformat(
updated_at.replace("Z", "+00:00")
)
if (
start
and updated_time.timestamp() < start
) or (
end and updated_time.timestamp() > end
):
continue
except (ValueError, TypeError):
# Skip if date cannot be parsed
logger.warning(
f"Invalid date format for item {item_id}: {updated_at}"
)
continue
content = self._get_item_content(item_details)
title = item_details.get("title", "")
doc_batch.append(
Document(
id=f"HIGHSPOT_{item_id}",
sections=[
TextSection(
link=item_details.get(
"url",
f"https://www.highspot.com/items/{item_id}",
),
text=content,
)
],
source=DocumentSource.HIGHSPOT,
semantic_identifier=title,
metadata={
"spot_name": spot_name,
"type": item_details.get(
"content_type", ""
),
"created_at": item_details.get(
"date_added", ""
),
"author": item_details.get("author", ""),
"language": item_details.get(
"language", ""
),
"can_download": str(
item_details.get("can_download", False)
),
},
doc_updated_at=item_details.get("date_updated"),
)
)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
except HighspotClientError as e:
item_id = "ID" if not item_id else item_id
logger.error(
f"Error retrieving item {item_id}: {str(e)}"
)
except Exception as e:
item_id = "ID" if not item_id else item_id
logger.error(
f"Unexpected error for item {item_id}: {str(e)}"
)
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(f"Error processing spot {spot_name}: {str(e)}")
except Exception as e:
logger.error(
f"Unexpected error processing spot {spot_name}: {str(e)}"
while has_more:
logger.info(
f"Retrieving items from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
items = response.get("collection", [])
logger.info(f"Received Items: {items}")
if not items:
has_more = False
continue
except Exception as e:
logger.error(f"Error in Highspot connector: {str(e)}")
raise
for item in items:
try:
item_id = item.get("id")
if not item_id:
logger.warning("Item without ID found, skipping")
continue
item_details = self.client.get_item(item_id)
if not item_details:
logger.warning(
f"Item {item_id} details not found, skipping"
)
continue
# Apply time filter if specified
if start or end:
updated_at = item_details.get("date_updated")
if updated_at:
# Convert to datetime for comparison
try:
updated_time = datetime.fromisoformat(
updated_at.replace("Z", "+00:00")
)
if (
start and updated_time.timestamp() < start
) or (end and updated_time.timestamp() > end):
continue
except (ValueError, TypeError):
# Skip if date cannot be parsed
logger.warning(
f"Invalid date format for item {item_id}: {updated_at}"
)
continue
content = self._get_item_content(item_details)
title = item_details.get("title", "")
doc_batch.append(
Document(
id=f"HIGHSPOT_{item_id}",
sections=[
TextSection(
link=item_details.get(
"url",
f"https://www.highspot.com/items/{item_id}",
),
text=content,
)
],
source=DocumentSource.HIGHSPOT,
semantic_identifier=title,
metadata={
"spot_name": spot_name,
"type": item_details.get("content_type", ""),
"created_at": item_details.get(
"date_added", ""
),
"author": item_details.get("author", ""),
"language": item_details.get("language", ""),
"can_download": str(
item_details.get("can_download", False)
),
},
doc_updated_at=item_details.get("date_updated"),
)
)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
except HighspotClientError as e:
item_id = "ID" if not item_id else item_id
logger.error(f"Error retrieving item {item_id}: {str(e)}")
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(f"Error processing spot {spot_name}: {str(e)}")
if doc_batch:
yield doc_batch
@@ -320,9 +286,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
# Extract title and description once at the beginning
title, description = self._extract_title_and_description(item_details)
default_content = f"{title}\n{description}"
logger.info(
f"Processing item {item_id} with extension {file_extension} and file name {content_name}"
)
logger.info(f"Processing item {item_id} with extension {file_extension}")
try:
if content_type == "WebLink":
@@ -334,39 +298,30 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
elif (
is_valid_format
and (
file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
or file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS
)
and file_extension in ALL_ACCEPTED_FILE_EXTENSIONS
and can_download
):
# For documents, try to get the text content
if not item_id: # Ensure item_id is defined
return default_content
content_response = self.client.get_item_content(item_id)
# Process and extract text from binary content based on type
if content_response:
text_content = extract_file_text(
BytesIO(content_response), content_name, False
BytesIO(content_response), content_name
)
return text_content if text_content else default_content
return text_content
return default_content
else:
return default_content
except HighspotClientError as e:
error_context = f"item {item_id}" if item_id else "(item id not found)"
# Use item_id safely in the warning message
error_context = f"item {item_id}" if item_id else "item"
logger.warning(f"Could not retrieve content for {error_context}: {str(e)}")
return default_content
except ValueError as e:
error_context = f"item {item_id}" if item_id else "(item id not found)"
logger.error(f"Value error for {error_context}: {str(e)}")
return default_content
except Exception as e:
error_context = f"item {item_id}" if item_id else "(item id not found)"
logger.error(
f"Unexpected error retrieving content for {error_context}: {str(e)}"
)
return default_content
return ""
def _extract_title_and_description(
self, item_details: Dict[str, Any]
@@ -403,63 +358,55 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
Batches of SlimDocument objects
"""
slim_doc_batch: list[SlimDocument] = []
try:
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
if not spot_names_to_process:
logger.warning("No spots found in Highspot")
raise ValueError("No spots found in Highspot")
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots for slim documents"
)
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
offset = 0
has_more = True
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots for slim documents"
)
while has_more:
logger.info(
f"Retrieving slim documents from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
offset = 0
has_more = True
items = response.get("collection", [])
if not items:
has_more = False
continue
for item in items:
item_id = item.get("id")
if not item_id:
continue
slim_doc_batch.append(
SlimDocument(id=f"HIGHSPOT_{item_id}")
)
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(
f"Error retrieving slim documents from spot {spot_name}: {str(e)}"
while has_more:
logger.info(
f"Retrieving slim documents from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
if slim_doc_batch:
yield slim_doc_batch
except Exception as e:
logger.error(f"Error in Highspot Slim Connector: {str(e)}")
raise
items = response.get("collection", [])
if not items:
has_more = False
continue
for item in items:
item_id = item.get("id")
if not item_id:
continue
slim_doc_batch.append(SlimDocument(id=f"HIGHSPOT_{item_id}"))
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(
f"Error retrieving slim documents from spot {spot_name}: {str(e)}"
)
if slim_doc_batch:
yield slim_doc_batch
def validate_credentials(self) -> bool:
"""

View File

@@ -1,4 +1,3 @@
import sys
from datetime import datetime
from enum import Enum
from typing import Any
@@ -41,9 +40,6 @@ class TextSection(Section):
text: str
link: str | None = None
def __sizeof__(self) -> int:
return sys.getsizeof(self.text) + sys.getsizeof(self.link)
class ImageSection(Section):
"""Section containing an image reference"""
@@ -51,9 +47,6 @@ class ImageSection(Section):
image_file_name: str
link: str | None = None
def __sizeof__(self) -> int:
return sys.getsizeof(self.image_file_name) + sys.getsizeof(self.link)
class BasicExpertInfo(BaseModel):
"""Basic Information for the owner of a document, any of the fields can be left as None
@@ -117,14 +110,6 @@ class BasicExpertInfo(BaseModel):
)
)
def __sizeof__(self) -> int:
size = sys.getsizeof(self.display_name)
size += sys.getsizeof(self.first_name)
size += sys.getsizeof(self.middle_initial)
size += sys.getsizeof(self.last_name)
size += sys.getsizeof(self.email)
return size
class DocumentBase(BaseModel):
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
@@ -178,32 +163,6 @@ class DocumentBase(BaseModel):
attributes.append(k + INDEX_SEPARATOR + v)
return attributes
def __sizeof__(self) -> int:
size = sys.getsizeof(self.id)
for section in self.sections:
size += sys.getsizeof(section)
size += sys.getsizeof(self.source)
size += sys.getsizeof(self.semantic_identifier)
size += sys.getsizeof(self.doc_updated_at)
size += sys.getsizeof(self.chunk_count)
if self.primary_owners is not None:
for primary_owner in self.primary_owners:
size += sys.getsizeof(primary_owner)
else:
size += sys.getsizeof(self.primary_owners)
if self.secondary_owners is not None:
for secondary_owner in self.secondary_owners:
size += sys.getsizeof(secondary_owner)
else:
size += sys.getsizeof(self.secondary_owners)
size += sys.getsizeof(self.title)
size += sys.getsizeof(self.from_ingestion_api)
size += sys.getsizeof(self.additional_info)
return size
def get_text_content(self) -> str:
return " ".join([section.text for section in self.sections if section.text])
@@ -235,12 +194,6 @@ class Document(DocumentBase):
from_ingestion_api=base.from_ingestion_api,
)
def __sizeof__(self) -> int:
size = super().__sizeof__()
size += sys.getsizeof(self.id)
size += sys.getsizeof(self.source)
return size
class IndexingDocument(Document):
"""Document with processed sections for indexing"""

View File

@@ -1,9 +1,4 @@
import gc
import os
import sys
import tempfile
from collections import defaultdict
from pathlib import Path
from typing import Any
from simple_salesforce import Salesforce
@@ -26,13 +21,9 @@ from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_t
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import sqlite_log_stats
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -41,8 +32,6 @@ _DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
MAX_BATCH_BYTES = 1024 * 1024
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@@ -75,45 +64,22 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
raise ConnectorMissingCredentialError("Salesforce")
return self._sf_client
@staticmethod
def reconstruct_object_types(directory: str) -> dict[str, list[str] | None]:
"""
Scans the given directory for all CSV files and reconstructs the available object types.
Assumes filenames are formatted as "ObjectType.filename.csv" or "ObjectType.csv".
Args:
directory (str): The path to the directory containing CSV files.
Returns:
dict[str, list[str]]: A dictionary mapping object types to lists of file paths.
"""
object_types = defaultdict(list)
for filename in os.listdir(directory):
if filename.endswith(".csv"):
parts = filename.split(".", 1) # Split on the first period
object_type = parts[0] # Take the first part as the object type
object_types[object_type].append(os.path.join(directory, filename))
return dict(object_types)
@staticmethod
def _download_object_csvs(
directory: str,
parent_object_list: list[str],
sf_client: Salesforce,
def _fetch_from_salesforce(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> None:
all_object_types: set[str] = set(parent_object_list)
) -> GenerateDocumentsOutput:
init_db()
all_object_types: set[str] = set(self.parent_object_list)
logger.info(
f"Parent object types: num={len(parent_object_list)} list={parent_object_list}"
)
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
logger.debug(f"Parent object types: {self.parent_object_list}")
# This takes like 20 seconds
for parent_object_type in parent_object_list:
child_types = get_all_children_of_sf_type(sf_client, parent_object_type)
for parent_object_type in self.parent_object_list:
child_types = get_all_children_of_sf_type(
self.sf_client, parent_object_type
)
all_object_types.update(child_types)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
@@ -122,53 +88,20 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# Always want to make sure user is grabbed for permissioning purposes
all_object_types.add("User")
logger.info(
f"All object types: num={len(all_object_types)} list={all_object_types}"
)
# gc.collect()
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
logger.debug(f"All object types: {all_object_types}")
# checkpoint - we've found all object types, now time to fetch the data
logger.info("Fetching CSVs for all object types")
logger.info("Starting to fetch CSVs for all object types")
# This takes like 30 minutes first time and <2 minutes for updates
object_type_to_csv_path = fetch_all_csvs_in_parallel(
sf_client=sf_client,
sf_client=self.sf_client,
object_types=all_object_types,
start=start,
end=end,
target_dir=directory,
)
# print useful information
num_csvs = 0
num_bytes = 0
for object_type, csv_paths in object_type_to_csv_path.items():
if not csv_paths:
continue
for csv_path in csv_paths:
if not csv_path:
continue
file_path = Path(csv_path)
file_size = file_path.stat().st_size
num_csvs += 1
num_bytes += file_size
logger.info(
f"CSV info: object_type={object_type} path={csv_path} bytes={file_size}"
)
logger.info(f"CSV info total: total_csvs={num_csvs} total_bytes={num_bytes}")
@staticmethod
def _load_csvs_to_db(csv_directory: str, db_directory: str) -> set[str]:
updated_ids: set[str] = set()
object_type_to_csv_path = SalesforceConnector.reconstruct_object_types(
csv_directory
)
# This takes like 10 seconds
# This is for testing the rest of the functionality if data has
# already been fetched and put in sqlite
@@ -187,16 +120,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# If path is None, it means it failed to fetch the csv
if csv_paths is None:
continue
# Go through each csv path and use it to update the db
for csv_path in csv_paths:
logger.debug(
f"Processing CSV: object_type={object_type} "
f"csv={csv_path} "
f"len={Path(csv_path).stat().st_size}"
)
logger.debug(f"Updating {object_type} with {csv_path}")
new_ids = update_sf_db_with_csv(
db_directory,
object_type=object_type,
csv_download_path=csv_path,
)
@@ -205,127 +132,49 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
f"Added {len(new_ids)} new/updated records for {object_type}"
)
os.remove(csv_path)
return updated_ids
def _fetch_from_salesforce(
self,
temp_dir: str,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
logger.info("_fetch_from_salesforce starting.")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
init_db(temp_dir)
sqlite_log_stats(temp_dir)
# Step 1 - download
SalesforceConnector._download_object_csvs(
temp_dir, self.parent_object_list, self._sf_client, start, end
)
gc.collect()
# Step 2 - load CSV's to sqlite
updated_ids = SalesforceConnector._load_csvs_to_db(temp_dir, temp_dir)
gc.collect()
logger.info(f"Found {len(updated_ids)} total updated records")
logger.info(
f"Starting to process parent objects of types: {self.parent_object_list}"
)
# Step 3 - extract and index docs
batches_processed = 0
docs_processed = 0
docs_to_yield: list[Document] = []
docs_to_yield_bytes = 0
docs_processed = 0
# Takes 15-20 seconds per batch
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
temp_dir,
updated_ids=list(updated_ids),
parent_types=self.parent_object_list,
):
batches_processed += 1
logger.info(
f"Processing batch: index={batches_processed} "
f"object_type={parent_type} "
f"len={len(parent_id_batch)} "
f"processed={docs_processed} "
f"remaining={len(updated_ids) - docs_processed}"
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
)
for parent_id in parent_id_batch:
if not (parent_object := get_record(temp_dir, parent_id, parent_type)):
if not (parent_object := get_record(parent_id, parent_type)):
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
doc = convert_sf_object_to_doc(
temp_dir,
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
docs_to_yield.append(
convert_sf_object_to_doc(
sf_object=parent_object,
sf_instance=self.sf_client.sf_instance,
)
)
doc_sizeof = sys.getsizeof(doc)
docs_to_yield_bytes += doc_sizeof
docs_to_yield.append(doc)
docs_processed += 1
# memory usage is sensitive to the input length, so we're yielding immediately
# if the batch exceeds a certain byte length
if (
len(docs_to_yield) >= self.batch_size
or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES
):
if len(docs_to_yield) >= self.batch_size:
yield docs_to_yield
docs_to_yield = []
docs_to_yield_bytes = 0
# observed a memory leak / size issue with the account table if we don't gc.collect here.
gc.collect()
yield docs_to_yield
logger.info(
f"Final processing stats: "
f"processed={docs_processed} "
f"remaining={len(updated_ids) - docs_processed}"
)
def load_from_state(self) -> GenerateDocumentsOutput:
if MULTI_TENANT:
# if multi tenant, we cannot expect the sqlite db to be cached/present
with tempfile.TemporaryDirectory() as temp_dir:
return self._fetch_from_salesforce(temp_dir)
# nuke the db since we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(f"load_from_state: Removing db at {sqlite_db_path}.")
os.remove(sqlite_db_path)
return self._fetch_from_salesforce(BASE_DATA_PATH)
return self._fetch_from_salesforce()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if MULTI_TENANT:
# if multi tenant, we cannot expect the sqlite db to be cached/present
with tempfile.TemporaryDirectory() as temp_dir:
return self._fetch_from_salesforce(temp_dir, start=start, end=end)
if start == 0:
# nuke the db if we're starting from scratch
sqlite_db_path = get_sqlite_db_path(BASE_DATA_PATH)
if os.path.exists(sqlite_db_path):
logger.info(
f"poll_source: Starting at time 0, removing db at {sqlite_db_path}."
)
os.remove(sqlite_db_path)
return self._fetch_from_salesforce(BASE_DATA_PATH)
return self._fetch_from_salesforce(start=start, end=end)
def retrieve_all_slim_documents(
self,
@@ -360,7 +209,7 @@ if __name__ == "__main__":
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
}
)
start_time = time.monotonic()
start_time = time.time()
doc_count = 0
section_count = 0
text_count = 0
@@ -372,7 +221,7 @@ if __name__ == "__main__":
for section in doc.sections:
if isinstance(section, TextSection) and section.text is not None:
text_count += len(section.text)
end_time = time.monotonic()
end_time = time.time()
print(f"Doc count: {doc_count}")
print(f"Section count: {section_count}")

View File

@@ -124,14 +124,13 @@ def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Text
def _extract_primary_owners(
directory: str,
sf_object: SalesforceObject,
) -> list[BasicExpertInfo] | None:
object_dict = sf_object.data
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
logger.warning(f"No LastModifiedById found for {sf_object.id}")
return None
if not (last_modified_by := get_record(directory, last_modified_by_id)):
if not (last_modified_by := get_record(last_modified_by_id)):
logger.warning(f"No LastModifiedBy found for {last_modified_by_id}")
return None
@@ -160,7 +159,6 @@ def _extract_primary_owners(
def convert_sf_object_to_doc(
directory: str,
sf_object: SalesforceObject,
sf_instance: str,
) -> Document:
@@ -172,8 +170,8 @@ def convert_sf_object_to_doc(
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
sections = [_extract_section(sf_object, base_url)]
for id in get_child_ids(directory, sf_object.id):
if not (child_object := get_record(directory, id)):
for id in get_child_ids(sf_object.id):
if not (child_object := get_record(id)):
continue
sections.append(_extract_section(child_object, base_url))
@@ -183,7 +181,7 @@ def convert_sf_object_to_doc(
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=_extract_primary_owners(directory, sf_object),
primary_owners=_extract_primary_owners(sf_object),
metadata={},
)
return doc

View File

@@ -11,12 +11,13 @@ from simple_salesforce.bulk2 import SFBulk2Type
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
from onyx.connectors.salesforce.utils import get_object_type_path
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _build_last_modified_time_filter_for_salesforce(
def _build_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
if start is None or end is None:
@@ -29,19 +30,6 @@ def _build_last_modified_time_filter_for_salesforce(
)
def _build_created_date_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
if start is None or end is None:
return ""
start_datetime = datetime.fromtimestamp(start, UTC)
end_datetime = datetime.fromtimestamp(end, UTC)
return (
f" WHERE CreatedDate > {start_datetime.isoformat()} "
f"AND CreatedDate < {end_datetime.isoformat()}"
)
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
return sf_object.describe()
@@ -121,6 +109,23 @@ def _check_if_object_type_is_empty(
return True
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
# Check if the csv already exists
if os.path.exists(get_object_type_path(sf_type)):
existing_csvs = [
os.path.join(get_object_type_path(sf_type), f)
for f in os.listdir(get_object_type_path(sf_type))
if f.endswith(".csv")
]
# If the csv already exists, return the path
# This is likely due to a previous run that failed
# after downloading the csv but before the data was
# written to the db
if existing_csvs:
return existing_csvs
return None
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
@@ -128,15 +133,16 @@ def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) ->
def _bulk_retrieve_from_salesforce(
sf_client: Salesforce, sf_type: str, time_filter: str, target_dir: str
sf_client: Salesforce,
sf_type: str,
time_filter: str,
) -> tuple[str, list[str] | None]:
"""Returns a tuple of
1. the salesforce object type
2. the list of CSV's
"""
if not _check_if_object_type_is_empty(sf_client, sf_type, time_filter):
return sf_type, None
if existing_csvs := _check_for_existing_csvs(sf_type):
return sf_type, existing_csvs
query = _build_bulk_query(sf_client, sf_type, time_filter)
bulk_2_handler = SFBulk2Handler(
@@ -153,33 +159,20 @@ def _bulk_retrieve_from_salesforce(
)
logger.info(f"Downloading {sf_type}")
logger.debug(f"Query: {query}")
logger.info(f"Query: {query}")
try:
# This downloads the file to a file in the target path with a random name
results = bulk_2_type.download(
query=query,
path=target_dir,
path=get_object_type_path(sf_type),
max_records=1000000,
)
# prepend each downloaded csv with the object type (delimiter = '.')
all_download_paths: list[str] = []
for result in results:
original_file_path = result["file"]
directory, filename = os.path.split(original_file_path)
new_filename = f"{sf_type}.{filename}"
new_file_path = os.path.join(directory, new_filename)
os.rename(original_file_path, new_file_path)
all_download_paths.append(new_file_path)
all_download_paths = [result["file"] for result in results]
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
return sf_type, all_download_paths
except Exception as e:
logger.error(
f"Failed to download salesforce csv for object type {sf_type}: {e}"
)
logger.warning(f"Exceptioning query for object type {sf_type}: {query}")
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
return sf_type, None
@@ -188,35 +181,12 @@ def fetch_all_csvs_in_parallel(
object_types: set[str],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
target_dir: str,
) -> dict[str, list[str] | None]:
"""
Fetches all the csvs in parallel for the given object types
Returns a dict of (sf_type, full_download_path)
"""
# these types don't query properly and need looking at
# problem_types: set[str] = {
# "ContentDocumentLink",
# "RecordActionHistory",
# "PendingOrderSummary",
# "UnifiedActivityRelation",
# }
# these types don't have a LastModifiedDate field and instead use CreatedDate
created_date_types: set[str] = {
"AccountHistory",
"AccountTag",
"EntitySubscription",
}
last_modified_time_filter = _build_last_modified_time_filter_for_salesforce(
start, end
)
created_date_time_filter = _build_created_date_time_filter_for_salesforce(
start, end
)
time_filter = _build_time_filter_for_salesforce(start, end)
time_filter_for_each_object_type = {}
# We do this outside of the thread pool executor because this requires
# a database connection and we don't want to block the thread pool
@@ -225,11 +195,8 @@ def fetch_all_csvs_in_parallel(
"""Only add time filter if there is at least one object of the type
in the database. We aren't worried about partially completed object update runs
because this occurs after we check for existing csvs which covers this case"""
if has_at_least_one_object_of_type(target_dir, sf_type):
if sf_type in created_date_types:
time_filter_for_each_object_type[sf_type] = created_date_time_filter
else:
time_filter_for_each_object_type[sf_type] = last_modified_time_filter
if has_at_least_one_object_of_type(sf_type):
time_filter_for_each_object_type[sf_type] = time_filter
else:
time_filter_for_each_object_type[sf_type] = ""
@@ -240,7 +207,6 @@ def fetch_all_csvs_in_parallel(
sf_client=sf_client,
sf_type=object_type,
time_filter=time_filter_for_each_object_type[object_type],
target_dir=target_dir,
),
object_types,
)

View File

@@ -2,10 +2,8 @@ import csv
import json
import os
import sqlite3
import time
from collections.abc import Iterator
from contextlib import contextmanager
from pathlib import Path
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import SalesforceObject
@@ -18,7 +16,6 @@ logger = setup_logger()
@contextmanager
def get_db_connection(
directory: str,
isolation_level: str | None = None,
) -> Iterator[sqlite3.Connection]:
"""Get a database connection with proper isolation level and error handling.
@@ -28,7 +25,7 @@ def get_db_connection(
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
"""
# 60 second timeout for locks
conn = sqlite3.connect(get_sqlite_db_path(directory), timeout=60.0)
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
if isolation_level is not None:
conn.isolation_level = isolation_level
@@ -41,41 +38,17 @@ def get_db_connection(
conn.close()
def sqlite_log_stats(directory: str) -> None:
with get_db_connection(directory, "EXCLUSIVE") as conn:
cache_pages = conn.execute("PRAGMA cache_size").fetchone()[0]
page_size = conn.execute("PRAGMA page_size").fetchone()[0]
if cache_pages >= 0:
cache_bytes = cache_pages * page_size
else:
cache_bytes = abs(cache_pages * 1024)
logger.info(
f"SQLite stats: sqlite_version={sqlite3.sqlite_version} "
f"cache_pages={cache_pages} "
f"page_size={page_size} "
f"cache_bytes={cache_bytes}"
)
def init_db(directory: str) -> None:
def init_db() -> None:
"""Initialize the SQLite database with required tables if they don't exist."""
# Create database directory if it doesn't exist
start = time.monotonic()
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
os.makedirs(os.path.dirname(get_sqlite_db_path(directory)), exist_ok=True)
with get_db_connection(directory, "EXCLUSIVE") as conn:
with get_db_connection("EXCLUSIVE") as conn:
cursor = conn.cursor()
db_exists = os.path.exists(get_sqlite_db_path(directory))
if db_exists:
file_path = Path(get_sqlite_db_path(directory))
file_size = file_path.stat().st_size
logger.info(f"init_db - found existing sqlite db: len={file_size}")
else:
# why is this only if the db doesn't exist?
db_exists = os.path.exists(get_sqlite_db_path())
if not db_exists:
# Enable WAL mode for better concurrent access and write performance
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
@@ -170,31 +143,16 @@ def init_db(directory: str) -> None:
""",
)
elapsed = time.monotonic() - start
logger.info(f"init_db - create tables and indices: elapsed={elapsed:.2f}")
# Analyze tables to help query planner
# NOTE(rkuo): skip ANALYZE - it takes too long and we likely don't have
# complicated queries that need this
# start = time.monotonic()
# cursor.execute("ANALYZE relationships")
# cursor.execute("ANALYZE salesforce_objects")
# cursor.execute("ANALYZE relationship_types")
# cursor.execute("ANALYZE user_email_map")
# elapsed = time.monotonic() - start
# logger.info(f"init_db - analyze: elapsed={elapsed:.2f}")
cursor.execute("ANALYZE relationships")
cursor.execute("ANALYZE salesforce_objects")
cursor.execute("ANALYZE relationship_types")
cursor.execute("ANALYZE user_email_map")
# If database already existed but user_email_map needs to be populated
start = time.monotonic()
cursor.execute("SELECT COUNT(*) FROM user_email_map")
elapsed = time.monotonic() - start
logger.info(f"init_db - count user_email_map: elapsed={elapsed:.2f}")
start = time.monotonic()
if cursor.fetchone()[0] == 0:
_update_user_email_map(conn)
elapsed = time.monotonic() - start
logger.info(f"init_db - update_user_email_map: elapsed={elapsed:.2f}")
conn.commit()
@@ -282,15 +240,15 @@ def _update_user_email_map(conn: sqlite3.Connection) -> None:
def update_sf_db_with_csv(
directory: str,
object_type: str,
csv_download_path: str,
delete_csv_after_use: bool = True,
) -> list[str]:
"""Update the SF DB with a CSV file using SQLite storage."""
updated_ids = []
# Use IMMEDIATE to get a write lock at the start of the transaction
with get_db_connection(directory, "IMMEDIATE") as conn:
with get_db_connection("IMMEDIATE") as conn:
cursor = conn.cursor()
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
@@ -337,12 +295,17 @@ def update_sf_db_with_csv(
conn.commit()
if delete_csv_after_use:
# Remove the csv file after it has been used
# to successfully update the db
os.remove(csv_download_path)
return updated_ids
def get_child_ids(directory: str, parent_id: str) -> set[str]:
def get_child_ids(parent_id: str) -> set[str]:
"""Get all child IDs for a given parent ID."""
with get_db_connection(directory) as conn:
with get_db_connection() as conn:
cursor = conn.cursor()
# Force index usage with INDEXED BY
@@ -354,9 +317,9 @@ def get_child_ids(directory: str, parent_id: str) -> set[str]:
return child_ids
def get_type_from_id(directory: str, object_id: str) -> str | None:
def get_type_from_id(object_id: str) -> str | None:
"""Get the type of an object from its ID."""
with get_db_connection(directory) as conn:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
@@ -369,15 +332,15 @@ def get_type_from_id(directory: str, object_id: str) -> str | None:
def get_record(
directory: str, object_id: str, object_type: str | None = None
object_id: str, object_type: str | None = None
) -> SalesforceObject | None:
"""Retrieve the record and return it as a SalesforceObject."""
if object_type is None:
object_type = get_type_from_id(directory, object_id)
object_type = get_type_from_id(object_id)
if not object_type:
return None
with get_db_connection(directory) as conn:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
result = cursor.fetchone()
@@ -389,9 +352,9 @@ def get_record(
return SalesforceObject(id=object_id, type=object_type, data=data)
def find_ids_by_type(directory: str, object_type: str) -> list[str]:
def find_ids_by_type(object_type: str) -> list[str]:
"""Find all object IDs for rows of the specified type."""
with get_db_connection(directory) as conn:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
@@ -400,7 +363,6 @@ def find_ids_by_type(directory: str, object_type: str) -> list[str]:
def get_affected_parent_ids_by_type(
directory: str,
updated_ids: list[str],
parent_types: list[str],
batch_size: int = 500,
@@ -412,7 +374,7 @@ def get_affected_parent_ids_by_type(
updated_ids_batches = batch_list(updated_ids, batch_size)
updated_parent_ids: set[str] = set()
with get_db_connection(directory) as conn:
with get_db_connection() as conn:
cursor = conn.cursor()
for batch_ids in updated_ids_batches:
@@ -457,7 +419,7 @@ def get_affected_parent_ids_by_type(
yield parent_type, new_affected_ids
def has_at_least_one_object_of_type(directory: str, object_type: str) -> bool:
def has_at_least_one_object_of_type(object_type: str) -> bool:
"""Check if there is at least one object of the specified type in the database.
Args:
@@ -466,7 +428,7 @@ def has_at_least_one_object_of_type(directory: str, object_type: str) -> bool:
Returns:
bool: True if at least one object exists, False otherwise
"""
with get_db_connection(directory) as conn:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
@@ -481,7 +443,7 @@ def has_at_least_one_object_of_type(directory: str, object_type: str) -> bool:
NULL_ID_STRING = "N/A"
def get_user_id_by_email(directory: str, email: str) -> str | None:
def get_user_id_by_email(email: str) -> str | None:
"""Get the Salesforce User ID for a given email address.
Args:
@@ -492,7 +454,7 @@ def get_user_id_by_email(directory: str, email: str) -> str | None:
- was_found: True if the email exists in the table, False if not found
- user_id: The Salesforce User ID if exists, None otherwise
"""
with get_db_connection(directory) as conn:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT user_id FROM user_email_map WHERE email = ?", (email,))
result = cursor.fetchone()
@@ -501,10 +463,10 @@ def get_user_id_by_email(directory: str, email: str) -> str | None:
return result[0]
def update_email_to_id_table(directory: str, email: str, id: str | None) -> None:
def update_email_to_id_table(email: str, id: str | None) -> None:
"""Update the email to ID map table with a new email and ID."""
id_to_use = id or NULL_ID_STRING
with get_db_connection(directory) as conn:
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"INSERT OR REPLACE INTO user_email_map (email, user_id) VALUES (?, ?)",

View File

@@ -30,9 +30,9 @@ class SalesforceObject:
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
def get_sqlite_db_path(directory: str) -> str:
def get_sqlite_db_path() -> str:
"""Get the path to the sqlite db file."""
return os.path.join(directory, "salesforce_db.sqlite")
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
def get_object_type_path(object_type: str) -> str:

View File

@@ -255,9 +255,7 @@ _DISALLOWED_MSG_SUBTYPES = {
def default_msg_filter(message: MessageType) -> bool:
# Don't keep messages from bots
if message.get("bot_id") or message.get("app_id"):
bot_profile_name = message.get("bot_profile", {}).get("name")
print(f"bot_profile_name: {bot_profile_name}")
if bot_profile_name == "DanswerBot Testing":
if message.get("bot_profile", {}).get("name") == "OnyxConnector":
return False
return True

View File

@@ -5,13 +5,11 @@ from typing import cast
from sqlalchemy.orm import Session
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.prune_and_merge import _merge_sections
from onyx.chat.prune_and_merge import ChunkRange
from onyx.chat.prune_and_merge import merge_chunk_intervals
from onyx.chat.prune_and_merge import prune_and_merge_sections
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import QueryFlow
@@ -63,7 +61,6 @@ class SearchPipeline:
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None,
contextual_pruning_config: ContextualPruningConfig | None = None,
):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
@@ -80,9 +77,6 @@ class SearchPipeline:
self.search_settings = get_current_search_settings(db_session)
self.document_index = get_default_document_index(self.search_settings, None)
self.prompt_config: PromptConfig | None = prompt_config
self.contextual_pruning_config: ContextualPruningConfig | None = (
contextual_pruning_config
)
# Preprocessing steps generate this
self._search_query: SearchQuery | None = None
@@ -227,7 +221,7 @@ class SearchPipeline:
# If ee is enabled, censor the chunk sections based on user access
# Otherwise, return the retrieved chunks
censored_chunks: list[InferenceChunk] = fetch_ee_implementation_or_noop(
censored_chunks = fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring",
"_post_query_chunk_censoring",
retrieved_chunks,
@@ -426,26 +420,7 @@ class SearchPipeline:
if self._final_context_sections is not None:
return self._final_context_sections
if (
self.contextual_pruning_config is not None
and self.prompt_config is not None
):
self._final_context_sections = prune_and_merge_sections(
sections=self.reranked_sections,
section_relevance_list=None,
prompt_config=self.prompt_config,
llm_config=self.llm.config,
question=self.search_query.query,
contextual_pruning_config=self.contextual_pruning_config,
)
else:
logger.error(
"Contextual pruning or prompt config not set, using default merge"
)
self._final_context_sections = _merge_sections(
sections=self.reranked_sections
)
self._final_context_sections = _merge_sections(sections=self.reranked_sections)
return self._final_context_sections
@property

View File

@@ -613,19 +613,8 @@ def fetch_connector_credential_pairs(
def resync_cc_pair(
cc_pair: ConnectorCredentialPair,
search_settings_id: int,
db_session: Session,
) -> None:
"""
Updates state stored in the connector_credential_pair table based on the
latest index attempt for the given search settings.
Args:
cc_pair: ConnectorCredentialPair to resync
search_settings_id: SearchSettings to use for resync
db_session: Database session
"""
def find_latest_index_attempt(
connector_id: int,
credential_id: int,
@@ -638,10 +627,11 @@ def resync_cc_pair(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
IndexAttempt.search_settings_id == search_settings_id,
SearchSettings.status == IndexModelStatus.PRESENT,
)
)

View File

@@ -43,8 +43,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
ONE_HOUR_IN_SECONDS = 60 * 60
def check_docs_exist(db_session: Session) -> bool:
stmt = select(exists(DbDocument))
@@ -609,46 +607,6 @@ def delete_documents_complete__no_commit(
delete_documents__no_commit(db_session, document_ids)
def delete_all_documents_for_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
timeout: int = ONE_HOUR_IN_SECONDS,
) -> None:
"""Delete all documents for a given connector credential pair.
This will delete all documents and their associated data (chunks, feedback, tags, etc.)
NOTE: a bit inefficient, but it's not a big deal since this is done rarely - only during
an index swap. If we wanted to make this more efficient, we could use a single delete
statement + cascade.
"""
batch_size = 1000
start_time = time.monotonic()
while True:
# Get document IDs in batches
stmt = (
select(DocumentByConnectorCredentialPair.id)
.where(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
.limit(batch_size)
)
document_ids = db_session.scalars(stmt).all()
if not document_ids:
break
delete_documents_complete__no_commit(
db_session=db_session, document_ids=list(document_ids)
)
db_session.commit()
if time.monotonic() - start_time > timeout:
raise RuntimeError("Timeout reached while deleting documents")
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
"""Acquire locks for the specified documents. Ideally this shouldn't be
called with large list of document_ids (an exception could be made if the

View File

@@ -217,6 +217,7 @@ def mark_attempt_in_progress(
"index_attempt_id": index_attempt.id,
"status": IndexingStatus.IN_PROGRESS.value,
"cc_pair_id": index_attempt.connector_credential_pair_id,
"search_settings_id": index_attempt.search_settings_id,
},
)
except Exception:
@@ -245,6 +246,9 @@ def mark_attempt_succeeded(
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.SUCCESS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
@@ -273,6 +277,9 @@ def mark_attempt_partially_succeeded(
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
@@ -305,6 +312,10 @@ def mark_attempt_canceled(
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.CANCELED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
@@ -339,6 +350,10 @@ def mark_attempt_failed(
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.FAILED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": failure_reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
@@ -710,25 +725,6 @@ def cancel_indexing_attempts_past_model(
)
def cancel_indexing_attempts_for_search_settings(
search_settings_id: int,
db_session: Session,
) -> None:
"""Stops all indexing attempts that are in progress or not started for
the specified search settings."""
db_session.execute(
update(IndexAttempt)
.where(
IndexAttempt.status.in_(
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
),
IndexAttempt.search_settings_id == search_settings_id,
)
.values(status=IndexingStatus.FAILED)
)
def count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id: int | None,
db_session: Session,

View File

@@ -703,11 +703,7 @@ class Connector(Base):
)
documents_by_connector: Mapped[
list["DocumentByConnectorCredentialPair"]
] = relationship(
"DocumentByConnectorCredentialPair",
back_populates="connector",
passive_deletes=True,
)
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
# synchronize this validation logic with RefreshFrequencySchema etc on front end
# until we have a centralized validation schema
@@ -761,11 +757,7 @@ class Credential(Base):
)
documents_by_credential: Mapped[
list["DocumentByConnectorCredentialPair"]
] = relationship(
"DocumentByConnectorCredentialPair",
back_populates="credential",
passive_deletes=True,
)
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
user: Mapped[User | None] = relationship("User", back_populates="credentials")
@@ -1118,10 +1110,10 @@ class DocumentByConnectorCredentialPair(Base):
id: Mapped[str] = mapped_column(ForeignKey("document.id"), primary_key=True)
# TODO: transition this to use the ConnectorCredentialPair id directly
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id", ondelete="CASCADE"), primary_key=True
ForeignKey("connector.id"), primary_key=True
)
credential_id: Mapped[int] = mapped_column(
ForeignKey("credential.id", ondelete="CASCADE"), primary_key=True
ForeignKey("credential.id"), primary_key=True
)
# used to better keep track of document counts at a connector level
@@ -1131,10 +1123,10 @@ class DocumentByConnectorCredentialPair(Base):
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="documents_by_connector", passive_deletes=True
"Connector", back_populates="documents_by_connector"
)
credential: Mapped[Credential] = relationship(
"Credential", back_populates="documents_by_credential", passive_deletes=True
"Credential", back_populates="documents_by_credential"
)
__table_args__ = (
@@ -1658,8 +1650,8 @@ class Prompt(Base):
)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
system_prompt: Mapped[str] = mapped_column(String(length=8000))
task_prompt: Mapped[str] = mapped_column(String(length=8000))
system_prompt: Mapped[str] = mapped_column(Text)
task_prompt: Mapped[str] = mapped_column(Text)
include_citations: Mapped[bool] = mapped_column(Boolean, default=True)
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
# Default prompts are configured via backend during deployment

View File

@@ -37,8 +37,8 @@ from onyx.db.models import UserFile
from onyx.db.models import UserFolder
from onyx.db.models import UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaSharedNotificationData
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -201,7 +201,7 @@ def create_update_persona(
create_persona_request: PersonaUpsertRequest,
user: User | None,
db_session: Session,
) -> FullPersonaSnapshot:
) -> PersonaSnapshot:
"""Higher level function than upsert_persona, although either is valid to use."""
# Permission to actually use these is checked later
@@ -271,7 +271,7 @@ def create_update_persona(
logger.exception("Failed to create persona")
raise HTTPException(status_code=400, detail=str(e))
return FullPersonaSnapshot.from_model(persona)
return PersonaSnapshot.from_model(persona)
def update_persona_shared_users(

View File

@@ -3,9 +3,8 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.document import delete_all_documents_for_connector_credential_pair
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import cancel_indexing_attempts_for_search_settings
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
@@ -27,50 +26,32 @@ def _perform_index_swap(
current_search_settings: SearchSettings,
secondary_search_settings: SearchSettings,
all_cc_pairs: list[ConnectorCredentialPair],
cleanup_documents: bool = False,
) -> None:
"""Swap the indices and expire the old one."""
if len(all_cc_pairs) > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_for_search_settings(
search_settings_id=current_search_settings.id,
db_session=db_session,
)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(
cc_pair=cc_pair,
# sync based on the new search settings
search_settings_id=secondary_search_settings.id,
db_session=db_session,
)
if cleanup_documents:
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
# doing an instant swap and no documents will exist in the new index.
for cc_pair in all_cc_pairs:
delete_all_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# swap over search settings
current_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if len(all_cc_pairs) > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
# remove the old index from the vector db
document_index = get_default_document_index(secondary_search_settings, None)
document_index.ensure_indices_exist(
@@ -107,9 +88,6 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
# doing an instant swap.
cleanup_documents=True,
)
return current_search_settings

View File

@@ -5,7 +5,6 @@ from datetime import timezone
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
@@ -75,10 +74,8 @@ def build_vespa_filters(
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
# ACL filters
if filters.access_control_list is not None:
filter_str += _build_or_filters(
ACCESS_CONTROL_LIST, filters.access_control_list
)
# if filters.access_control_list is not None:
# filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)
# Source type filters
source_strs = (

View File

@@ -2,7 +2,6 @@ import io
import json
import os
import re
import uuid
import zipfile
from collections.abc import Callable
from collections.abc import Iterator
@@ -15,7 +14,6 @@ from pathlib import Path
from typing import Any
from typing import IO
from typing import NamedTuple
from typing import Optional
import chardet
import docx # type: ignore
@@ -570,8 +568,8 @@ def extract_text_and_images(
def convert_docx_to_txt(
file: UploadFile, file_store: FileStore, file_path: Optional[str] = None
) -> str:
file: UploadFile, file_store: FileStore, file_path: str
) -> None:
"""
Helper to convert docx to a .txt file in the same filestore.
"""
@@ -583,41 +581,15 @@ def convert_docx_to_txt(
all_paras = [p.text for p in doc.paragraphs]
text_content = "\n".join(all_paras)
file_name = file.filename or f"docx_{uuid.uuid4()}"
text_file_name = docx_to_txt_filename(file_path if file_path else file_name)
txt_file_path = docx_to_txt_filename(file_path)
file_store.save_file(
file_name=text_file_name,
file_name=txt_file_path,
content=BytesIO(text_content.encode("utf-8")),
display_name=file.filename,
file_origin=FileOrigin.CONNECTOR,
file_type="text/plain",
)
return text_file_name
def docx_to_txt_filename(file_path: str) -> str:
return file_path.rsplit(".", 1)[0] + ".txt"
def convert_pdf_to_txt(file: UploadFile, file_store: FileStore, file_path: str) -> str:
"""
Helper to convert PDF to a .txt file in the same filestore.
"""
file.file.seek(0)
# Extract text from the PDF
text_content, _, _ = read_pdf_file(file.file)
text_file_name = pdf_to_txt_filename(file_path)
file_store.save_file(
file_name=text_file_name,
content=BytesIO(text_content.encode("utf-8")),
display_name=file.filename,
file_origin=FileOrigin.CONNECTOR,
file_type="text/plain",
)
return text_file_name
def pdf_to_txt_filename(file_path: str) -> str:
return file_path.rsplit(".", 1)[0] + ".txt"

View File

@@ -459,6 +459,10 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
llm = get_default_llm_with_vision()
if not llm:
logger.warning(
"No vision-capable LLM available. Image sections will not be processed."
)
# Even without LLM, we still convert to IndexingDocument with base Sections
return [
IndexingDocument(
@@ -925,12 +929,10 @@ def index_doc_batch(
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
short_descriptor_list = [
chunk.to_short_descriptor() for chunk in access_aware_chunks
]
short_descriptor_log = str(short_descriptor_list)[:1024]
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
logger.debug(
"Indexing the following chunks: "
f"{[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
)
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set

View File

@@ -602,7 +602,7 @@ def get_max_input_tokens(
)
if input_toks <= 0:
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
raise RuntimeError("No tokens for input for the LLM given settings")
return input_toks

View File

@@ -1,4 +1,3 @@
import logging
import sys
import traceback
from collections.abc import AsyncGenerator
@@ -17,7 +16,6 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from httpx_oauth.clients.google import GoogleOAuth2
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from sqlalchemy.orm import Session
@@ -104,8 +102,6 @@ from onyx.server.utils import BasicAuthenticationError
from onyx.setup import setup_multitenant_onyx
from onyx.setup import setup_onyx
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.telemetry import get_or_generate_uuid
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
@@ -120,12 +116,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
file_handlers = [
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
]
setup_uvicorn_logger(shared_file_handlers=file_handlers)
def validation_exception_handler(request: Request, exc: Exception) -> JSONResponse:
if not isinstance(exc, RequestValidationError):
@@ -431,14 +421,9 @@ def get_application() -> FastAPI:
if LOG_ENDPOINT_LATENCY:
add_latency_logging_middleware(application, logger)
add_onyx_request_id_middleware(application, "API", logger)
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
return application

View File

@@ -175,7 +175,7 @@ class EmbeddingModel:
embeddings: list[Embedding] = []
def process_batch(
batch_idx: int, batch_len: int, text_batch: list[str]
batch_idx: int, text_batch: list[str]
) -> tuple[int, list[Embedding]]:
if self.callback:
if self.callback.should_stop():
@@ -202,8 +202,8 @@ class EmbeddingModel:
end_time = time.time()
processing_time = end_time - start_time
logger.debug(
f"EmbeddingModel.process_batch: Batch {batch_idx}/{batch_len} processing time: {processing_time:.2f} seconds"
logger.info(
f"Batch {batch_idx} processing time: {processing_time:.2f} seconds"
)
return batch_idx, response.embeddings
@@ -215,7 +215,7 @@ class EmbeddingModel:
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
with ThreadPoolExecutor(max_workers=num_threads) as executor:
future_to_batch = {
executor.submit(process_batch, idx, len(text_batches), batch): idx
executor.submit(process_batch, idx, batch): idx
for idx, batch in enumerate(text_batches, start=1)
}
@@ -238,7 +238,7 @@ class EmbeddingModel:
else:
# Original sequential processing
for idx, text_batch in enumerate(text_batches, start=1):
_, batch_embeddings = process_batch(idx, len(text_batches), text_batch)
_, batch_embeddings = process_batch(idx, text_batch)
embeddings.extend(batch_embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)

View File

@@ -1,147 +0,0 @@
# Standards
SEPARATOR_LINE = "-------"
SEPARATOR_LINE_LONG = "---------------"
NO_EXTRACTION = "No extraction of knowledge graph objects was feasable."
YES = "yes"
NO = "no"
DC_OBJECT_SEPARATOR = ";"
DC_OBJECT_NO_BASE_DATA_EXTRACTION_PROMPT = f"""
You are an expert in finding relevant objects/objext specifications of the same type in a list of documents. \
In this case you are interested \
in generating: {{objects_of_interest}}.
You should look at the documents - in no particular order! - and extract each object you find in the documents.
{SEPARATOR_LINE}
Here are the documents you are supposed to search through:
--
{{document_text}}
{SEPARATOR_LINE}
Here are the task instructions you should use to help you find the desired objects:
{SEPARATOR_LINE}
{{task}}
{SEPARATOR_LINE}
Here is the question that may provide critical additional context for the task:
{SEPARATOR_LINE}
{{question}}
{SEPARATOR_LINE}
Please answer the question in the following format:
REASONING: <your reasoning for the classification> - OBJECTS: <the objects - just their names - that you found, \
separated by ';'>
""".strip()
DC_OBJECT_WITH_BASE_DATA_EXTRACTION_PROMPT = f"""
You are an expert in finding relevant objects/object specifications of the same type in a list of documents. \
In this case you are interested \
in generating: {{objects_of_interest}}.
You should look at the provided data - in no particular order! - and extract each object you find in the documents.
{SEPARATOR_LINE}
Here are the data provided by the user:
--
{{base_data}}
{SEPARATOR_LINE}
Here are the task instructions you should use to help you find the desired objects:
{SEPARATOR_LINE}
{{task}}
{SEPARATOR_LINE}
Here is the request that may provide critical additional context for the task:
{SEPARATOR_LINE}
{{question}}
{SEPARATOR_LINE}
Please address the request in the following format:
REASONING: <your reasoning for the classification> - OBJECTS: <the objects - just their names - that you found, \
separated by ';'>
""".strip()
DC_OBJECT_SOURCE_RESEARCH_PROMPT = f"""
Today is {{today}}. You are an expert in extracting relevant structured information from a list of documents that \
should relate to one object. (Try to make sure that you know it relates to that one object!).
You should look at the documents - in no particular order! - and extract the information asked for this task:
{SEPARATOR_LINE}
{{task}}
{SEPARATOR_LINE}
Here is the user question that may provide critical additional context for the task:
{SEPARATOR_LINE}
{{question}}
{SEPARATOR_LINE}
Here are the documents you are supposed to search through:
--
{{document_text}}
{SEPARATOR_LINE}
Note: please cite your sources inline as you generate the results! Use the format [1], etc. Infer the \
number from the provided context documents. This is very important!
Please address the task in the following format:
REASONING:
-- <your reasoning for the classification>
RESEARCH RESULTS:
{{format}}
""".strip()
DC_OBJECT_CONSOLIDATION_PROMPT = f"""
You are a helpful assistant that consolidates information about a specific object \
from multiple sources.
The object is:
{SEPARATOR_LINE}
{{object}}
{SEPARATOR_LINE}
and the information is
{SEPARATOR_LINE}
{{information}}
{SEPARATOR_LINE}
Here is the user question that may provide critical additional context for the task:
{SEPARATOR_LINE}
{{question}}
{SEPARATOR_LINE}
Please consolidate the information into a single, concise answer. The consolidated informtation \
for the object should be in the following format:
{SEPARATOR_LINE}
{{format}}
{SEPARATOR_LINE}
Overall, please use this structure to communicate the consolidated information:
{SEPARATOR_LINE}
REASONING: <your reasoning for consolidating the information>
INFORMATION:
<consolidated information in the proper format that you have created>
"""
DC_FORMATTING_NO_BASE_DATA_PROMPT = f"""
You are an expert in text formatting. Your task is to take a given text and convert it 100 percent accurately \
in a new format.
Here is the text you are supposed to format:
{SEPARATOR_LINE}
{{text}}
{SEPARATOR_LINE}
Here is the format you are supposed to use:
{SEPARATOR_LINE}
{{format}}
{SEPARATOR_LINE}
Please start the generation directly with the formatted text. (Note that the output should not be code, but text.)
"""
DC_FORMATTING_WITH_BASE_DATA_PROMPT = f"""
You are an expert in text formatting. Your task is to take a given text and the initial \
base data provided by the user, and convert it 100 percent accurately \
in a new format. The base data may also contain important relationships that are critical \
for the formatting.
Here is the initial data provided by the user:
{SEPARATOR_LINE}
{{base_data}}
{SEPARATOR_LINE}
Here is the text you are supposed combine (and format) with the initial data, adhering to the \
format instructions provided by later in the prompt:
{SEPARATOR_LINE}
{{text}}
{SEPARATOR_LINE}
And here are the format instructions you are supposed to use:
{SEPARATOR_LINE}
{{format}}
{SEPARATOR_LINE}
Please start the generation directly with the formatted text. (Note that the output should not be code, but text.)
"""

View File

@@ -49,7 +49,6 @@ PUBLIC_ENDPOINT_SPECS = [
("/auth/oauth/callback", {"GET"}),
# anonymous user on cloud
("/tenants/anonymous-user", {"POST"}),
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
]

View File

@@ -21,7 +21,7 @@ from onyx.background.celery.tasks.external_group_syncing.tasks import (
from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
@@ -219,7 +219,7 @@ def update_cc_pair_status(
continue
# Revoke the task to prevent it from running
client_app.control.revoke(index_payload.celery_task_id)
primary_app.control.revoke(index_payload.celery_task_id)
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
@@ -238,7 +238,7 @@ def update_cc_pair_status(
db_session.commit()
# this speeds up the start of indexing by firing the check immediately
client_app.send_task(
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
kwargs=dict(tenant_id=tenant_id),
priority=OnyxCeleryPriority.HIGH,
@@ -376,7 +376,7 @@ def prune_cc_pair(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_prune_generator_task(
client_app, cc_pair, db_session, r, tenant_id
primary_app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -450,7 +450,7 @@ def sync_cc_pair(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_permissions_sync_task(
client_app, cc_pair_id, r, tenant_id
primary_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -524,7 +524,7 @@ def sync_cc_pair_groups(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_external_group_sync_task(
client_app, cc_pair_id, r, tenant_id
primary_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -634,7 +634,7 @@ def associate_credential_to_connector(
)
# trigger indexing immediately
client_app.send_task(
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},

View File

@@ -20,7 +20,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
from onyx.configs.constants import DocumentSource
@@ -100,7 +100,6 @@ from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.file_processing.extract_file_text import convert_docx_to_txt
from onyx.file_processing.extract_file_text import convert_pdf_to_txt
from onyx.file_store.file_store import get_default_file_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.redis.redis_connector import RedisConnector
@@ -129,7 +128,6 @@ from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
_GMAIL_CREDENTIAL_ID_COOKIE_NAME = "gmail_credential_id"
@@ -432,23 +430,6 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
)
continue
# Special handling for docx files - only store the plaintext version
if file.content_type and file.content_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
file_path = os.path.join(str(uuid.uuid4()), cast(str, file.filename))
text_file_path = convert_docx_to_txt(file, file_store)
deduped_file_paths.append(text_file_path)
continue
# Special handling for PDF files - only store the plaintext version
if file.content_type and file.content_type.startswith("application/pdf"):
file_path = os.path.join(str(uuid.uuid4()), cast(str, file.filename))
text_file_path = convert_pdf_to_txt(file, file_store, file_path)
deduped_file_paths.append(text_file_path)
continue
# Default handling for all other file types
file_path = os.path.join(str(uuid.uuid4()), cast(str, file.filename))
deduped_file_paths.append(file_path)
file_store.save_file(
@@ -459,6 +440,11 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
file_type=file.content_type or "text/plain",
)
if file.content_type and file.content_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
convert_docx_to_txt(file, file_store, file_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return FileUploadResponse(file_paths=deduped_file_paths)
@@ -942,7 +928,7 @@ def create_connector_with_mock_credential(
)
# trigger indexing immediately
client_app.send_task(
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
@@ -1328,7 +1314,7 @@ def trigger_indexing_for_cc_pair(
# run the beat task to pick up the triggers immediately
priority = OnyxCeleryPriority.HIGHEST if is_user_file else OnyxCeleryPriority.HIGH
logger.info(f"Sending indexing check task with priority {priority}")
client_app.send_task(
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=priority,
kwargs={"tenant_id": tenant_id},

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document_set import check_document_sets_are_public
@@ -52,7 +52,7 @@ def create_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
client_app.send_task(
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
@@ -85,7 +85,7 @@ def patch_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
client_app.send_task(
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
@@ -108,7 +108,7 @@ def delete_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
client_app.send_task(
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,

View File

@@ -43,7 +43,6 @@ from onyx.file_store.models import ChatFileType
from onyx.secondary_llm_flows.starter_message_creation import (
generate_starter_messages,
)
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import PersonaLabelCreate
@@ -425,8 +424,8 @@ def get_persona(
persona_id: int,
user: User | None = Depends(current_limited_user),
db_session: Session = Depends(get_session),
) -> FullPersonaSnapshot:
return FullPersonaSnapshot.from_model(
) -> PersonaSnapshot:
return PersonaSnapshot.from_model(
get_persona_by_id(
persona_id=persona_id,
user=user,

View File

@@ -91,80 +91,37 @@ class PersonaUpsertRequest(BaseModel):
class PersonaSnapshot(BaseModel):
id: int
owner: MinimalUserSnapshot | None
name: str
description: str
is_public: bool
is_visible: bool
icon_shape: int | None = None
icon_color: str | None = None
is_public: bool
display_priority: int | None
description: str
num_chunks: float | None
llm_relevance_filter: bool
llm_filter_extraction: bool
llm_model_provider_override: str | None
llm_model_version_override: str | None
starter_messages: list[StarterMessage] | None
builtin_persona: bool
prompts: list[PromptSnapshot]
tools: list[ToolSnapshot]
document_sets: list[DocumentSet]
users: list[MinimalUserSnapshot]
groups: list[int]
icon_color: str | None
icon_shape: int | None
uploaded_image_id: str | None = None
user_file_ids: list[int] = Field(default_factory=list)
user_folder_ids: list[int] = Field(default_factory=list)
display_priority: int | None = None
is_default_persona: bool = False
builtin_persona: bool = False
starter_messages: list[StarterMessage] | None = None
tools: list[ToolSnapshot] = Field(default_factory=list)
labels: list["PersonaLabelSnapshot"] = Field(default_factory=list)
owner: MinimalUserSnapshot | None = None
users: list[MinimalUserSnapshot] = Field(default_factory=list)
groups: list[int] = Field(default_factory=list)
document_sets: list[DocumentSet] = Field(default_factory=list)
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
num_chunks: float | None = None
@classmethod
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
return PersonaSnapshot(
id=persona.id,
name=persona.name,
description=persona.description,
is_public=persona.is_public,
is_visible=persona.is_visible,
icon_shape=persona.icon_shape,
icon_color=persona.icon_color,
uploaded_image_id=persona.uploaded_image_id,
user_file_ids=[file.id for file in persona.user_files],
user_folder_ids=[folder.id for folder in persona.user_folders],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
owner=(
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
if persona.user
else None
),
users=[
MinimalUserSnapshot(id=user.id, email=user.email)
for user in persona.users
],
groups=[user_group.id for user_group in persona.groups],
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
num_chunks=persona.num_chunks,
)
# Model with full context on perona's internal settings
# This is used for flows which need to know all settings
class FullPersonaSnapshot(PersonaSnapshot):
is_default_persona: bool
search_start_date: datetime | None = None
prompts: list[PromptSnapshot] = Field(default_factory=list)
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
labels: list["PersonaLabelSnapshot"] = []
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
@classmethod
def from_model(
cls, persona: Persona, allow_deleted: bool = False
) -> "FullPersonaSnapshot":
) -> "PersonaSnapshot":
if persona.deleted:
error_msg = f"Persona with ID {persona.id} has been deleted"
if not allow_deleted:
@@ -172,32 +129,44 @@ class FullPersonaSnapshot(PersonaSnapshot):
else:
logger.warning(error_msg)
return FullPersonaSnapshot(
return PersonaSnapshot(
id=persona.id,
name=persona.name,
description=persona.description,
is_public=persona.is_public,
is_visible=persona.is_visible,
icon_shape=persona.icon_shape,
icon_color=persona.icon_color,
uploaded_image_id=persona.uploaded_image_id,
user_file_ids=[file.id for file in persona.user_files],
user_folder_ids=[folder.id for folder in persona.user_folders],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
owner=(
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
if persona.user
else None
),
search_start_date=persona.search_start_date,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
is_visible=persona.is_visible,
is_public=persona.is_public,
display_priority=persona.display_priority,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
builtin_persona=persona.builtin_persona,
is_default_persona=persona.is_default_persona,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
users=[
MinimalUserSnapshot(id=user.id, email=user.email)
for user in persona.users
],
groups=[user_group.id for user_group in persona.groups],
icon_color=persona.icon_color,
icon_shape=persona.icon_shape,
uploaded_image_id=persona.uploaded_image_id,
search_start_date=persona.search_start_date,
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
user_file_ids=[file.id for file in persona.user_files],
user_folder_ids=[folder.id for folder in persona.user_folders],
)

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
@@ -192,7 +192,7 @@ def create_deletion_attempt_for_connector_id(
db_session.commit()
# run the beat task to pick up this deletion from the db immediately
client_app.send_task(
primary_app.send_task(
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},

View File

@@ -19,7 +19,6 @@ from onyx.db.models import SlackBot as SlackAppModel
from onyx.db.models import SlackChannelConfig as SlackChannelConfigModel
from onyx.db.models import User
from onyx.onyxbot.slack.config import VALID_SLACK_FILTERS
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
@@ -246,7 +245,7 @@ class SlackChannelConfig(BaseModel):
id=slack_channel_config_model.id,
slack_bot_id=slack_channel_config_model.slack_bot_id,
persona=(
FullPersonaSnapshot.from_model(
PersonaSnapshot.from_model(
slack_channel_config_model.persona, allow_deleted=True
)
if slack_channel_config_model.persona

View File

@@ -117,11 +117,7 @@ def set_new_search_settings(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=new_search_settings.id,
db_session=db_session,
)
resync_cc_pair(cc_pair, db_session=db_session)
db_session.commit()
return IdReturn(id=new_search_settings.id)

View File

@@ -96,11 +96,7 @@ def setup_onyx(
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=search_settings.id,
db_session=db_session,
)
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)

View File

@@ -1,5 +1,4 @@
from collections.abc import Callable
from datetime import datetime
from typing import Any
from uuid import UUID
@@ -7,7 +6,6 @@ from pydantic import BaseModel
from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
@@ -77,8 +75,6 @@ class SearchToolOverrideKwargs(BaseModel):
ordering_only: bool | None = (
None # Flag for fast path when search is only needed for ordering
)
document_sources: list[DocumentSource] | None = None
time_cutoff: datetime | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -292,8 +292,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
user_file_ids = None
user_folder_ids = None
ordering_only = False
document_sources = None
time_cutoff = None
if override_kwargs:
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
@@ -304,8 +302,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
user_file_ids = override_kwargs.user_file_ids
user_folder_ids = override_kwargs.user_folder_ids
ordering_only = use_alt_not_None(override_kwargs.ordering_only, False)
document_sources = override_kwargs.document_sources
time_cutoff = override_kwargs.time_cutoff
# Fast path for ordering-only search
if ordering_only:
@@ -338,23 +334,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
retrieval_options = RetrievalDetails(filters=filters)
if document_sources or time_cutoff:
# Get retrieval_options and filters, or create if they don't exist
retrieval_options = retrieval_options or RetrievalDetails()
retrieval_options.filters = retrieval_options.filters or BaseFilters()
# Handle document sources
if document_sources:
source_types = retrieval_options.filters.source_type or []
retrieval_options.filters.source_type = list(
set(source_types + document_sources)
)
# Handle time cutoff
if time_cutoff:
# Overwrite time-cutoff should supercede existing time-cutoff, even if defined
retrieval_options.filters.time_cutoff = time_cutoff
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
@@ -397,7 +376,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config,
retrieved_sections_callback=retrieved_sections_callback,
contextual_pruning_config=self.contextual_pruning_config,
)
search_query_info = SearchQueryInfo(
@@ -469,7 +447,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
db_session=self.db_session,
bypass_acl=self.bypass_acl,
prompt_config=self.prompt_config,
contextual_pruning_config=self.contextual_pruning_config,
)
# Log what we're doing

View File

@@ -13,7 +13,6 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
logging.addLevelName(logging.INFO + 5, "NOTICE")
@@ -72,14 +71,6 @@ def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
return log_level_dict.get(log_level_str.upper(), logging.getLevelName("NOTICE"))
class OnyxRequestIDFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
record.request_id = ONYX_REQUEST_ID_CONTEXTVAR.get() or "-"
return True
class OnyxLoggingAdapter(logging.LoggerAdapter):
def process(
self, msg: str, kwargs: MutableMapping[str, Any]
@@ -112,7 +103,6 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
msg = f"[CC Pair: {cc_pair_id}] {msg}"
break
# Add tenant information if it differs from default
# This will always be the case for authenticated API requests
if MULTI_TENANT:
@@ -125,11 +115,6 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
)
msg = f"[t:{short_tenant}] {msg}"
# request id within a fastapi route
fastapi_request_id = ONYX_REQUEST_ID_CONTEXTVAR.get()
if fastapi_request_id:
msg = f"[{fastapi_request_id}] {msg}"
# For Slack Bot, logs the channel relevant to the request
channel_id = self.extra.get(SLACK_CHANNEL_ID) if self.extra else None
if channel_id:
@@ -180,14 +165,6 @@ class ColoredFormatter(logging.Formatter):
return super().format(record)
def get_uvicorn_standard_formatter() -> ColoredFormatter:
"""Returns a standard colored logging formatter."""
return ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: [%(request_id)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
def get_standard_formatter() -> ColoredFormatter:
"""Returns a standard colored logging formatter."""
return ColoredFormatter(
@@ -224,6 +201,12 @@ def setup_logger(
logger.addHandler(handler)
uvicorn_logger = logging.getLogger("uvicorn.access")
if uvicorn_logger:
uvicorn_logger.handlers = []
uvicorn_logger.addHandler(handler)
uvicorn_logger.setLevel(log_level)
is_containerized = is_running_in_container()
if LOG_FILE_NAME and (is_containerized or DEV_LOGGING_ENABLED):
log_levels = ["debug", "info", "notice"]
@@ -242,37 +225,14 @@ def setup_logger(
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if uvicorn_logger:
uvicorn_logger.addHandler(file_handler)
logger.notice = lambda msg, *args, **kwargs: logger.log(logging.getLevelName("NOTICE"), msg, *args, **kwargs) # type: ignore
return OnyxLoggingAdapter(logger, extra=extra)
def setup_uvicorn_logger(
log_level: int = get_log_level_from_str(),
shared_file_handlers: list[logging.FileHandler] | None = None,
) -> None:
uvicorn_logger = logging.getLogger("uvicorn.access")
if not uvicorn_logger:
return
formatter = get_uvicorn_standard_formatter()
handler = logging.StreamHandler()
handler.setLevel(log_level)
handler.setFormatter(formatter)
uvicorn_logger.handlers = []
uvicorn_logger.addHandler(handler)
uvicorn_logger.setLevel(log_level)
uvicorn_logger.addFilter(OnyxRequestIDFilter())
if shared_file_handlers:
for fh in shared_file_handlers:
uvicorn_logger.addHandler(fh)
return
def print_loggers() -> None:
"""Print information about all loggers. Use to debug logging issues."""
root_logger = logging.getLogger()

View File

@@ -1,62 +0,0 @@
import base64
import hashlib
import logging
import uuid
from collections.abc import Awaitable
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
def add_onyx_request_id_middleware(
app: FastAPI, prefix: str, logger: logging.LoggerAdapter
) -> None:
@app.middleware("http")
async def set_request_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Generate a request hash that can be used to track the lifecycle
of a request. The hash is prefixed to help indicated where the request id
originated.
Format is f"{PREFIX}:{ID}" where PREFIX is 3 chars and ID is 8 chars.
Total length is 12 chars.
"""
onyx_request_id = request.headers.get("X-Onyx-Request-ID")
if not onyx_request_id:
onyx_request_id = make_randomized_onyx_request_id(prefix)
ONYX_REQUEST_ID_CONTEXTVAR.set(onyx_request_id)
return await call_next(request)
def make_randomized_onyx_request_id(prefix: str) -> str:
"""generates a randomized request id"""
hash_input = str(uuid.uuid4())
return _make_onyx_request_id(prefix, hash_input)
def make_structured_onyx_request_id(prefix: str, request_url: str) -> str:
"""Not used yet, but could be in the future!"""
hash_input = f"{request_url}:{datetime.now(timezone.utc)}"
return _make_onyx_request_id(prefix, hash_input)
def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
"""helper function to return an id given a string input"""
hash_obj = hashlib.md5(hash_input.encode("utf-8"))
hash_bytes = hash_obj.digest()[:6] # Truncate to 6 bytes
# 6 bytes becomes 8 bytes. we shouldn't need to strip but just in case
# NOTE: possible we'll want more input bytes if id's aren't unique enough
hash_str = base64.urlsafe_b64encode(hash_bytes).decode("utf-8").rstrip("=")
onyx_request_id = f"{prefix}:{hash_str}"
return onyx_request_id

View File

@@ -39,7 +39,6 @@ class RecordType(str, Enum):
INDEXING_PROGRESS = "indexing_progress"
INDEXING_COMPLETE = "indexing_complete"
PERMISSION_SYNC_PROGRESS = "permission_sync_progress"
PERMISSION_SYNC_COMPLETE = "permission_sync_complete"
INDEX_ATTEMPT_STATUS = "index_attempt_status"

View File

@@ -332,15 +332,14 @@ def wait_on_background(task: TimeoutThread[R]) -> R:
return task.result
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
return ind, next(gen, None)
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
return ind, next(g, None)
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {
executor.submit(_next_or_none, ind, gen): ind
for ind, gen in enumerate(gens)
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
}
next_ind = len(gens)

View File

@@ -95,5 +95,4 @@ urllib3==2.2.3
mistune==0.8.4
sentry-sdk==2.14.0
prometheus_client==0.21.0
fastapi-limiter==0.1.6
prometheus_fastapi_instrumentator==7.1.0
fastapi-limiter==0.1.6

View File

@@ -15,5 +15,4 @@ uvicorn==0.21.1
voyageai==0.2.3
litellm==1.61.16
sentry-sdk[fastapi,celery,starlette]==2.14.0
aioboto3==13.4.0
prometheus_fastapi_instrumentator==7.1.0
aioboto3==13.4.0

View File

@@ -887,7 +887,6 @@ def main() -> None:
type=int,
help="Maximum number of documents to delete (for delete-all-documents)",
)
parser.add_argument("--link", help="Document link (for get_acls filter)")
args = parser.parse_args()
vespa_debug = VespaDebugging(args.tenant_id)
@@ -925,11 +924,7 @@ def main() -> None:
elif args.action == "get_acls":
if args.cc_pair_id is None:
parser.error("--cc-pair-id is required for get_acls action")
if args.link is None:
vespa_debug.acls(args.cc_pair_id, args.n)
else:
vespa_debug.acls_by_link(args.cc_pair_id, args.link)
vespa_debug.acls(args.cc_pair_id, args.n)
if __name__ == "__main__":

View File

@@ -58,7 +58,6 @@ INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true"
# The process needs to have this for the log file to write to
# otherwise, it will not create additional log files
# This should just be the filename base without extension or path.
LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "onyx"
# Enable generating persistent log files for local dev environments

View File

@@ -11,15 +11,6 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
"current_tenant_id", default=None if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
)
# set by every route in the API server
INDEXING_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[
str | None
] = contextvars.ContextVar("indexing_request_id", default=None)
# set by every route in the API server
ONYX_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"onyx_request_id", default=None
)
"""Utils related to contextvars"""

View File

@@ -34,7 +34,7 @@ def confluence_connector(space: str) -> ConfluenceConnector:
return connector
@pytest.mark.parametrize("space", [os.getenv("CONFLUENCE_TEST_SPACE") or "DailyConne"])
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,

View File

@@ -1,44 +0,0 @@
import os
import time
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.connectors.gong.connector import GongConnector
from onyx.connectors.models import Document
@pytest.fixture
def gong_connector() -> GongConnector:
connector = GongConnector()
connector.load_credentials(
{
"gong_access_key": os.environ["GONG_ACCESS_KEY"],
"gong_access_key_secret": os.environ["GONG_ACCESS_KEY_SECRET"],
}
)
return connector
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_gong_basic(mock_get_api_key: MagicMock, gong_connector: GongConnector) -> None:
doc_batch_generator = gong_connector.poll_source(0, time.time())
doc_batch = next(doc_batch_generator)
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 2
docs: list[Document] = []
for doc in doc_batch:
docs.append(doc)
assert docs[0].semantic_identifier == "test with chris"
assert docs[1].semantic_identifier == "Testing Gong"

View File

@@ -1,7 +1,6 @@
import json
import os
import time
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import patch
@@ -106,54 +105,6 @@ def test_highspot_connector_slim(
assert len(all_slim_doc_ids) > 0
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_highspot_connector_poll_source(
mock_get_api_key: MagicMock, highspot_connector: HighspotConnector
) -> None:
"""Test poll_source functionality with date range filtering."""
# Define date range: April 3, 2025 to April 4, 2025
start_date = datetime(2025, 4, 3, 0, 0, 0)
end_date = datetime(2025, 4, 4, 23, 59, 59)
# Convert to seconds since Unix epoch
start_time = int(time.mktime(start_date.timetuple()))
end_time = int(time.mktime(end_date.timetuple()))
# Load test data for assertions
test_data = load_test_data()
poll_source_data = test_data.get("poll_source", {})
target_doc_id = poll_source_data.get("target_doc_id")
# Call poll_source with date range
all_docs: list[Document] = []
target_doc: Document | None = None
for doc_batch in highspot_connector.poll_source(start_time, end_time):
for doc in doc_batch:
all_docs.append(doc)
if doc.id == f"HIGHSPOT_{target_doc_id}":
target_doc = doc
# Verify documents were loaded
assert len(all_docs) > 0
# Verify the specific test document was found and has correct properties
assert target_doc is not None
assert target_doc.semantic_identifier == poll_source_data.get("semantic_identifier")
assert target_doc.source == DocumentSource.HIGHSPOT
assert target_doc.metadata is not None
# Verify sections
assert len(target_doc.sections) == 1
section = target_doc.sections[0]
assert section.link == poll_source_data.get("link")
assert section.text is not None
assert len(section.text) > 0
def test_highspot_connector_validate_credentials(
highspot_connector: HighspotConnector,
) -> None:

View File

@@ -1,10 +1,5 @@
{
"target_doc_id": "67cd8eb35d3ee0487de2e704",
"semantic_identifier": "Highspot in Action _ Salesforce Integration",
"link": "https://www.highspot.com/items/67cd8eb35d3ee0487de2e704",
"poll_source": {
"target_doc_id":"67ef9edcc3f40b2bf3d816a8",
"semantic_identifier":"A Brief Introduction To AI",
"link":"https://www.highspot.com/items/67ef9edcc3f40b2bf3d816a8"
}
"link": "https://www.highspot.com/items/67cd8eb35d3ee0487de2e704"
}

View File

@@ -35,22 +35,23 @@ def salesforce_connector() -> SalesforceConnector:
connector = SalesforceConnector(
requested_objects=["Account", "Contact", "Opportunity"],
)
username = os.environ["SF_USERNAME"]
password = os.environ["SF_PASSWORD"]
security_token = os.environ["SF_SECURITY_TOKEN"]
connector.load_credentials(
{
"sf_username": username,
"sf_password": password,
"sf_security_token": security_token,
"sf_username": os.environ["SF_USERNAME"],
"sf_password": os.environ["SF_PASSWORD"],
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
}
)
return connector
# TODO: make the credentials not expire
@pytest.mark.xfail(
reason=(
"Credentials change over time, so this test will fail if run when "
"the credentials expire."
)
)
def test_salesforce_connector_basic(salesforce_connector: SalesforceConnector) -> None:
test_data = load_test_data()
target_test_doc: Document | None = None
@@ -60,26 +61,21 @@ def test_salesforce_connector_basic(salesforce_connector: SalesforceConnector) -
all_docs.append(doc)
if doc.id == test_data["id"]:
target_test_doc = doc
break
# The number of docs here seems to change actively so do a very loose check
# as of 2025-03-28 it was around 32472
assert len(all_docs) > 32000
assert len(all_docs) < 40000
assert len(all_docs) == 6
assert target_test_doc is not None
# Set of received links
received_links: set[str] = set()
# List of received text fields, which contain key-value pairs seperated by newlines
received_text: list[str] = []
recieved_text: list[str] = []
# Iterate over the sections of the target test doc to extract the links and text
for section in target_test_doc.sections:
assert section.link
assert section.text
received_links.add(section.link)
received_text.append(section.text)
recieved_text.append(section.text)
# Check that the received links match the expected links from the test data json
expected_links = set(test_data["expected_links"])
@@ -89,9 +85,8 @@ def test_salesforce_connector_basic(salesforce_connector: SalesforceConnector) -
expected_text = test_data["expected_text"]
if not isinstance(expected_text, list):
raise ValueError("Expected text is not a list")
unparsed_expected_key_value_pairs: list[str] = expected_text
received_key_value_pairs = extract_key_value_pairs_to_set(received_text)
received_key_value_pairs = extract_key_value_pairs_to_set(recieved_text)
expected_key_value_pairs = extract_key_value_pairs_to_set(
unparsed_expected_key_value_pairs
)
@@ -101,21 +96,13 @@ def test_salesforce_connector_basic(salesforce_connector: SalesforceConnector) -
assert target_test_doc.source == DocumentSource.SALESFORCE
assert target_test_doc.semantic_identifier == test_data["semantic_identifier"]
assert target_test_doc.metadata == test_data["metadata"]
assert target_test_doc.primary_owners is not None
primary_owner = target_test_doc.primary_owners[0]
expected_primary_owner = test_data["primary_owners"]
assert isinstance(expected_primary_owner, dict)
assert primary_owner.email == expected_primary_owner["email"]
assert primary_owner.first_name == expected_primary_owner["first_name"]
assert primary_owner.last_name == expected_primary_owner["last_name"]
assert target_test_doc.primary_owners == test_data["primary_owners"]
assert target_test_doc.secondary_owners == test_data["secondary_owners"]
assert target_test_doc.title == test_data["title"]
# TODO: make the credentials not expire
@pytest.mark.skip(
@pytest.mark.xfail(
reason=(
"Credentials change over time, so this test will fail if run when "
"the credentials expire."

View File

@@ -1,162 +1,20 @@
{
"id": "SALESFORCE_001bm00000eu6n5AAA",
"id": "SALESFORCE_001fI000005drUcQAI",
"expected_links": [
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESpEeAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESqd3AAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESoKiAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESvDSAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrmHAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrl2AAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESvejAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000EStlvAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESpPfAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrP9AAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESvlMAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESt3JAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESoBkAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000EStw2AAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrkMAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESojKAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESuLEAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESoSIAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESu2YAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESvgSAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESurnAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrnqAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESoB5AAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESuJuAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrfyAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/001bm00000eu6n5AAA",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESpUHAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESsgGAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESr7UAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESu1BAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESpqzAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESplZAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESvJ3AAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESurKAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000EStSiAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESuJFAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESu8xAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESqfzAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESqsrAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000EStoZAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESsIUAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESsAGAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESv8GAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrOKAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESoUmAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESudKAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESuJ8AAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESvf2AAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESw3qAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESugRAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESr18AAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESqV1AAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESuLVAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESpjoAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESqULAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESuCAAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrfpAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESp5YAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrMNAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000EStaUAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESt5LAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrtcAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESomaAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrtIAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESoToAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESuWLAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESrWvAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESsJEAA1",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESsxwAAD",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESvUgAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESvWjAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000EStBuAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESpZiAAL",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESuhYAAT",
"https://danswer-dev-ed.develop.my.salesforce.com/003bm00000ESuWAAA1"
"https://customization-ruby-2195.my.salesforce.com/001fI000005drUcQAI",
"https://customization-ruby-2195.my.salesforce.com/003fI000001jiCPQAY",
"https://customization-ruby-2195.my.salesforce.com/017fI00000T7hvsQAB",
"https://customization-ruby-2195.my.salesforce.com/006fI000000rDvBQAU"
],
"expected_text": [
"IsDeleted: false\nBillingCity: Shaykh al \u00e1\u00b8\u00a8ad\u00c4\u00abd\nName: Voonder\nCleanStatus: Pending\nBillingStreet: 12 Cambridge Parkway",
"Email: eslayqzs@icio.us\nIsDeleted: false\nLastName: Slay\nIsEmailBounced: false\nFirstName: Ebeneser\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: ptweedgdh@umich.edu\nIsDeleted: false\nLastName: Tweed\nIsEmailBounced: false\nFirstName: Paulita\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: ehurnellnlx@facebook.com\nIsDeleted: false\nLastName: Hurnell\nIsEmailBounced: false\nFirstName: Eliot\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: ccarik4q4@google.it\nIsDeleted: false\nLastName: Carik\nIsEmailBounced: false\nFirstName: Chadwick\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: cvannozziina6@moonfruit.com\nIsDeleted: false\nLastName: Vannozzii\nIsEmailBounced: false\nFirstName: Christophorus\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: mikringill2kz@hugedomains.com\nIsDeleted: false\nLastName: Ikringill\nIsEmailBounced: false\nFirstName: Meghann\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: bgrinvalray@fda.gov\nIsDeleted: false\nLastName: Grinval\nIsEmailBounced: false\nFirstName: Berti\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: aollanderhr7@cam.ac.uk\nIsDeleted: false\nLastName: Ollander\nIsEmailBounced: false\nFirstName: Annemarie\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: rwhitesideq38@gravatar.com\nIsDeleted: false\nLastName: Whiteside\nIsEmailBounced: false\nFirstName: Rolando\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: vkrafthmz@techcrunch.com\nIsDeleted: false\nLastName: Kraft\nIsEmailBounced: false\nFirstName: Vidovik\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: jhillaut@4shared.com\nIsDeleted: false\nLastName: Hill\nIsEmailBounced: false\nFirstName: Janel\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: lralstonycs@discovery.com\nIsDeleted: false\nLastName: Ralston\nIsEmailBounced: false\nFirstName: Lorrayne\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: blyttlewba@networkadvertising.org\nIsDeleted: false\nLastName: Lyttle\nIsEmailBounced: false\nFirstName: Ban\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: pplummernvf@technorati.com\nIsDeleted: false\nLastName: Plummer\nIsEmailBounced: false\nFirstName: Pete\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: babrahamoffxpb@theatlantic.com\nIsDeleted: false\nLastName: Abrahamoff\nIsEmailBounced: false\nFirstName: Brander\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: ahargieym0@homestead.com\nIsDeleted: false\nLastName: Hargie\nIsEmailBounced: false\nFirstName: Aili\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: hstotthp2@yelp.com\nIsDeleted: false\nLastName: Stott\nIsEmailBounced: false\nFirstName: Hartley\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: jganniclifftuvj@blinklist.com\nIsDeleted: false\nLastName: Ganniclifft\nIsEmailBounced: false\nFirstName: Jamima\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: ldodelly8q@ed.gov\nIsDeleted: false\nLastName: Dodell\nIsEmailBounced: false\nFirstName: Lynde\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: rmilner3cp@smh.com.au\nIsDeleted: false\nLastName: Milner\nIsEmailBounced: false\nFirstName: Ralph\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: gghiriardellic19@state.tx.us\nIsDeleted: false\nLastName: Ghiriardelli\nIsEmailBounced: false\nFirstName: Garv\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: rhubatschfpu@nature.com\nIsDeleted: false\nLastName: Hubatsch\nIsEmailBounced: false\nFirstName: Rose\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: mtrenholme1ws@quantcast.com\nIsDeleted: false\nLastName: Trenholme\nIsEmailBounced: false\nFirstName: Mariejeanne\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: jmussettpbd@over-blog.com\nIsDeleted: false\nLastName: Mussett\nIsEmailBounced: false\nFirstName: Juliann\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: bgoroni145@illinois.edu\nIsDeleted: false\nLastName: Goroni\nIsEmailBounced: false\nFirstName: Bernarr\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: afalls3ph@theguardian.com\nIsDeleted: false\nLastName: Falls\nIsEmailBounced: false\nFirstName: Angelia\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: lswettjoi@go.com\nIsDeleted: false\nLastName: Swett\nIsEmailBounced: false\nFirstName: Levon\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: emullinsz38@dailymotion.com\nIsDeleted: false\nLastName: Mullins\nIsEmailBounced: false\nFirstName: Elsa\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: ibernettehco@ebay.co.uk\nIsDeleted: false\nLastName: Bernette\nIsEmailBounced: false\nFirstName: Ingrid\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: trisleybtt@simplemachines.org\nIsDeleted: false\nLastName: Risley\nIsEmailBounced: false\nFirstName: Toma\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: rgypsonqx1@goodreads.com\nIsDeleted: false\nLastName: Gypson\nIsEmailBounced: false\nFirstName: Reed\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: cposvneri28@jiathis.com\nIsDeleted: false\nLastName: Posvner\nIsEmailBounced: false\nFirstName: Culley\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: awilmut2rz@geocities.jp\nIsDeleted: false\nLastName: Wilmut\nIsEmailBounced: false\nFirstName: Andy\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: aluckwellra5@exblog.jp\nIsDeleted: false\nLastName: Luckwell\nIsEmailBounced: false\nFirstName: Andreana\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: irollings26j@timesonline.co.uk\nIsDeleted: false\nLastName: Rollings\nIsEmailBounced: false\nFirstName: Ibrahim\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: gspireqpd@g.co\nIsDeleted: false\nLastName: Spire\nIsEmailBounced: false\nFirstName: Gaelan\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: sbezleyk2y@acquirethisname.com\nIsDeleted: false\nLastName: Bezley\nIsEmailBounced: false\nFirstName: Sindee\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: icollerrr@flickr.com\nIsDeleted: false\nLastName: Coller\nIsEmailBounced: false\nFirstName: Inesita\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: kfolliott1bo@nature.com\nIsDeleted: false\nLastName: Folliott\nIsEmailBounced: false\nFirstName: Kennan\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: kroofjfo@gnu.org\nIsDeleted: false\nLastName: Roof\nIsEmailBounced: false\nFirstName: Karlik\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: lcovotti8s4@rediff.com\nIsDeleted: false\nLastName: Covotti\nIsEmailBounced: false\nFirstName: Lucho\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: gpatriskson1rs@census.gov\nIsDeleted: false\nLastName: Patriskson\nIsEmailBounced: false\nFirstName: Gardener\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: spidgleyqvw@usgs.gov\nIsDeleted: false\nLastName: Pidgley\nIsEmailBounced: false\nFirstName: Simona\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: cbecarrak0i@over-blog.com\nIsDeleted: false\nLastName: Becarra\nIsEmailBounced: false\nFirstName: Cally\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: aparkman9td@bbc.co.uk\nIsDeleted: false\nLastName: Parkman\nIsEmailBounced: false\nFirstName: Agneta\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: bboddingtonhn@quantcast.com\nIsDeleted: false\nLastName: Boddington\nIsEmailBounced: false\nFirstName: Betta\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: dcasementx0p@cafepress.com\nIsDeleted: false\nLastName: Casement\nIsEmailBounced: false\nFirstName: Dannie\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: hzornbhe@latimes.com\nIsDeleted: false\nLastName: Zorn\nIsEmailBounced: false\nFirstName: Haleigh\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: cfifieldbjb@blogspot.com\nIsDeleted: false\nLastName: Fifield\nIsEmailBounced: false\nFirstName: Christalle\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: ddewerson4t3@skype.com\nIsDeleted: false\nLastName: Dewerson\nIsEmailBounced: false\nFirstName: Dyann\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: khullock52p@sohu.com\nIsDeleted: false\nLastName: Hullock\nIsEmailBounced: false\nFirstName: Kellina\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: tfremantle32n@bandcamp.com\nIsDeleted: false\nLastName: Fremantle\nIsEmailBounced: false\nFirstName: Turner\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: sbernardtylp@nps.gov\nIsDeleted: false\nLastName: Bernardt\nIsEmailBounced: false\nFirstName: Selina\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: smcgettigan8kk@slideshare.net\nIsDeleted: false\nLastName: McGettigan\nIsEmailBounced: false\nFirstName: Sada\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: wdelafontvgn@businesswire.com\nIsDeleted: false\nLastName: Delafont\nIsEmailBounced: false\nFirstName: West\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: lbelsher9ne@indiatimes.com\nIsDeleted: false\nLastName: Belsher\nIsEmailBounced: false\nFirstName: Lou\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: cgoody27y@blogtalkradio.com\nIsDeleted: false\nLastName: Goody\nIsEmailBounced: false\nFirstName: Colene\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: cstodejzz@ucoz.ru\nIsDeleted: false\nLastName: Stode\nIsEmailBounced: false\nFirstName: Curcio\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: abromidgejb@china.com.cn\nIsDeleted: false\nLastName: Bromidge\nIsEmailBounced: false\nFirstName: Ariela\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: ldelgardilloqvp@xrea.com\nIsDeleted: false\nLastName: Delgardillo\nIsEmailBounced: false\nFirstName: Lauralee\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: dcroal9t4@businessinsider.com\nIsDeleted: false\nLastName: Croal\nIsEmailBounced: false\nFirstName: Devlin\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: dclarageqzb@wordpress.com\nIsDeleted: false\nLastName: Clarage\nIsEmailBounced: false\nFirstName: Dre\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: dthirlwall3jf@taobao.com\nIsDeleted: false\nLastName: Thirlwall\nIsEmailBounced: false\nFirstName: Dareen\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: tkeddie2lj@wiley.com\nIsDeleted: false\nLastName: Keddie\nIsEmailBounced: false\nFirstName: Tandi\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: jrimingtoni3i@istockphoto.com\nIsDeleted: false\nLastName: Rimington\nIsEmailBounced: false\nFirstName: Judy\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: gtroynet@slashdot.org\nIsDeleted: false\nLastName: Troy\nIsEmailBounced: false\nFirstName: Gail\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: ebunneyh0n@meetup.com\nIsDeleted: false\nLastName: Bunney\nIsEmailBounced: false\nFirstName: Efren\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: yhaken8p3@slate.com\nIsDeleted: false\nLastName: Haken\nIsEmailBounced: false\nFirstName: Yard\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: nolliffeq6q@biblegateway.com\nIsDeleted: false\nLastName: Olliffe\nIsEmailBounced: false\nFirstName: Nani\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: bgalia9jz@odnoklassniki.ru\nIsDeleted: false\nLastName: Galia\nIsEmailBounced: false\nFirstName: Berrie\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: djedrzej3v1@google.com\nIsDeleted: false\nLastName: Jedrzej\nIsEmailBounced: false\nFirstName: Deanne\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: mcamiesh1t@fc2.com\nIsDeleted: false\nLastName: Camies\nIsEmailBounced: false\nFirstName: Mikaela\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: csunshineqni@state.tx.us\nIsDeleted: false\nLastName: Sunshine\nIsEmailBounced: false\nFirstName: Curtis\nIsPriorityRecord: false\nCleanStatus: Pending",
"Email: fiannellib46@marriott.com\nIsDeleted: false\nLastName: Iannelli\nIsEmailBounced: false\nFirstName: Felicio\nIsPriorityRecord: false\nCleanStatus: Pending"
"BillingPostalCode: 60601\nType: Prospect\nWebsite: www.globalistindustries.com\nBillingCity: Chicago\nDescription: Globalist company\nIsDeleted: false\nIsPartner: false\nPhone: (312) 555-0456\nShippingCountry: USA\nShippingState: IL\nIsBuyer: false\nBillingCountry: USA\nBillingState: IL\nShippingPostalCode: 60601\nBillingStreet: 456 Market St\nIsCustomerPortal: false\nPersonActiveTrackerCount: 0\nShippingCity: Chicago\nShippingStreet: 456 Market St",
"FirstName: Michael\nMailingCountry: USA\nActiveTrackerCount: 0\nEmail: m.brown@globalindustries.com\nMailingState: IL\nMailingStreet: 456 Market St\nMailingCity: Chicago\nLastName: Brown\nTitle: CTO\nIsDeleted: false\nPhone: (312) 555-0456\nHasOptedOutOfEmail: false\nIsEmailBounced: false\nMailingPostalCode: 60601",
"ForecastCategory: Closed\nName: Global Industries Equipment Sale\nIsDeleted: false\nForecastCategoryName: Closed\nFiscalYear: 2024\nFiscalQuarter: 4\nIsClosed: true\nIsWon: true\nAmount: 5000000.0\nProbability: 100.0\nPushCount: 0\nHasOverdueTask: false\nStageName: Closed Won\nHasOpenActivity: false\nHasOpportunityLineItem: false",
"Field: created\nDataType: Text\nIsDeleted: false"
],
"semantic_identifier": "Voonder",
"semantic_identifier": "Unknown Object",
"metadata": {},
"primary_owners": {"email": "hagen@danswer.ai", "first_name": "Hagen", "last_name": "oneill"},
"primary_owners": null,
"secondary_owners": null,
"title": null
}

View File

@@ -444,7 +444,6 @@ class CCPairManager:
)
if group_sync_result.status_code != 409:
group_sync_result.raise_for_status()
time.sleep(2)
@staticmethod
def get_doc_sync_task(

View File

@@ -165,18 +165,17 @@ class DocumentManager:
doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict
}
# NOTE(rkuo): too much log spam
# Left this here for debugging purposes.
# import json
import json
# print("DEBUGGING DOCUMENTS")
# print(retrieved_docs)
# for doc in retrieved_docs.values():
# printable_doc = doc.copy()
# print(printable_doc.keys())
# printable_doc.pop("embeddings")
# printable_doc.pop("title_embedding")
# print(json.dumps(printable_doc, indent=2))
print("DEBUGGING DOCUMENTS")
print(retrieved_docs)
for doc in retrieved_docs.values():
printable_doc = doc.copy()
print(printable_doc.keys())
printable_doc.pop("embeddings")
printable_doc.pop("title_embedding")
print(json.dumps(printable_doc, indent=2))
for document in cc_pair.documents:
retrieved_doc = retrieved_docs.get(document.id)

View File

@@ -1,4 +1,3 @@
import time
from datetime import datetime
from datetime import timedelta
from urllib.parse import urlencode
@@ -192,7 +191,7 @@ class IndexAttemptManager:
user_performing_action: DATestUser | None = None,
) -> None:
"""Wait for an IndexAttempt to complete"""
start = time.monotonic()
start = datetime.now()
while True:
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt_id,
@@ -204,7 +203,7 @@ class IndexAttemptManager:
print(f"IndexAttempt {index_attempt_id} completed")
return
elapsed = time.monotonic() - start
elapsed = (datetime.now() - start).total_seconds()
if elapsed > timeout:
raise TimeoutError(
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"

View File

@@ -4,7 +4,7 @@ from uuid import uuid4
import requests
from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
@@ -181,7 +181,7 @@ class PersonaManager:
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[FullPersonaSnapshot]:
) -> list[PersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/admin/persona",
headers=user_performing_action.headers
@@ -189,13 +189,13 @@ class PersonaManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return [FullPersonaSnapshot(**persona) for persona in response.json()]
return [PersonaSnapshot(**persona) for persona in response.json()]
@staticmethod
def get_one(
persona_id: int,
user_performing_action: DATestUser | None = None,
) -> list[FullPersonaSnapshot]:
) -> list[PersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/persona/{persona_id}",
headers=user_performing_action.headers
@@ -203,7 +203,7 @@ class PersonaManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return [FullPersonaSnapshot(**response.json())]
return [PersonaSnapshot(**response.json())]
@staticmethod
def verify(

View File

@@ -14,8 +14,9 @@ from tests.integration.connector_job_tests.slack.slack_api_utils import SlackMan
@pytest.fixture()
def slack_test_setup() -> Generator[tuple[dict[str, Any], dict[str, Any]], None, None]:
slack_client = SlackManager.get_slack_client(os.environ["SLACK_BOT_TOKEN"])
user_map = SlackManager.build_slack_user_email_id_map(slack_client)
admin_user_id = user_map["admin@onyx-test.com"]
admin_user_id = SlackManager.build_slack_user_email_id_map(slack_client)[
"admin@onyx-test.com"
]
(
public_channel,

View File

@@ -3,6 +3,8 @@ from datetime import datetime
from datetime import timezone
from typing import Any
import pytest
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
@@ -23,6 +25,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
@pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False)
def test_slack_permission_sync(
reset: None,
vespa_client: vespa_fixture,
@@ -218,6 +221,7 @@ def test_slack_permission_sync(
assert private_message not in onyx_doc_message_strings
@pytest.mark.xfail(reason="flaky", strict=False)
def test_slack_group_permission_sync(
reset: None,
vespa_client: vespa_fixture,

Some files were not shown because too many files have changed in this diff Show More