Compare commits

..

3 Commits

Author SHA1 Message Date
Bo-Onyx
4da54f34b3 address comments 2026-03-23 13:11:52 -07:00
Bo-Onyx
5ebb80aeb2 address comments 2026-03-22 12:11:03 -07:00
Bo-Onyx
f77d5d2d01 feat(hook): integrate query processing hook point 2026-03-20 17:37:01 -07:00
212 changed files with 2804 additions and 5870 deletions

View File

@@ -44,7 +44,7 @@ jobs:
fetch-tags: true
- name: Setup uv
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
version: "0.9.9"
enable-cache: false
@@ -165,7 +165,7 @@ jobs:
fetch-depth: 0
- name: Setup uv
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
version: "0.9.9"
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
@@ -307,7 +307,7 @@ jobs:
xdg-utils
- name: setup node
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v6.3.0
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6.2.0
with:
node-version: 24
package-manager-cache: false

View File

@@ -114,7 +114,7 @@ jobs:
ref: main
- name: Install the latest version of uv
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -50,7 +50,7 @@ jobs:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238
with:
node-version: 24
cache: "npm" # zizmor: ignore[cache-poisoning]

View File

@@ -28,7 +28,7 @@ jobs:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts

View File

@@ -272,7 +272,7 @@ jobs:
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]
@@ -471,7 +471,7 @@ jobs:
- name: Install the latest version of uv
if: always()
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
@@ -614,7 +614,7 @@ jobs:
- name: Setup node
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm" # zizmor: ignore[cache-poisoning]

View File

@@ -73,7 +73,7 @@ jobs:
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f
- name: Build and load
uses: docker/bake-action@82490499d2e5613fcead7e128237ef0b0ea210f7 # ratchet:docker/bake-action@v7.0.0
uses: docker/bake-action@5be5f02ff8819ecd3092ea6b2e6261c31774f2b4 # ratchet:docker/bake-action@v6
env:
TAG: model-server-${{ github.run_id }}
with:

View File

@@ -30,7 +30,7 @@ jobs:
- name: Setup Terraform
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
- name: Setup node
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v6
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
with: # zizmor: ignore[cache-poisoning]
node-version: 22
cache: "npm"

View File

@@ -22,7 +22,7 @@ jobs:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm"

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -32,7 +32,7 @@ jobs:
persist-credentials: false
- name: Setup node
uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
with:
node-version: 22
cache: "npm"

View File

@@ -24,7 +24,7 @@ jobs:
persist-credentials: false
- name: Install the latest version of uv
uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"

View File

@@ -1,26 +0,0 @@
"""rename persona is_visible to is_listed and featured to is_featured
Revision ID: b728689f45b1
Revises: 689433b0d8de
Create Date: 2026-03-23 12:36:26.607305
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b728689f45b1"
down_revision = "689433b0d8de"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("persona", "is_visible", new_column_name="is_listed")
op.alter_column("persona", "featured", new_column_name="is_featured")
def downgrade() -> None:
op.alter_column("persona", "is_listed", new_column_name="is_visible")
op.alter_column("persona", "is_featured", new_column_name="featured")

View File

@@ -36,56 +36,6 @@ TABLES_WITH_USER_ID = [
]
def _dedupe_null_notifications(connection: sa.Connection) -> None:
# Multiple NULL-owned notifications can exist because the unique index treats
# NULL user_id values as distinct. Before migrating them to the anonymous
# user, collapse duplicates and remove rows that would conflict with an
# already-existing anonymous notification.
result = connection.execute(
sa.text(
"""
WITH ranked_null_notifications AS (
SELECT
id,
ROW_NUMBER() OVER (
PARTITION BY notif_type, COALESCE(additional_data, '{}'::jsonb)
ORDER BY first_shown DESC, last_shown DESC, id DESC
) AS row_num
FROM notification
WHERE user_id IS NULL
)
DELETE FROM notification
WHERE id IN (
SELECT id
FROM ranked_null_notifications
WHERE row_num > 1
)
"""
)
)
if result.rowcount > 0:
print(f"Deleted {result.rowcount} duplicate NULL-owned notifications")
result = connection.execute(
sa.text(
"""
DELETE FROM notification AS null_owned
USING notification AS anonymous_owned
WHERE null_owned.user_id IS NULL
AND anonymous_owned.user_id = :user_id
AND null_owned.notif_type = anonymous_owned.notif_type
AND COALESCE(null_owned.additional_data, '{}'::jsonb) =
COALESCE(anonymous_owned.additional_data, '{}'::jsonb)
"""
),
{"user_id": ANONYMOUS_USER_UUID},
)
if result.rowcount > 0:
print(
f"Deleted {result.rowcount} NULL-owned notifications that conflict with existing anonymous-owned notifications"
)
def upgrade() -> None:
"""
Create the anonymous user for anonymous access feature.
@@ -115,12 +65,7 @@ def upgrade() -> None:
# Migrate any remaining user_id=NULL records to anonymous user
for table in TABLES_WITH_USER_ID:
# Dedup notifications outside the savepoint so deletions persist
# even if the subsequent UPDATE rolls back
if table == "notification":
_dedupe_null_notifications(connection)
with connection.begin_nested():
try:
# Exclude public credential (id=0) which must remain user_id=NULL
# Exclude builtin tools (in_code_tool_id IS NOT NULL) which must remain user_id=NULL
# Exclude builtin personas (builtin_persona=True) which must remain user_id=NULL
@@ -135,7 +80,6 @@ def upgrade() -> None:
condition = "user_id IS NULL AND is_public = false"
else:
condition = "user_id IS NULL"
result = connection.execute(
sa.text(
f"""
@@ -148,19 +92,19 @@ def upgrade() -> None:
)
if result.rowcount > 0:
print(f"Updated {result.rowcount} rows in {table} to anonymous user")
except Exception as e:
print(f"Skipping {table}: {e}")
def downgrade() -> None:
"""
Set anonymous user's records back to NULL and delete the anonymous user.
Note: Duplicate NULL-owned notifications removed during upgrade are not restored.
"""
connection = op.get_bind()
# Set records back to NULL
for table in TABLES_WITH_USER_ID:
with connection.begin_nested():
try:
connection.execute(
sa.text(
f"""
@@ -171,6 +115,8 @@ def downgrade() -> None:
),
{"user_id": ANONYMOUS_USER_UUID},
)
except Exception:
pass
# Delete the anonymous user
connection.execute(

View File

@@ -25,6 +25,9 @@ from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Default number of pre-provisioned tenants to maintain
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
@@ -55,7 +58,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_check: RedisLock = r.lock(
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# These tasks should never overlap
@@ -71,7 +74,9 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
num_available_tenants = db_session.query(AvailableTenant).count()
# Get the target number of available tenants
num_minimum_available_tenants = TARGET_AVAILABLE_TENANTS
num_minimum_available_tenants = getattr(
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
)
# Calculate how many new tenants we need to provision
if num_available_tenants < num_minimum_available_tenants:
@@ -93,12 +98,7 @@ def check_available_tenants(self: Task) -> None: # noqa: ARG001
task_logger.exception("Error in check_available_tenants task")
finally:
try:
lock_check.release()
except Exception:
task_logger.warning(
"Could not release check lock (likely expired), continuing"
)
lock_check.release()
def pre_provision_tenant() -> None:
@@ -113,7 +113,7 @@ def pre_provision_tenant() -> None:
r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
lock_provision: RedisLock = r.lock(
OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK,
timeout=_TENANT_PROVISIONING_TIME_LIMIT,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
@@ -185,9 +185,4 @@ def pre_provision_tenant() -> None:
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
finally:
try:
lock_provision.release()
except Exception:
task_logger.warning(
"Could not release provision lock (likely expired), continuing"
)
lock_provision.release()

View File

@@ -157,11 +157,7 @@ def fetch_logo_helper(db_session: Session) -> Response: # noqa: ARG001
detail="No logo file found",
)
else:
return Response(
content=onyx_file.data,
media_type=onyx_file.mime_type,
headers={"Cache-Control": "no-cache"},
)
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
def fetch_logotype_helper(db_session: Session) -> Response: # noqa: ARG001

View File

@@ -178,7 +178,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
system_prompt=persona.system_prompt,
task_prompt=persona.task_prompt,
datetime_aware=persona.datetime_aware,
is_featured=persona.is_featured,
featured=persona.featured,
commit=False,
)
db_session.commit()

View File

@@ -80,45 +80,15 @@ def capture_and_sync_with_alternate_posthog(
logger.error(f"Error identifying cloud posthog user: {e}")
def alias_user(distinct_id: str, anonymous_id: str) -> None:
"""Link an anonymous distinct_id to an identified user, merging person profiles.
No-ops when the IDs match (e.g. returning users whose PostHog cookie
already contains their identified user ID).
"""
if not posthog or anonymous_id == distinct_id:
return
try:
posthog.alias(previous_id=anonymous_id, distinct_id=distinct_id)
posthog.flush()
except Exception as e:
logger.error(f"Error aliasing PostHog user: {e}")
def get_anon_id_from_request(request: Any) -> str | None:
"""Extract the anonymous distinct_id from the app PostHog cookie on a request."""
if not POSTHOG_API_KEY:
return None
cookie_name = f"ph_{POSTHOG_API_KEY}_posthog"
if (cookie_value := request.cookies.get(cookie_name)) and (
parsed := parse_posthog_cookie(cookie_value)
):
return parsed.get("distinct_id")
return None
def get_marketing_posthog_cookie_name() -> str | None:
if not MARKETING_POSTHOG_API_KEY:
return None
return f"onyx_custom_ph_{MARKETING_POSTHOG_API_KEY}_posthog"
def parse_posthog_cookie(cookie_value: str) -> dict[str, Any] | None:
def parse_marketing_cookie(cookie_value: str) -> dict[str, Any] | None:
"""
Parse a URL-encoded JSON PostHog cookie
Parse the URL-encoded JSON marketing cookie.
Expected format (URL-encoded):
{"distinct_id":"...", "featureFlags":{"landing_page_variant":"..."}, ...}
@@ -132,7 +102,7 @@ def parse_posthog_cookie(cookie_value: str) -> dict[str, Any] | None:
cookie_data = json.loads(decoded_cookie)
distinct_id = cookie_data.get("distinct_id")
if not distinct_id or not isinstance(distinct_id, str):
if not distinct_id:
return None
return cookie_data

View File

@@ -135,8 +135,6 @@ from onyx.redis.redis_pool import retrieve_ws_token_data
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_alias
from onyx.utils.telemetry import mt_cloud_get_anon_id
from onyx.utils.telemetry import mt_cloud_identify
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.telemetry import optional_telemetry
@@ -253,12 +251,18 @@ def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if not email:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email must be specified")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email is not valid")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
for email_whitelist in whitelist:
try:
@@ -275,9 +279,12 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise OnyxError(
OnyxErrorCode.UNAUTHORIZED,
"This workspace is invite-only. Please ask your admin to invite you.",
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
@@ -287,47 +294,48 @@ def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
verify_email_is_invited(email)
def verify_email_domain(email: str, *, is_registration: bool = False) -> None:
def verify_email_domain(email: str) -> None:
if email.count("@") != 1:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email is not valid")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email is not valid",
)
local_part, domain = email.split("@")
domain = domain.lower()
local_part = local_part.lower()
if AUTH_TYPE == AuthType.CLOUD:
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
if domain == "googlemail.com":
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Please use @gmail.com instead of @googlemail.com.",
)
# Only block dotted Gmail on new signups — existing users must still be
# able to sign in with the address they originally registered with.
if is_registration and domain == "gmail.com" and "." in local_part:
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Gmail addresses with '.' are not allowed. Please use your base email address.",
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
)
if "+" in local_part and domain != "onyx.app":
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Email addresses with '+' are not allowed. Please use your base email address.",
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
},
)
# Check if email uses a disposable/temporary domain
if is_disposable_email(email):
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Disposable email addresses are not allowed. Please use a permanent email address.",
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
},
)
# Check domain whitelist if configured
if VALID_EMAIL_DOMAINS:
if domain not in VALID_EMAIL_DOMAINS:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email domain is not valid")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email domain is not valid",
)
def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
@@ -343,7 +351,7 @@ def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None:
)(db_session, seats_needed=seats_needed)
if result is not None and not result.available:
raise OnyxError(OnyxErrorCode.SEAT_LIMIT_EXCEEDED, result.error_message)
raise HTTPException(status_code=402, detail=result.error_message)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
@@ -396,7 +404,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
captcha_token or "", expected_action="signup"
)
except CaptchaVerificationError as e:
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": str(e)},
)
# We verify the password here to make sure it's valid before we proceed
await self.validate_password(
@@ -406,10 +417,13 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Check for disposable emails BEFORE provisioning tenant
# This prevents creating tenants for throwaway email addresses
try:
verify_email_domain(user_create.email, is_registration=True)
except OnyxError as e:
verify_email_domain(user_create.email)
except HTTPException as e:
# Log blocked disposable email attempts
if "Disposable email" in e.detail:
if (
e.status_code == status.HTTP_400_BAD_REQUEST
and "Disposable email" in str(e.detail)
):
domain = (
user_create.email.split("@")[-1]
if "@" in user_create.email
@@ -553,9 +567,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
result = await db_session.execute(
select(Persona.id)
.where(
Persona.is_featured.is_(True),
Persona.featured.is_(True),
Persona.is_public.is_(True),
Persona.is_listed.is_(True),
Persona.is_visible.is_(True),
Persona.deleted.is_(False),
)
.order_by(
@@ -683,8 +697,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
raise exceptions.UserNotExists()
except exceptions.UserNotExists:
verify_email_domain(account_email, is_registration=True)
# Check seat availability before creating (single-tenant only)
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
@@ -783,12 +795,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
logger.exception("Error deleting anonymous user cookie")
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
# Link the anonymous PostHog session to the identified user so that
# pre-login session recordings and events merge into one person profile.
if anon_id := mt_cloud_get_anon_id(request):
mt_cloud_alias(distinct_id=str(user.id), anonymous_id=anon_id)
mt_cloud_identify(
distinct_id=str(user.id),
properties={"email": user.email, "tenant_id": tenant_id},
@@ -812,11 +818,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_count = await get_user_count()
logger.debug(f"Current tenant user count: {user_count}")
# Link the anonymous PostHog session to the identified user so
# that pre-signup session recordings merge into one person profile.
if anon_id := mt_cloud_get_anon_id(request):
mt_cloud_alias(distinct_id=str(user.id), anonymous_id=anon_id)
# Ensure a PostHog person profile exists for this user.
mt_cloud_identify(
distinct_id=str(user.id),
@@ -845,9 +846,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
attribute="get_marketing_posthog_cookie_name",
noop_return_value=None,
)
parse_posthog_cookie = fetch_ee_implementation_or_noop(
parse_marketing_cookie = fetch_ee_implementation_or_noop(
module="onyx.utils.posthog_client",
attribute="parse_posthog_cookie",
attribute="parse_marketing_cookie",
noop_return_value=None,
)
capture_and_sync_with_alternate_posthog = fetch_ee_implementation_or_noop(
@@ -861,7 +862,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
and user_count is not None
and (marketing_cookie_name := get_marketing_posthog_cookie_name())
and (marketing_cookie_value := request.cookies.get(marketing_cookie_name))
and (parsed_cookie := parse_posthog_cookie(marketing_cookie_value))
and (parsed_cookie := parse_marketing_cookie(marketing_cookie_value))
):
marketing_anonymous_id = parsed_cookie["distinct_id"]

View File

@@ -30,8 +30,6 @@ from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.utils import plaintext_file_name_for_id
from onyx.file_store.utils import store_plaintext
from onyx.kg.models import KGException
from onyx.kg.setup.kg_default_entity_definitions import (
populate_missing_default_entity_types__commit,
@@ -291,33 +289,6 @@ def process_kg_commands(
raise KGException("KG setup done")
def _get_or_extract_plaintext(
file_id: str,
extract_fn: Callable[[], str],
) -> str:
"""Load cached plaintext for a file, or extract and store it.
Tries to read pre-stored plaintext from the file store. On a miss,
calls extract_fn to produce the text, then stores the result so
future calls skip the expensive extraction.
"""
file_store = get_default_file_store()
plaintext_key = plaintext_file_name_for_id(file_id)
# Try cached plaintext first.
try:
plaintext_io = file_store.read_file(plaintext_key, mode="b")
return plaintext_io.read().decode("utf-8")
except Exception:
logger.exception(f"Error when reading file, id={file_id}")
# Cache miss — extract and store.
content_text = extract_fn()
if content_text:
store_plaintext(file_id, content_text)
return content_text
@log_function_time(print_only=True)
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
@@ -332,23 +303,12 @@ def load_chat_file(
file_type = ChatFileType(file_descriptor["type"])
if file_type.is_text_file():
file_id = file_descriptor["id"]
def _extract() -> str:
return extract_file_text(
try:
content_text = extract_file_text(
file=file_io,
file_name=file_descriptor.get("name") or "",
break_on_unprocessable=False,
)
# Use the user_file_id as cache key when available (matches what
# the celery indexing worker stores), otherwise fall back to the
# file store id (covers code-interpreter-generated files, etc.).
user_file_id_str = file_descriptor.get("user_file_id")
cache_key = user_file_id_str or file_id
try:
content_text = _get_or_extract_plaintext(cache_key, _extract)
except Exception as e:
logger.warning(
f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}"

View File

@@ -177,8 +177,8 @@ class ExtractedContextFiles(BaseModel):
class SearchParams(BaseModel):
"""Resolved search filter IDs and search-tool usage for a chat turn."""
project_id_filter: int | None
persona_id_filter: int | None
search_project_id: int | None
search_persona_id: int | None
search_usage: SearchToolUsage

View File

@@ -59,6 +59,7 @@ from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.enums import HookPoint
from onyx.db.memory import get_memories
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -68,11 +69,19 @@ from onyx.db.models import UserFile
from onyx.db.projects import get_user_files_from_project
from onyx.db.tools import get_tools
from onyx.deep_research.dr_loop import run_deep_research_llm_loop
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
from onyx.error_handling.exceptions import OnyxError
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import verify_user_files
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingPayload
from onyx.hooks.points.query_processing import QueryProcessingResponse
from onyx.llm.factory import get_llm_for_persona
from onyx.llm.factory import get_llm_token_counter
from onyx.llm.interfaces import LLM
@@ -399,13 +408,13 @@ def determine_search_params(
"""
is_custom_persona = persona_id != DEFAULT_PERSONA_ID
project_id_filter: int | None = None
persona_id_filter: int | None = None
search_project_id: int | None = None
search_persona_id: int | None = None
if extracted_context_files.use_as_search_filter:
if is_custom_persona:
persona_id_filter = persona_id
search_persona_id = persona_id
else:
project_id_filter = project_id
search_project_id = project_id
search_usage = SearchToolUsage.AUTO
if not is_custom_persona and project_id:
@@ -418,12 +427,38 @@ def determine_search_params(
search_usage = SearchToolUsage.DISABLED
return SearchParams(
project_id_filter=project_id_filter,
persona_id_filter=persona_id_filter,
search_project_id=search_project_id,
search_persona_id=search_persona_id,
search_usage=search_usage,
)
def _resolve_query_processing_hook_result(
hook_result: BaseModel | HookSkipped | HookSoftFailed,
message_text: str,
) -> str:
"""Apply the Query Processing hook result to the message text.
Returns the (possibly rewritten) message text, or raises OnyxError with
QUERY_REJECTED if the hook signals rejection (query is null or empty).
HookSkipped and HookSoftFailed are pass-throughs — the original text is
returned unchanged.
"""
if isinstance(hook_result, (HookSkipped, HookSoftFailed)):
return message_text
if not isinstance(hook_result, QueryProcessingResponse):
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"Expected QueryProcessingResponse from hook, got {type(hook_result).__name__}",
)
if not (hook_result.query and hook_result.query.strip()):
raise OnyxError(
OnyxErrorCode.QUERY_REJECTED,
hook_result.rejection_message or "Your query was rejected.",
)
return hook_result.query.strip()
def handle_stream_message_objects(
new_msg_req: SendMessageRequest,
user: User,
@@ -484,6 +519,7 @@ def handle_stream_message_objects(
persona = chat_session.persona
message_text = new_msg_req.message
user_identity = LLMUserIdentity(
user_id=llm_user_identifier, session_id=str(chat_session.id)
)
@@ -575,6 +611,27 @@ def handle_stream_message_objects(
if parent_message.message_type == MessageType.USER:
user_message = parent_message
else:
# New message — run the Query Processing hook before saving to DB.
# Skipped on regeneration: the message already exists and was accepted previously.
# Skip the hook for empty/whitespace-only messages — no meaningful query
# to process, and SendMessageRequest.message has no min_length guard.
if message_text.strip():
hook_result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=QueryProcessingPayload(
query=message_text,
# Pass None for anonymous users or authenticated users without an email
# (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None,
# so None is accepted and serialised as null in both cases.
user_email=None if user.is_anonymous else user.email,
chat_session_id=str(chat_session.id),
).model_dump(),
)
message_text = _resolve_query_processing_hook_result(
hook_result, message_text
)
user_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=parent_message,
@@ -711,8 +768,8 @@ def handle_stream_message_objects(
llm=llm,
search_tool_config=SearchToolConfig(
user_selected_filters=new_msg_req.internal_search_filters,
project_id_filter=search_params.project_id_filter,
persona_id_filter=search_params.persona_id_filter,
project_id=search_params.search_project_id,
persona_id=search_params.search_persona_id,
bypass_acl=bypass_acl,
slack_context=slack_context,
enable_slack_search=_should_enable_slack_search(
@@ -914,6 +971,17 @@ def handle_stream_message_objects(
state_container=state_container,
)
except OnyxError as e:
if e.error_code is not OnyxErrorCode.QUERY_REJECTED:
log_onyx_error(e)
yield StreamingError(
error=e.detail,
error_code=e.error_code.code,
is_retryable=e.status_code >= 500,
)
db_session.rollback()
return
except ValueError as e:
logger.exception("Failed to process chat message.")

View File

@@ -88,9 +88,8 @@ WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
IFRAME_TEXT_LENGTH_THRESHOLD = 700
# Message indicating JavaScript is disabled, which often appears when scraping fails
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
# Grace period after page navigation to allow bot-detection challenges
# and SPA content rendering to complete
PAGE_RENDER_TIMEOUT_MS = 5000
# Grace period after page navigation to allow bot-detection challenges to complete
BOT_DETECTION_GRACE_PERIOD_MS = 5000
# Define common headers that mimic a real browser
DEFAULT_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"
@@ -548,15 +547,7 @@ class WebConnector(LoadConnector):
)
# Give the page a moment to start rendering after navigation commits.
# Allows CloudFlare and other bot-detection challenges to complete.
page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS)
# Wait for network activity to settle so SPAs that fetch content
# asynchronously after the initial JS bundle have time to render.
try:
# A bit of extra time to account for long-polling, websockets, etc.
page.wait_for_load_state("networkidle", timeout=PAGE_RENDER_TIMEOUT_MS)
except TimeoutError:
pass
page.wait_for_timeout(BOT_DETECTION_GRACE_PERIOD_MS)
last_modified = (
page_response.header_value("Last-Modified") if page_response else None
@@ -585,7 +576,7 @@ class WebConnector(LoadConnector):
# (e.g., CloudFlare protection keeps making requests)
try:
page.wait_for_load_state(
"networkidle", timeout=PAGE_RENDER_TIMEOUT_MS
"networkidle", timeout=BOT_DETECTION_GRACE_PERIOD_MS
)
except TimeoutError:
# If networkidle times out, just give it a moment for content to render

View File

@@ -2,6 +2,7 @@ from collections.abc import Sequence
from datetime import datetime
from enum import Enum
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
@@ -69,13 +70,9 @@ class BaseFilters(BaseModel):
class UserFileFilters(BaseModel):
# Scopes search to user files tagged with a given project/persona in Vespa.
# These are NOT simply the IDs of the current project or persona — they are
# only set when the persona's/project's user files overflowed the LLM
# context window and must be searched via vector DB instead of being loaded
# directly into the prompt.
project_id_filter: int | None = None
persona_id_filter: int | None = None
user_file_ids: list[UUID] | None = None
project_id: int | None = None
persona_id: int | None = None
class AssistantKnowledgeFilters(BaseModel):

View File

@@ -1,5 +1,6 @@
from collections import defaultdict
from datetime import datetime
from uuid import UUID
from sqlalchemy.orm import Session
@@ -38,8 +39,9 @@ logger = setup_logger()
def _build_index_filters(
user_provided_filters: BaseFilters | None,
user: User, # Used for ACLs, anonymous users only see public docs
project_id_filter: int | None,
persona_id_filter: int | None,
project_id: int | None,
persona_id: int | None,
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
db_session: Session | None = None,
@@ -95,6 +97,16 @@ def _build_index_filters(
if not source_filter and detected_source_filter:
source_filter = detected_source_filter
# CRITICAL FIX: If user_file_ids are present, we must ensure "user_file"
# source type is included in the filter, otherwise user files will be excluded!
if user_file_ids and source_filter:
from onyx.configs.constants import DocumentSource
# Add user_file to the source filter if not already present
if DocumentSource.USER_FILE not in source_filter:
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
if bypass_acl:
user_acl_filters = None
elif acl_filters is not None:
@@ -105,8 +117,9 @@ def _build_index_filters(
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
project_id_filter=project_id_filter,
persona_id_filter=persona_id_filter,
user_file_ids=user_file_ids,
project_id=project_id,
persona_id=persona_id,
source_type=source_filter,
document_set=document_set_filter,
time_cutoff=time_filter,
@@ -252,16 +265,19 @@ def search_pipeline(
db_session: Session | None = None,
auto_detect_filters: bool = False,
llm: LLM | None = None,
# Vespa metadata filters for overflowing user files. NOT the raw IDs
# of the current project/persona — only set when user files couldn't fit
# in the LLM context and need to be searched via vector DB.
project_id_filter: int | None = None,
persona_id_filter: int | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# If a persona_id is provided, search scopes to files attached to this persona
persona_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
user_uploaded_persona_files: list[UUID] | None = (
[user_file.id for user_file in persona.user_files] if persona else None
)
persona_document_sets: list[str] | None = (
[persona_document_set.name for persona_document_set in persona.document_sets]
if persona
@@ -286,8 +302,9 @@ def search_pipeline(
filters = _build_index_filters(
user_provided_filters=chunk_search_request.user_selected_filters,
user=user,
project_id_filter=project_id_filter,
persona_id_filter=persona_id_filter,
project_id=project_id,
persona_id=persona_id,
user_file_ids=user_uploaded_persona_files,
persona_document_sets=persona_document_sets,
persona_time_cutoff=persona_time_cutoff,
db_session=db_session,

View File

@@ -110,6 +110,7 @@ def search_chunks(
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(

View File

@@ -583,67 +583,6 @@ def get_latest_index_attempt_for_cc_pair_id(
return db_session.execute(stmt).scalar_one_or_none()
def get_latest_successful_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool = False,
) -> IndexAttempt | None:
"""Returns the most recent successful index attempt for the given cc pair,
filtered to the current (or future) search settings.
Uses MAX(id) semantics to match get_latest_index_attempts_by_status."""
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
stmt = (
select(IndexAttempt)
.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
IndexAttempt.status.in_(
[IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS]
),
)
.join(SearchSettings)
.where(SearchSettings.status == status)
.order_by(desc(IndexAttempt.id))
.limit(1)
)
return db_session.execute(stmt).scalar_one_or_none()
def get_latest_successful_index_attempts_parallel(
secondary_index: bool = False,
) -> Sequence[IndexAttempt]:
"""Batch version: returns the latest successful index attempt per cc pair.
Covers both SUCCESS and COMPLETED_WITH_ERRORS (matching is_successful())."""
model_status = (
IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
)
with get_session_with_current_tenant() as db_session:
latest_ids = (
select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
)
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
.where(
SearchSettings.status == model_status,
IndexAttempt.status.in_(
[IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS]
),
)
.group_by(IndexAttempt.connector_credential_pair_id)
.subquery()
)
stmt = select(IndexAttempt).join(
latest_ids,
(
IndexAttempt.connector_credential_pair_id
== latest_ids.c.connector_credential_pair_id
)
& (IndexAttempt.id == latest_ids.c.max_id),
)
return db_session.execute(stmt).scalars().all()
def count_index_attempts_for_cc_pair(
db_session: Session,
cc_pair_id: int,

View File

@@ -3467,9 +3467,9 @@ class Persona(Base):
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# Featured personas are highlighted in the UI
is_featured: Mapped[bool] = mapped_column(Boolean, default=False)
# controls whether the persona is listed in user-facing agent lists
is_listed: Mapped[bool] = mapped_column(Boolean, default=True)
featured: Mapped[bool] = mapped_column(Boolean, default=False)
# controls whether the persona is available to be selected by users
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
# controls the ordering of personas in the UI
# higher priority personas are displayed first, ties are resolved by the ID,
# where lower value IDs (e.g. created earlier) are displayed first

View File

@@ -126,7 +126,7 @@ def _add_user_filters(
else:
# Group the public persona conditions
public_condition = (Persona.is_public == True) & ( # noqa: E712
Persona.is_listed == True # noqa: E712
Persona.is_visible == True # noqa: E712
)
where_clause |= public_condition
@@ -260,7 +260,7 @@ def create_update_persona(
try:
# Featured persona validation
if create_persona_request.is_featured:
if create_persona_request.featured:
# Curators can edit featured personas, but not make them
# TODO this will be reworked soon with RBAC permissions feature
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
@@ -300,7 +300,7 @@ def create_update_persona(
remove_image=create_persona_request.remove_image,
search_start_date=create_persona_request.search_start_date,
label_ids=create_persona_request.label_ids,
is_featured=create_persona_request.is_featured,
featured=create_persona_request.featured,
user_file_ids=converted_user_file_ids,
commit=False,
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
@@ -910,11 +910,11 @@ def upsert_persona(
uploaded_image_id: str | None = None,
icon_name: str | None = None,
display_priority: int | None = None,
is_listed: bool = True,
is_visible: bool = True,
remove_image: bool | None = None,
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_featured: bool | None = None,
featured: bool | None = None,
label_ids: list[int] | None = None,
user_file_ids: list[UUID] | None = None,
hierarchy_node_ids: list[int] | None = None,
@@ -1037,13 +1037,13 @@ def upsert_persona(
if remove_image or uploaded_image_id:
existing_persona.uploaded_image_id = uploaded_image_id
existing_persona.icon_name = icon_name
existing_persona.is_listed = is_listed
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
if label_ids is not None:
existing_persona.labels.clear()
existing_persona.labels = labels or []
existing_persona.is_featured = (
is_featured if is_featured is not None else existing_persona.is_featured
existing_persona.featured = (
featured if featured is not None else existing_persona.featured
)
# Update embedded prompt fields if provided
if system_prompt is not None:
@@ -1109,9 +1109,9 @@ def upsert_persona(
uploaded_image_id=uploaded_image_id,
icon_name=icon_name,
display_priority=display_priority,
is_listed=is_listed,
is_visible=is_visible,
search_start_date=search_start_date,
is_featured=(is_featured if is_featured is not None else False),
featured=(featured if featured is not None else False),
user_files=user_files or [],
labels=labels or [],
hierarchy_nodes=hierarchy_nodes or [],
@@ -1158,7 +1158,7 @@ def delete_old_default_personas(
def update_persona_featured(
persona_id: int,
is_featured: bool,
featured: bool,
db_session: Session,
user: User,
) -> None:
@@ -1166,13 +1166,13 @@ def update_persona_featured(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
persona.is_featured = is_featured
persona.featured = featured
db_session.commit()
def update_persona_visibility(
persona_id: int,
is_listed: bool,
is_visible: bool,
db_session: Session,
user: User,
) -> None:
@@ -1180,7 +1180,7 @@ def update_persona_visibility(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
persona.is_listed = is_listed
persona.is_visible = is_visible
db_session.commit()

View File

@@ -75,7 +75,7 @@ def create_slack_channel_persona(
llm_model_version_override=None,
starter_messages=None,
is_public=True,
is_featured=False,
featured=False,
db_session=db_session,
commit=False,
)

View File

@@ -10,8 +10,8 @@ How `IndexFilters` fields combine into the final query filter. Applies to both V
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
| **ACL** | `access_control_list` | OR within, AND with rest |
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
| **Knowledge scope** | `document_set`, `attached_document_ids`, `hierarchy_node_ids`, `persona_id_filter` | OR within group, AND with rest |
| **Additive scope** | `project_id_filter` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
## How filters combine
@@ -31,22 +31,12 @@ AND time >= cutoff -- if set
The knowledge scope filter controls **what knowledge an assistant can access**.
### Primary vs additive triggers
- **`persona_id_filter`** is a **primary** trigger. A persona with user files IS explicit
knowledge, so `persona_id_filter` alone can start a knowledge scope. Note: this is
NOT the raw ID of the persona being used — it is only set when the persona's
user files overflowed the LLM context window.
- **`project_id_filter`** is **additive**. It widens an existing scope to include project
files but never restricts on its own — a chat inside a project should still search
team knowledge when no other knowledge is attached.
### No explicit knowledge attached
When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona_id_filter` are all empty/None:
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
- `project_id_filter` is ignored — it never restricts on its own.
- `project_id` and `persona_id` are ignored — they never restrict on their own.
### One explicit knowledge type
@@ -54,40 +44,39 @@ When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona
-- Only document sets
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
-- Only persona user files (overflowed context)
AND (personas contains 42)
-- Only user files
AND (document_id = "uuid-1" OR document_id = "uuid-2")
```
### Multiple explicit knowledge types (OR'd)
```
-- Document sets + persona user files
-- Document sets + user files
AND (
document_sets contains "Engineering"
OR document_id = "uuid-1"
)
```
### Explicit knowledge + overflowing user files
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
```
-- Document sets + persona user files overflowed
AND (
document_sets contains "Engineering"
OR personas contains 42
)
```
### Explicit knowledge + overflowing project files
When an explicit knowledge restriction is in effect **and** `project_id_filter` is set (project files overflowed the LLM context window), `project_id_filter` widens the filter:
```
-- Document sets + project files overflowed
-- User files + project files overflowed
AND (
document_sets contains "Engineering"
OR user_project contains 7
)
-- Persona user files + project files (won't happen in practice;
-- custom personas ignore project files per the precedence rule)
AND (
personas contains 42
document_id = "uuid-1"
OR user_project contains 7
)
```
### Only project_id_filter (no explicit knowledge)
### Only project_id or persona_id (no explicit knowledge)
No knowledge scope filter. The assistant searches everything.
@@ -102,10 +91,11 @@ AND (acl contains ...)
| Filter field | Vespa field | Vespa type | Purpose |
|---|---|---|---|
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
| `persona_id_filter` | `personas` | `array<int>` | Persona tag for overflowing user files (**primary** trigger) |
| `project_id_filter` | `user_project` | `array<int>` | Project tag for overflowing project files (**additive** only) |
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
@@ -218,8 +219,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=min_chunk_index,
max_chunk_index=max_chunk_index,
@@ -284,8 +286,9 @@ class DocumentQuery:
source_types=[],
tags=[],
document_sets=[],
project_id_filter=None,
persona_id_filter=None,
user_file_ids=[],
project_id=None,
persona_id=None,
time_cutoff=None,
min_chunk_index=None,
max_chunk_index=None,
@@ -353,8 +356,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -445,8 +449,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -524,8 +529,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -585,8 +591,9 @@ class DocumentQuery:
source_types=index_filters.source_type or [],
tags=index_filters.tags or [],
document_sets=index_filters.document_set or [],
project_id_filter=index_filters.project_id_filter,
persona_id_filter=index_filters.persona_id_filter,
user_file_ids=index_filters.user_file_ids or [],
project_id=index_filters.project_id,
persona_id=index_filters.persona_id,
time_cutoff=index_filters.time_cutoff,
min_chunk_index=None,
max_chunk_index=None,
@@ -817,8 +824,9 @@ class DocumentQuery:
source_types: list[DocumentSource],
tags: list[Tag],
document_sets: list[str],
project_id_filter: int | None,
persona_id_filter: int | None,
user_file_ids: list[UUID],
project_id: int | None,
persona_id: int | None,
time_cutoff: datetime | None,
min_chunk_index: int | None,
max_chunk_index: int | None,
@@ -849,12 +857,12 @@ class DocumentQuery:
list corresponding to a tag will be retrieved.
document_sets: If supplied, only documents with at least one
document set ID from this list will be retrieved.
project_id_filter: If not None, only documents with this project ID
in user projects will be retrieved. Additive — only applied
when a knowledge scope already exists.
persona_id_filter: If not None, only documents whose personas array
contains this persona ID will be retrieved. Primary — creates
a knowledge scope on its own.
user_file_ids: If supplied, only document IDs in this list will be
retrieved.
project_id: If not None, only documents with this project ID in user
projects will be retrieved.
persona_id: If not None, only documents whose personas array
contains this persona ID will be retrieved.
time_cutoff: Time cutoff for the documents to retrieve. If not None,
Documents which were last updated before this date will not be
returned. For documents which do not have a value for their last
@@ -871,6 +879,10 @@ class DocumentQuery:
NOTE: See DocumentChunk.max_chunk_size.
document_id: The document ID to retrieve. If None, no filter will be
applied for this. Defaults to None.
WARNING: This filters on the same property as user_file_ids.
Although it would never make sense to supply both, note that if
user_file_ids is supplied and does not contain document_id, no
matches will be retrieved.
attached_document_ids: Document IDs explicitly attached to the
assistant. If provided along with hierarchy_node_ids, documents
matching EITHER criteria will be retrieved (OR logic).
@@ -931,6 +943,15 @@ class DocumentQuery:
)
return document_set_filter
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
# Logical OR operator on its elements.
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
for user_file_id in user_file_ids:
user_file_id_filter["bool"]["should"].append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
)
return user_file_id_filter
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
# Logical OR operator on its elements.
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
@@ -1031,17 +1052,14 @@ class DocumentQuery:
# assistant can see. When none are set the assistant searches
# everything.
#
# persona_id_filter is a primary trigger — a persona with user files IS
# explicit knowledge, so it can start a knowledge scope on its own.
#
# project_id_filter is additive — it widens the scope to also cover
# overflowing project files but never restricts on its own (a chat
# inside a project should still search team knowledge).
# project_id / persona_id are additive: they make overflowing user files
# findable but must NOT trigger the restriction on their own (an agent
# with no explicit knowledge should search everything).
has_knowledge_scope = (
attached_document_ids
or hierarchy_node_ids
or user_file_ids
or document_sets
or persona_id_filter is not None
)
if has_knowledge_scope:
@@ -1056,17 +1074,23 @@ class DocumentQuery:
knowledge_filter["bool"]["should"].append(
_get_hierarchy_node_filter(hierarchy_node_ids)
)
if user_file_ids:
knowledge_filter["bool"]["should"].append(
_get_user_file_id_filter(user_file_ids)
)
if document_sets:
knowledge_filter["bool"]["should"].append(
_get_document_set_filter(document_sets)
)
if persona_id_filter is not None:
# Additive: widen scope to also cover overflowing user files, but
# only when an explicit restriction is already in effect.
if project_id is not None:
knowledge_filter["bool"]["should"].append(
_get_persona_filter(persona_id_filter)
_get_user_project_filter(project_id)
)
if project_id_filter is not None:
if persona_id is not None:
knowledge_filter["bool"]["should"].append(
_get_user_project_filter(project_id_filter)
_get_persona_filter(persona_id)
)
filter_clauses.append(knowledge_filter)
@@ -1084,6 +1108,8 @@ class DocumentQuery:
)
if document_id is not None:
# WARNING: If user_file_ids has elements and if none of them are
# document_id, no matches will be retrieved.
filter_clauses.append(
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
)

View File

@@ -199,29 +199,31 @@ def build_vespa_filters(
]
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
# Knowledge scope: explicit knowledge attachments restrict what an
# assistant can see. When none are set, the assistant can see
# everything.
# Knowledge scope: explicit knowledge attachments (document_sets,
# user_file_ids) restrict what an assistant can see. When none are
# set, the assistant can see everything.
#
# persona_id_filter is a primary trigger — a persona with user files IS
# explicit knowledge, so it can start a knowledge scope on its own.
#
# project_id_filter is additive — it widens the scope to also cover
# overflowing project files but never restricts on its own (a chat
# inside a project should still search team knowledge).
# project_id / persona_id are additive: they make overflowing user
# files findable in Vespa but must NOT trigger the restriction on
# their own (an agent with no explicit knowledge should search
# everything).
knowledge_scope_parts: list[str] = []
_append(
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
)
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id_filter))
# project_id_filter only widens an existing scope.
user_file_ids_str = (
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
)
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
# Only include project/persona scopes when an explicit knowledge
# restriction is already in effect — they widen the scope to also
# cover overflowing user files but never restrict on their own.
if knowledge_scope_parts:
_append(
knowledge_scope_parts,
_build_user_project_filter(filters.project_id_filter),
)
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
if len(knowledge_scope_parts) > 1:
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")

View File

@@ -44,6 +44,7 @@ class OnyxErrorCode(Enum):
VALIDATION_ERROR = ("VALIDATION_ERROR", 400)
INVALID_INPUT = ("INVALID_INPUT", 400)
MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400)
QUERY_REJECTED = ("QUERY_REJECTED", 400)
# ------------------------------------------------------------------
# Not Found (404)

View File

@@ -38,7 +38,17 @@ def get_federated_retrieval_functions(
source_types: list[DocumentSource] | None,
document_set_names: list[str] | None,
slack_context: SlackContext | None = None,
user_file_ids: list[UUID] | None = None,
) -> list[FederatedRetrievalInfo]:
# When User Knowledge (user files) is the only knowledge source enabled,
# skip federated connectors entirely. User Knowledge mode means the agent
# should ONLY use uploaded files, not team connectors like Slack.
if user_file_ids and not document_set_names:
logger.debug(
"Skipping all federated connectors: User Knowledge mode enabled "
f"with {len(user_file_ids)} user files and no document sets"
)
return []
# Check for Slack bot context first (regardless of user_id)
if slack_context:

View File

@@ -23,55 +23,45 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
def plaintext_file_name_for_id(file_id: str) -> str:
"""Generate a consistent file name for storing plaintext content of a file."""
return f"plaintext_{file_id}"
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return f"plaintext_{user_file_id}"
def store_plaintext(file_id: str, plaintext_content: str) -> bool:
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
"""
Store plaintext content for a file in the file store.
Store plaintext content for a user file in the file store.
Args:
file_id: The ID of the file (user_file or artifact_file)
user_file_id: The ID of the user file
plaintext_content: The plaintext content to store
Returns:
bool: True if storage was successful, False otherwise
"""
# Skip empty content
if not plaintext_content:
return False
plaintext_file_name = plaintext_file_name_for_id(file_id)
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))
file_store.save_file(
content=file_content,
display_name=f"Plaintext for {file_id}",
display_name=f"Plaintext for user file {user_file_id}",
file_origin=FileOrigin.PLAINTEXT_CACHE,
file_type="text/plain",
file_id=plaintext_file_name,
)
return True
except Exception as e:
logger.warning(f"Failed to store plaintext for {file_id}: {e}")
logger.warning(f"Failed to store plaintext for user file {user_file_id}: {e}")
return False
# --- Convenience wrappers for callers that use user-file UUIDs ---
def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
return plaintext_file_name_for_id(str(user_file_id))
def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool:
"""Store plaintext content for a user file (delegates to :func:`store_plaintext`)."""
return store_plaintext(str(user_file_id), plaintext_content)
def load_chat_file_by_id(file_id: str) -> InMemoryChatFile:
"""Load a file directly from the file store using its file_record ID.

View File

@@ -14,7 +14,7 @@ Usage (Celery tasks and FastAPI handlers):
# hook failed but fail strategy is SOFT — continue with original behavior
...
else:
# result is the response payload dict from the customer's endpoint
# result is a validated Pydantic model instance (spec.response_model)
...
is_reachable update policy
@@ -56,6 +56,7 @@ from typing import Any
import httpx
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
@@ -67,6 +68,7 @@ from onyx.db.hook import update_hook__no_commit
from onyx.db.models import Hook
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.registry import get_hook_point_spec
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.utils.logger import setup_logger
@@ -268,22 +270,21 @@ def _persist_result(
# ---------------------------------------------------------------------------
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
def _execute_hook_inner(
hook: Hook,
payload: dict[str, Any],
) -> dict[str, Any] | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously."""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
) -> BaseModel | HookSoftFailed:
"""Make the HTTP call, validate the response, and return a typed model.
Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure.
"""
timeout = hook.timeout_seconds
hook_id = hook.id
fail_strategy = hook.fail_strategy
endpoint_url = hook.endpoint_url
current_is_reachable: bool | None = hook.is_reachable
hook_point = hook.hook_point # extract before HTTP call per design intent
if not endpoint_url:
raise ValueError(
f"hook_id={hook_id} is active but has no endpoint_url — "
@@ -300,13 +301,37 @@ def execute_hook(
headers: dict[str, str] = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
with httpx.Client(timeout=timeout) as client:
with httpx.Client(
timeout=timeout, follow_redirects=False
) as client: # SSRF guard: never follow redirects
response = client.post(endpoint_url, json=payload, headers=headers)
except Exception as e:
exc = e
duration_ms = int((time.monotonic() - start) * 1000)
outcome = _process_response(response=response, exc=exc, timeout=timeout)
# Validate the response payload against the spec's response_model.
# A validation failure downgrades the outcome to a failure so it is logged,
# is_reachable is left unchanged (server responded — just a bad payload),
# and fail_strategy is respected below.
validated_model: BaseModel | None = None
if outcome.is_success and outcome.response_payload is not None:
spec = get_hook_point_spec(hook_point)
try:
validated_model = spec.response_model.model_validate(
outcome.response_payload
)
except ValidationError as e:
msg = f"Hook response failed validation against {spec.response_model.__name__}: {e}"
outcome = _HttpOutcome(
is_success=False,
updated_is_reachable=None, # server responded — reachability unchanged
status_code=outcome.status_code,
error_message=msg,
response_payload=None,
)
# Skip the is_reachable write when the value would not change — avoids a
# no-op DB round-trip on every call when the hook is already in the expected state.
if outcome.updated_is_reachable == current_is_reachable:
@@ -323,8 +348,40 @@ def execute_hook(
f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}"
)
return HookSoftFailed()
if outcome.response_payload is None:
raise ValueError(
f"response_payload is None for successful hook call (hook_id={hook_id})"
if validated_model is None:
raise OnyxError(
OnyxErrorCode.INTERNAL_ERROR,
f"validated_model is None for successful hook call (hook_id={hook_id})",
)
return outcome.response_payload
return validated_model
def execute_hook(
*,
db_session: Session,
hook_point: HookPoint,
payload: dict[str, Any],
) -> BaseModel | HookSkipped | HookSoftFailed:
"""Execute the hook for the given hook point synchronously.
Returns HookSkipped if no active hook is configured, HookSoftFailed if the
hook failed with SOFT fail strategy, or a validated response model on success.
Raises OnyxError on HARD failure or if the hook is misconfigured.
"""
hook = _lookup_hook(db_session, hook_point)
if isinstance(hook, HookSkipped):
return hook
fail_strategy = hook.fail_strategy
hook_id = hook.id
try:
return _execute_hook_inner(hook, payload)
except Exception:
if fail_strategy == HookFailStrategy.SOFT:
logger.exception(
f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}"
)
return HookSoftFailed()
raise

View File

@@ -51,13 +51,12 @@ class HookPointSpec:
output_schema: ClassVar[dict[str, Any]]
def __init_subclass__(cls, **kwargs: object) -> None:
"""Enforce that every concrete subclass declares all required class attributes.
"""Enforce that every subclass declares all required class attributes.
Called automatically by Python whenever a class inherits from HookPointSpec.
Abstract subclasses (those still carrying unimplemented abstract methods) are
skipped — they are intermediate base classes and may not yet define everything.
Only fully concrete subclasses are validated, ensuring a clear TypeError at
import time rather than a confusing AttributeError at runtime.
Raises TypeError at import time if any required attribute is missing or if
payload_model / response_model are not Pydantic BaseModel subclasses.
input_schema and output_schema are derived automatically from the models.
"""
super().__init_subclass__(**kwargs)
missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)]

View File

@@ -26,8 +26,6 @@ class DocumentIngestionSpec(HookPointSpec):
default_timeout_seconds = 30.0
fail_hard_description = "The document will not be indexed."
default_fail_strategy = HookFailStrategy.HARD
# TODO(Bo-Onyx): update later
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.ue263ual5vdi"
payload_model = DocumentIngestionPayload
response_model = DocumentIngestionResponse

View File

@@ -15,7 +15,7 @@ class QueryProcessingPayload(BaseModel):
description="Email of the user submitting the query, or null if unauthenticated."
)
chat_session_id: str = Field(
description="UUID of the chat session. Always present — the session is guaranteed to exist by the time this hook fires."
description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires."
)
@@ -25,7 +25,7 @@ class QueryProcessingResponse(BaseModel):
default=None,
description=(
"The query to use in the pipeline. "
"Null, empty string, or absent = reject the query."
"Null, empty string, whitespace-only, or absent = reject the query."
),
)
rejection_message: str | None = Field(
@@ -65,8 +65,6 @@ class QueryProcessingSpec(HookPointSpec):
"The query will be blocked and the user will see an error message."
)
default_fail_strategy = HookFailStrategy.HARD
# TODO(Bo-Onyx): update later
docs_url = "https://docs.google.com/document/d/1pGhB8Wcnhhj8rS4baEJL6CX05yFhuIDNk1gbBRiWu94/edit?tab=t.g2r1a1699u87"
payload_model = QueryProcessingPayload
response_model = QueryProcessingResponse

View File

@@ -43,9 +43,6 @@ from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import count_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import (
get_latest_successful_index_attempt_for_cc_pair_id,
)
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
@@ -193,11 +190,6 @@ def get_cc_pair_full_info(
only_finished=False,
)
latest_successful_attempt = get_latest_successful_index_attempt_for_cc_pair_id(
db_session=db_session,
connector_credential_pair_id=cc_pair_id,
)
# Get latest permission sync attempt for status
latest_permission_sync_attempt = None
if cc_pair.access_type == AccessType.SYNC:
@@ -215,11 +207,6 @@ def get_cc_pair_full_info(
cc_pair_id=cc_pair_id,
),
last_index_attempt=latest_attempt,
last_successful_index_time=(
latest_successful_attempt.time_started
if latest_successful_attempt
else None
),
latest_deletion_attempt=get_deletion_attempt_snapshot(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,

View File

@@ -3,7 +3,6 @@ import math
import mimetypes
import os
import zipfile
from datetime import datetime
from io import BytesIO
from typing import Any
from typing import cast
@@ -110,9 +109,6 @@ from onyx.db.federated import fetch_all_federated_connectors_parallel
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempts_by_status
from onyx.db.index_attempt import get_latest_index_attempts_parallel
from onyx.db.index_attempt import (
get_latest_successful_index_attempts_parallel,
)
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import FederatedConnector
from onyx.db.models import IndexAttempt
@@ -1162,26 +1158,21 @@ def get_connector_indexing_status(
),
(),
),
# Get most recent successful index attempts
(
lambda: get_latest_successful_index_attempts_parallel(
request.secondary_index,
),
(),
),
]
if user and user.role == UserRole.ADMIN:
# For Admin users, we already got all the cc pair in editable_cc_pairs
# its not needed to get them again
(
editable_cc_pairs,
federated_connectors,
latest_index_attempts,
latest_finished_index_attempts,
latest_successful_index_attempts,
) = run_functions_tuples_in_parallel(parallel_functions)
non_editable_cc_pairs = []
else:
parallel_functions.append(
# Get non-editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, False, None, True, True, False, True, request.source
@@ -1195,7 +1186,6 @@ def get_connector_indexing_status(
federated_connectors,
latest_index_attempts,
latest_finished_index_attempts,
latest_successful_index_attempts,
non_editable_cc_pairs,
) = run_functions_tuples_in_parallel(parallel_functions)
@@ -1207,9 +1197,6 @@ def get_connector_indexing_status(
latest_finished_index_attempts = cast(
list[IndexAttempt], latest_finished_index_attempts
)
latest_successful_index_attempts = cast(
list[IndexAttempt], latest_successful_index_attempts
)
document_count_info = get_document_counts_for_all_cc_pairs(db_session)
@@ -1219,48 +1206,42 @@ def get_connector_indexing_status(
for connector_id, credential_id, cnt in document_count_info
}
def _attempt_lookup(
attempts: list[IndexAttempt],
) -> dict[int, IndexAttempt]:
return {attempt.connector_credential_pair_id: attempt for attempt in attempts}
cc_pair_to_latest_index_attempt: dict[tuple[int, int], IndexAttempt] = {
(
attempt.connector_credential_pair.connector_id,
attempt.connector_credential_pair.credential_id,
): attempt
for attempt in latest_index_attempts
}
cc_pair_to_latest_index_attempt = _attempt_lookup(latest_index_attempts)
cc_pair_to_latest_finished_index_attempt = _attempt_lookup(
latest_finished_index_attempts
)
cc_pair_to_latest_successful_index_attempt = _attempt_lookup(
latest_successful_index_attempts
)
cc_pair_to_latest_finished_index_attempt: dict[tuple[int, int], IndexAttempt] = {
(
attempt.connector_credential_pair.connector_id,
attempt.connector_credential_pair.credential_id,
): attempt
for attempt in latest_finished_index_attempts
}
def build_connector_indexing_status(
cc_pair: ConnectorCredentialPair,
is_editable: bool,
) -> ConnectorIndexingStatusLite | None:
# TODO remove this to enable ingestion API
if cc_pair.name == "DefaultCCPair":
return None
latest_attempt = cc_pair_to_latest_index_attempt.get(cc_pair.id)
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
cc_pair.id
latest_attempt = cc_pair_to_latest_index_attempt.get(
(cc_pair.connector_id, cc_pair.credential_id)
)
latest_successful_attempt = cc_pair_to_latest_successful_index_attempt.get(
cc_pair.id
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
(cc_pair.connector_id, cc_pair.credential_id)
)
doc_count = cc_pair_to_document_cnt.get(
(cc_pair.connector_id, cc_pair.credential_id), 0
)
return _get_connector_indexing_status_lite(
cc_pair,
latest_attempt,
latest_finished_attempt,
(
latest_successful_attempt.time_started
if latest_successful_attempt
else None
),
is_editable,
doc_count,
cc_pair, latest_attempt, latest_finished_attempt, is_editable, doc_count
)
# Process editable cc_pairs
@@ -1421,7 +1402,6 @@ def _get_connector_indexing_status_lite(
cc_pair: ConnectorCredentialPair,
latest_index_attempt: IndexAttempt | None,
latest_finished_index_attempt: IndexAttempt | None,
last_successful_index_time: datetime | None,
is_editable: bool,
document_cnt: int,
) -> ConnectorIndexingStatusLite | None:
@@ -1455,7 +1435,7 @@ def _get_connector_indexing_status_lite(
else None
),
last_status=latest_index_attempt.status if latest_index_attempt else None,
last_success=last_successful_index_time,
last_success=cc_pair.last_successful_index_time,
docs_indexed=document_cnt,
latest_index_attempt_docs_indexed=(
latest_index_attempt.total_docs_indexed if latest_index_attempt else None

View File

@@ -330,7 +330,6 @@ class CCPairFullInfo(BaseModel):
num_docs_indexed: int, # not ideal, but this must be computed separately
is_editable_for_current_user: bool,
indexing: bool,
last_successful_index_time: datetime | None = None,
last_permission_sync_attempt_status: PermissionSyncStatus | None = None,
permission_syncing: bool = False,
last_permission_sync_attempt_finished: datetime | None = None,
@@ -383,7 +382,9 @@ class CCPairFullInfo(BaseModel):
creator_email=(
cc_pair_model.creator.email if cc_pair_model.creator else None
),
last_indexed=last_successful_index_time,
last_indexed=(
last_index_attempt.time_started if last_index_attempt else None
),
last_pruned=cc_pair_model.last_pruned,
last_full_permission_sync=cls._get_last_full_permission_sync(cc_pair_model),
overall_indexing_speed=overall_indexing_speed,

View File

@@ -6978,9 +6978,9 @@
}
},
"node_modules/flatted": {
"version": "3.4.2",
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz",
"integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==",
"version": "3.3.3",
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz",
"integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==",
"dev": true,
"license": "ISC"
},

View File

@@ -119,8 +119,8 @@ admin_agents_router = APIRouter(prefix=ADMIN_AGENTS_RESOURCE)
agents_router = APIRouter(prefix=AGENTS_RESOURCE)
class IsListedRequest(BaseModel):
is_listed: bool
class IsVisibleRequest(BaseModel):
is_visible: bool
class IsPublicRequest(BaseModel):
@@ -128,19 +128,19 @@ class IsPublicRequest(BaseModel):
class IsFeaturedRequest(BaseModel):
is_featured: bool
featured: bool
@admin_router.patch("/{persona_id}/listed")
@admin_router.patch("/{persona_id}/visible")
def patch_persona_visibility(
persona_id: int,
is_listed_request: IsListedRequest,
is_visible_request: IsVisibleRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_visibility(
persona_id=persona_id,
is_listed=is_listed_request.is_listed,
is_visible=is_visible_request.is_visible,
db_session=db_session,
user=user,
)
@@ -175,7 +175,7 @@ def patch_persona_featured_status(
try:
update_persona_featured(
persona_id=persona_id,
is_featured=is_featured_request.is_featured,
featured=is_featured_request.featured,
db_session=db_session,
user=user,
)

View File

@@ -123,7 +123,7 @@ class PersonaUpsertRequest(BaseModel):
)
search_start_date: datetime | None = None
label_ids: list[int] | None = None
is_featured: bool = False
featured: bool = False
display_priority: int | None = None
# Accept string UUIDs from frontend
user_file_ids: list[str] | None = None
@@ -165,9 +165,9 @@ class MinimalPersonaSnapshot(BaseModel):
icon_name: str | None
is_public: bool
is_listed: bool
is_visible: bool
display_priority: int | None
is_featured: bool
featured: bool
builtin_persona: bool
# Used for filtering
@@ -218,9 +218,9 @@ class MinimalPersonaSnapshot(BaseModel):
uploaded_image_id=persona.uploaded_image_id,
icon_name=persona.icon_name,
is_public=persona.is_public,
is_listed=persona.is_listed,
is_visible=persona.is_visible,
display_priority=persona.display_priority,
is_featured=persona.is_featured,
featured=persona.featured,
builtin_persona=persona.builtin_persona,
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
owner=(
@@ -236,13 +236,13 @@ class PersonaSnapshot(BaseModel):
name: str
description: str
is_public: bool
is_listed: bool
is_visible: bool
uploaded_image_id: str | None
icon_name: str | None
# Return string UUIDs to frontend for consistency
user_file_ids: list[str]
display_priority: int | None
is_featured: bool
featured: bool
builtin_persona: bool
starter_messages: list[StarterMessage] | None
tools: list[ToolSnapshot]
@@ -271,12 +271,12 @@ class PersonaSnapshot(BaseModel):
name=persona.name,
description=persona.description,
is_public=persona.is_public,
is_listed=persona.is_listed,
is_visible=persona.is_visible,
uploaded_image_id=persona.uploaded_image_id,
icon_name=persona.icon_name,
user_file_ids=[str(file.id) for file in persona.user_files],
display_priority=persona.display_priority,
is_featured=persona.is_featured,
featured=persona.featured,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
tools=[
@@ -337,12 +337,12 @@ class FullPersonaSnapshot(PersonaSnapshot):
name=persona.name,
description=persona.description,
is_public=persona.is_public,
is_listed=persona.is_listed,
is_visible=persona.is_visible,
uploaded_image_id=persona.uploaded_image_id,
icon_name=persona.icon_name,
user_file_ids=[str(file.id) for file in persona.user_files],
display_priority=persona.display_priority,
is_featured=persona.is_featured,
featured=persona.featured,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
users=[

View File

@@ -351,7 +351,7 @@ def upsert_project_instructions(
class ProjectPayload(BaseModel):
project: UserProjectSnapshot
files: list[UserFileSnapshot] | None = None
persona_id_to_is_featured: dict[int, bool] | None = None
persona_id_to_featured: dict[int, bool] | None = None
@router.get(
@@ -370,13 +370,11 @@ def get_project_details(
if session.persona_id is not None
]
personas = get_personas_by_ids(persona_ids, db_session)
persona_id_to_is_featured = {
persona.id: persona.is_featured for persona in personas
}
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
return ProjectPayload(
project=project,
files=files,
persona_id_to_is_featured=persona_id_to_is_featured,
persona_id_to_featured=persona_id_to_featured,
)

View File

@@ -142,7 +142,7 @@ def enable_or_disable_kg(
users=[user.id],
groups=[],
label_ids=[],
is_featured=False,
featured=False,
display_priority=0,
user_file_ids=[],
)

View File

@@ -5,7 +5,6 @@ from fastapi import Depends
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from onyx import __version__ as onyx_version
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import is_user_admin
@@ -17,7 +16,6 @@ from onyx.db.models import User
from onyx.db.notification import dismiss_all_notifications
from onyx.db.notification import get_notifications
from onyx.db.notification import update_notification_last_shown
from onyx.hooks.utils import HOOKS_AVAILABLE
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.features.build.utils import is_onyx_craft_enabled
@@ -81,8 +79,6 @@ def fetch_settings(
needs_reindexing=needs_reindexing,
onyx_craft_enabled=onyx_craft_enabled_for_user,
vector_db_enabled=not DISABLE_VECTOR_DB,
hooks_enabled=HOOKS_AVAILABLE,
version=onyx_version,
)

View File

@@ -104,7 +104,3 @@ class UserSettings(Settings):
# False when DISABLE_VECTOR_DB is set — connectors, RAG search, and
# document sets are unavailable.
vector_db_enabled: bool = True
# True when hooks are available: single-tenant deployment with HOOK_ENABLED=true.
hooks_enabled: bool = False
# Application version, read from the ONYX_VERSION env var at startup.
version: str | None = None

View File

@@ -53,12 +53,8 @@ logger = setup_logger()
class SearchToolConfig(BaseModel):
user_selected_filters: BaseFilters | None = None
# Vespa metadata filters for overflowing user files. These are NOT the
# IDs of the current project/persona — they are only set when the
# project's/persona's user files didn't fit in the LLM context window and
# must be found via vector DB search instead.
project_id_filter: int | None = None
persona_id_filter: int | None = None
project_id: int | None = None
persona_id: int | None = None
bypass_acl: bool = False
additional_context: str | None = None
slack_context: SlackContext | None = None
@@ -184,8 +180,8 @@ def construct_tools(
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
project_id=search_tool_config.project_id,
persona_id=search_tool_config.persona_id,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,
@@ -400,7 +396,6 @@ def construct_tools(
tool_definition=saved_tool.mcp_input_schema or {},
connection_config=connection_config,
user_email=user_email,
user_id=str(user.id),
user_oauth_token=mcp_user_oauth_token,
additional_headers=additional_mcp_headers,
)
@@ -433,8 +428,8 @@ def construct_tools(
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
project_id=search_tool_config.project_id,
persona_id=search_tool_config.persona_id,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,

View File

@@ -1,8 +1,6 @@
import json
from typing import Any
from mcp.client.auth import OAuthClientProvider
from onyx.chat.emitter import Emitter
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
@@ -49,7 +47,6 @@ class MCPTool(Tool[None]):
tool_definition: dict[str, Any],
connection_config: MCPConnectionConfig | None = None,
user_email: str = "",
user_id: str = "",
user_oauth_token: str | None = None,
additional_headers: dict[str, str] | None = None,
) -> None:
@@ -59,7 +56,6 @@ class MCPTool(Tool[None]):
self.mcp_server = mcp_server
self.connection_config = connection_config
self.user_email = user_email
self._user_id = user_id
self._user_oauth_token = user_oauth_token
self._additional_headers = additional_headers or {}
@@ -202,42 +198,12 @@ class MCPTool(Tool[None]):
llm_facing_response=llm_facing_response,
)
# For OAuth servers, construct OAuthClientProvider so the MCP SDK
# can refresh expired tokens automatically
auth: OAuthClientProvider | None = None
if (
self.mcp_server.auth_type == MCPAuthenticationType.OAUTH
and self.connection_config is not None
and self._user_id
):
if self.mcp_server.transport == MCPTransport.SSE:
logger.warning(
f"MCP tool '{self._name}': OAuth token refresh is not supported "
f"for SSE transport — auth provider will be ignored. "
f"Re-authentication may be required after token expiry."
)
else:
from onyx.server.features.mcp.api import UNUSED_RETURN_PATH
from onyx.server.features.mcp.api import make_oauth_provider
# user_id is the requesting user's UUID; safe here because
# UNUSED_RETURN_PATH ensures redirect_handler raises immediately
# and user_id is never consulted for Redis state lookups.
auth = make_oauth_provider(
self.mcp_server,
self._user_id,
UNUSED_RETURN_PATH,
self.connection_config.id,
None,
)
tool_result = call_mcp_tool(
self.mcp_server.server_url,
self._name,
llm_kwargs,
connection_headers=headers,
transport=self.mcp_server.transport or MCPTransport.STREAMABLE_HTTP,
auth=auth,
)
logger.info(f"MCP tool '{self._name}' executed successfully")
@@ -282,7 +248,6 @@ class MCPTool(Tool[None]):
"invalid token",
"invalid api key",
"invalid credentials",
"please reconnect to the server",
]
is_auth_error = any(

View File

@@ -764,7 +764,8 @@ class OpenURLTool(Tool[OpenURLToolOverrideKwargs]):
tags=None,
access_control_list=access_control_list,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
project_id_filter=None,
user_file_ids=None,
project_id=None,
)
def _merge_indexed_and_crawled_results(

View File

@@ -244,11 +244,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
document_index: DocumentIndex,
# Respecting user selections
user_selected_filters: BaseFilters | None,
# Vespa metadata filters for overflowing user files. NOT the raw IDs
# of the current project/persona — only set when user files couldn't
# fit in the LLM context and need to be searched via vector DB.
project_id_filter: int | None,
persona_id_filter: int | None = None,
# If the chat is part of a project
project_id: int | None,
# If set, search scopes to files attached to this persona
persona_id: int | None = None,
bypass_acl: bool = False,
# Slack context for federated Slack search (tokens fetched internally)
slack_context: SlackContext | None = None,
@@ -262,8 +261,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
self.llm = llm
self.document_index = document_index
self.user_selected_filters = user_selected_filters
self.project_id_filter = project_id_filter
self.persona_id_filter = persona_id_filter
self.project_id = project_id
self.persona_id = persona_id
self.bypass_acl = bypass_acl
self.slack_context = slack_context
self.enable_slack_search = enable_slack_search
@@ -452,15 +451,13 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
hybrid_alpha=hybrid_alpha,
# For projects, the search scope is the project and has no other limits
user_selected_filters=(
self.user_selected_filters
if self.project_id_filter is None
else None
self.user_selected_filters if self.project_id is None else None
),
bypass_acl=self.bypass_acl,
limit=num_hits,
),
project_id_filter=self.project_id_filter,
persona_id_filter=self.persona_id_filter,
project_id=self.project_id,
persona_id=self.persona_id,
document_index=self.document_index,
user=self.user,
persona=self.persona,
@@ -577,7 +574,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
# Federated retrieval functions (non-Slack; Slack is separate)
if self.project_id_filter is not None:
if self.project_id is not None:
# Project mode ignores user filters → no federated sources
prefetch_source_types = None
else:
@@ -590,12 +587,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
persona_document_sets = (
[ds.name for ds in self.persona.document_sets] if self.persona else None
)
user_file_ids = (
[uf.id for uf in self.persona.user_files] if self.persona else None
)
federated_retrieval_infos = (
get_federated_retrieval_functions(
db_session=db_session,
user_id=self.user.id if self.user else None,
source_types=prefetch_source_types,
document_set_names=persona_document_sets,
user_file_ids=user_file_ids,
)
or []
)

View File

@@ -189,30 +189,3 @@ def mt_cloud_identify(
attribute="identify_user",
fallback=noop_fallback,
)(distinct_id, properties)
def mt_cloud_alias(
distinct_id: str,
anonymous_id: str,
) -> None:
"""Link an anonymous distinct_id to an identified user (Cloud only)."""
if not MULTI_TENANT:
return
fetch_versioned_implementation_with_fallback(
module="onyx.utils.posthog_client",
attribute="alias_user",
fallback=noop_fallback,
)(distinct_id, anonymous_id)
def mt_cloud_get_anon_id(request: Any) -> str | None:
"""Extract the anonymous distinct_id from the app PostHog cookie (Cloud only)."""
if not MULTI_TENANT or not request:
return None
return fetch_versioned_implementation_with_fallback(
module="onyx.utils.posthog_client",
attribute="get_anon_id_from_request",
fallback=noop_fallback,
)(request)

View File

@@ -1,170 +0,0 @@
#!/usr/bin/env python3
"""Benchmarks OpenSearchDocumentIndex latency.
Requires Onyx to be running as it reads search settings from the database.
Usage:
source .venv/bin/activate
python backend/scripts/debugging/opensearch/benchmark_retrieval.py --help
"""
import argparse
import statistics
import time
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.context.search.enums import QueryType
from onyx.context.search.models import IndexFilters
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.indexing.models import IndexingSetting
from scripts.debugging.opensearch.constants import DEV_TENANT_ID
from scripts.debugging.opensearch.embedding_io import load_query_embedding_from_file
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
DEFAULT_N = 50
def main() -> None:
def add_query_embedding_argument(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"-e",
"--embedding-file-path",
type=str,
required=True,
help="Path to the query embedding file.",
)
def add_query_string_argument(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"-q",
"--query",
type=str,
required=True,
help="Query string.",
)
parser = argparse.ArgumentParser(
description="A benchmarking tool to measure OpenSearch retrieval latency."
)
parser.add_argument(
"-n",
type=int,
default=DEFAULT_N,
help=f"Number of samples to take (default: {DEFAULT_N}).",
)
subparsers = parser.add_subparsers(
dest="query_type",
help="Query type to benchmark.",
required=True,
)
hybrid_parser = subparsers.add_parser(
"hybrid", help="Benchmark hybrid retrieval latency."
)
add_query_embedding_argument(hybrid_parser)
add_query_string_argument(hybrid_parser)
keyword_parser = subparsers.add_parser(
"keyword", help="Benchmark keyword retrieval latency."
)
add_query_string_argument(keyword_parser)
semantic_parser = subparsers.add_parser(
"semantic", help="Benchmark semantic retrieval latency."
)
add_query_embedding_argument(semantic_parser)
args = parser.parse_args()
if args.n < 1:
parser.error("Number of samples (-n) must be at least 1.")
if MULTI_TENANT:
CURRENT_TENANT_ID_CONTEXTVAR.set(DEV_TENANT_ID)
SqlEngine.init_engine(pool_size=1, max_overflow=0)
with get_session_with_current_tenant() as session:
search_settings = get_current_search_settings(session)
indexing_setting = IndexingSetting.from_db_model(search_settings)
tenant_state = TenantState(
tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT
)
index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=search_settings.index_name,
embedding_dim=indexing_setting.final_embedding_dim,
embedding_precision=indexing_setting.embedding_precision,
)
filters = IndexFilters(
access_control_list=[],
tenant_id=get_current_tenant_id(),
)
if args.query_type == "hybrid":
embedding = load_query_embedding_from_file(args.embedding_file_path)
search_callable = lambda: index.hybrid_retrieval( # noqa: E731
query=args.query,
query_embedding=embedding,
final_keywords=None,
# This arg doesn't do anything right now.
query_type=QueryType.KEYWORD,
filters=filters,
num_to_retrieve=NUM_RETURNED_HITS,
)
elif args.query_type == "keyword":
search_callable = lambda: index.keyword_retrieval( # noqa: E731
query=args.query,
filters=filters,
num_to_retrieve=NUM_RETURNED_HITS,
)
elif args.query_type == "semantic":
embedding = load_query_embedding_from_file(args.embedding_file_path)
search_callable = lambda: index.semantic_retrieval( # noqa: E731
query_embedding=embedding,
filters=filters,
num_to_retrieve=NUM_RETURNED_HITS,
)
else:
raise ValueError(f"Invalid query type: {args.query_type}")
print(f"Running {args.n} invocations of {args.query_type} retrieval...")
latencies: list[float] = []
for i in range(args.n):
start = time.perf_counter()
results = search_callable()
elapsed_ms = (time.perf_counter() - start) * 1000
latencies.append(elapsed_ms)
# Print the current iteration and its elapsed time on the same line.
print(
f" [{i:>{len(str(args.n))}}] {elapsed_ms:7.1f} ms ({len(results)} results) (top result doc ID, chunk idx: {results[0].document_id if results else 'N/A'}, {results[0].chunk_id if results else 'N/A'})",
end="\r",
flush=True,
)
print()
print(f"Results over {args.n} invocations:")
print(f" mean: {statistics.mean(latencies):7.1f} ms")
print(
f" stdev: {statistics.stdev(latencies):7.1f} ms"
if args.n > 1
else " stdev: N/A (only 1 sample)"
)
print(f" max: {max(latencies):7.1f} ms (i: {latencies.index(max(latencies))})")
print(f" min: {min(latencies):7.1f} ms (i: {latencies.index(min(latencies))})")
if args.n >= 20:
print(f" p50: {statistics.median(latencies):7.1f} ms")
print(f" p95: {statistics.quantiles(latencies, n=20)[-1]:7.1f} ms")
if __name__ == "__main__":
main()

View File

@@ -1 +0,0 @@
DEV_TENANT_ID = "tenant_dev"

View File

@@ -1,64 +0,0 @@
#!/usr/bin/env python3
"""Embeds a query and saves the embedding to a file.
Requires Onyx to be running as it reads search settings from the database.
Usage:
source .venv/bin/activate
python backend/scripts/debugging/opensearch/embed_and_save.py --help
"""
import argparse
import time
from onyx.context.search.utils import get_query_embedding
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from scripts.debugging.opensearch.constants import DEV_TENANT_ID
from scripts.debugging.opensearch.embedding_io import save_query_embedding_to_file
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def main() -> None:
parser = argparse.ArgumentParser(
description="A tool to embed a query and save the embedding to a file."
)
parser.add_argument(
"-q",
"--query",
type=str,
required=True,
help="Query string to embed.",
)
parser.add_argument(
"-f",
"--file-path",
type=str,
required=True,
help="Path to the output file to save the embedding to.",
)
args = parser.parse_args()
if MULTI_TENANT:
CURRENT_TENANT_ID_CONTEXTVAR.set(DEV_TENANT_ID)
SqlEngine.init_engine(pool_size=1, max_overflow=0)
with get_session_with_current_tenant() as session:
start = time.perf_counter()
query_embedding = get_query_embedding(
query=args.query,
db_session=session,
embedding_model=None,
)
elapsed_ms = (time.perf_counter() - start) * 1000
save_query_embedding_to_file(query_embedding, args.file_path)
print(
f"Query embedding of dimension {len(query_embedding)} generated in {elapsed_ms:.1f} ms and saved to {args.file_path}."
)
if __name__ == "__main__":
main()

View File

@@ -1,43 +0,0 @@
from shared_configs.model_server_models import Embedding
def load_query_embedding_from_file(file_path: str) -> Embedding:
"""Returns an embedding vector read from a file.
The file should be formatted as follows:
- The first line should contain an integer representing the embedding
dimension.
- Every subsequent line should contain a float value representing a
component of the embedding vector.
- The size and embedding content should all be delimited by a newline.
Args:
file_path: Path to the file containing the embedding vector.
Returns:
Embedding: The embedding vector.
"""
with open(file_path, "r") as f:
dimension = int(f.readline().strip())
embedding = [float(line.strip()) for line in f.readlines()]
assert len(embedding) == dimension, "Embedding dimension mismatch."
return embedding
def save_query_embedding_to_file(embedding: Embedding, file_path: str) -> None:
"""Saves an embedding vector to a file.
The file will be formatted as follows:
- The first line will contain the embedding dimension.
- Every subsequent line will contain a float value representing a
component of the embedding vector.
- The size and embedding content will all be delimited by a newline.
Args:
embedding: The embedding vector to save.
file_path: Path to the file to save the embedding vector to.
"""
with open(file_path, "w") as f:
f.write(f"{len(embedding)}\n")
for component in embedding:
f.write(f"{component}\n")

View File

@@ -2,10 +2,9 @@
"""A utility to interact with OpenSearch.
Usage:
source .venv/bin/activate
python backend/scripts/debugging/opensearch/opensearch_debug.py --help
python backend/scripts/debugging/opensearch/opensearch_debug.py list
python backend/scripts/debugging/opensearch/opensearch_debug.py delete <index_name>
python3 opensearch_debug.py --help
python3 opensearch_debug.py list
python3 opensearch_debug.py delete <index_name>
Environment Variables:
OPENSEARCH_HOST: OpenSearch host
@@ -108,15 +107,16 @@ def main() -> None:
parser = argparse.ArgumentParser(
description="A utility to interact with OpenSearch."
)
add_standard_arguments(parser)
subparsers = parser.add_subparsers(
dest="command", help="Command to execute.", required=True
)
subparsers.add_parser("list", help="List all indices with info.")
list_parser = subparsers.add_parser("list", help="List all indices with info.")
add_standard_arguments(list_parser)
delete_parser = subparsers.add_parser("delete", help="Delete an index.")
delete_parser.add_argument("index", help="Index name.", type=str)
add_standard_arguments(delete_parser)
args = parser.parse_args()

View File

@@ -83,7 +83,7 @@ def test_stream_chat_message_objects_without_web_search(
db_session=db_session,
tool_ids=[], # Explicitly no tools
document_set_ids=None,
is_listed=True,
is_visible=True,
)
# Create a chat session with our test persona

View File

@@ -91,7 +91,7 @@ def _create_test_persona(
document_sets=[],
users=[user],
groups=[],
is_listed=True,
is_visible=True,
is_public=True,
display_priority=None,
starter_messages=None,

View File

@@ -63,7 +63,7 @@ def _create_persona(db_session: Session, user: User) -> Persona:
document_sets=[],
users=[user],
groups=[],
is_listed=True,
is_visible=True,
is_public=True,
display_priority=None,
starter_messages=None,

View File

@@ -1,248 +0,0 @@
"""Shared fixtures for document_index external dependency tests.
Provides Vespa and OpenSearch index setup, tenant context, and chunk helpers.
"""
import os
import time
import uuid
from collections.abc import Generator
from unittest.mock import patch
import httpx
import pytest
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.interfaces_new import IndexingMetadata
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
from tests.external_dependency_unit.constants import TEST_TENANT_ID
EMBEDDING_DIM = 128
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_chunk(
doc_id: str,
chunk_id: int = 0,
content: str = "test content",
) -> DocMetadataAwareIndexChunk:
"""Create a chunk suitable for external dependency testing (128-dim embeddings)."""
tenant_id = get_current_tenant_id()
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
embeddings = ChunkEmbedding(
full_embedding=[1.0] + [0.0] * (EMBEDDING_DIM - 1),
mini_chunk_embeddings=[],
)
source_document = Document(
id=doc_id,
semantic_identifier="test_doc",
source=DocumentSource.FILE,
sections=[],
metadata={},
title="test title",
)
return DocMetadataAwareIndexChunk(
tenant_id=tenant_id,
access=access,
document_sets=set(),
user_project=[],
personas=[],
boost=0,
aggregated_chunk_boost_factor=0,
ancestor_hierarchy_node_ids=[],
embeddings=embeddings,
title_embedding=[1.0] + [0.0] * (EMBEDDING_DIM - 1),
source_document=source_document,
title_prefix="",
metadata_suffix_keyword="",
metadata_suffix_semantic="",
contextual_rag_reserved_tokens=0,
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
large_chunk_id=None,
chunk_id=chunk_id,
blurb=content[:50],
content=content,
source_links={0: ""},
image_file_id=None,
section_continuation=False,
)
def make_indexing_metadata(
doc_ids: list[str],
old_counts: list[int],
new_counts: list[int],
) -> IndexingMetadata:
return IndexingMetadata(
doc_id_to_chunk_cnt_diff={
doc_id: IndexingMetadata.ChunkCounts(
old_chunk_cnt=old,
new_chunk_cnt=new,
)
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
}
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def tenant_context() -> Generator[None, None, None]:
"""Sets up tenant context for testing."""
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture(scope="module")
def test_index_name() -> Generator[str, None, None]:
yield f"test_index_{uuid.uuid4().hex[:8]}"
@pytest.fixture(scope="module")
def httpx_client() -> Generator[httpx.Client, None, None]:
client = get_vespa_http_client()
try:
yield client
finally:
client.close()
@pytest.fixture(scope="module")
def vespa_index(
httpx_client: httpx.Client,
tenant_context: None, # noqa: ARG001
test_index_name: str,
) -> Generator[VespaIndex, None, None]:
"""Create a Vespa index, wait for schema readiness, and yield it."""
vespa_idx = VespaIndex(
index_name=test_index_name,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
backend_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "..")
)
with patch("os.getcwd", return_value=backend_dir):
vespa_idx.ensure_indices_exist(
primary_embedding_dim=EMBEDDING_DIM,
primary_embedding_precision=EmbeddingPrecision.FLOAT,
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
if not wait_for_vespa_with_timeout(wait_limit=90):
pytest.fail("Vespa is not available.")
# Wait until the schema is actually ready for writes on content nodes. We
# probe by attempting a PUT; 200 means the schema is live, 400 means not
# yet. This is only temporary until we entirely move off of Vespa.
probe_doc = {
"fields": {
"document_id": "__probe__",
"chunk_id": 0,
"blurb": "",
"title": "",
"skip_title": True,
"content": "",
"content_summary": "",
"source_type": "file",
"source_links": "null",
"semantic_identifier": "",
"section_continuation": False,
"large_chunk_reference_ids": [],
"metadata": "{}",
"metadata_list": [],
"metadata_suffix": "",
"chunk_context": "",
"doc_summary": "",
"embeddings": {"full_chunk": [1.0] + [0.0] * (EMBEDDING_DIM - 1)},
"access_control_list": {},
"document_sets": {},
"image_file_name": None,
"user_project": [],
"personas": [],
"boost": 0.0,
"aggregated_chunk_boost_factor": 0.0,
"primary_owners": [],
"secondary_owners": [],
}
}
probe_url = (
f"http://localhost:8081/document/v1/default/{test_index_name}/docid/__probe__"
)
schema_ready = False
for _ in range(60):
resp = httpx_client.post(probe_url, json=probe_doc)
if resp.status_code == 200:
schema_ready = True
httpx_client.delete(probe_url)
break
time.sleep(1)
if not schema_ready:
pytest.fail(f"Vespa schema '{test_index_name}' did not become ready in time.")
yield vespa_idx
@pytest.fixture(scope="module")
def opensearch_old_index(
tenant_context: None, # noqa: ARG001
test_index_name: str,
) -> Generator[OpenSearchOldDocumentIndex, None, None]:
"""Create an OpenSearch index via the old adapter and yield it."""
if not wait_for_opensearch_with_timeout():
pytest.fail("OpenSearch is not available.")
opensearch_idx = OpenSearchOldDocumentIndex(
index_name=test_index_name,
embedding_dim=EMBEDDING_DIM,
embedding_precision=EmbeddingPrecision.FLOAT,
secondary_index_name=None,
secondary_embedding_dim=None,
secondary_embedding_precision=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
)
opensearch_idx.ensure_indices_exist(
primary_embedding_dim=EMBEDDING_DIM,
primary_embedding_precision=EmbeddingPrecision.FLOAT,
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
yield opensearch_idx

View File

@@ -1,203 +0,0 @@
"""External dependency tests for the new DocumentIndex interface.
These tests assume Vespa and OpenSearch are running.
"""
import time
import uuid
from collections.abc import Generator
import httpx
import pytest
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.interfaces_new import DocumentIndex as DocumentIndexNew
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchDocumentIndex,
)
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
from tests.external_dependency_unit.constants import TEST_TENANT_ID
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
from tests.external_dependency_unit.document_index.conftest import make_chunk
from tests.external_dependency_unit.document_index.conftest import (
make_indexing_metadata,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(scope="module")
def vespa_document_index(
vespa_index: VespaIndex, # noqa: ARG001 — ensures schema exists
httpx_client: httpx.Client,
test_index_name: str,
) -> Generator[VespaDocumentIndex, None, None]:
yield VespaDocumentIndex(
index_name=test_index_name,
tenant_state=TenantState(tenant_id=TEST_TENANT_ID, multitenant=False),
large_chunks_enabled=False,
httpx_client=httpx_client,
)
@pytest.fixture(scope="module")
def opensearch_document_index(
opensearch_old_index: OpenSearchOldDocumentIndex, # noqa: ARG001 — ensures index exists
test_index_name: str,
) -> Generator[OpenSearchDocumentIndex, None, None]:
yield OpenSearchDocumentIndex(
tenant_state=TenantState(tenant_id=TEST_TENANT_ID, multitenant=False),
index_name=test_index_name,
embedding_dim=EMBEDDING_DIM,
embedding_precision=EmbeddingPrecision.FLOAT,
)
@pytest.fixture(scope="module")
def document_indices(
vespa_document_index: VespaDocumentIndex,
opensearch_document_index: OpenSearchDocumentIndex,
) -> Generator[list[DocumentIndexNew], None, None]:
yield [opensearch_document_index, vespa_document_index]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestDocumentIndexNew:
"""Tests the new DocumentIndex interface against real Vespa and OpenSearch."""
def test_index_single_new_doc(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""Indexing a single new document returns one record with already_existed=False."""
for document_index in document_indices:
doc_id = f"test_single_new_{uuid.uuid4().hex[:8]}"
chunk = make_chunk(doc_id)
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[1])
results = document_index.index(chunks=[chunk], indexing_metadata=metadata)
assert len(results) == 1
assert results[0].document_id == doc_id
assert results[0].already_existed is False
def test_index_existing_doc_already_existed_true(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""Re-indexing a doc with previous chunks returns already_existed=True."""
for document_index in document_indices:
doc_id = f"test_existing_{uuid.uuid4().hex[:8]}"
chunk = make_chunk(doc_id)
# First index — brand new document.
metadata_first = make_indexing_metadata(
[doc_id], old_counts=[0], new_counts=[1]
)
document_index.index(chunks=[chunk], indexing_metadata=metadata_first)
# Allow near-real-time indexing to settle (needed for Vespa).
time.sleep(1)
# Re-index — old_chunk_cnt=1 signals the document already existed.
metadata_second = make_indexing_metadata(
[doc_id], old_counts=[1], new_counts=[1]
)
results = document_index.index(
chunks=[chunk], indexing_metadata=metadata_second
)
assert len(results) == 1
assert results[0].already_existed is True
def test_index_multiple_docs(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""Indexing multiple documents returns one record per unique document."""
for document_index in document_indices:
doc1 = f"test_multi_1_{uuid.uuid4().hex[:8]}"
doc2 = f"test_multi_2_{uuid.uuid4().hex[:8]}"
chunks = [
make_chunk(doc1, chunk_id=0),
make_chunk(doc1, chunk_id=1),
make_chunk(doc2, chunk_id=0),
]
metadata = make_indexing_metadata(
[doc1, doc2], old_counts=[0, 0], new_counts=[2, 1]
)
results = document_index.index(chunks=chunks, indexing_metadata=metadata)
result_map = {r.document_id: r.already_existed for r in results}
assert len(result_map) == 2
assert result_map[doc1] is False
assert result_map[doc2] is False
def test_index_deduplicates_doc_ids_in_results(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""Multiple chunks from the same document produce only one
DocumentInsertionRecord."""
for document_index in document_indices:
doc_id = f"test_dedup_{uuid.uuid4().hex[:8]}"
chunks = [make_chunk(doc_id, chunk_id=i) for i in range(5)]
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[5])
results = document_index.index(chunks=chunks, indexing_metadata=metadata)
assert len(results) == 1
assert results[0].document_id == doc_id
def test_index_mixed_new_and_existing_docs(
self,
document_indices: list[DocumentIndexNew],
tenant_context: None, # noqa: ARG002
) -> None:
"""A batch with both new and existing documents returns the correct
already_existed flag for each."""
for document_index in document_indices:
existing_doc = f"test_mixed_exist_{uuid.uuid4().hex[:8]}"
new_doc = f"test_mixed_new_{uuid.uuid4().hex[:8]}"
# Pre-index the existing document.
pre_chunk = make_chunk(existing_doc)
pre_metadata = make_indexing_metadata(
[existing_doc], old_counts=[0], new_counts=[1]
)
document_index.index(chunks=[pre_chunk], indexing_metadata=pre_metadata)
time.sleep(1)
# Now index a batch with the existing doc and a new doc.
chunks = [
make_chunk(existing_doc, chunk_id=0),
make_chunk(new_doc, chunk_id=0),
]
metadata = make_indexing_metadata(
[existing_doc, new_doc], old_counts=[1, 0], new_counts=[1, 1]
)
results = document_index.index(chunks=chunks, indexing_metadata=metadata)
result_map = {r.document_id: r.already_existed for r in results}
assert len(result_map) == 2
assert result_map[existing_doc] is True
assert result_map[new_doc] is False

View File

@@ -1,41 +1,275 @@
"""External dependency tests for the old DocumentIndex interface.
These tests assume Vespa and OpenSearch are running.
TODO(ENG-3764)(andrei): Consolidate some of these test fixtures.
"""
import os
import time
import uuid
from collections.abc import Generator
from unittest.mock import patch
import httpx
import pytest
from onyx.access.models import DocumentAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.context.search.models import IndexFilters
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import IndexBatchParams
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.interfaces import VespaDocumentUserFields
from onyx.document_index.opensearch.client import wait_for_opensearch_with_timeout
from onyx.document_index.opensearch.opensearch_document_index import (
OpenSearchOldDocumentIndex,
)
from onyx.document_index.vespa.index import VespaIndex
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocMetadataAwareIndexChunk
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
from tests.external_dependency_unit.document_index.conftest import make_chunk
from tests.external_dependency_unit.constants import TEST_TENANT_ID
@pytest.fixture(scope="module")
def opensearch_available() -> Generator[None, None, None]:
"""Verifies OpenSearch is running, fails the test if not."""
if not wait_for_opensearch_with_timeout():
pytest.fail("OpenSearch is not available.")
yield # Test runs here.
@pytest.fixture(scope="module")
def test_index_name() -> Generator[str, None, None]:
yield f"test_index_{uuid.uuid4().hex[:8]}" # Test runs here.
@pytest.fixture(scope="module")
def tenant_context() -> Generator[None, None, None]:
"""Sets up tenant context for testing."""
token = CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
try:
yield # Test runs here.
finally:
# Reset the tenant context after the test
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@pytest.fixture(scope="module")
def httpx_client() -> Generator[httpx.Client, None, None]:
client = get_vespa_http_client()
try:
yield client
finally:
client.close()
@pytest.fixture(scope="module")
def vespa_document_index(
httpx_client: httpx.Client,
tenant_context: None, # noqa: ARG001
test_index_name: str,
) -> Generator[VespaIndex, None, None]:
vespa_index = VespaIndex(
index_name=test_index_name,
secondary_index_name=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
)
backend_dir = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "..", "..")
)
with patch("os.getcwd", return_value=backend_dir):
vespa_index.ensure_indices_exist(
primary_embedding_dim=128,
primary_embedding_precision=EmbeddingPrecision.FLOAT,
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
# Verify Vespa is running, fails the test if not. Try 90 seconds for testing
# in CI. We have to do this here because this endpoint only becomes live
# once we create an index.
if not wait_for_vespa_with_timeout(wait_limit=90):
pytest.fail("Vespa is not available.")
# Wait until the schema is actually ready for writes on content nodes. We
# probe by attempting a PUT; 200 means the schema is live, 400 means not
# yet. This is so scuffed but running the test is really flakey otherwise;
# this is only temporary until we entirely move off of Vespa.
probe_doc = {
"fields": {
"document_id": "__probe__",
"chunk_id": 0,
"blurb": "",
"title": "",
"skip_title": True,
"content": "",
"content_summary": "",
"source_type": "file",
"source_links": "null",
"semantic_identifier": "",
"section_continuation": False,
"large_chunk_reference_ids": [],
"metadata": "{}",
"metadata_list": [],
"metadata_suffix": "",
"chunk_context": "",
"doc_summary": "",
"embeddings": {"full_chunk": [1.0] + [0.0] * 127},
"access_control_list": {},
"document_sets": {},
"image_file_name": None,
"user_project": [],
"personas": [],
"boost": 0.0,
"aggregated_chunk_boost_factor": 0.0,
"primary_owners": [],
"secondary_owners": [],
}
}
schema_ready = False
probe_url = (
f"http://localhost:8081/document/v1/default/{test_index_name}/docid/__probe__"
)
for _ in range(60):
resp = httpx_client.post(probe_url, json=probe_doc)
if resp.status_code == 200:
schema_ready = True
# Clean up the probe document.
httpx_client.delete(probe_url)
break
time.sleep(1)
if not schema_ready:
pytest.fail(f"Vespa schema '{test_index_name}' did not become ready in time.")
yield vespa_index # Test runs here.
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
# pressing; in CI we should be using fresh instances of dependencies each
# time anyway.
@pytest.fixture(scope="module")
def opensearch_document_index(
opensearch_available: None, # noqa: ARG001
tenant_context: None, # noqa: ARG001
test_index_name: str,
) -> Generator[OpenSearchOldDocumentIndex, None, None]:
opensearch_index = OpenSearchOldDocumentIndex(
index_name=test_index_name,
embedding_dim=128,
embedding_precision=EmbeddingPrecision.FLOAT,
secondary_index_name=None,
secondary_embedding_dim=None,
secondary_embedding_precision=None,
large_chunks_enabled=False,
secondary_large_chunks_enabled=None,
multitenant=MULTI_TENANT,
)
opensearch_index.ensure_indices_exist(
primary_embedding_dim=128,
primary_embedding_precision=EmbeddingPrecision.FLOAT,
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
yield opensearch_index # Test runs here.
# TODO(ENG-3765)(andrei): Explicitly cleanup index. Not immediately
# pressing; in CI we should be using fresh instances of dependencies each
# time anyway.
@pytest.fixture(scope="module")
def document_indices(
vespa_index: VespaIndex,
opensearch_old_index: OpenSearchOldDocumentIndex,
vespa_document_index: VespaIndex,
opensearch_document_index: OpenSearchOldDocumentIndex,
) -> Generator[list[DocumentIndex], None, None]:
# Ideally these are parametrized; doing so with pytest fixtures is tricky.
yield [opensearch_old_index, vespa_index]
yield [opensearch_document_index, vespa_document_index] # Test runs here.
@pytest.fixture(scope="function")
def chunks(
tenant_context: None, # noqa: ARG001
) -> Generator[list[DocMetadataAwareIndexChunk], None, None]:
yield [make_chunk("test_doc", chunk_id=i) for i in range(5)]
result = []
chunk_count = 5
doc_id = "test_doc"
tenant_id = get_current_tenant_id()
access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=True,
)
document_sets: set[str] = set()
user_project: list[int] = list()
personas: list[int] = list()
boost = 0
blurb = "blurb"
content = "content"
title_prefix = ""
doc_summary = ""
chunk_context = ""
title_embedding = [1.0] + [0] * 127
# Full 0 vectors are not supported for cos similarity.
embeddings = ChunkEmbedding(
full_embedding=[1.0] + [0] * 127, mini_chunk_embeddings=[]
)
source_document = Document(
id=doc_id,
semantic_identifier="semantic identifier",
source=DocumentSource.FILE,
sections=[],
metadata={},
title="title",
)
metadata_suffix_keyword = ""
image_file_id = None
source_links: dict[int, str] = {0: ""}
ancestor_hierarchy_node_ids: list[int] = []
for i in range(chunk_count):
result.append(
DocMetadataAwareIndexChunk(
tenant_id=tenant_id,
access=access,
document_sets=document_sets,
user_project=user_project,
personas=personas,
boost=boost,
aggregated_chunk_boost_factor=0,
ancestor_hierarchy_node_ids=ancestor_hierarchy_node_ids,
embeddings=embeddings,
title_embedding=title_embedding,
source_document=source_document,
title_prefix=title_prefix,
metadata_suffix_keyword=metadata_suffix_keyword,
metadata_suffix_semantic="",
contextual_rag_reserved_tokens=0,
doc_summary=doc_summary,
chunk_context=chunk_context,
mini_chunk_texts=None,
large_chunk_id=None,
chunk_id=i,
blurb=blurb,
content=content,
source_links=source_links,
image_file_id=image_file_id,
section_continuation=False,
)
)
yield result # Test runs here.
@pytest.fixture(scope="function")
@@ -102,8 +336,8 @@ class TestDocumentIndexOld:
project_persona_filters = IndexFilters(
access_control_list=None,
tenant_id=tenant_id,
project_id_filter=1,
persona_id_filter=2,
project_id=1,
persona_id=2,
# We need this even though none of the chunks belong to a
# document set because project_id and persona_id are only
# additive filters in the event the agent has knowledge scope;

View File

@@ -1,30 +1,34 @@
"""Tests for OpenSearch assistant knowledge filter construction.
These tests verify that when an assistant (persona) has knowledge attached,
the search filter includes the appropriate scope filters with OR logic (not AND),
ensuring documents are discoverable across knowledge types like attached documents,
hierarchy nodes, document sets, and persona/project user files.
These tests verify that when an assistant (persona) has user files attached,
the search filter includes those user file IDs in the assistant knowledge filter
with OR logic (not AND), ensuring user files are discoverable alongside other
knowledge types like attached documents and hierarchy nodes.
This prevents a regression where user_file_ids were added as a separate AND
filter, making it impossible to find user files when the assistant also had
attached documents or hierarchy nodes (since no document could match both).
"""
from typing import Any
from uuid import UUID
from onyx.configs.constants import DocumentSource
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
from onyx.document_index.opensearch.search import DocumentQuery
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
USER_FILE_ID = UUID("6ad84e45-4450-406c-9d36-fcb5e74aca6b")
ATTACHED_DOCUMENT_ID = "https://docs.google.com/document/d/test-doc-id"
HIERARCHY_NODE_ID = 42
PERSONA_ID = 7
def _get_search_filters(
source_types: list[DocumentSource],
user_file_ids: list[UUID],
attached_document_ids: list[str] | None,
hierarchy_node_ids: list[int] | None,
persona_id_filter: int | None = None,
document_sets: list[str] | None = None,
) -> list[dict[str, Any]]:
return DocumentQuery._get_search_filters(
tenant_state=TenantState(tenant_id=POSTGRES_DEFAULT_SCHEMA, multitenant=False),
@@ -32,14 +36,15 @@ def _get_search_filters(
access_control_list=["user_email:test@example.com"],
source_types=source_types,
tags=[],
document_sets=document_sets or [],
project_id_filter=None,
persona_id_filter=persona_id_filter,
document_sets=[],
project_id=None,
persona_id=None,
time_cutoff=None,
min_chunk_index=None,
max_chunk_index=None,
max_chunk_size=None,
document_id=None,
user_file_ids=user_file_ids,
attached_document_ids=attached_document_ids,
hierarchy_node_ids=hierarchy_node_ids,
)
@@ -48,97 +53,137 @@ def _get_search_filters(
class TestAssistantKnowledgeFilter:
"""Tests for assistant knowledge filter construction in OpenSearch queries."""
def test_persona_id_filter_added_when_knowledge_scope_exists(self) -> None:
"""persona_id_filter should be OR'd into the knowledge scope filter
when explicit knowledge attachments (attached_document_ids,
hierarchy_node_ids, document_sets) are present."""
def test_user_file_ids_included_in_assistant_knowledge_filter(self) -> None:
"""
Tests that user_file_ids are included in the assistant knowledge filter
with OR logic when the assistant has both user files and attached documents.
This prevents the regression where user files were ANDed with other
knowledge types, making them unfindable.
"""
# Under test: Call the filter construction method directly
filter_clauses = _get_search_filters(
source_types=[DocumentSource.FILE],
source_types=[DocumentSource.FILE, DocumentSource.USER_FILE],
user_file_ids=[USER_FILE_ID],
attached_document_ids=[ATTACHED_DOCUMENT_ID],
hierarchy_node_ids=[HIERARCHY_NODE_ID],
persona_id_filter=PERSONA_ID,
)
# Postcondition: Find the assistant knowledge filter (bool with should clauses)
knowledge_filter = None
for clause in filter_clauses:
if "bool" in clause and "should" in clause["bool"]:
# Check if this is the knowledge filter (has minimum_should_match=1)
if clause["bool"].get("minimum_should_match") == 1:
knowledge_filter = clause
break
assert knowledge_filter is not None, (
"Expected to find an assistant knowledge filter with "
"'minimum_should_match: 1'"
)
assert (
knowledge_filter is not None
), "Expected to find an assistant knowledge filter with 'minimum_should_match: 1'"
# The knowledge filter should have 3 should clauses (user files, attached docs, hierarchy nodes)
should_clauses = knowledge_filter["bool"]["should"]
persona_found = any(
clause.get("term", {}).get(PERSONAS_FIELD_NAME, {}).get("value")
== PERSONA_ID
for clause in should_clauses
)
assert persona_found, (
f"Expected persona_id={PERSONA_ID} filter on {PERSONAS_FIELD_NAME} "
f"in should clauses. Got: {should_clauses}"
assert (
len(should_clauses) == 3
), f"Expected 3 should clauses (user_file, attached_doc, hierarchy_node), got {len(should_clauses)}"
# Verify user_file_id is in one of the should clauses
user_file_filter_found = False
for should_clause in should_clauses:
# The user file filter uses a nested bool with should for each file ID
if "bool" in should_clause and "should" in should_clause["bool"]:
for term_clause in should_clause["bool"]["should"]:
if "term" in term_clause:
term_value = term_clause["term"].get(DOCUMENT_ID_FIELD_NAME, {})
if term_value.get("value") == str(USER_FILE_ID):
user_file_filter_found = True
break
assert user_file_filter_found, (
f"Expected user_file_id {USER_FILE_ID} to be in the assistant knowledge "
f"filter's should clauses. Filter structure: {knowledge_filter}"
)
def test_persona_id_filter_alone_creates_knowledge_scope(self) -> None:
"""persona_id_filter IS a primary knowledge scope trigger — a persona
with user files is explicit knowledge, so it should restrict
search on its own."""
def test_user_file_ids_only_creates_knowledge_filter(self) -> None:
"""
Tests that when only user_file_ids are provided (no attached_documents or
hierarchy_nodes), the assistant knowledge filter is still created with the
user file IDs.
"""
# Precondition
filter_clauses = _get_search_filters(
source_types=[],
source_types=[DocumentSource.USER_FILE],
user_file_ids=[USER_FILE_ID],
attached_document_ids=None,
hierarchy_node_ids=None,
persona_id_filter=PERSONA_ID,
)
knowledge_filter = None
# Postcondition: Find filter that contains our user file ID
user_file_filter_found = False
for clause in filter_clauses:
if "bool" in clause and "should" in clause["bool"]:
if clause["bool"].get("minimum_should_match") == 1:
knowledge_filter = clause
break
clause_str = str(clause)
if str(USER_FILE_ID) in clause_str:
user_file_filter_found = True
break
assert (
knowledge_filter is not None
), "Expected persona_id_filter alone to create a knowledge scope filter"
persona_found = any(
clause.get("term", {}).get(PERSONAS_FIELD_NAME, {}).get("value")
== PERSONA_ID
for clause in knowledge_filter["bool"]["should"]
)
assert persona_found, (
f"Expected persona_id={PERSONA_ID} filter in knowledge scope. "
f"Got: {knowledge_filter}"
)
user_file_filter_found
), f"Expected user_file_id {USER_FILE_ID} to be in the filter clauses. Got: {filter_clauses}"
def test_no_separate_user_file_filter_when_assistant_has_knowledge(self) -> None:
"""
Tests that user_file_ids are NOT added as a separate AND filter when the
assistant has other knowledge attached (attached_documents or hierarchy_nodes).
"""
def test_knowledge_filter_with_document_sets_and_persona_filter(self) -> None:
"""document_sets and persona_id_filter should be OR'd together in
the knowledge scope filter."""
filter_clauses = _get_search_filters(
source_types=[],
attached_document_ids=None,
source_types=[DocumentSource.FILE, DocumentSource.USER_FILE],
user_file_ids=[USER_FILE_ID],
attached_document_ids=[ATTACHED_DOCUMENT_ID],
hierarchy_node_ids=None,
persona_id_filter=PERSONA_ID,
document_sets=["engineering"],
)
knowledge_filter = None
# Postcondition: Count how many times user_file_id appears in filter clauses
# It should appear exactly once (in the knowledge filter), not twice
user_file_id_str = str(USER_FILE_ID)
occurrences = 0
for clause in filter_clauses:
if "bool" in clause and "should" in clause["bool"]:
if clause["bool"].get("minimum_should_match") == 1:
knowledge_filter = clause
break
if user_file_id_str in str(clause):
occurrences += 1
assert (
knowledge_filter is not None
), "Expected knowledge filter when document_sets is provided"
assert occurrences == 1, (
f"Expected user_file_id to appear exactly once in filter clauses "
f"(inside the assistant knowledge filter), but found {occurrences} "
f"occurrences. This suggests user_file_ids is being added as both a "
f"separate AND filter and inside the knowledge filter. "
f"Filter clauses: {filter_clauses}"
)
filter_str = str(knowledge_filter)
assert (
"engineering" in filter_str
), "Expected document_set 'engineering' in knowledge filter"
assert (
str(PERSONA_ID) in filter_str
), f"Expected persona_id_filter {PERSONA_ID} in knowledge filter"
def test_multiple_user_files_all_included_in_filter(self) -> None:
"""
Tests that when multiple user files are attached to an assistant,
all of them are included in the filter.
"""
# Precondition
user_file_ids = [
UUID("6ad84e45-4450-406c-9d36-fcb5e74aca6b"),
UUID("7be95f56-5561-517d-ae47-acd6f85bdb7c"),
UUID("8cf06a67-6672-628e-bf58-ade7a96cec8d"),
]
filter_clauses = _get_search_filters(
source_types=[DocumentSource.USER_FILE],
user_file_ids=user_file_ids,
attached_document_ids=[ATTACHED_DOCUMENT_ID],
hierarchy_node_ids=None,
)
# Postcondition: All user file IDs should be in the filter
filter_str = str(filter_clauses)
for user_file_id in user_file_ids:
assert (
str(user_file_id) in filter_str
), f"Expected user_file_id {user_file_id} to be in the filter clauses"

View File

@@ -52,7 +52,7 @@ def _create_test_persona_with_mcp_tool(
document_sets=[],
users=[user],
groups=[],
is_listed=True,
is_visible=True,
is_public=True,
display_priority=None,
starter_messages=None,
@@ -368,10 +368,9 @@ class TestMCPPassThroughOAuth:
def mock_call_mcp_tool(
server_url: str, # noqa: ARG001
tool_name: str, # noqa: ARG001
arguments: dict[str, Any], # noqa: ARG001
kwargs: dict[str, Any], # noqa: ARG001
connection_headers: dict[str, str],
transport: MCPTransport, # noqa: ARG001
auth: Any = None, # noqa: ARG001
) -> dict[str, Any]:
captured_headers.update(connection_headers)
return mocked_response

View File

@@ -62,7 +62,7 @@ def _create_test_persona(db_session: Session, user: User, tools: list[Tool]) ->
document_sets=[],
users=[user],
groups=[],
is_listed=True,
is_visible=True,
is_public=True,
display_priority=None,
starter_messages=None,

View File

@@ -53,7 +53,7 @@ class PersonaManager:
label_ids=label_ids or [],
user_file_ids=user_file_ids or [],
display_priority=display_priority,
is_featured=featured,
featured=featured,
)
response = requests.post(
@@ -79,7 +79,7 @@ class PersonaManager:
users=users or [],
groups=groups or [],
label_ids=label_ids or [],
is_featured=featured,
featured=featured,
)
@staticmethod
@@ -122,7 +122,7 @@ class PersonaManager:
users=[UUID(user) for user in (users or persona.users)],
groups=groups or persona.groups,
label_ids=label_ids or persona.label_ids,
is_featured=featured if featured is not None else persona.is_featured,
featured=featured if featured is not None else persona.featured,
)
response = requests.patch(
@@ -152,7 +152,7 @@ class PersonaManager:
users=[user["email"] for user in updated_persona_data["users"]],
groups=updated_persona_data["groups"],
label_ids=[label["id"] for label in updated_persona_data["labels"]],
is_featured=updated_persona_data["is_featured"],
featured=updated_persona_data["featured"],
)
@staticmethod
@@ -205,13 +205,9 @@ class PersonaManager:
mismatches.append(
("is_public", persona.is_public, fetched_persona.is_public)
)
if fetched_persona.is_featured != persona.is_featured:
if fetched_persona.featured != persona.featured:
mismatches.append(
(
"is_featured",
persona.is_featured,
fetched_persona.is_featured,
)
("featured", persona.featured, fetched_persona.featured)
)
if (
fetched_persona.llm_model_provider_override

View File

@@ -169,7 +169,7 @@ class DATestPersona(BaseModel):
users: list[str]
groups: list[int]
label_ids: list[int]
is_featured: bool = False
featured: bool = False
# Embedded prompt fields (no longer separate prompt_ids)
system_prompt: str | None = None

View File

@@ -14,7 +14,6 @@ from __future__ import annotations
import os
import subprocess
import sys
import time
import uuid
from collections.abc import Generator
@@ -29,9 +28,6 @@ _BACKEND_DIR = os.path.normpath(
os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")
)
_DROP_SCHEMA_MAX_RETRIES = 3
_DROP_SCHEMA_RETRY_DELAY_SEC = 2
# ---------------------------------------------------------------------------
# Helpers
@@ -54,39 +50,6 @@ def _run_script(
)
def _force_drop_schema(engine: Engine, schema: str) -> None:
"""Terminate backends using *schema* then drop it, retrying on deadlock.
Background Celery workers may discover test schemas (they match the
``tenant_`` prefix) and hold locks on tables inside them. A bare
``DROP SCHEMA … CASCADE`` can deadlock with those workers, so we
first kill their connections and retry if we still hit a deadlock.
"""
for attempt in range(_DROP_SCHEMA_MAX_RETRIES):
try:
with engine.connect() as conn:
conn.execute(
text(
"""
SELECT pg_terminate_backend(l.pid)
FROM pg_locks l
JOIN pg_class c ON c.oid = l.relation
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE n.nspname = :schema
AND l.pid != pg_backend_pid()
"""
),
{"schema": schema},
)
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
return
except Exception:
if attempt == _DROP_SCHEMA_MAX_RETRIES - 1:
raise
time.sleep(_DROP_SCHEMA_RETRY_DELAY_SEC)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@@ -141,7 +104,9 @@ def tenant_schema_at_head(
yield schema
_force_drop_schema(engine, schema)
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
@pytest.fixture
@@ -158,7 +123,9 @@ def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
yield schema
_force_drop_schema(engine, schema)
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
@pytest.fixture
@@ -183,7 +150,9 @@ def tenant_schema_bad_rev(engine: Engine) -> Generator[str, None, None]:
yield schema
_force_drop_schema(engine, schema)
with engine.connect() as conn:
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
conn.commit()
# ---------------------------------------------------------------------------

View File

@@ -1,237 +0,0 @@
"""
Integration tests for the "Last Indexed" time displayed on both the
per-connector detail page and the all-connectors listing page.
Expected behavior: "Last Indexed" = time_started of the most recent
successful index attempt for the cc pair, regardless of pagination.
Edge cases:
1. First page of index attempts is entirely errors — last_indexed should
still reflect the older successful attempt beyond page 1.
2. Credential swap — successful attempts, then failures after a
"credential change"; last_indexed should reflect the most recent
successful attempt.
3. Mix of statuses — only the most recent successful attempt matters.
4. COMPLETED_WITH_ERRORS counts as a success for last_indexed purposes.
"""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from onyx.db.models import IndexingStatus
from onyx.server.documents.models import CCPairFullInfo
from onyx.server.documents.models import ConnectorIndexingStatusLite
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
def _wait_for_real_success(
cc_pair: DATestCCPair,
admin: DATestUser,
) -> None:
"""Wait for the initial index attempt to complete successfully."""
CCPairManager.wait_for_indexing_completion(
cc_pair,
after=datetime(2000, 1, 1, tzinfo=timezone.utc),
user_performing_action=admin,
timeout=120,
)
def _get_detail(cc_pair_id: int, admin: DATestUser) -> CCPairFullInfo:
result = CCPairManager.get_single(cc_pair_id, admin)
assert result is not None
return result
def _get_listing(cc_pair_id: int, admin: DATestUser) -> ConnectorIndexingStatusLite:
result = CCPairManager.get_indexing_status_by_id(cc_pair_id, admin)
assert result is not None
return result
def test_last_indexed_first_page_all_errors(reset: None) -> None: # noqa: ARG001
"""When the first page of index attempts is entirely errors but an
older successful attempt exists, both the detail page and the listing
page should still show the time of that successful attempt.
The detail page UI uses page size 8. We insert 10 failed attempts
more recent than the initial success to push the success off page 1.
"""
admin = UserManager.create(name="admin_first_page_errors")
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin)
_wait_for_real_success(cc_pair, admin)
# Baseline: last_success should be set from the initial successful run
listing_before = _get_listing(cc_pair.id, admin)
assert listing_before.last_success is not None
# 10 recent failures push the success off page 1
IndexAttemptManager.create_test_index_attempts(
num_attempts=10,
cc_pair_id=cc_pair.id,
status=IndexingStatus.FAILED,
error_msg="simulated failure",
base_time=datetime.now(tz=timezone.utc),
)
detail = _get_detail(cc_pair.id, admin)
listing = _get_listing(cc_pair.id, admin)
assert (
detail.last_indexed is not None
), "Detail page last_indexed is None even though a successful attempt exists"
assert (
listing.last_success is not None
), "Listing page last_success is None even though a successful attempt exists"
# Both surfaces must agree
assert detail.last_indexed == listing.last_success, (
f"Detail last_indexed={detail.last_indexed} != "
f"listing last_success={listing.last_success}"
)
def test_last_indexed_credential_swap_scenario(reset: None) -> None: # noqa: ARG001
"""Perform an actual credential swap: create connector + cred1 (cc_pair_1),
wait for success, then associate a new cred2 with the same connector
(cc_pair_2), wait for that to succeed, and inject failures on cc_pair_2.
cc_pair_2's last_indexed must reflect cc_pair_2's own success, not
cc_pair_1's older one. Both the detail page and listing page must agree.
"""
admin = UserManager.create(name="admin_cred_swap")
connector = ConnectorManager.create(user_performing_action=admin)
cred1 = CredentialManager.create(user_performing_action=admin)
cc_pair_1 = CCPairManager.create(
connector_id=connector.id,
credential_id=cred1.id,
user_performing_action=admin,
)
_wait_for_real_success(cc_pair_1, admin)
cred2 = CredentialManager.create(user_performing_action=admin, name="swapped-cred")
cc_pair_2 = CCPairManager.create(
connector_id=connector.id,
credential_id=cred2.id,
user_performing_action=admin,
)
_wait_for_real_success(cc_pair_2, admin)
listing_after_swap = _get_listing(cc_pair_2.id, admin)
assert listing_after_swap.last_success is not None
IndexAttemptManager.create_test_index_attempts(
num_attempts=10,
cc_pair_id=cc_pair_2.id,
status=IndexingStatus.FAILED,
error_msg="credential expired",
base_time=datetime.now(tz=timezone.utc),
)
detail = _get_detail(cc_pair_2.id, admin)
listing = _get_listing(cc_pair_2.id, admin)
assert detail.last_indexed is not None
assert listing.last_success is not None
assert detail.last_indexed == listing.last_success, (
f"Detail last_indexed={detail.last_indexed} != "
f"listing last_success={listing.last_success}"
)
def test_last_indexed_mixed_statuses(reset: None) -> None: # noqa: ARG001
"""Mix of in_progress, failed, and successful attempts. Only the most
recent successful attempt's time matters."""
admin = UserManager.create(name="admin_mixed")
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin)
_wait_for_real_success(cc_pair, admin)
now = datetime.now(tz=timezone.utc)
# Success 5 hours ago
IndexAttemptManager.create_test_index_attempts(
num_attempts=1,
cc_pair_id=cc_pair.id,
status=IndexingStatus.SUCCESS,
base_time=now - timedelta(hours=5),
)
# Failures 3 hours ago
IndexAttemptManager.create_test_index_attempts(
num_attempts=3,
cc_pair_id=cc_pair.id,
status=IndexingStatus.FAILED,
error_msg="transient failure",
base_time=now - timedelta(hours=3),
)
# In-progress 1 hour ago
IndexAttemptManager.create_test_index_attempts(
num_attempts=1,
cc_pair_id=cc_pair.id,
status=IndexingStatus.IN_PROGRESS,
base_time=now - timedelta(hours=1),
)
detail = _get_detail(cc_pair.id, admin)
listing = _get_listing(cc_pair.id, admin)
assert detail.last_indexed is not None
assert listing.last_success is not None
assert detail.last_indexed == listing.last_success, (
f"Detail last_indexed={detail.last_indexed} != "
f"listing last_success={listing.last_success}"
)
def test_last_indexed_completed_with_errors(reset: None) -> None: # noqa: ARG001
"""COMPLETED_WITH_ERRORS is treated as a successful attempt (matching
IndexingStatus.is_successful()). When it is the most recent "success"
and later attempts all failed, both surfaces should reflect its time."""
admin = UserManager.create(name="admin_completed_errors")
cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin)
_wait_for_real_success(cc_pair, admin)
now = datetime.now(tz=timezone.utc)
# COMPLETED_WITH_ERRORS 2 hours ago
IndexAttemptManager.create_test_index_attempts(
num_attempts=1,
cc_pair_id=cc_pair.id,
status=IndexingStatus.COMPLETED_WITH_ERRORS,
base_time=now - timedelta(hours=2),
)
# 10 failures after — push everything else off page 1
IndexAttemptManager.create_test_index_attempts(
num_attempts=10,
cc_pair_id=cc_pair.id,
status=IndexingStatus.FAILED,
error_msg="post-partial failure",
base_time=now,
)
detail = _get_detail(cc_pair.id, admin)
listing = _get_listing(cc_pair.id, admin)
assert (
detail.last_indexed is not None
), "COMPLETED_WITH_ERRORS should count as a success for last_indexed"
assert (
listing.last_success is not None
), "COMPLETED_WITH_ERRORS should count as a success for last_success"
assert detail.last_indexed == listing.last_success, (
f"Detail last_indexed={detail.last_indexed} != "
f"listing last_success={listing.last_success}"
)

View File

@@ -35,8 +35,8 @@ def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Per
id=persona_id,
name=name,
description="Test persona for Discord bot tests",
is_listed=True,
is_featured=False,
is_visible=True,
featured=False,
deleted=False,
builtin_persona=False,
)

View File

@@ -25,7 +25,7 @@ def test_cold_startup_default_assistant() -> None:
result = db_session.execute(
text(
"""
SELECT id, name, builtin_persona, is_featured, deleted
SELECT id, name, builtin_persona, featured, deleted
FROM persona
WHERE builtin_persona = true
ORDER BY id
@@ -40,7 +40,7 @@ def test_cold_startup_default_assistant() -> None:
assert default[0] == 0, "Default assistant should have ID 0"
assert default[1] == "Assistant", "Should be named 'Assistant'"
assert default[2] is True, "Should be builtin"
assert default[3] is True, "Should be is_featured"
assert default[3] is True, "Should be featured"
assert default[4] is False, "Should not be deleted"
# Check tools are properly associated

View File

@@ -7,7 +7,6 @@ import json
import pytest
from sqlalchemy import text
from onyx.configs.constants import ANONYMOUS_USER_UUID
from onyx.configs.constants import DEFAULT_BOOST
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from tests.integration.common_utils.reset import downgrade_postgres
@@ -238,6 +237,7 @@ def test_jira_connector_migration() -> None:
upgrade_postgres(
database="postgres", config_name="alembic", revision="da42808081e3"
)
# Verify the upgrade was applied correctly
with get_session_with_current_tenant() as db_session:
results = db_session.execute(
@@ -322,165 +322,3 @@ def test_jira_connector_migration() -> None:
== "https://example.atlassian.net/projects/TEST"
)
assert config_2["batch_size"] == 50
def test_anonymous_user_migration_dedupes_null_notifications() -> None:
downgrade_postgres(
database="postgres", config_name="alembic", revision="base", clear_data=True
)
upgrade_postgres(
database="postgres",
config_name="alembic",
revision="f7ca3e2f45d9",
)
with get_session_with_current_tenant() as db_session:
db_session.execute(
text(
"""
INSERT INTO notification (
id,
notif_type,
user_id,
dismissed,
last_shown,
first_shown,
title,
description,
additional_data
)
VALUES
(
1,
'RELEASE_NOTES',
NULL,
FALSE,
NOW(),
NOW(),
'Onyx v2.10.0 is available!',
'Check out what''s new in v2.10.0',
'{"version":"v2.10.0","link":"https://docs.onyx.app/changelog#v2-10-0"}'::jsonb
),
(
2,
'RELEASE_NOTES',
NULL,
FALSE,
NOW(),
NOW(),
'Onyx v2.10.0 is available!',
'Check out what''s new in v2.10.0',
'{"version":"v2.10.0","link":"https://docs.onyx.app/changelog#v2-10-0"}'::jsonb
)
"""
)
)
db_session.commit()
upgrade_postgres(
database="postgres", config_name="alembic", revision="e7f8a9b0c1d2"
)
with get_session_with_current_tenant() as db_session:
notifications = db_session.execute(
text(
"""
SELECT id, user_id
FROM notification
ORDER BY id
"""
)
).fetchall()
anonymous_user = db_session.execute(
text(
"""
SELECT id, email, role
FROM "user"
WHERE id = :user_id
"""
),
{"user_id": ANONYMOUS_USER_UUID},
).fetchone()
assert len(notifications) == 1
assert notifications[0].id == 2 # Higher id wins when timestamps are equal
assert str(notifications[0].user_id) == ANONYMOUS_USER_UUID
assert anonymous_user is not None
assert anonymous_user.email == "anonymous@onyx.app"
assert anonymous_user.role == "LIMITED"
def test_anonymous_user_migration_collision_with_existing_anonymous_notification() -> (
None
):
"""Test that a NULL-owned notification that collides with an already-existing
anonymous-owned notification is removed during migration."""
downgrade_postgres(
database="postgres", config_name="alembic", revision="base", clear_data=True
)
upgrade_postgres(
database="postgres",
config_name="alembic",
revision="f7ca3e2f45d9",
)
with get_session_with_current_tenant() as db_session:
# Create the anonymous user early so we can insert a notification owned by it
db_session.execute(
text(
"""
INSERT INTO "user" (id, email, hashed_password, is_active, is_superuser, is_verified, role)
VALUES (:id, 'anonymous@onyx.app', '', TRUE, FALSE, TRUE, 'LIMITED')
ON CONFLICT (id) DO NOTHING
"""
),
{"id": ANONYMOUS_USER_UUID},
)
# Insert an anonymous-owned notification (already migrated in a prior partial run)
db_session.execute(
text(
"""
INSERT INTO notification (
id, notif_type, user_id, dismissed, last_shown, first_shown,
title, description, additional_data
)
VALUES
(
1, 'RELEASE_NOTES', :user_id, FALSE, NOW(), NOW(),
'Onyx v2.10.0 is available!',
'Check out what''s new in v2.10.0',
'{"version":"v2.10.0","link":"https://docs.onyx.app/changelog#v2-10-0"}'::jsonb
),
(
2, 'RELEASE_NOTES', NULL, FALSE, NOW(), NOW(),
'Onyx v2.10.0 is available!',
'Check out what''s new in v2.10.0',
'{"version":"v2.10.0","link":"https://docs.onyx.app/changelog#v2-10-0"}'::jsonb
)
"""
),
{"user_id": ANONYMOUS_USER_UUID},
)
db_session.commit()
upgrade_postgres(
database="postgres", config_name="alembic", revision="e7f8a9b0c1d2"
)
with get_session_with_current_tenant() as db_session:
notifications = db_session.execute(
text(
"""
SELECT id, user_id
FROM notification
ORDER BY id
"""
)
).fetchall()
# Only the original anonymous-owned notification should remain;
# the NULL-owned duplicate should have been deleted
assert len(notifications) == 1
assert notifications[0].id == 1
assert str(notifications[0].user_id) == ANONYMOUS_USER_UUID

View File

@@ -33,8 +33,8 @@ def test_unified_assistant(
"search, web browsing, and image generation"
in unified_assistant.description.lower()
)
assert unified_assistant.is_featured is True
assert unified_assistant.is_listed is True
assert unified_assistant.featured is True
assert unified_assistant.is_visible is True
# Verify tools
tools = unified_assistant.tools

View File

@@ -1,5 +1,3 @@
import csv
import io
import os
from datetime import datetime
from datetime import timedelta
@@ -141,12 +139,12 @@ def test_chat_history_csv_export(
assert headers["Content-Type"] == "text/csv; charset=utf-8"
assert "Content-Disposition" in headers
# Use csv.reader to properly handle newlines inside quoted fields
csv_rows = list(csv.reader(io.StringIO(csv_content)))
assert len(csv_rows) == 3 # Header + 2 QA pairs
assert csv_rows[0][0] == "chat_session_id"
assert "user_message" in csv_rows[0]
assert "ai_response" in csv_rows[0]
# Verify CSV content
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 3 # Header + 2 QA pairs
assert "chat_session_id" in csv_content
assert "user_message" in csv_content
assert "ai_response" in csv_content
assert "What was the Q1 revenue?" in csv_content
assert "What about Q2 revenue?" in csv_content
@@ -158,5 +156,5 @@ def test_chat_history_csv_export(
end_time=past_end,
user_performing_action=admin_user,
)
csv_rows = list(csv.reader(io.StringIO(csv_content)))
assert len(csv_rows) == 1 # Only header, no data rows
csv_lines = csv_content.strip().split("\n")
assert len(csv_lines) == 1 # Only header, no data rows

View File

@@ -86,7 +86,7 @@ async def test_get_or_create_user_skips_inactive(
"""Inactive users should not be re-authenticated via JWT."""
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
monkeypatch.setattr(users_module, "verify_email_domain", lambda *_a, **_kw: None)
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
email = "inactive@example.com"
payload: dict[str, Any] = {"email": email}
@@ -126,7 +126,7 @@ async def test_get_or_create_user_handles_race_conditions(
"""If provisioning races, newly inactive users should still be blocked."""
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", True)
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
monkeypatch.setattr(users_module, "verify_email_domain", lambda *_a, **_kw: None)
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
email = "race@example.com"
payload: dict[str, Any] = {"email": email}
@@ -182,7 +182,7 @@ async def test_get_or_create_user_provisions_new_user(
monkeypatch.setattr(users_module, "TRACK_EXTERNAL_IDP_EXPIRY", False)
monkeypatch.setattr(users_module, "generate_password", lambda: "TempPass123!")
monkeypatch.setattr(users_module, "verify_email_is_invited", lambda _: None)
monkeypatch.setattr(users_module, "verify_email_domain", lambda *_a, **_kw: None)
monkeypatch.setattr(users_module, "verify_email_domain", lambda _: None)
recorded: dict[str, Any] = {}

View File

@@ -15,11 +15,11 @@ from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from fastapi import HTTPException
from onyx.auth.schemas import UserCreate
from onyx.auth.users import UserManager
from onyx.configs.constants import AuthType
from onyx.error_handling.exceptions import OnyxError
# Note: Only async test methods are marked with @pytest.mark.asyncio individually
# to avoid warnings on synchronous tests
@@ -89,11 +89,11 @@ class TestDisposableEmailValidation:
user_manager = UserManager(MagicMock())
# Execute & Assert
with pytest.raises(OnyxError) as exc:
with pytest.raises(HTTPException) as exc:
await user_manager.create(mock_user_create)
assert exc.value.status_code == 400
assert "Disposable email" in exc.value.detail
assert "Disposable email" in str(exc.value.detail)
# Verify we never got to tenant provisioning
mock_fetch_ee.assert_not_called()
@@ -138,9 +138,7 @@ class TestDisposableEmailValidation:
pass # We just want to verify domain check passed
# Verify domain validation was called
mock_verify_domain.assert_called_once_with(
mock_user_create.email, is_registration=True
)
mock_verify_domain.assert_called_once_with(mock_user_create.email)
class TestMultiTenantInviteLogic:
@@ -333,7 +331,7 @@ class TestSAMLOIDCBehavior:
mock_get_invited.return_value = ["allowed@example.com"]
# Execute & Assert
with pytest.raises(OnyxError) as exc:
with pytest.raises(HTTPException) as exc:
verify_email_is_invited("newuser@example.com")
assert exc.value.status_code == 403
@@ -387,7 +385,7 @@ class TestWhitelistBehavior:
mock_get_invited.return_value = ["allowed@example.com"]
# Execute & Assert
with pytest.raises(OnyxError) as exc:
with pytest.raises(HTTPException) as exc:
verify_email_is_invited("notallowed@example.com")
assert exc.value.status_code == 403
@@ -422,7 +420,7 @@ class TestSeatLimitEnforcement:
"onyx.auth.users.fetch_ee_implementation_or_noop",
return_value=lambda *_a, **_kw: seat_result,
):
with pytest.raises(OnyxError) as exc:
with pytest.raises(HTTPException) as exc:
enforce_seat_limit(MagicMock())
assert exc.value.status_code == 402
@@ -492,9 +490,7 @@ class TestCaseInsensitiveEmailMatching:
pass
# Verify flow
mock_verify_domain.assert_called_once_with(
user_create.email, is_registration=True
)
mock_verify_domain.assert_called_once_with(user_create.email)
@patch("onyx.auth.users.is_disposable_email")
@patch("onyx.auth.users.verify_email_domain")
@@ -544,7 +540,5 @@ class TestCaseInsensitiveEmailMatching:
pass
# Verify flow
mock_verify_domain.assert_called_once_with(
mock_user_create.email, is_registration=True
)
mock_verify_domain.assert_called_once_with(mock_user_create.email)
mock_verify_invited.assert_called_once() # Existing tenant = invite needed

View File

@@ -1,9 +1,9 @@
import pytest
from fastapi import HTTPException
import onyx.auth.users as users
from onyx.auth.users import verify_email_domain
from onyx.configs.constants import AuthType
from onyx.error_handling.exceptions import OnyxError
def test_verify_email_domain_allows_case_insensitive_match(
@@ -21,7 +21,7 @@ def test_verify_email_domain_rejects_non_whitelisted_domain(
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", ["example.com"], raising=False)
with pytest.raises(OnyxError) as exc:
with pytest.raises(HTTPException) as exc:
verify_email_domain("user@another.com")
assert exc.value.status_code == 400
assert "Email domain is not valid" in exc.value.detail
@@ -32,7 +32,7 @@ def test_verify_email_domain_invalid_email_format(
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", ["example.com"], raising=False)
with pytest.raises(OnyxError) as exc:
with pytest.raises(HTTPException) as exc:
verify_email_domain("userexample.com") # missing '@'
assert exc.value.status_code == 400
assert "Email is not valid" in exc.value.detail
@@ -44,10 +44,10 @@ def test_verify_email_domain_rejects_plus_addressing(
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
with pytest.raises(OnyxError) as exc:
with pytest.raises(HTTPException) as exc:
verify_email_domain("user+tag@gmail.com")
assert exc.value.status_code == 400
assert "'+'" in exc.value.detail
assert "'+'" in str(exc.value.detail)
def test_verify_email_domain_allows_plus_for_onyx_app(
@@ -60,53 +60,13 @@ def test_verify_email_domain_allows_plus_for_onyx_app(
verify_email_domain("user+tag@onyx.app")
def test_verify_email_domain_rejects_dotted_gmail_on_registration(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
with pytest.raises(OnyxError) as exc:
verify_email_domain("first.last@gmail.com", is_registration=True)
assert exc.value.status_code == 400
assert "'.'" in exc.value.detail
def test_verify_email_domain_dotted_gmail_allowed_when_not_registration(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
# Existing user signing in — should not be blocked
verify_email_domain("first.last@gmail.com", is_registration=False)
def test_verify_email_domain_allows_dotted_non_gmail_on_registration(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
verify_email_domain("first.last@example.com", is_registration=True)
def test_verify_email_domain_dotted_gmail_allowed_when_not_cloud(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.BASIC, raising=False)
verify_email_domain("first.last@gmail.com", is_registration=True)
def test_verify_email_domain_rejects_googlemail(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
with pytest.raises(OnyxError) as exc:
with pytest.raises(HTTPException) as exc:
verify_email_domain("user@googlemail.com")
assert exc.value.status_code == 400
assert "gmail.com" in exc.value.detail
assert "gmail.com" in str(exc.value.detail)

View File

@@ -1,9 +1,9 @@
import pytest
from fastapi import HTTPException
import onyx.auth.users as users
from onyx.auth.users import verify_email_is_invited
from onyx.configs.constants import AuthType
from onyx.error_handling.exceptions import OnyxError
@pytest.mark.parametrize("auth_type", [AuthType.SAML, AuthType.OIDC])
@@ -35,7 +35,7 @@ def test_verify_email_is_invited_enforced_for_basic_auth(
raising=False,
)
with pytest.raises(OnyxError) as exc:
with pytest.raises(HTTPException) as exc:
verify_email_is_invited("newuser@example.com")
assert exc.value.status_code == 403

View File

@@ -324,7 +324,7 @@ class TestExtractContextFiles:
class TestSearchFilterDetermination:
"""Verify that determine_search_params correctly resolves
project_id_filter, persona_id_filter, and search_usage based on
search_project_id, search_persona_id, and search_usage based on
the extraction result and the precedence rule.
"""
@@ -353,8 +353,8 @@ class TestSearchFilterDetermination:
uncapped_token_count=100,
),
)
assert result.project_id_filter is None
assert result.persona_id_filter is None
assert result.search_project_id is None
assert result.search_persona_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_custom_persona_files_overflow_persona_filter(self) -> None:
@@ -364,8 +364,8 @@ class TestSearchFilterDetermination:
project_id=99,
extracted_context_files=self._make_context(use_as_search_filter=True),
)
assert result.persona_id_filter == 42
assert result.project_id_filter is None
assert result.search_persona_id == 42
assert result.search_project_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_custom_persona_no_files_no_project_leak(self) -> None:
@@ -375,8 +375,8 @@ class TestSearchFilterDetermination:
project_id=99,
extracted_context_files=self._make_context(),
)
assert result.project_id_filter is None
assert result.persona_id_filter is None
assert result.search_project_id is None
assert result.search_persona_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_default_persona_project_files_fit_disables_search(self) -> None:
@@ -389,7 +389,7 @@ class TestSearchFilterDetermination:
uncapped_token_count=100,
),
)
assert result.project_id_filter is None
assert result.search_project_id is None
assert result.search_usage == SearchToolUsage.DISABLED
def test_default_persona_project_files_overflow_enables_search(self) -> None:
@@ -402,8 +402,8 @@ class TestSearchFilterDetermination:
uncapped_token_count=7000,
),
)
assert result.project_id_filter == 99
assert result.persona_id_filter is None
assert result.search_project_id == 99
assert result.search_persona_id is None
assert result.search_usage == SearchToolUsage.ENABLED
def test_default_persona_no_project_auto(self) -> None:
@@ -413,7 +413,7 @@ class TestSearchFilterDetermination:
project_id=None,
extracted_context_files=self._make_context(),
)
assert result.project_id_filter is None
assert result.search_project_id is None
assert result.search_usage == SearchToolUsage.AUTO
def test_default_persona_project_no_files_disables_search(self) -> None:

View File

@@ -1,4 +1,12 @@
import pytest
from onyx.chat.process_message import _resolve_query_processing_hook_result
from onyx.chat.process_message import remove_answer_citations
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
def test_remove_answer_citations_strips_http_markdown_citation() -> None:
@@ -32,3 +40,94 @@ def test_remove_answer_citations_preserves_non_citation_markdown_links() -> None
remove_answer_citations(answer)
== "See [reference](https://example.com/Function_(mathematics)) for context."
)
# ---------------------------------------------------------------------------
# Query Processing hook response handling (_resolve_query_processing_hook_result)
# ---------------------------------------------------------------------------
def test_wrong_model_type_raises_internal_error() -> None:
"""If the executor ever returns an unexpected BaseModel type, raise INTERNAL_ERROR
rather than an AssertionError or AttributeError."""
from pydantic import BaseModel as PydanticBaseModel
class _OtherModel(PydanticBaseModel):
pass
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(_OtherModel(), "original query")
assert exc_info.value.error_code is OnyxErrorCode.INTERNAL_ERROR
def test_hook_skipped_leaves_message_text_unchanged() -> None:
result = _resolve_query_processing_hook_result(HookSkipped(), "original query")
assert result == "original query"
def test_hook_soft_failed_leaves_message_text_unchanged() -> None:
result = _resolve_query_processing_hook_result(HookSoftFailed(), "original query")
assert result == "original query"
def test_null_query_raises_query_rejected() -> None:
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=None), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_empty_string_query_raises_query_rejected() -> None:
"""Empty string is falsy — must be treated as rejection, same as None."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=""), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_whitespace_only_query_raises_query_rejected() -> None:
"""Whitespace-only string is truthy but meaningless — must be treated as rejection."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=" "), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_absent_query_field_raises_query_rejected() -> None:
"""query defaults to None when not provided."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(), "original query"
)
assert exc_info.value.error_code is OnyxErrorCode.QUERY_REJECTED
def test_rejection_message_surfaced_in_error_when_provided() -> None:
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(
query=None, rejection_message="Queries about X are not allowed."
),
"original query",
)
assert "Queries about X are not allowed." in str(exc_info.value)
def test_fallback_rejection_message_when_none() -> None:
"""No rejection_message → generic fallback used in OnyxError detail."""
with pytest.raises(OnyxError) as exc_info:
_resolve_query_processing_hook_result(
QueryProcessingResponse(query=None, rejection_message=None),
"original query",
)
assert "Your query was rejected." in str(exc_info.value)
def test_nonempty_query_rewrites_message_text() -> None:
result = _resolve_query_processing_hook_result(
QueryProcessingResponse(query="rewritten query"), "original query"
)
assert result == "rewritten query"

View File

@@ -7,6 +7,7 @@ from unittest.mock import patch
import httpx
import pytest
from pydantic import BaseModel
from onyx.db.enums import HookFailStrategy
from onyx.db.enums import HookPoint
@@ -15,13 +16,15 @@ from onyx.error_handling.exceptions import OnyxError
from onyx.hooks.executor import execute_hook
from onyx.hooks.executor import HookSkipped
from onyx.hooks.executor import HookSoftFailed
from onyx.hooks.points.query_processing import QueryProcessingResponse
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_PAYLOAD: dict[str, Any] = {"query": "test", "user_email": "u@example.com"}
_RESPONSE_PAYLOAD: dict[str, Any] = {"rewritten_query": "better test"}
# A valid QueryProcessingResponse payload — used by success-path tests.
_RESPONSE_PAYLOAD: dict[str, Any] = {"query": "better test"}
def _make_hook(
@@ -33,6 +36,7 @@ def _make_hook(
fail_strategy: HookFailStrategy = HookFailStrategy.SOFT,
hook_id: int = 1,
is_reachable: bool | None = None,
hook_point: HookPoint = HookPoint.QUERY_PROCESSING,
) -> MagicMock:
hook = MagicMock()
hook.is_active = is_active
@@ -42,6 +46,7 @@ def _make_hook(
hook.id = hook_id
hook.fail_strategy = fail_strategy
hook.is_reachable = is_reachable
hook.hook_point = hook_point
return hook
@@ -152,7 +157,9 @@ def test_early_exit_returns_skipped_with_no_db_writes(
# ---------------------------------------------------------------------------
def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> None:
def test_success_returns_validated_model_and_sets_reachable(
db_session: MagicMock,
) -> None:
hook = _make_hook()
with (
@@ -173,7 +180,8 @@ def test_success_returns_payload_and_sets_reachable(db_session: MagicMock) -> No
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
_, update_kwargs = mock_update.call_args
assert update_kwargs["is_reachable"] is True
mock_log.assert_not_called()
@@ -202,7 +210,8 @@ def test_success_skips_reachable_write_when_already_true(db_session: MagicMock)
payload=_PAYLOAD,
)
assert result == _RESPONSE_PAYLOAD
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
mock_update.assert_not_called()
@@ -457,16 +466,16 @@ def test_authorization_header(
@pytest.mark.parametrize(
"http_exception,expected_result",
"http_exception,expect_onyx_error",
[
pytest.param(None, _RESPONSE_PAYLOAD, id="success_path"),
pytest.param(httpx.ConnectError("refused"), OnyxError, id="hard_fail_path"),
pytest.param(None, False, id="success_path"),
pytest.param(httpx.ConnectError("refused"), True, id="hard_fail_path"),
],
)
def test_persist_session_failure_is_swallowed(
db_session: MagicMock,
http_exception: Exception | None,
expected_result: Any,
expect_onyx_error: bool,
) -> None:
"""DB session failure in _persist_result must not mask the real return value or OnyxError."""
hook = _make_hook(fail_strategy=HookFailStrategy.HARD)
@@ -489,7 +498,7 @@ def test_persist_session_failure_is_swallowed(
side_effect=http_exception,
)
if expected_result is OnyxError:
if expect_onyx_error:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
@@ -503,7 +512,135 @@ def test_persist_session_failure_is_swallowed(
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert result == expected_result
assert isinstance(result, QueryProcessingResponse)
assert result.query == _RESPONSE_PAYLOAD["query"]
# ---------------------------------------------------------------------------
# Response model validation
# ---------------------------------------------------------------------------
class _StrictResponse(BaseModel):
"""Strict model used to reliably trigger a ValidationError in tests."""
required_field: str # no default → missing key raises ValidationError
def _make_strict_spec() -> MagicMock:
spec = MagicMock()
spec.response_model = _StrictResponse
return spec
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(
HookFailStrategy.SOFT, HookSoftFailed, id="validation_failure_soft"
),
pytest.param(HookFailStrategy.HARD, OnyxError, id="validation_failure_hard"),
],
)
def test_response_validation_failure_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""A response that fails response_model validation is treated like any other
hook failure: logged, is_reachable left unchanged, fail_strategy respected."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch("onyx.hooks.executor.get_session_with_current_tenant"),
patch("onyx.hooks.executor.update_hook__no_commit") as mock_update,
patch("onyx.hooks.executor.create_hook_execution_log__no_commit") as mock_log,
patch(
"onyx.hooks.executor.get_hook_point_spec",
return_value=_make_strict_spec(),
),
patch("httpx.Client") as mock_client_cls,
):
# Response payload is missing required_field → ValidationError
_setup_client(mock_client_cls, response=_make_response(json_return={}))
if expected_type is OnyxError:
with pytest.raises(OnyxError) as exc_info:
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert exc_info.value.error_code is OnyxErrorCode.HOOK_EXECUTION_FAILED
else:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
# is_reachable must not be updated — server responded correctly
mock_update.assert_not_called()
# failure must be logged
mock_log.assert_called_once()
_, log_kwargs = mock_log.call_args
assert log_kwargs["is_success"] is False
assert "validation" in (log_kwargs["error_message"] or "").lower()
# ---------------------------------------------------------------------------
# Outer soft-fail guard in execute_hook
# ---------------------------------------------------------------------------
@pytest.mark.parametrize(
"fail_strategy,expected_type",
[
pytest.param(HookFailStrategy.SOFT, HookSoftFailed, id="unexpected_exc_soft"),
pytest.param(HookFailStrategy.HARD, ValueError, id="unexpected_exc_hard"),
],
)
def test_unexpected_exception_in_inner_respects_fail_strategy(
db_session: MagicMock,
fail_strategy: HookFailStrategy,
expected_type: type,
) -> None:
"""An unexpected exception raised by _execute_hook_inner (not an OnyxError from
HARD fail — e.g. a bug or an assertion error) must be swallowed and return
HookSoftFailed for SOFT strategy, or re-raised for HARD strategy."""
hook = _make_hook(fail_strategy=fail_strategy)
with (
patch("onyx.hooks.executor.HOOKS_AVAILABLE", True),
patch(
"onyx.hooks.executor.get_non_deleted_hook_by_hook_point",
return_value=hook,
),
patch(
"onyx.hooks.executor._execute_hook_inner",
side_effect=ValueError("unexpected bug"),
),
):
if expected_type is HookSoftFailed:
result = execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
assert isinstance(result, HookSoftFailed)
else:
with pytest.raises(ValueError, match="unexpected bug"):
execute_hook(
db_session=db_session,
hook_point=HookPoint.QUERY_PROCESSING,
payload=_PAYLOAD,
)
def test_is_reachable_failure_does_not_prevent_log(db_session: MagicMock) -> None:

View File

@@ -1,6 +1,7 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from uuid import UUID
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
@@ -10,10 +11,10 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_filters,
)
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import PERSONAS
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_PROJECT
@@ -150,30 +151,56 @@ class TestBuildVespaFilters:
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_user_project_filter(self) -> None:
"""Test user project filtering.
def test_user_file_ids_filter(self) -> None:
"""Test user file IDs filtering."""
id1 = UUID("00000000-0000-0000-0000-000000000123")
id2 = UUID("00000000-0000-0000-0000-000000000456")
project_id_filter alone does NOT trigger a knowledge scope restriction
(an agent with no explicit knowledge should search everything).
It only participates when explicit knowledge filters are present.
"""
# project_id_filter alone → no restriction
filters = IndexFilters(access_control_list=[], project_id_filter=789)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
# project_id_filter with document_set → both OR'd
filters = IndexFilters(
access_control_list=[], project_id_filter=789, document_set=["set1"]
)
# Single user file ID (UUID)
filters = IndexFilters(access_control_list=[], user_file_ids=[id1])
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "set1") or ({USER_PROJECT} contains "789")) and '
f'!({HIDDEN}=true) and ({DOCUMENT_ID} contains "{str(id1)}") and ' == result
)
# Multiple user file IDs (UUIDs)
filters = IndexFilters(access_control_list=[], user_file_ids=[id1, id2])
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({DOCUMENT_ID} contains "{str(id1)}" or {DOCUMENT_ID} contains "{str(id2)}") and '
== result
)
# No project id filter
filters = IndexFilters(access_control_list=[], project_id_filter=None)
# Empty user file IDs
filters = IndexFilters(access_control_list=[], user_file_ids=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_user_project_filter(self) -> None:
"""Test user project filtering.
project_id alone does NOT trigger a knowledge scope restriction
(an agent with no explicit knowledge should search everything).
It only participates when explicit knowledge filters are present.
"""
# project_id alone → no restriction
filters = IndexFilters(access_control_list=[], project_id=789)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
# project_id with user_file_ids → both OR'd
id1 = UUID("00000000-0000-0000-0000-000000000123")
filters = IndexFilters(
access_control_list=[], project_id=789, user_file_ids=[id1]
)
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and (({DOCUMENT_ID} contains "{str(id1)}") or ({USER_PROJECT} contains "789")) and '
== result
)
# No project id
filters = IndexFilters(access_control_list=[], project_id=None)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
@@ -206,16 +233,17 @@ class TestBuildVespaFilters:
def test_combined_filters(self) -> None:
"""Test combining multiple filter types.
Knowledge-scope filters (document_set, project_id_filter, persona_id_filter)
are OR'd together, while all other filters are AND'd.
Knowledge-scope filters (document_set, user_file_ids, project_id,
persona_id) are OR'd together, while all other filters are AND'd.
"""
id1 = UUID("00000000-0000-0000-0000-000000000123")
filters = IndexFilters(
access_control_list=["user1", "group1"],
source_type=[DocumentSource.WEB],
tags=[Tag(tag_key="color", tag_value="red")],
document_set=["set1"],
project_id_filter=789,
persona_id_filter=42,
user_file_ids=[id1],
project_id=789,
time_cutoff=datetime(2023, 1, 1, tzinfo=timezone.utc),
)
@@ -226,10 +254,9 @@ class TestBuildVespaFilters:
expected += f'({SOURCE_TYPE} contains "web") and '
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
# Knowledge scope filters are OR'd together
# (persona_id_filter is primary, project_id_filter is additive — order reflects this)
expected += (
f'(({DOCUMENT_SETS} contains "set1")'
f' or ({PERSONAS} contains "42")'
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
f' or ({USER_PROJECT} contains "789")'
f") and "
)
@@ -249,37 +276,18 @@ class TestBuildVespaFilters:
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
def test_persona_id_filter_is_primary_knowledge_scope(self) -> None:
"""persona_id_filter alone should trigger a knowledge scope restriction
(a persona with user files IS explicit knowledge)."""
filters = IndexFilters(access_control_list=[], persona_id_filter=42)
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({PERSONAS} contains "42") and ' == result
def test_persona_id_filter_with_project_id_filter(self) -> None:
"""When persona_id_filter triggers the scope, project_id_filter should be
OR'd in additively."""
filters = IndexFilters(
access_control_list=[], persona_id_filter=42, project_id_filter=789
)
result = build_vespa_filters(filters)
expected = (
f"!({HIDDEN}=true) and "
f'(({PERSONAS} contains "42") or ({USER_PROJECT} contains "789")) and '
)
assert expected == result
def test_knowledge_scope_document_set_and_persona_filter_ored(self) -> None:
"""Document set filter and persona_id_filter must be OR'd so that
connector documents (in the set) and persona user files can
both be found."""
def test_knowledge_scope_document_set_and_user_files_ored(self) -> None:
"""Document set filter and user file IDs must be OR'd so that
connector documents (in the set) and user files (with specific
IDs) can both be found."""
id1 = UUID("00000000-0000-0000-0000-000000000123")
filters = IndexFilters(
access_control_list=[],
document_set=["engineering"],
persona_id_filter=42,
user_file_ids=[id1],
)
result = build_vespa_filters(filters)
expected = f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "engineering") or ({PERSONAS} contains "42")) and '
expected = f'!({HIDDEN}=true) and (({DOCUMENT_SETS} contains "engineering") or ({DOCUMENT_ID} contains "{str(id1)}")) and '
assert expected == result
def test_acl_large_list_uses_weighted_set(self) -> None:

View File

@@ -70,8 +70,8 @@ function Prompt-OrDefault {
function Confirm-Action {
param([string]$Description)
$reply = (Prompt-OrDefault "Install $Description? (Y/n) [default: Y]" "Y").Trim().ToLower()
if ($reply -match '^n') {
$reply = Prompt-OrDefault "Install $Description? (Y/n) [default: Y]" "Y"
if ($reply -match '^[Nn]') {
Print-Warning "Skipping: $Description"
return $false
}
@@ -364,7 +364,7 @@ function Invoke-OnyxDeleteData {
Write-Host "`n=== WARNING: This will permanently delete all Onyx data ===`n" -ForegroundColor Red
Print-Warning "This action will remove all Onyx containers, volumes, files, and user data."
if (Test-Interactive) {
$confirm = Prompt-OrDefault "Type 'DELETE' to confirm" ""
$confirm = Read-Host "Type 'DELETE' to confirm"
if ($confirm -ne "DELETE") { Print-Info "Operation cancelled."; return }
} else {
Print-OnyxError "Cannot confirm destructive operation in non-interactive mode."
@@ -720,7 +720,6 @@ function Invoke-WslInstall {
# Ensure WSL2 is available
Invoke-NativeQuiet { wsl --status }
if ($LASTEXITCODE -ne 0) {
if (-not (Confirm-Action "WSL2 (Windows Subsystem for Linux)")) { exit 1 }
Print-Info "Installing WSL2..."
try {
$proc = Start-Process wsl -ArgumentList "--install", "--no-distribution" -Wait -PassThru -NoNewWindow
@@ -807,7 +806,7 @@ function Main {
if (Test-Interactive) {
Write-Host "`nPlease acknowledge and press Enter to continue..." -ForegroundColor Yellow
$null = Prompt-OrDefault "" ""
Read-Host | Out-Null
} else {
Write-Host "`nRunning in non-interactive mode - proceeding automatically..." -ForegroundColor Yellow
}
@@ -903,8 +902,8 @@ function Main {
if ($resourceWarning) {
Print-Warning "Onyx recommends at least $($script:ExpectedDockerRamGB)GB RAM and $($script:ExpectedDiskGB)GB disk for standard mode."
Print-Warning "Lite mode requires less (1-4GB RAM, 8-16GB disk) but has no vector database."
$reply = (Prompt-OrDefault "Do you want to continue anyway? (Y/n)" "y").Trim().ToLower()
if ($reply -notmatch '^y') { Print-Info "Installation cancelled."; exit 1 }
$reply = Prompt-OrDefault "Do you want to continue anyway? (Y/n)" "y"
if ($reply -notmatch '^[Yy]') { Print-Info "Installation cancelled."; exit 1 }
Print-Info "Proceeding despite resource limitations..."
}
@@ -926,8 +925,8 @@ function Main {
if ($composeVersion -ne "unknown" -and (Compare-SemVer $composeVersion "2.24.0") -lt 0) {
Print-Warning "Docker Compose $composeVersion is older than 2.24.0 (required for env_file format)."
Print-Info "Update Docker Desktop or install a newer Docker Compose. Installation may fail."
$reply = (Prompt-OrDefault "Continue anyway? (Y/n)" "y").Trim().ToLower()
if ($reply -notmatch '^y') { exit 1 }
$reply = Prompt-OrDefault "Continue anyway? (Y/n)" "y"
if ($reply -notmatch '^[Yy]') { exit 1 }
}
$liteOverlayPath = Join-Path $deploymentDir $script:LiteComposeFile

View File

@@ -174,42 +174,34 @@ ensure_file() {
# --- Interactive prompt helpers ---
is_interactive() {
[[ "$NO_PROMPT" = false ]] && [[ -r /dev/tty ]] && [[ -w /dev/tty ]]
}
read_prompt_line() {
local prompt_text="$1"
if ! is_interactive; then
REPLY=""
return
fi
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
IFS= read -r REPLY < /dev/tty || REPLY=""
}
read_prompt_char() {
local prompt_text="$1"
if ! is_interactive; then
REPLY=""
return
fi
[[ -n "$prompt_text" ]] && printf "%s" "$prompt_text" > /dev/tty
IFS= read -r -n 1 REPLY < /dev/tty || REPLY=""
printf "\n" > /dev/tty
[[ "$NO_PROMPT" = false ]]
}
prompt_or_default() {
local prompt_text="$1"
local default_value="$2"
read_prompt_line "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
if is_interactive; then
read -p "$prompt_text" -r REPLY
if [[ -z "$REPLY" ]]; then
REPLY="$default_value"
fi
else
REPLY="$default_value"
fi
}
prompt_yn_or_default() {
local prompt_text="$1"
local default_value="$2"
read_prompt_char "$prompt_text"
[[ -z "$REPLY" ]] && REPLY="$default_value"
if is_interactive; then
read -p "$prompt_text" -n 1 -r
echo ""
if [[ -z "$REPLY" ]]; then
REPLY="$default_value"
fi
else
REPLY="$default_value"
fi
}
confirm_action() {
@@ -310,8 +302,8 @@ if [ "$DELETE_DATA_MODE" = true ]; then
echo " • All user data and documents"
echo ""
if is_interactive; then
prompt_or_default "Are you sure you want to continue? Type 'DELETE' to confirm: " ""
echo "" > /dev/tty
read -p "Are you sure you want to continue? Type 'DELETE' to confirm: " -r
echo ""
if [ "$REPLY" != "DELETE" ]; then
print_info "Operation cancelled."
exit 0
@@ -505,7 +497,7 @@ echo ""
if is_interactive; then
echo -e "${YELLOW}${BOLD}Please acknowledge and press Enter to continue...${NC}"
read_prompt_line ""
read -r
echo ""
else
echo -e "${YELLOW}${BOLD}Running in non-interactive mode - proceeding automatically...${NC}"

View File

@@ -8,7 +8,6 @@ This directory contains Terraform modules to provision the core AWS infrastructu
- `postgres`: Creates an Amazon RDS for PostgreSQL instance and returns a connection URL
- `redis`: Creates an ElastiCache for Redis replication group
- `s3`: Creates an S3 bucket and locks access to a provided S3 VPC endpoint
- `opensearch`: Creates an Amazon OpenSearch domain for managed search workloads
- `onyx`: A higher-level composition that wires the above modules together for a complete, opinionated stack
Use the `onyx` module if you want a working EKS + Postgres + Redis + S3 stack with sane defaults. Use the individual modules if you need more granular control.
@@ -129,7 +128,6 @@ Inputs (common):
- `postgres_username`, `postgres_password`
- `create_vpc` (default true) or existing VPC details and `s3_vpc_endpoint_id`
- WAF controls such as `waf_allowed_ip_cidrs`, `waf_common_rule_set_count_rules`, rate limits, geo restrictions, and logging retention
- Optional OpenSearch controls such as `enable_opensearch`, sizing, credentials, and log retention
### `vpc`
- Builds a VPC sized for EKS with multiple private and public subnets
@@ -161,11 +159,6 @@ Key inputs include:
### `s3`
- Creates an S3 bucket for file storage and scopes access to the provided S3 gateway VPC endpoint
### `opensearch`
- Creates an Amazon OpenSearch domain inside the VPC
- Supports custom subnets, security groups, fine-grained access control, encryption, and CloudWatch log publishing
- Outputs domain endpoints, ARN, and the managed security group ID when it creates one
## Installing the Onyx Helm chart (after Terraform)
Once the cluster is active, deploy application workloads via Helm. You can use the chart in `deployment/helm/charts/onyx`.

View File

@@ -1,13 +1,12 @@
locals {
workspace = terraform.workspace
name = var.name
merged_tags = merge(var.tags, { tenant = local.name, environment = local.workspace })
vpc_name = "${var.name}-vpc-${local.workspace}"
cluster_name = "${var.name}-${local.workspace}"
bucket_name = "${var.name}-file-store-${local.workspace}"
redis_name = "${var.name}-redis-${local.workspace}"
postgres_name = "${var.name}-postgres-${local.workspace}"
opensearch_name = var.opensearch_domain_name != null ? var.opensearch_domain_name : "${var.name}-opensearch-${local.workspace}"
workspace = terraform.workspace
name = var.name
merged_tags = merge(var.tags, { tenant = local.name, environment = local.workspace })
vpc_name = "${var.name}-vpc-${local.workspace}"
cluster_name = "${var.name}-${local.workspace}"
bucket_name = "${var.name}-file-store-${local.workspace}"
redis_name = "${var.name}-redis-${local.workspace}"
postgres_name = "${var.name}-postgres-${local.workspace}"
vpc_id = var.create_vpc ? module.vpc[0].vpc_id : var.vpc_id
private_subnets = var.create_vpc ? module.vpc[0].private_subnets : var.private_subnets
@@ -97,38 +96,3 @@ module "waf" {
enable_logging = var.waf_enable_logging
log_retention_days = var.waf_log_retention_days
}
module "opensearch" {
source = "../opensearch"
count = var.enable_opensearch ? 1 : 0
name = local.opensearch_name
vpc_id = local.vpc_id
# Prefer setting subnet_ids explicitly if the state of private_subnets is
# unclear.
subnet_ids = length(var.opensearch_subnet_ids) > 0 ? var.opensearch_subnet_ids : slice(local.private_subnets, 0, 3)
ingress_cidrs = [local.vpc_cidr_block]
tags = local.merged_tags
# Reuse EKS security groups
security_group_ids = [module.eks.node_security_group_id, module.eks.cluster_security_group_id]
# Configuration
engine_version = var.opensearch_engine_version
instance_type = var.opensearch_instance_type
instance_count = var.opensearch_instance_count
dedicated_master_enabled = var.opensearch_dedicated_master_enabled
dedicated_master_type = var.opensearch_dedicated_master_type
multi_az_with_standby_enabled = var.opensearch_multi_az_with_standby_enabled
ebs_volume_size = var.opensearch_ebs_volume_size
ebs_throughput = var.opensearch_ebs_throughput
# Authentication
internal_user_database_enabled = var.opensearch_internal_user_database_enabled
master_user_name = var.opensearch_master_user_name
master_user_password = var.opensearch_master_user_password
# Logging
enable_logging = var.opensearch_enable_logging
log_retention_days = var.opensearch_log_retention_days
}

View File

@@ -32,18 +32,3 @@ output "postgres_dbi_resource_id" {
description = "RDS DB instance resource id"
value = module.postgres.dbi_resource_id
}
output "opensearch_endpoint" {
description = "OpenSearch domain endpoint"
value = var.enable_opensearch ? module.opensearch[0].domain_endpoint : null
}
output "opensearch_dashboard_endpoint" {
description = "OpenSearch Dashboards endpoint"
value = var.enable_opensearch ? module.opensearch[0].kibana_endpoint : null
}
output "opensearch_domain_arn" {
description = "OpenSearch domain ARN"
value = var.enable_opensearch ? module.opensearch[0].domain_arn : null
}

View File

@@ -152,101 +152,3 @@ variable "waf_log_retention_days" {
description = "Number of days to retain WAF logs"
default = 90
}
# OpenSearch Configuration Variables
variable "enable_opensearch" {
type = bool
description = "Whether to create an OpenSearch domain"
default = false
}
variable "opensearch_engine_version" {
type = string
description = "OpenSearch engine version"
default = "3.3"
}
variable "opensearch_instance_type" {
type = string
description = "Instance type for OpenSearch data nodes"
default = "r8g.large.search"
}
variable "opensearch_instance_count" {
type = number
description = "Number of OpenSearch data nodes"
default = 3
}
variable "opensearch_dedicated_master_enabled" {
type = bool
description = "Whether to enable dedicated master nodes for OpenSearch"
default = true
}
variable "opensearch_dedicated_master_type" {
type = string
description = "Instance type for dedicated master nodes"
default = "m7g.large.search"
}
variable "opensearch_multi_az_with_standby_enabled" {
type = bool
description = "Whether to enable Multi-AZ with Standby deployment"
default = true
}
variable "opensearch_ebs_volume_size" {
type = number
description = "EBS volume size in GiB per OpenSearch node"
default = 512
}
variable "opensearch_ebs_throughput" {
type = number
description = "Throughput in MiB/s for gp3 volumes"
default = 256
}
variable "opensearch_internal_user_database_enabled" {
type = bool
description = "Whether to enable the internal user database for fine-grained access control"
default = true
}
variable "opensearch_master_user_name" {
type = string
description = "Master user name for OpenSearch internal user database"
default = null
sensitive = true
}
variable "opensearch_master_user_password" {
type = string
description = "Master user password for OpenSearch internal user database"
default = null
sensitive = true
}
variable "opensearch_domain_name" {
type = string
description = "Override the OpenSearch domain name. If null, defaults to {name}-opensearch-{workspace}."
default = null
}
variable "opensearch_enable_logging" {
type = bool
default = false
}
variable "opensearch_log_retention_days" {
type = number
description = "Number of days to retain OpenSearch CloudWatch logs (0 = never expire)"
default = 0
}
variable "opensearch_subnet_ids" {
type = list(string)
description = "Subnet IDs for OpenSearch. If empty, uses first 3 private subnets."
default = []
}

View File

@@ -1,229 +0,0 @@
# OpenSearch domain security group
resource "aws_security_group" "opensearch_sg" {
count = length(var.security_group_ids) > 0 ? 0 : 1
name = "${var.name}-sg"
description = "Allow inbound traffic to OpenSearch from VPC"
vpc_id = var.vpc_id
tags = var.tags
ingress {
from_port = 443
to_port = 443
protocol = "tcp"
cidr_blocks = var.ingress_cidrs
}
egress {
from_port = 0
to_port = 0
protocol = "-1"
cidr_blocks = ["0.0.0.0/0"]
}
}
# Service-linked role for OpenSearch (required for VPC deployment)
# This may already exist in your account - if so, import it or set create_service_linked_role = false
resource "aws_iam_service_linked_role" "opensearch" {
count = var.create_service_linked_role ? 1 : 0
aws_service_name = "opensearchservice.amazonaws.com"
}
# IAM policy for OpenSearch access
data "aws_caller_identity" "current" {}
data "aws_region" "current" {}
# KMS key lookup for encryption at rest
data "aws_kms_key" "opensearch" {
key_id = "alias/aws/es"
}
# Access policy - allows all principals within the VPC (secured by VPC + security groups)
data "aws_iam_policy_document" "opensearch_access" {
statement {
effect = "Allow"
principals {
type = "AWS"
identifiers = ["*"]
}
actions = ["es:*"]
resources = [
"arn:aws:es:${data.aws_region.current.id}:${data.aws_caller_identity.current.account_id}:domain/${var.name}/*"
]
}
}
# OpenSearch domain
resource "aws_opensearch_domain" "main" {
domain_name = var.name
engine_version = "OpenSearch_${var.engine_version}"
cluster_config {
instance_type = var.instance_type
instance_count = var.instance_count
zone_awareness_enabled = var.zone_awareness_enabled
dedicated_master_enabled = var.dedicated_master_enabled
dedicated_master_type = var.dedicated_master_enabled ? var.dedicated_master_type : null
dedicated_master_count = var.dedicated_master_enabled ? var.dedicated_master_count : null
multi_az_with_standby_enabled = var.multi_az_with_standby_enabled
warm_enabled = var.warm_enabled
warm_type = var.warm_enabled ? var.warm_type : null
warm_count = var.warm_enabled ? var.warm_count : null
dynamic "zone_awareness_config" {
for_each = var.zone_awareness_enabled ? [1] : []
content {
availability_zone_count = var.availability_zone_count
}
}
dynamic "cold_storage_options" {
for_each = var.cold_storage_enabled ? [1] : []
content {
enabled = true
}
}
}
ebs_options {
ebs_enabled = true
volume_type = var.ebs_volume_type
volume_size = var.ebs_volume_size
iops = var.ebs_volume_type == "gp3" || var.ebs_volume_type == "io1" ? var.ebs_iops : null
throughput = var.ebs_volume_type == "gp3" ? var.ebs_throughput : null
}
vpc_options {
subnet_ids = var.subnet_ids
security_group_ids = length(var.security_group_ids) > 0 ? var.security_group_ids : [aws_security_group.opensearch_sg[0].id]
}
encrypt_at_rest {
enabled = true
kms_key_id = var.kms_key_id != null ? var.kms_key_id : data.aws_kms_key.opensearch.arn
}
node_to_node_encryption {
enabled = true
}
domain_endpoint_options {
enforce_https = true
tls_security_policy = var.tls_security_policy
}
advanced_security_options {
enabled = true
anonymous_auth_enabled = false
internal_user_database_enabled = var.internal_user_database_enabled
dynamic "master_user_options" {
for_each = var.internal_user_database_enabled ? [1] : []
content {
master_user_name = var.master_user_name
master_user_password = var.master_user_password
}
}
dynamic "master_user_options" {
for_each = var.internal_user_database_enabled ? [] : [1]
content {
master_user_arn = var.master_user_arn
}
}
}
advanced_options = var.advanced_options
access_policies = data.aws_iam_policy_document.opensearch_access.json
auto_tune_options {
desired_state = var.auto_tune_enabled ? "ENABLED" : "DISABLED"
rollback_on_disable = var.auto_tune_rollback_on_disable
}
off_peak_window_options {
enabled = var.off_peak_window_enabled
dynamic "off_peak_window" {
for_each = var.off_peak_window_enabled ? [1] : []
content {
window_start_time {
hours = var.off_peak_window_start_hours
minutes = var.off_peak_window_start_minutes
}
}
}
}
software_update_options {
auto_software_update_enabled = var.auto_software_update_enabled
}
dynamic "log_publishing_options" {
for_each = var.enable_logging ? ["INDEX_SLOW_LOGS", "SEARCH_SLOW_LOGS", "ES_APPLICATION_LOGS"] : []
content {
cloudwatch_log_group_arn = "arn:aws:logs:${data.aws_region.current.name}:${data.aws_caller_identity.current.account_id}:log-group:${local.log_group_name}"
log_type = log_publishing_options.value
}
}
tags = var.tags
depends_on = [
aws_iam_service_linked_role.opensearch,
aws_cloudwatch_log_resource_policy.opensearch
]
lifecycle {
precondition {
condition = !var.internal_user_database_enabled || var.master_user_name != null
error_message = "master_user_name is required when internal_user_database_enabled is true."
}
precondition {
condition = !var.internal_user_database_enabled || var.master_user_password != null
error_message = "master_user_password is required when internal_user_database_enabled is true."
}
}
}
# CloudWatch log group for OpenSearch
locals {
log_group_name = var.log_group_name != null ? var.log_group_name : "/aws/OpenSearchService/domains/${var.name}/search-logs"
}
resource "aws_cloudwatch_log_group" "opensearch" {
count = var.enable_logging ? 1 : 0
name = local.log_group_name
retention_in_days = var.log_retention_days
tags = var.tags
}
# CloudWatch log resource policy for OpenSearch
data "aws_iam_policy_document" "opensearch_log_policy" {
count = var.enable_logging ? 1 : 0
statement {
effect = "Allow"
principals {
type = "Service"
identifiers = ["es.amazonaws.com"]
}
actions = [
"logs:PutLogEvents",
"logs:CreateLogStream",
]
resources = ["arn:aws:logs:${data.aws_region.current.name}:${data.aws_caller_identity.current.account_id}:log-group:${local.log_group_name}:*"]
}
}
resource "aws_cloudwatch_log_resource_policy" "opensearch" {
count = var.enable_logging ? 1 : 0
policy_name = "OpenSearchService-${var.name}-Search-logs"
policy_document = data.aws_iam_policy_document.opensearch_log_policy[0].json
}

View File

@@ -1,29 +0,0 @@
output "domain_endpoint" {
description = "The endpoint of the OpenSearch domain"
value = aws_opensearch_domain.main.endpoint
}
output "domain_arn" {
description = "The ARN of the OpenSearch domain"
value = aws_opensearch_domain.main.arn
}
output "domain_id" {
description = "The unique identifier for the OpenSearch domain"
value = aws_opensearch_domain.main.domain_id
}
output "domain_name" {
description = "The name of the OpenSearch domain"
value = aws_opensearch_domain.main.domain_name
}
output "kibana_endpoint" {
description = "The OpenSearch Dashboards endpoint"
value = aws_opensearch_domain.main.dashboard_endpoint
}
output "security_group_id" {
description = "The ID of the OpenSearch security group"
value = length(aws_security_group.opensearch_sg) > 0 ? aws_security_group.opensearch_sg[0].id : null
}

View File

@@ -1,242 +0,0 @@
variable "name" {
description = "Name of the OpenSearch domain"
type = string
}
variable "vpc_id" {
description = "ID of the VPC to deploy the OpenSearch domain into"
type = string
}
variable "subnet_ids" {
description = "List of subnet IDs for the OpenSearch domain"
type = list(string)
}
variable "ingress_cidrs" {
description = "CIDR blocks allowed to access OpenSearch"
type = list(string)
}
variable "engine_version" {
description = "OpenSearch engine version (e.g., 2.17, 3.3)"
type = string
default = "3.3"
}
variable "instance_type" {
description = "Instance type for data nodes"
type = string
default = "r8g.large.search"
}
variable "instance_count" {
description = "Number of data nodes"
type = number
default = 3
}
variable "zone_awareness_enabled" {
description = "Whether to enable zone awareness for the cluster"
type = bool
default = true
}
variable "availability_zone_count" {
description = "Number of availability zones (2 or 3)"
type = number
default = 3
}
variable "dedicated_master_enabled" {
description = "Whether to enable dedicated master nodes"
type = bool
default = true
}
variable "dedicated_master_type" {
description = "Instance type for dedicated master nodes"
type = string
default = "m7g.large.search"
}
variable "dedicated_master_count" {
description = "Number of dedicated master nodes (must be 3 or 5)"
type = number
default = 3
}
variable "multi_az_with_standby_enabled" {
description = "Whether to enable Multi-AZ with Standby deployment"
type = bool
default = true
}
variable "warm_enabled" {
description = "Whether to enable warm storage"
type = bool
default = false
}
variable "warm_type" {
description = "Instance type for warm nodes"
type = string
default = "ultrawarm1.medium.search"
}
variable "warm_count" {
description = "Number of warm nodes"
type = number
default = 2
}
variable "cold_storage_enabled" {
description = "Whether to enable cold storage"
type = bool
default = false
}
variable "ebs_volume_type" {
description = "EBS volume type (gp3, gp2, io1)"
type = string
default = "gp3"
}
variable "ebs_volume_size" {
description = "EBS volume size in GB per node"
type = number
default = 512
}
variable "ebs_iops" {
description = "IOPS for gp3/io1 volumes"
type = number
default = 3000
}
variable "ebs_throughput" {
description = "Throughput in MiB/s for gp3 volumes"
type = number
default = 256
}
variable "kms_key_id" {
description = "KMS key ID for encryption at rest (uses AWS managed key if not specified)"
type = string
default = null
}
variable "tls_security_policy" {
description = "TLS security policy for HTTPS endpoints"
type = string
default = "Policy-Min-TLS-1-2-2019-07"
}
variable "internal_user_database_enabled" {
description = "Whether to enable the internal user database for fine-grained access control"
type = bool
default = true
}
variable "master_user_name" {
description = "Master user name for internal user database"
type = string
default = null
sensitive = true
}
variable "master_user_password" {
description = "Master user password for internal user database"
type = string
default = null
sensitive = true
}
variable "master_user_arn" {
description = "IAM ARN for the master user (used when internal_user_database_enabled is false)"
type = string
default = null
}
variable "advanced_options" {
description = "Advanced options for OpenSearch"
type = map(string)
default = {
"indices.fielddata.cache.size" = "20"
"indices.query.bool.max_clause_count" = "1024"
"override_main_response_version" = "false"
"rest.action.multi.allow_explicit_index" = "true"
}
}
variable "auto_tune_enabled" {
description = "Whether to enable Auto-Tune"
type = bool
default = true
}
variable "auto_tune_rollback_on_disable" {
description = "Whether to roll back Auto-Tune changes when disabled"
type = string
default = "NO_ROLLBACK"
}
variable "off_peak_window_enabled" {
description = "Whether to enable off-peak window for maintenance"
type = bool
default = true
}
variable "off_peak_window_start_hours" {
description = "Hour (UTC) when off-peak window starts (0-23)"
type = number
default = 6
}
variable "off_peak_window_start_minutes" {
description = "Minutes when off-peak window starts (0-59)"
type = number
default = 0
}
variable "auto_software_update_enabled" {
description = "Whether to enable automatic software updates"
type = bool
default = false
}
variable "enable_logging" {
description = "Whether to enable CloudWatch logging"
type = bool
default = false
}
variable "create_service_linked_role" {
description = "Whether to create the OpenSearch service-linked role (set to false if it already exists)"
type = bool
default = false
}
variable "log_retention_days" {
description = "Number of days to retain CloudWatch logs"
type = number
default = 30
}
variable "security_group_ids" {
description = "Existing security group IDs to attach. If empty, a new SG is created."
type = list(string)
default = []
}
variable "log_group_name" {
description = "CloudWatch log group name. Defaults to AWS console convention."
type = string
default = null
}
variable "tags" {
description = "Tags to apply to OpenSearch resources"
type = map(string)
default = {}
}

View File

@@ -3839,9 +3839,9 @@
}
},
"node_modules/flatted": {
"version": "3.4.2",
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.4.2.tgz",
"integrity": "sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA==",
"version": "3.3.3",
"resolved": "https://registry.npmjs.org/flatted/-/flatted-3.3.3.tgz",
"integrity": "sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==",
"dev": true,
"license": "ISC"
},

View File

@@ -144,7 +144,6 @@ module.exports = {
"**/src/app/**/services/*.test.ts",
"**/src/app/**/utils/*.test.ts",
"**/src/app/**/hooks/*.test.ts", // Pure packet processor tests
"**/src/hooks/**/*.test.ts",
"**/src/refresh-components/**/*.test.ts",
"**/src/refresh-pages/**/*.test.ts",
"**/src/sections/**/*.test.ts",

View File

@@ -1,30 +1,33 @@
"use client";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import type { TableSize } from "@opal/components/table/TableSizeContext";
interface ActionsContainerProps {
type: "head" | "cell";
children: React.ReactNode;
size?: TableSize;
/** Pass-through click handler (e.g. stopPropagation on body cells). */
onClick?: (e: React.MouseEvent) => void;
children: React.ReactNode;
}
export default function ActionsContainer({
type,
children,
size,
onClick,
}: ActionsContainerProps) {
const size = useTableSize();
const contextSize = useTableSize();
const resolvedSize = size ?? contextSize;
const Tag = type === "head" ? "th" : "td";
return (
<Tag
className="tbl-actions"
data-type={type}
data-size={size}
data-size={resolvedSize}
onClick={onClick}
>
<div className="flex h-full items-center justify-end">{children}</div>
<div className="flex h-full items-center justify-center">{children}</div>
</Tag>
);
}

View File

@@ -8,7 +8,6 @@ import {
type SortingState,
} from "@tanstack/react-table";
import { Button, LineItemButton } from "@opal/components";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import { SvgArrowUpDown, SvgSortOrder, SvgCheck } from "@opal/icons";
import Popover from "@/refresh-components/Popover";
import Divider from "@/refresh-components/Divider";
@@ -21,6 +20,7 @@ import Text from "@/refresh-components/texts/Text";
interface SortingPopoverProps<TData extends RowData = RowData> {
table: Table<TData>;
sorting: SortingState;
size?: "md" | "lg";
footerText?: string;
ascendingLabel?: string;
descendingLabel?: string;
@@ -29,11 +29,11 @@ interface SortingPopoverProps<TData extends RowData = RowData> {
function SortingPopover<TData extends RowData>({
table,
sorting,
size = "lg",
footerText,
ascendingLabel = "Ascending",
descendingLabel = "Descending",
}: SortingPopoverProps<TData>) {
const size = useTableSize();
const [open, setOpen] = useState(false);
const sortableColumns = table
.getAllLeafColumns()
@@ -158,6 +158,7 @@ function SortingPopover<TData extends RowData>({
// ---------------------------------------------------------------------------
interface CreateSortingColumnOptions {
size?: "md" | "lg";
footerText?: string;
ascendingLabel?: string;
descendingLabel?: string;
@@ -176,6 +177,7 @@ function createSortingColumn<TData>(
<SortingPopover
table={table}
sorting={table.getState().sorting}
size={options?.size}
footerText={options?.footerText}
ascendingLabel={options?.ascendingLabel}
descendingLabel={options?.descendingLabel}

View File

@@ -8,7 +8,6 @@ import {
type VisibilityState,
} from "@tanstack/react-table";
import { Button, LineItemButton, Tag } from "@opal/components";
import { useTableSize } from "@opal/components/table/TableSizeContext";
import { SvgColumn, SvgCheck } from "@opal/icons";
import Popover from "@/refresh-components/Popover";
import Divider from "@/refresh-components/Divider";
@@ -20,13 +19,14 @@ import Divider from "@/refresh-components/Divider";
interface ColumnVisibilityPopoverProps<TData extends RowData = RowData> {
table: Table<TData>;
columnVisibility: VisibilityState;
size?: "md" | "lg";
}
function ColumnVisibilityPopover<TData extends RowData>({
table,
columnVisibility,
size = "lg",
}: ColumnVisibilityPopoverProps<TData>) {
const size = useTableSize();
const [open, setOpen] = useState(false);
// User-defined columns only (exclude internal qualifier/actions)
@@ -87,7 +87,13 @@ function ColumnVisibilityPopover<TData extends RowData>({
// Column definition factory
// ---------------------------------------------------------------------------
function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
interface CreateColumnVisibilityColumnOptions {
size?: "md" | "lg";
}
function createColumnVisibilityColumn<TData>(
options?: CreateColumnVisibilityColumnOptions
): ColumnDef<TData, unknown> {
return {
id: "__columnVisibility",
size: 44,
@@ -98,6 +104,7 @@ function createColumnVisibilityColumn<TData>(): ColumnDef<TData, unknown> {
<ColumnVisibilityPopover
table={table}
columnVisibility={table.getState().columnVisibility}
size={options?.size}
/>
),
cell: () => null,

Some files were not shown because too many files have changed in this diff Show More