Compare commits

..

71 Commits

Author SHA1 Message Date
pablodanswer
25b38212e9 nit 2025-01-19 09:50:35 -08:00
pablodanswer
3096b0b2a7 add linear check 2025-01-19 09:49:26 -08:00
Chris Weaver
342bb9f685 Fix document counts (#3671)
* Various fixes/improvements to document counting

* Add new column + index

* Avoid double scan

* comment fixes

* Fix revision history

* Fix IT

* Fix IT

* Fix migration

* Rebase
2025-01-19 05:36:07 +00:00
hagen-danswer
b25668c83a fixed group sync to account for changes in drive permissions (#3666)
* fixed group sync to account for changes in drive permissions

* mypy

* addressed

* reeeeeeeee
2025-01-19 00:08:50 +00:00
Weves
a72bd31f5d Small background telemetry fix 2025-01-18 16:19:28 -08:00
hagen-danswer
896e716d02 query history pagination tests (#3700)
* dummy pr

* Update prompts.yaml

* fixed tests and added query history pagination test

* done

* fixed

* utils!
2025-01-18 21:28:03 +00:00
pablonyx
eec3ce8162 Markdown rendering (#3698)
* nit

* update comment
2025-01-18 12:12:19 -08:00
pablonyx
2761a837c6 quick nit for no-longer living files (#3702) 2025-01-18 11:09:34 -08:00
hagen-danswer
da43abe644 Made copy button and cmd+c work for cmd+v and cmd+shift+v (#3693)
* Made copy button and cmd+c work for cmd+v and cmd+shift+v

* made sub selections work as well

* ok it works

* fixed npm run build

* im not from earth

* added logging

* more logging

* bye logs

* should work now

* whoops

* added stuff

* made it robust

* ctrl shift v behavior
2025-01-18 10:34:32 -08:00
skylares
af953ff8a3 Paginate Query History table (#3592)
* Add pagination for query history table

* Fix method name

* Fix mypy
2025-01-17 15:31:42 -08:00
rkuo-danswer
6fc52c81ab Bugfix/beat redux (#3639)
* WIP

* WIP

* try spinning out check for indexing into a system task

* check for the correct delimiter

* use constants

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-01-17 20:59:43 +00:00
hagen-danswer
1ad2128b2a Combined Persona and Prompt API (#3690)
* Combined Persona and Prompt API

* quality

* added tests

* consolidated models and got rid of redundant fields

* tenant appreciation day

* reverted default
2025-01-17 20:21:20 +00:00
Kaveen Jayamanna
880c42ad41 Validating slackbot tokens (#3695)
* added missing dependency, missing api key placeholder, updated docs

* Apply black formatting and validate bot token functionality

* acknowledging black formatting

* added the validation to update tokens as well

* Made the token validation errors looks nicer

* getting rif of duplicate dependency
2025-01-17 11:50:22 -08:00
pablonyx
c9e0d77c93 Minor large PR cleanup (misc fies)
Minor large PR cleanup
2025-01-16 09:41:06 -08:00
pablodanswer
7a750dc2ca Minor large PR cleanup 2025-01-16 09:39:27 -08:00
pablonyx
44b70a87df UX Refresh (#3687)
* add new ux

* quick nit

* additional nit

* finalize

* quick fix

* fix typing
2025-01-16 08:08:01 +00:00
Chris Weaver
a05addec19 Add is_cloud info to telemetry + get consistent customer_uuid's for a… (#3684)
* Add is_cloud info to telemetry + get consistent customer_uuid's for a given tenant

* Address Richard's comments
2025-01-16 02:43:21 +00:00
Chris Weaver
8a4d762798 Fix follow ups in thread + fix user name (#3686)
* Fix follow ups in thread + fix user name

* Add back single history str

* Remove newline
2025-01-16 02:40:25 +00:00
rkuo-danswer
c9a420ec49 better logging and reduce long sessions (#3673)
* testing some tweaks based on issues seen with okteto

* shorten session usage in indexing. still a couple of long running sessions to clean up

* merge sessions

* fixing detached session issues

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-16 01:27:12 +00:00
pablodanswer
beccca5fa2 Remove stranded file 2025-01-15 16:34:13 -08:00
pablonyx
66d8b8bb10 Add chrome extension pages (#3629) 2025-01-15 15:09:49 -08:00
pablonyx
76ca650972 Admin usage for seeding (#3683)
* admin usage for seeding

* functional

* proper fix

* k

* typing
2025-01-15 19:04:25 +00:00
hagen-danswer
eb70699c0b temp test fixes (#3682)
* fix discord test

* Fix discord test

* fixed fireflies test too
2025-01-15 09:07:05 -08:00
skylares
b401f83eb6 Salesforce daily test (#3611)
* Add daily salesforce test

* Add more assertions

* Add assertions for data by parsing the key-value strings

* Fix grammar
2025-01-15 07:53:50 -08:00
skylares
993a1a6caf Add discord daily test (#3676)
* Add discord daily test

* Fix mypy error
2025-01-15 07:50:33 -08:00
skylares
c3481c7356 Fireflies daily test (#3663)
* Init test files for fireflies

* Finish creating daily test and update parsing of sections

* Added comment
2025-01-15 06:40:31 -08:00
Chris Weaver
3b7695539f Add monitoring worker (#3677)
* Add monitoring worker

* Add locks

* Add tenant id to lock

* Remove unneeded tenant postfix
2025-01-15 01:39:56 +00:00
hagen-danswer
b1957737f2 refactored _add_user_filter usage (#3674)
* refactored db.connector_credential_pair

* Rerfactored the db.credentials user filtering

* the restr
2025-01-14 23:35:52 +00:00
rkuo-danswer
5f462056f6 Merge pull request #3660 from onyx-dot-app/bugfix/index_attempt_query
optimize another index attempt check
2025-01-13 20:02:54 -08:00
Richard Kuo (Danswer)
0de4d61b6d Merge branch 'main' of https://github.com/onyx-dot-app/onyx into bugfix/index_attempt_query 2025-01-13 16:26:22 -08:00
rkuo-danswer
7a28a5c216 Merge pull request #3669 from onyx-dot-app/bugfix/fix_time_updated
fix missed var names
2025-01-13 15:04:17 -08:00
Richard Kuo (Danswer)
d8aa21ca3a fix missed var names 2025-01-13 14:32:26 -08:00
Richard Kuo (Danswer)
c4323573d2 fix alembic 2025-01-13 13:23:40 -08:00
Richard Kuo (Danswer)
46cfaa96b7 Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/index_attempt_query 2025-01-13 13:23:30 -08:00
Weves
a610b6bd8d Support new model for image input 2025-01-13 13:17:51 -08:00
rkuo-danswer
cb66aadd80 Merge pull request #3648 from onyx-dot-app/bugfix/light_cpu
figuring out why multiprocessing set_start_method isn't working.
2025-01-13 13:08:55 -08:00
Chris Weaver
9ea2ae267e Performance monitoring (#3658)
* Initial scaffolding for metrics

* iterate

* more

* More metrics + SyncRecord concept

* Add indices, standardize timing

* Small cleanup

* Address comments
2025-01-13 12:36:45 -08:00
Richard Kuo (Danswer)
7d86b28335 maybe we don't need pre ping yet 2025-01-13 12:14:32 -08:00
Richard Kuo (Danswer)
4f8e48df7c try more sql settings 2025-01-13 11:50:04 -08:00
Richard Kuo (Danswer)
d96d2fc6e9 add comment 2025-01-13 11:35:58 -08:00
Richard Kuo (Danswer)
b6dd999c1b add some type hints 2025-01-13 11:31:57 -08:00
Richard Kuo (Danswer)
9a09222b7d add comments 2025-01-13 10:58:33 -08:00
Richard Kuo (Danswer)
be3cfdd4a6 saved files 2025-01-13 10:46:20 -08:00
Richard Kuo (Danswer)
f5bdf9d2c9 move to celeryd_init 2025-01-13 02:46:03 -08:00
hagen-danswer
6afd27f9c9 fix group sync name capitalization (#3653)
* fix group sync name capitalization

* everything is lowercased now

* comments

* Added test for be2ab2aa50ee migration

* polish
2025-01-10 16:51:33 -08:00
Richard Kuo (Danswer)
ccef350287 try using spawn specifically 2025-01-10 14:19:31 -08:00
Richard Kuo (Danswer)
4400a945e3 optimize another index attempt check 2025-01-10 14:18:49 -08:00
Richard Kuo (Danswer)
384a38418b test set_spawn_method and handle exceptions 2025-01-10 12:59:34 -08:00
Richard Kuo (Danswer)
2163a138ed logging 2025-01-10 12:41:05 -08:00
Richard Kuo (Danswer)
b6c2ecfecb more debugging of start method 2025-01-10 12:16:13 -08:00
Richard Kuo (Danswer)
ac182c74b3 log all start methods 2025-01-10 12:11:33 -08:00
pablonyx
cab7e60542 Proper anonymous user restricting (#3645) 2025-01-10 11:31:11 -08:00
Richard Kuo (Danswer)
8e25c3c412 Merge branch 'main' of https://github.com/danswer-ai/danswer into bugfix/light_cpu 2025-01-10 11:01:12 -08:00
Weves
1470b7e038 Add tests for some LLM provider endpoints + small logic change to ensure that display_model_names is not empty 2025-01-10 08:55:53 -08:00
rkuo-danswer
bf78fb79f8 possible fix for gdrive oauth in the cloud (#3642)
* possible fix for gd oauth in the cloud

* missed code in rename/merge

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-10 02:10:59 +00:00
rkuo-danswer
d972a78f45 Make connector pause and delete fast (#3646)
* first cut

* refresh on delete

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-10 01:39:45 +00:00
Richard Kuo (Danswer)
962240031f figuring out why multiprocessing set_start_method isn't working. 2025-01-09 16:29:37 -08:00
hagen-danswer
50131ba22c Better logging for confluence space permissions 2025-01-09 15:13:02 -08:00
rkuo-danswer
439217317f Merge pull request #3644 from onyx-dot-app/bugfix/model-server-build-fix
hope this env var works.
2025-01-09 14:34:25 -08:00
hagen-danswer
c55de28423 added distinct when outer joining for user filters (#3641)
* added distinct when outer joining for user filters

* Added distinct when outer joining for user filters for all
2025-01-09 14:15:38 -08:00
Richard Kuo (Danswer)
91e32e801d hope this env var works. 2025-01-09 13:51:58 -08:00
rkuo-danswer
2ae91f0f2b Feature/redis prod tool (#3619)
* prototype tools for handling prod issues

* add some commands

* add batching and dry run options

* custom redis tool

* comment

* default to app config settings for redis

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-09 21:34:07 +00:00
hagen-danswer
d40fd82803 Conf doc sync improvements (#3643)
* Reduce number of requests to Confluence

* undo

* added a way to dynamically adjust the pagination limit

* undo
2025-01-09 12:56:56 -08:00
rkuo-danswer
97a963b4bf add index to speed up get last attempt (#3636)
* add index to speed up get last attempt

* use descending order

* put back unique param

* how did this not get formatted?

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-09 00:56:55 +00:00
pablonyx
7f6ef1ff57 Remove unnecessary logspam
Remove unnecessary logs
2025-01-08 17:03:52 -08:00
rkuo-danswer
a76f1b4c1b Merge pull request #3628 from onyx-dot-app/bugfix/debug_tenant
add more debug logging for locking issue
2025-01-08 15:14:37 -08:00
hagen-danswer
4c4ff46fe3 Fixing google drive tests (#3634)
* Fixing google drive texts

* Update conftest.py
2025-01-08 22:34:38 +00:00
hagen-danswer
0f9842064f Added env var to skip warm up (#3633) 2025-01-08 14:29:15 -08:00
Richard Kuo (Danswer)
1f48de9731 more logging 2025-01-08 12:49:24 -08:00
Richard Kuo (Danswer)
a22d02ff70 add another log line 2025-01-08 10:01:24 -08:00
Richard Kuo (Danswer)
dcfc621a66 add more debug logging for locking issue 2025-01-08 09:43:47 -08:00
347 changed files with 22187 additions and 7525 deletions

View File

@@ -1,11 +1,15 @@
## Description
[Provide a brief description of the changes in this PR]
## How Has This Been Tested?
[Describe the tests you ran to verify your changes]
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
- [ ] I have included a link to a Linear ticket in my description.
- [ ] [Optional] Override Linear Check

View File

@@ -118,6 +118,6 @@ jobs:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/onyxdotapp/onyx-model-server:${{ github.ref_name }}
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
timeout: "10m"

29
.github/workflows/pr-linear-check.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: Ensure PR references Linear
on:
pull_request:
types: [opened, edited, reopened, synchronize]
jobs:
linear-check:
runs-on: ubuntu-latest
steps:
- name: Check PR body for Linear link or override
run: |
PR_BODY="${{ github.event.pull_request.body }}"
# Looking for "https://linear.app" in the body
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
echo "Found a Linear link. Check passed."
exit 0
fi
# Looking for a checked override: "[x] Override Linear Check"
if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then
echo "Override box is checked. Check passed."
exit 0
fi
# Otherwise, fail the run
echo "No Linear link or override found in the PR description."
exit 1

View File

@@ -5,6 +5,8 @@
# For local dev, often user Authentication is not needed
AUTH_TYPE=disabled
# Skip warm up for dev
SKIP_WARM_UP=True
# Always keep these on for Dev
# Logs all model prompts to stdout
@@ -27,6 +29,7 @@ REQUIRE_EMAIL_VERIFICATION=False
# Set these so if you wipe the DB, you don't end up having to go through the UI every time
GEN_AI_API_KEY=<REPLACE THIS>
OPENAI_API_KEY=<REPLACE THIS>
# If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper
GEN_AI_MODEL_VERSION=gpt-4o
FAST_GEN_AI_MODEL_VERSION=gpt-4o

View File

@@ -28,6 +28,7 @@
"Celery heavy",
"Celery indexing",
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
@@ -51,7 +52,8 @@
"Celery light",
"Celery heavy",
"Celery indexing",
"Celery beat"
"Celery beat",
"Celery monitoring",
],
"presentation": {
"group": "1",
@@ -269,6 +271,31 @@
},
"consoleTitle": "Celery indexing Console"
},
{
"name": "Celery monitoring",
"type": "debugpy",
"request": "launch",
"module": "celery",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.vscode/.env",
"env": {},
"args": [
"-A",
"onyx.background.celery.versioned_apps.monitoring",
"worker",
"--pool=solo",
"--concurrency=1",
"--prefetch-multiplier=1",
"--loglevel=INFO",
"--hostname=monitoring@%n",
"-Q",
"monitoring",
],
"presentation": {
"group": "2",
},
"consoleTitle": "Celery monitoring Console"
},
{
"name": "Celery beat",
"type": "debugpy",

View File

@@ -17,9 +17,10 @@ Before starting, make sure the Docker Daemon is running.
1. Open the Debug view in VSCode (Cmd+Shift+D on macOS)
2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button
3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button
4. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
5. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
6. Use the debug toolbar to step through code, inspect variables, etc.
4. CD into web, run "npm i" followed by npm run dev.
5. Now, you can navigate to onyx in your browser (default is http://localhost:3000) and start using the app
6. You can set breakpoints by clicking to the left of line numbers to help debug while the app is running
7. Use the debug toolbar to step through code, inspect variables, etc.
## Features

View File

@@ -0,0 +1,29 @@
"""add shortcut option for users
Revision ID: 027381bce97c
Revises: 6fc7886d665d
Create Date: 2025-01-14 12:14:00.814390
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "027381bce97c"
down_revision = "6fc7886d665d"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"shortcut_enabled", sa.Boolean(), nullable=False, server_default="false"
),
)
def downgrade() -> None:
op.drop_column("user", "shortcut_enabled")

View File

@@ -0,0 +1,36 @@
"""add index to index_attempt.time_created
Revision ID: 0f7ff6d75b57
Revises: 369644546676
Create Date: 2025-01-10 14:01:14.067144
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "0f7ff6d75b57"
down_revision = "fec3db967bf7"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
op.f("ix_index_attempt_status"),
"index_attempt",
["status"],
unique=False,
)
op.create_index(
op.f("ix_index_attempt_time_created"),
"index_attempt",
["time_created"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_index_attempt_time_created"), table_name="index_attempt")
op.drop_index(op.f("ix_index_attempt_status"), table_name="index_attempt")

View File

@@ -0,0 +1,35 @@
"""add composite index for index attempt time updated
Revision ID: 369644546676
Revises: 2955778aa44c
Create Date: 2025-01-08 15:38:17.224380
"""
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "369644546676"
down_revision = "2955778aa44c"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
text("time_updated DESC"),
],
unique=False,
)
def downgrade() -> None:
op.drop_index(
"ix_index_attempt_ccpair_search_settings_time_updated",
table_name="index_attempt",
)

View File

@@ -0,0 +1,59 @@
"""add back input prompts
Revision ID: 3c6531f32351
Revises: aeda5f2df4f6
Create Date: 2025-01-13 12:49:51.705235
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "3c6531f32351"
down_revision = "aeda5f2df4f6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column(
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
),
sa.Column("disabled", sa.Boolean(), nullable=False, default=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)
def downgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")

View File

@@ -40,6 +40,6 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_constraint("fk_persona_category", "persona", type_="foreignkey")
op.drop_constraint("persona_category_id_fkey", "persona", type_="foreignkey")
op.drop_column("persona", "category_id")
op.drop_table("persona_category")

View File

@@ -0,0 +1,80 @@
"""make categories labels and many to many
Revision ID: 6fc7886d665d
Revises: 3c6531f32351
Create Date: 2025-01-13 18:12:18.029112
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6fc7886d665d"
down_revision = "3c6531f32351"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Rename persona_category table to persona_label
op.rename_table("persona_category", "persona_label")
# Create the new association table
op.create_table(
"persona__persona_label",
sa.Column("persona_id", sa.Integer(), nullable=False),
sa.Column("persona_label_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["persona_id"],
["persona.id"],
),
sa.ForeignKeyConstraint(
["persona_label_id"],
["persona_label.id"],
ondelete="CASCADE",
),
sa.PrimaryKeyConstraint("persona_id", "persona_label_id"),
)
# Copy existing relationships to the new table
op.execute(
"""
INSERT INTO persona__persona_label (persona_id, persona_label_id)
SELECT id, category_id FROM persona WHERE category_id IS NOT NULL
"""
)
# Remove the old category_id column from persona table
op.drop_column("persona", "category_id")
def downgrade() -> None:
# Rename persona_label table back to persona_category
op.rename_table("persona_label", "persona_category")
# Add back the category_id column to persona table
op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"persona_category_id_fkey",
"persona",
"persona_category",
["category_id"],
["id"],
)
# Copy the first label relationship back to the persona table
op.execute(
"""
UPDATE persona
SET category_id = (
SELECT persona_label_id
FROM persona__persona_label
WHERE persona__persona_label.persona_id = persona.id
LIMIT 1
)
"""
)
# Drop the association table
op.drop_table("persona__persona_label")

View File

@@ -0,0 +1,72 @@
"""Add SyncRecord
Revision ID: 97dbb53fa8c8
Revises: 369644546676
Create Date: 2025-01-11 19:39:50.426302
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "97dbb53fa8c8"
down_revision = "be2ab2aa50ee"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"sync_record",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("entity_id", sa.Integer(), nullable=False),
sa.Column(
"sync_type",
sa.Enum(
"DOCUMENT_SET",
"USER_GROUP",
"CONNECTOR_DELETION",
name="synctype",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column(
"sync_status",
sa.Enum(
"IN_PROGRESS",
"SUCCESS",
"FAILED",
"CANCELED",
name="syncstatus",
native_enum=False,
length=40,
),
nullable=False,
),
sa.Column("num_docs_synced", sa.Integer(), nullable=False),
sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
# Add index for fetch_latest_sync_record query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_start_time",
"sync_record",
["entity_id", "sync_type", "sync_start_time"],
)
# Add index for cleanup_sync_records query
op.create_index(
"ix_sync_record_entity_id_sync_type_sync_status",
"sync_record",
["entity_id", "sync_type", "sync_status"],
)
def downgrade() -> None:
op.drop_index("ix_sync_record_entity_id_sync_type_sync_status")
op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time")
op.drop_table("sync_record")

View File

@@ -0,0 +1,27 @@
"""add pinned assistants
Revision ID: aeda5f2df4f6
Revises: c5eae4a75a1b
Create Date: 2025-01-09 16:04:10.770636
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "aeda5f2df4f6"
down_revision = "c5eae4a75a1b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user", sa.Column("pinned_assistants", postgresql.JSONB(), nullable=True)
)
op.execute('UPDATE "user" SET pinned_assistants = chosen_assistants')
def downgrade() -> None:
op.drop_column("user", "pinned_assistants")

View File

@@ -0,0 +1,38 @@
"""fix_capitalization
Revision ID: be2ab2aa50ee
Revises: 369644546676
Create Date: 2025-01-10 13:13:26.228960
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "be2ab2aa50ee"
down_revision = "369644546676"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.execute(
"""
UPDATE document
SET
external_user_group_ids = ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
),
last_modified = NOW()
WHERE
external_user_group_ids IS NOT NULL
AND external_user_group_ids::text[] <> ARRAY(
SELECT LOWER(unnest(external_user_group_ids))
)::text[]
"""
)
def downgrade() -> None:
# No way to cleanly persist the bad state through an upgrade/downgrade
# cycle, so we just pass
pass

View File

@@ -0,0 +1,36 @@
"""Add chat_message__standard_answer table
Revision ID: c5eae4a75a1b
Revises: 0f7ff6d75b57
Create Date: 2025-01-15 14:08:49.688998
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c5eae4a75a1b"
down_revision = "0f7ff6d75b57"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chat_message__standard_answer",
sa.Column("chat_message_id", sa.Integer(), nullable=False),
sa.Column("standard_answer_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["chat_message_id"],
["chat_message.id"],
),
sa.ForeignKeyConstraint(
["standard_answer_id"],
["standard_answer.id"],
),
sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"),
)
def downgrade() -> None:
op.drop_table("chat_message__standard_answer")

View File

@@ -0,0 +1,48 @@
"""Add has_been_indexed to DocumentByConnectorCredentialPair
Revision ID: c7bf5721733e
Revises: fec3db967bf7
Create Date: 2025-01-13 12:39:05.831693
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c7bf5721733e"
down_revision = "027381bce97c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# assume all existing rows have been indexed, no better approach
op.add_column(
"document_by_connector_credential_pair",
sa.Column("has_been_indexed", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE"
)
op.alter_column(
"document_by_connector_credential_pair",
"has_been_indexed",
nullable=False,
)
# Add index to optimize get_document_counts_for_cc_pairs query pattern
op.create_index(
"idx_document_cc_pair_counts",
"document_by_connector_credential_pair",
["connector_id", "credential_id", "has_been_indexed"],
unique=False,
)
def downgrade() -> None:
# Remove the index first before removing the column
op.drop_index(
"idx_document_cc_pair_counts",
table_name="document_by_connector_credential_pair",
)
op.drop_column("document_by_connector_credential_pair", "has_been_indexed")

View File

@@ -0,0 +1,41 @@
"""Add time_updated to UserGroup and DocumentSet
Revision ID: fec3db967bf7
Revises: 97dbb53fa8c8
Create Date: 2025-01-12 15:49:02.289100
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "fec3db967bf7"
down_revision = "97dbb53fa8c8"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"document_set",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
op.add_column(
"user_group",
sa.Column(
"time_last_modified_by_user",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
)
def downgrade() -> None:
op.drop_column("user_group", "time_last_modified_by_user")
op.drop_column("document_set", "time_last_modified_by_user")

View File

@@ -1,6 +1,9 @@
from datetime import timedelta
from typing import Any
from onyx.background.celery.tasks.beat_schedule import (
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
)
from onyx.background.celery.tasks.beat_schedule import (
tasks_to_schedule as base_tasks_to_schedule,
)
@@ -8,7 +11,7 @@ from onyx.configs.constants import OnyxCeleryTask
ee_tasks_to_schedule = [
{
"name": "autogenerate_usage_report",
"name": "autogenerate-usage-report",
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
"schedule": timedelta(days=30), # TODO: change this to config flag
},
@@ -20,5 +23,9 @@ ee_tasks_to_schedule = [
]
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return base_cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return ee_tasks_to_schedule + base_tasks_to_schedule

View File

@@ -8,6 +8,9 @@ from ee.onyx.db.user_group import fetch_user_group
from ee.onyx.db.user_group import mark_user_group_as_synced
from ee.onyx.db.user_group import prepare_user_group_for_deletion
from onyx.background.celery.apps.app_base import task_logger
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
@@ -43,24 +46,59 @@ def monitor_usergroup_taskset(
f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
)
if count > 0:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=count,
)
return
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
if user_group:
usergroup_name = user_group.name
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
try:
if user_group.is_up_for_deletion:
# this prepare should have been run when the deletion was scheduled,
# but run it again to be sure we're ready to go
mark_user_group_as_synced(db_session, user_group)
prepare_user_group_for_deletion(db_session, usergroup_id)
delete_user_group(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
)
else:
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
task_logger.info(
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
)
except Exception as e:
update_sync_record_status(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
sync_status=SyncStatus.FAILED,
num_docs_synced=initial_count,
)
raise e
rug.reset()

View File

@@ -345,7 +345,8 @@ def fetch_assistant_unique_users_total(
def user_can_view_assistant_stats(
db_session: Session, user: User | None, assistant_id: int
) -> bool:
# If user is None, assume the user is an admin or auth is disabled
# If user is None and auth is disabled, assume the user is an admin
if user is None or user.role == UserRole.ADMIN:
return True

View File

@@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.models import ExternalAccess
from onyx.access.utils import prefix_group_w_source
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.db.models import Document as DbDocument
@@ -25,7 +25,7 @@ def upsert_document_external_perms__no_commit(
).first()
prefixed_external_groups = [
prefix_group_w_source(
build_ext_group_name_for_onyx(
ext_group_name=group_id,
source=source_type,
)
@@ -66,7 +66,7 @@ def upsert_document_external_perms(
).first()
prefixed_external_groups: set[str] = {
prefix_group_w_source(
build_ext_group_name_for_onyx(
ext_group_name=group_id,
source=source_type,
)

View File

@@ -6,8 +6,9 @@ from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.access.utils import prefix_group_w_source
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.users import get_user_by_email
@@ -61,8 +62,10 @@ def replace_user__ext_group_for_cc_pair(
all_group_member_emails.add(user_email)
# batch add users if they don't exist and get their ids
all_group_members = batch_add_ext_perm_user_if_not_exists(
db_session=db_session, emails=list(all_group_member_emails)
all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
# NOTE: this function handles case sensitivity for emails
emails=list(all_group_member_emails),
)
delete_user__ext_group_for_cc_pair__no_commit(
@@ -84,12 +87,14 @@ def replace_user__ext_group_for_cc_pair(
f" with email {user_email} not found"
)
continue
external_group_id = build_ext_group_name_for_onyx(
ext_group_name=external_group.id,
source=source,
)
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,
external_user_group_id=prefix_group_w_source(
external_group.id, source
),
external_user_group_id=external_group_id,
cc_pair_id=cc_pair_id,
)
)

View File

@@ -1,27 +1,135 @@
import datetime
from typing import Literal
from collections.abc import Sequence
from datetime import datetime
from sqlalchemy import asc
from sqlalchemy import BinaryExpression
from sqlalchemy import ColumnElement
from sqlalchemy import desc
from sqlalchemy import distinct
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import case
from sqlalchemy.sql import func
from sqlalchemy.sql import select
from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.expression import UnaryExpression
from onyx.configs.constants import QAFeedbackType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessageFeedback
from onyx.db.models import ChatSession
SortByOptions = Literal["time_sent"]
def _build_filter_conditions(
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> list[ColumnElement]:
"""
Helper function to build all filter conditions for chat sessions.
Filters by start and end time, feedback type, and any sessions without messages.
start_time: Date from which to filter
end_time: Date to which to filter
feedback_filter: Feedback type to filter by
Returns: List of filter conditions
"""
conditions = []
if start_time is not None:
conditions.append(ChatSession.time_created >= start_time)
if end_time is not None:
conditions.append(ChatSession.time_created <= end_time)
if feedback_filter is not None:
feedback_subq = (
select(ChatMessage.chat_session_id)
.join(ChatMessageFeedback)
.group_by(ChatMessage.chat_session_id)
.having(
case(
(
case(
{literal(feedback_filter == QAFeedbackType.LIKE): True},
else_=False,
),
func.bool_and(ChatMessageFeedback.is_positive),
),
(
case(
{literal(feedback_filter == QAFeedbackType.DISLIKE): True},
else_=False,
),
func.bool_and(func.not_(ChatMessageFeedback.is_positive)),
),
else_=func.bool_or(ChatMessageFeedback.is_positive)
& func.bool_or(func.not_(ChatMessageFeedback.is_positive)),
)
)
)
conditions.append(ChatSession.id.in_(feedback_subq))
return conditions
def get_total_filtered_chat_sessions_count(
db_session: Session,
start_time: datetime | None,
end_time: datetime | None,
feedback_filter: QAFeedbackType | None,
) -> int:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
stmt = (
select(func.count(distinct(ChatSession.id)))
.select_from(ChatSession)
.filter(*conditions)
)
return db_session.scalar(stmt) or 0
def get_page_of_chat_sessions(
start_time: datetime | None,
end_time: datetime | None,
db_session: Session,
page_num: int,
page_size: int,
feedback_filter: QAFeedbackType | None = None,
) -> Sequence[ChatSession]:
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
subquery = (
select(ChatSession.id, ChatSession.time_created)
.filter(*conditions)
.order_by(ChatSession.id, desc(ChatSession.time_created))
.distinct(ChatSession.id)
.limit(page_size)
.offset(page_num * page_size)
.subquery()
)
stmt = (
select(ChatSession)
.join(subquery, ChatSession.id == subquery.c.id)
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
.options(
joinedload(ChatSession.user),
joinedload(ChatSession.persona),
contains_eager(ChatSession.messages).joinedload(
ChatMessage.chat_message_feedbacks
),
)
.order_by(desc(ChatSession.time_created), asc(ChatMessage.id))
)
return db_session.scalars(stmt).unique().all()
def fetch_chat_sessions_eagerly_by_time(
start: datetime.datetime,
end: datetime.datetime,
start: datetime,
end: datetime,
db_session: Session,
limit: int | None = 500,
initial_time: datetime.datetime | None = None,
initial_time: datetime | None = None,
) -> list[ChatSession]:
time_order: UnaryExpression = desc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)

View File

@@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.constants import TokenRateLimitScope
from onyx.db.models import TokenRateLimit
from onyx.db.models import TokenRateLimit__UserGroup
@@ -20,10 +21,11 @@ from onyx.server.token_rate_limits.models import TokenRateLimitArgs
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
stmt = stmt.distinct()
TRLimit_UG = aliased(TokenRateLimit__UserGroup)
User__UG = aliased(User__UserGroup)
@@ -46,6 +48,12 @@ def _add_user_filters(
that the user isn't a curator for
- if we are not editing, we show all token_rate_limits in the groups the user curates
"""
# If user is None, this is an anonymous user and we should only show public token_rate_limits
if user is None:
where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
return stmt.where(where_clause)
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
@@ -103,10 +111,10 @@ def insert_user_group_token_rate_limit(
return token_limit
def fetch_user_group_token_rate_limits(
def fetch_user_group_token_rate_limits_for_user(
db_session: Session,
group_id: int,
user: User | None = None,
user: User | None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,

View File

@@ -374,7 +374,9 @@ def _add_user_group__cc_pair_relationships__no_commit(
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
db_user_group = UserGroup(name=user_group.name)
db_user_group = UserGroup(
name=user_group.name, time_last_modified_by_user=func.now()
)
db_session.add(db_user_group)
db_session.flush() # give the group an ID
@@ -630,6 +632,10 @@ def update_user_group(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
_validate_curator_status__no_commit(db_session, list(removed_users))
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
db_session.commit()
return db_user_group
@@ -699,7 +705,10 @@ def delete_user_group_cc_pair_relationship__no_commit(
connector_credential_pair_id matches the given cc_pair_id.
Should be used very carefully (only for connectors that are being deleted)."""
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist")

View File

@@ -24,7 +24,9 @@ _REQUEST_PAGINATION_LIMIT = 5000
def _get_server_space_permissions(
confluence_client: OnyxConfluence, space_key: str
) -> ExternalAccess:
space_permissions = confluence_client.get_space_permissions(space_key=space_key)
space_permissions = confluence_client.get_all_space_permissions_server(
space_key=space_key
)
viewspace_permissions = []
for permission_category in space_permissions:
@@ -67,6 +69,13 @@ def _get_server_space_permissions(
else:
logger.warning(f"Email for user {user_name} not found in Confluence")
if not user_emails and not group_names:
logger.warning(
"No user emails or group names found in Confluence space permissions"
f"\nSpace key: {space_key}"
f"\nSpace permissions: {space_permissions}"
)
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_names,

View File

@@ -120,9 +120,12 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_emails,
external_user_group_ids=group_ids,
is_public=public,
)

View File

@@ -1,16 +1,127 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import AdminService
from onyx.connectors.google_utils.resources import get_admin_service
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_drive_members(
google_drive_connector: GoogleDriveConnector,
) -> dict[str, tuple[set[str], set[str]]]:
"""
This builds a map of drive ids to their members (group and user emails).
E.g. {
"drive_id_1": ({"group_email_1"}, {"user_email_1", "user_email_2"}),
"drive_id_2": ({"group_email_3"}, {"user_email_3"}),
}
"""
drive_ids = google_drive_connector.get_all_drive_ids()
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {}
drive_service = get_drive_service(
google_drive_connector.creds,
google_drive_connector.primary_admin_email,
)
for drive_id in drive_ids:
group_emails: set[str] = set()
user_emails: set[str] = set()
for permission in execute_paginated_retrieval(
drive_service.permissions().list,
list_key="permissions",
fileId=drive_id,
fields="permissions(emailAddress, type)",
supportsAllDrives=True,
):
if permission["type"] == "group":
group_emails.add(permission["emailAddress"])
elif permission["type"] == "user":
user_emails.add(permission["emailAddress"])
drive_id_to_members_map[drive_id] = (group_emails, user_emails)
return drive_id_to_members_map
def _get_all_groups(
admin_service: AdminService,
google_domain: str,
) -> set[str]:
"""
This gets all the group emails.
"""
group_emails: set[str] = set()
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_domain,
fields="groups(email)",
):
group_emails.add(group["email"])
return group_emails
def _map_group_email_to_member_emails(
admin_service: AdminService,
group_emails: set[str],
) -> dict[str, set[str]]:
"""
This maps group emails to their member emails.
"""
group_to_member_map: dict[str, set[str]] = {}
for group_email in group_emails:
group_member_emails: set[str] = set()
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.add(member["email"])
group_to_member_map[group_email] = group_member_emails
return group_to_member_map
def _build_onyx_groups(
drive_id_to_members_map: dict[str, tuple[set[str], set[str]]],
group_email_to_member_emails_map: dict[str, set[str]],
) -> list[ExternalUserGroup]:
onyx_groups: list[ExternalUserGroup] = []
# Convert all drive member definitions to onyx groups
# This is because having drive level access means you have
# irrevocable access to all the files in the drive.
for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items():
all_member_emails: set[str] = user_emails
for group_email in group_emails:
all_member_emails.update(group_email_to_member_emails_map[group_email])
onyx_groups.append(
ExternalUserGroup(
id=drive_id,
user_emails=list(all_member_emails),
)
)
# Convert all group member definitions to onyx groups
for group_email, member_emails in group_email_to_member_emails_map.items():
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(member_emails),
)
)
return onyx_groups
def gdrive_group_sync(
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
@@ -19,34 +130,23 @@ def gdrive_group_sync(
google_drive_connector.creds, google_drive_connector.primary_admin_email
)
onyx_groups: list[ExternalUserGroup] = []
for group in execute_paginated_retrieval(
admin_service.groups().list,
list_key="groups",
domain=google_drive_connector.google_domain,
fields="groups(email)",
):
# The id is the group email
group_email = group["email"]
# Get all drive members
drive_id_to_members_map = _get_drive_members(google_drive_connector)
# Gather group member emails
group_member_emails: list[str] = []
for member in execute_paginated_retrieval(
admin_service.members().list,
list_key="members",
groupKey=group_email,
fields="members(email)",
):
group_member_emails.append(member["email"])
# Get all group emails
all_group_emails = _get_all_groups(
admin_service, google_drive_connector.google_domain
)
if not group_member_emails:
continue
# Map group emails to their members
group_email_to_member_emails_map = _map_group_email_to_member_emails(
admin_service, all_group_emails
)
onyx_groups.append(
ExternalUserGroup(
id=group_email,
user_emails=list(group_member_emails),
)
)
# Convert the maps to onyx groups
onyx_groups = _build_onyx_groups(
drive_id_to_members_map=drive_id_to_members_map,
group_email_to_member_emails_map=group_email_to_member_emails_map,
)
return onyx_groups

View File

@@ -161,7 +161,10 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id]
if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json

View File

@@ -150,9 +150,9 @@ def _handle_standard_answers(
db_session=db_session,
description="",
user_id=None,
persona_id=slack_channel_config.persona.id
if slack_channel_config.persona
else 0,
persona_id=(
slack_channel_config.persona.id if slack_channel_config.persona else 0
),
onyxbot_flow=True,
slack_thread_id=slack_thread_id,
)
@@ -182,7 +182,7 @@ def _handle_standard_answers(
formatted_answers.append(formatted_answer)
answer_message = "\n\n".join(formatted_answers)
_ = create_new_chat_message(
chat_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=new_user_message,
prompt_id=prompt.id if prompt else None,
@@ -191,8 +191,13 @@ def _handle_standard_answers(
message_type=MessageType.ASSISTANT,
error=None,
db_session=db_session,
commit=True,
commit=False,
)
# attach the standard answers to the chat message
chat_message.standard_answers = [
standard_answer for standard_answer, _ in matching_standard_answers
]
db_session.commit()
update_emote_react(
emoji=DANSWER_REACT_EMOJI,

View File

@@ -1,19 +1,23 @@
import csv
import io
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Literal
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time
from ee.onyx.db.query_history import get_page_of_chat_sessions
from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count
from ee.onyx.server.query_history.models import ChatSessionMinimal
from ee.onyx.server.query_history.models import ChatSessionSnapshot
from ee.onyx.server.query_history.models import MessageSnapshot
from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
@@ -23,257 +27,15 @@ from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
from onyx.db.engine import get_session
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import User
from onyx.server.documents.models import PaginatedReturn
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter()
class AbridgedSearchDoc(BaseModel):
"""A subset of the info present in `SearchDoc`"""
document_id: str
semantic_identifier: str
link: str | None
class MessageSnapshot(BaseModel):
message: str
message_type: MessageType
documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
time_created: datetime
@classmethod
def build(cls, message: ChatMessage) -> "MessageSnapshot":
latest_messages_feedback_obj = (
message.chat_message_feedbacks[-1]
if len(message.chat_message_feedbacks) > 0
else None
)
feedback_type = (
(
QAFeedbackType.LIKE
if latest_messages_feedback_obj.is_positive
else QAFeedbackType.DISLIKE
)
if latest_messages_feedback_obj
else None
)
feedback_text = (
latest_messages_feedback_obj.feedback_text
if latest_messages_feedback_obj
else None
)
return cls(
message=message.message,
message_type=message.message_type,
documents=[
AbridgedSearchDoc(
document_id=document.document_id,
semantic_identifier=document.semantic_id,
link=document.link,
)
for document in message.search_docs
],
feedback_type=feedback_type,
feedback_text=feedback_text,
time_created=message.time_sent,
)
class ChatSessionMinimal(BaseModel):
id: UUID
user_email: str
name: str | None
first_user_message: str
first_ai_message: str
assistant_id: int | None
assistant_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | Literal["mixed"] | None
flow_type: SessionType
conversation_length: int
class ChatSessionSnapshot(BaseModel):
id: UUID
user_email: str
name: str | None
messages: list[MessageSnapshot]
assistant_id: int | None
assistant_name: str | None
time_created: datetime
flow_type: SessionType
class QuestionAnswerPairSnapshot(BaseModel):
chat_session_id: UUID
# 1-indexed message number in the chat_session
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
message_pair_num: int
user_message: str
ai_response: str
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str | None
user_email: str
time_created: datetime
flow_type: SessionType
@classmethod
def from_chat_session_snapshot(
cls,
chat_session_snapshot: ChatSessionSnapshot,
) -> list["QuestionAnswerPairSnapshot"]:
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
for ind in range(1, len(chat_session_snapshot.messages), 2):
message_pairs.append(
(
chat_session_snapshot.messages[ind - 1],
chat_session_snapshot.messages[ind],
)
)
return [
cls(
chat_session_id=chat_session_snapshot.id,
message_pair_num=ind + 1,
user_message=user_message.message,
ai_response=ai_message.message,
retrieved_documents=ai_message.documents,
feedback_type=ai_message.feedback_type,
feedback_text=ai_message.feedback_text,
persona_name=chat_session_snapshot.assistant_name,
user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created,
flow_type=chat_session_snapshot.flow_type,
)
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
"user_message": self.user_message,
"ai_response": self.ai_response,
"retrieved_documents": "|".join(
[
doc.link or doc.semantic_identifier
for doc in self.retrieved_documents
]
),
"feedback_type": self.feedback_type.value if self.feedback_type else "",
"feedback_text": self.feedback_text or "",
"persona_name": self.persona_name,
"user_email": self.user_email,
"time_created": str(self.time_created),
"flow_type": self.flow_type,
}
def determine_flow_type(chat_session: ChatSession) -> SessionType:
return SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
def fetch_and_process_chat_session_history_minimal(
db_session: Session,
start: datetime,
end: datetime,
feedback_filter: QAFeedbackType | None = None,
limit: int | None = 500,
) -> list[ChatSessionMinimal]:
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
minimal_sessions = []
for chat_session in chat_sessions:
if not chat_session.messages:
continue
first_user_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.USER
),
"",
)
first_ai_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.ASSISTANT
),
"",
)
has_positive_feedback = any(
feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
)
has_negative_feedback = any(
not feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
)
feedback_type: QAFeedbackType | Literal["mixed"] | None = (
"mixed"
if has_positive_feedback and has_negative_feedback
else QAFeedbackType.LIKE
if has_positive_feedback
else QAFeedbackType.DISLIKE
if has_negative_feedback
else None
)
if feedback_filter:
if feedback_filter == QAFeedbackType.LIKE and not has_positive_feedback:
continue
if feedback_filter == QAFeedbackType.DISLIKE and not has_negative_feedback:
continue
flow_type = determine_flow_type(chat_session)
minimal_sessions.append(
ChatSessionMinimal(
id=chat_session.id,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
assistant_id=chat_session.persona_id,
assistant_name=(
chat_session.persona.name if chat_session.persona else None
),
time_created=chat_session.time_created,
feedback_type=feedback_type,
flow_type=flow_type,
conversation_length=len(
[
m
for m in chat_session.messages
if m.message_type != MessageType.SYSTEM
]
),
)
)
return minimal_sessions
def fetch_and_process_chat_session_history(
db_session: Session,
start: datetime,
@@ -319,7 +81,7 @@ def snapshot_from_chat_session(
except RuntimeError:
return None
flow_type = determine_flow_type(chat_session)
flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT
return ChatSessionSnapshot(
id=chat_session.id,
@@ -371,22 +133,38 @@ def get_user_chat_sessions(
@router.get("/admin/chat-session-history")
def get_chat_session_history(
page_num: int = Query(0, ge=0),
page_size: int = Query(10, ge=1),
feedback_type: QAFeedbackType | None = None,
start: datetime | None = None,
end: datetime | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[ChatSessionMinimal]:
return fetch_and_process_chat_session_history_minimal(
) -> PaginatedReturn[ChatSessionMinimal]:
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
page_size=page_size,
db_session=db_session,
start=start
or (
datetime.now(tz=timezone.utc) - timedelta(days=30)
), # default is 30d lookback
end=end or datetime.now(tz=timezone.utc),
start_time=start_time,
end_time=end_time,
feedback_filter=feedback_type,
)
total_filtered_chat_sessions_count = get_total_filtered_chat_sessions_count(
db_session=db_session,
start_time=start_time,
end_time=end_time,
feedback_filter=feedback_type,
)
return PaginatedReturn(
items=[
ChatSessionMinimal.from_chat_session(chat_session)
for chat_session in page_of_chat_sessions
],
total_items=total_filtered_chat_sessions_count,
)
@router.get("/admin/chat-session-history/{chat_session_id}")
def get_chat_session_admin(

View File

@@ -0,0 +1,218 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
from onyx.auth.users import get_display_email
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import SessionType
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
class AbridgedSearchDoc(BaseModel):
"""A subset of the info present in `SearchDoc`"""
document_id: str
semantic_identifier: str
link: str | None
class MessageSnapshot(BaseModel):
id: int
message: str
message_type: MessageType
documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
time_created: datetime
@classmethod
def build(cls, message: ChatMessage) -> "MessageSnapshot":
latest_messages_feedback_obj = (
message.chat_message_feedbacks[-1]
if len(message.chat_message_feedbacks) > 0
else None
)
feedback_type = (
(
QAFeedbackType.LIKE
if latest_messages_feedback_obj.is_positive
else QAFeedbackType.DISLIKE
)
if latest_messages_feedback_obj
else None
)
feedback_text = (
latest_messages_feedback_obj.feedback_text
if latest_messages_feedback_obj
else None
)
return cls(
id=message.id,
message=message.message,
message_type=message.message_type,
documents=[
AbridgedSearchDoc(
document_id=document.document_id,
semantic_identifier=document.semantic_id,
link=document.link,
)
for document in message.search_docs
],
feedback_type=feedback_type,
feedback_text=feedback_text,
time_created=message.time_sent,
)
class ChatSessionMinimal(BaseModel):
id: UUID
user_email: str
name: str | None
first_user_message: str
first_ai_message: str
assistant_id: int | None
assistant_name: str | None
time_created: datetime
feedback_type: QAFeedbackType | None
flow_type: SessionType
conversation_length: int
@classmethod
def from_chat_session(cls, chat_session: ChatSession) -> "ChatSessionMinimal":
first_user_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.USER
),
"",
)
first_ai_message = next(
(
message.message
for message in chat_session.messages
if message.message_type == MessageType.ASSISTANT
),
"",
)
list_of_message_feedbacks = [
feedback.is_positive
for message in chat_session.messages
for feedback in message.chat_message_feedbacks
]
session_feedback_type = None
if list_of_message_feedbacks:
if all(list_of_message_feedbacks):
session_feedback_type = QAFeedbackType.LIKE
elif not any(list_of_message_feedbacks):
session_feedback_type = QAFeedbackType.DISLIKE
else:
session_feedback_type = QAFeedbackType.MIXED
return cls(
id=chat_session.id,
user_email=get_display_email(
chat_session.user.email if chat_session.user else None
),
name=chat_session.description,
first_user_message=first_user_message,
first_ai_message=first_ai_message,
assistant_id=chat_session.persona_id,
assistant_name=(
chat_session.persona.name if chat_session.persona else None
),
time_created=chat_session.time_created,
feedback_type=session_feedback_type,
flow_type=SessionType.SLACK
if chat_session.onyxbot_flow
else SessionType.CHAT,
conversation_length=len(
[
message
for message in chat_session.messages
if message.message_type != MessageType.SYSTEM
]
),
)
class ChatSessionSnapshot(BaseModel):
id: UUID
user_email: str
name: str | None
messages: list[MessageSnapshot]
assistant_id: int | None
assistant_name: str | None
time_created: datetime
flow_type: SessionType
class QuestionAnswerPairSnapshot(BaseModel):
chat_session_id: UUID
# 1-indexed message number in the chat_session
# e.g. the first message pair in the chat_session is 1, the second is 2, etc.
message_pair_num: int
user_message: str
ai_response: str
retrieved_documents: list[AbridgedSearchDoc]
feedback_type: QAFeedbackType | None
feedback_text: str | None
persona_name: str | None
user_email: str
time_created: datetime
flow_type: SessionType
@classmethod
def from_chat_session_snapshot(
cls,
chat_session_snapshot: ChatSessionSnapshot,
) -> list["QuestionAnswerPairSnapshot"]:
message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = []
for ind in range(1, len(chat_session_snapshot.messages), 2):
message_pairs.append(
(
chat_session_snapshot.messages[ind - 1],
chat_session_snapshot.messages[ind],
)
)
return [
cls(
chat_session_id=chat_session_snapshot.id,
message_pair_num=ind + 1,
user_message=user_message.message,
ai_response=ai_message.message,
retrieved_documents=ai_message.documents,
feedback_type=ai_message.feedback_type,
feedback_text=ai_message.feedback_text,
persona_name=chat_session_snapshot.assistant_name,
user_email=get_display_email(chat_session_snapshot.user_email),
time_created=user_message.time_created,
flow_type=chat_session_snapshot.flow_type,
)
for ind, (user_message, ai_message) in enumerate(message_pairs)
]
def to_json(self) -> dict[str, str | None]:
return {
"chat_session_id": str(self.chat_session_id),
"message_pair_num": str(self.message_pair_num),
"user_message": self.user_message,
"ai_response": self.ai_response,
"retrieved_documents": "|".join(
[
doc.link or doc.semantic_identifier
for doc in self.retrieved_documents
]
),
"feedback_type": self.feedback_type.value if self.feedback_type else "",
"feedback_text": self.feedback_text or "",
"persona_name": self.persona_name,
"user_email": self.user_email,
"time_created": str(self.time_created),
"flow_type": self.flow_type,
}

View File

@@ -24,7 +24,7 @@ from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
from onyx.db.persona import upsert_persona
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.settings.models import Settings
from onyx.server.settings.store import store_settings as store_base_settings
@@ -57,7 +57,7 @@ class SeedConfiguration(BaseModel):
llms: list[LLMProviderUpsertRequest] | None = None
admin_user_emails: list[str] | None = None
seeded_logo_path: str | None = None
personas: list[CreatePersonaRequest] | None = None
personas: list[PersonaUpsertRequest] | None = None
settings: Settings | None = None
enterprise_settings: EnterpriseSettings | None = None
@@ -128,7 +128,7 @@ def _seed_llms(
)
def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) -> None:
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
if personas:
logger.notice("Seeding Personas")
for persona in personas:

View File

@@ -5,7 +5,7 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits
from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user
from ee.onyx.db.token_limit import insert_user_group_token_rate_limit
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
@@ -51,8 +51,10 @@ def get_group_token_limit_settings(
) -> list[TokenRateLimitDisplay]:
return [
TokenRateLimitDisplay.from_db(token_rate_limit)
for token_rate_limit in fetch_user_group_token_rate_limits(
db_session, group_id, user
for token_rate_limit in fetch_user_group_token_rate_limits_for_user(
db_session=db_session,
group_id=group_id,
user=user,
)
]

View File

@@ -19,6 +19,9 @@ def prefix_external_group(ext_group_name: str) -> str:
return f"external_group:{ext_group_name}"
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"
def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str:
"""
External groups may collide across sources, every source needs its own prefix.
NOTE: the name is lowercased to handle case sensitivity for group names
"""
return f"{source.value}_{ext_group_name}".lower()

View File

@@ -23,6 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
preferences_data = cast(
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
)
print("preferences_data", preferences_data)
return UserPreferences(**preferences_data)
except KvKeyNotFoundError:
return UserPreferences(

View File

@@ -20,6 +20,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
@@ -100,6 +101,10 @@ def on_task_postrun(
if not task_id:
return
if task.name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
# this is a cloud / all tenant task ... no postrun is needed
return
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
@@ -161,9 +166,34 @@ def on_task_postrun(
return
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
# NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn"
# But something is blocking set_start_method from working in the cloud unless
# force=True. so we use force=True as a fallback.
all_start_methods: list[str] = multiprocessing.get_all_start_methods()
logger.info(f"Multiprocessing all start methods: {all_start_methods}")
try:
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
except Exception:
logger.info(
"Multiprocessing set_start_method exceptioned. Trying force=True..."
)
try:
multiprocessing.set_start_method(
"spawn", force=True
) # fork is unsafe, set to spawn
except Exception:
logger.info(
"Multiprocessing set_start_method force=True exceptioned even with force=True."
)
logger.info(
f"Multiprocessing selected start method: {multiprocessing.get_start_method()}"
)
def wait_for_redis(sender: Any, **kwargs: Any) -> None:

View File

@@ -1,5 +1,6 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -7,12 +8,14 @@ from celery.beat import PersistentScheduler # type: ignore
from celery.signals import beat_init
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -28,7 +31,7 @@ class DynamicTenantScheduler(PersistentScheduler):
self._last_reload = self.app.now() - self._reload_interval
# Let the parent class handle store initialization
self.setup_schedule()
self._update_tenant_tasks()
self._try_updating_schedule()
logger.info(f"Set reload interval to {self._reload_interval}")
def setup_schedule(self) -> None:
@@ -44,105 +47,154 @@ class DynamicTenantScheduler(PersistentScheduler):
or (now - self._last_reload) > self._reload_interval
):
logger.info("Reload interval reached, initiating task update")
self._update_tenant_tasks()
try:
self._try_updating_schedule()
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
self._last_reload = now
logger.info("Task update completed, reset reload timer")
return retval
def _update_tenant_tasks(self) -> None:
logger.info("Starting task update process")
try:
logger.info("Fetching all IDs")
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} IDs")
def _generate_schedule(
self, tenant_ids: list[str] | list[None]
) -> dict[str, dict[str, Any]]:
"""Given a list of tenant id's, generates a new beat schedule for celery."""
logger.info("Fetching tasks to schedule")
logger.info("Fetching tasks to schedule")
tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
new_schedule: dict[str, dict[str, Any]] = {}
if MULTI_TENANT:
# cloud tasks only need the single task beat across all tenants
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule",
"get_cloud_tasks_to_schedule",
)
new_beat_schedule: dict[str, dict[str, Any]] = {}
cloud_tasks_to_schedule: list[
dict[str, Any]
] = get_cloud_tasks_to_schedule()
for task in cloud_tasks_to_schedule:
task_name = task["name"]
cloud_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
cloud_task["options"] = options
new_schedule[task_name] = cloud_task
current_schedule = self.schedule.items()
# regular task beats are multiplied across all tenants
get_tasks_to_schedule = fetch_versioned_implementation(
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
)
existing_tenants = set()
for task_name, _ in current_schedule:
if "-" in task_name:
existing_tenants.add(task_name.split("-")[-1])
logger.info(f"Found {len(existing_tenants)} existing items in schedule")
tasks_to_schedule: list[dict[str, Any]] = get_tasks_to_schedule()
for tenant_id in tenant_ids:
if (
IGNORED_SYNCING_TENANT_LIST
and tenant_id in IGNORED_SYNCING_TENANT_LIST
):
logger.info(
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
if tenant_id not in existing_tenants:
logger.info(f"Processing new item: {tenant_id}")
for task in tasks_to_schedule():
task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {task_name}")
new_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(f"Adding options to task {task_name}: {options}")
new_task["options"] = options
new_beat_schedule[task_name] = new_task
if self._should_update_schedule(current_schedule, new_beat_schedule):
for tenant_id in tenant_ids:
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_beat_schedule),
"current_tasks": len(current_schedule),
},
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
)
continue
# Create schedule entries
entries = {}
for name, entry in new_beat_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
for task in tasks_to_schedule:
task_name = task["name"]
tenant_task_name = f"{task['name']}-{tenant_id}"
logger.debug(f"Creating task configuration for {tenant_task_name}")
tenant_task = {
"task": task["task"],
"schedule": task["schedule"],
"kwargs": {"tenant_id": tenant_id},
}
if options := task.get("options"):
logger.debug(
f"Adding options to task {tenant_task_name}: {options}"
)
tenant_task["options"] = options
new_schedule[tenant_task_name] = tenant_task
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
return new_schedule
# Ensure changes are persisted
self.sync()
def _try_updating_schedule(self) -> None:
"""Only updates the actual beat schedule on the celery app when it changes"""
logger.info("Schedule update completed successfully")
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tasks: {str(e)}")
logger.info("_try_updating_schedule starting")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
) -> bool:
"""Compare schedules to determine if an update is needed."""
logger.debug("Comparing current and new schedules")
current_tasks = set(name for name, _ in current_schedule)
new_tasks = set(new_schedule.keys())
needs_update = current_tasks != new_tasks
logger.debug(f"Schedule update needed: {needs_update}")
return needs_update
tenant_ids = get_all_tenant_ids()
logger.info(f"Found {len(tenant_ids)} IDs")
# get current schedule and extract current tenants
current_schedule = self.schedule.items()
current_tenants = set()
for task_name, _ in current_schedule:
task_name = cast(str, task_name)
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
continue
if "_" in task_name:
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
# -> "12345678-abcd-efgh-ijkl-12345678"
current_tenants.add(task_name.split("_")[-1])
logger.info(f"Found {len(current_tenants)} existing items in schedule")
for tenant_id in tenant_ids:
if tenant_id not in current_tenants:
logger.info(f"Processing new tenant: {tenant_id}")
new_schedule = self._generate_schedule(tenant_ids)
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
logger.info(
"_try_updating_schedule: Current schedule is up to date, no changes needed"
)
return
logger.info(
"Schedule update required",
extra={
"new_tasks": len(new_schedule),
"current_tasks": len(current_schedule),
},
)
# Create schedule entries
entries = {}
for name, entry in new_schedule.items():
entries[name] = self.Entry(
name=name,
app=self.app,
task=entry["task"],
schedule=entry["schedule"],
options=entry.get("options", {}),
kwargs=entry.get("kwargs", {}),
)
# Update the schedule using the scheduler's methods
self.schedule.clear()
self.schedule.update(entries)
# Ensure changes are persisted
self.sync()
logger.info("_try_updating_schedule: Schedule updated successfully")
@staticmethod
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
"""Compare schedules to determine if an update is needed.
True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys())
if current_tasks != new_tasks:
return False
return True
@beat_init.connect

View File

@@ -1,9 +1,9 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@@ -49,17 +49,16 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -1,9 +1,9 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_process_init
@@ -50,22 +50,21 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
# rkuo: been seeing transient connection exceptions here, so upping the connection count
# from just concurrency/concurrency to concurrency/concurrency*2
SqlEngine.init_engine(
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
)
# rkuo: Transient errors keep happening in the indexing watchdog threads.
# "SSL connection has been closed unexpectedly"
# actually setting the spawn method in the cloud fixes 95% of these.
# setting pre ping might help even more, but not worrying about that yet
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -1,9 +1,9 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
@@ -15,7 +15,6 @@ from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
@@ -49,17 +48,18 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -0,0 +1,95 @@
import multiprocessing
from typing import Any
from celery import Celery
from celery import signals
from celery import Task
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_MONITORING_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
@signals.task_postrun.connect
def on_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_MONITORING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=3)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
app_base.on_secondary_worker_init(sender, **kwargs)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_ready(sender, **kwargs)
@worker_shutdown.connect
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
@signals.setup_logging.connect
def on_setup_logging(
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
]
)

View File

@@ -1,5 +1,4 @@
import logging
import multiprocessing
from typing import Any
from typing import cast
@@ -7,6 +6,7 @@ from celery import bootsteps # type: ignore
from celery import Celery
from celery import signals
from celery import Task
from celery.apps.worker import Worker
from celery.exceptions import WorkerShutdown
from celery.signals import celeryd_init
from celery.signals import worker_init
@@ -17,7 +17,7 @@ from redis.lock import Lock as RedisLock
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.tasks.indexing.tasks import (
from onyx.background.celery.tasks.indexing.utils import (
get_unfenced_index_attempt_ids,
)
from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
@@ -73,14 +73,13 @@ def on_task_postrun(
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
app_base.on_celeryd_init(sender, conf, **kwargs)
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
@@ -135,7 +134,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
raise WorkerShutdown("Primary worker lock could not be acquired!")
# tacking on our own user data to the sender
sender.primary_worker_lock = lock
sender.primary_worker_lock = lock # type: ignore
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)

View File

@@ -14,6 +14,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import Document
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import TaskStatus
from onyx.db.models import TaskQueueState
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -41,14 +42,21 @@ def _get_deletion_status(
return None
redis_connector = RedisConnector(tenant_id, cc_pair.id)
if not redis_connector.delete.fenced:
return None
if redis_connector.delete.fenced:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.STARTED,
)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return TaskQueueState(
task_id="",
task_name=redis_connector.delete.fence_key,
status=TaskStatus.PENDING,
)
return None
def get_deletion_attempt_snapshot(

View File

@@ -0,0 +1,21 @@
import onyx.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Monitoring worker specific settings
worker_concurrency = 1 # Single worker is sufficient for monitoring
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -2,24 +2,43 @@ from datetime import timedelta
from typing import Any
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from shared_configs.configs import MULTI_TENANT
# choosing 15 minutes because it roughly gives us enough time to process many tasks
# we might be able to reduce this greatly if we can run a unified
# loop across all tenants rather than tasks per tenant
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# tasks that only run in the cloud
# the name attribute must start with ONYX_CELERY_CLOUD_PREFIX = "cloud" to be filtered
# by the DynamicTenantScheduler
cloud_tasks_to_schedule = [
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
"task": OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGHEST,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that run in either self-hosted on cloud
tasks_to_schedule = [
{
"name": "check-for-vespa-sync",
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -28,16 +47,7 @@ tasks_to_schedule = [
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -46,7 +56,7 @@ tasks_to_schedule = [
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -64,16 +74,26 @@ tasks_to_schedule = [
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "monitor-background-processes",
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
{
"name": "check-for-doc-permissions-sync",
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
@@ -82,12 +102,25 @@ tasks_to_schedule = [
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
if not MULTI_TENANT:
tasks_to_schedule.append(
{
"name": "check-for-indexing",
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.MEDIUM,
"expires": BEAT_EXPIRES_DEFAULT,
},
}
)
# Only add the LLM model update task if the API URL is configured
if LLM_MODEL_UPDATE_API_URL:
tasks_to_schedule.append(
@@ -103,5 +136,9 @@ if LLM_MODEL_UPDATE_API_URL:
)
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
return cloud_tasks_to_schedule
def get_tasks_to_schedule() -> list[dict[str, Any]]:
return tasks_to_schedule

View File

@@ -10,14 +10,17 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncType
from onyx.db.search_settings import get_all_search_settings
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
from onyx.redis.redis_pool import get_redis_client
@@ -41,7 +44,7 @@ def check_for_connector_deletion_task(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
@@ -113,11 +116,21 @@ def try_generate_document_cc_pair_cleanup_tasks(
# we need to load the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
# there should be no in-progress sync records if this is up to date
# clean it up just in case things got into a bad state
cleanup_sync_records(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
return None
# set a basic fence to start
@@ -126,6 +139,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
submitted=datetime.now(timezone.utc),
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
)
redis_connector.delete.set_fence(fence_payload)
try:

View File

@@ -22,9 +22,9 @@ from ee.onyx.external_permissions.sync_params import (
from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
@@ -99,7 +99,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
@@ -279,7 +279,10 @@ def connector_permission_sync_generator_task(
try:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
@@ -391,5 +394,7 @@ def update_external_document_permissions_task(
)
return True
except Exception:
logger.exception("Error Syncing Document Permissions")
logger.exception(
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
)
return False

View File

@@ -22,7 +22,7 @@ from ee.onyx.external_permissions.sync_params import (
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -99,7 +99,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
@@ -250,7 +250,10 @@ def connector_external_group_sync_generator_task(
return None
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,519 @@
import time
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt
from onyx.db.index_attempt import delete_index_attempt
from onyx.db.index_attempt import get_all_index_attempts_by_status
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# inner/outer/inner double check pattern to avoid race conditions when checking for
# bad state
# inner = index_attempt in non terminal state
# outer = r.fence_key down
# check the db for index attempts in a non terminal state
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
# if the fence is down / doesn't exist, possible error but not confirmed
if r.exists(fence_key):
continue
# Between the time the attempts are first looked up and the time we see the fence down,
# the attempt may have completed and taken down the fence normally.
# We need to double check that the index attempt is still in a non terminal state
# and matches the original state, which confirms we are really in a bad state.
attempt_2 = get_index_attempt(db_session, attempt.id)
if not attempt_2:
continue
if attempt.status != attempt_2.status:
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
class IndexingCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
parent_pid: int,
stop_key: str,
generator_progress_key: str,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, tag: str, amount: int) -> None:
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
# so leave this code in until we're ready to test it.
# if self.parent_pid:
# # check if the parent pid is alive so we aren't running as a zombie
# now = time.monotonic()
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
# try:
# # this is unintuitive, but it checks if the parent pid is still running
# os.kill(self.parent_pid, 0)
# except Exception:
# logger.exception("IndexingCallback - parent pid check exceptioned")
# raise
# self.last_parent_check = now
try:
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_tag = tag
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
redis_lock_dump(self.redis_lock, self.redis_client)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
def validate_indexing_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"validate_indexing_fence - could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
payload = redis_connector_index.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
if redis_connector_index.active():
return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
if found:
# the celery task exists in the redis queue
redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
mark_attempt_failed(
payload.index_attempt_id,
db_session,
"validate_indexing_fence - Canceling index attempt due to missing celery tasks: "
f"index_attempt={payload.index_attempt_id}",
)
except Exception:
logger.exception(
"validate_indexing_fence - Exception while marking index attempt as failed: "
f"index_attempt={payload.index_attempt_id}",
)
redis_connector_index.reset()
return
def validate_indexing_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
reserved_indexing_tasks,
r_celery,
db_session,
)
return
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
"""Checks various global settings and past indexing attempts to determine if
we should try to start indexing the cc pair / search setting combination.
Note that tactical checks such as preventing overlap with a currently running task
are not handled here.
Return True if we should try to index, False if not.
"""
connector = cc_pair.connector
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
if connector.refresh_freq is None:
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
return False
return True
def try_creating_indexing_task(
celery_app: Celery,
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
"""
LOCK_TIMEOUT = 30
index_attempt_id: int | None = None
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
redis_connector = RedisConnector(tenant_id, cc_pair.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
# skip if already indexing
if redis_connector_index.fenced:
return None
# skip indexing if the cc_pair is deleting
if redis_connector.delete.fenced:
return None
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
redis_connector_index.generator_clear()
# set a basic fence to start
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=None,
)
redis_connector_index.set_active()
redis_connector_index.set_fence(payload)
# create the index attempt for tracking purposes
# code elsewhere checks for index attempts without an associated redis key
# and cleans them up
# therefore we must create the attempt and the task after the fence goes up
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
db_session=db_session,
)
custom_task_id = redis_connector_index.generate_generator_task_id()
# when the task is sent, we have yet to finish setting up the fence
# therefore, the task must contain code that blocks until the fence is ready
result = celery_app.send_task(
OnyxCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=OnyxCeleryQueues.CONNECTOR_INDEXING,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
)
if not result:
raise RuntimeError("send_task for connector_indexing_proxy_task failed.")
# now fill out the fence with the rest of the data
redis_connector_index.set_active()
payload.index_attempt_id = index_attempt_id
payload.celery_task_id = result.id
redis_connector_index.set_fence(payload)
except Exception:
task_logger.exception(
f"try_creating_indexing_task - Unexpected exception: "
f"cc_pair={cc_pair.id} "
f"search_settings={search_settings.id}"
)
if index_attempt_id is not None:
delete_index_attempt(db_session, index_attempt_id)
redis_connector_index.set_fence(None)
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id

View File

@@ -0,0 +1,458 @@
import json
from collections.abc import Callable
from datetime import timedelta
from typing import Any
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import BaseModel
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import IndexAttempt
from onyx.db.models import SyncRecord
from onyx.db.models import UserGroup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
_CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT = (
"monitoring_connector_index_attempt_start_latency:{cc_pair_id}:{index_attempt_id}"
)
_CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
)
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
redis_std.set(key, "1", ex=24 * 60 * 60) # Expire after 1 day
def _has_metric_been_emitted(redis_std: Redis, key: str) -> bool:
"""Check if a metric has been emitted by checking for existence of Redis key"""
return bool(redis_std.exists(key))
class Metric(BaseModel):
key: str | None # only required if we need to store that we have emitted this metric
name: str
value: Any
tags: dict[str, str]
def log(self) -> None:
"""Log the metric in a standardized format"""
data = {
"metric": self.name,
"value": self.value,
"tags": self.tags,
}
task_logger.info(json.dumps(data))
def emit(self, tenant_id: str | None) -> None:
# Convert value to appropriate type based on the input value
bool_value = None
float_value = None
int_value = None
string_value = None
# NOTE: have to do bool first, since `isinstance(True, int)` is true
# e.g. bool is a subclass of int
if isinstance(self.value, bool):
bool_value = self.value
elif isinstance(self.value, int):
int_value = self.value
elif isinstance(self.value, float):
float_value = self.value
elif isinstance(self.value, str):
string_value = self.value
else:
task_logger.error(
f"Invalid metric value type: {type(self.value)} "
f"({self.value}) for metric {self.name}."
)
return
# don't send None values over the wire
data = {
k: v
for k, v in {
"metric_name": self.name,
"float_value": float_value,
"int_value": int_value,
"string_value": string_value,
"bool_value": bool_value,
"tags": self.tags,
}.items()
if v is not None
}
optional_telemetry(
record_type=RecordType.METRIC,
data=data,
tenant_id=tenant_id,
)
def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]:
"""Collect metrics about queue lengths for different Celery queues"""
metrics = []
queue_mappings = {
"celery_queue_length": "celery",
"indexing_queue_length": "indexing",
"sync_queue_length": "sync",
"deletion_queue_length": "deletion",
"pruning_queue_length": "pruning",
"permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
"external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
"permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
}
for name, queue in queue_mappings.items():
metrics.append(
Metric(
key=None,
name=name,
value=celery_get_queue_length(queue, redis_celery),
tags={"queue": name},
)
)
return metrics
def _build_connector_start_latency_metric(
cc_pair: ConnectorCredentialPair,
recent_attempt: IndexAttempt,
second_most_recent_attempt: IndexAttempt | None,
redis_std: Redis,
) -> Metric | None:
if not recent_attempt.time_started:
return None
# check if we already emitted a metric for this index attempt
metric_key = _CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=recent_attempt.id,
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping metric for connector {cc_pair.connector.id} "
f"index attempt {recent_attempt.id} because it has already been "
"emitted"
)
return None
# Connector start latency
# first run case - we should start as soon as it's created
if not second_most_recent_attempt:
desired_start_time = cc_pair.connector.time_created
else:
if not cc_pair.connector.refresh_freq:
task_logger.error(
"Found non-initial index attempt for connector "
"without refresh_freq. This should never happen."
)
return None
desired_start_time = second_most_recent_attempt.time_updated + timedelta(
seconds=cc_pair.connector.refresh_freq
)
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
return Metric(
key=metric_key,
name="connector_start_latency",
value=start_latency,
tags={},
)
def _build_run_success_metrics(
cc_pair: ConnectorCredentialPair,
recent_attempts: list[IndexAttempt],
redis_std: Redis,
) -> list[Metric]:
metrics = []
for attempt in recent_attempts:
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
cc_pair_id=cc_pair.id,
index_attempt_id=attempt.id,
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.info(
f"Skipping metric for connector {cc_pair.connector.id} "
f"index attempt {attempt.id} because it has already been "
"emitted"
)
continue
if attempt.status in [
IndexingStatus.SUCCESS,
IndexingStatus.FAILED,
IndexingStatus.CANCELED,
]:
metrics.append(
Metric(
key=metric_key,
name="connector_run_succeeded",
value=attempt.status == IndexingStatus.SUCCESS,
tags={"source": str(cc_pair.connector.source)},
)
)
return metrics
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""Collect metrics about connector runs from the past hour"""
# NOTE: use get_db_current_time since the IndexAttempt times are set based on DB time
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all connector credential pairs
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
metrics = []
for cc_pair in cc_pairs:
# Get all attempts in the last hour
recent_attempts = (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair.id,
IndexAttempt.time_created >= one_hour_ago,
)
.order_by(IndexAttempt.time_created.desc())
.all()
)
most_recent_attempt = recent_attempts[0] if recent_attempts else None
second_most_recent_attempt = (
recent_attempts[1] if len(recent_attempts) > 1 else None
)
# if no metric to emit, skip
if most_recent_attempt is None:
continue
# Connector start latency
start_latency_metric = _build_connector_start_latency_metric(
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
)
if start_latency_metric:
metrics.append(start_latency_metric)
# Connector run success/failure
run_success_metrics = _build_run_success_metrics(
cc_pair, recent_attempts, redis_std
)
metrics.extend(run_success_metrics)
return metrics
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
"""Collect metrics about document set and group syncing speed"""
# NOTE: use get_db_current_time since the SyncRecord times are set based on DB time
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
# Get all sync records from the last hour
recent_sync_records = db_session.scalars(
select(SyncRecord)
.where(SyncRecord.sync_start_time >= one_hour_ago)
.order_by(SyncRecord.sync_start_time.desc())
).all()
metrics = []
for sync_record in recent_sync_records:
# Skip if no end time (sync still in progress)
if not sync_record.sync_end_time:
continue
# Check if we already emitted a metric for this sync record
metric_key = (
f"sync_speed:{sync_record.sync_type}:"
f"{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, metric_key):
task_logger.debug(
f"Skipping metric for sync record {sync_record.id} "
"because it has already been emitted"
)
continue
# Calculate sync duration in minutes
sync_duration_mins = (
sync_record.sync_end_time - sync_record.sync_start_time
).total_seconds() / 60.0
# Calculate sync speed (docs/min) - avoid division by zero
sync_speed = (
sync_record.num_docs_synced / sync_duration_mins
if sync_duration_mins > 0
else None
)
if sync_speed is None:
task_logger.error(
"Something went wrong with sync speed calculation. "
f"Sync record: {sync_record.id}"
)
continue
metrics.append(
Metric(
key=metric_key,
name="sync_speed_docs_per_min",
value=sync_speed,
tags={
"sync_type": str(sync_record.sync_type),
"status": str(sync_record.sync_status),
},
)
)
# Add sync start latency metric
start_latency_key = (
f"sync_start_latency:{sync_record.sync_type}"
f":{sync_record.entity_id}:{sync_record.id}"
)
if _has_metric_been_emitted(redis_std, start_latency_key):
task_logger.debug(
f"Skipping start latency metric for sync record {sync_record.id} "
"because it has already been emitted"
)
continue
# Get the entity's last update time based on sync type
entity: DocumentSet | UserGroup | None = None
if sync_record.sync_type == SyncType.DOCUMENT_SET:
entity = db_session.scalar(
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
)
elif sync_record.sync_type == SyncType.USER_GROUP:
entity = db_session.scalar(
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
)
else:
# Skip other sync types
task_logger.debug(
f"Skipping sync record {sync_record.id} "
f"with type {sync_record.sync_type} "
f"and id {sync_record.entity_id} "
"because it is not a document set or user group"
)
continue
if entity is None:
task_logger.error(
f"Could not find entity for sync record {sync_record.id} "
f"with type {sync_record.sync_type} and id {sync_record.entity_id}"
)
continue
# Calculate start latency in seconds
start_latency = (
sync_record.sync_start_time - entity.time_last_modified_by_user
).total_seconds()
if start_latency < 0:
task_logger.error(
f"Start latency is negative for sync record {sync_record.id} "
f"with type {sync_record.sync_type} and id {sync_record.entity_id}."
"This is likely because the entity was updated between the time the "
"time the sync finished and this job ran. Skipping."
)
continue
metrics.append(
Metric(
key=start_latency_key,
name="sync_start_latency_seconds",
value=start_latency,
tags={
"sync_type": str(sync_record.sync_type),
},
)
)
return metrics
@shared_task(
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
time_limit=_MONITORING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
"""Collect and emit metrics about background processes.
This task runs periodically to gather metrics about:
- Queue lengths for different Celery queues
- Connector run metrics (start latency, success rate)
- Syncing speed metrics
- Worker status and task counts
"""
task_logger.info("Starting background monitoring")
r = get_redis_client(tenant_id=tenant_id)
lock_monitoring: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
timeout=_MONITORING_SOFT_TIME_LIMIT,
)
# these tasks should never overlap
if not lock_monitoring.acquire(blocking=False):
task_logger.info("Skipping monitoring task because it is already running")
return None
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client(tenant_id=tenant_id)
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
lambda: _collect_queue_metrics(redis_celery),
lambda: _collect_connector_metrics(db_session, redis_std),
lambda: _collect_sync_metrics(db_session, redis_std),
]
# Collect and log each metric
with get_session_with_tenant(tenant_id) as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
metric.log()
metric.emit(tenant_id)
if metric.key:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
task_logger.exception("Error collecting background process metrics")
raise e
finally:
if lock_monitoring.owned():
lock_monitoring.release()
task_logger.info("Background monitoring task finished")

View File

@@ -13,11 +13,11 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.indexing.tasks import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -86,7 +86,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
@@ -103,7 +103,10 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
continue

View File

@@ -1,6 +1,7 @@
import random
import time
import traceback
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
@@ -53,10 +54,16 @@ from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.index_attempt import delete_index_attempts
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.models import DocumentSet
from onyx.db.models import UserGroup
from onyx.db.sync_record import cleanup_sync_records
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
@@ -278,11 +285,21 @@ def try_generate_document_set_sync_tasks(
# don't generate sync tasks if we're up to date
# race condition with the monitor/cleanup function if we use a cached result!
document_set = get_document_set_by_id(db_session, document_set_id)
document_set = get_document_set_by_id(
db_session=db_session,
document_set_id=document_set_id,
)
if not document_set:
return None
if document_set.is_up_to_date:
# there should be no in-progress sync records if this is up to date
# clean it up just in case things got into a bad state
cleanup_sync_records(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
return None
# add tasks to celery and build up the task set to monitor in redis
@@ -311,6 +328,13 @@ def try_generate_document_set_sync_tasks(
f"document_set={document_set.id} tasks_generated={tasks_generated}"
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
)
# set this only after all tasks have been added
rds.set_fence(tasks_generated)
return tasks_generated
@@ -332,8 +356,9 @@ def try_generate_user_group_sync_tasks(
return None
# race condition with the monitor/cleanup function if we use a cached result!
fetch_user_group = fetch_versioned_implementation(
"onyx.db.user_group", "fetch_user_group"
fetch_user_group = cast(
Callable[[Session, int], UserGroup | None],
fetch_versioned_implementation("onyx.db.user_group", "fetch_user_group"),
)
usergroup = fetch_user_group(db_session, usergroup_id)
@@ -341,6 +366,13 @@ def try_generate_user_group_sync_tasks(
return None
if usergroup.is_up_to_date:
# there should be no in-progress sync records if this is up to date
# clean it up just in case things got into a bad state
cleanup_sync_records(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
return None
# add tasks to celery and build up the task set to monitor in redis
@@ -368,8 +400,16 @@ def try_generate_user_group_sync_tasks(
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
insert_sync_record(
db_session=db_session,
entity_id=usergroup_id,
sync_type=SyncType.USER_GROUP,
)
# set this only after all tasks have been added
rug.set_fence(tasks_generated)
return tasks_generated
@@ -419,6 +459,13 @@ def monitor_document_set_taskset(
f"remaining={count} initial={initial_count}"
)
if count > 0:
update_sync_record_status(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=count,
)
return
document_set = cast(
@@ -437,6 +484,13 @@ def monitor_document_set_taskset(
task_logger.info(
f"Successfully synced document set: document_set={document_set_id}"
)
update_sync_record_status(
db_session=db_session,
entity_id=document_set_id,
sync_type=SyncType.DOCUMENT_SET,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=initial_count,
)
rds.reset()
@@ -470,10 +524,21 @@ def monitor_connector_deletion_taskset(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=remaining,
)
return
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -545,11 +610,29 @@ def monitor_connector_deletion_taskset(
)
db_session.delete(connector)
db_session.commit()
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.SUCCESS,
num_docs_synced=fence_data.num_tasks,
)
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
sync_type=SyncType.CONNECTOR_DELETION,
sync_status=SyncStatus.FAILED,
num_docs_synced=fence_data.num_tasks,
)
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"

View File

@@ -0,0 +1,15 @@
"""Factory stub for running celery worker / celery beat."""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.monitoring import celery_app
return celery_app
app = get_app()

View File

@@ -4,9 +4,10 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing as mp
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing import Process
from multiprocessing.context import SpawnProcess
from typing import Any
from typing import Literal
from typing import Optional
@@ -46,7 +47,9 @@ def _initializer(
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# Initialize a new engine with desired parameters
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
SqlEngine.init_engine(
pool_size=4, max_overflow=12, pool_recycle=60, pool_pre_ping=True
)
# Proceed with executing the target function
return func(*args, **kwargs)
@@ -63,7 +66,7 @@ class SimpleJob:
"""Drop in replacement for `dask.distributed.Future`"""
id: int
process: Optional["Process"] = None
process: Optional["SpawnProcess"] = None
def cancel(self) -> bool:
return self.release()
@@ -131,7 +134,10 @@ class SimpleJobClient:
job_id = self.job_id_counter
self.job_id_counter += 1
process = Process(target=_run_in_process, args=(func, args), daemon=True)
# this approach allows us to always "spawn" a new process regardless of
# get_start_method's current setting
ctx = mp.get_context("spawn")
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()

View File

@@ -4,6 +4,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
@@ -11,6 +12,7 @@ from onyx.background.indexing.tracer import OnyxTracer
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.factory import instantiate_connector
@@ -21,12 +23,14 @@ from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
@@ -75,7 +79,8 @@ def _get_connector_runner(
# it will never succeed
cc_pair = get_connector_credential_pair_from_id(
attempt.connector_credential_pair.id, db_session
db_session=db_session,
cc_pair_id=attempt.connector_credential_pair.id,
)
if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE:
update_connector_credential_pair(
@@ -96,10 +101,17 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
for doc in doc_batch:
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
@@ -115,6 +127,9 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
)
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
@@ -124,9 +139,21 @@ class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
class RunIndexingContext(BaseModel):
index_name: str
cc_pair_id: int
connector_id: int
credential_id: int
source: DocumentSource
earliest_index_time: float
from_beginning: bool
is_primary: bool
search_settings_status: IndexModelStatus
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
index_attempt_id: int,
tenant_id: str | None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
@@ -140,61 +167,76 @@ def _run_indexing(
"""
start_time = time.time()
if index_attempt.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_start:
raise ValueError(
f"Index attempt {index_attempt_id} does not exist in DB. This should not be possible."
)
if index_attempt_start.search_settings is None:
raise ValueError(
"Search settings must be set for indexing. This should not be possible."
)
# search_settings = index_attempt_start.search_settings
db_connector = index_attempt_start.connector_credential_pair.connector
db_credential = index_attempt_start.connector_credential_pair.credential
ctx = RunIndexingContext(
index_name=index_attempt_start.search_settings.index_name,
cc_pair_id=index_attempt_start.connector_credential_pair.id,
connector_id=db_connector.id,
credential_id=db_credential.id,
source=db_connector.source,
earliest_index_time=(
db_connector.indexing_start.timestamp()
if db_connector.indexing_start
else 0
),
from_beginning=index_attempt_start.from_beginning,
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary=(
index_attempt_start.search_settings.status == IndexModelStatus.PRESENT
),
search_settings_status=index_attempt_start.search_settings.status,
)
search_settings = index_attempt.search_settings
last_successful_index_time = (
ctx.earliest_index_time
if ctx.from_beginning
else get_last_successful_attempt_time(
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
earliest_index=ctx.earliest_index_time,
search_settings=index_attempt_start.search_settings,
db_session=db_session_temp,
)
)
index_name = search_settings.index_name
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary = search_settings.status == IndexModelStatus.PRESENT
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=index_attempt_start.search_settings,
callback=callback,
)
# Indexing is only done into one index at a time
document_index = get_default_document_index(
primary_index_name=index_name, secondary_index_name=None
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings,
callback=callback,
primary_index_name=ctx.index_name, secondary_index_name=None
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
attempt_id=index_attempt_id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
ctx.from_beginning
or (ctx.search_settings_status == IndexModelStatus.FUTURE)
),
db_session=db_session,
tenant_id=tenant_id,
callback=callback,
)
db_cc_pair = index_attempt.connector_credential_pair
db_connector = index_attempt.connector_credential_pair.connector
db_credential = index_attempt.connector_credential_pair.credential
earliest_index_time = (
db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0
)
last_successful_index_time = (
earliest_index_time
if index_attempt.from_beginning
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
earliest_index=earliest_index_time,
search_settings=index_attempt.search_settings,
db_session=db_session,
)
)
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = OnyxTracer()
@@ -202,8 +244,8 @@ def _run_indexing(
tracer.snap()
index_attempt_md = IndexAttemptMetadata(
connector_id=db_connector.id,
credential_id=db_credential.id,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
batch_num = 0
@@ -219,21 +261,31 @@ def _run_indexing(
source_type=db_connector.source,
)
):
cc_pair_loop: ConnectorCredentialPair | None = None
index_attempt_loop: IndexAttempt | None = None
try:
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
connector_runner = _get_connector_runner(
db_session=db_session,
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_loop_start = get_index_attempt(
db_session_temp, index_attempt_id
)
if not index_attempt_loop_start:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
all_connector_doc_ids: set[str] = set()
connector_runner = _get_connector_runner(
db_session=db_session_temp,
attempt=index_attempt_loop_start,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
@@ -248,24 +300,38 @@ def _run_indexing(
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
db_session.refresh(db_cc_pair)
if (
(
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and search_settings.status != IndexModelStatus.FUTURE
with get_session_with_tenant(tenant_id) as db_session_temp:
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
)
# if it's deleting, we don't care if this is a secondary index
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt.status}"
if (
(
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
)
# if it's deleting, we don't care if this is a secondary index
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(
db_session_temp, index_attempt_id
)
if not index_attempt_loop:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
batch_description = []
@@ -289,16 +355,15 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
index_pipeline_result = indexing_pipeline(
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch_cleaned)
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
net_doc_change += index_pipeline_result.new_docs
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
@@ -307,18 +372,19 @@ def _run_indexing(
# be inaccurate
db_session.commit()
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_tenant(tenant_id) as db_session_temp:
update_docs_indexed(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=0,
)
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
index_attempt=index_attempt,
total_docs_indexed=document_count,
new_docs_indexed=net_doc_change,
docs_removed_from_index=0,
)
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
@@ -331,34 +397,36 @@ def _run_indexing(
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
run_end_dt = window_end
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
if ctx.is_primary:
with get_session_with_tenant(tenant_id) as db_session_temp:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
mark_attempt_canceled(
index_attempt.id,
db_session,
reason=str(e),
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
@@ -372,24 +440,30 @@ def _run_indexing(
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or not db_cc_pair.status.is_active()
or index_attempt.status != IndexingStatus.IN_PROGRESS
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
or (
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
net_docs=net_doc_change,
or (
index_attempt_loop is not None
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
)
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
@@ -411,56 +485,58 @@ def _run_indexing(
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
mark_attempt_failed(
index_attempt.id,
db_session,
failure_reason="All batches exceptioned.",
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=index_attempt.connector_credential_pair.connector.id,
credential_id=index_attempt.connector_credential_pair.credential.id,
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason="All batches exceptioned.",
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt, db_session)
with get_session_with_tenant(tenant_id) as db_session_temp:
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt_id, db_session_temp)
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
properties=None,
db_session=db_session,
)
create_milestone_and_report(
user=None,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CONNECTOR_SUCCEEDED,
properties=None,
db_session=db_session_temp,
)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt, db_session)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
logger.info(
f"Connector succeeded: "
f"docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
)
else:
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
f"elapsed={elapsed_time:.2f}s"
)
if is_primary:
update_connector_credential_pair(
db_session=db_session,
connector_id=db_connector.id,
credential_id=db_credential.id,
run_dt=run_end_dt,
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
run_dt=run_end_dt,
)
def run_indexing_entrypoint(
@@ -480,27 +556,35 @@ def run_indexing_entrypoint(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
credential_id = attempt.connector_credential_pair.credential_id
_run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
except Exception as e:
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"

View File

@@ -12,10 +12,10 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.build import AnswerPromptBuilder
from onyx.chat.prompt_builder.build import default_build_system_message
from onyx.chat.prompt_builder.build import default_build_user_message
from onyx.chat.prompt_builder.build import LLMCall
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.stream_processing.answer_response_handler import (
CitationResponseHandler,
)
@@ -212,19 +212,6 @@ class Answer:
current_llm_call
) or ([], [])
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
# if self.answer_style_config.citation_config:
# answer_handler = CitationResponseHandler(
# context_docs=search_result,
# doc_id_to_rank_map=map_document_id_order(search_result),
# )
# elif self.answer_style_config.quotes_config:
# answer_handler = QuotesResponseHandler(
# context_docs=search_result,
# )
# else:
# raise ValueError("No answer style config provided")
answer_handler = CitationResponseHandler(
context_docs=final_search_results,
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
@@ -265,11 +252,13 @@ class Answer:
user_query=self.question,
prompt_config=self.prompt_config,
files=self.latest_query_files,
single_message_history=self.single_message_history,
),
message_history=self.message_history,
llm_config=self.llm.config,
raw_user_query=self.question,
raw_user_uploaded_files=self.latest_query_files or [],
single_message_history=self.single_message_history,
raw_user_text=self.question,
)
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)

View File

@@ -25,7 +25,7 @@ from onyx.db.models import Persona
from onyx.db.models import Prompt
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.persona import get_prompts_by_ids
from onyx.db.prompts import get_prompts_by_ids
from onyx.llm.models import PreviousMessage
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.server.query_and_chat.models import CreateChatMessageRequest

View File

@@ -7,7 +7,7 @@ from langchain_core.messages import BaseMessage
from onyx.chat.models import ResponsePart
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.build import LLMCall
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler

View File

@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from pydantic import model_validator
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
@@ -261,13 +260,8 @@ class CitationConfig(BaseModel):
all_docs_useful: bool = False
class QuotesConfig(BaseModel):
pass
class AnswerStyleConfig(BaseModel):
citation_config: CitationConfig | None = None
quotes_config: QuotesConfig | None = None
citation_config: CitationConfig
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
@@ -276,20 +270,6 @@ class AnswerStyleConfig(BaseModel):
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
if self.citation_config is None and self.quotes_config is None:
raise ValueError(
"One of `citation_config` or `quotes_config` must be provided"
)
if self.citation_config is not None and self.quotes_config is not None:
raise ValueError(
"Only one of `citation_config` or `quotes_config` must be provided"
)
return self
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed

View File

@@ -302,6 +302,11 @@ def stream_chat_message_objects(
enforce_chat_session_id_for_search_docs: bool = True,
bypass_acl: bool = False,
include_contexts: bool = False,
# a string which represents the history of a conversation. Used in cases like
# Slack threads where the conversation cannot be represented by a chain of User/Assistant
# messages.
# NOTE: is not stored in the database at all.
single_message_history: str | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -707,6 +712,7 @@ def stream_chat_message_objects(
],
tools=tools,
force_use_tool=_get_force_search_settings(new_msg_req, tools),
single_message_history=single_message_history,
)
reference_db_search_docs = None

View File

@@ -17,6 +17,7 @@ from onyx.llm.utils import check_message_tokens
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.tools.force import ForceUseTool
@@ -42,11 +43,22 @@ def default_build_system_message(
def default_build_user_message(
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
single_message_history: str | None = None,
) -> HumanMessage:
history_block = (
HISTORY_BLOCK.format(history_str=single_message_history)
if single_message_history
else ""
)
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
task_prompt=prompt_config.task_prompt, user_query=user_query
history_block=history_block,
task_prompt=prompt_config.task_prompt,
user_query=user_query,
)
if prompt_config.task_prompt
else user_query
@@ -64,7 +76,8 @@ class AnswerPromptBuilder:
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_text: str,
raw_user_query: str,
raw_user_uploaded_files: list[InMemoryChatFile],
single_message_history: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
@@ -83,10 +96,6 @@ class AnswerPromptBuilder:
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
# for cases where like the QA flow where we want to condense the chat history
# into a single message rather than a sequence of User / Assistant messages
self.single_message_history = single_message_history
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
@@ -95,7 +104,10 @@ class AnswerPromptBuilder:
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
self.raw_user_message = raw_user_text
# used for building a new prompt after a tool-call
self.raw_user_query = raw_user_query
self.raw_user_uploaded_files = raw_user_uploaded_files
self.single_message_history = single_message_history
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:

View File

@@ -1,12 +1,13 @@
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from onyx.context.search.models import InferenceChunk
from onyx.db.models import Persona
from onyx.db.persona import get_default_prompt__read_only
from onyx.db.prompts import get_default_prompt
from onyx.db.search_settings import get_multilingual_expansion
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
@@ -97,11 +98,12 @@ def compute_max_document_tokens(
def compute_max_document_tokens_for_persona(
db_session: Session,
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
prompt = persona.prompts[0] if persona.prompts else get_default_prompt(db_session)
return compute_max_document_tokens(
prompt_config=PromptConfig.from_model(prompt),
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
@@ -144,9 +146,7 @@ def build_citations_user_message(
)
history_block = (
HISTORY_BLOCK.format(history_str=history_message) + "\n"
if history_message
else ""
HISTORY_BLOCK.format(history_str=history_message) if history_message else ""
)
query, img_urls = message_to_prompt_and_imgs(message)

View File

@@ -7,26 +7,6 @@ from onyx.db.models import ChatMessage
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def translate_onyx_msg_to_langchain(

View File

@@ -5,7 +5,7 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages import ToolCall
from onyx.chat.models import ResponsePart
from onyx.chat.prompt_builder.build import LLMCall
from onyx.chat.prompt_builder.answer_prompt_builder import LLMCall
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.message import build_tool_message
@@ -62,7 +62,7 @@ class ToolResponseHandler:
llm_call.force_use_tool.args
if llm_call.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=llm_call.prompt_builder.raw_user_message,
query=llm_call.prompt_builder.raw_user_query,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
force_run=True,
@@ -76,7 +76,7 @@ class ToolResponseHandler:
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=llm_call.tools,
query=llm_call.prompt_builder.raw_user_message,
query=llm_call.prompt_builder.raw_user_query,
history=llm_call.prompt_builder.raw_message_history,
llm=llm,
)
@@ -95,7 +95,7 @@ class ToolResponseHandler:
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=llm_call.prompt_builder.raw_message_history,
query=llm_call.prompt_builder.raw_user_message,
query=llm_call.prompt_builder.raw_user_query,
llm=llm,
)
if available_tools_and_args

View File

@@ -17,6 +17,7 @@ APP_PORT = 8080
# prefix from requests directed towards the API server. In these cases, set this to `/api`
APP_API_PREFIX = os.environ.get("API_PREFIX", "")
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "").lower() == "true"
#####
# User Facing Features Configs

View File

@@ -1,6 +1,6 @@
import os
INPUT_PROMPT_YAML = "./onyx/seeding/input_prompts.yaml"
PROMPTS_YAML = "./onyx/seeding/prompts.yaml"
PERSONAS_YAML = "./onyx/seeding/personas.yaml"

View File

@@ -47,6 +47,7 @@ POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
@@ -78,6 +79,8 @@ KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
# NOTE: we use this timeout / 4 in various places to refresh a lock
# might be worth separating this timeout into separate timeouts for each situation
CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
@@ -197,6 +200,7 @@ class SessionType(str, Enum):
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
MIXED = "mixed" # User likes some answers and dislikes other, used for chat session metrics
class SearchFeedbackType(str, Enum):
@@ -260,6 +264,9 @@ class OnyxCeleryQueues:
# Indexing queue
CONNECTOR_INDEXING = "connector_indexing"
# Monitoring queue
MONITORING = "monitoring"
class OnyxRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
@@ -274,6 +281,7 @@ class OnyxRedisLocks:
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
@@ -286,6 +294,8 @@ class OnyxRedisLocks:
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat"
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
@@ -299,6 +309,13 @@ class OnyxCeleryPriority(int, Enum):
LOWEST = auto()
# a prefix used to distinguish system wide tasks in the cloud
ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
# the tenant id we use for system level redis operations
ONYX_CLOUD_TENANT_ID = "cloud"
class OnyxCeleryTask:
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -308,6 +325,7 @@ class OnyxCeleryTask:
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"
@@ -325,6 +343,8 @@ class OnyxCeleryTask:
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing"
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15

View File

@@ -121,6 +121,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence(Confluence):
@@ -134,32 +135,6 @@ class OnyxConfluence(Confluence):
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _wrap_methods(self) -> None:
"""
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
@@ -204,24 +179,41 @@ class OnyxConfluence(Confluence):
# If the problematic expansion is in the url, replace it
# with the replacement expansion and try again
# If that fails, raise the error
if _PROBLEMATIC_EXPANSIONS not in url_suffix:
logger.exception(
if _PROBLEMATIC_EXPANSIONS in url_suffix:
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
if (
raw_response.status_code == 500
and limit > _MINIMUM_PAGINATION_LIMIT
):
new_limit = limit // 2
logger.warning(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
f"Reducing limit from {limit} to {new_limit} and trying again."
)
raise e
url_suffix = url_suffix.replace(
f"limit={limit}", f"limit={new_limit}"
)
limit = new_limit
continue
logger.warning(
f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS}"
" and trying again."
logger.exception(
f"Error in confluence call to {url_suffix} \n"
f"Raw Response Text: {raw_response.text} \n"
f"Full Response: {raw_response.__dict__} \n"
f"Error: {e} \n"
)
url_suffix = url_suffix.replace(
_PROBLEMATIC_EXPANSIONS,
_REPLACEMENT_EXPANSIONS,
)
continue
raise e
try:
next_response = raw_response.json()
@@ -336,6 +328,62 @@ class OnyxConfluence(Confluence):
group_name = quote(group_name)
yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit)
def get_all_space_permissions_server(
self,
space_key: str,
) -> list[dict[str, Any]]:
"""
This is a confluence server specific method that can be used to
fetch the permissions of a space.
This is better logging than calling the get_space_permissions method
because it returns a jsonrpc response.
TODO: Make this call these endpoints for newer confluence versions:
- /rest/api/space/{spaceKey}/permissions
- /rest/api/space/{spaceKey}/permissions/anonymous
"""
url = "rpc/json-rpc/confluenceservice-v2"
data = {
"jsonrpc": "2.0",
"method": "getSpacePermissionSets",
"id": 7,
"params": [space_key],
}
response = self.post(url, data=data)
logger.debug(f"jsonrpc response: {response}")
if not response.get("result"):
logger.warning(
f"No jsonrpc response for space permissions for space {space_key}"
f"\nResponse: {response}"
)
return response.get("result", [])
def get_current_user(self, expand: str | None = None) -> Any:
"""
Implements a method that isn't in the third party client.
Get information about the current user
:param expand: OPTIONAL expand for get status of user.
Possible param is "status". Results are "Active, Deactivated"
:return: Returns the user details
"""
from atlassian.errors import ApiPermissionError # type:ignore
url = "rest/api/user/current"
params = {}
if expand:
params["expand"] = expand
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)
raise
return response
def _validate_connector_configuration(
credentials: dict[str, Any],

View File

@@ -30,13 +30,14 @@ _FIREFLIES_API_QUERY = """
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
id
title
host_email
organizer_email
participants
date
transcript_url
sentences {
text
speaker_name
start_time
}
}
}
@@ -44,16 +45,34 @@ _FIREFLIES_API_QUERY = """
def _create_doc_from_transcript(transcript: dict) -> Document | None:
meeting_text = ""
sentences = transcript.get("sentences", [])
if sentences:
for sentence in sentences:
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
meeting_text += ": " + sentence.get("text", "") + "\n\n"
else:
return None
sections: List[Section] = []
current_speaker_name = None
current_link = ""
current_text = ""
meeting_link = transcript["transcript_url"]
for sentence in transcript["sentences"]:
if sentence["speaker_name"] != current_speaker_name:
if current_speaker_name is not None:
sections.append(
Section(
link=current_link,
text=current_text.strip(),
)
)
current_speaker_name = sentence.get("speaker_name") or "Unknown Speaker"
current_link = f"{transcript['transcript_url']}?t={sentence['start_time']}"
current_text = f"{current_speaker_name}: "
cleaned_text = sentence["text"].replace("\xa0", " ")
current_text += f"{cleaned_text} "
# Sometimes these links (links with a timestamp) do not work, it is a bug with Fireflies.
sections.append(
Section(
link=current_link,
text=current_text.strip(),
)
)
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
@@ -62,27 +81,22 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
meeting_date_unix = transcript["date"]
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
meeting_host_email = transcript["host_email"]
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
meeting_organizer_email = transcript["organizer_email"]
organizer_email_user_info = [BasicExpertInfo(email=meeting_organizer_email)]
meeting_participants_email_list = []
for participant in transcript.get("participants", []):
if participant != meeting_host_email and participant:
if participant != meeting_organizer_email and participant:
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
return Document(
id=fireflies_id,
sections=[
Section(
link=meeting_link,
text=meeting_text,
)
],
sections=sections,
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},
doc_updated_at=meeting_date,
primary_owners=host_email_user_info,
primary_owners=organizer_email_user_info,
secondary_owners=meeting_participants_email_list,
)

View File

@@ -258,7 +258,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
user_emails.append(email)
return user_emails
def _get_all_drive_ids(self) -> set[str]:
def get_all_drive_ids(self) -> set[str]:
primary_drive_service = get_drive_service(
creds=self.creds,
user_email=self.primary_admin_email,
@@ -353,7 +353,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
) -> Iterator[GoogleDriveFileType]:
all_org_emails: list[str] = self._get_all_user_emails()
all_drive_ids: set[str] = self._get_all_drive_ids()
all_drive_ids: set[str] = self.get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
@@ -437,7 +437,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
# If all 3 are true, we already yielded from get_all_files_for_oauth
return
all_drive_ids = self._get_all_drive_ids()
all_drive_ids = self.get_all_drive_ids()
drive_ids_to_retrieve: set[str] = set()
folder_ids_to_retrieve: set[str] = set()
if self._requested_shared_drive_ids or self._requested_folder_ids:

View File

@@ -252,6 +252,7 @@ def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
id=file["webViewLink"],
perm_sync_data={
"doc_id": file.get("id"),
"drive_id": file.get("driveId"),
"permissions": file.get("permissions", []),
"permission_ids": file.get("permissionIds", []),
"name": file.get("name"),

View File

@@ -19,7 +19,7 @@ FILE_FIELDS = (
"shortcutDetails, owners(emailAddress), size)"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
"nextPageToken, files(mimeType, driveId, id, name, permissions(emailAddress, type), "
"permissionIds, webViewLink, owners(emailAddress))"
)
FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)"

View File

@@ -17,6 +17,9 @@ from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.google_utils.resources import get_gmail_service
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
@@ -29,6 +32,9 @@ from onyx.connectors.google_utils.shared_constants import (
from onyx.connectors.google_utils.shared_constants import (
GOOGLE_SCOPES,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.connectors.google_utils.shared_constants import (
MISSING_SCOPES_ERROR_STR,
)
@@ -96,6 +102,7 @@ def update_credential_access_tokens(
user: User,
db_session: Session,
source: DocumentSource,
auth_method: GoogleOAuthAuthenticationMethod,
) -> OAuthCredentials | None:
app_credentials = get_google_app_cred(source)
flow = InstalledAppFlow.from_client_config(
@@ -119,6 +126,7 @@ def update_credential_access_tokens(
new_creds_dict = {
DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str,
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email,
DB_CREDENTIALS_AUTHENTICATION_METHOD: auth_method.value,
}
if not update_credential_json(credential_id, new_creds_dict, user, db_session):
@@ -129,6 +137,7 @@ def update_credential_access_tokens(
def build_service_account_creds(
source: DocumentSource,
primary_admin_email: str | None = None,
name: str | None = None,
) -> CredentialBase:
service_account_key = get_service_account_key(source=source)
@@ -138,10 +147,15 @@ def build_service_account_creds(
if primary_admin_email:
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.UPLOADED.value
return CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=source,
name=name,
)

View File

@@ -0,0 +1,737 @@
import csv
import os
import shutil
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
get_affected_parent_ids_by_type,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
update_sf_db_with_csv,
)
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
_VALID_SALESFORCE_IDS = [
"001bm00000fd9Z3AAI",
"001bm00000fdYTdAAM",
"001bm00000fdYTeAAM",
"001bm00000fdYTfAAM",
"001bm00000fdYTgAAM",
"001bm00000fdYThAAM",
"001bm00000fdYTiAAM",
"001bm00000fdYTjAAM",
"001bm00000fdYTkAAM",
"001bm00000fdYTlAAM",
"001bm00000fdYTmAAM",
"001bm00000fdYTnAAM",
"001bm00000fdYToAAM",
"500bm00000XoOxtAAF",
"500bm00000XoOxuAAF",
"500bm00000XoOxvAAF",
"500bm00000XoOxwAAF",
"500bm00000XoOxxAAF",
"500bm00000XoOxyAAF",
"500bm00000XoOxzAAF",
"500bm00000XoOy0AAF",
"500bm00000XoOy1AAF",
"500bm00000XoOy2AAF",
"500bm00000XoOy3AAF",
"500bm00000XoOy4AAF",
"500bm00000XoOy5AAF",
"500bm00000XoOy6AAF",
"500bm00000XoOy7AAF",
"500bm00000XoOy8AAF",
"500bm00000XoOy9AAF",
"500bm00000XoOyAAAV",
"500bm00000XoOyBAAV",
"500bm00000XoOyCAAV",
"500bm00000XoOyDAAV",
"500bm00000XoOyEAAV",
"500bm00000XoOyFAAV",
"500bm00000XoOyGAAV",
"500bm00000XoOyHAAV",
"500bm00000XoOyIAAV",
"003bm00000EjHCjAAN",
"003bm00000EjHCkAAN",
"003bm00000EjHClAAN",
"003bm00000EjHCmAAN",
"003bm00000EjHCnAAN",
"003bm00000EjHCoAAN",
"003bm00000EjHCpAAN",
"003bm00000EjHCqAAN",
"003bm00000EjHCrAAN",
"003bm00000EjHCsAAN",
"003bm00000EjHCtAAN",
"003bm00000EjHCuAAN",
"003bm00000EjHCvAAN",
"003bm00000EjHCwAAN",
"003bm00000EjHCxAAN",
"003bm00000EjHCyAAN",
"003bm00000EjHCzAAN",
"003bm00000EjHD0AAN",
"003bm00000EjHD1AAN",
"003bm00000EjHD2AAN",
"550bm00000EXc2tAAD",
"006bm000006kyDpAAI",
"006bm000006kyDqAAI",
"006bm000006kyDrAAI",
"006bm000006kyDsAAI",
"006bm000006kyDtAAI",
"006bm000006kyDuAAI",
"006bm000006kyDvAAI",
"006bm000006kyDwAAI",
"006bm000006kyDxAAI",
"006bm000006kyDyAAI",
"006bm000006kyDzAAI",
"006bm000006kyE0AAI",
"006bm000006kyE1AAI",
"006bm000006kyE2AAI",
"006bm000006kyE3AAI",
"006bm000006kyE4AAI",
"006bm000006kyE5AAI",
"006bm000006kyE6AAI",
"006bm000006kyE7AAI",
"006bm000006kyE8AAI",
"006bm000006kyE9AAI",
"006bm000006kyEAAAY",
"006bm000006kyEBAAY",
"006bm000006kyECAAY",
"006bm000006kyEDAAY",
"006bm000006kyEEAAY",
"006bm000006kyEFAAY",
"006bm000006kyEGAAY",
"006bm000006kyEHAAY",
"006bm000006kyEIAAY",
"006bm000006kyEJAAY",
"005bm000009zy0TAAQ",
"005bm000009zy25AAA",
"005bm000009zy26AAA",
"005bm000009zy28AAA",
"005bm000009zy29AAA",
"005bm000009zy2AAAQ",
"005bm000009zy2BAAQ",
]
def clear_sf_db() -> None:
"""
Clears the SF DB by deleting all files in the data directory.
"""
shutil.rmtree(BASE_DATA_PATH)
def create_csv_file(
object_type: str, records: list[dict], filename: str = "test_data.csv"
) -> None:
"""
Creates a CSV file for the given object type and records.
Args:
object_type: The Salesforce object type (e.g. "Account", "Contact")
records: List of dictionaries containing the record data
filename: Name of the CSV file to create (default: test_data.csv)
"""
if not records:
return
# Get all unique fields from records
fields: set[str] = set()
for record in records:
fields.update(record.keys())
fields = set(sorted(list(fields))) # Sort for consistent order
# Create CSV file
csv_path = os.path.join(get_object_type_path(object_type), filename)
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for record in records:
writer.writerow(record)
# Update the database with the CSV
update_sf_db_with_csv(object_type, csv_path)
def create_csv_with_example_data() -> None:
"""
Creates CSV files with example data, organized by object type.
"""
example_data: dict[str, list[dict]] = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
{
"Id": _VALID_SALESFORCE_IDS[3],
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
{
"Id": _VALID_SALESFORCE_IDS[4],
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
{
"Id": _VALID_SALESFORCE_IDS[5],
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
{
"Id": _VALID_SALESFORCE_IDS[6],
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
{
"Id": _VALID_SALESFORCE_IDS[7],
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"FirstName": "John",
"LastName": "Doe",
"Email": "john.doe@acme.com",
"Title": "CEO",
},
{
"Id": _VALID_SALESFORCE_IDS[41],
"FirstName": "Jane",
"LastName": "Smith",
"Email": "jane.smith@acme.com",
"Title": "CTO",
},
{
"Id": _VALID_SALESFORCE_IDS[42],
"FirstName": "Bob",
"LastName": "Johnson",
"Email": "bob.j@globex.com",
"Title": "Sales Director",
},
{
"Id": _VALID_SALESFORCE_IDS[43],
"FirstName": "Sarah",
"LastName": "Chen",
"Email": "sarah.chen@techcorp.com",
"Title": "Product Manager",
"Phone": "415-555-0101",
},
{
"Id": _VALID_SALESFORCE_IDS[44],
"FirstName": "Michael",
"LastName": "Rodriguez",
"Email": "m.rodriguez@biomed.com",
"Title": "Research Director",
"Phone": "617-555-0202",
},
{
"Id": _VALID_SALESFORCE_IDS[45],
"FirstName": "Emily",
"LastName": "Green",
"Email": "emily.g@greenenergy.com",
"Title": "Sustainability Lead",
"Phone": "503-555-0303",
},
{
"Id": _VALID_SALESFORCE_IDS[46],
"FirstName": "David",
"LastName": "Kim",
"Email": "david.kim@dataflow.com",
"Title": "Data Scientist",
"Phone": "206-555-0404",
},
{
"Id": _VALID_SALESFORCE_IDS[47],
"FirstName": "Rachel",
"LastName": "Taylor",
"Email": "r.taylor@cloudnine.com",
"Title": "Cloud Architect",
"Phone": "303-555-0505",
},
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"Name": "Acme Server Upgrade",
"Amount": 50000,
"Stage": "Prospecting",
"CloseDate": "2024-06-30",
},
{
"Id": _VALID_SALESFORCE_IDS[63],
"Name": "Globex Manufacturing Line",
"Amount": 150000,
"Stage": "Negotiation",
"CloseDate": "2024-03-15",
},
{
"Id": _VALID_SALESFORCE_IDS[64],
"Name": "Initech Software License",
"Amount": 75000,
"Stage": "Closed Won",
"CloseDate": "2024-01-30",
},
{
"Id": _VALID_SALESFORCE_IDS[65],
"Name": "TechCorp AI Implementation",
"Amount": 250000,
"Stage": "Needs Analysis",
"CloseDate": "2024-08-15",
"Probability": 60,
},
{
"Id": _VALID_SALESFORCE_IDS[66],
"Name": "BioMed Lab Equipment",
"Amount": 500000,
"Stage": "Value Proposition",
"CloseDate": "2024-09-30",
"Probability": 75,
},
{
"Id": _VALID_SALESFORCE_IDS[67],
"Name": "Green Energy Solar Project",
"Amount": 750000,
"Stage": "Proposal",
"CloseDate": "2024-07-15",
"Probability": 80,
},
{
"Id": _VALID_SALESFORCE_IDS[68],
"Name": "DataFlow Analytics Platform",
"Amount": 180000,
"Stage": "Negotiation",
"CloseDate": "2024-05-30",
"Probability": 90,
},
{
"Id": _VALID_SALESFORCE_IDS[69],
"Name": "Cloud Nine Infrastructure",
"Amount": 300000,
"Stage": "Qualification",
"CloseDate": "2024-10-15",
"Probability": 40,
},
],
}
# Create CSV files for each object type
for object_type, records in example_data.items():
create_csv_file(object_type, records)
def test_query() -> None:
"""
Tests querying functionality by verifying:
1. All expected Account IDs are found
2. Each Account's data matches what was inserted
"""
# Expected test data for verification
expected_accounts: dict[str, dict[str, str | int]] = {
_VALID_SALESFORCE_IDS[0]: {
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
_VALID_SALESFORCE_IDS[1]: {
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
_VALID_SALESFORCE_IDS[2]: {
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
_VALID_SALESFORCE_IDS[3]: {
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
_VALID_SALESFORCE_IDS[4]: {
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
_VALID_SALESFORCE_IDS[5]: {
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
_VALID_SALESFORCE_IDS[6]: {
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
_VALID_SALESFORCE_IDS[7]: {
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
}
# Get all Account IDs
account_ids = find_ids_by_type("Account")
# Verify we found all expected accounts
assert len(account_ids) == len(
expected_accounts
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
assert set(account_ids) == set(
expected_accounts.keys()
), "Found account IDs don't match expected IDs"
# Verify each account's data
for acc_id in account_ids:
combined = get_record(acc_id)
assert combined is not None, f"Could not find account {acc_id}"
expected = expected_accounts[acc_id]
# Verify account data matches
for key, value in expected.items():
value = str(value)
assert (
combined.data[key] == value
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
print("All query tests passed successfully!")
def test_upsert() -> None:
"""
Tests upsert functionality by:
1. Updating an existing account
2. Creating a new account
3. Verifying both operations were successful
"""
# Create CSV for updating an existing account and adding a new one
update_data: list[dict[str, str | int]] = [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc. Updated",
"BillingCity": "New York",
"Industry": "Technology",
"Description": "Updated company info",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "New Company Inc.",
"BillingCity": "Miami",
"Industry": "Finance",
"AnnualRevenue": 1000000,
},
]
create_csv_file("Account", update_data, "update_data.csv")
# Verify the update worked
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
assert updated_record is not None, "Updated record not found"
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
assert (
updated_record.data["Description"] == "Updated company info"
), "Description not added"
# Verify the new record was created
new_record = get_record(_VALID_SALESFORCE_IDS[2])
assert new_record is not None, "New record not found"
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
print("All upsert tests passed successfully!")
def test_relationships() -> None:
"""
Tests relationship shelf updates and queries by:
1. Creating test data with relationships
2. Verifying the relationships are correctly stored
3. Testing relationship queries
"""
# Create test data for each object type
test_data: dict[str, list[dict[str, str | int]]] = {
"Case": [
{
"Id": _VALID_SALESFORCE_IDS[13],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 1",
},
{
"Id": _VALID_SALESFORCE_IDS[14],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 2",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[48],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Name": "Test Opportunity",
"Amount": 100000,
}
],
}
# Create and update CSV files for each object type
for object_type, records in test_data.items():
create_csv_file(object_type, records, "relationship_test.csv")
# Test relationship queries
# All these objects should be children of Acme Inc.
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
assert (
_VALID_SALESFORCE_IDS[62] in child_ids
), "Opportunity not found in relationship"
# Test querying relationships for a different account (should be empty)
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert (
len(other_account_children) == 0
), "Expected no children for different account"
print("All relationship tests passed successfully!")
def test_account_with_children() -> None:
"""
Tests querying all accounts and retrieving their child objects.
This test verifies that:
1. All accounts can be retrieved
2. Child objects are correctly linked
3. Child object data is complete and accurate
"""
# First get all account IDs
account_ids = find_ids_by_type("Account")
assert len(account_ids) > 0, "No accounts found"
# For each account, get its children and verify the data
for account_id in account_ids:
account = get_record(account_id)
assert account is not None, f"Could not find account {account_id}"
# Get all child objects
child_ids = get_child_ids(account_id)
# For Acme Inc., verify specific relationships
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
assert (
len(child_ids) == 4
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
# Get all child records
child_records = []
for child_id in child_ids:
child_record = get_record(child_id)
if child_record is not None:
child_records.append(child_record)
# Verify Cases
cases = [r for r in child_records if r.type == "Case"]
assert (
len(cases) == 2
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
case_subjects = {case.data["Subject"] for case in cases}
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
# Verify Contacts
contacts = [r for r in child_records if r.type == "Contact"]
assert (
len(contacts) == 1
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
contact = contacts[0]
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
# Verify Opportunities
opportunities = [r for r in child_records if r.type == "Opportunity"]
assert (
len(opportunities) == 1
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
opportunity = opportunities[0]
assert (
opportunity.data["Name"] == "Test Opportunity"
), "Opportunity name mismatch"
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
print("All account with children tests passed successfully!")
def test_relationship_updates() -> None:
"""
Tests that relationships are properly updated when a child object's parent reference changes.
This test verifies:
1. Initial relationship is created correctly
2. When parent reference is updated, old relationship is removed
3. New relationship is created correctly
"""
# Create initial test data - Contact linked to Acme Inc.
initial_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", initial_contact, "initial_contact.csv")
# Verify initial relationship
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] in acme_children
), "Initial relationship not created"
# Update contact to be linked to Globex Corp instead
updated_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[1],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", updated_contact, "updated_contact.csv")
# Verify old relationship is removed
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] not in acme_children
), "Old relationship not removed"
# Verify new relationship is created
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
print("All relationship update tests passed successfully!")
def test_get_affected_parent_ids() -> None:
"""
Tests get_affected_parent_ids functionality by verifying:
1. IDs that are directly in the parent_types list are included
2. IDs that have children in the updated_ids list are included
3. IDs that are neither of the above are not included
"""
# Create test data with relationships
test_data = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Parent Account 1",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Parent Account 2",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Not Affected Account",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Child",
"LastName": "Contact",
}
],
}
# Create and update CSV files for test data
for object_type, records in test_data.items():
create_csv_file(object_type, records)
# Test Case 1: Account directly in updated_ids and parent_types
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
# Test Case 2: Account with child in updated_ids
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids
), "Parent of updated child not included"
# Test Case 3: Both direct and indirect affects
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
assert (
_VALID_SALESFORCE_IDS[2] not in affected_ids
), "Unaffected ID incorrectly included"
# Test Case 4: No matches
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Opportunity"] # Wrong type
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 0, "Should return empty list when no matches"
print("All get_affected_parent_ids tests passed successfully!")
def main_build() -> None:
clear_sf_db()
create_csv_with_example_data()
test_query()
test_upsert()
test_relationships()
test_account_with_children()
test_relationship_updates()
test_get_affected_parent_ids()
if __name__ == "__main__":
main_build()

View File

@@ -10,9 +10,10 @@ from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
@@ -28,17 +29,17 @@ from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
stmt = stmt.distinct()
UG__CCpair = aliased(UserGroup__ConnectorCredentialPair)
User__UG = aliased(User__UserGroup)
@@ -62,6 +63,12 @@ def _add_user_filters(
- if we are not editing, we show all cc_pairs in the groups the user is a curator
for (as well as public cc_pairs)
"""
# If user is None, this is an anonymous user and we should only show public cc_pairs
if user is None:
where_clause = ConnectorCredentialPair.access_type == AccessType.PUBLIC
return stmt.where(where_clause)
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
@@ -85,10 +92,9 @@ def _add_user_filters(
return stmt.where(where_clause)
def get_connector_credential_pairs(
def get_connector_credential_pairs_for_user(
db_session: Session,
include_disabled: bool = True,
user: User | None = None,
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
@@ -99,11 +105,18 @@ def get_connector_credential_pairs(
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
def get_connector_credential_pairs(
db_session: Session,
ids: list[int] | None = None,
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()
if not include_disabled:
stmt = stmt.where(
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
@@ -115,7 +128,10 @@ def add_deletion_failure_message(
cc_pair_id: int,
failure_message: str,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
return
cc_pair.deletion_failure_message = failure_message
@@ -125,24 +141,21 @@ def add_deletion_failure_message(
def get_cc_pair_groups_for_ids(
db_session: Session,
cc_pair_ids: list[int],
user: User | None = None,
get_editable: bool = True,
) -> list[UserGroup__ConnectorCredentialPair]:
stmt = select(UserGroup__ConnectorCredentialPair).distinct()
stmt = stmt.outerjoin(
ConnectorCredentialPair,
UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id,
)
stmt = _add_user_filters(stmt, user, get_editable)
stmt = stmt.where(UserGroup__ConnectorCredentialPair.cc_pair_id.in_(cc_pair_ids))
return list(db_session.scalars(stmt).all())
def get_connector_credential_pair(
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
credential_id: int,
db_session: Session,
user: User | None = None,
user: User | None,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
@@ -153,24 +166,22 @@ def get_connector_credential_pair(
return result.scalar_one_or_none()
def get_connector_credential_source_from_id(
cc_pair_id: int,
def get_connector_credential_pair(
db_session: Session,
user: User | None = None,
get_editable: bool = True,
) -> DocumentSource | None:
connector_id: int,
credential_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
stmt = _add_user_filters(stmt, user, get_editable)
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
result = db_session.execute(stmt)
cc_pair = result.scalar_one_or_none()
return cc_pair.connector.source if cc_pair else None
return result.scalar_one_or_none()
def get_connector_credential_pair_from_id(
def get_connector_credential_pair_from_id_for_user(
cc_pair_id: int,
db_session: Session,
user: User | None = None,
user: User | None,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
@@ -180,6 +191,16 @@ def get_connector_credential_pair_from_id(
return result.scalar_one_or_none()
def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def get_last_successful_attempt_time(
connector_id: int,
credential_id: int,
@@ -191,7 +212,9 @@ def get_last_successful_attempt_time(
the CC Pair row in the database"""
if search_settings.status == IndexModelStatus.PRESENT:
connector_credential_pair = get_connector_credential_pair(
connector_id, credential_id, db_session
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if (
connector_credential_pair is None
@@ -252,7 +275,10 @@ def update_connector_credential_pair_from_id(
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
logger.warning(
f"Attempted to update pair for Connector Credential Pair '{cc_pair_id}'"
@@ -277,7 +303,11 @@ def update_connector_credential_pair(
net_docs: int | None = None,
run_dt: datetime | None = None,
) -> None:
cc_pair = get_connector_credential_pair(connector_id, credential_id, db_session)
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
logger.warning(
f"Attempted to update pair for connector id {connector_id} "
@@ -359,14 +389,23 @@ def add_credential_to_connector(
auto_sync_options: dict | None = None,
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.ACTIVE,
last_successful_index_time: datetime | None = None,
seeding_flow: bool = False,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(
credential_id,
user,
db_session,
get_editable=False,
)
# If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user
if seeding_flow:
credential = fetch_credential_by_id(
db_session=db_session,
credential_id=credential_id,
)
else:
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
@@ -443,7 +482,7 @@ def remove_credential_from_connector(
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
@@ -459,10 +498,10 @@ def remove_credential_from_connector(
detail="Credential does not exist or does not belong to user",
)
association = get_connector_credential_pair(
association = get_connector_credential_pair_for_user(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
db_session=db_session,
user=user,
get_editable=True,
)

View File

@@ -9,6 +9,7 @@ from sqlalchemy.sql.expression import and_
from sqlalchemy.sql.expression import or_
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
@@ -42,22 +43,21 @@ PUBLIC_CREDENTIAL_ID = 0
def _add_user_filters(
stmt: Select,
user: User | None,
assume_admin: bool = False, # Used with API key
get_editable: bool = True,
) -> Select:
"""Attaches filters to the statement to ensure that the user can only
access the appropriate credentials"""
if not user:
if assume_admin:
# apply admin filters minus the user_id check
stmt = stmt.where(
or_(
Credential.user_id.is_(None),
Credential.admin_public == True, # noqa: E712
Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE),
)
if user is None:
if not DISABLE_AUTH:
raise ValueError("Anonymous users are not allowed to access credentials")
# If user is None and auth is disabled, assume the user is an admin
return stmt.where(
or_(
Credential.user_id.is_(None),
Credential.admin_public == True, # noqa: E712
Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE),
)
return stmt
)
if user.role == UserRole.ADMIN:
# Admins can access all credentials that are public or owned by them
@@ -74,6 +74,7 @@ def _add_user_filters(
# Basic users can only access credentials that are owned by them
return stmt.where(Credential.user_id == user.id)
stmt = stmt.distinct()
"""
THIS PART IS FOR CURATORS AND GLOBAL CURATORS
Here we select cc_pairs by relation:
@@ -137,9 +138,9 @@ def _relate_credential_to_user_groups__no_commit(
db_session.add_all(credential_user_groups)
def fetch_credentials(
def fetch_credentials_for_user(
db_session: Session,
user: User | None = None,
user: User | None,
get_editable: bool = True,
) -> list[Credential]:
stmt = select(Credential)
@@ -148,11 +149,10 @@ def fetch_credentials(
return list(results.all())
def fetch_credential_by_id(
def fetch_credential_by_id_for_user(
credential_id: int,
user: User | None,
db_session: Session,
assume_admin: bool = False,
get_editable: bool = True,
) -> Credential | None:
stmt = select(Credential).distinct()
@@ -160,7 +160,6 @@ def fetch_credential_by_id(
stmt = _add_user_filters(
stmt=stmt,
user=user,
assume_admin=assume_admin,
get_editable=get_editable,
)
result = db_session.execute(stmt)
@@ -168,7 +167,18 @@ def fetch_credential_by_id(
return credential
def fetch_credentials_by_source(
def fetch_credential_by_id(
db_session: Session,
credential_id: int,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential
def fetch_credentials_by_source_for_user(
db_session: Session,
user: User | None,
document_source: DocumentSource | None = None,
@@ -180,11 +190,22 @@ def fetch_credentials_by_source(
return list(credentials)
def fetch_credentials_by_source(
db_session: Session,
document_source: DocumentSource | None = None,
) -> list[Credential]:
base_query = select(Credential).where(Credential.source == document_source)
credentials = db_session.execute(base_query).scalars().all()
return list(credentials)
def swap_credentials_connector(
new_credential_id: int, connector_id: int, user: User | None, db_session: Session
) -> ConnectorCredentialPair:
# Check if the user has permission to use the new credential
new_credential = fetch_credential_by_id(new_credential_id, user, db_session)
new_credential = fetch_credential_by_id_for_user(
new_credential_id, user, db_session
)
if not new_credential:
raise ValueError(
f"No Credential found with id {new_credential_id} or user doesn't have permission to use it"
@@ -274,7 +295,7 @@ def alter_credential(
db_session: Session,
) -> Credential | None:
# TODO: add user group relationship update
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
return None
@@ -298,7 +319,7 @@ def update_credential(
user: User,
db_session: Session,
) -> Credential | None:
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
return None
@@ -315,7 +336,7 @@ def update_credential_json(
user: User,
db_session: Session,
) -> Credential | None:
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
return None
@@ -340,7 +361,7 @@ def delete_credential(
db_session: Session,
force: bool = False,
) -> None:
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
raise ValueError(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
@@ -395,7 +416,10 @@ def create_initial_public_credential(db_session: Session) -> None:
"DB is not in a valid initial state."
"There must exist an empty public credential for data connectors that do not require additional Auth."
)
first_credential = fetch_credential_by_id(PUBLIC_CREDENTIAL_ID, None, db_session)
first_credential = fetch_credential_by_id(
db_session=db_session,
credential_id=PUBLIC_CREDENTIAL_ID,
)
if first_credential is not None:
if first_credential.credential_json != {} or first_credential.user is not None:
@@ -413,7 +437,7 @@ def create_initial_public_credential(db_session: Session) -> None:
def cleanup_gmail_credentials(db_session: Session) -> None:
gmail_credentials = fetch_credentials_by_source(
db_session=db_session, user=None, document_source=DocumentSource.GMAIL
db_session=db_session, document_source=DocumentSource.GMAIL
)
for credential in gmail_credentials:
db_session.delete(credential)
@@ -422,7 +446,7 @@ def cleanup_gmail_credentials(db_session: Session) -> None:
def cleanup_google_drive_credentials(db_session: Session) -> None:
google_drive_credentials = fetch_credentials_by_source(
db_session=db_session, user=None, document_source=DocumentSource.GOOGLE_DRIVE
db_session=db_session, document_source=DocumentSource.GOOGLE_DRIVE
)
for credential in google_drive_credentials:
db_session.delete(credential)
@@ -432,7 +456,7 @@ def cleanup_google_drive_credentials(db_session: Session) -> None:
def delete_service_account_credentials(
user: User | None, db_session: Session, source: DocumentSource
) -> None:
credentials = fetch_credentials(db_session=db_session, user=user)
credentials = fetch_credentials_for_user(db_session=db_session, user=user)
for credential in credentials:
if (
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)

View File

@@ -1,6 +1,7 @@
import contextlib
import time
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
@@ -13,6 +14,7 @@ from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import tuple_
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
from sqlalchemy.exc import OperationalError
@@ -107,7 +109,8 @@ def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int
) -> list[DbDocument]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id, db_session=db_session
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
@@ -137,7 +140,8 @@ def get_documents_for_cc_pair(
cc_pair_id: int,
) -> list[DbDocument]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id, db_session=db_session
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
raise ValueError(f"No CC pair found with ID: {cc_pair_id}")
@@ -224,10 +228,13 @@ def get_document_counts_for_cc_pairs(
func.count(),
)
.where(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(cc_ids)
and_(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(cc_ids),
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
)
)
.group_by(
DocumentByConnectorCredentialPair.connector_id,
@@ -380,18 +387,40 @@ def upsert_document_by_connector_credential_pair(
id=doc_id,
connector_id=connector_id,
credential_id=credential_id,
has_been_indexed=False,
)
)
for doc_id in document_ids
]
)
# for now, there are no columns to update. If more metadata is added, then this
# needs to change to an `on_conflict_do_update`
# this must be `on_conflict_do_nothing` rather than `on_conflict_do_update`
# since we don't want to update the `has_been_indexed` field for documents
# that already exist
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
db_session.execute(on_conflict_stmt)
db_session.commit()
def mark_document_as_indexed_for_cc_pair__no_commit(
db_session: Session,
connector_id: int,
credential_id: int,
document_ids: Iterable[str],
) -> None:
"""Should be called only after a successful index operation for a batch."""
db_session.execute(
update(DocumentByConnectorCredentialPair)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
DocumentByConnectorCredentialPair.id.in_(document_ids),
)
)
.values(has_been_indexed=True)
)
def update_docs_updated_at__no_commit(
ids_to_new_updated_at: dict[str, datetime],
db_session: Session,

View File

@@ -12,6 +12,7 @@ from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.enums import AccessType
@@ -36,10 +37,11 @@ logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
stmt = stmt.distinct()
DocumentSet__UG = aliased(DocumentSet__UserGroup)
User__UG = aliased(User__UserGroup)
"""
@@ -60,6 +62,12 @@ def _add_user_filters(
- if we are not editing, we show all DocumentSets in the groups the user is a curator
for (as well as public DocumentSets)
"""
# If user is None, this is an anonymous user and we should only show public DocumentSets
if user is None:
where_clause = DocumentSetDBModel.is_public == True # noqa: E712
return stmt.where(where_clause)
where_clause = User__UserGroup.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UserGroup.is_curator == True # noqa: E712
@@ -108,10 +116,10 @@ def delete_document_set_privacy__no_commit(
"""No private document sets in Onyx MIT"""
def get_document_set_by_id(
def get_document_set_by_id_for_user(
db_session: Session,
document_set_id: int,
user: User | None = None,
user: User | None,
get_editable: bool = True,
) -> DocumentSetDBModel | None:
stmt = select(DocumentSetDBModel).distinct()
@@ -120,6 +128,15 @@ def get_document_set_by_id(
return db_session.scalar(stmt)
def get_document_set_by_id(
db_session: Session,
document_set_id: int,
) -> DocumentSetDBModel | None:
stmt = select(DocumentSetDBModel).distinct()
stmt = stmt.where(DocumentSetDBModel.id == document_set_id)
return db_session.scalar(stmt)
def get_document_set_by_name(
db_session: Session, document_set_name: str
) -> DocumentSetDBModel | None:
@@ -210,6 +227,7 @@ def insert_document_set(
description=document_set_creation_request.description,
user_id=user_id,
is_public=document_set_creation_request.is_public,
time_last_modified_by_user=func.now(),
)
db_session.add(new_document_set_row)
db_session.flush() # ensure the new document set gets assigned an ID
@@ -266,7 +284,7 @@ def update_document_set(
try:
# update the description
document_set_row = get_document_set_by_id(
document_set_row = get_document_set_by_id_for_user(
db_session=db_session,
document_set_id=document_set_update_request.id,
user=user,
@@ -285,7 +303,7 @@ def update_document_set(
document_set_row.description = document_set_update_request.description
document_set_row.is_up_to_date = False
document_set_row.is_public = document_set_update_request.is_public
document_set_row.time_last_modified_by_user = func.now()
versioned_private_doc_set_fn = fetch_versioned_implementation(
"onyx.db.document_set", "make_doc_set_private"
)
@@ -357,7 +375,7 @@ def mark_document_set_as_to_be_deleted(
job which syncs these changes to Vespa."""
try:
document_set_row = get_document_set_by_id(
document_set_row = get_document_set_by_id_for_user(
db_session=db_session,
document_set_id=document_set_id,
user=user,
@@ -469,7 +487,7 @@ def fetch_document_sets(
def fetch_all_document_sets_for_user(
db_session: Session,
user: User | None = None,
user: User | None,
get_editable: bool = True,
) -> Sequence[DocumentSetDBModel]:
stmt = select(DocumentSetDBModel).distinct()

View File

@@ -240,8 +240,11 @@ class SqlEngine:
def get_all_tenant_ids() -> list[str] | list[None]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
result = session.execute(
text(
@@ -354,6 +357,26 @@ async def get_current_tenant_id(request: Request) -> str:
raise HTTPException(status_code=500, detail="Internal server error")
# Listen for events on the synchronous Session class
@event.listens_for(Session, "after_begin")
def _set_search_path(
session: Session, transaction: Any, connection: Any, *args: Any, **kwargs: Any
) -> None:
"""Every time a new transaction is started,
set the search_path from the session's info."""
tenant_id = session.info.get("tenant_id")
if tenant_id:
connection.exec_driver_sql(f'SET search_path = "{tenant_id}"')
engine = get_sqlalchemy_async_engine()
AsyncSessionLocal = sessionmaker( # type: ignore
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_async_session_with_tenant(
tenant_id: str | None = None,
@@ -363,41 +386,22 @@ async def get_async_session_with_tenant(
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID")
raise ValueError("Invalid tenant ID")
engine = get_sqlalchemy_async_engine()
async_session_factory = sessionmaker(
bind=engine, expire_on_commit=False, class_=AsyncSession
) # type: ignore
async with AsyncSessionLocal() as session:
session.sync_session.info["tenant_id"] = tenant_id
async def _set_search_path(session: AsyncSession, tenant_id: str) -> None:
await session.execute(text(f'SET search_path = "{tenant_id}"'))
async with async_session_factory() as session:
# Register an event listener that is called whenever a new transaction starts
@event.listens_for(session.sync_session, "after_begin")
def after_begin(session_: Any, transaction: Any, connection: Any) -> None:
# Because the event is sync, we can't directly await here.
# Instead we queue up an asyncio task to ensures
# the next statement sets the search_path
session_.do_orm_execute = lambda state: connection.exec_driver_sql(
f'SET search_path = "{tenant_id}"'
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
text(
f"SET idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
try:
await _set_search_path(session, tenant_id)
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
except Exception:
logger.exception("Error setting search_path.")
raise
else:
yield session
finally:
pass
@contextmanager

View File

@@ -24,12 +24,27 @@ class IndexingMode(str, PyEnum):
REINDEX = "reindex"
# these may differ in the future, which is why we're okay with this duplication
class DeletionStatus(str, PyEnum):
NOT_STARTED = "not_started"
class SyncType(str, PyEnum):
DOCUMENT_SET = "document_set"
USER_GROUP = "user_group"
CONNECTOR_DELETION = "connector_deletion"
def __str__(self) -> str:
return self.value
class SyncStatus(str, PyEnum):
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
CANCELED = "canceled"
def is_terminal(self) -> bool:
terminal_states = {
SyncStatus.SUCCESS,
SyncStatus.FAILED,
}
return self in terminal_states
# Consistent with Celery task statuses

View File

@@ -13,6 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.constants import MessageType
from onyx.configs.constants import SearchFeedbackType
from onyx.db.chat import get_chat_message
@@ -26,7 +27,6 @@ from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.document_index.interfaces import DocumentIndex
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -46,10 +46,11 @@ def _fetch_db_doc_by_id(doc_id: str, db_session: Session) -> DbDocument:
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
stmt = stmt.distinct()
DocByCC = aliased(DocumentByConnectorCredentialPair)
CCPair = aliased(ConnectorCredentialPair)
UG__CCpair = aliased(UserGroup__ConnectorCredentialPair)
@@ -83,6 +84,12 @@ def _add_user_filters(
- if we are not editing, we show all objects in the groups the user is a curator
for (as well as public objects as well)
"""
# If user is None, this is an anonymous user and we should only show public documents
if user is None:
where_clause = CCPair.access_type == AccessType.PUBLIC
return stmt.where(where_clause)
where_clause = User__UG.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UG.is_curator == True # noqa: E712
@@ -100,9 +107,9 @@ def _add_user_filters(
return stmt.where(where_clause)
def fetch_docs_ranked_by_boost(
def fetch_docs_ranked_by_boost_for_user(
db_session: Session,
user: User | None = None,
user: User | None,
ascending: bool = False,
limit: int = 100,
) -> list[DbDocument]:
@@ -121,11 +128,11 @@ def fetch_docs_ranked_by_boost(
return list(doc_list)
def update_document_boost(
def update_document_boost_for_user(
db_session: Session,
document_id: str,
boost: int,
user: User | None = None,
user: User | None,
) -> None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
stmt = _add_user_filters(stmt, user, get_editable=True)
@@ -143,12 +150,11 @@ def update_document_boost(
db_session.commit()
def update_document_hidden(
def update_document_hidden_for_user(
db_session: Session,
document_id: str,
hidden: bool,
document_index: DocumentIndex,
user: User | None = None,
user: User | None,
) -> None:
stmt = select(DbDocument).where(DbDocument.id == document_id)
stmt = _add_user_filters(stmt, user, get_editable=True)
@@ -170,7 +176,6 @@ def create_doc_retrieval_feedback(
message_id: int,
document_id: str,
document_rank: int,
document_index: DocumentIndex,
db_session: Session,
clicked: bool = False,
feedback: SearchFeedbackType | None = None,

View File

@@ -9,7 +9,6 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.connectors.models import Document
@@ -118,21 +117,14 @@ def get_in_progress_index_attempts(
def get_all_index_attempts_by_status(
status: IndexingStatus, db_session: Session
) -> list[IndexAttempt]:
"""This eagerly loads the connector and credential so that the db_session can be expired
before running long-living indexing jobs, which causes increasing memory usage.
"""Returns index attempts with the given status.
Only recommend calling this with non-terminal states as the full list of
terminal statuses may be quite large.
Results are ordered by time_created (oldest to newest)."""
stmt = select(IndexAttempt)
stmt = stmt.where(IndexAttempt.status == status)
stmt = stmt.order_by(IndexAttempt.time_created)
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.connector
),
joinedload(IndexAttempt.connector_credential_pair).joinedload(
ConnectorCredentialPair.credential
),
)
new_attempts = db_session.scalars(stmt)
return list(new_attempts.all())
@@ -190,13 +182,13 @@ def mark_attempt_in_progress(
def mark_attempt_succeeded(
index_attempt: IndexAttempt,
index_attempt_id: int,
db_session: Session,
) -> None:
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
@@ -208,13 +200,13 @@ def mark_attempt_succeeded(
def mark_attempt_partially_succeeded(
index_attempt: IndexAttempt,
index_attempt_id: int,
db_session: Session,
) -> None:
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
@@ -273,17 +265,26 @@ def mark_attempt_failed(
def update_docs_indexed(
db_session: Session,
index_attempt: IndexAttempt,
index_attempt_id: int,
total_docs_indexed: int,
new_docs_indexed: int,
docs_removed_from_index: int,
) -> None:
index_attempt.total_docs_indexed = total_docs_indexed
index_attempt.new_docs_indexed = new_docs_indexed
index_attempt.docs_removed_from_index = docs_removed_from_index
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
db_session.add(index_attempt)
db_session.commit()
attempt.total_docs_indexed = total_docs_indexed
attempt.new_docs_indexed = new_docs_indexed
attempt.docs_removed_from_index = docs_removed_from_index
db_session.commit()
except Exception:
db_session.rollback()
logger.exception("update_docs_indexed exceptioned.")
raise
def get_last_attempt(

View File

@@ -0,0 +1,262 @@
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.constants import AuthType
from onyx.db.models import InputPrompt
from onyx.db.models import InputPrompt__User
from onyx.db.models import User
from onyx.server.features.input_prompt.models import InputPromptSnapshot
from onyx.server.manage.models import UserInfo
from onyx.utils.logger import setup_logger
logger = setup_logger()
def insert_input_prompt_if_not_exists(
user: User | None,
input_prompt_id: int | None,
prompt: str,
content: str,
active: bool,
is_public: bool,
db_session: Session,
commit: bool = True,
) -> InputPrompt:
if input_prompt_id is not None:
input_prompt = (
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
)
else:
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
if user:
query = query.filter(InputPrompt.user_id == user.id)
else:
query = query.filter(InputPrompt.user_id.is_(None))
input_prompt = query.first()
if input_prompt is None:
input_prompt = InputPrompt(
id=input_prompt_id,
prompt=prompt,
content=content,
active=active,
is_public=is_public or user is None,
user_id=user.id if user else None,
)
db_session.add(input_prompt)
if commit:
db_session.commit()
return input_prompt
def insert_input_prompt(
prompt: str,
content: str,
is_public: bool,
user: User | None,
db_session: Session,
) -> InputPrompt:
input_prompt = InputPrompt(
prompt=prompt,
content=content,
active=True,
is_public=is_public,
user_id=user.id if user is not None else None,
)
db_session.add(input_prompt)
db_session.commit()
return input_prompt
def update_input_prompt(
user: User | None,
input_prompt_id: int,
prompt: str,
content: str,
active: bool,
db_session: Session,
) -> InputPrompt:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You don't own this prompt")
input_prompt.prompt = prompt
input_prompt.content = content
input_prompt.active = active
db_session.commit()
return input_prompt
def validate_user_prompt_authorization(
user: User | None, input_prompt: InputPrompt
) -> bool:
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
if prompt.user_id is not None:
if user is None:
return False
user_details = UserInfo.from_model(user)
if str(user_details.id) != str(prompt.user_id):
return False
return True
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not input_prompt.is_public:
raise HTTPException(status_code=400, detail="This prompt is not public")
db_session.delete(input_prompt)
db_session.commit()
def remove_input_prompt(
user: User | None,
input_prompt_id: int,
db_session: Session,
delete_public: bool = False,
) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if input_prompt.is_public and not delete_public:
raise HTTPException(
status_code=400, detail="Cannot delete public prompts with this method"
)
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You do not own this prompt")
db_session.delete(input_prompt)
db_session.commit()
def fetch_input_prompt_by_id(
id: int, user_id: UUID | None, db_session: Session
) -> InputPrompt:
query = select(InputPrompt).where(InputPrompt.id == id)
if user_id:
query = query.where(
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
)
else:
# If no user_id is provided, only fetch prompts without a user_id (aka public)
query = query.where(InputPrompt.user_id == None) # noqa
result = db_session.scalar(query)
if result is None:
raise HTTPException(422, "No input prompt found")
return result
def fetch_public_input_prompts(
db_session: Session,
) -> list[InputPrompt]:
query = select(InputPrompt).where(InputPrompt.is_public)
return list(db_session.scalars(query).all())
def fetch_input_prompts_by_user(
db_session: Session,
user_id: UUID | None,
active: bool | None = None,
include_public: bool = False,
) -> list[InputPrompt]:
"""
Returns all prompts belonging to the user or public prompts,
excluding those the user has specifically disabled.
"""
# Start with a basic query for InputPrompt
query = select(InputPrompt)
# If we have a user, left join to InputPrompt__User so we can check "disabled"
if user_id is not None:
IPU = aliased(InputPrompt__User)
query = query.join(
IPU,
(IPU.input_prompt_id == InputPrompt.id) & (IPU.user_id == user_id),
isouter=True,
)
# Exclude disabled prompts
# i.e. keep only those where (IPU.disabled is NULL or False)
query = query.where(or_(IPU.disabled.is_(None), IPU.disabled.is_(False)))
if include_public:
# user-owned or public
query = query.where(
(InputPrompt.user_id == user_id) | (InputPrompt.is_public)
)
else:
# only user-owned prompts
query = query.where(InputPrompt.user_id == user_id)
# If no user is logged in, get all prompts (public and private)
if user_id is None and AUTH_TYPE == AuthType.DISABLED:
query = query.where(True) # type: ignore
# If no user is logged in but we want to include public prompts
elif include_public:
query = query.where(InputPrompt.is_public)
if active is not None:
query = query.where(InputPrompt.active == active)
return list(db_session.scalars(query).all())
def disable_input_prompt_for_user(
input_prompt_id: int,
user_id: UUID,
db_session: Session,
) -> None:
"""
Sets (or creates) a record in InputPrompt__User with disabled=True
so that this prompt is hidden for the user.
"""
ipu = (
db_session.query(InputPrompt__User)
.filter_by(input_prompt_id=input_prompt_id, user_id=user_id)
.first()
)
if ipu is None:
# Create a new association row
ipu = InputPrompt__User(
input_prompt_id=input_prompt_id, user_id=user_id, disabled=True
)
db_session.add(ipu)
else:
# Just update the existing record
ipu.disabled = True
db_session.commit()

View File

@@ -18,6 +18,7 @@ from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTa
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
from sqlalchemy import Boolean
from sqlalchemy import DateTime
from sqlalchemy import desc
from sqlalchemy import Enum
from sqlalchemy import Float
from sqlalchemy import ForeignKey
@@ -43,7 +44,7 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.db.enums import AccessType, IndexingMode
from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.constants import TokenRateLimitScope
@@ -150,6 +151,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
@@ -162,6 +164,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
recent_assistants: Mapped[list[dict]] = mapped_column(
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
)
pinned_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
@@ -183,7 +188,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
input_prompts: Mapped[list["InputPrompt"]] = relationship(
"InputPrompt", back_populates="user"
)
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
# Custom tools created by this user
@@ -762,7 +769,7 @@ class IndexAttempt(Base):
# the run once API
from_beginning: Mapped[bool] = mapped_column(Boolean)
status: Mapped[IndexingStatus] = mapped_column(
Enum(IndexingStatus, native_enum=False)
Enum(IndexingStatus, native_enum=False, index=True)
)
# The two below may be slightly out of sync if user switches Embedding Model
new_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0)
@@ -781,6 +788,7 @@ class IndexAttempt(Base):
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
index=True,
)
# when the actual indexing run began
# NOTE: will use the api_server clock rather than DB server clock
@@ -813,6 +821,13 @@ class IndexAttempt(Base):
"connector_credential_pair_id",
"time_created",
),
Index(
"ix_index_attempt_ccpair_search_settings_time_updated",
"connector_credential_pair_id",
"search_settings_id",
desc("time_updated"),
unique=False,
),
)
def __repr__(self) -> str:
@@ -872,6 +887,46 @@ class IndexAttemptError(Base):
)
class SyncRecord(Base):
"""
Represents the status of a "sync" operation (e.g. document set, user group, deletion).
A "sync" operation is an operation which needs to update a set of documents within
Vespa, usually to match the state of Postgres.
"""
__tablename__ = "sync_record"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
# document set id, user group id, or deletion id
entity_id: Mapped[int] = mapped_column(Integer)
sync_type: Mapped[SyncType] = mapped_column(Enum(SyncType, native_enum=False))
sync_status: Mapped[SyncStatus] = mapped_column(Enum(SyncStatus, native_enum=False))
num_docs_synced: Mapped[int] = mapped_column(Integer, default=0)
sync_start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
sync_end_time: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
__table_args__ = (
Index(
"ix_sync_record_entity_id_sync_type_sync_start_time",
"entity_id",
"sync_type",
"sync_start_time",
),
Index(
"ix_sync_record_entity_id_sync_type_sync_status",
"entity_id",
"sync_type",
"sync_status",
),
)
class DocumentByConnectorCredentialPair(Base):
"""Represents an indexing of a document by a specific connector / credential pair"""
@@ -886,6 +941,12 @@ class DocumentByConnectorCredentialPair(Base):
ForeignKey("credential.id"), primary_key=True
)
# used to better keep track of document counts at a connector level
# e.g. if a document is added as part of permission syncing, it should
# not be counted as part of the connector's document count until
# the actual indexing is complete
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="documents_by_connector"
)
@@ -900,6 +961,14 @@ class DocumentByConnectorCredentialPair(Base):
"credential_id",
unique=False,
),
# Index to optimize get_document_counts_for_cc_pairs query pattern
Index(
"idx_document_cc_pair_counts",
"connector_id",
"credential_id",
"has_been_indexed",
unique=False,
),
)
@@ -1275,6 +1344,11 @@ class DocumentSet(Base):
# given access to it either via the `users` or `groups` relationships
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
# Last time a user updated this document set
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
connector_credential_pairs: Mapped[list[ConnectorCredentialPair]] = relationship(
"ConnectorCredentialPair",
secondary=DocumentSet__ConnectorCredentialPair.__table__,
@@ -1375,8 +1449,17 @@ class StarterMessage(TypedDict):
class StarterMessageModel(BaseModel):
name: str
message: str
name: str
class Persona__PersonaLabel(Base):
__tablename__ = "persona__persona_label"
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
persona_label_id: Mapped[int] = mapped_column(
ForeignKey("persona_label.id", ondelete="CASCADE"), primary_key=True
)
class Persona(Base):
@@ -1401,9 +1484,7 @@ class Persona(Base):
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
Enum(RecencyBiasSetting, native_enum=False)
)
category_id: Mapped[int | None] = mapped_column(
ForeignKey("persona_category.id"), nullable=True
)
# Allows the Persona to specify a different LLM version than is controlled
# globablly via env variables. For flexibility, validity is not currently enforced
# NOTE: only is applied on the actual response generation - is not used for things like
@@ -1475,10 +1556,11 @@ class Persona(Base):
secondary="persona__user_group",
viewonly=True,
)
category: Mapped["PersonaCategory"] = relationship(
"PersonaCategory", back_populates="personas"
labels: Mapped[list["PersonaLabel"]] = relationship(
"PersonaLabel",
secondary=Persona__PersonaLabel.__table__,
back_populates="personas",
)
# Default personas loaded via yaml cannot have the same name
__table_args__ = (
Index(
@@ -1490,14 +1572,17 @@ class Persona(Base):
)
class PersonaCategory(Base):
__tablename__ = "persona_category"
class PersonaLabel(Base):
__tablename__ = "persona_label"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
description: Mapped[str | None] = mapped_column(String, nullable=True)
personas: Mapped[list["Persona"]] = relationship(
"Persona", back_populates="category"
"Persona",
secondary=Persona__PersonaLabel.__table__,
back_populates="labels",
cascade="all, delete-orphan",
single_parent=True,
)
@@ -1754,6 +1839,11 @@ class UserGroup(Base):
Boolean, nullable=False, default=False
)
# Last time a user updated this user group
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
users: Mapped[list[User]] = relationship(
"User",
secondary=User__UserGroup.__table__,
@@ -1915,6 +2005,32 @@ class UsageReport(Base):
file = relationship("PGFileStore")
class InputPrompt(Base):
__tablename__ = "inputprompt"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
prompt: Mapped[str] = mapped_column(String)
content: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
class InputPrompt__User(Base):
__tablename__ = "inputprompt__user"
input_prompt_id: Mapped[int] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
disabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
"""
Multi-tenancy related tables
"""

View File

@@ -1,6 +1,5 @@
from collections.abc import Sequence
from datetime import datetime
from functools import lru_cache
from uuid import UUID
from fastapi import HTTPException
@@ -8,7 +7,6 @@ from sqlalchemy import delete
from sqlalchemy import exists
from sqlalchemy import func
from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
@@ -17,25 +15,25 @@ from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.chat_configs import BING_API_KEY
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import DocumentSet
from onyx.db.models import Persona
from onyx.db.models import Persona__User
from onyx.db.models import Persona__UserGroup
from onyx.db.models import PersonaCategory
from onyx.db.models import PersonaLabel
from onyx.db.models import Prompt
from onyx.db.models import StarterMessage
from onyx.db.models import Tool
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.server.features.persona.models import CreatePersonaRequest
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -45,10 +43,11 @@ logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None, assume the user is an admin or auth is disabled
if user is None or user.role == UserRole.ADMIN:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
stmt = stmt.distinct()
Persona__UG = aliased(Persona__UserGroup)
User__UG = aliased(User__UserGroup)
"""
@@ -77,6 +76,12 @@ def _add_user_filters(
for (as well as public Personas)
- if we are not editing, we return all Personas directly connected to the user
"""
# If user is None, this is an anonymous user and we should only show public Personas
if user is None:
where_clause = Persona.is_public == True # noqa: E712
return stmt.where(where_clause)
where_clause = User__UserGroup.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UserGroup.is_curator == True # noqa: E712
@@ -99,10 +104,7 @@ def _add_user_filters(
return stmt.where(where_clause)
# fetch_persona_by_id is used to fetch a persona by its ID. It is used to fetch a persona by its ID.
def fetch_persona_by_id(
def fetch_persona_by_id_for_user(
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
@@ -176,7 +178,7 @@ def make_persona_private(
def create_update_persona(
persona_id: int | None,
create_persona_request: CreatePersonaRequest,
create_persona_request: PersonaUpsertRequest,
user: User | None,
db_session: Session,
) -> PersonaSnapshot:
@@ -184,14 +186,36 @@ def create_update_persona(
# Permission to actually use these is checked later
try:
persona_data = {
"persona_id": persona_id,
"user": user,
"db_session": db_session,
**create_persona_request.model_dump(exclude={"users", "groups"}),
}
all_prompt_ids = create_persona_request.prompt_ids
persona = upsert_persona(**persona_data)
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
persona = upsert_persona(
persona_id=persona_id,
user=user,
db_session=db_session,
description=create_persona_request.description,
name=create_persona_request.name,
prompt_ids=all_prompt_ids,
document_set_ids=create_persona_request.document_set_ids,
tool_ids=create_persona_request.tool_ids,
is_public=create_persona_request.is_public,
recency_bias=create_persona_request.recency_bias,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
icon_color=create_persona_request.icon_color,
icon_shape=create_persona_request.icon_shape,
uploaded_image_id=create_persona_request.uploaded_image_id,
display_priority=create_persona_request.display_priority,
remove_image=create_persona_request.remove_image,
search_start_date=create_persona_request.search_start_date,
label_ids=create_persona_request.label_ids,
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
)
versioned_make_persona_private = fetch_versioned_implementation(
"onyx.db.persona", "make_persona_private"
@@ -221,7 +245,7 @@ def update_persona_shared_users(
"""Simplified version of `create_update_persona` which only touches the
accessibility rather than any of the logic (e.g. prompt, connected data sources,
etc.)."""
persona = fetch_persona_by_id(
persona = fetch_persona_by_id_for_user(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
@@ -247,7 +271,7 @@ def update_persona_public_status(
db_session: Session,
user: User | None,
) -> None:
persona = fetch_persona_by_id(
persona = fetch_persona_by_id_for_user(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
@@ -257,25 +281,7 @@ def update_persona_public_status(
db_session.commit()
def get_prompts(
user_id: UUID | None,
db_session: Session,
include_default: bool = True,
include_deleted: bool = False,
) -> Sequence[Prompt]:
stmt = select(Prompt).where(
or_(Prompt.user_id == user_id, Prompt.user_id.is_(None))
)
if not include_default:
stmt = stmt.where(Prompt.default_prompt.is_(False))
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
return db_session.scalars(stmt).all()
def get_personas(
def get_personas_for_user(
# if user is `None` assume the user is an admin or auth is disabled
user: User | None,
db_session: Session,
@@ -306,6 +312,13 @@ def get_personas(
return db_session.execute(stmt).unique().scalars().all()
def get_personas(db_session: Session) -> Sequence[Persona]:
stmt = select(Persona).distinct()
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
stmt = stmt.where(Persona.deleted.is_(False))
return db_session.execute(stmt).unique().scalars().all()
def mark_persona_as_deleted(
persona_id: int,
user: User | None,
@@ -349,7 +362,7 @@ def update_all_personas_display_priority(
db_session: Session,
) -> None:
"""Updates the display priority of all lives Personas"""
personas = get_personas(user=None, db_session=db_session)
personas = get_personas(db_session=db_session)
available_persona_ids = {persona.id for persona in personas}
if available_persona_ids != set(display_priority_map.keys()):
raise ValueError("Invalid persona IDs provided")
@@ -359,65 +372,6 @@ def update_all_personas_display_priority(
db_session.commit()
def upsert_prompt(
user: User | None,
name: str,
description: str,
system_prompt: str,
task_prompt: str,
include_citations: bool,
datetime_aware: bool,
personas: list[Persona] | None,
db_session: Session,
prompt_id: int | None = None,
default_prompt: bool = True,
commit: bool = True,
) -> Prompt:
if prompt_id is not None:
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
else:
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
if prompt:
if not default_prompt and prompt.default_prompt:
raise ValueError("Cannot update default prompt with non-default.")
prompt.name = name
prompt.description = description
prompt.system_prompt = system_prompt
prompt.task_prompt = task_prompt
prompt.include_citations = include_citations
prompt.datetime_aware = datetime_aware
prompt.default_prompt = default_prompt
if personas is not None:
prompt.personas.clear()
prompt.personas = personas
else:
prompt = Prompt(
id=prompt_id,
user_id=user.id if user else None,
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
datetime_aware=datetime_aware,
default_prompt=default_prompt,
personas=personas or [],
)
db_session.add(prompt)
if commit:
db_session.commit()
else:
# Flush the session so that the Prompt has an ID
db_session.flush()
return prompt
def upsert_persona(
user: User | None,
name: str,
@@ -445,7 +399,7 @@ def upsert_persona(
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool = False,
category_id: int | None = None,
label_ids: list[int] | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
@@ -462,6 +416,15 @@ def upsert_persona(
persona_name=name, user=user, db_session=db_session
)
if existing_persona:
# this checks if the user has permission to edit the persona
# will raise an Exception if the user does not have permission
existing_persona = fetch_persona_by_id_for_user(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
)
# Fetch and attach tools by IDs
tools = None
if tool_ids is not None:
@@ -491,6 +454,12 @@ def upsert_persona(
f"specified. Specified IDs were: '{prompt_ids}'"
)
labels = None
if label_ids is not None:
labels = (
db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all()
)
# ensure all specified tools are valid
if tools:
validate_persona_tools(tools)
@@ -501,15 +470,6 @@ def upsert_persona(
if existing_persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
# will raise an Exception if the user does not have permission
existing_persona = fetch_persona_by_id(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
)
# The following update excludes `default`, `built-in`, and display priority.
# Display priority is handled separately in the `display-priority` endpoint.
# `default` and `built-in` properties can only be set when creating a persona.
@@ -532,7 +492,7 @@ def upsert_persona(
existing_persona.uploaded_image_id = uploaded_image_id
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.category_id = category_id
existing_persona.labels = labels or []
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@@ -585,7 +545,7 @@ def upsert_persona(
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona,
category_id=category_id,
labels=labels or [],
)
db_session.add(new_persona)
persona = new_persona
@@ -598,16 +558,6 @@ def upsert_persona(
return persona
def mark_prompt_as_deleted(
prompt_id: int,
user: User | None,
db_session: Session,
) -> None:
prompt = get_prompt_by_id(prompt_id=prompt_id, user=user, db_session=db_session)
prompt.deleted = True
db_session.commit()
def delete_old_default_personas(
db_session: Session,
) -> None:
@@ -629,7 +579,7 @@ def update_persona_visibility(
db_session: Session,
user: User | None = None,
) -> None:
persona = fetch_persona_by_id(
persona = fetch_persona_by_id_for_user(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
@@ -645,69 +595,6 @@ def validate_persona_tools(tools: list[Tool]) -> None:
)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return list(prompts)
def get_prompt_by_id(
prompt_id: int,
user: User | None,
db_session: Session,
include_deleted: bool = False,
) -> Prompt:
stmt = select(Prompt).where(Prompt.id == prompt_id)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None)))
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise ValueError(
f"Prompt with ID {prompt_id} does not exist or does not belong to user"
)
return prompt
def _get_default_prompt(db_session: Session) -> Prompt:
stmt = select(Prompt).where(Prompt.id == 0)
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise RuntimeError("Default Prompt not found")
return prompt
def get_default_prompt(db_session: Session) -> Prompt:
return _get_default_prompt(db_session)
@lru_cache()
def get_default_prompt__read_only() -> Prompt:
"""Due to the way lru_cache / SQLAlchemy works, this can cause issues
when trying to attach the returned `Prompt` object to a `Persona`. If you are
doing anything other than reading, you should use the `get_default_prompt`
method instead."""
with Session(get_sqlalchemy_engine()) as db_session:
return _get_default_prompt(db_session)
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
# a direct mapping indicating whether a user has access to a specific persona?
def get_persona_by_id(
@@ -779,22 +666,6 @@ def get_personas_by_ids(
return personas
def get_prompt_by_name(
prompt_name: str, user: User | None, db_session: Session
) -> Prompt | None:
stmt = select(Prompt).where(Prompt.name == prompt_name)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def delete_persona_by_name(
persona_name: str, db_session: Session, is_default: bool = True
) -> None:
@@ -806,37 +677,31 @@ def delete_persona_by_name(
db_session.commit()
def get_assistant_categories(db_session: Session) -> list[PersonaCategory]:
return db_session.query(PersonaCategory).all()
def get_assistant_labels(db_session: Session) -> list[PersonaLabel]:
return db_session.query(PersonaLabel).all()
def create_assistant_category(
db_session: Session, name: str, description: str
) -> PersonaCategory:
category = PersonaCategory(name=name, description=description)
db_session.add(category)
def create_assistant_label(db_session: Session, name: str) -> PersonaLabel:
label = PersonaLabel(name=name)
db_session.add(label)
db_session.commit()
return category
return label
def update_persona_category(
category_id: int,
category_description: str,
category_name: str,
def update_persona_label(
label_id: int,
label_name: str,
db_session: Session,
) -> None:
persona_category = (
db_session.query(PersonaCategory)
.filter(PersonaCategory.id == category_id)
.one_or_none()
persona_label = (
db_session.query(PersonaLabel).filter(PersonaLabel.id == label_id).one_or_none()
)
if persona_category is None:
raise ValueError(f"Persona category with ID {category_id} does not exist")
persona_category.description = category_description
persona_category.name = category_name
if persona_label is None:
raise ValueError(f"Persona label with ID {label_id} does not exist")
persona_label.name = label_name
db_session.commit()
def delete_persona_category(category_id: int, db_session: Session) -> None:
db_session.query(PersonaCategory).filter(PersonaCategory.id == category_id).delete()
def delete_persona_label(label_id: int, db_session: Session) -> None:
db_session.query(PersonaLabel).filter(PersonaLabel.id == label_id).delete()
db_session.commit()

119
backend/onyx/db/prompts.py Normal file
View File

@@ -0,0 +1,119 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.models import Persona
from onyx.db.models import Prompt
from onyx.db.models import User
from onyx.utils.logger import setup_logger
# Note: As prompts are fairly innocuous/harmless, there are no protections
# to prevent users from messing with prompts of other users.
logger = setup_logger()
def _get_default_prompt(db_session: Session) -> Prompt:
stmt = select(Prompt).where(Prompt.id == 0)
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise RuntimeError("Default Prompt not found")
return prompt
def get_default_prompt(db_session: Session) -> Prompt:
return _get_default_prompt(db_session)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return list(prompts)
def get_prompt_by_name(
prompt_name: str, user: User | None, db_session: Session
) -> Prompt | None:
stmt = select(Prompt).where(Prompt.name == prompt_name)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(Prompt.user_id == user.id)
# Order by ID to ensure consistent result when multiple prompts exist
stmt = stmt.order_by(Prompt.id).limit(1)
result = db_session.execute(stmt).scalar_one_or_none()
return result
def build_prompt_name_from_persona_name(persona_name: str) -> str:
return f"default-prompt__{persona_name}"
def upsert_prompt(
db_session: Session,
user: User | None,
name: str,
system_prompt: str,
task_prompt: str,
datetime_aware: bool,
prompt_id: int | None = None,
personas: list[Persona] | None = None,
include_citations: bool = False,
default_prompt: bool = True,
# Support backwards compatibility
description: str | None = None,
) -> Prompt:
if description is None:
description = f"Default prompt for {name}"
if prompt_id is not None:
prompt = db_session.query(Prompt).filter_by(id=prompt_id).first()
else:
prompt = get_prompt_by_name(prompt_name=name, user=user, db_session=db_session)
if prompt:
if not default_prompt and prompt.default_prompt:
raise ValueError("Cannot update default prompt with non-default.")
prompt.name = name
prompt.description = description
prompt.system_prompt = system_prompt
prompt.task_prompt = task_prompt
prompt.include_citations = include_citations
prompt.datetime_aware = datetime_aware
prompt.default_prompt = default_prompt
if personas is not None:
prompt.personas.clear()
prompt.personas = personas
else:
prompt = Prompt(
id=prompt_id,
user_id=user.id if user else None,
name=name,
description=description,
system_prompt=system_prompt,
task_prompt=task_prompt,
include_citations=include_citations,
datetime_aware=datetime_aware,
default_prompt=default_prompt,
personas=personas or [],
)
db_session.add(prompt)
# Flush the session so that the Prompt has an ID
db_session.flush()
return prompt

View File

@@ -12,9 +12,9 @@ from onyx.db.models import Persona
from onyx.db.models import Persona__DocumentSet
from onyx.db.models import SlackChannelConfig
from onyx.db.models import User
from onyx.db.persona import get_default_prompt
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import upsert_persona
from onyx.db.prompts import get_default_prompt
from onyx.utils.errors import EERequiredError
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,

View File

@@ -0,0 +1,110 @@
from sqlalchemy import and_
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import SyncRecord
def insert_sync_record(
db_session: Session,
entity_id: int | None,
sync_type: SyncType,
) -> SyncRecord:
"""Insert a new sync record into the database.
Args:
db_session: The database session to use
entity_id: The ID of the entity being synced (document set ID, user group ID, etc.)
sync_type: The type of sync operation
"""
sync_record = SyncRecord(
entity_id=entity_id,
sync_type=sync_type,
sync_status=SyncStatus.IN_PROGRESS,
num_docs_synced=0,
sync_start_time=func.now(),
)
db_session.add(sync_record)
db_session.commit()
return sync_record
def fetch_latest_sync_record(
db_session: Session,
entity_id: int,
sync_type: SyncType,
) -> SyncRecord | None:
"""Fetch the most recent sync record for a given entity ID and status.
Args:
db_session: The database session to use
entity_id: The ID of the entity to fetch sync record for
sync_type: The type of sync operation
"""
stmt = (
select(SyncRecord)
.where(
and_(
SyncRecord.entity_id == entity_id,
SyncRecord.sync_type == sync_type,
)
)
.order_by(desc(SyncRecord.sync_start_time))
.limit(1)
)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
def update_sync_record_status(
db_session: Session,
entity_id: int,
sync_type: SyncType,
sync_status: SyncStatus,
num_docs_synced: int | None = None,
) -> None:
"""Update the status of a sync record.
Args:
db_session: The database session to use
entity_id: The ID of the entity being synced
sync_type: The type of sync operation
sync_status: The new status to set
num_docs_synced: Optional number of documents synced to update
"""
sync_record = fetch_latest_sync_record(db_session, entity_id, sync_type)
if sync_record is None:
raise ValueError(
f"No sync record found for entity_id={entity_id} sync_type={sync_type}"
)
sync_record.sync_status = sync_status
if num_docs_synced is not None:
sync_record.num_docs_synced = num_docs_synced
if sync_status.is_terminal():
sync_record.sync_end_time = func.now() # type: ignore
db_session.commit()
def cleanup_sync_records(
db_session: Session, entity_id: int, sync_type: SyncType
) -> None:
"""Cleanup sync records for a given entity ID and sync type by marking them as failed."""
stmt = (
update(SyncRecord)
.where(SyncRecord.entity_id == entity_id)
.where(SyncRecord.sync_type == sync_type)
.where(SyncRecord.sync_status == SyncStatus.IN_PROGRESS)
.values(sync_status=SyncStatus.CANCELED, sync_end_time=func.now())
)
db_session.execute(stmt)
db_session.commit()

View File

@@ -21,6 +21,7 @@ from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
from onyx.db.document import prepare_to_modify_documents
from onyx.db.document import update_docs_chunk_count__no_commit
from onyx.db.document import update_docs_last_modified__no_commit
@@ -55,12 +56,23 @@ class DocumentBatchPrepareContext(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class IndexingPipelineResult(BaseModel):
# number of documents that are completely new (e.g. did
# not exist as a part of this OR any other connector)
new_docs: int
# NOTE: need total_docs, since the pipeline can skip some docs
# (e.g. not even insert them into Postgres)
total_docs: int
# number of chunks that were inserted into Vespa
total_chunks: int
class IndexingPipelineProtocol(Protocol):
def __call__(
self,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
) -> tuple[int, int]:
) -> IndexingPipelineResult:
...
@@ -147,10 +159,12 @@ def index_doc_batch_with_handler(
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
r = (0, 0)
) -> IndexingPipelineResult:
index_pipeline_result = IndexingPipelineResult(
new_docs=0, total_docs=len(document_batch), total_chunks=0
)
try:
r = index_doc_batch(
index_pipeline_result = index_doc_batch(
chunker=chunker,
embedder=embedder,
document_index=document_index,
@@ -203,7 +217,7 @@ def index_doc_batch_with_handler(
else:
pass
return r
return index_pipeline_result
def index_doc_batch_prepare(
@@ -227,6 +241,15 @@ def index_doc_batch_prepare(
if not ignore_time_skip
else documents
)
if len(updatable_docs) != len(documents):
updatable_doc_ids = [doc.id for doc in updatable_docs]
skipped_doc_ids = [
doc.id for doc in documents if doc.id not in updatable_doc_ids
]
logger.info(
f"Skipping {len(skipped_doc_ids)} documents "
f"because they are up to date. Skipped doc IDs: {skipped_doc_ids}"
)
# for all updatable docs, upsert into the DB
# Does not include doc_updated_at which is also used to indicate a successful update
@@ -263,21 +286,6 @@ def index_doc_batch_prepare(
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
# Remove any NUL characters from title/semantic_id
# This is a known issue with the Zendesk connector
# Postgres cannot handle NUL characters in text fields
if document.title:
document.title = document.title.replace("\x00", "")
if document.semantic_identifier:
document.semantic_identifier = document.semantic_identifier.replace(
"\x00", ""
)
# Remove NUL characters from all sections
for section in document.sections:
if section.text is not None:
section.text = section.text.replace("\x00", "")
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
@@ -333,7 +341,7 @@ def index_doc_batch(
ignore_time_skip: bool = False,
tenant_id: str | None = None,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> tuple[int, int]:
) -> IndexingPipelineResult:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements
@@ -359,7 +367,18 @@ def index_doc_batch(
db_session=db_session,
)
if not ctx:
return 0, 0
# even though we didn't actually index anything, we should still
# mark them as "completed" for the CC Pair in order to make the
# counts match
mark_document_as_indexed_for_cc_pair__no_commit(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_ids=[doc.id for doc in filtered_documents],
db_session=db_session,
)
return IndexingPipelineResult(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
)
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
@@ -425,7 +444,8 @@ def index_doc_batch(
]
logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
"Indexing the following chunks: "
f"{[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
)
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
@@ -440,14 +460,17 @@ def index_doc_batch(
),
)
successful_doc_ids = [record.document_id for record in insertion_records]
successful_docs = [
doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids
]
successful_doc_ids = {record.document_id for record in insertion_records}
if successful_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Successful IDs: {successful_doc_ids}"
)
last_modified_ids = []
ids_to_new_updated_at = {}
for doc in successful_docs:
for doc in ctx.updatable_docs:
last_modified_ids.append(doc.id)
# doc_updated_at is the source's idea (on the other end of the connector)
# of when the doc was last modified
@@ -469,11 +492,24 @@ def index_doc_batch(
db_session=db_session,
)
# these documents can now be counted as part of the CC Pairs
# document count, so we need to mark them as indexed
# NOTE: even documents we skipped since they were already up
# to date should be counted here in order to maintain parity
# between CC Pair and index attempt counts
mark_document_as_indexed_for_cc_pair__no_commit(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_ids=[doc.id for doc in filtered_documents],
db_session=db_session,
)
db_session.commit()
result = (
len([r for r in insertion_records if r.already_existed is False]),
len(access_aware_chunks),
result = IndexingPipelineResult(
new_docs=len([r for r in insertion_records if r.already_existed is False]),
total_docs=len(filtered_documents),
total_chunks=len(access_aware_chunks),
)
return result

View File

@@ -55,10 +55,15 @@ from onyx.server.documents.indexing import router as indexing_router
from onyx.server.documents.standard_oauth import router as oauth_router
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.folder.api import router as folder_router
from onyx.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
from onyx.server.features.input_prompt.api import (
basic_router as input_prompt_router,
)
from onyx.server.features.notifications.api import router as notification_router
from onyx.server.features.persona.api import admin_router as admin_persona_router
from onyx.server.features.persona.api import basic_router as persona_router
from onyx.server.features.prompt.api import basic_router as prompt_router
from onyx.server.features.tool.api import admin_router as admin_tool_router
from onyx.server.features.tool.api import router as tool_router
from onyx.server.gpts.api import router as gpts_router
@@ -215,7 +220,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
else:
setup_multitenant_onyx()
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
if not MULTI_TENANT:
# don't emit a metric for every pod rollover/restart
optional_telemetry(
record_type=RecordType.VERSION, data={"version": __version__}
)
if AUTH_RATE_LIMITING_ENABLED:
await setup_auth_limiter()
@@ -274,6 +283,8 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, connector_router)
include_router_with_global_prefix_prepended(application, user_router)
include_router_with_global_prefix_prepended(application, credential_router)
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, folder_router)
include_router_with_global_prefix_prepended(application, document_set_router)
@@ -284,7 +295,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)
include_router_with_global_prefix_prepended(application, state_router)

View File

@@ -12,6 +12,7 @@ from requests import Response
from retry import retry
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import SKIP_WARM_UP
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from onyx.configs.model_configs import (
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
@@ -384,6 +385,9 @@ def warm_up_bi_encoder(
embedding_model: EmbeddingModel,
non_blocking: bool = False,
) -> None:
if SKIP_WARM_UP:
return
warm_up_str = " ".join(WARM_UP_STRINGS)
logger.debug(f"Warming up encoder model: {embedding_model.model_name}")

View File

@@ -14,8 +14,6 @@ from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.db.engine import get_session_with_tenant
from onyx.db.feedback import create_chat_message_feedback
from onyx.db.feedback import create_doc_retrieval_feedback
from onyx.document_index.document_index_utils import get_both_index_names
from onyx.document_index.factory import get_default_document_index
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
from onyx.onyxbot.slack.blocks import get_document_feedback_blocks
from onyx.onyxbot.slack.config import get_slack_channel_config_for_bot_and_channel
@@ -129,7 +127,7 @@ def handle_generate_answer_button(
channel_to_respond=channel_id,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=user_id or None,
sender_id=user_id or None,
email=email or None,
bypass_filters=True,
is_bot_msg=False,
@@ -186,16 +184,10 @@ def handle_slack_feedback(
else:
feedback = SearchFeedbackType.HIDE
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
create_doc_retrieval_feedback(
message_id=message_id,
document_id=doc_id,
document_rank=doc_rank,
document_index=document_index,
db_session=db_session,
clicked=False, # Not tracking this for Slack
feedback=feedback,

View File

@@ -28,12 +28,12 @@ logger_base = setup_logger()
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender:
if details.is_bot_msg and details.sender_id:
respond_in_thread(
client=client,
channel=details.channel_to_respond,
thread_ts=details.msg_to_respond,
receiver_ids=[details.sender],
receiver_ids=[details.sender_id],
text="Hi, we're evaluating your query :face_with_monocle:",
)
return
@@ -70,7 +70,7 @@ def schedule_feedback_reminder(
try:
response = client.chat_scheduleMessage(
channel=details.sender, # type:ignore
channel=details.sender_id, # type:ignore
post_at=int(future.timestamp()),
blocks=[
get_feedback_reminder_blocks(
@@ -123,7 +123,7 @@ def handle_message(
logger = setup_logger(extra={SLACK_CHANNEL_ID: channel})
messages = message_info.thread_messages
sender_id = message_info.sender
sender_id = message_info.sender_id
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_bot_dm = message_info.is_bot_dm

View File

@@ -126,7 +126,12 @@ def handle_regular_answer(
# messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer
# )
combined_message = slackify_message_thread(messages)
# NOTE: only the message history will contain the person asking. This is likely
# fine since the most common use case for this info is when referring to a user
# who previously posted in the thread.
user_message = messages[-1]
history_messages = messages[:-1]
single_message_history = slackify_message_thread(history_messages) or None
bypass_acl = False
if (
@@ -159,6 +164,7 @@ def handle_regular_answer(
user=onyx_user,
db_session=db_session,
bypass_acl=bypass_acl,
single_message_history=single_message_history,
)
answer = gather_stream_for_slack(packets)
@@ -198,7 +204,7 @@ def handle_regular_answer(
with get_session_with_tenant(tenant_id) as db_session:
answer_request = prepare_chat_message_request(
message_text=combined_message,
message_text=user_message.message,
user=user,
persona_id=persona.id,
# This is not used in the Slack flow, only in the answer API
@@ -312,7 +318,7 @@ def handle_regular_answer(
top_docs = retrieval_info.top_documents
if not top_docs and not should_respond_even_with_no_docs:
logger.error(
f"Unable to answer question: '{combined_message}' - no documents found"
f"Unable to answer question: '{user_message}' - no documents found"
)
# Optionally, respond in thread with the error message
# Used primarily for debugging purposes
@@ -371,8 +377,8 @@ def handle_regular_answer(
respond_in_thread(
client=client,
channel=channel,
receiver_ids=[message_info.sender]
if message_info.is_bot_msg and message_info.sender
receiver_ids=[message_info.sender_id]
if message_info.is_bot_msg and message_info.sender_id
else receiver_ids,
text="Hello! Onyx has some results for you!",
blocks=all_blocks,

View File

@@ -540,9 +540,9 @@ def build_request_details(
tagged = event.get("type") == "app_mention"
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
sender = event.get("user") or None
sender_id = event.get("user") or None
expert_info = expert_info_from_slack_id(
sender, client.web_client, user_cache={}
sender_id, client.web_client, user_cache={}
)
email = expert_info.email if expert_info else None
@@ -566,8 +566,21 @@ def build_request_details(
channel=channel, thread=thread_ts, client=client.web_client
)
else:
sender_display_name = None
if expert_info:
sender_display_name = expert_info.display_name
if sender_display_name is None:
sender_display_name = (
f"{expert_info.first_name} {expert_info.last_name}"
if expert_info.last_name
else expert_info.first_name
)
if sender_display_name is None:
sender_display_name = expert_info.email
thread_messages = [
ThreadMessage(message=msg, sender=None, role=MessageType.USER)
ThreadMessage(
message=msg, sender=sender_display_name, role=MessageType.USER
)
]
return SlackMessageInfo(
@@ -575,7 +588,7 @@ def build_request_details(
channel_to_respond=channel,
msg_to_respond=cast(str, message_ts or thread_ts),
thread_to_respond=cast(str, thread_ts or message_ts),
sender=sender,
sender_id=sender_id,
email=email,
bypass_filters=tagged,
is_bot_msg=False,
@@ -598,7 +611,7 @@ def build_request_details(
channel_to_respond=channel,
msg_to_respond=None,
thread_to_respond=None,
sender=sender,
sender_id=sender,
email=email,
bypass_filters=True,
is_bot_msg=True,
@@ -687,7 +700,7 @@ def process_message(
if feedback_reminder_id:
remove_scheduled_feedback_reminder(
client=client.web_client,
channel=details.sender,
channel=details.sender_id,
msg_id=feedback_reminder_id,
)
# Skipping answering due to pre-filtering is not considered a failure

View File

@@ -8,7 +8,7 @@ class SlackMessageInfo(BaseModel):
channel_to_respond: str
msg_to_respond: str | None
thread_to_respond: str | None
sender: str | None
sender_id: str | None
email: str | None
bypass_filters: bool # User has tagged @OnyxBot
is_bot_msg: bool # User is using /OnyxBot

View File

@@ -590,6 +590,7 @@ def slack_usage_report(
record_type=RecordType.USAGE,
data={"action": action},
user_id=str(onyx_user.id) if onyx_user else "Non-Onyx-Or-No-Auth-User",
tenant_id=tenant_id,
)

Some files were not shown because too many files have changed in this diff Show More