Compare commits

..

17 Commits

Author SHA1 Message Date
pablonyx
25d9266da4 update 2025-02-26 08:48:35 -08:00
Weves
23073d91b9 reduce number of chars to index for search 2025-02-25 19:27:50 -08:00
Chris Weaver
f767b1f476 Fix confluence permission syncing at scale (#4129)
* Fix confluence permission syncing at scale

* Remove line

* Better log message

* Adjust log
2025-02-25 19:22:52 -08:00
pablonyx
9ffc8cb2c4 k 2025-02-25 18:15:49 -08:00
pablonyx
98bfb58147 Handle bad slack configurations– multi tenant (#4118)
* k

* quick nit

* k

* k
2025-02-25 22:22:54 +00:00
evan-danswer
6ce810e957 faster indexing status at scale plus minor cleanups (#4081)
* faster indexing status at scale plus minor cleanups

* mypy

* address chris comments

* remove extra prints
2025-02-25 21:22:26 +00:00
pablonyx
07b0b57b31 (nit) bump timeout 2025-02-25 14:10:30 -08:00
pablonyx
118cdd7701 Chat search (#4113)
* add chat search

* don't add the bible

* base functional

* k

* k

* functioning

* functioning well

* functioning well

* k

* delete bible

* quick cleanup

* quick cleanup

* k

* fixed frontend hooks

* delete bible

* nit

* nit

* nit

* fix build

* k

* improved debouncing

* address comments

* fix alembic

* k
2025-02-25 20:49:46 +00:00
rkuo-danswer
ac83b4c365 validate connector deletion (#4108)
* validate connector deletion

* fixes

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-25 20:35:21 +00:00
pablonyx
fa408ff447 add 3.7 (#4116) 2025-02-25 12:41:40 -08:00
rkuo-danswer
4aa8eb8b75 fix scrolling test (#4117)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-02-25 10:23:04 -08:00
rkuo-danswer
60bd9271f7 Bugfix/model tests (#4092)
* trying out a fix

* add ability to manually run model tests

* add log dump

* check status code, not text?

* just the model server

* add port mapping to host

* pass through more api keys

* add azure tests

* fix litellm env vars

* fix env vars in github workflow

* temp disable litellm test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-25 04:53:51 +00:00
Weves
5d58a5e3ea Add ability to index all of Github 2025-02-24 18:56:36 -08:00
Chris Weaver
a99dd05533 Add option to index all Jira projects (#4106)
* Add option to index all Jira projects

* Fix test

* Fix web build

* Address comment
2025-02-25 02:07:00 +00:00
pablonyx
0dce67094e Prettier formatting for bedrock (#4111)
* k

* k
2025-02-25 02:05:29 +00:00
pablonyx
ffd14435a4 Text overflow logic (#4051)
* proper components

* k

* k

* k
2025-02-25 01:05:22 +00:00
rkuo-danswer
c9a3b45ad4 more aggressive handling of tasks blocking deletion (#4093)
* more aggressive handling of tasks blocking deletion

* comment updated

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-24 22:41:13 +00:00
55 changed files with 2144 additions and 339 deletions

View File

@@ -17,8 +17,13 @@ env:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# OpenAI
# API keys for testing
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
jobs:
model-check:
@@ -72,7 +77,7 @@ jobs:
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d indexing_model_server
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
@@ -123,9 +128,22 @@ jobs:
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v

View File

@@ -0,0 +1,31 @@
"""add index
Revision ID: 8f43500ee275
Revises: da42808081e3
Create Date: 2025-02-24 17:35:33.072714
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8f43500ee275"
down_revision = "da42808081e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a basic index on the lowercase message column for direct text matching
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
op.execute(
"""
CREATE INDEX idx_chat_message_message_lower
ON chat_message (LOWER(substring(message, 1, 1500)))
"""
)
def downgrade() -> None:
# Drop the index
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -0,0 +1,120 @@
"""migrate jira connectors to new format
Revision ID: da42808081e3
Revises: f13db29f3101
Create Date: 2025-02-24 11:24:54.396040
"""
from alembic import op
import sqlalchemy as sa
import json
from onyx.configs.constants import DocumentSource
from onyx.connectors.onyx_jira.utils import extract_jira_project
# revision identifiers, used by Alembic.
revision = "da42808081e3"
down_revision = "f13db29f3101"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config
for connector_id, old_config in jira_connectors:
if not old_config:
continue
# Extract project key from URL if it exists
new_config: dict[str, str | None] = {}
if project_url := old_config.get("jira_project_url"):
# Parse the URL to get base and project
try:
jira_base, project_key = extract_jira_project(project_url)
new_config = {"jira_base_url": jira_base, "project_key": project_key}
except ValueError:
# If URL parsing fails, just use the URL as the base
new_config = {
"jira_base_url": project_url.split("/projects/")[0],
"project_key": None,
}
else:
# For connectors without a project URL, we need admin intervention
# Mark these for review
print(
f"WARNING: Jira connector {connector_id} has no project URL configured"
)
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :id
"""
),
{"id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config back to the old format
for connector_id, new_config in jira_connectors:
if not new_config:
continue
old_config = {}
base_url = new_config.get("jira_base_url")
project_key = new_config.get("project_key")
if base_url and project_key:
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
elif base_url:
old_config = {"jira_project_url": base_url}
else:
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :old_config
WHERE id = :id
"""
),
{"id": connector_id, "old_config": old_config},
)

View File

@@ -224,7 +224,7 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20241022",
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
)

View File

@@ -420,8 +420,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = cast(User, await self.user_db.get_by_email(account_email))
user = await self.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()

View File

@@ -57,6 +57,51 @@ class TaskDependencyError(RuntimeError):
with connector deletion."""
def revoke_tasks_blocking_deletion(
redis_connector: RedisConnector, db_session: Session, app: Celery
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
task_logger.info(
f"Revoked indexing task {index_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
try:
permissions_sync_payload = redis_connector.permissions.payload
if permissions_sync_payload and permissions_sync_payload.celery_task_id:
app.control.revoke(permissions_sync_payload.celery_task_id)
task_logger.info(
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking pruning task")
try:
prune_payload = redis_connector.prune.payload
if prune_payload and prune_payload.celery_task_id:
app.control.revoke(prune_payload.celery_task_id)
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
except Exception:
task_logger.exception("Exception while revoking permissions sync task")
try:
external_group_sync_payload = redis_connector.external_group_sync.payload
if external_group_sync_payload and external_group_sync_payload.celery_task_id:
app.control.revoke(external_group_sync_payload.celery_task_id)
task_logger.info(
f"Revoked external group sync task {external_group_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking external group sync task")
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
ignore_result=True,
@@ -76,7 +121,7 @@ def check_for_connector_deletion_task(
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
# Prevent this task from overlapping with itself
if not lock_beat.acquire(blocking=False):
return None
@@ -113,9 +158,38 @@ def check_for_connector_deletion_task(
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# Leave a stop signal to clear indexing and pruning tasks more quickly
# on the first error, we set a stop signal and revoke the dependent tasks
# on subsequent errors, we hard reset blocking fences after our specified timeout
# is exceeded
task_logger.info(str(e))
redis_connector.stop.set_fence(True)
if not redis_connector.stop.fenced:
# one time revoke of celery tasks
task_logger.info("Revoking any tasks blocking deletion.")
revoke_tasks_blocking_deletion(
redis_connector, db_session, self.app
)
redis_connector.stop.set_fence(True)
redis_connector.stop.set_timeout()
else:
# stop signal already set
if redis_connector.stop.timed_out:
# waiting too long, just reset blocking fences
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()
else:
# just wait
pass
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)

View File

@@ -11,6 +11,8 @@ from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from requests import HTTPError
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.utils.logger import setup_logger
@@ -161,7 +163,7 @@ class OnyxConfluence(Confluence):
)
def _paginate_url(
self, url_suffix: str, limit: int | None = None
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -236,9 +238,41 @@ class OnyxConfluence(Confluence):
raise e
# yield the results individually
yield from next_response.get("results", [])
results = cast(list[dict[str, Any]], next_response.get("results", []))
yield from results
url_suffix = next_response.get("_links", {}).get("next")
old_url_suffix = url_suffix
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
# make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
if url_suffix and "start" in url_suffix:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.warning(
f"Start was updated by more than the amount of results "
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start
adjusted_start = previous_start + len(results)
url_suffix = update_param_in_path(
url_suffix, "start", str(adjusted_start)
)
# some APIs don't properly paginate, so we need to manually update the `start` param
if auto_paginate and len(results) > 0:
previous_start = get_start_param_from_url(old_url_suffix)
updated_start = previous_start + len(results)
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
def paginated_cql_retrieval(
self,
@@ -298,7 +332,9 @@ class OnyxConfluence(Confluence):
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
for user_result in self._paginate_url(url, limit):
# endpoint doesn't properly paginate, so we need to manually update the `start` param
# thus the auto_paginate flag
for user_result in self._paginate_url(url, limit, auto_paginate=True):
# Example response:
# {
# 'user': {

View File

@@ -2,7 +2,10 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TYPE_CHECKING
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
@@ -10,13 +13,13 @@ from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.onyx_confluence import (
OnyxConfluence,
)
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
logger = setup_logger()
@@ -24,7 +27,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
confluence_client: "OnyxConfluence", user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
@@ -47,7 +50,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
@@ -78,7 +81,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_client: "OnyxConfluence",
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
@@ -191,7 +194,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
def attachment_to_content(
confluence_client: OnyxConfluence,
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
@@ -279,3 +282,32 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def get_single_param_from_url(url: str, param: str) -> str | None:
"""Get a parameter from a url"""
parsed_url = urlparse(url)
return parse_qs(parsed_url.query).get(param, [None])[0]
def get_start_param_from_url(url: str) -> int:
"""Get the start parameter from a url"""
start_str = get_single_param_from_url(url, "start")
if start_str is None:
return 0
return int(start_str)
def update_param_in_path(path: str, param: str, value: str) -> str:
"""Update a parameter in a path. Path should look something like:
/api/rest/users?start=0&limit=10
"""
parsed_url = urlparse(path)
query_params = parse_qs(parsed_url.query)
query_params[param] = [value]
return (
path.split("?")[0]
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)

View File

@@ -124,7 +124,7 @@ class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repo_name: str,
repo_name: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
@@ -162,53 +162,81 @@ class GithubConnector(LoadConnector, PollConnector):
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
def _get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
# Try to get organization first
try:
org = github_client.get_organization(self.repo_owner)
return list(org.get_repos())
except GithubException:
# If not an org, try as a user
user = github_client.get_user(self.repo_owner)
return list(user.get_repos())
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_all_repos(github_client, attempt_num + 1)
def _fetch_from_github(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repo = self._get_github_repo(self.github_client)
repos = (
[self._get_github_repo(self.github_client)]
if self.repo_name
else self._get_all_repos(self.github_client)
)
if self.include_prs:
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
for repo in repos:
if self.include_prs:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
return
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
break
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
if self.include_issues:
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
if self.include_issues:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
return
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
break
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_github()
@@ -234,16 +262,26 @@ class GithubConnector(LoadConnector, PollConnector):
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
if not self.repo_owner or not self.repo_name:
if not self.repo_owner:
raise ConnectorValidationError(
"Invalid connector settings: 'repo_owner' and 'repo_name' must be provided."
"Invalid connector settings: 'repo_owner' must be provided."
)
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
if self.repo_name:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
org = self.github_client.get_organization(self.repo_owner)
org.get_repos().totalCount # Just check if we can access repos
except GithubException:
# If not an org, try as a user
user = self.github_client.get_user(self.repo_owner)
user.get_repos().totalCount # Just check if we can access repos
except RateLimitExceededException:
raise UnexpectedError(
@@ -260,9 +298,14 @@ class GithubConnector(LoadConnector, PollConnector):
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
if self.repo_name:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"

View File

@@ -29,7 +29,6 @@ from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
from onyx.connectors.onyx_jira.utils import build_jira_client
from onyx.connectors.onyx_jira.utils import build_jira_url
from onyx.connectors.onyx_jira.utils import extract_jira_project
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
from onyx.connectors.onyx_jira.utils import get_comment_strs
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -160,7 +159,8 @@ def fetch_jira_issues_batch(
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
jira_project_url: str,
jira_base_url: str,
project_key: str | None = None,
comment_email_blacklist: list[str] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
# if a ticket has one of the labels specified in this list, we will just
@@ -169,12 +169,13 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
) -> None:
self.batch_size = batch_size
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
self._jira_client: JIRA | None = None
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
self.jira_project = project_key
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
self._jira_client: JIRA | None = None
@property
def comment_email_blacklist(self) -> tuple:
return tuple(email.strip() for email in self._comment_email_blacklist)
@@ -188,7 +189,9 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
@property
def quoted_jira_project(self) -> str:
# Quote the project name to handle reserved words
return f'"{self._jira_project}"'
if not self.jira_project:
return ""
return f'"{self.jira_project}"'
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
@@ -197,8 +200,14 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _get_jql_query(self) -> str:
"""Get the JQL query based on whether a specific project is set"""
if self.jira_project:
return f"project = {self.quoted_jira_project}"
return "" # Empty string means all accessible projects
def load_from_state(self) -> GenerateDocumentsOutput:
jql = f"project = {self.quoted_jira_project}"
jql = self._get_jql_query()
document_batch = []
for doc in fetch_jira_issues_batch(
@@ -225,11 +234,10 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
base_jql = self._get_jql_query()
jql = (
f"project = {self.quoted_jira_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)
f"{base_jql} AND " if base_jql else ""
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
document_batch = []
for doc in fetch_jira_issues_batch(
@@ -252,7 +260,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
jql = f"project = {self.quoted_jira_project}"
jql = self._get_jql_query()
slim_doc_batch = []
for issue in _paginate_jql_search(
@@ -279,43 +287,63 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
if not self._jira_project:
raise ConnectorValidationError(
"Invalid connector settings: 'jira_project' must be provided."
)
# If a specific project is set, validate it exists
if self.jira_project:
try:
self.jira_client.project(self.jira_project)
except Exception as e:
status_code = getattr(e, "status_code", None)
try:
self.jira_client.project(self._jira_project)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
)
elif status_code == 404:
raise ConnectorValidationError(
f"Jira project not found with key: {self.jira_project}"
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
except Exception as e:
status_code = getattr(e, "status_code", None)
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
else:
# If no project specified, validate we can access the Jira API
try:
# Try to list projects to validate access
self.jira_client.projects()
except Exception as e:
status_code = getattr(e, "status_code", None)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions to list projects (HTTP 403)."
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
)
elif status_code == 404:
raise ConnectorValidationError(
f"Jira project not found with key: {self._jira_project}"
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
else:
raise Exception(f"Unexpected Jira error during validation: {e}")
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
if __name__ == "__main__":
import os
connector = JiraConnector(
os.environ["JIRA_PROJECT_URL"], comment_email_blacklist=[]
jira_base_url=os.environ["JIRA_BASE_URL"],
project_key=os.environ.get("JIRA_PROJECT_KEY"),
comment_email_blacklist=[],
)
connector.load_credentials(
{
"jira_user_email": os.environ["JIRA_USER_EMAIL"],

View File

@@ -0,0 +1,152 @@
from typing import List
from typing import Optional
from typing import Tuple
from uuid import UUID
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import literal
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import union_all
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
def search_chat_sessions(
user_id: UUID | None,
db_session: Session,
query: Optional[str] = None,
page: int = 1,
page_size: int = 10,
include_deleted: bool = False,
include_onyxbot_flows: bool = False,
) -> Tuple[List[ChatSession], bool]:
"""
Search for chat sessions based on the provided query.
If no query is provided, returns recent chat sessions.
Returns a tuple of (chat_sessions, has_more)
"""
offset = (page - 1) * page_size
# If no search query, we use standard SQLAlchemy pagination
if not query or not query.strip():
stmt = select(ChatSession)
if user_id:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
# Apply pagination
stmt = stmt.offset(offset).limit(page_size + 1)
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
has_more = len(chat_sessions) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
return list(chat_sessions), has_more
words = query.lower().strip().split()
# Message mach subquery
message_matches = []
for word in words:
word_like = f"%{word}%"
message_match: Select = (
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
.where(func.lower(ChatMessage.message).like(word_like))
)
if user_id:
message_match = message_match.where(ChatSession.user_id == user_id)
message_matches.append(message_match)
if message_matches:
message_matches_query = union_all(*message_matches).alias("message_matches")
else:
return [], False
# Description matches
description_match: Select = select(
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
if user_id:
description_match = description_match.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
description_match = description_match.where(ChatSession.deleted.is_(False))
# Combine all match sources
combined_matches = union_all(
message_matches_query.select(), description_match
).alias("combined_matches")
# Use CTE to group and get max rank
session_ranks = (
select(
combined_matches.c.chat_session_id,
func.max(combined_matches.c.search_rank).label("rank"),
)
.group_by(combined_matches.c.chat_session_id)
.alias("session_ranks")
)
# Get ranked sessions with pagination
ranked_query = (
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
.offset(offset)
.limit(page_size + 1)
)
result = ranked_query.all()
# Extract session IDs and ranks
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
session_ids = list(session_ids_with_ranks.keys())
if not session_ids:
return [], False
# Now, let's query the actual ChatSession objects using the IDs
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
if user_id:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
# Full objects with eager loading
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
# Sort based on above ranking
chat_sessions = sorted(
chat_sessions,
key=lambda session: (
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
session.time_created.timestamp() * -1, # Then by time (newest first)
),
)
has_more = len(chat_sessions) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
return chat_sessions, has_more

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import TypeVarTuple
from fastapi import HTTPException
from sqlalchemy import delete
@@ -8,15 +9,18 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
@@ -31,10 +35,12 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
R = TypeVarTuple("R")
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
) -> Select[tuple[*R]]:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -98,17 +104,52 @@ def get_connector_credential_pairs_for_user(
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
if eager_load_user:
assert (
eager_load_credential
), "eager_load_credential must be True if eager_load_user is True"
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
if eager_load_user:
load_opts = load_opts.joinedload(Credential.user)
stmt = stmt.options(load_opts)
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
return list(db_session.scalars(stmt).unique().all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_connector_credential_pairs_for_user_parallel(
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_connector_credential_pairs_for_user(
db_session,
user,
get_editable,
ids,
eager_load_connector,
eager_load_credential,
eager_load_user,
)
def get_connector_credential_pairs(
@@ -151,6 +192,16 @@ def get_cc_pair_groups_for_ids(
return list(db_session.scalars(stmt).all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_cc_pair_groups_for_ids_parallel(
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,

View File

@@ -24,6 +24,7 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
@@ -229,12 +230,12 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
# Prepare a list of (connector_id, credential_id) tuples
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
stmt = (
select(
@@ -260,6 +261,16 @@ def get_document_counts_for_cc_pairs(
return db_session.execute(stmt).all() # type: ignore
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_document_counts_for_cc_pairs_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
) -> Sequence[tuple[int, int, int]]:
with get_session_context_manager() as db_session:
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
def get_access_info_for_document(
db_session: Session,
document_id: str,

View File

@@ -218,6 +218,7 @@ class SqlEngine:
final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:

View File

@@ -2,6 +2,7 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import TypeVarTuple
from sqlalchemy import and_
from sqlalchemy import delete
@@ -9,9 +10,13 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
@@ -368,19 +373,33 @@ def get_latest_index_attempts_by_status(
return db_session.execute(stmt).scalars().all()
T = TypeVarTuple("T")
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
return stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
def get_latest_index_attempts(
secondary_index: bool,
db_session: Session,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
ids_stmt = select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
if secondary_index:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
else:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
ids_stmt = ids_stmt.where(SearchSettings.status == status)
if only_finished:
ids_stmt = _add_only_finished_clause(ids_stmt)
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
@@ -395,7 +414,53 @@ def get_latest_index_attempts(
.where(IndexAttempt.id == ids_subquery.c.max_id)
)
return db_session.execute(stmt).scalars().all()
if only_finished:
stmt = _add_only_finished_clause(stmt)
if eager_load_cc_pair:
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
return db_session.execute(stmt).scalars().unique().all()
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_latest_index_attempts_parallel(
secondary_index: bool,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,
eager_load_cc_pair,
only_finished,
)
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
)
if only_finished:
stmt = _add_only_finished_clause(stmt)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
def count_index_attempts_for_connector(
@@ -453,37 +518,12 @@ def get_paginated_index_attempts_for_cc_pair_id(
# Apply pagination
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.execute(stmt).scalars().all())
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
stmt = stmt.options(
contains_eager(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
if only_finished:
stmt = stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
return list(db_session.execute(stmt).scalars().unique().all())
def get_index_attempts_for_cc_pair(

View File

@@ -103,7 +103,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
api_version_required=False,
custom_config_keys=[],
llm_names=fetch_models_for_provider(ANTHROPIC_PROVIDER_NAME),
default_model="claude-3-5-sonnet-20241022",
default_model="claude-3-7-sonnet-20250219",
default_fast_model="claude-3-5-sonnet-20241022",
),
WellKnownLLMProviderDescriptor(

View File

@@ -17,10 +17,12 @@ from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.product_gating import get_gated_tenants
from onyx.chat.models import ThreadMessage
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import POD_NAME
@@ -249,7 +251,12 @@ class SlackbotHandler:
- If yes, store them in self.tenant_ids and manage the socket connections.
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
all_tenants = get_all_tenant_ids()
all_tenants = [
tenant_id
for tenant_id in get_all_tenant_ids()
if tenant_id not in get_gated_tenants()
]
token: Token[str | None]
@@ -416,6 +423,7 @@ class SlackbotHandler:
try:
bot_info = socket_client.web_client.auth_test()
if bot_info["ok"]:
bot_user_id = bot_info["user_id"]
user_info = socket_client.web_client.users_info(user=bot_user_id)
@@ -426,9 +434,23 @@ class SlackbotHandler:
logger.info(
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
)
except SlackApiError as e:
# Only error out if we get a not_authed error
if "not_authed" in str(e):
self.tenant_ids.add(tenant_id)
logger.error(
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
"Error: {e}"
)
return
# Log other Slack API errors but continue
logger.error(
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
)
except Exception as e:
logger.warning(
f"Could not fetch bot name: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
# Log other exceptions but continue
logger.error(
f"Error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
)
# Append the event handler

View File

@@ -93,10 +93,7 @@ class RedisConnectorIndex:
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
@property
def payload(self) -> RedisConnectorIndexPayload | None:
@@ -106,9 +103,7 @@ class RedisConnectorIndex:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
return payload
return RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
def set_fence(
self,
@@ -123,10 +118,7 @@ class RedisConnectorIndex:
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
def terminating(self, celery_task_id: str) -> bool:
if self.redis.exists(f"{self.terminate_key}_{celery_task_id}"):
return True
return False
return bool(self.redis.exists(f"{self.terminate_key}_{celery_task_id}"))
def set_terminate(self, celery_task_id: str) -> None:
"""This sets a signal. It does not block!"""
@@ -146,10 +138,7 @@ class RedisConnectorIndex:
def watchdog_signaled(self) -> bool:
"""Check the state of the watchdog."""
if self.redis.exists(self.watchdog_key):
return True
return False
return bool(self.redis.exists(self.watchdog_key))
def set_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -160,10 +149,7 @@ class RedisConnectorIndex:
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
return bool(self.redis.exists(self.active_key))
def set_connector_active(self) -> None:
"""This sets a signal to keep the indexing flow from getting cleaned up within
@@ -180,10 +166,7 @@ class RedisConnectorIndex:
return False
def generator_locked(self) -> bool:
if self.redis.exists(self.generator_lock_key):
return True
return False
return bool(self.redis.exists(self.generator_lock_key))
def set_generator_complete(self, payload: int | None) -> None:
if not payload:

View File

@@ -5,7 +5,13 @@ class RedisConnectorStop:
"""Manages interactions with redis for stop signaling. Should only be accessed
through RedisConnector."""
FENCE_PREFIX = "connectorstop_fence"
PREFIX = "connectorstop"
FENCE_PREFIX = f"{PREFIX}_fence"
# if this timeout is exceeded, the caller may decide to take more
# drastic measures
TIMEOUT_PREFIX = f"{PREFIX}_timeout"
TIMEOUT_TTL = 300
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
@@ -13,6 +19,7 @@ class RedisConnectorStop:
self.redis = redis
self.fence_key: str = f"{self.FENCE_PREFIX}_{id}"
self.timeout_key: str = f"{self.TIMEOUT_PREFIX}_{id}"
@property
def fenced(self) -> bool:
@@ -28,7 +35,22 @@ class RedisConnectorStop:
self.redis.set(self.fence_key, 0)
@property
def timed_out(self) -> bool:
if self.redis.exists(self.timeout_key):
return False
return True
def set_timeout(self) -> None:
"""After calling this, call timed_out to determine if the timeout has been
exceeded."""
self.redis.set(f"{self.timeout_key}", 0, ex=self.TIMEOUT_TTL)
@staticmethod
def reset_all(r: redis.Redis) -> None:
for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorStop.TIMEOUT_PREFIX + "*"):
r.delete(key)

View File

@@ -123,15 +123,15 @@ def get_cc_pair_full_info(
)
is_editable_for_current_user = editable_cc_pair is not None
cc_pair_identifier = ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
document_count_info_list = list(
get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=[cc_pair_identifier],
cc_pairs=[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
],
)
)
documents_indexed = (

View File

@@ -72,25 +72,31 @@ from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector import update_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids_parallel
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.connector_credential_pair import (
get_connector_credential_pairs_for_user_parallel,
)
from onyx.db.credentials import cleanup_gmail_credentials
from onyx.db.credentials import cleanup_google_drive_credentials
from onyx.db.credentials import create_credential
from onyx.db.credentials import delete_service_account_credentials
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed
from onyx.db.document import get_document_counts_for_cc_pairs
from onyx.db.document import get_document_counts_for_cc_pairs_parallel
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import IndexingMode
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.index_attempt import get_latest_index_attempts
from onyx.db.index_attempt import get_latest_index_attempts_by_status
from onyx.db.index_attempt import get_latest_index_attempts_parallel
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.file_processing.extract_file_text import convert_docx_to_txt
@@ -119,8 +125,8 @@ from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -578,6 +584,8 @@ def get_connector_status(
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
eager_load_connector=True,
eager_load_credential=True,
)
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
@@ -632,23 +640,35 @@ def get_connector_indexing_status(
# Additional checks are done to make sure the connector and credential still exist.
# TODO: make this one query ... possibly eager load or wrap in a read transaction
# to avoid the complexity of trying to error check throughout the function
cc_pairs = get_connector_credential_pairs_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
)
cc_pair_identifiers = [
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
)
for cc_pair in cc_pairs
]
latest_index_attempts = get_latest_index_attempts(
secondary_index=secondary_index,
db_session=db_session,
# see https://stackoverflow.com/questions/75758327/
# sqlalchemy-method-connection-for-bind-is-already-in-progress
# for why we can't pass in the current db_session to these functions
(
cc_pairs,
latest_index_attempts,
latest_finished_index_attempts,
) = run_functions_tuples_in_parallel(
[
(
# Gets the connector/credential pairs for the user
get_connector_credential_pairs_for_user_parallel,
(user, get_editable, None, True, True, True),
),
(
# Gets the most recent index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, False),
),
(
# Gets the most recent FINISHED index attempt for each connector/credential pair
get_latest_index_attempts_parallel,
(secondary_index, True, True),
),
]
)
cc_pairs = cast(list[ConnectorCredentialPair], cc_pairs)
latest_index_attempts = cast(list[IndexAttempt], latest_index_attempts)
cc_pair_to_latest_index_attempt = {
(
@@ -658,31 +678,60 @@ def get_connector_indexing_status(
for index_attempt in latest_index_attempts
}
document_count_info = get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=cc_pair_identifiers,
cc_pair_to_latest_finished_index_attempt = {
(
index_attempt.connector_credential_pair.connector_id,
index_attempt.connector_credential_pair.credential_id,
): index_attempt
for index_attempt in latest_finished_index_attempts
}
document_count_info, group_cc_pair_relationships = run_functions_tuples_in_parallel(
[
(
get_document_counts_for_cc_pairs_parallel,
(
[
ConnectorCredentialPairIdentifier(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
for cc_pair in cc_pairs
],
),
),
(
get_cc_pair_groups_for_ids_parallel,
([cc_pair.id for cc_pair in cc_pairs],),
),
]
)
document_count_info = cast(list[tuple[int, int, int]], document_count_info)
group_cc_pair_relationships = cast(
list[UserGroup__ConnectorCredentialPair], group_cc_pair_relationships
)
cc_pair_to_document_cnt = {
(connector_id, credential_id): cnt
for connector_id, credential_id, cnt in document_count_info
}
group_cc_pair_relationships = get_cc_pair_groups_for_ids(
db_session=db_session,
cc_pair_ids=[cc_pair.id for cc_pair in cc_pairs],
)
group_cc_pair_relationships_dict: dict[int, list[int]] = {}
for relationship in group_cc_pair_relationships:
group_cc_pair_relationships_dict.setdefault(relationship.cc_pair_id, []).append(
relationship.user_group_id
)
search_settings: SearchSettings | None = None
if not secondary_index:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_secondary_search_settings(db_session)
connector_to_cc_pair_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_cc_pair_ids.setdefault(cc_pair.connector_id, []).append(cc_pair.id)
get_search_settings = (
get_secondary_search_settings
if secondary_index
else get_current_search_settings
)
search_settings = get_search_settings(db_session)
for cc_pair in cc_pairs:
# TODO remove this to enable ingestion API
if cc_pair.name == "DefaultCCPair":
@@ -705,11 +754,8 @@ def get_connector_indexing_status(
(connector.id, credential.id)
)
latest_finished_attempt = get_latest_index_attempt_for_cc_pair_id(
db_session=db_session,
connector_credential_pair_id=cc_pair.id,
secondary_index=secondary_index,
only_finished=True,
latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get(
(connector.id, credential.id)
)
indexing_statuses.append(
@@ -718,7 +764,9 @@ def get_connector_indexing_status(
name=cc_pair.name,
in_progress=in_progress,
cc_pair_status=cc_pair.status,
connector=ConnectorSnapshot.from_connector_db_model(connector),
connector=ConnectorSnapshot.from_connector_db_model(
connector, connector_to_cc_pair_ids.get(connector.id, [])
),
credential=CredentialSnapshot.from_credential_db_model(credential),
access_type=cc_pair.access_type,
owner=credential.user.email if credential.user else "",

View File

@@ -83,7 +83,9 @@ class ConnectorSnapshot(ConnectorBase):
source: DocumentSource
@classmethod
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
def from_connector_db_model(
cls, connector: Connector, credential_ids: list[int] | None = None
) -> "ConnectorSnapshot":
return ConnectorSnapshot(
id=connector.id,
name=connector.name,
@@ -92,9 +94,10 @@ class ConnectorSnapshot(ConnectorBase):
connector_specific_config=connector.connector_specific_config,
refresh_freq=connector.refresh_freq,
prune_freq=connector.prune_freq,
credential_ids=[
association.credential.id for association in connector.credentials
],
credential_ids=(
credential_ids
or [association.credential.id for association in connector.credentials]
),
indexing_start=connector.indexing_start,
time_created=connector.time_created,
time_updated=connector.time_updated,

View File

@@ -1,15 +1,18 @@
import asyncio
import datetime
import io
import json
import os
import uuid
from collections.abc import Callable
from collections.abc import Generator
from datetime import timedelta
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
@@ -44,6 +47,7 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import set_as_latest_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import update_chat_session
from onyx.db.chat_search import search_chat_sessions
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.feedback import create_chat_message_feedback
@@ -65,10 +69,13 @@ from onyx.secondary_llm_flows.chat_session_naming import (
from onyx.server.query_and_chat.models import ChatFeedbackRequest
from onyx.server.query_and_chat.models import ChatMessageIdentifier
from onyx.server.query_and_chat.models import ChatRenameRequest
from onyx.server.query_and_chat.models import ChatSearchResponse
from onyx.server.query_and_chat.models import ChatSessionCreationRequest
from onyx.server.query_and_chat.models import ChatSessionDetailResponse
from onyx.server.query_and_chat.models import ChatSessionDetails
from onyx.server.query_and_chat.models import ChatSessionGroup
from onyx.server.query_and_chat.models import ChatSessionsResponse
from onyx.server.query_and_chat.models import ChatSessionSummary
from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
from onyx.server.query_and_chat.models import CreateChatMessageRequest
from onyx.server.query_and_chat.models import CreateChatSessionID
@@ -794,3 +801,84 @@ def fetch_chat_file(
file_io = file_store.read_file(file_id, mode="b")
return StreamingResponse(file_io, media_type=media_type)
@router.get("/search")
async def search_chats(
query: str | None = Query(None),
page: int = Query(1),
page_size: int = Query(10),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSearchResponse:
"""
Search for chat sessions based on the provided query.
If no query is provided, returns recent chat sessions.
"""
# Use the enhanced database function for chat search
chat_sessions, has_more = search_chat_sessions(
user_id=user.id if user else None,
db_session=db_session,
query=query,
page=page,
page_size=page_size,
include_deleted=False,
include_onyxbot_flows=False,
)
# Group chat sessions by time period
today = datetime.datetime.now().date()
yesterday = today - timedelta(days=1)
this_week = today - timedelta(days=7)
this_month = today - timedelta(days=30)
today_chats: list[ChatSessionSummary] = []
yesterday_chats: list[ChatSessionSummary] = []
this_week_chats: list[ChatSessionSummary] = []
this_month_chats: list[ChatSessionSummary] = []
older_chats: list[ChatSessionSummary] = []
for session in chat_sessions:
session_date = session.time_created.date()
chat_summary = ChatSessionSummary(
id=session.id,
name=session.description,
persona_id=session.persona_id,
time_created=session.time_created,
shared_status=session.shared_status,
folder_id=session.folder_id,
current_alternate_model=session.current_alternate_model,
current_temperature_override=session.temperature_override,
)
if session_date == today:
today_chats.append(chat_summary)
elif session_date == yesterday:
yesterday_chats.append(chat_summary)
elif session_date > this_week:
this_week_chats.append(chat_summary)
elif session_date > this_month:
this_month_chats.append(chat_summary)
else:
older_chats.append(chat_summary)
# Create groups
groups = []
if today_chats:
groups.append(ChatSessionGroup(title="Today", chats=today_chats))
if yesterday_chats:
groups.append(ChatSessionGroup(title="Yesterday", chats=yesterday_chats))
if this_week_chats:
groups.append(ChatSessionGroup(title="This Week", chats=this_week_chats))
if this_month_chats:
groups.append(ChatSessionGroup(title="This Month", chats=this_month_chats))
if older_chats:
groups.append(ChatSessionGroup(title="Older", chats=older_chats))
return ChatSearchResponse(
groups=groups,
has_more=has_more,
next_page=page + 1 if has_more else None,
)

View File

@@ -24,6 +24,7 @@ from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.tools.models import ToolCallFinalResult
if TYPE_CHECKING:
pass
@@ -282,3 +283,35 @@ class AdminSearchRequest(BaseModel):
class AdminSearchResponse(BaseModel):
documents: list[SearchDoc]
class ChatSessionSummary(BaseModel):
id: UUID
name: str | None = None
persona_id: int | None = None
time_created: datetime
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
current_alternate_model: str | None = None
current_temperature_override: float | None = None
class ChatSessionGroup(BaseModel):
title: str
chats: list[ChatSessionSummary]
class ChatSearchResponse(BaseModel):
groups: list[ChatSessionGroup]
has_more: bool
next_page: int | None = None
class ChatSearchRequest(BaseModel):
query: str | None = None
page: int = 1
page_size: int = 10
class CreateChatResponse(BaseModel):
chat_session_id: str

View File

@@ -37,7 +37,7 @@ langchainhub==0.1.21
langgraph==0.2.72
langgraph-checkpoint==2.0.13
langgraph-sdk==0.1.44
litellm==1.60.2
litellm==1.61.16
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45

View File

@@ -12,5 +12,5 @@ torch==2.2.0
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.60.2
litellm==1.61.16
sentry-sdk[fastapi,celery,starlette]==2.14.0

View File

@@ -10,7 +10,8 @@ from onyx.connectors.onyx_jira.connector import JiraConnector
@pytest.fixture
def jira_connector() -> JiraConnector:
connector = JiraConnector(
"https://danswerai.atlassian.net/jira/software/c/projects/AS/boards/6",
jira_base_url="https://danswerai.atlassian.net",
project_key="AS",
comment_email_blacklist=[],
)
connector.load_credentials(

View File

@@ -4,6 +4,10 @@ from onyx.connectors.models import Document
from onyx.connectors.web.connector import WEB_CONNECTOR_VALID_SETTINGS
from onyx.connectors.web.connector import WebConnector
EXPECTED_QUOTE = (
"If you can't explain it to a six year old, you don't understand it yourself."
)
# NOTE(rkuo): we will probably need to adjust this test to point at our own test site
# to avoid depending on a third party site
@@ -11,7 +15,7 @@ from onyx.connectors.web.connector import WebConnector
def web_connector(request: pytest.FixtureRequest) -> WebConnector:
scroll_before_scraping = request.param
connector = WebConnector(
base_url="https://developer.onewelcome.com",
base_url="https://quotes.toscrape.com/scroll",
web_connector_type=WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value,
scroll_before_scraping=scroll_before_scraping,
)
@@ -28,7 +32,7 @@ def test_web_connector_scroll(web_connector: WebConnector) -> None:
assert len(all_docs) == 1
doc = all_docs[0]
assert "Onegini Identity Cloud" in doc.sections[0].text
assert EXPECTED_QUOTE in doc.sections[0].text
@pytest.mark.parametrize("web_connector", [False], indirect=True)
@@ -41,4 +45,4 @@ def test_web_connector_no_scroll(web_connector: WebConnector) -> None:
assert len(all_docs) == 1
doc = all_docs[0]
assert "Onegini Identity Cloud" not in doc.sections[0].text
assert EXPECTED_QUOTE not in doc.sections[0].text

View File

@@ -71,12 +71,13 @@ def litellm_embedding_model() -> EmbeddingModel:
normalize=True,
query_prefix=None,
passage_prefix=None,
api_key=os.getenv("LITE_LLM_API_KEY"),
api_key=os.getenv("LITELLM_API_KEY"),
provider_type=EmbeddingProvider.LITELLM,
api_url=os.getenv("LITE_LLM_API_URL"),
api_url=os.getenv("LITELLM_API_URL"),
)
@pytest.mark.skip(reason="re-enable when we can get the correct litellm key and url")
def test_litellm_embedding(litellm_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, litellm_embedding_model, 1536)
_run_embeddings(TOO_LONG_SAMPLE, litellm_embedding_model, 1536)
@@ -117,6 +118,11 @@ def azure_embedding_model() -> EmbeddingModel:
)
def test_azure_embedding(azure_embedding_model: EmbeddingModel) -> None:
_run_embeddings(VALID_SAMPLE, azure_embedding_model, 1536)
_run_embeddings(TOO_LONG_SAMPLE, azure_embedding_model, 1536)
# NOTE (chris): this test doesn't work, and I do not know why
# def test_azure_embedding_model_rate_limit(azure_embedding_model: EmbeddingModel):
# """NOTE: this test relies on a very low rate limit for the Azure API +

View File

@@ -0,0 +1,37 @@
services:
indexing_model_server:
image: onyxdotapp/onyx-model-server:${IMAGE_TAG:-latest}
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- INDEX_BATCH_SIZE=${INDEX_BATCH_SIZE:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_ONLY=True
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
- CLIENT_EMBEDDING_TIMEOUT=${CLIENT_EMBEDDING_TIMEOUT:-}
# Analytics Configs
- SENTRY_DSN=${SENTRY_DSN:-}
volumes:
# Not necessary, this is just to reduce download time during startup
- indexing_huggingface_model_cache:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
ports:
- "9000:9000" # <-- Add this line to expose the port to the host
volumes:
indexing_huggingface_model_cache:

View File

@@ -68,6 +68,28 @@ const nextConfig = {
},
];
},
async rewrites() {
return [
{
source: "/api/docs/:path*", // catch /api/docs and /api/docs/...
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/docs/:path*`,
},
{
source: "/api/docs", // if you also need the exact /api/docs
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/docs`,
},
{
source: "/openapi.json",
destination: `${
process.env.INTERNAL_URL || "http://localhost:8080"
}/openapi.json`,
},
];
},
};
// Sentry configuration for error monitoring:

View File

@@ -142,6 +142,7 @@ import {
import { Button } from "@/components/ui/button";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { MessageChannel } from "node:worker_threads";
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -870,6 +871,7 @@ export function ChatPage({
}, [liveAssistant]);
const filterManager = useFilters();
const [isChatSearchModalOpen, setIsChatSearchModalOpen] = useState(false);
const [currentFeedback, setCurrentFeedback] = useState<
[FeedbackType, number] | null
@@ -2329,6 +2331,11 @@ export function ChatPage({
/>
)}
<ChatSearchModal
open={isChatSearchModalOpen}
onCloseModal={() => setIsChatSearchModalOpen(false)}
/>
{retrievalEnabled && documentSidebarVisible && settings?.isMobile && (
<div className="md:hidden">
<Modal
@@ -2436,6 +2443,9 @@ export function ChatPage({
>
<div className="w-full relative">
<HistorySidebar
toggleChatSessionSearchModal={() =>
setIsChatSearchModalOpen((open) => !open)
}
liveAssistant={liveAssistant}
setShowAssistantsModal={setShowAssistantsModal}
explicitlyUntoggle={explicitlyUntoggle}
@@ -2452,6 +2462,7 @@ export function ChatPage({
showDeleteAllModal={() => setShowDeleteAllModal(true)}
/>
</div>
<div
className={`
flex-none

View File

@@ -0,0 +1,31 @@
import React from "react";
import { ChatSearchItem } from "./ChatSearchItem";
import { ChatSessionSummary } from "./interfaces";
interface ChatSearchGroupProps {
title: string;
chats: ChatSessionSummary[];
onSelectChat: (id: string) => void;
}
export function ChatSearchGroup({
title,
chats,
onSelectChat,
}: ChatSearchGroupProps) {
return (
<div className="mb-4">
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-gray-800/90 py-2 px-4 px-4">
<div className="text-xs font-medium leading-4 text-gray-600 dark:text-gray-400">
{title}
</div>
</div>
<ol>
{chats.map((chat) => (
<ChatSearchItem key={chat.id} chat={chat} onSelect={onSelectChat} />
))}
</ol>
</div>
);
}

View File

@@ -0,0 +1,30 @@
import React from "react";
import { MessageSquare } from "lucide-react";
import { ChatSessionSummary } from "./interfaces";
interface ChatSearchItemProps {
chat: ChatSessionSummary;
onSelect: (id: string) => void;
}
export function ChatSearchItem({ chat, onSelect }: ChatSearchItemProps) {
return (
<li>
<div className="cursor-pointer" onClick={() => onSelect(chat.id)}>
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-800">
<div className="flex items-center">
<MessageSquare className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
<div className="relative grow overflow-hidden whitespace-nowrap pl-4">
<div className="text-sm dark:text-neutral-200">
{chat.name || "Untitled Chat"}
</div>
</div>
<div className="opacity-0 group-hover:opacity-100 transition-opacity text-xs text-neutral-500 dark:text-neutral-400">
{new Date(chat.time_created).toLocaleDateString()}
</div>
</div>
</div>
</div>
</li>
);
}

View File

@@ -0,0 +1,122 @@
import React, { useRef } from "react";
import { Dialog, DialogContent } from "@/components/ui/dialog";
import { ScrollArea } from "@/components/ui/scroll-area";
import { ChatSearchGroup } from "./ChatSearchGroup";
import { NewChatButton } from "./NewChatButton";
import { useChatSearch } from "./hooks/useChatSearch";
import { LoadingSpinner } from "./LoadingSpinner";
import { useRouter } from "next/navigation";
import { SearchInput } from "./components/SearchInput";
import { ChatSearchSkeletonList } from "./components/ChatSearchSkeleton";
import { useIntersectionObserver } from "./hooks/useIntersectionObserver";
interface ChatSearchModalProps {
open: boolean;
onCloseModal: () => void;
}
export function ChatSearchModal({ open, onCloseModal }: ChatSearchModalProps) {
const {
searchQuery,
setSearchQuery,
chatGroups,
isLoading,
isSearching,
hasMore,
fetchMoreChats,
} = useChatSearch();
const onClose = () => {
setSearchQuery("");
onCloseModal();
};
const router = useRouter();
const scrollAreaRef = useRef<HTMLDivElement>(null);
const { targetRef } = useIntersectionObserver({
root: scrollAreaRef.current,
onIntersect: fetchMoreChats,
enabled: open && hasMore && !isLoading,
});
const handleChatSelect = (chatId: string) => {
router.push(`/chat?chatId=${chatId}`);
onClose();
};
const handleNewChat = async () => {
try {
onClose();
router.push(`/chat`);
} catch (error) {
console.error("Error creating new chat:", error);
}
};
return (
<Dialog open={open} onOpenChange={(open) => !open && onClose()}>
<DialogContent
hideCloseIcon
className="!rounded-xl overflow-hidden p-0 w-full max-w-2xl"
backgroundColor="bg-neutral-950/20 shadow-xl"
>
<div className="w-full flex flex-col bg-white dark:bg-neutral-800 h-[80vh] max-h-[600px]">
<div className="sticky top-0 z-20 px-6 py-3 w-full flex items-center justify-between bg-white dark:bg-neutral-800 border-b border-neutral-200 dark:border-neutral-700">
<SearchInput
searchQuery={searchQuery}
setSearchQuery={setSearchQuery}
isSearching={isSearching}
/>
</div>
<ScrollArea
className="flex-grow bg-white relative dark:bg-neutral-800"
ref={scrollAreaRef}
type="auto"
>
<div className="px-4 py-2">
<NewChatButton onClick={handleNewChat} />
{isSearching ? (
<ChatSearchSkeletonList />
) : isLoading && chatGroups.length === 0 ? (
<div className="py-8">
<LoadingSpinner size="large" className="mx-auto" />
</div>
) : chatGroups.length > 0 ? (
<>
{chatGroups.map((group, groupIndex) => (
<ChatSearchGroup
key={groupIndex}
title={group.title}
chats={group.chats}
onSelectChat={handleChatSelect}
/>
))}
<div ref={targetRef} className="py-4">
{isLoading && hasMore && (
<LoadingSpinner className="mx-auto" />
)}
{!hasMore && chatGroups.length > 0 && (
<div className="text-center text-xs text-neutral-500 dark:text-neutral-400 py-2">
No more chats to load
</div>
)}
</div>
</>
) : (
!isLoading && (
<div className="px-4 py-3 text-sm text-neutral-500 dark:text-neutral-400">
No chats found
</div>
)
)}
</div>
</ScrollArea>
</div>
</DialogContent>
</Dialog>
);
}

View File

@@ -0,0 +1,25 @@
import React from "react";
interface LoadingSpinnerProps {
size?: "small" | "medium" | "large";
className?: string;
}
export function LoadingSpinner({
size = "medium",
className = "",
}: LoadingSpinnerProps) {
const sizeClasses = {
small: "h-4 w-4 border-2",
medium: "h-6 w-6 border-2",
large: "h-8 w-8 border-3",
};
return (
<div className={`flex justify-center items-center ${className}`}>
<div
className={`${sizeClasses[size]} animate-spin rounded-full border-solid border-gray-300 border-t-gray-600 dark:border-gray-600 dark:border-t-gray-300`}
/>
</div>
);
}

View File

@@ -0,0 +1,22 @@
import React from "react";
import { PlusCircle } from "lucide-react";
import { NewChatIcon } from "@/components/icons/icons";
interface NewChatButtonProps {
onClick: () => void;
}
export function NewChatButton({ onClick }: NewChatButtonProps) {
return (
<div className="mb-2">
<div className="cursor-pointer" onClick={onClick}>
<div className="group relative flex items-center rounded-lg px-4 py-3 hover:bg-neutral-100 dark:bg-neutral-800 dark:hover:bg-neutral-700">
<NewChatIcon className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
<div className="relative grow overflow-hidden whitespace-nowrap pl-4">
<div className="text-sm dark:text-neutral-200">New Chat</div>
</div>
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,25 @@
import React from "react";
export function ChatSearchItemSkeleton() {
return (
<div className="animate-pulse px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-700 rounded-lg">
<div className="flex items-center">
<div className="h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-700"></div>
<div className="ml-4 flex-1">
<div className="h-2 my-1 w-3/4 bg-neutral-200 dark:bg-neutral-700 rounded"></div>
<div className="mt-2 h-3 w-1/2 bg-neutral-200 dark:bg-neutral-700 rounded"></div>
</div>
</div>
</div>
);
}
export function ChatSearchSkeletonList() {
return (
<div>
{[...Array(5)].map((_, index) => (
<ChatSearchItemSkeleton key={index} />
))}
</div>
);
}

View File

@@ -0,0 +1,42 @@
import React from "react";
import { Input } from "@/components/ui/input";
import { XIcon } from "lucide-react";
import { LoadingSpinner } from "../LoadingSpinner";
interface SearchInputProps {
searchQuery: string;
setSearchQuery: (query: string) => void;
isSearching: boolean;
}
export function SearchInput({
searchQuery,
setSearchQuery,
isSearching,
}: SearchInputProps) {
return (
<div className="relative w-full">
<div className="flex items-center">
<Input
removeFocusRing
className="w-full !focus-visible:ring-offset-0 !focus-visible:ring-none !focus-visible:ring-0 hover:focus-none border-none bg-transparent placeholder:text-neutral-400 focus:border-transparent focus:outline-none focus:ring-0 dark:placeholder:text-neutral-500 dark:text-neutral-200"
placeholder="Search chats..."
value={searchQuery}
onChange={(e) => setSearchQuery(e.target.value)}
/>
{searchQuery &&
(isSearching ? (
<div className="absolute right-2 top-1/2 -translate-y-1/2">
<LoadingSpinner size="small" />
</div>
) : (
<XIcon
size={16}
className="absolute right-2 top-1/2 -translate-y-1/2 cursor-pointer"
onClick={() => setSearchQuery("")}
/>
))}
</div>
</div>
);
}

View File

@@ -0,0 +1,255 @@
import { useState, useEffect, useCallback, useRef } from "react";
import { fetchChatSessions } from "../utils";
import { ChatSessionGroup, ChatSessionSummary } from "../interfaces";
interface UseChatSearchOptions {
pageSize?: number;
}
interface UseChatSearchResult {
searchQuery: string;
setSearchQuery: (query: string) => void;
chatGroups: ChatSessionGroup[];
isLoading: boolean;
isSearching: boolean;
hasMore: boolean;
fetchMoreChats: () => Promise<void>;
refreshChats: () => Promise<void>;
}
export function useChatSearch(
options: UseChatSearchOptions = {}
): UseChatSearchResult {
const { pageSize = 10 } = options;
const [searchQuery, setSearchQueryInternal] = useState("");
const [chatGroups, setChatGroups] = useState<ChatSessionGroup[]>([]);
const [isLoading, setIsLoading] = useState(false);
const [isSearching, setIsSearching] = useState(false);
const [hasMore, setHasMore] = useState(true);
const [debouncedIsSearching, setDebouncedIsSearching] = useState(false);
const [page, setPage] = useState(1);
const searchTimeoutRef = useRef<NodeJS.Timeout | null>(null);
const currentAbortController = useRef<AbortController | null>(null);
const activeSearchIdRef = useRef<number>(0); // Add a unique ID for each search
const PAGE_SIZE = pageSize;
useEffect(() => {
// Only set a timeout if we're not already in the desired state
if (!isSearching) {
const timeout = setTimeout(() => {
setDebouncedIsSearching(isSearching);
}, 300);
// Keep track of the timeout reference to clear it on cleanup
const timeoutRef = timeout;
return () => clearTimeout(timeoutRef);
} else {
setDebouncedIsSearching(isSearching);
}
}, [isSearching, debouncedIsSearching]);
// Helper function to merge groups properly
const mergeGroups = useCallback(
(
existingGroups: ChatSessionGroup[],
newGroups: ChatSessionGroup[]
): ChatSessionGroup[] => {
const mergedGroups: Record<string, ChatSessionSummary[]> = {};
// Initialize with existing groups
existingGroups.forEach((group) => {
mergedGroups[group.title] = [
...(mergedGroups[group.title] || []),
...group.chats,
];
});
// Merge in new groups
newGroups.forEach((group) => {
mergedGroups[group.title] = [
...(mergedGroups[group.title] || []),
...group.chats,
];
});
// Convert back to array format
return Object.entries(mergedGroups)
.map(([title, chats]) => ({ title, chats }))
.sort((a, b) => {
// Custom sort order for time periods
const order = [
"Today",
"Yesterday",
"This Week",
"This Month",
"Older",
];
return order.indexOf(a.title) - order.indexOf(b.title);
});
},
[]
);
const fetchInitialChats = useCallback(
async (query: string, searchId: number, signal?: AbortSignal) => {
try {
setIsLoading(true);
setPage(1);
const response = await fetchChatSessions({
query,
page: 1,
page_size: PAGE_SIZE,
signal,
});
// Only update state if this is still the active search
if (activeSearchIdRef.current === searchId && !signal?.aborted) {
setChatGroups(response.groups);
setHasMore(response.has_more);
}
} catch (error: any) {
if (
error?.name !== "AbortError" &&
activeSearchIdRef.current === searchId
) {
console.error("Error fetching chats:", error);
}
} finally {
// Only update loading state if this is still the active search
if (activeSearchIdRef.current === searchId) {
setIsLoading(false);
setIsSearching(false);
}
}
},
[PAGE_SIZE]
);
const fetchMoreChats = useCallback(async () => {
if (isLoading || !hasMore) return;
setIsLoading(true);
if (currentAbortController.current) {
currentAbortController.current.abort();
}
const newSearchId = activeSearchIdRef.current + 1;
activeSearchIdRef.current = newSearchId;
const controller = new AbortController();
currentAbortController.current = controller;
const localSignal = controller.signal;
try {
const nextPage = page + 1;
const response = await fetchChatSessions({
query: searchQuery,
page: nextPage,
page_size: PAGE_SIZE,
signal: localSignal,
});
if (activeSearchIdRef.current === newSearchId && !localSignal.aborted) {
// Use mergeGroups instead of just concatenating
setChatGroups((prevGroups) => mergeGroups(prevGroups, response.groups));
setHasMore(response.has_more);
setPage(nextPage);
}
} catch (error: any) {
if (
error?.name !== "AbortError" &&
activeSearchIdRef.current === newSearchId
) {
console.error("Error fetching more chats:", error);
}
} finally {
if (activeSearchIdRef.current === newSearchId) {
setIsLoading(false);
}
}
}, [isLoading, hasMore, page, searchQuery, PAGE_SIZE, mergeGroups]);
const setSearchQuery = useCallback(
(query: string) => {
setSearchQueryInternal(query);
// Clear any pending timeouts
if (searchTimeoutRef.current) {
clearTimeout(searchTimeoutRef.current);
searchTimeoutRef.current = null;
}
// Abort any in-flight requests
if (currentAbortController.current) {
currentAbortController.current.abort();
currentAbortController.current = null;
}
// Create a new search ID
const newSearchId = activeSearchIdRef.current + 1;
activeSearchIdRef.current = newSearchId;
if (query.trim()) {
setIsSearching(true);
const controller = new AbortController();
currentAbortController.current = controller;
searchTimeoutRef.current = setTimeout(() => {
fetchInitialChats(query, newSearchId, controller.signal);
}, 500);
} else {
// For empty queries, clear search state immediately
setIsSearching(false);
// Optionally fetch initial unfiltered results
fetchInitialChats("", newSearchId);
}
},
[fetchInitialChats]
);
// Initial fetch on mount
useEffect(() => {
const newSearchId = activeSearchIdRef.current + 1;
activeSearchIdRef.current = newSearchId;
const controller = new AbortController();
currentAbortController.current = controller;
fetchInitialChats(searchQuery, newSearchId, controller.signal);
return () => {
if (searchTimeoutRef.current) {
clearTimeout(searchTimeoutRef.current);
}
controller.abort();
};
}, [fetchInitialChats, searchQuery]);
return {
searchQuery,
setSearchQuery,
chatGroups,
isLoading,
isSearching: debouncedIsSearching,
hasMore,
fetchMoreChats,
refreshChats: () => {
const newSearchId = activeSearchIdRef.current + 1;
activeSearchIdRef.current = newSearchId;
if (currentAbortController.current) {
currentAbortController.current.abort();
}
const controller = new AbortController();
currentAbortController.current = controller;
return fetchInitialChats(searchQuery, newSearchId, controller.signal);
},
};
}

View File

@@ -0,0 +1,51 @@
import { useEffect, useRef } from "react";
interface UseIntersectionObserverOptions {
root?: Element | null;
rootMargin?: string;
threshold?: number;
onIntersect: () => void;
enabled?: boolean;
}
export function useIntersectionObserver({
root = null,
rootMargin = "0px",
threshold = 0.1,
onIntersect,
enabled = true,
}: UseIntersectionObserverOptions) {
const targetRef = useRef<HTMLDivElement | null>(null);
const observerRef = useRef<IntersectionObserver | null>(null);
useEffect(() => {
if (!enabled) return;
const options = {
root,
rootMargin,
threshold,
};
const observer = new IntersectionObserver((entries) => {
const [entry] = entries;
if (entry.isIntersecting) {
onIntersect();
}
}, options);
if (targetRef.current) {
observer.observe(targetRef.current);
}
observerRef.current = observer;
return () => {
if (observerRef.current) {
observerRef.current.disconnect();
}
};
}, [root, rootMargin, threshold, onIntersect, enabled]);
return { targetRef };
}

View File

@@ -0,0 +1,34 @@
import { ChatSessionSharedStatus } from "../interfaces";
export interface ChatSessionSummary {
id: string;
name: string | null;
persona_id: number | null;
time_created: string;
shared_status: ChatSessionSharedStatus;
folder_id: number | null;
current_alternate_model: string | null;
current_temperature_override: number | null;
highlights?: string[];
}
export interface ChatSessionGroup {
title: string;
chats: ChatSessionSummary[];
}
export interface ChatSessionsResponse {
sessions: ChatSessionSummary[];
}
export interface ChatSearchResponse {
groups: ChatSessionGroup[];
has_more: boolean;
next_page: number | null;
}
export interface ChatSearchRequest {
query?: string;
page?: number;
page_size?: number;
}

View File

@@ -0,0 +1,79 @@
import { ChatSearchRequest, ChatSearchResponse } from "./interfaces";
const API_BASE_URL = "/api";
export interface ExtendedChatSearchRequest extends ChatSearchRequest {
include_highlights?: boolean;
signal?: AbortSignal;
}
export async function fetchChatSessions(
params: ExtendedChatSearchRequest = {}
): Promise<ChatSearchResponse> {
const queryParams = new URLSearchParams();
if (params.query) {
queryParams.append("query", params.query);
}
if (params.page) {
queryParams.append("page", params.page.toString());
}
if (params.page_size) {
queryParams.append("page_size", params.page_size.toString());
}
if (params.include_highlights !== undefined) {
queryParams.append(
"include_highlights",
params.include_highlights.toString()
);
}
const queryString = queryParams.toString()
? `?${queryParams.toString()}`
: "";
const response = await fetch(`${API_BASE_URL}/chat/search${queryString}`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
throw new Error(`Failed to fetch chat sessions: ${response.statusText}`);
}
return response.json();
}
export async function createNewChat(): Promise<{ chat_session_id: string }> {
const response = await fetch(`${API_BASE_URL}/chat/sessions`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({}),
});
if (!response.ok) {
throw new Error(`Failed to create new chat: ${response.statusText}`);
}
return response.json();
}
export async function deleteChat(chatId: string): Promise<void> {
const response = await fetch(`${API_BASE_URL}/chat/sessions/${chatId}`, {
method: "DELETE",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
throw new Error(`Failed to delete chat: ${response.statusText}`);
}
}

View File

@@ -1,4 +1,9 @@
import React, { useState, useEffect } from "react";
import React, {
useState,
useEffect,
useCallback,
useLayoutEffect,
} from "react";
import {
Popover,
PopoverContent,
@@ -28,6 +33,7 @@ import { FiAlertTriangle } from "react-icons/fi";
import { Slider } from "@/components/ui/slider";
import { useUser } from "@/components/user/UserProvider";
import { TruncatedText } from "@/components/ui/truncatedText";
interface LLMPopoverProps {
llmProviders: LLMProviderDescriptor[];
@@ -158,9 +164,7 @@ export default function LLMPopover({
size: 16,
className: "flex-none my-auto text-black",
})}
<span className="line-clamp-1 ">
{getDisplayNameForModel(name)}
</span>
<TruncatedText text={getDisplayNameForModel(name)} />
{(() => {
if (currentAssistant?.llm_model_version_override === name) {
return (

View File

@@ -4,10 +4,7 @@ import React, {
ForwardedRef,
forwardRef,
useContext,
useState,
useCallback,
useLayoutEffect,
useRef,
} from "react";
import Link from "next/link";
import {
@@ -50,9 +47,9 @@ import {
} from "@dnd-kit/sortable";
import { useSortable } from "@dnd-kit/sortable";
import { CSS } from "@dnd-kit/utilities";
import { CirclePlus, CircleX, PinIcon } from "lucide-react";
import { CircleX, PinIcon } from "lucide-react";
import { restrictToVerticalAxis } from "@dnd-kit/modifiers";
import { turborepoTraceAccess } from "next/dist/build/turborepo-access-trace";
import { TruncatedText } from "@/components/ui/truncatedText";
interface HistorySidebarProps {
liveAssistant?: Persona | null;
@@ -69,6 +66,7 @@ interface HistorySidebarProps {
explicitlyUntoggle: () => void;
showDeleteAllModal?: () => void;
setShowAssistantsModal: (show: boolean) => void;
toggleChatSessionSearchModal?: () => void;
}
interface SortableAssistantProps {
@@ -101,24 +99,6 @@ const SortableAssistant: React.FC<SortableAssistantProps> = ({
...(isDragging ? { zIndex: 1000, position: "relative" as const } : {}),
};
const nameRef = useRef<HTMLParagraphElement>(null);
const hiddenNameRef = useRef<HTMLSpanElement>(null);
const [isNameTruncated, setIsNameTruncated] = useState(false);
useLayoutEffect(() => {
const checkTruncation = () => {
if (nameRef.current && hiddenNameRef.current) {
const visibleWidth = nameRef.current.offsetWidth;
const fullTextWidth = hiddenNameRef.current.offsetWidth;
setIsNameTruncated(fullTextWidth > visibleWidth);
}
};
checkTruncation();
window.addEventListener("resize", checkTruncation);
return () => window.removeEventListener("resize", checkTruncation);
}, [assistant.name]);
return (
<div
ref={setNodeRef}
@@ -146,27 +126,11 @@ const SortableAssistant: React.FC<SortableAssistantProps> = ({
} relative flex items-center gap-x-2 py-1 px-2 rounded-md`}
>
<AssistantIcon assistant={assistant} size={16} className="flex-none" />
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<p
ref={nameRef}
className="text-base text-left w-fit line-clamp-1 text-ellipsis text-black dark:text-[#D4D4D4]"
>
{assistant.name}
</p>
</TooltipTrigger>
{isNameTruncated && (
<TooltipContent>{assistant.name}</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
<span
ref={hiddenNameRef}
className="absolute left-[-9999px] whitespace-nowrap"
>
{assistant.name}
</span>
<TruncatedText
className="text-base mr-4 text-left w-fit line-clamp-1 text-ellipsis text-black dark:text-[#D4D4D4]"
text={assistant.name}
/>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
@@ -217,6 +181,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
toggleSidebar,
removeToggle,
showShareModal,
toggleChatSessionSearchModal,
showDeleteModal,
showDeleteAllModal,
},
@@ -355,7 +320,6 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
)}
</div>
)}
<div className="h-full relative overflow-x-hidden overflow-y-auto">
<div className="flex px-4 font-normal text-sm gap-x-2 leading-normal text-text-500/80 dark:text-[#D4D4D4] items-center font-normal leading-normal">
Assistants
@@ -432,6 +396,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
</div>
<PagesTab
toggleChatSessionSearchModal={toggleChatSessionSearchModal}
showDeleteModal={showDeleteModal}
showShareModal={showShareModal}
closeSidebar={removeToggle}

View File

@@ -17,6 +17,13 @@ import { useState, useCallback, useRef, useContext, useEffect } from "react";
import { Caret } from "@/components/icons/icons";
import { groupSessionsByDateRange } from "../lib";
import React from "react";
import {
Tooltip,
TooltipProvider,
TooltipTrigger,
TooltipContent,
} from "@/components/ui/tooltip";
import { Search } from "lucide-react";
import {
DndContext,
closestCenter,
@@ -101,10 +108,12 @@ export function PagesTab({
showShareModal,
showDeleteModal,
showDeleteAllModal,
toggleChatSessionSearchModal,
}: {
existingChats?: ChatSession[];
currentChatId?: string;
folders?: Folder[];
toggleChatSessionSearchModal?: () => void;
closeSidebar?: () => void;
showShareModal?: (chatSession: ChatSession) => void;
showDeleteModal?: (chatSession: ChatSession) => void;
@@ -318,8 +327,28 @@ export function PagesTab({
<div className="flex flex-col gap-y-2 flex-grow">
{popup}
<div className="px-4 mt-2 group mr-2 bg-background-sidebar dark:bg-transparent z-20">
<div className="flex justify-between text-sm gap-x-2 text-text-300/80 items-center font-normal leading-normal">
<div className="flex group justify-between text-sm gap-x-2 text-text-300/80 items-center font-normal leading-normal">
<p>Chats</p>
<TooltipProvider delayDuration={1000}>
<Tooltip>
<TooltipTrigger asChild>
<button
className="my-auto mr-auto group-hover:opacity-100 opacity-0 transition duration-200 cursor-pointer gap-x-1 items-center text-black text-xs font-medium leading-normal mobile:hidden"
onClick={() => {
toggleChatSessionSearchModal?.();
}}
>
<Search
className="flex-none text-text-mobile-sidebar"
size={12}
/>
</button>
</TooltipTrigger>
<TooltipContent>Search Chats</TooltipContent>
</Tooltip>
</TooltipProvider>
<button
onClick={handleCreateFolder}
className="flex group-hover:opacity-100 opacity-0 transition duration-200 cursor-pointer gap-x-1 items-center text-black text-xs font-medium leading-normal"

View File

@@ -38,7 +38,9 @@ export const ConnectorTitle = ({
const typedConnector = connector as Connector<GithubConfig>;
additionalMetadata.set(
"Repo",
`${typedConnector.connector_specific_config.repo_owner}/${typedConnector.connector_specific_config.repo_name}`
typedConnector.connector_specific_config.repo_name
? `${typedConnector.connector_specific_config.repo_owner}/${typedConnector.connector_specific_config.repo_name}`
: `${typedConnector.connector_specific_config.repo_owner}/*`
);
} else if (connector.source === "gitlab") {
const typedConnector = connector as Connector<GitlabConfig>;

View File

@@ -110,37 +110,38 @@ export default function LogoWithText({
</Tooltip>
</TooltipProvider>
)}
{showArrow && toggleSidebar && (
<TooltipProvider delayDuration={0}>
<Tooltip>
<TooltipTrigger asChild>
<button
className="mr-2 my-auto ml-auto"
onClick={() => {
toggleSidebar();
if (toggled) {
explicitlyUntoggle();
}
}}
>
{!toggled && !combinedSettings?.isMobile ? (
<RightToLineIcon className="mobile:hidden text-sidebar-toggle" />
) : (
<LeftToLineIcon className="mobile:hidden text-sidebar-toggle" />
)}
<FiSidebar
size={20}
className="hidden mobile:block text-text-mobile-sidebar"
/>
</button>
</TooltipTrigger>
<TooltipContent className="!border-none">
{toggled ? `Unpin sidebar` : "Pin sidebar"}
</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
<div className="flex ml-auto gap-x-4">
{showArrow && toggleSidebar && (
<TooltipProvider delayDuration={0}>
<Tooltip>
<TooltipTrigger asChild>
<button
className="mr-2 my-auto"
onClick={() => {
toggleSidebar();
if (toggled) {
explicitlyUntoggle();
}
}}
>
{!toggled && !combinedSettings?.isMobile ? (
<RightToLineIcon className="mobile:hidden text-sidebar-toggle" />
) : (
<LeftToLineIcon className="mobile:hidden text-sidebar-toggle" />
)}
<FiSidebar
size={20}
className="hidden mobile:block text-text-mobile-sidebar"
/>
</button>
</TooltipTrigger>
<TooltipContent className="!border-none">
{toggled ? `Unpin sidebar` : "Pin sidebar"}
</TooltipContent>
</Tooltip>
</TooltipProvider>
)}
</div>
</div>
);
}

View File

@@ -16,12 +16,15 @@ const DialogClose = DialogPrimitive.Close;
const DialogOverlay = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Overlay>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Overlay>
>(({ className, ...props }, ref) => (
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Overlay> & {
backgroundColor?: string;
}
>(({ className, backgroundColor, ...props }, ref) => (
<DialogPrimitive.Overlay
ref={ref}
className={cn(
"fixed inset-0 z-50 bg-neutral-950/80 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0",
backgroundColor || "bg-neutral-950/60",
"fixed inset-0 z-50 data-[state=open]:animate-in data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0",
className
)}
{...props}
@@ -33,10 +36,11 @@ const DialogContent = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Content> & {
hideCloseIcon?: boolean;
backgroundColor?: string;
}
>(({ className, children, hideCloseIcon, ...props }, ref) => (
>(({ className, children, hideCloseIcon, backgroundColor, ...props }, ref) => (
<DialogPortal>
<DialogOverlay />
<DialogOverlay backgroundColor={backgroundColor} />
<DialogPrimitive.Content
ref={ref}
className={cn(

View File

@@ -2,13 +2,20 @@ import * as React from "react";
import { cn } from "@/lib/utils";
const Input = React.forwardRef<HTMLInputElement, React.ComponentProps<"input">>(
({ className, type, ...props }, ref) => {
interface InputProps extends React.ComponentProps<"input"> {
removeFocusRing?: boolean;
}
const Input = React.forwardRef<HTMLInputElement, InputProps>(
({ className, type, removeFocusRing, ...props }, ref) => {
return (
<input
type={type}
className={cn(
"flex h-10 w-full rounded-md border border-neutral-200 bg-white px-3 py-2 text-base ring-offset-white file:border-0 file:bg-transparent file:text-sm file:font-medium file:text-neutral-950 placeholder:text-neutral-500 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-neutral-950 focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm dark:border-neutral-800 dark:bg-neutral-950 dark:ring-offset-neutral-950 dark:file:text-neutral-50 dark:placeholder:text-neutral-400 dark:focus-visible:ring-neutral-300",
"flex h-10 w-full rounded-md border border-neutral-200 bg-white px-3 py-2 text-base ring-offset-white file:border-0 file:bg-transparent file:text-sm file:font-medium file:text-neutral-950 placeholder:text-neutral-500",
removeFocusRing
? ""
: "focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-neutral-950 focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50 md:text-sm dark:border-neutral-800 dark:bg-neutral-950 dark:ring-offset-neutral-950 dark:file:text-neutral-50 dark:placeholder:text-neutral-400 dark:focus-visible:ring-neutral-300",
className
)}
ref={ref}

View File

@@ -0,0 +1,86 @@
import React, {
useState,
useRef,
useLayoutEffect,
HTMLAttributes,
} from "react";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
interface TruncatedTextProps extends HTMLAttributes<HTMLSpanElement> {
text: string;
tooltipClassName?: string;
tooltipSide?: "top" | "right" | "bottom" | "left";
tooltipSideOffset?: number;
}
/**
* Renders passed in text on a single line. If text is truncated,
* shows a tooltip on hover with the full text.
*/
export function TruncatedText({
text,
tooltipClassName,
tooltipSide = "right",
tooltipSideOffset = 5,
className = "",
...rest
}: TruncatedTextProps) {
const [isTruncated, setIsTruncated] = useState(false);
const visibleRef = useRef<HTMLSpanElement>(null);
const hiddenRef = useRef<HTMLSpanElement>(null);
useLayoutEffect(() => {
function checkTruncation() {
if (visibleRef.current && hiddenRef.current) {
const visibleWidth = visibleRef.current.offsetWidth;
const fullTextWidth = hiddenRef.current.offsetWidth;
setIsTruncated(fullTextWidth > visibleWidth);
}
}
checkTruncation();
window.addEventListener("resize", checkTruncation);
return () => window.removeEventListener("resize", checkTruncation);
}, [text]);
return (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<span
ref={visibleRef}
// Ensure the text can actually truncate via line-clamp or overflow
className={`line-clamp-1 break-all flex-grow ${className}`}
{...rest}
>
{text}
</span>
</TooltipTrigger>
{/* Hide offscreen to measure full text width */}
<span
ref={hiddenRef}
className="absolute left-[-9999px] whitespace-nowrap pointer-events-none"
aria-hidden="true"
>
{text}
</span>
{isTruncated && (
<TooltipContent
side={tooltipSide}
sideOffset={tooltipSideOffset}
className={tooltipClassName}
>
<p className="text-xs max-w-[200px] whitespace-normal break-words">
{text}
</p>
</TooltipContent>
)}
</Tooltip>
</TooltipProvider>
);
}

View File

@@ -170,32 +170,59 @@ export const connectorConfigs: Record<
values: [
{
type: "text",
query: "Enter the repository owner:",
query: "Enter the GitHub username or organization:",
label: "Repository Owner",
name: "repo_owner",
optional: false,
},
{
type: "text",
query: "Enter the repository name:",
label: "Repository Name",
name: "repo_name",
optional: false,
type: "tab",
name: "github_mode",
label: "What should we index from GitHub?",
optional: true,
tabs: [
{
value: "repo",
label: "Specific Repository",
fields: [
{
type: "text",
query: "Enter the repository name:",
label: "Repository Name",
name: "repo_name",
optional: false,
},
],
},
{
value: "everything",
label: "Everything",
fields: [
{
type: "string_tab",
label: "Everything",
name: "everything",
description:
"This connector will index all repositories the provided credentials have access to!",
},
],
},
],
},
{
type: "checkbox",
query: "Include pull requests?",
label: "Include pull requests?",
description: "Index pull requests from this repository",
description: "Index pull requests from repositories",
name: "include_prs",
optional: true,
},
{
type: "checkbox",
query: "Include issues?",
label: "Include Issues",
label: "Include Issues?",
name: "include_issues",
description: "Index issues from this repository",
description: "Index issues from repositories",
optional: true,
},
],
@@ -462,14 +489,52 @@ export const connectorConfigs: Record<
},
jira: {
description: "Configure Jira connector",
subtext: `Specify any link to a Jira page below and click "Index" to Index. Based on the provided link, we will index the ENTIRE PROJECT, not just the specified page. For example, entering https://onyx.atlassian.net/jira/software/projects/DAN/boards/1 and clicking the Index button will index the whole DAN Jira project.`,
subtext: `Configure which Jira content to index. You can index everything or specify a particular project.`,
values: [
{
type: "text",
query: "Enter the Jira project URL:",
label: "Jira Project URL",
name: "jira_project_url",
query: "Enter the Jira base URL:",
label: "Jira Base URL",
name: "jira_base_url",
optional: false,
description:
"The base URL of your Jira instance (e.g., https://your-domain.atlassian.net)",
},
{
type: "tab",
name: "indexing_scope",
label: "How Should We Index Your Jira?",
optional: true,
tabs: [
{
value: "everything",
label: "Everything",
fields: [
{
type: "string_tab",
label: "Everything",
name: "everything",
description:
"This connector will index all issues the provided credentials have access to!",
},
],
},
{
value: "project",
label: "Project",
fields: [
{
type: "text",
query: "Enter the project key:",
label: "Project Key",
name: "project_key",
description:
"The key of a specific project to index (e.g., 'PROJ').",
},
],
},
],
defaultTab: "everything",
},
{
type: "list",
@@ -1309,6 +1374,7 @@ export interface ConfluenceConfig {
export interface JiraConfig {
jira_project_url: string;
project_key?: string;
comment_email_blacklist?: string[];
}

View File

@@ -714,10 +714,11 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
"claude-2.1": "Claude 2.1",
"claude-2.0": "Claude 2.0",
"claude-instant-1.2": "Claude Instant 1.2",
"claude-3-5-sonnet-20240620": "Claude 3.5 Sonnet",
"claude-3-5-sonnet-20241022": "Claude 3.5 Sonnet (New)",
"claude-3-5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
"claude-3.5-sonnet-v2@20241022": "Claude 3.5 Sonnet (New)",
"claude-3-5-sonnet-20240620": "Claude 3.5 Sonnet (June 2024)",
"claude-3-5-sonnet-20241022": "Claude 3.5 Sonnet",
"claude-3-7-sonnet-20250219": "Claude 3.7 Sonnet",
"claude-3-5-sonnet-v2@20241022": "Claude 3.5 Sonnet",
"claude-3.5-sonnet-v2@20241022": "Claude 3.5 Sonnet",
"claude-3-5-haiku-20241022": "Claude 3.5 Haiku",
"claude-3-5-haiku@20241022": "Claude 3.5 Haiku",
"claude-3.5-haiku@20241022": "Claude 3.5 Haiku",
@@ -770,6 +771,12 @@ const MODEL_DISPLAY_NAMES: { [key: string]: string } = {
};
export function getDisplayNameForModel(modelName: string): string {
if (modelName.startsWith("bedrock/")) {
const parts = modelName.split("/");
const lastPart = parts[parts.length - 1];
return MODEL_DISPLAY_NAMES[lastPart] || lastPart;
}
return MODEL_DISPLAY_NAMES[modelName] || modelName;
}

View File

@@ -71,6 +71,7 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
// standard claude names
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
@@ -88,6 +89,7 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"anthropic.claude-3-haiku-20240307-v1:0",
"anthropic.claude-3-5-sonnet-20240620-v1:0",
"anthropic.claude-3-5-sonnet-20241022-v2:0",
"anthropic.claude-3-7-sonnet-20250219-v1:0",
// google gemini model names
"gemini-1.5-pro",
"gemini-1.5-flash",

View File

@@ -18,7 +18,7 @@ async function verifyAdminPageNavigation(
try {
await expect(page.locator("h1.text-3xl")).toHaveText(pageTitle, {
timeout: 3000,
timeout: 5000,
});
} catch (error) {
console.error(