mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
61676620ab | ||
|
|
edaca1f58b | ||
|
|
1f280bafca | ||
|
|
cb5bbd3812 | ||
|
|
742d29e504 | ||
|
|
ecc155d082 | ||
|
|
0857e4809d | ||
|
|
22e00a1f5c | ||
|
|
0d0588a0c1 | ||
|
|
aab777f844 | ||
|
|
babbe7689a | ||
|
|
a123661c92 | ||
|
|
c554889baf | ||
|
|
f08fa878a6 | ||
|
|
d307534781 | ||
|
|
6f54791910 | ||
|
|
0d5497bb6b | ||
|
|
7648627503 | ||
|
|
927554d5ca | ||
|
|
7dcec6caf5 | ||
|
|
036648146d | ||
|
|
2aa4697ac8 | ||
|
|
bc9b4e4f45 | ||
|
|
178a64f298 | ||
|
|
c79f1edf1d | ||
|
|
7c8e23aa54 | ||
|
|
d37b427d52 | ||
|
|
a65fefd226 | ||
|
|
bb09bde519 |
209
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
209
.github/workflows/pr-mit-integration-tests.yml
vendored
Normal file
@@ -0,0 +1,209 @@
|
||||
name: Run MIT Integration Tests v2
|
||||
concurrency:
|
||||
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Web Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-web-server:latest
|
||||
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
|
||||
|
||||
# we use the runs-on cache for docker builds
|
||||
# in conjunction with runs-on runners, it has better speed and unlimited caching
|
||||
# https://runs-on.com/caching/s3-cache-for-github-actions/
|
||||
# https://runs-on.com/caching/docker/
|
||||
# https://github.com/moby/buildkit#s3-cache-experimental
|
||||
|
||||
# images are built and run locally for testing purposes. Not pushed.
|
||||
- name: Build Backend Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-backend:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build Model Server Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-model-server:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: onyxdotapp/onyx-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
AUTH_TYPE=basic \
|
||||
POSTGRES_POOL_PRE_PING=true \
|
||||
POSTGRES_USE_NULL_POOL=true \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
docker logs -f onyx-stack-api_server-1 &
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Start Mock Services
|
||||
run: |
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network onyx-stack_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/tests \
|
||||
/app/tests/integration/connector_job_tests
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
|
||||
@@ -9,6 +9,10 @@ on:
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
env:
|
||||
# AWS
|
||||
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
|
||||
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
|
||||
|
||||
# Confluence
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
|
||||
|
||||
@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
COPY ./static /app/static
|
||||
|
||||
# Escape hatch scripts
|
||||
COPY ./scripts/debugging /app/scripts/debugging
|
||||
|
||||
@@ -28,6 +28,20 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First, drop any existing indexes to avoid conflicts
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
|
||||
# Drop existing columns if they exist
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
|
||||
@@ -93,12 +93,12 @@ def _get_access_for_documents(
|
||||
)
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
access_map[document_id] = DocumentAccess(
|
||||
user_emails=non_ee_access.user_emails,
|
||||
user_groups=set(user_group_info.get(document_id, [])),
|
||||
access_map[document_id] = DocumentAccess.build(
|
||||
user_emails=list(non_ee_access.user_emails),
|
||||
user_groups=user_group_info.get(document_id, []),
|
||||
is_public=is_public_anywhere,
|
||||
external_user_emails=ext_u_emails,
|
||||
external_user_group_ids=ext_u_groups,
|
||||
external_user_emails=list(ext_u_emails),
|
||||
external_user_group_ids=list(ext_u_groups),
|
||||
)
|
||||
return access_map
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
@@ -32,8 +31,6 @@ def gather_stream_for_answer_api(
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
elif isinstance(packet, OnyxContexts):
|
||||
response.contexts = packet
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
@@ -25,6 +25,10 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
@@ -39,6 +43,7 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
@@ -72,6 +77,13 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
)
|
||||
|
||||
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
# The posthog client does not accept empty API keys or hosts however it fails silently
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
|
||||
@@ -3,6 +3,8 @@ from collections.abc import Generator
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
@@ -66,13 +68,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 30 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
|
||||
from ee.onyx.server.query_and_chat.models import SimpleDoc
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
@@ -56,25 +55,6 @@ logger = setup_logger()
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
def _translate_doc_response_to_simple_doc(
|
||||
doc_response: QADocsResponse,
|
||||
) -> list[SimpleDoc]:
|
||||
return [
|
||||
SimpleDoc(
|
||||
id=doc.document_id,
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
link=doc.link,
|
||||
blurb=doc.blurb,
|
||||
match_highlights=[
|
||||
highlight for highlight in doc.match_highlights if highlight
|
||||
],
|
||||
source_type=doc.source_type,
|
||||
metadata=doc.metadata,
|
||||
)
|
||||
for doc in doc_response.top_documents
|
||||
]
|
||||
|
||||
|
||||
def _get_final_context_doc_indices(
|
||||
final_context_docs: list[LlmDoc] | None,
|
||||
top_docs: list[SavedSearchDoc] | None,
|
||||
@@ -111,9 +91,6 @@ def _convert_packet_stream_to_response(
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.top_documents = packet.top_documents
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
|
||||
# This is a no-op if agent_sub_questions hasn't already been filled
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
|
||||
@@ -8,7 +8,6 @@ from pydantic import model_validator
|
||||
|
||||
from ee.onyx.server.manage.models import StandardAnswer
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
@@ -164,8 +163,6 @@ class ChatBasicResponse(BaseModel):
|
||||
cited_documents: dict[int, str] | None = None
|
||||
|
||||
# FOR BACKWARDS COMPATIBILITY
|
||||
# TODO: deprecate both of these
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
# agentic fields
|
||||
@@ -220,4 +217,3 @@ class OneShotQAResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
contexts: OnyxContexts | None = None
|
||||
|
||||
@@ -36,8 +36,12 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
# Define non-authenticated user roles that should be re-created during SAML login
|
||||
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
logger.debug(f"Attempting to upsert SAML user with email: {email}")
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
) # type:ignore
|
||||
@@ -48,9 +52,13 @@ async def upsert_saml_user(email: str) -> User:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
async with get_user_manager_context(user_db) as user_manager:
|
||||
try:
|
||||
return await user_manager.get_by_email(email)
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if user.role in NON_AUTHENTICATED_ROLES:
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
logger.notice("Creating user from SAML login")
|
||||
logger.info("Creating user from SAML login")
|
||||
|
||||
user_count = await get_user_count()
|
||||
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
|
||||
@@ -59,11 +67,10 @@ async def upsert_saml_user(email: str) -> User:
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
|
||||
user: User = await user_manager.create(
|
||||
user = await user_manager.create(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password=hashed_pass,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -70,6 +70,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
@@ -88,9 +89,25 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
.first()
|
||||
)
|
||||
|
||||
# If user already has an active mapping, add this one as inactive
|
||||
if not existing_mapping:
|
||||
# Only add if mapping doesn't exist
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
# Check if the user already has an active mapping to any tenant
|
||||
has_active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=False if has_active_mapping else True,
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
|
||||
@@ -18,7 +18,7 @@ def _get_access_for_document(
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
return DocumentAccess.build(
|
||||
doc_access = DocumentAccess.build(
|
||||
user_emails=info[1] if info and info[1] else [],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
@@ -26,6 +26,8 @@ def _get_access_for_document(
|
||||
is_public=info[2] if info else False,
|
||||
)
|
||||
|
||||
return doc_access
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
document_id: str,
|
||||
@@ -38,12 +40,12 @@ def get_access_for_document(
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
return DocumentAccess(
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
return DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=False,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
|
||||
|
||||
@@ -56,18 +58,18 @@ def _get_access_for_documents(
|
||||
document_ids=document_ids,
|
||||
)
|
||||
doc_access = {
|
||||
document_id: DocumentAccess(
|
||||
user_emails=set([email for email in user_emails if email]),
|
||||
document_id: DocumentAccess.build(
|
||||
user_emails=[email for email in user_emails if email],
|
||||
# MIT version will wipe all groups and external groups on update
|
||||
user_groups=set(),
|
||||
user_groups=[],
|
||||
is_public=is_public,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for document_id, user_emails, is_public in document_access_info
|
||||
}
|
||||
|
||||
# Sometimes the document has not be indexed by the indexing job yet, in those cases
|
||||
# Sometimes the document has not been indexed by the indexing job yet, in those cases
|
||||
# the document does not exist and so we use least permissive. Specifically the EE version
|
||||
# checks the MIT version permissions and creates a superset. This ensures that this flow
|
||||
# does not fail even if the Document has not yet been indexed.
|
||||
|
||||
@@ -56,34 +56,46 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@dataclass(frozen=True, init=False)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
user_emails: set[str | None]
|
||||
|
||||
# Names of user groups associated with this document
|
||||
user_groups: set[str]
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
return set(
|
||||
[
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.user_emails
|
||||
if user_email
|
||||
]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ [
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.external_user_emails
|
||||
]
|
||||
+ [
|
||||
# The group names are already prefixed by the source type
|
||||
# This adds an additional prefix of "external_group:"
|
||||
prefix_external_group(group_name)
|
||||
for group_name in self.external_user_group_ids
|
||||
]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
external_user_emails: set[str]
|
||||
external_user_group_ids: set[str]
|
||||
is_public: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise TypeError(
|
||||
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
acl_set.add(prefix_user_email(user_email))
|
||||
|
||||
for group_name in self.user_groups:
|
||||
acl_set.add(prefix_user_group(group_name))
|
||||
|
||||
for external_user_email in self.external_user_emails:
|
||||
acl_set.add(prefix_user_email(external_user_email))
|
||||
|
||||
for external_group_id in self.external_user_group_ids:
|
||||
acl_set.add(prefix_external_group(external_group_id))
|
||||
|
||||
if self.is_public:
|
||||
acl_set.add(PUBLIC_DOC_PAT)
|
||||
|
||||
return acl_set
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
@@ -93,29 +105,32 @@ class DocumentAccess(ExternalAccess):
|
||||
external_user_group_ids: list[str],
|
||||
is_public: bool,
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
external_user_emails={
|
||||
prefix_user_email(external_email)
|
||||
for external_email in external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
prefix_external_group(external_group_id)
|
||||
for external_group_id in external_user_group_ids
|
||||
},
|
||||
user_emails={
|
||||
prefix_user_email(user_email)
|
||||
for user_email in user_emails
|
||||
if user_email
|
||||
},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
|
||||
|
||||
obj = object.__new__(cls)
|
||||
object.__setattr__(
|
||||
obj, "user_emails", {user_email for user_email in user_emails if user_email}
|
||||
)
|
||||
object.__setattr__(obj, "user_groups", set(user_groups))
|
||||
object.__setattr__(
|
||||
obj,
|
||||
"external_user_emails",
|
||||
{external_email for external_email in external_user_emails},
|
||||
)
|
||||
object.__setattr__(
|
||||
obj,
|
||||
"external_user_group_ids",
|
||||
{external_group_id for external_group_id in external_user_group_ids},
|
||||
)
|
||||
object.__setattr__(obj, "is_public", is_public)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
default_public_access = DocumentAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
default_public_access = DocumentAccess.build(
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
@@ -24,7 +23,7 @@ def process_llm_stream(
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
|
||||
@@ -156,7 +156,6 @@ def generate_initial_answer(
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
|
||||
@@ -183,7 +183,6 @@ def generate_validate_refined_answer(
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
|
||||
@@ -57,7 +57,6 @@ def format_results(
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
get_retrieved_sections=lambda: reranked_documents,
|
||||
get_reranked_sections=lambda: state.retrieved_documents,
|
||||
get_final_context_sections=lambda: reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
|
||||
@@ -13,9 +13,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
@@ -59,9 +57,7 @@ def basic_use_tool_response(
|
||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||
for section in search_response_summary.top_sections:
|
||||
if section.center_chunk.document_id not in initial_search_results:
|
||||
initial_search_results.append(
|
||||
context_from_inference_section(section)
|
||||
)
|
||||
initial_search_results.append(section_to_llm_doc(section))
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -10,12 +9,10 @@ from celery.utils.log import get_task_logger
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -141,8 +138,6 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
do_update = False
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
task_logger.debug("_try_updating_schedule starting")
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -152,16 +147,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# get potential new state
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
|
||||
)
|
||||
beat_multiplier = OnyxRuntime.get_beat_multiplier()
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = setup_logger()
|
||||
# Only set up memory monitoring in container environment
|
||||
if is_running_in_container():
|
||||
# Set up a dedicated memory monitoring logger
|
||||
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
|
||||
MEMORY_LOG_DIR = "/var/log/memory"
|
||||
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
|
||||
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
|
||||
|
||||
@@ -21,6 +21,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
# we have a better implementation (backpressure, etc)
|
||||
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
|
||||
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
|
||||
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = []
|
||||
|
||||
@@ -25,14 +25,8 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import (
|
||||
delete_all_documents_by_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
@@ -389,6 +383,8 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
credential_id_to_delete: int | None = None
|
||||
connector_id_to_delete: int | None = None
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
@@ -443,26 +439,23 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Explicitly delete document by connector credential pair records before deleting the connector
|
||||
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
|
||||
delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# Store IDs before potentially expiring cc_pair
|
||||
connector_id_to_delete = cc_pair.connector_id
|
||||
credential_id_to_delete = cc_pair.credential_id
|
||||
|
||||
# No need to explicitly delete DocumentByConnectorCredentialPair records anymore
|
||||
# as we have proper cascade relationships set up in the models
|
||||
|
||||
# Flush to ensure all operations happen in sequence
|
||||
db_session.flush()
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# Delete the cc-pair directly
|
||||
db_session.delete(cc_pair)
|
||||
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
connector_id=connector_id_to_delete,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
@@ -495,15 +488,15 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"connector={connector_id_to_delete} "
|
||||
f"credential={credential_id_to_delete} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
@@ -553,7 +546,7 @@ def validate_connector_deletion_fences(
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
queued_upsert_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
@@ -640,7 +633,7 @@ def validate_connector_deletion_fence(
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
if member_str in queued_upsert_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
@@ -17,6 +17,7 @@ from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
|
||||
@@ -63,6 +64,7 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
@@ -106,9 +108,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
|
||||
|
||||
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
|
||||
if not source_sync_period:
|
||||
return True
|
||||
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
|
||||
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
|
||||
|
||||
# If the last sync is greater than the full fetch period, we run the sync
|
||||
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
|
||||
@@ -286,7 +289,7 @@ def try_creating_permissions_sync_task(
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
|
||||
@@ -271,7 +271,7 @@ def try_creating_external_group_sync_task(
|
||||
),
|
||||
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
|
||||
@@ -72,6 +72,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -401,7 +402,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
redis_client.set(
|
||||
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
|
||||
1,
|
||||
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
|
||||
)
|
||||
|
||||
# 1/3: KICKOFF
|
||||
|
||||
|
||||
@@ -194,17 +194,6 @@ class StreamingError(BaseModel):
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class OnyxContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class OnyxContexts(BaseModel):
|
||||
contexts: list[OnyxContext]
|
||||
|
||||
|
||||
class OnyxAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
@@ -270,7 +259,6 @@ class PersonaOverrideConfig(BaseModel):
|
||||
AnswerQuestionPossibleReturn = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
| OnyxContexts
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
|
||||
@@ -29,7 +29,6 @@ from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
@@ -73,6 +72,7 @@ from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.chat import update_chat_session_updated_at_timestamp
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
@@ -130,7 +130,6 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -299,7 +298,6 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| OnyxContexts
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -918,8 +916,6 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
||||
yield cast(OnyxContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
@@ -1069,6 +1065,8 @@ def stream_chat_message_objects(
|
||||
prev_message = next_answer_message
|
||||
|
||||
logger.debug("Committing messages")
|
||||
# Explicitly update the timestamp on the chat session
|
||||
update_chat_session_updated_at_timestamp(chat_session_id, db_session)
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)
|
||||
|
||||
@@ -301,6 +301,10 @@ def prune_sections(
|
||||
|
||||
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
assert (
|
||||
len(set([chunk.document_id for chunk in chunks])) == 1
|
||||
), "One distinct document must be passed into merge_doc_chunks"
|
||||
|
||||
# Assuming there are no duplicates by this point
|
||||
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from collections.abc import Sequence
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
|
||||
|
||||
@@ -12,7 +11,7 @@ class DocumentIdOrderMapping(BaseModel):
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> DocumentIdOrderMapping:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
|
||||
@@ -382,6 +382,7 @@ ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
# the redis namespace for runtime variables
|
||||
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
|
||||
CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
|
||||
@@ -87,7 +87,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
credentials.get(key)
|
||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Google Cloud Storage")
|
||||
raise ConnectorMissingCredentialError("Amazon S3")
|
||||
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
|
||||
@@ -65,20 +65,6 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
|
||||
"gif",
|
||||
"mp4",
|
||||
"mov",
|
||||
"mp3",
|
||||
"wav",
|
||||
]
|
||||
_FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
[
|
||||
f" and title!~'*.{extension}'"
|
||||
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
|
||||
]
|
||||
)
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
|
||||
@@ -209,7 +195,6 @@ class ConfluenceConnector(
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
attachment_query += _FULL_EXTENSION_FILTER_STRING
|
||||
return attachment_query
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
@@ -374,11 +359,13 @@ class ConfluenceConnector(
|
||||
if not validate_attachment_filetype(
|
||||
attachment,
|
||||
):
|
||||
logger.info(f"Skipping attachment: {attachment['title']}")
|
||||
continue
|
||||
|
||||
logger.info(f"Processing attachment: {attachment['title']}")
|
||||
|
||||
# Attempt to get textual content or image summarization:
|
||||
try:
|
||||
logger.info(f"Processing attachment: {attachment['title']}")
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
|
||||
@@ -28,8 +28,9 @@ from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
@@ -69,7 +70,9 @@ def _process_egnyte_file(
|
||||
|
||||
file_name = file_metadata["name"]
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_valid_file_ext(extension):
|
||||
if not is_accepted_file_ext(
|
||||
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return None
|
||||
|
||||
|
||||
@@ -22,8 +22,9 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import is_accepted_file_ext
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.extract_file_text import OnyxExtensionType
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -51,7 +52,7 @@ def _read_files_and_metadata(
|
||||
file_content, ignore_dirs=True
|
||||
):
|
||||
yield os.path.join(directory_path, file_info.filename), subfile, metadata
|
||||
elif is_valid_file_ext(extension):
|
||||
elif is_accepted_file_ext(extension, OnyxExtensionType.All):
|
||||
yield file_name, file_content, metadata
|
||||
else:
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
@@ -122,7 +123,7 @@ def _process_file(
|
||||
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
|
||||
return []
|
||||
|
||||
if not is_valid_file_ext(extension):
|
||||
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
|
||||
logger.warning(
|
||||
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
|
||||
)
|
||||
|
||||
@@ -2,9 +2,11 @@ import copy
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Protocol
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -26,7 +28,9 @@ from onyx.connectors.google_drive.doc_conversion import (
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
@@ -56,7 +60,6 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.lazy import lazy_eval
|
||||
@@ -85,13 +88,18 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
||||
|
||||
def _convert_single_file(
|
||||
creds: Any,
|
||||
primary_admin_email: str,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
retriever_email: str,
|
||||
file: dict[str, Any],
|
||||
) -> Document | ConnectorFailure | None:
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
# We used to always get the user email from the file owners when available,
|
||||
# but this was causing issues with shared folders where the owner was not included in the service account
|
||||
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
|
||||
# wanting to retry with file owners and/or admin email at some point.
|
||||
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
|
||||
user_email = retriever_email
|
||||
# Only construct these services when needed
|
||||
user_drive_service = lazy_eval(
|
||||
lambda: get_drive_service(creds, user_email=user_email)
|
||||
@@ -449,10 +457,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
logger.info(f"Getting all files in my drive as '{user_email}'")
|
||||
|
||||
yield from add_retrieval_info(
|
||||
get_all_files_in_my_drive(
|
||||
get_all_files_in_my_drive_and_shared(
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
is_slim=is_slim,
|
||||
include_shared_with_me=self.include_files_shared_with_me,
|
||||
start=curr_stage.completed_until if resuming else start,
|
||||
end=end,
|
||||
),
|
||||
@@ -460,6 +469,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
DriveRetrievalStage.MY_DRIVE_FILES,
|
||||
)
|
||||
curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES
|
||||
resuming = False # we are starting the next stage for the first time
|
||||
|
||||
if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
|
||||
|
||||
@@ -495,7 +505,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
)
|
||||
yield from _yield_from_drive(drive_id, start)
|
||||
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
|
||||
|
||||
resuming = False # we are starting the next stage for the first time
|
||||
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
|
||||
|
||||
def _yield_from_folder_crawl(
|
||||
@@ -548,6 +558,16 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
|
||||
)
|
||||
|
||||
# Setup initial completion map on first connector run
|
||||
for email in all_org_emails:
|
||||
# don't overwrite existing completion map on resuming runs
|
||||
if email in checkpoint.completion_map:
|
||||
continue
|
||||
checkpoint.completion_map[email] = StageCompletion(
|
||||
stage=DriveRetrievalStage.START,
|
||||
completed_until=0,
|
||||
)
|
||||
|
||||
# we've found all users and drives, now time to actually start
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
@@ -561,11 +581,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
drive_ids_to_retrieve, checkpoint
|
||||
)
|
||||
|
||||
for email in all_org_emails:
|
||||
checkpoint.completion_map[email] = StageCompletion(
|
||||
stage=DriveRetrievalStage.START,
|
||||
completed_until=0,
|
||||
)
|
||||
user_retrieval_gens = [
|
||||
self._impersonate_user_for_retrieval(
|
||||
email,
|
||||
@@ -796,10 +811,12 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
return
|
||||
|
||||
for file in drive_files:
|
||||
if file.error is not None:
|
||||
if file.error is None:
|
||||
checkpoint.completion_map[file.user_email].update(
|
||||
stage=file.completion_stage,
|
||||
completed_until=file.drive_file[GoogleFields.MODIFIED_TIME.value],
|
||||
completed_until=datetime.fromisoformat(
|
||||
file.drive_file[GoogleFields.MODIFIED_TIME.value]
|
||||
).timestamp(),
|
||||
completed_until_parent_id=file.parent_id,
|
||||
)
|
||||
yield file
|
||||
@@ -901,25 +918,45 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
checkpoint: GoogleDriveCheckpoint,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[list[Document | ConnectorFailure]]:
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
try:
|
||||
# Prepare a partial function with the credentials and admin email
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
self.allow_images,
|
||||
self.size_threshold,
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
func_with_args: list[
|
||||
tuple[
|
||||
Callable[..., Document | ConnectorFailure | None], tuple[Any, ...]
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
|
||||
def _yield_batch(
|
||||
files_batch: list[RetrievedDriveFile],
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
nonlocal batches_complete
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [
|
||||
(
|
||||
convert_func,
|
||||
(
|
||||
file.user_email,
|
||||
file.drive_file,
|
||||
),
|
||||
)
|
||||
for file in files_batch
|
||||
]
|
||||
] = []
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
)
|
||||
|
||||
docs_and_failures = [result for result in results if result is not None]
|
||||
|
||||
if docs_and_failures:
|
||||
yield from docs_and_failures
|
||||
batches_complete += 1
|
||||
|
||||
for retrieved_file in self._fetch_drive_items(
|
||||
is_slim=False,
|
||||
checkpoint=checkpoint,
|
||||
@@ -937,48 +974,21 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
)
|
||||
failure_message += f"error: {retrieved_file.error}"
|
||||
logger.error(failure_message)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
]
|
||||
yield ConnectorFailure(
|
||||
failed_entity=EntityFailure(
|
||||
entity_id=failure_stage,
|
||||
),
|
||||
failure_message=failure_message,
|
||||
exception=retrieved_file.error,
|
||||
)
|
||||
|
||||
continue
|
||||
files_batch.append(retrieved_file.drive_file)
|
||||
files_batch.append(retrieved_file)
|
||||
|
||||
if len(files_batch) < self.batch_size:
|
||||
continue
|
||||
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [(convert_func, (file,)) for file in files_batch]
|
||||
results = run_functions_tuples_in_parallel(
|
||||
func_with_args, max_workers=8
|
||||
)
|
||||
|
||||
documents = []
|
||||
for idx, (result, exception) in enumerate(results):
|
||||
if exception:
|
||||
error_str = f"Error converting file: {exception}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=files_batch[idx]["id"],
|
||||
document_link=files_batch[idx]["webViewLink"],
|
||||
),
|
||||
failure_message=error_str,
|
||||
exception=exception,
|
||||
)
|
||||
]
|
||||
elif result is not None:
|
||||
documents.append(result)
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
batches_complete += 1
|
||||
yield from _yield_batch(files_batch)
|
||||
files_batch = []
|
||||
|
||||
if batches_complete > BATCHES_PER_CHECKPOINT:
|
||||
@@ -987,31 +997,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
|
||||
# Process any remaining files
|
||||
if files_batch:
|
||||
func_with_args = [(convert_func, (file,)) for file in files_batch]
|
||||
results = run_functions_tuples_in_parallel(
|
||||
func_with_args, max_workers=8
|
||||
)
|
||||
|
||||
documents = []
|
||||
for idx, (result, exception) in enumerate(results):
|
||||
if exception:
|
||||
error_str = f"Error converting file: {exception}"
|
||||
logger.error(error_str)
|
||||
yield [
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=files_batch[idx]["id"],
|
||||
document_link=files_batch[idx]["webViewLink"],
|
||||
),
|
||||
failure_message=error_str,
|
||||
exception=exception,
|
||||
)
|
||||
]
|
||||
elif result is not None:
|
||||
documents.append(result)
|
||||
|
||||
if documents:
|
||||
yield documents
|
||||
yield from _yield_batch(files_batch)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error extracting documents from Google Drive: {e}")
|
||||
raise e
|
||||
@@ -1033,10 +1019,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids
|
||||
try:
|
||||
for doc_list in self._extract_docs_from_google_drive(
|
||||
checkpoint, start, end
|
||||
):
|
||||
yield from doc_list
|
||||
yield from self._extract_docs_from_google_drive(checkpoint, start, end)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
|
||||
@@ -87,35 +87,17 @@ def _download_and_extract_sections_basic(
|
||||
mime_type = file["mimeType"]
|
||||
link = file.get("webViewLink", "")
|
||||
|
||||
try:
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
@@ -124,88 +106,100 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
return []
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
image_data=img_data,
|
||||
file_name=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=img_data,
|
||||
file_name=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {file_name}: {e}")
|
||||
return []
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
|
||||
@@ -123,7 +123,7 @@ def crawl_folders_for_files(
|
||||
end=end,
|
||||
):
|
||||
found_files = True
|
||||
logger.info(f"Found file: {file['name']}")
|
||||
logger.info(f"Found file: {file['name']}, user email: {user_email}")
|
||||
yield RetrievedDriveFile(
|
||||
drive_file=file,
|
||||
user_email=user_email,
|
||||
@@ -214,10 +214,11 @@ def get_files_in_shared_drive(
|
||||
yield file
|
||||
|
||||
|
||||
def get_all_files_in_my_drive(
|
||||
def get_all_files_in_my_drive_and_shared(
|
||||
service: GoogleDriveService,
|
||||
update_traversed_ids_func: Callable,
|
||||
is_slim: bool,
|
||||
include_shared_with_me: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
@@ -229,7 +230,8 @@ def get_all_files_in_my_drive(
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
folder_query += " and 'me' in owners"
|
||||
if not include_shared_with_me:
|
||||
folder_query += " and 'me' in owners"
|
||||
found_folders = False
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
@@ -246,7 +248,8 @@ def get_all_files_in_my_drive(
|
||||
# Then get the files
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
file_query += " and 'me' in owners"
|
||||
if not include_shared_with_me:
|
||||
file_query += " and 'me' in owners"
|
||||
file_query += _generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
|
||||
@@ -75,7 +75,7 @@ class HighspotClient:
|
||||
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
self.base_url = base_url
|
||||
self.base_url = base_url.rstrip("/") + "/"
|
||||
self.timeout = timeout
|
||||
|
||||
# Set up session with retry logic
|
||||
|
||||
@@ -20,8 +20,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import VALID_FILE_EXTENSIONS
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -298,7 +298,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
elif (
|
||||
is_valid_format
|
||||
and file_extension in VALID_FILE_EXTENSIONS
|
||||
and file_extension in ALL_ACCEPTED_FILE_EXTENSIONS
|
||||
and can_download
|
||||
):
|
||||
# For documents, try to get the text content
|
||||
|
||||
@@ -339,6 +339,12 @@ class SearchPipeline:
|
||||
self._retrieved_sections = self._get_sections()
|
||||
return self._retrieved_sections
|
||||
|
||||
@property
|
||||
def merged_retrieved_sections(self) -> list[InferenceSection]:
|
||||
"""Should be used to display in the UI in order to prevent displaying
|
||||
multiple sections for the same document as separate "documents"."""
|
||||
return _merge_sections(sections=self.retrieved_sections)
|
||||
|
||||
@property
|
||||
def reranked_sections(self) -> list[InferenceSection]:
|
||||
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
||||
@@ -415,6 +421,10 @@ class SearchPipeline:
|
||||
raise ValueError(
|
||||
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
|
||||
)
|
||||
# NOTE: final_context_sections must be accessed before accessing self._postprocessing_generator
|
||||
# since the property sets the generator. DO NOT REMOVE.
|
||||
_ = self.final_context_sections
|
||||
|
||||
self._section_relevance = next(
|
||||
cast(
|
||||
Iterator[list[SectionRelevancePiece]],
|
||||
|
||||
@@ -1089,3 +1089,20 @@ def log_agent_sub_question_results(
|
||||
db_session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def update_chat_session_updated_at_timestamp(
|
||||
chat_session_id: UUID, db_session: Session
|
||||
) -> None:
|
||||
"""
|
||||
Explicitly update the timestamp on a chat session without modifying other fields.
|
||||
This is useful when adding messages to a chat session to reflect recent activity.
|
||||
"""
|
||||
|
||||
# Direct SQL update to avoid loading the entire object if it's not already loaded
|
||||
db_session.execute(
|
||||
update(ChatSession)
|
||||
.where(ChatSession.id == chat_session_id)
|
||||
.values(time_updated=func.now())
|
||||
)
|
||||
# No commit - the caller is responsible for committing the transaction
|
||||
|
||||
@@ -555,28 +555,6 @@ def delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> None:
|
||||
"""Deletes all document by connector credential pair entries for a specific connector and credential.
|
||||
This is primarily used during connector deletion to ensure all references are removed
|
||||
before deleting the connector itself. This is crucial because connector_id is part of the
|
||||
primary key in DocumentByConnectorCredentialPair, and attempting to delete the Connector
|
||||
would otherwise try to set the foreign key to NULL, which fails for primary keys.
|
||||
|
||||
NOTE: Does not commit the transaction, this must be done by the caller.
|
||||
"""
|
||||
stmt = delete(DocumentByConnectorCredentialPair).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None:
|
||||
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))
|
||||
|
||||
|
||||
@@ -694,7 +694,11 @@ class Connector(Base):
|
||||
)
|
||||
documents_by_connector: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
|
||||
] = relationship(
|
||||
"DocumentByConnectorCredentialPair",
|
||||
back_populates="connector",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# synchronize this validation logic with RefreshFrequencySchema etc on front end
|
||||
# until we have a centralized validation schema
|
||||
@@ -748,7 +752,11 @@ class Credential(Base):
|
||||
)
|
||||
documents_by_credential: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
|
||||
] = relationship(
|
||||
"DocumentByConnectorCredentialPair",
|
||||
back_populates="credential",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="credentials")
|
||||
|
||||
|
||||
@@ -24,7 +24,9 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user role update is valid.
|
||||
Assumed only admins can hit this endpoint.
|
||||
@@ -57,6 +59,9 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
|
||||
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if explicit_override:
|
||||
return
|
||||
|
||||
if requested_role == UserRole.CURATOR:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
|
||||
@@ -7,6 +7,8 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Sequence
|
||||
from email.parser import Parser as EmailParser
|
||||
from enum import auto
|
||||
from enum import IntFlag
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -35,7 +37,7 @@ logger = setup_logger()
|
||||
|
||||
TEXT_SECTION_SEPARATOR = "\n\n"
|
||||
|
||||
PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".txt",
|
||||
".md",
|
||||
".mdx",
|
||||
@@ -49,7 +51,7 @@ PLAIN_TEXT_FILE_EXTENSIONS = [
|
||||
".yaml",
|
||||
]
|
||||
|
||||
VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
|
||||
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
|
||||
".pdf",
|
||||
".docx",
|
||||
".pptx",
|
||||
@@ -57,12 +59,21 @@ VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
|
||||
".eml",
|
||||
".epub",
|
||||
".html",
|
||||
]
|
||||
|
||||
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".webp",
|
||||
]
|
||||
|
||||
ALL_ACCEPTED_FILE_EXTENSIONS = (
|
||||
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
|
||||
+ ACCEPTED_DOCUMENT_FILE_EXTENSIONS
|
||||
+ ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
)
|
||||
|
||||
IMAGE_MEDIA_TYPES = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
@@ -70,8 +81,15 @@ IMAGE_MEDIA_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
class OnyxExtensionType(IntFlag):
|
||||
Plain = auto()
|
||||
Document = auto()
|
||||
Multimedia = auto()
|
||||
All = Plain | Document | Multimedia
|
||||
|
||||
|
||||
def is_text_file_extension(file_name: str) -> bool:
|
||||
return any(file_name.endswith(ext) for ext in PLAIN_TEXT_FILE_EXTENSIONS)
|
||||
return any(file_name.endswith(ext) for ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS)
|
||||
|
||||
|
||||
def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
@@ -83,8 +101,20 @@ def is_valid_media_type(media_type: str) -> bool:
|
||||
return media_type in IMAGE_MEDIA_TYPES
|
||||
|
||||
|
||||
def is_valid_file_ext(ext: str) -> bool:
|
||||
return ext in VALID_FILE_EXTENSIONS
|
||||
def is_accepted_file_ext(ext: str, ext_type: OnyxExtensionType) -> bool:
|
||||
if ext_type & OnyxExtensionType.Plain:
|
||||
if ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
|
||||
return True
|
||||
|
||||
if ext_type & OnyxExtensionType.Document:
|
||||
if ext in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
|
||||
return True
|
||||
|
||||
if ext_type & OnyxExtensionType.Multimedia:
|
||||
if ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_text_file(file: IO[bytes]) -> bool:
|
||||
@@ -382,6 +412,9 @@ def extract_file_text(
|
||||
"""
|
||||
Legacy function that returns *only text*, ignoring embedded images.
|
||||
For backward-compatibility in code that only wants text.
|
||||
|
||||
NOTE: Ignoring seems to be defined as returning an empty string for files it can't
|
||||
handle (such as images).
|
||||
"""
|
||||
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
|
||||
".pdf": pdf_to_text,
|
||||
@@ -405,7 +438,9 @@ def extract_file_text(
|
||||
if extension is None:
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
if is_valid_file_ext(extension):
|
||||
if is_accepted_file_ext(
|
||||
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
|
||||
):
|
||||
func = extension_to_function.get(extension, file_io_to_text)
|
||||
file.seek(0)
|
||||
return func(file)
|
||||
|
||||
@@ -15,6 +15,7 @@ EXCLUDED_IMAGE_TYPES = [
|
||||
"image/tiff",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"image/avif",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import get_chat_message
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
@@ -553,8 +552,7 @@ def handle_followup_resolved_button(
|
||||
|
||||
# Delete the message with the option to mark resolved
|
||||
if not immediate:
|
||||
slack_call = make_slack_api_rate_limited(client.web_client.chat_delete)
|
||||
response = slack_call(
|
||||
response = client.web_client.chat_delete(
|
||||
channel=channel_id,
|
||||
ts=message_ts,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,9 @@ from prometheus_client import start_http_server
|
||||
from redis.lock import Lock
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry import RateLimitErrorRetryHandler
|
||||
from slack_sdk.http_retry import RetryHandler
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -944,10 +947,21 @@ def _get_socket_client(
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.onyx.app/slack_bot_setup
|
||||
|
||||
# use the retry handlers built into the slack sdk
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler()
|
||||
rate_limit_error_retry_handler = RateLimitErrorRetryHandler(max_retry_count=7)
|
||||
slack_retry_handlers: list[RetryHandler] = [
|
||||
connection_error_retry_handler,
|
||||
rate_limit_error_retry_handler,
|
||||
]
|
||||
|
||||
return TenantSocketModeClient(
|
||||
# This app-level token will be used only for establishing a connection
|
||||
app_token=slack_bot_tokens.app_token,
|
||||
web_client=WebClient(token=slack_bot_tokens.bot_token),
|
||||
web_client=WebClient(
|
||||
token=slack_bot_tokens.bot_token, retry_handlers=slack_retry_handlers
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
slack_bot_id=slack_bot_id,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.configs.onyxbot_configs import (
|
||||
from onyx.configs.onyxbot_configs import (
|
||||
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS,
|
||||
)
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -125,13 +124,18 @@ def update_emote_react(
|
||||
)
|
||||
return
|
||||
|
||||
func = client.reactions_remove if remove else client.reactions_add
|
||||
slack_call = make_slack_api_rate_limited(func) # type: ignore
|
||||
slack_call(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
if remove:
|
||||
client.reactions_remove(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
else:
|
||||
client.reactions_add(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if remove:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
@@ -200,9 +204,8 @@ def respond_in_thread_or_channel(
|
||||
|
||||
message_ids: list[str] = []
|
||||
if not receiver_ids:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
|
||||
try:
|
||||
response = slack_call(
|
||||
response = client.chat_postMessage(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
@@ -224,7 +227,7 @@ def respond_in_thread_or_channel(
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
response = client.chat_postMessage(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks_without_urls,
|
||||
@@ -236,11 +239,9 @@ def respond_in_thread_or_channel(
|
||||
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
|
||||
for receiver in receiver_ids:
|
||||
try:
|
||||
response = slack_call(
|
||||
response = client.chat_postEphemeral(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
@@ -263,7 +264,7 @@ def respond_in_thread_or_channel(
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
response = client.chat_postEphemeral(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
@@ -500,7 +501,7 @@ def fetch_user_semantic_id_from_id(
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
|
||||
response = client.users_info(user=user_id)
|
||||
if not response["ok"]:
|
||||
return None
|
||||
|
||||
|
||||
@@ -132,6 +132,7 @@ class UserByEmail(BaseModel):
|
||||
class UserRoleUpdateRequest(BaseModel):
|
||||
user_email: str
|
||||
new_role: UserRole
|
||||
explicit_override: bool = False
|
||||
|
||||
|
||||
class UserRoleResponse(BaseModel):
|
||||
|
||||
@@ -102,6 +102,7 @@ def set_user_role(
|
||||
validate_user_role_update(
|
||||
requested_role=requested_role,
|
||||
current_role=current_role,
|
||||
explicit_override=user_role_update_request.explicit_override,
|
||||
)
|
||||
|
||||
if user_to_update.id == current_user.id:
|
||||
@@ -122,6 +123,22 @@ def set_user_role(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class TestUpsertRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
@router.post("/manage/users/test-upsert-user")
|
||||
async def test_upsert_user(
|
||||
request: TestUpsertRequest,
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None | FullUserSnapshot:
|
||||
"""Test endpoint for upsert_saml_user. Only used for integration testing."""
|
||||
user = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.saml", "upsert_saml_user", None
|
||||
)(email=request.email)
|
||||
return FullUserSnapshot.from_user_model(user) if user else None
|
||||
|
||||
|
||||
@router.get("/manage/users/accepted")
|
||||
def list_accepted_users(
|
||||
q: str | None = Query(default=None),
|
||||
@@ -296,7 +313,7 @@ def bulk_invite_users(
|
||||
detail=f"Invalid email address: {email} - {str(e)}",
|
||||
)
|
||||
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
@@ -318,7 +335,7 @@ def bulk_invite_users(
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
|
||||
if not MULTI_TENANT:
|
||||
if not MULTI_TENANT or DEV_MODE:
|
||||
return number_of_invited_users
|
||||
|
||||
# for billing purposes, write to the control plane about the number of new users
|
||||
@@ -359,7 +376,7 @@ def remove_invited_user(
|
||||
number_of_invited_users = write_invited_users(remaining_users)
|
||||
|
||||
try:
|
||||
if MULTI_TENANT:
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.billing", "register_tenant_users", None
|
||||
)(tenant_id, get_total_users_count(db_session))
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
import io
|
||||
from typing import cast
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT,
|
||||
)
|
||||
from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.file_store.file_store import PostgresBackedFileStore
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.file import FileWithMimeType
|
||||
from onyx.utils.file import OnyxStaticFileManager
|
||||
from onyx.utils.variable_functionality import (
|
||||
@@ -87,3 +96,72 @@ class OnyxRuntime:
|
||||
)
|
||||
|
||||
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
|
||||
|
||||
@staticmethod
|
||||
def get_beat_multiplier() -> float:
|
||||
"""the beat multiplier is used to scale up or down the frequency of certain beat
|
||||
tasks in the cloud. It has a significant effect on load and is useful to adjust
|
||||
in real time."""
|
||||
|
||||
beat_multiplier: float = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if beat_multiplier <= 0.0:
|
||||
return 1.0
|
||||
|
||||
return beat_multiplier
|
||||
|
||||
@staticmethod
|
||||
def get_doc_permission_sync_multiplier() -> float:
|
||||
"""Permission syncs are a significant source of load / queueing in the cloud."""
|
||||
|
||||
value: float = CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
value_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:doc_permission_sync_multiplier")
|
||||
if value_raw is not None:
|
||||
try:
|
||||
value_bytes = cast(bytes, value_raw)
|
||||
value = float(value_bytes.decode())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if value <= 0.0:
|
||||
return 1.0
|
||||
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def get_build_fence_lookup_table_interval() -> int:
|
||||
"""We maintain an active fence table to make lookups of existing fences efficient.
|
||||
However, reconstructing the table is expensive, so adjusting it in realtime is useful.
|
||||
"""
|
||||
|
||||
interval: int = CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
|
||||
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
interval_raw = r.get(
|
||||
f"{ONYX_CLOUD_REDIS_RUNTIME}:build_fence_lookup_table_interval"
|
||||
)
|
||||
if interval_raw is not None:
|
||||
try:
|
||||
interval_bytes = cast(bytes, interval_raw)
|
||||
interval = int(interval_bytes.decode())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if interval <= 0.0:
|
||||
return CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
|
||||
|
||||
return interval
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ContextualPruningConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
@@ -42,9 +41,6 @@ from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
build_next_prompt_for_search_like_tool,
|
||||
@@ -58,7 +54,6 @@ from onyx.utils.special_types import JSON_ro
|
||||
logger = setup_logger()
|
||||
|
||||
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
||||
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
||||
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
QUERY_FIELD = "query"
|
||||
@@ -357,13 +352,13 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||
)
|
||||
yield from yield_search_responses(
|
||||
query,
|
||||
lambda: search_pipeline.retrieved_sections,
|
||||
lambda: search_pipeline.reranked_sections,
|
||||
lambda: search_pipeline.final_context_sections,
|
||||
search_query_info,
|
||||
lambda: search_pipeline.section_relevance,
|
||||
self,
|
||||
query=query,
|
||||
# give back the merged sections to prevent duplicate docs from appearing in the UI
|
||||
get_retrieved_sections=lambda: search_pipeline.merged_retrieved_sections,
|
||||
get_final_context_sections=lambda: search_pipeline.final_context_sections,
|
||||
search_query_info=search_query_info,
|
||||
get_section_relevance=lambda: search_pipeline.section_relevance,
|
||||
search_tool=self,
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
@@ -405,7 +400,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
def yield_search_responses(
|
||||
query: str,
|
||||
get_retrieved_sections: Callable[[], list[InferenceSection]],
|
||||
get_reranked_sections: Callable[[], list[InferenceSection]],
|
||||
get_final_context_sections: Callable[[], list[InferenceSection]],
|
||||
search_query_info: SearchQueryInfo,
|
||||
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
|
||||
@@ -423,16 +417,6 @@ def yield_search_responses(
|
||||
),
|
||||
)
|
||||
|
||||
yield ToolResponse(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=OnyxContexts(
|
||||
contexts=[
|
||||
context_from_inference_section(section)
|
||||
for section in get_reranked_sections()
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
section_relevance = get_section_relevance()
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.prompt_utils import clean_up_source
|
||||
|
||||
@@ -32,10 +31,23 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
|
||||
return doc_dict
|
||||
|
||||
|
||||
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
|
||||
return OnyxContext(
|
||||
content=section.combined_content,
|
||||
def section_to_llm_doc(section: InferenceSection) -> LlmDoc:
|
||||
possible_link_chunks = [section.center_chunk] + section.chunks
|
||||
link: str | None = None
|
||||
for chunk in possible_link_chunks:
|
||||
if chunk.source_links:
|
||||
link = list(chunk.source_links.values())[0]
|
||||
break
|
||||
|
||||
return LlmDoc(
|
||||
document_id=section.center_chunk.document_id,
|
||||
content=section.combined_content,
|
||||
source_type=section.center_chunk.source_type,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
metadata=section.center_chunk.metadata,
|
||||
updated_at=section.center_chunk.updated_at,
|
||||
blurb=section.center_chunk.blurb,
|
||||
link=link,
|
||||
source_links=section.center_chunk.source_links,
|
||||
match_highlights=section.center_chunk.match_highlights,
|
||||
)
|
||||
|
||||
@@ -78,19 +78,19 @@ def generate_dummy_chunk(
|
||||
for i in range(number_of_document_sets):
|
||||
document_set_names.append(f"Document Set {i}")
|
||||
|
||||
user_emails: set[str | None] = set()
|
||||
user_groups: set[str] = set()
|
||||
external_user_emails: set[str] = set()
|
||||
external_user_group_ids: set[str] = set()
|
||||
user_emails: list[str | None] = []
|
||||
user_groups: list[str] = []
|
||||
external_user_emails: list[str] = []
|
||||
external_user_group_ids: list[str] = []
|
||||
for i in range(number_of_acl_entries):
|
||||
user_emails.add(f"user_{i}@example.com")
|
||||
user_groups.add(f"group_{i}")
|
||||
external_user_emails.add(f"external_user_{i}@example.com")
|
||||
external_user_group_ids.add(f"external_group_{i}")
|
||||
user_emails.append(f"user_{i}@example.com")
|
||||
user_groups.append(f"group_{i}")
|
||||
external_user_emails.append(f"external_user_{i}@example.com")
|
||||
external_user_group_ids.append(f"external_group_{i}")
|
||||
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=DocumentAccess(
|
||||
access=DocumentAccess.build(
|
||||
user_emails=user_emails,
|
||||
user_groups=user_groups,
|
||||
external_user_emails=external_user_emails,
|
||||
|
||||
77
backend/tests/daily/connectors/blob/test_blob_connector.py
Normal file
77
backend/tests/daily/connectors/blob/test_blob_connector.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import BlobType
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def blob_connector(request: pytest.FixtureRequest) -> BlobStorageConnector:
|
||||
connector = BlobStorageConnector(
|
||||
bucket_type=BlobType.S3, bucket_name="onyx-connector-tests"
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"aws_access_key_id": os.environ["AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS"],
|
||||
"aws_secret_access_key": os.environ[
|
||||
"AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS"
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_blob_s3_connector(
|
||||
mock_get_api_key: MagicMock, blob_connector: BlobStorageConnector
|
||||
) -> None:
|
||||
"""
|
||||
Plain and document file types should be fully indexed.
|
||||
|
||||
Multimedia and unknown file types will be indexed by title only with one empty section.
|
||||
|
||||
This is intentional in order to allow searching by just the title even if we can't
|
||||
index the file content.
|
||||
"""
|
||||
all_docs: list[Document] = []
|
||||
document_batches = blob_connector.load_from_state()
|
||||
for doc_batch in document_batches:
|
||||
for doc in doc_batch:
|
||||
all_docs.append(doc)
|
||||
|
||||
#
|
||||
assert len(all_docs) == 19
|
||||
|
||||
for doc in all_docs:
|
||||
section = doc.sections[0]
|
||||
assert isinstance(section, TextSection)
|
||||
|
||||
file_extension = get_file_ext(doc.semantic_identifier)
|
||||
if file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
|
||||
assert len(section.text) > 0
|
||||
continue
|
||||
|
||||
if file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
|
||||
assert len(section.text) > 0
|
||||
continue
|
||||
|
||||
if file_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS:
|
||||
assert len(section.text) == 0
|
||||
continue
|
||||
|
||||
# unknown extension
|
||||
assert len(section.text) == 0
|
||||
@@ -58,6 +58,16 @@ SECTIONS_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
|
||||
)
|
||||
|
||||
EXTERNAL_SHARED_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
|
||||
)
|
||||
EXTERNAL_SHARED_DOCS_IN_FOLDER = [
|
||||
"https://docs.google.com/document/d/1Sywmv1-H6ENk2GcgieKou3kQHR_0te1mhIUcq8XlcdY"
|
||||
]
|
||||
EXTERNAL_SHARED_DOC_SINGLETON = (
|
||||
"https://docs.google.com/document/d/11kmisDfdvNcw5LYZbkdPVjTOdj-Uc5ma6Jep68xzeeA"
|
||||
)
|
||||
|
||||
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
|
||||
|
||||
ADMIN_EMAIL = "admin@onyx-test.com"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
@@ -9,6 +10,15 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_DOC_SINGLETON,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_DOCS_IN_FOLDER,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_FOLDER_URL,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
@@ -100,7 +110,8 @@ def test_include_shared_drives_only_with_size_threshold(
|
||||
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
assert len(retrieved_docs) == 50
|
||||
# 2 extra files from shared drive owned by non-admin and not shared with admin
|
||||
assert len(retrieved_docs) == 52
|
||||
|
||||
|
||||
@patch(
|
||||
@@ -137,7 +148,8 @@ def test_include_shared_drives_only(
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
|
||||
assert len(retrieved_docs) == 51
|
||||
# 2 extra files from shared drive owned by non-admin and not shared with admin
|
||||
assert len(retrieved_docs) == 53
|
||||
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
@@ -294,6 +306,64 @@ def test_folders_only(
|
||||
)
|
||||
|
||||
|
||||
def test_shared_folder_owned_by_external_user(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_shared_folder_owned_by_external_user")
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=False,
|
||||
include_my_drives=False,
|
||||
include_files_shared_with_me=False,
|
||||
shared_drive_urls=None,
|
||||
shared_folder_urls=EXTERNAL_SHARED_FOLDER_URL,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_docs = EXTERNAL_SHARED_DOCS_IN_FOLDER
|
||||
|
||||
assert len(retrieved_docs) == len(expected_docs) # 1 for now
|
||||
assert expected_docs[0] in retrieved_docs[0].id
|
||||
|
||||
|
||||
def test_shared_with_me(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_shared_with_me")
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=False,
|
||||
include_my_drives=True,
|
||||
include_files_shared_with_me=True,
|
||||
shared_drive_urls=None,
|
||||
shared_folder_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
print(retrieved_docs)
|
||||
|
||||
expected_file_ids = (
|
||||
ADMIN_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
+ TEST_USER_1_FILE_IDS
|
||||
+ TEST_USER_2_FILE_IDS
|
||||
+ TEST_USER_3_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
retrieved_ids = {urlparse(doc.id).path.split("/")[-2] for doc in retrieved_docs}
|
||||
for id in retrieved_ids:
|
||||
print(id)
|
||||
|
||||
assert EXTERNAL_SHARED_DOC_SINGLETON.split("/")[-1] in retrieved_ids
|
||||
assert EXTERNAL_SHARED_DOCS_IN_FOLDER[0].split("/")[-1] in retrieved_ids
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
|
||||
@@ -6,7 +6,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
|
||||
MAX_DELAY = 45
|
||||
MAX_DELAY = 60
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import requests
|
||||
from requests.models import Response
|
||||
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
@@ -97,17 +98,24 @@ class ChatSessionManager:
|
||||
for data in response_data:
|
||||
if "rephrased_query" in data:
|
||||
analyzed.rephrased_query = data["rephrased_query"]
|
||||
elif "tool_name" in data:
|
||||
if "tool_name" in data:
|
||||
analyzed.tool_name = data["tool_name"]
|
||||
analyzed.tool_result = (
|
||||
data.get("tool_result")
|
||||
if analyzed.tool_name == "run_search"
|
||||
else None
|
||||
)
|
||||
elif "relevance_summaries" in data:
|
||||
if "relevance_summaries" in data:
|
||||
analyzed.relevance_summaries = data["relevance_summaries"]
|
||||
elif "answer_piece" in data and data["answer_piece"]:
|
||||
if "answer_piece" in data and data["answer_piece"]:
|
||||
analyzed.full_message += data["answer_piece"]
|
||||
if "top_documents" in data:
|
||||
assert (
|
||||
analyzed.top_documents is None
|
||||
), "top_documents should only be set once"
|
||||
analyzed.top_documents = [
|
||||
SavedSearchDoc(**doc) for doc in data["top_documents"]
|
||||
]
|
||||
|
||||
return analyzed
|
||||
|
||||
|
||||
@@ -9,7 +9,9 @@ from requests import HTTPError
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import UserInfo
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
@@ -123,10 +125,15 @@ class UserManager:
|
||||
user_to_set: DATestUser,
|
||||
target_role: UserRole,
|
||||
user_performing_action: DATestUser,
|
||||
explicit_override: bool = False,
|
||||
) -> DATestUser:
|
||||
response = requests.patch(
|
||||
url=f"{API_SERVER_URL}/manage/set-user-role",
|
||||
json={"user_email": user_to_set.email, "new_role": target_role.value},
|
||||
json={
|
||||
"user_email": user_to_set.email,
|
||||
"new_role": target_role.value,
|
||||
"explicit_override": explicit_override,
|
||||
},
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -240,3 +247,69 @@ class UserManager:
|
||||
total_items=data["total_items"],
|
||||
)
|
||||
return paginated_result
|
||||
|
||||
@staticmethod
|
||||
def invite_user(
|
||||
user_to_invite_email: str, user_performing_action: DATestUser
|
||||
) -> None:
|
||||
"""Invite a user by email to join the organization.
|
||||
|
||||
Args:
|
||||
user_to_invite_email: Email of the user to invite
|
||||
user_performing_action: User with admin permissions performing the invitation
|
||||
"""
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/manage/admin/users",
|
||||
headers=user_performing_action.headers,
|
||||
json={"emails": [user_to_invite_email]},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def accept_invitation(tenant_id: str, user_performing_action: DATestUser) -> None:
|
||||
"""Accept an invitation to join the organization.
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant/organization to accept invitation for
|
||||
user_performing_action: User accepting the invitation
|
||||
"""
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/tenants/users/invite/accept",
|
||||
headers=user_performing_action.headers,
|
||||
json={"tenant_id": tenant_id},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_invited_users(
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[InvitedUserSnapshot]:
|
||||
"""Get a list of all invited users.
|
||||
|
||||
Args:
|
||||
user_performing_action: User with admin permissions performing the action
|
||||
|
||||
Returns:
|
||||
List of invited user snapshots
|
||||
"""
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/users/invited",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return [InvitedUserSnapshot(**user) for user in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def get_user_info(user_performing_action: DATestUser) -> UserInfo:
|
||||
"""Get user info for the current user.
|
||||
|
||||
Args:
|
||||
user_performing_action: User performing the action
|
||||
"""
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return UserInfo(**response.json())
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import Field
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
@@ -157,7 +158,7 @@ class StreamedResponse(BaseModel):
|
||||
full_message: str = ""
|
||||
rephrased_query: str | None = None
|
||||
tool_name: str | None = None
|
||||
top_documents: list[dict[str, Any]] | None = None
|
||||
top_documents: list[SavedSearchDoc] | None = None
|
||||
relevance_summaries: list[dict[str, Any]] | None = None
|
||||
tool_result: Any | None = None
|
||||
user: str | None = None
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
from onyx.db.models import UserRole
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
INVITED_BASIC_USER = "basic_user"
|
||||
INVITED_BASIC_USER_EMAIL = "basic_user@test.com"
|
||||
|
||||
|
||||
def test_user_invitation_flow(reset_multitenant: None) -> None:
|
||||
# Create first user (admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin")
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Create second user
|
||||
invited_user: DATestUser = UserManager.create(name="admin_invited")
|
||||
assert UserManager.is_role(invited_user, UserRole.ADMIN)
|
||||
|
||||
# Admin user invites the previously registered and non-registered user
|
||||
UserManager.invite_user(invited_user.email, admin_user)
|
||||
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)
|
||||
|
||||
invited_basic_user: DATestUser = UserManager.create(
|
||||
name=INVITED_BASIC_USER, email=INVITED_BASIC_USER_EMAIL
|
||||
)
|
||||
assert UserManager.is_role(invited_basic_user, UserRole.BASIC)
|
||||
|
||||
# Verify the user is in the invited users list
|
||||
invited_users = UserManager.get_invited_users(admin_user)
|
||||
assert invited_user.email in [
|
||||
user.email for user in invited_users
|
||||
], f"User {invited_user.email} not found in invited users list"
|
||||
|
||||
# Get user info to check tenant information
|
||||
user_info = UserManager.get_user_info(invited_user)
|
||||
|
||||
# Extract the tenant_id from the invitation
|
||||
invited_tenant_id = (
|
||||
user_info.tenant_info.invitation.tenant_id
|
||||
if user_info.tenant_info and user_info.tenant_info.invitation
|
||||
else None
|
||||
)
|
||||
assert invited_tenant_id is not None, "Expected to find an invitation tenant_id"
|
||||
|
||||
UserManager.accept_invitation(invited_tenant_id, invited_user)
|
||||
|
||||
# Get updated user info after accepting invitation
|
||||
updated_user_info = UserManager.get_user_info(invited_user)
|
||||
|
||||
# Verify the user is no longer in the invited users list
|
||||
updated_invited_users = UserManager.get_invited_users(admin_user)
|
||||
assert invited_user.email not in [
|
||||
user.email for user in updated_invited_users
|
||||
], f"User {invited_user.email} should not be in invited users list after accepting"
|
||||
|
||||
# Verify the user has BASIC role in the organization
|
||||
assert (
|
||||
updated_user_info.role == UserRole.BASIC
|
||||
), f"Expected user to have BASIC role, but got {updated_user_info.role}"
|
||||
|
||||
# Verify user is in the organization
|
||||
user_page = UserManager.get_user_page(
|
||||
user_performing_action=admin_user, role_filter=[UserRole.BASIC]
|
||||
)
|
||||
|
||||
# Check if the invited user is in the list of users with BASIC role
|
||||
invited_user_emails = [user.email for user in user_page.items]
|
||||
assert invited_user.email in invited_user_emails, (
|
||||
f"User {invited_user.email} not found in the list of basic users "
|
||||
f"in the organization. Available users: {invited_user_emails}"
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion(reset: None) -> None:
|
||||
"""
|
||||
Test that SAML login correctly converts users with non-authenticated roles
|
||||
(SLACK_USER or EXT_PERM_USER) to authenticated roles (BASIC).
|
||||
|
||||
This test:
|
||||
1. Creates an admin and a regular user
|
||||
2. Changes the regular user's role to EXT_PERM_USER
|
||||
3. Simulates a SAML login by calling the test endpoint
|
||||
4. Verifies the user's role is converted to BASIC
|
||||
|
||||
This tests the fix that ensures users with non-authenticated roles (SLACK_USER or EXT_PERM_USER)
|
||||
are properly converted to authenticated roles during SAML login.
|
||||
"""
|
||||
# Create an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@onyx-test.com")
|
||||
|
||||
# Create a regular user that we'll convert to EXT_PERM_USER
|
||||
test_user_email = "ext_perm_user@example.com"
|
||||
test_user = UserManager.create(email=test_user_email)
|
||||
|
||||
# Verify the user was created with BASIC role initially
|
||||
assert UserManager.is_role(test_user, UserRole.BASIC)
|
||||
|
||||
# Change the user's role to EXT_PERM_USER using the UserManager
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify the user has EXT_PERM_USER role now
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login by calling the test endpoint
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": test_user_email},
|
||||
headers=admin_user.headers, # Use admin headers for authorization
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify user role was changed in the database
|
||||
assert UserManager.is_role(test_user, UserRole.BASIC)
|
||||
|
||||
# Do the same test with SLACK_USER
|
||||
slack_user_email = "slack_user@example.com"
|
||||
slack_user = UserManager.create(email=slack_user_email)
|
||||
|
||||
# Verify the user was created with BASIC role initially
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
|
||||
# Change the user's role to SLACK_USER
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify the user has SLACK_USER role
|
||||
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
|
||||
|
||||
# Simulate SAML login again
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": slack_user_email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user's role was changed in the database
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
@@ -5,6 +5,7 @@ This file contains tests for the following:
|
||||
- updates the document sets and user groups to remove the connector
|
||||
- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected
|
||||
"""
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -32,6 +33,13 @@ from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
user_group_1: DATestUserGroup
|
||||
user_group_2: DATestUserGroup
|
||||
|
||||
is_ee = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
# create api key
|
||||
@@ -78,16 +86,17 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
|
||||
print("Document sets created and synced")
|
||||
|
||||
# create user groups
|
||||
user_group_1: DATestUserGroup = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
user_group_2: DATestUserGroup = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
|
||||
if is_ee:
|
||||
# create user groups
|
||||
user_group_1 = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
user_group_2 = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
|
||||
|
||||
# inject a finished index attempt and index attempt error (exercises foreign key errors)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
@@ -147,12 +156,13 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
)
|
||||
|
||||
# Update local records to match the database for later comparison
|
||||
user_group_1.cc_pair_ids = []
|
||||
user_group_2.cc_pair_ids = [cc_pair_2.id]
|
||||
doc_set_1.cc_pair_ids = []
|
||||
doc_set_2.cc_pair_ids = [cc_pair_2.id]
|
||||
cc_pair_1.groups = []
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
if is_ee:
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
else:
|
||||
cc_pair_2.groups = []
|
||||
|
||||
CCPairManager.wait_for_deletion_completion(
|
||||
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
|
||||
@@ -168,11 +178,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
verify_deleted=True,
|
||||
)
|
||||
|
||||
cc_pair_2_group_name_expected = []
|
||||
if is_ee:
|
||||
cc_pair_2_group_name_expected = [user_group_2.name]
|
||||
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
doc_set_names=[doc_set_2.name],
|
||||
group_names=[user_group_2.name],
|
||||
group_names=cc_pair_2_group_name_expected,
|
||||
doc_creating_user=admin_user,
|
||||
verify_deleted=False,
|
||||
)
|
||||
@@ -193,15 +207,19 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# validate user groups
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
if is_ee:
|
||||
user_group_1.cc_pair_ids = []
|
||||
user_group_2.cc_pair_ids = [cc_pair_2.id]
|
||||
|
||||
# validate user groups
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
def test_connector_deletion_for_overlapping_connectors(
|
||||
@@ -210,6 +228,13 @@ def test_connector_deletion_for_overlapping_connectors(
|
||||
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
|
||||
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
|
||||
"""
|
||||
user_group_1: DATestUserGroup
|
||||
user_group_2: DATestUserGroup
|
||||
|
||||
is_ee = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
# create api key
|
||||
@@ -281,47 +306,48 @@ def test_connector_deletion_for_overlapping_connectors(
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
# create a user group and attach it to connector 1
|
||||
user_group_1: DATestUserGroup = UserGroupManager.create(
|
||||
name="Test User Group 1",
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_1.groups = [user_group_1.id]
|
||||
if is_ee:
|
||||
# create a user group and attach it to connector 1
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="Test User Group 1",
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_1.groups = [user_group_1.id]
|
||||
|
||||
print("User group 1 created and synced")
|
||||
print("User group 1 created and synced")
|
||||
|
||||
# create a user group and attach it to connector 2
|
||||
user_group_2: DATestUserGroup = UserGroupManager.create(
|
||||
name="Test User Group 2",
|
||||
cc_pair_ids=[cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_2],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
# create a user group and attach it to connector 2
|
||||
user_group_2 = UserGroupManager.create(
|
||||
name="Test User Group 2",
|
||||
cc_pair_ids=[cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_2],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
|
||||
print("User group 2 created and synced")
|
||||
print("User group 2 created and synced")
|
||||
|
||||
# verify vespa document is in the user group
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_1,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
# verify vespa document is in the user group
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_1,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
# delete connector 1
|
||||
CCPairManager.pause_cc_pair(
|
||||
@@ -354,11 +380,15 @@ def test_connector_deletion_for_overlapping_connectors(
|
||||
|
||||
# verify the document is not in any document sets
|
||||
# verify the document is only in user group 2
|
||||
group_names_expected = []
|
||||
if is_ee:
|
||||
group_names_expected = [user_group_2.name]
|
||||
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
doc_set_names=[],
|
||||
group_names=[user_group_2.name],
|
||||
group_names=group_names_expected,
|
||||
doc_creating_user=admin_user,
|
||||
verify_deleted=False,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -12,6 +15,10 @@ from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history is enterprise only",
|
||||
)
|
||||
def test_all_stream_chat_message_objects_outputs(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -16,10 +18,11 @@ from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_send_message_simple_with_history(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
|
||||
# create connectors
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
@@ -53,18 +56,22 @@ def test_send_message_simple_with_history(reset: None) -> None:
|
||||
response_json = response.json()
|
||||
|
||||
# Check that the top document is the correct document
|
||||
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
|
||||
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
|
||||
|
||||
# assert that the metadata is correct
|
||||
for doc in cc_pair_1.documents:
|
||||
found_doc = next(
|
||||
(x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None
|
||||
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
|
||||
None,
|
||||
)
|
||||
assert found_doc
|
||||
assert found_doc["metadata"]["document_id"] == doc.id
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
@@ -154,6 +161,10 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
|
||||
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="/chat/send-message-simple-with-history tests are enterprise only",
|
||||
)
|
||||
def test_send_message_simple_with_history_strict_json(
|
||||
new_admin_user: DATestUser | None,
|
||||
) -> None:
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
This file takes the happy path to adding a curator to a user group and then tests
|
||||
the permissions of the curator manipulating connector-credential pairs.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -15,6 +17,10 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and User Group tests are enterprise only",
|
||||
)
|
||||
def test_cc_pair_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
This file takes the happy path to adding a curator to a user group and then tests
|
||||
the permissions of the curator manipulating connectors.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -13,6 +15,10 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_connector_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
This file takes the happy path to adding a curator to a user group and then tests
|
||||
the permissions of the curator manipulating credentials.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -12,6 +14,10 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_credential_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -10,6 +12,10 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_doc_set_permissions_setup(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -4,6 +4,8 @@ This file tests the permissions for creating and editing personas for different
|
||||
- Curators can edit personas that belong exclusively to groups they curate
|
||||
- Admins can edit all personas
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -13,6 +15,10 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_persona_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"""
|
||||
This file tests the ability of different user types to set the role of other users.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
@@ -10,6 +12,10 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator and user group tests are enterprise only",
|
||||
)
|
||||
def test_user_role_setting_permissions(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""
|
||||
This test tests the happy path for curator permissions
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
@@ -12,6 +16,10 @@ from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator tests are enterprise only",
|
||||
)
|
||||
def test_whole_curator_flow(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
@@ -89,6 +97,10 @@ def test_whole_curator_flow(reset: None) -> None:
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Curator tests are enterprise only",
|
||||
)
|
||||
def test_global_curator_flow(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -63,6 +64,10 @@ def setup_chat_session(reset: None) -> tuple[DATestUser, str]:
|
||||
return admin_user, str(chat_session.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Chat history tests are enterprise only",
|
||||
)
|
||||
def test_chat_history_endpoints(
|
||||
reset: None, setup_chat_session: tuple[DATestUser, str]
|
||||
) -> None:
|
||||
@@ -116,6 +121,10 @@ def test_chat_history_endpoints(
|
||||
assert len(history_response.items) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Chat history tests are enterprise only",
|
||||
)
|
||||
def test_chat_history_csv_export(
|
||||
reset: None, setup_chat_session: tuple[DATestUser, str]
|
||||
) -> None:
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from tests.integration.common_utils.managers.query_history import QueryHistoryManager
|
||||
from tests.integration.common_utils.test_models import DAQueryHistoryEntry
|
||||
@@ -47,6 +50,10 @@ def _verify_query_history_pagination(
|
||||
assert all_expected_sessions == all_retrieved_sessions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Query history tests are enterprise only",
|
||||
)
|
||||
def test_query_history_pagination(reset: None) -> None:
|
||||
(
|
||||
admin_user,
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import SimpleTestDocument
|
||||
|
||||
|
||||
DocumentBuilderType = Callable[[list[str]], list[SimpleTestDocument]]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_builder(admin_user: DATestUser) -> DocumentBuilderType:
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# create connector
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
def _document_builder(contents: list[str]) -> list[SimpleTestDocument]:
|
||||
# seed documents
|
||||
docs: list[SimpleTestDocument] = [
|
||||
DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content=content,
|
||||
api_key=api_key,
|
||||
)
|
||||
for content in contents
|
||||
]
|
||||
|
||||
return docs
|
||||
|
||||
return _document_builder
|
||||
@@ -5,12 +5,11 @@ import pytest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.tests.streaming_endpoints.conftest import DocumentBuilderType
|
||||
|
||||
|
||||
def test_send_message_simple_with_history(reset: None) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
@@ -24,6 +23,44 @@ def test_send_message_simple_with_history(reset: None) -> None:
|
||||
assert len(response.full_message) > 0
|
||||
|
||||
|
||||
def test_send_message__basic_searches(
|
||||
reset: None, admin_user: DATestUser, document_builder: DocumentBuilderType
|
||||
) -> None:
|
||||
MESSAGE = "run a search for 'test'"
|
||||
SHORT_DOC_CONTENT = "test"
|
||||
LONG_DOC_CONTENT = "blah blah blah blah" * 100
|
||||
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
short_doc = document_builder([SHORT_DOC_CONTENT])[0]
|
||||
|
||||
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message=MESSAGE,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert response.top_documents is not None
|
||||
assert len(response.top_documents) == 1
|
||||
assert response.top_documents[0].document_id == short_doc.id
|
||||
|
||||
# make sure this doc is really long so that it will be split into multiple chunks
|
||||
long_doc = document_builder([LONG_DOC_CONTENT])[0]
|
||||
|
||||
# new chat session for simplicity
|
||||
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message=MESSAGE,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert response.top_documents is not None
|
||||
assert len(response.top_documents) == 2
|
||||
# short doc should be more relevant and thus first
|
||||
assert response.top_documents[0].document_id == short_doc.id
|
||||
assert response.top_documents[1].document_id == long_doc.id
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="enable for autorun when we have a testing environment with semantically useful data"
|
||||
)
|
||||
|
||||
@@ -8,6 +8,10 @@ This tests the deletion of a user group with the following foreign key constrain
|
||||
- token_rate_limit (Not Implemented)
|
||||
- persona
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
@@ -25,6 +29,10 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_user_group_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
@@ -11,6 +15,10 @@ from tests.integration.common_utils.test_models import DATestUserGroup
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
@@ -269,5 +269,5 @@ async def test_test_expire_oauth_token(
|
||||
|
||||
# Now should be within 10-11 seconds of the set expiration
|
||||
now = datetime.now(timezone.utc).timestamp()
|
||||
assert update_data["expires_at"] - now >= 9 # Allow 1 second for test execution
|
||||
assert update_data["expires_at"] - now <= 11 # Allow 1 second for test execution
|
||||
assert update_data["expires_at"] - now >= 8.9 # Allow 1 second for test execution
|
||||
assert update_data["expires_at"] - now <= 11.1 # Allow 1 second for test execution
|
||||
|
||||
@@ -9,8 +9,6 @@ from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -19,7 +17,6 @@ from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
@@ -120,24 +117,7 @@ def mock_search_results(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
|
||||
return OnyxContexts(
|
||||
contexts=[
|
||||
OnyxContext(
|
||||
content=section.combined_content,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
blurb=section.center_chunk.blurb,
|
||||
)
|
||||
for section in mock_inference_sections
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_tool(
|
||||
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
|
||||
) -> MagicMock:
|
||||
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
|
||||
mock_tool = MagicMock(spec=SearchTool)
|
||||
mock_tool.name = "search"
|
||||
mock_tool.build_tool_message_content.return_value = "search_response"
|
||||
@@ -146,7 +126,6 @@ def mock_search_tool(
|
||||
json.loads(doc.model_dump_json()) for doc in mock_search_results
|
||||
]
|
||||
mock_tool.run.return_value = [
|
||||
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
|
||||
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
|
||||
]
|
||||
mock_tool.tool_definition.return_value = {
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
@@ -33,7 +32,6 @@ from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
@@ -141,7 +139,6 @@ def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
|
||||
def test_answer_with_search_call(
|
||||
answer_instance: Answer,
|
||||
mock_search_results: list[LlmDoc],
|
||||
mock_contexts: OnyxContexts,
|
||||
mock_search_tool: MagicMock,
|
||||
force_use_tool: ForceUseTool,
|
||||
expected_tool_args: dict,
|
||||
@@ -197,25 +194,21 @@ def test_answer_with_search_call(
|
||||
tool_name="search", tool_args=expected_tool_args
|
||||
)
|
||||
assert output[1] == ToolResponse(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=mock_contexts,
|
||||
)
|
||||
assert output[2] == ToolResponse(
|
||||
id="final_context_documents",
|
||||
response=mock_search_results,
|
||||
)
|
||||
assert output[3] == ToolCallFinalResult(
|
||||
assert output[2] == ToolCallFinalResult(
|
||||
tool_name="search",
|
||||
tool_args=expected_tool_args,
|
||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||
)
|
||||
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||
assert output[5] == expected_citation
|
||||
assert output[6] == OnyxAnswerPiece(
|
||||
assert output[4] == expected_citation
|
||||
assert output[5] == OnyxAnswerPiece(
|
||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||
)
|
||||
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
|
||||
expected_answer = (
|
||||
"Based on the search results, "
|
||||
@@ -268,7 +261,6 @@ def test_answer_with_search_call(
|
||||
def test_answer_with_search_no_tool_calling(
|
||||
answer_instance: Answer,
|
||||
mock_search_results: list[LlmDoc],
|
||||
mock_contexts: OnyxContexts,
|
||||
mock_search_tool: MagicMock,
|
||||
) -> None:
|
||||
answer_instance.graph_config.tooling.tools = [mock_search_tool]
|
||||
@@ -288,30 +280,26 @@ def test_answer_with_search_no_tool_calling(
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
|
||||
# Assertions
|
||||
assert len(output) == 8
|
||||
assert len(output) == 7
|
||||
assert output[0] == ToolCallKickoff(
|
||||
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
|
||||
)
|
||||
assert output[1] == ToolResponse(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=mock_contexts,
|
||||
)
|
||||
assert output[2] == ToolResponse(
|
||||
id=FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
response=mock_search_results,
|
||||
)
|
||||
assert output[3] == ToolCallFinalResult(
|
||||
assert output[2] == ToolCallFinalResult(
|
||||
tool_name="search",
|
||||
tool_args=DEFAULT_SEARCH_ARGS,
|
||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||
)
|
||||
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||
assert output[5] == expected_citation
|
||||
assert output[6] == OnyxAnswerPiece(
|
||||
assert output[4] == expected_citation
|
||||
assert output[5] == OnyxAnswerPiece(
|
||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||
)
|
||||
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
|
||||
expected_answer = (
|
||||
"Based on the search results, "
|
||||
|
||||
@@ -79,7 +79,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
for res in results:
|
||||
print(res)
|
||||
|
||||
expected_count = 4 if skip_gen_ai_answer_generation else 5
|
||||
expected_count = 3 if skip_gen_ai_answer_generation else 4
|
||||
assert len(results) == expected_count
|
||||
if not skip_gen_ai_answer_generation:
|
||||
mock_llm.stream.assert_called_once()
|
||||
|
||||
@@ -89,7 +89,8 @@ def test_run_in_background_and_wait_success() -> None:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert result == 42
|
||||
assert elapsed >= 0.1 # Verify we actually waited for the sleep
|
||||
# sometimes slightly flaky
|
||||
assert elapsed >= 0.095 # Verify we actually waited for the sleep
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
||||
|
||||
@@ -129,6 +129,9 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# optional, only for debugging purposes
|
||||
volumes:
|
||||
- api_server_logs:/var/log
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
@@ -256,7 +259,7 @@ services:
|
||||
- "host.docker.internal:host-gateway"
|
||||
# optional, only for debugging purposes
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- background_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -325,6 +328,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- inference_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -357,6 +362,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- indexing_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -434,4 +441,8 @@ volumes:
|
||||
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
# for logs that we don't want to lose on container restarts
|
||||
api_server_logs:
|
||||
background_logs:
|
||||
inference_model_server_logs:
|
||||
indexing_model_server_logs:
|
||||
|
||||
@@ -106,6 +106,9 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
volumes:
|
||||
# optional, only for debugging purposes
|
||||
- api_server_logs:/var/log
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
@@ -211,7 +214,7 @@ services:
|
||||
- "host.docker.internal:host-gateway"
|
||||
# optional, only for debugging purposes
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- background_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -273,6 +276,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- inference_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -310,6 +315,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- indexing_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -387,4 +394,8 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
# for logs that we don't want to lose on container restarts
|
||||
api_server_logs:
|
||||
background_logs:
|
||||
inference_model_server_logs:
|
||||
indexing_model_server_logs:
|
||||
|
||||
@@ -244,8 +244,6 @@ services:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -423,4 +421,3 @@ volumes:
|
||||
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
|
||||
@@ -54,9 +54,6 @@ services:
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
# optional, only for debugging purposes
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -236,4 +233,3 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
|
||||
@@ -36,6 +36,10 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
volumes:
|
||||
# optional, only for debugging purposes
|
||||
- api_server_logs:/var/log
|
||||
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
@@ -69,7 +73,7 @@ services:
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- background_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -122,6 +126,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- inference_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -150,6 +156,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- indexing_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -231,4 +239,8 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
# for logs that we don't want to lose on container restarts
|
||||
api_server_logs:
|
||||
background_logs:
|
||||
inference_model_server_logs:
|
||||
indexing_model_server_logs:
|
||||
|
||||
@@ -32,13 +32,14 @@ services:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
volumes:
|
||||
- api_server_logs:/var/log
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
@@ -76,7 +77,7 @@ services:
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- background_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -152,6 +153,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- inference_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -180,6 +183,8 @@ services:
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
# optional, only for debugging purposes
|
||||
- indexing_model_server_logs:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
@@ -264,4 +269,8 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
log_store: # for logs that we don't want to lose on container restarts
|
||||
# for logs that we don't want to lose on container restarts
|
||||
api_server_logs:
|
||||
background_logs:
|
||||
inference_model_server_logs:
|
||||
indexing_model_server_logs:
|
||||
|
||||
@@ -63,7 +63,7 @@ services:
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
volumes:
|
||||
- log_store:/var/log/persisted-logs
|
||||
- log_store:/var/log
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
|
||||
1
openapi.json
Normal file
1
openapi.json
Normal file
File diff suppressed because one or more lines are too long
@@ -45,7 +45,7 @@ export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
|
||||
className="mr-1 my-auto cursor-pointer"
|
||||
onClick={() =>
|
||||
router.push(
|
||||
`/admin/tools/edit/${tool.id}?u=${Date.now()}`
|
||||
`/admin/actions/edit/${tool.id}?u=${Date.now()}`
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
@@ -281,7 +281,7 @@ export default function AddConnector({
|
||||
return (
|
||||
<Formik
|
||||
initialValues={{
|
||||
...createConnectorInitialValues(connector),
|
||||
...createConnectorInitialValues(connector, currentCredential),
|
||||
...Object.fromEntries(
|
||||
connectorConfigs[connector].advanced_values.map((field) => [
|
||||
field.name,
|
||||
|
||||
@@ -1384,6 +1384,7 @@ export function ChatPage({
|
||||
if (!packet) {
|
||||
continue;
|
||||
}
|
||||
console.log("Packet:", JSON.stringify(packet));
|
||||
|
||||
if (!initialFetchDetails) {
|
||||
if (!Object.hasOwn(packet, "user_message_id")) {
|
||||
@@ -1729,6 +1730,7 @@ export function ChatPage({
|
||||
}
|
||||
}
|
||||
} catch (e: any) {
|
||||
console.log("Error:", e);
|
||||
const errorMsg = e.message;
|
||||
upsertToCompleteMessageMap({
|
||||
messages: [
|
||||
@@ -1756,11 +1758,13 @@ export function ChatPage({
|
||||
completeMessageMapOverride: currentMessageMap(completeMessageDetail),
|
||||
});
|
||||
}
|
||||
console.log("Finished streaming");
|
||||
setAgenticGenerating(false);
|
||||
resetRegenerationState(currentSessionId());
|
||||
|
||||
updateChatState("input");
|
||||
if (isNewSession) {
|
||||
console.log("Setting up new session");
|
||||
if (finalMessage) {
|
||||
setSelectedMessageForDocDisplay(finalMessage.message_id);
|
||||
}
|
||||
|
||||
@@ -91,7 +91,7 @@ export function AgenticToggle({
|
||||
>
|
||||
<div className="flex items-center space-x-2 mb-3">
|
||||
<h3 className="text-sm font-semibold text-neutral-900">
|
||||
Agent Search (BETA)
|
||||
Agent Search
|
||||
</h3>
|
||||
</div>
|
||||
<p className="text-xs text-neutral-600 dark:text-neutral-700 mb-2">
|
||||
|
||||
@@ -34,11 +34,12 @@
|
||||
/* -------------------------------------------------------
|
||||
* 2. Keep special, custom, or near-duplicate background
|
||||
* ------------------------------------------------------- */
|
||||
--background: #fefcfa; /* slightly off-white, keep it */
|
||||
--background: #fefcfa; /* slightly off-white */
|
||||
--background-50: #fffdfb; /* a little lighter than background but not quite white */
|
||||
--input-background: #fefcfa;
|
||||
--input-border: #f1eee8;
|
||||
--text-text: #f4f2ed;
|
||||
--background-dark: #e9e6e0;
|
||||
--background-dark: #141414;
|
||||
--new-background: #ebe7de;
|
||||
--new-background-light: #d9d1c0;
|
||||
--background-chatbar: #f5f3ee;
|
||||
@@ -234,6 +235,7 @@
|
||||
|
||||
--text-text: #1d1d1d;
|
||||
--background-dark: #252525;
|
||||
--background-50: #252525;
|
||||
|
||||
/* --new-background: #fff; */
|
||||
--new-background: #2c2c2c;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user