Compare commits

..

13 Commits

Author SHA1 Message Date
pablonyx
9cf5cdba2c improve scroll 2025-02-22 13:25:23 -08:00
Weves
bdaa293ae4 Fix nginx for prod compose file 2025-02-21 16:57:54 -08:00
pablonyx
5a131f4547 Fix integration tests (#4059) 2025-02-21 15:56:11 -08:00
rkuo-danswer
ffb7d5b85b enable manual testing for model server (#4003)
* trying out a fix

* add ability to manually run model tests

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-21 14:00:32 -08:00
rkuo-danswer
fe8a5d671a don't spam the logs with texts on auth errors (#4085)
* don't spam the logs with texts on auth errors

* refactor the logging a bit

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-21 13:40:07 -08:00
Yuhong Sun
6de53ebf60 README Touchup (#4088) 2025-02-21 13:31:07 -08:00
rkuo-danswer
61d536c782 tool fixes (#4075) 2025-02-21 12:30:33 -08:00
Chris Weaver
e1ff9086a4 Fix LLM selection (#4078) 2025-02-21 11:32:57 -08:00
evan-danswer
ba21bacbbf coerce useLanggraph to boolean (#4084)
* coerce useLanggraph to boolean
2025-02-21 09:43:46 -08:00
pablonyx
158bccc3fc Default on for non-ee (#4083) 2025-02-21 09:11:45 -08:00
Weves
599b7705c2 Fix gitbook connector issues 2025-02-20 15:29:11 -08:00
rkuo-danswer
4958a5355d try more efficient query (#4047) 2025-02-20 12:58:50 -08:00
Chris Weaver
c4b8519381 Add support for sending email invites for single tenant users (#4065) 2025-02-19 21:05:23 -08:00
40 changed files with 609 additions and 474 deletions

View File

@@ -145,7 +145,7 @@ jobs:
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
@@ -157,6 +157,7 @@ jobs:
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
@@ -199,7 +200,7 @@ jobs:
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |

View File

@@ -1,10 +1,16 @@
name: Connector Tests
name: Model Server Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
workflow_dispatch:
inputs:
branch:
description: 'Branch to run the workflow on'
required: false
default: 'main'
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
@@ -26,6 +32,23 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Model Server Docker image
run: |
docker pull onyxdotapp/onyx-model-server:latest
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
- name: Set up Python
uses: actions/setup-python@v5
with:
@@ -41,6 +64,49 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
@@ -56,3 +122,10 @@ jobs:
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -26,12 +26,12 @@
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date.
Create custom AI agents with unique prompts, knowledge, and actions the agents can take.
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
<h3>Feature Showcase</h3>
<h3>Feature Highlights</h3>
**Deep research over your team's knowledge:**
@@ -63,22 +63,21 @@ We also have built-in support for high-availability/scalable deployment on Kuber
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
## 🚧 Roadmap
- Extensions to the Chrome Plugin
- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- Personalized Search
- Organizational understanding and ability to locate and suggest experts from your team.
- Code Search
- SQL and Structured Query Language
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models only through Onyx + learn from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
## 🔌 Connectors
Keep knowledge and access up to sync across 40+ connectors:

View File

@@ -0,0 +1,27 @@
"""Add composite index for last_modified and last_synced to document
Revision ID: f13db29f3101
Revises: b388730a2899
Create Date: 2025-02-18 22:48:11.511389
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f13db29f3101"
down_revision = "acaab4ef4507"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
op.create_index(
"ix_document_sync_status",
"document",
["last_modified", "last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index("ix_document_sync_status", table_name="document")

View File

@@ -98,12 +98,17 @@ class CloudEmbedding:
return final_embeddings
except Exception as e:
error_string = (
f"Error embedding text with OpenAI: {str(e)} \n"
f"Model: {model} \n"
f"Provider: {self.provider} \n"
f"Texts: {texts}"
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
)
logger.error(error_string)
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
async def _embed_cohere(

View File

@@ -10,6 +10,7 @@ from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
@@ -187,23 +188,51 @@ def send_subscription_cancellation_email(user_email: str) -> None:
send_email(user_email, subject, html_content, text_content)
def send_user_email_invite(user_email: str, current_user: User) -> None:
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
subject = "Invitation to Join Onyx Organization"
heading = "You've Been Invited!"
message = (
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
elif auth_type == AuthType.BASIC:
message += (
"<p>To join the organization, please click the button below to set a password "
"and complete your registration.</p>"
)
elif auth_type == AuthType.GOOGLE_OAUTH:
message += (
"<p>To join the organization, please click the button below to login with Google "
"and complete your registration.</p>"
)
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
message += (
"<p>To join the organization, please click the button below to"
" complete your registration.</p>"
)
else:
raise ValueError(f"Invalid auth type: {auth_type}")
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
"You'll be asked to set a password or login with Google to complete your registration."
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
send_email(user_email, subject, html_content, text_content)

View File

@@ -361,6 +361,7 @@ def connector_external_group_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(

View File

@@ -15,6 +15,7 @@ from onyx.background.indexing.memory_tracer import MemoryTracer
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
@@ -89,8 +90,8 @@ def _get_connector_runner(
)
# validate the connector settings
runnable_connector.validate_connector_settings()
if not INTEGRATION_TESTS_MODE:
runnable_connector.validate_connector_settings()
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")

View File

@@ -158,7 +158,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
@@ -626,6 +626,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"

View File

@@ -7,22 +7,15 @@ from typing import Optional
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from botocore.exceptions import PartialCredentialsError
from mypy_boto3_s3 import S3Client # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import BlobType
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
@@ -247,73 +240,6 @@ class BlobStorageConnector(LoadConnector, PollConnector):
return None
def validate_connector_settings(self) -> None:
if self.s3_client is None:
raise ConnectorMissingCredentialError(
"Blob storage credentials not loaded."
)
if not self.bucket_name:
raise ConnectorValidationError(
"No bucket name was provided in connector settings."
)
try:
# We only fetch one object/page as a light-weight validation step.
# This ensures we trigger typical S3 permission checks (ListObjectsV2, etc.).
self.s3_client.list_objects_v2(
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
)
except NoCredentialsError:
raise ConnectorMissingCredentialError(
"No valid blob storage credentials found or provided to boto3."
)
except PartialCredentialsError:
raise ConnectorMissingCredentialError(
"Partial or incomplete blob storage credentials provided to boto3."
)
except ClientError as e:
error_code = e.response["Error"].get("Code", "")
status_code = e.response["ResponseMetadata"].get("HTTPStatusCode")
# Most common S3 error cases
if error_code in [
"AccessDenied",
"InvalidAccessKeyId",
"SignatureDoesNotMatch",
]:
if status_code == 403 or error_code == "AccessDenied":
raise InsufficientPermissionsError(
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
"Please check your bucket policy and/or IAM policy."
)
if status_code == 401 or error_code == "SignatureDoesNotMatch":
raise CredentialExpiredError(
"Provided blob storage credentials appear invalid or expired."
)
raise CredentialExpiredError(
f"Credential issue encountered ({error_code})."
)
if error_code == "NoSuchBucket" or status_code == 404:
raise ConnectorValidationError(
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
)
raise ConnectorValidationError(
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
)
except Exception as e:
# Catch-all for anything not captured by the above
# Since we are unsure of the error and it may not disable the connector,
# raise an unexpected error (does not disable connector)
raise UnexpectedError(
f"Unexpected error during blob storage settings validation: {e}"
)
if __name__ == "__main__":
credentials_dict = {

View File

@@ -3,6 +3,7 @@ from typing import Type
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceRequiringTenantContext
from onyx.connectors.airtable.airtable_connector import AirtableConnector
@@ -187,6 +188,9 @@ def validate_ccpair_for_user(
user: User | None,
tenant_id: str | None,
) -> None:
if INTEGRATION_TESTS_MODE:
return
# Validate the connector settings
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id_for_user(
@@ -199,7 +203,10 @@ def validate_ccpair_for_user(
if not connector:
raise ValueError("Connector not found")
if connector.source == DocumentSource.INGESTION_API:
if (
connector.source == DocumentSource.INGESTION_API
or connector.source == DocumentSource.MOCK_CONNECTOR
):
return
if not credential:

View File

@@ -229,16 +229,20 @@ class GitbookConnector(LoadConnector, PollConnector):
try:
content = self.client.get(f"/spaces/{self.space_id}/content")
pages = content.get("pages", [])
pages: list[dict[str, Any]] = content.get("pages", [])
current_batch: list[Document] = []
for page in pages:
updated_at = datetime.fromisoformat(page["updatedAt"])
while pages:
page = pages.pop(0)
updated_at_raw = page.get("updatedAt")
if updated_at_raw is None:
# if updatedAt is not present, that means the page has never been edited
continue
updated_at = datetime.fromisoformat(updated_at_raw)
if start and updated_at < start:
if current_batch:
yield current_batch
return
continue
if end and updated_at > end:
continue
@@ -250,6 +254,8 @@ class GitbookConnector(LoadConnector, PollConnector):
yield current_batch
current_batch = []
pages.extend(page.get("pages", []))
if current_batch:
yield current_batch

View File

@@ -220,7 +220,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
try:
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
except KeyError:
raise ValueError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,

View File

@@ -87,18 +87,16 @@ class HubSpotConnector(LoadConnector, PollConnector):
contact = api_client.crm.contacts.basic_api.get_by_id(
contact_id=contact.id
)
email = contact.properties.get("email")
if email is not None:
associated_emails.append(email)
associated_emails.append(contact.properties["email"])
if notes:
for note in notes.results:
note = api_client.crm.objects.notes.basic_api.get_by_id(
note_id=note.id, properties=["content", "hs_body_preview"]
)
preview = note.properties.get("hs_body_preview")
if preview is not None:
associated_notes.append(preview)
if note.properties["hs_body_preview"] is None:
continue
associated_notes.append(note.properties["hs_body_preview"])
associated_emails_str = " ,".join(associated_emails)
associated_notes_str = " ".join(associated_notes)

View File

@@ -20,13 +20,9 @@ from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -670,56 +666,6 @@ class SlackConnector(SlimConnector, CheckpointConnector):
)
return checkpoint
def validate_connector_settings(self) -> None:
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
try:
# Minimal API call to confirm we can list channels
# We set limit=1 for a lightweight check
response = self.client.conversations_list(limit=1, types=["public_channel"])
# Just ensure Slack responded "ok: True"
if not response.get("ok", False):
error_msg = response.get("error", "Unknown error from Slack")
if error_msg == "invalid_auth":
raise ConnectorValidationError(
f"Invalid or expired Slack bot token ({error_msg})."
)
elif error_msg == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({error_msg})."
)
raise UnexpectedError(f"Slack API returned a failure: {error_msg}")
except SlackApiError as e:
slack_error = e.response.get("error", "")
if slack_error == "missing_scope":
# The needed scope is typically "channels:read" or "groups:read"
# for viewing channels. The error message might also contain the
# specific scope needed vs. provided.
raise InsufficientPermissionsError(
"Slack bot token lacks the necessary scope to list channels. "
"Please ensure your Slack app has 'channels:read' (or 'groups:read' for private channels) enabled."
)
elif slack_error == "invalid_auth":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
elif slack_error == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
else:
# Generic Slack error
raise UnexpectedError(
f"Unexpected Slack error '{slack_error}' during settings validation."
)
except Exception as e:
# Catch-all for unexpected exceptions
raise UnexpectedError(
f"Unexpected error during Slack settings validation: {e}"
)
if __name__ == "__main__":
import os

View File

@@ -5,7 +5,6 @@ from typing import Any
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.runtime.client_request_exception import ClientRequestException # type: ignore
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.chats.messages.message import ChatMessage # type: ignore
from office365.teams.team import Team # type: ignore
@@ -13,14 +12,10 @@ from office365.teams.team import Team # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
@@ -284,64 +279,6 @@ class TeamsConnector(LoadConnector, PollConnector):
end_datetime = datetime.fromtimestamp(end, timezone.utc)
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
def validate_connector_settings(self) -> None:
"""
Validate that we can connect to Microsoft Teams with the provided MSAL/Graph credentials
and that we can see at least one Team. If the user has specified a list of Teams by name,
confirm at least one of them is found.
Raises:
ConnectorMissingCredentialError: If the Graph client is not yet set (missing credentials).
CredentialExpiredError: If credentials appear invalid/expired (e.g. 401 Unauthorized).
InsufficientPermissionsError: If the app lacks required permissions to read Teams.
ConnectorValidationError: If no Teams are found, or if requested Teams are not found.
"""
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
try:
# Minimal call to confirm we can retrieve Teams
found_teams = self._get_all_teams()
# You may optionally catch the Graph/Office365 request exception if available:
except ClientRequestException as e:
status_code = e.response.status_code
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
)
else:
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
except Exception as e:
error_str = str(e).lower()
if (
"unauthorized" in error_str
or "401" in error_str
or "invalid_grant" in error_str
):
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials."
)
elif "forbidden" in error_str or "403" in error_str:
raise InsufficientPermissionsError(
"App lacks required permissions to read from Microsoft Teams."
)
raise ConnectorValidationError(
f"Unexpected error during Teams validation: {e}"
)
# If we get this far, the Graph call succeeded. Check for presence of Teams:
if not found_teams:
raise ConnectorValidationError(
"No Teams found for the given credentials. "
"Either there are no Teams in this tenant, or your app does not have permission to view them."
)
if __name__ == "__main__":
connector = TeamsConnector(teams=os.environ["TEAMS"].split(","))

View File

@@ -440,10 +440,7 @@ class WebConnector(LoadConnector):
"No URL configured. Please provide at least one valid URL."
)
if (
self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP.value
or self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value
):
if self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP.value:
return None
# We'll just test the first URL for connectivity and correctness

View File

@@ -194,9 +194,14 @@ def get_connector_credential_pair_from_id_for_user(
def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
eager_load_credential: bool = False,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
if eager_load_credential:
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
result = db_session.execute(stmt)
return result.scalar_one_or_none()

View File

@@ -60,9 +60,8 @@ def count_documents_by_needs_sync(session: Session) -> int:
This function executes the query and returns the count of
documents matching the criteria."""
count = (
session.query(func.count(DbDocument.id.distinct()))
.select_from(DbDocument)
return (
session.query(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
@@ -73,63 +72,53 @@ def count_documents_by_needs_sync(session: Session) -> int:
DbDocument.last_synced.is_(None),
)
)
.scalar()
.count()
)
return count
def construct_document_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
return (
select(DbDocument)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
),
)
)
.distinct()
)
return stmt
def construct_document_id_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
return (
select(DbDocument.id)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
),
)
)
.distinct()
)
return stmt
def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int

View File

@@ -570,6 +570,14 @@ class Document(Base):
back_populates="documents",
)
__table_args__ = (
Index(
"ix_document_sync_status",
last_modified,
last_synced,
),
)
class Tag(Base):
__tablename__ = "tag"

View File

@@ -311,19 +311,23 @@ def bulk_invite_users(
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
number_of_invited_users = write_invited_users(all_emails)
# send out email invitations if enabled
if ENABLE_EMAIL_INVITES:
try:
for email in new_invited_emails:
send_user_email_invite(email, current_user, AUTH_TYPE)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
if not MULTI_TENANT:
return number_of_invited_users
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))
if ENABLE_EMAIL_INVITES:
try:
for email in new_invited_emails:
send_user_email_invite(email, current_user)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
return number_of_invited_users
except Exception as e:

View File

@@ -3,6 +3,7 @@ import json
import logging
import sys
import time
from enum import Enum
from logging import getLogger
from typing import cast
from uuid import UUID
@@ -20,10 +21,13 @@ from onyx.configs.app_configs import REDIS_PORT
from onyx.configs.app_configs import REDIS_SSL
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_pool import RedisPool
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
# Tool to run helpful operations on Redis in production
# This is targeted for internal usage and may not have all the necessary parameters
@@ -42,6 +46,19 @@ SCAN_ITER_COUNT = 10000
BATCH_DEFAULT = 1000
class OnyxRedisCommand(Enum):
purge_connectorsync_taskset = "purge_connectorsync_taskset"
purge_documentset_taskset = "purge_documentset_taskset"
purge_usergroup_taskset = "purge_usergroup_taskset"
purge_locks_blocking_deletion = "purge_locks_blocking_deletion"
purge_vespa_syncing = "purge_vespa_syncing"
get_user_token = "get_user_token"
delete_user_token = "delete_user_token"
def __str__(self) -> str:
return self.value
def get_user_id(user_email: str) -> tuple[UUID, str]:
tenant_id = (
get_tenant_id_for_email(user_email) if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
@@ -55,50 +72,79 @@ def get_user_id(user_email: str) -> tuple[UUID, str]:
def onyx_redis(
command: str,
command: OnyxRedisCommand,
batch: int,
dry_run: bool,
ssl: bool,
host: str,
port: int,
db: int,
password: str | None,
user_email: str | None = None,
cc_pair_id: int | None = None,
) -> int:
# this is global and not tenant aware
pool = RedisPool.create_pool(
host=host,
port=port,
db=db,
password=password if password else "",
ssl=REDIS_SSL,
ssl=ssl,
ssl_cert_reqs="optional",
ssl_ca_certs=None,
)
r = Redis(connection_pool=pool)
logger.info("Redis ping starting. This may hang if your settings are incorrect.")
try:
r.ping()
except:
logger.exception("Redis ping exceptioned")
raise
if command == "purge_connectorsync_taskset":
logger.info("Redis ping succeeded.")
if command == OnyxRedisCommand.purge_connectorsync_taskset:
"""Purge connector tasksets. Used when the tasks represented in the tasksets
have been purged."""
return purge_by_match_and_type(
"*connectorsync_taskset*", "set", batch, dry_run, r
)
elif command == "purge_documentset_taskset":
elif command == OnyxRedisCommand.purge_documentset_taskset:
return purge_by_match_and_type(
"*documentset_taskset*", "set", batch, dry_run, r
)
elif command == "purge_usergroup_taskset":
elif command == OnyxRedisCommand.purge_usergroup_taskset:
return purge_by_match_and_type("*usergroup_taskset*", "set", batch, dry_run, r)
elif command == "purge_vespa_syncing":
elif command == OnyxRedisCommand.purge_locks_blocking_deletion:
if cc_pair_id is None:
logger.error("You must specify --cc-pair with purge_deletion_locks")
return 1
tenant_id = get_current_tenant_id()
logger.info(f"Purging locks associated with deleting cc_pair={cc_pair_id}.")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
match_pattern = f"{tenant_id}:{RedisConnectorIndex.FENCE_PREFIX}_{cc_pair_id}/*"
purge_by_match_and_type(match_pattern, "string", batch, dry_run, r)
redis_delete_if_exists_helper(
f"{tenant_id}:{redis_connector.prune.fence_key}", dry_run, r
)
redis_delete_if_exists_helper(
f"{tenant_id}:{redis_connector.permissions.fence_key}", dry_run, r
)
redis_delete_if_exists_helper(
f"{tenant_id}:{redis_connector.external_group_sync.fence_key}", dry_run, r
)
return 0
elif command == OnyxRedisCommand.purge_vespa_syncing:
return purge_by_match_and_type(
"*connectorsync:vespa_syncing*", "string", batch, dry_run, r
)
elif command == "get_user_token":
elif command == OnyxRedisCommand.get_user_token:
if not user_email:
logger.error("You must specify --user-email with get_user_token")
return 1
@@ -109,7 +155,7 @@ def onyx_redis(
else:
print(f"No token found for user {user_email}")
return 2
elif command == "delete_user_token":
elif command == OnyxRedisCommand.delete_user_token:
if not user_email:
logger.error("You must specify --user-email with delete_user_token")
return 1
@@ -131,6 +177,25 @@ def flush_batch_delete(batch_keys: list[bytes], r: Redis) -> None:
pipe.execute()
def redis_delete_if_exists_helper(key: str, dry_run: bool, r: Redis) -> bool:
"""Returns True if the key was found, False if not.
This function exists for logging purposes as the delete operation itself
doesn't really need to check the existence of the key.
"""
if not r.exists(key):
logger.info(f"Did not find {key}.")
return False
if dry_run:
logger.info(f"(DRY-RUN) Deleting {key}.")
else:
logger.info(f"Deleting {key}.")
r.delete(key)
return True
def purge_by_match_and_type(
match_pattern: str, match_type: str, batch_size: int, dry_run: bool, r: Redis
) -> int:
@@ -138,6 +203,12 @@ def purge_by_match_and_type(
match_type: https://redis.io/docs/latest/commands/type/
"""
logger.info(
f"purge_by_match_and_type start: "
f"match_pattern={match_pattern} "
f"match_type={match_type}"
)
# cursor = "0"
# while cursor != 0:
# cursor, data = self.scan(
@@ -164,13 +235,15 @@ def purge_by_match_and_type(
logger.info(f"Deleting item {count}: {key_str}")
batch_keys.append(key)
# flush if batch size has been reached
if len(batch_keys) >= batch_size:
flush_batch_delete(batch_keys, r)
batch_keys.clear()
if len(batch_keys) >= batch_size:
flush_batch_delete(batch_keys, r)
batch_keys.clear()
# final flush
flush_batch_delete(batch_keys, r)
batch_keys.clear()
logger.info(f"Deleted {count} matches.")
@@ -279,7 +352,21 @@ def delete_user_token_from_redis(
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Onyx Redis Manager")
parser.add_argument("--command", type=str, help="Operation to run", required=True)
parser.add_argument(
"--command",
type=OnyxRedisCommand,
help="The command to run",
choices=list(OnyxRedisCommand),
required=True,
)
parser.add_argument(
"--ssl",
type=bool,
default=REDIS_SSL,
help="Use SSL when connecting to Redis. Usually True for prod and False for local testing",
required=False,
)
parser.add_argument(
"--host",
@@ -342,6 +429,13 @@ if __name__ == "__main__":
required=False,
)
parser.add_argument(
"--cc-pair",
type=int,
help="A connector credential pair id. Used with the purge_deletion_locks command.",
required=False,
)
args = parser.parse_args()
if args.tenant_id:
@@ -368,10 +462,12 @@ if __name__ == "__main__":
command=args.command,
batch=args.batch,
dry_run=args.dry_run,
ssl=args.ssl,
host=args.host,
port=args.port,
db=args.db,
password=args.password,
user_email=args.user_email,
cc_pair_id=args.cc_pair,
)
sys.exit(exitcode)

View File

@@ -3,7 +3,7 @@ import os
ADMIN_USER_NAME = "admin_user"
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "127.0.0.1"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
MAX_DELAY = 45

View File

@@ -30,8 +30,10 @@ class ConnectorManager:
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config
or {"file_locations": []},
connector_specific_config=(
connector_specific_config
or ({"file_locations": []} if source == DocumentSource.FILE else {})
),
access_type=access_type,
groups=groups or [],
)

View File

@@ -88,8 +88,6 @@ class UserManager:
if not session_cookie:
raise Exception("Failed to login")
print(f"Logged in as {test_user.email}")
# Set cookies in the headers
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
test_user.cookies = {"fastapiusersauth": session_cookie}

View File

@@ -1,5 +1,5 @@
# fill in the template
envsubst '$SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf
envsubst '$DOMAIN $SSL_CERT_FILE_NAME $SSL_CERT_KEY_FILE_NAME' < "/etc/nginx/conf.d/$1" > /etc/nginx/conf.d/app.conf
# wait for the api_server to be ready
echo "Waiting for API server to boot up; this may take a minute or two..."

View File

@@ -36,6 +36,7 @@ services:
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
- INTEGRATION_TESTS_MODE=${INTEGRATION_TESTS_MODE:-}
# Gen AI Settings
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}

View File

@@ -47,6 +47,7 @@ import {
removeMessage,
sendMessage,
setMessageAsLatest,
updateLlmOverrideForChatSession,
updateParentChildren,
uploadFilesForChat,
useScrollonStream,
@@ -65,7 +66,7 @@ import {
import { usePopup } from "@/components/admin/connectors/Popup";
import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
import { useDocumentSelection } from "./useDocumentSelection";
import { LlmOverride, useFilters, useLlmOverride } from "@/lib/hooks";
import { LlmDescriptor, useFilters, useLlmManager } from "@/lib/hooks";
import { ChatState, FeedbackType, RegenerationState } from "./types";
import { DocumentResults } from "./documentSidebar/DocumentResults";
import { OnyxInitializingLoader } from "@/components/OnyxInitializingLoader";
@@ -89,7 +90,11 @@ import {
import { buildFilters } from "@/lib/search/utils";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
import {
checkLLMSupportsImageInput,
getFinalLLM,
structureValue,
} from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar";
import { useChatContext } from "@/components/context/ChatContext";
import { v4 as uuidv4 } from "uuid";
@@ -194,16 +199,6 @@ export function ChatPage({
return screenSize;
}
const { height: screenHeight } = useScreenSize();
const getContainerHeight = () => {
if (autoScrollEnabled) return undefined;
if (screenHeight < 600) return "20vh";
if (screenHeight < 1200) return "30vh";
return "40vh";
};
// handle redirect if chat page is disabled
// NOTE: this must be done here, in a client component since
// settings are passed in via Context and therefore aren't
@@ -222,6 +217,7 @@ export function ChatPage({
setProSearchEnabled(!proSearchEnabled);
};
const isInitialLoad = useRef(true);
const [userSettingsToggled, setUserSettingsToggled] = useState(false);
const {
@@ -356,7 +352,7 @@ export function ChatPage({
]
);
const llmOverrideManager = useLlmOverride(
const llmManager = useLlmManager(
llmProviders,
selectedChatSession,
liveAssistant
@@ -520,8 +516,17 @@ export function ChatPage({
scrollInitialized.current = false;
if (!hasPerformedInitialScroll) {
if (isInitialLoad.current) {
setHasPerformedInitialScroll(true);
isInitialLoad.current = false;
}
clientScrollToBottom();
setTimeout(() => {
setHasPerformedInitialScroll(true);
}, 100);
} else if (isChatSessionSwitch) {
setHasPerformedInitialScroll(true);
clientScrollToBottom(true);
}
@@ -1130,6 +1135,56 @@ export function ChatPage({
});
};
const [uncaughtError, setUncaughtError] = useState<string | null>(null);
const [agenticGenerating, setAgenticGenerating] = useState(false);
const autoScrollEnabled =
(user?.preferences?.auto_scroll && !agenticGenerating) ?? false;
useScrollonStream({
chatState: currentSessionChatState,
scrollableDivRef,
scrollDist,
endDivRef,
debounceNumber,
mobile: settings?.isMobile,
enableAutoScroll: autoScrollEnabled,
});
// Track whether a message has been sent during this page load, keyed by chat session id
const [sessionHasSentLocalUserMessage, setSessionHasSentLocalUserMessage] =
useState<Map<string | null, boolean>>(new Map());
// Update the local state for a session once the user sends a message
const markSessionMessageSent = (sessionId: string | null) => {
setSessionHasSentLocalUserMessage((prev) => {
const newMap = new Map(prev);
newMap.set(sessionId, true);
return newMap;
});
};
const currentSessionHasSentLocalUserMessage = useMemo(
() => (sessionId: string | null) => {
return sessionHasSentLocalUserMessage.size === 0
? undefined
: sessionHasSentLocalUserMessage.get(sessionId) || false;
},
[sessionHasSentLocalUserMessage]
);
const { height: screenHeight } = useScreenSize();
const getContainerHeight = useMemo(() => {
return () => {
if (!currentSessionHasSentLocalUserMessage(chatSessionIdRef.current)) {
return undefined;
}
if (autoScrollEnabled) return undefined;
if (screenHeight < 600) return "40vh";
if (screenHeight < 1200) return "50vh";
return "60vh";
};
}, [autoScrollEnabled, screenHeight, currentSessionHasSentLocalUserMessage]);
const onSubmit = async ({
messageIdToResend,
@@ -1138,7 +1193,7 @@ export function ChatPage({
forceSearch,
isSeededChat,
alternativeAssistantOverride = null,
modelOverRide,
modelOverride,
regenerationRequest,
overrideFileDescriptors,
}: {
@@ -1148,7 +1203,7 @@ export function ChatPage({
forceSearch?: boolean;
isSeededChat?: boolean;
alternativeAssistantOverride?: Persona | null;
modelOverRide?: LlmOverride;
modelOverride?: LlmDescriptor;
regenerationRequest?: RegenerationRequest | null;
overrideFileDescriptors?: FileDescriptor[];
} = {}) => {
@@ -1156,6 +1211,9 @@ export function ChatPage({
let frozenSessionId = currentSessionId();
updateCanContinue(false, frozenSessionId);
// Mark that we've sent a message for this session in the current page load
markSessionMessageSent(frozenSessionId);
if (currentChatState() != "input") {
if (currentChatState() == "uploading") {
setPopup({
@@ -1191,6 +1249,22 @@ export function ChatPage({
currChatSessionId = chatSessionIdRef.current as string;
}
frozenSessionId = currChatSessionId;
// update the selected model for the chat session if one is specified so that
// it persists across page reloads. Do not `await` here so that the message
// request can continue and this will just happen in the background.
// NOTE: only set the model override for the chat session once we send a
// message with it. If the user switches models and then starts a new
// chat session, it is unexpected for that model to be used when they
// return to this session the next day.
let finalLLM = modelOverride || llmManager.currentLlm;
updateLlmOverrideForChatSession(
currChatSessionId,
structureValue(
finalLLM.name || "",
finalLLM.provider || "",
finalLLM.modelName || ""
)
);
updateStatesWithNewSessionId(currChatSessionId);
@@ -1250,11 +1324,14 @@ export function ChatPage({
: null) ||
(messageMap.size === 1 ? Array.from(messageMap.values())[0] : null);
const currentAssistantId = alternativeAssistantOverride
? alternativeAssistantOverride.id
: alternativeAssistant
? alternativeAssistant.id
: liveAssistant.id;
let currentAssistantId;
if (alternativeAssistantOverride) {
currentAssistantId = alternativeAssistantOverride.id;
} else if (alternativeAssistant) {
currentAssistantId = alternativeAssistant.id;
} else {
currentAssistantId = liveAssistant.id;
}
resetInputBar();
let messageUpdates: Message[] | null = null;
@@ -1326,15 +1403,13 @@ export function ChatPage({
forceSearch,
regenerate: regenerationRequest !== undefined,
modelProvider:
modelOverRide?.name ||
llmOverrideManager.llmOverride.name ||
undefined,
modelOverride?.name || llmManager.currentLlm.name || undefined,
modelVersion:
modelOverRide?.modelName ||
llmOverrideManager.llmOverride.modelName ||
modelOverride?.modelName ||
llmManager.currentLlm.modelName ||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
undefined,
temperature: llmOverrideManager.temperature || undefined,
temperature: llmManager.temperature || undefined,
systemPromptOverride:
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
useExistingUserMessage: isSeededChat,
@@ -1802,7 +1877,7 @@ export function ChatPage({
const [_, llmModel] = getFinalLLM(
llmProviders,
liveAssistant,
llmOverrideManager.llmOverride
llmManager.currentLlm
);
const llmAcceptsImages = checkLLMSupportsImageInput(llmModel);
@@ -1857,7 +1932,6 @@ export function ChatPage({
// Used to maintain a "time out" for history sidebar so our existing refs can have time to process change
const [untoggled, setUntoggled] = useState(false);
const [loadingError, setLoadingError] = useState<string | null>(null);
const [agenticGenerating, setAgenticGenerating] = useState(false);
const explicitlyUntoggle = () => {
setShowHistorySidebar(false);
@@ -1899,19 +1973,6 @@ export function ChatPage({
isAnonymousUser: user?.is_anonymous_user,
});
const autoScrollEnabled =
(user?.preferences?.auto_scroll && !agenticGenerating) ?? false;
useScrollonStream({
chatState: currentSessionChatState,
scrollableDivRef,
scrollDist,
endDivRef,
debounceNumber,
mobile: settings?.isMobile,
enableAutoScroll: autoScrollEnabled,
});
// Virtualization + Scrolling related effects and functions
const scrollInitialized = useRef(false);
interface VisibleRange {
@@ -2121,7 +2182,7 @@ export function ChatPage({
}, [searchParams, router]);
useEffect(() => {
llmOverrideManager.updateImageFilesPresent(imageFileInMessageHistory);
llmManager.updateImageFilesPresent(imageFileInMessageHistory);
}, [imageFileInMessageHistory]);
const pathname = usePathname();
@@ -2175,9 +2236,9 @@ export function ChatPage({
function createRegenerator(regenerationRequest: RegenerationRequest) {
// Returns new function that only needs `modelOverRide` to be specified when called
return async function (modelOverRide: LlmOverride) {
return async function (modelOverride: LlmDescriptor) {
return await onSubmit({
modelOverRide,
modelOverride,
messageIdToResend: regenerationRequest.parentMessage.messageId,
regenerationRequest,
forceSearch: regenerationRequest.forceSearch,
@@ -2258,9 +2319,7 @@ export function ChatPage({
{(settingsToggled || userSettingsToggled) && (
<UserSettingsModal
setPopup={setPopup}
setLlmOverride={(newOverride) =>
llmOverrideManager.updateLLMOverride(newOverride)
}
setCurrentLlm={(newLlm) => llmManager.updateCurrentLlm(newLlm)}
defaultModel={user?.preferences.default_model!}
llmProviders={llmProviders}
onClose={() => {
@@ -2324,7 +2383,7 @@ export function ChatPage({
<ShareChatSessionModal
assistantId={liveAssistant?.id}
message={message}
modelOverride={llmOverrideManager.llmOverride}
modelOverride={llmManager.currentLlm}
chatSessionId={sharedChatSession.id}
existingSharedStatus={sharedChatSession.shared_status}
onClose={() => setSharedChatSession(null)}
@@ -2342,7 +2401,7 @@ export function ChatPage({
<ShareChatSessionModal
message={message}
assistantId={liveAssistant?.id}
modelOverride={llmOverrideManager.llmOverride}
modelOverride={llmManager.currentLlm}
chatSessionId={chatSessionIdRef.current}
existingSharedStatus={chatSessionSharedStatus}
onClose={() => setSharingModalVisible(false)}
@@ -2572,6 +2631,7 @@ export function ChatPage({
style={{ overflowAnchor: "none" }}
key={currentSessionId()}
className={
(hasPerformedInitialScroll ? "" : " hidden ") +
"desktop:-ml-4 w-full mx-auto " +
"absolute mobile:top-0 desktop:top-0 left-0 " +
(settings?.enterpriseSettings
@@ -3058,7 +3118,7 @@ export function ChatPage({
messageId: message.messageId,
parentMessage: parentMessage!,
forceSearch: true,
})(llmOverrideManager.llmOverride);
})(llmManager.currentLlm);
} else {
setPopup({
type: "error",
@@ -3203,7 +3263,7 @@ export function ChatPage({
availableDocumentSets={documentSets}
availableTags={tags}
filterManager={filterManager}
llmOverrideManager={llmOverrideManager}
llmManager={llmManager}
removeDocs={() => {
clearSelectedDocuments();
}}

View File

@@ -1,8 +1,8 @@
import { useChatContext } from "@/components/context/ChatContext";
import {
getDisplayNameForModel,
LlmOverride,
useLlmOverride,
LlmDescriptor,
useLlmManager,
} from "@/lib/hooks";
import { StringOrNumberOption } from "@/components/Dropdown";
@@ -106,13 +106,13 @@ export default function RegenerateOption({
onDropdownVisibleChange,
}: {
selectedAssistant: Persona;
regenerate: (modelOverRide: LlmOverride) => Promise<void>;
regenerate: (modelOverRide: LlmDescriptor) => Promise<void>;
overriddenModel?: string;
onHoverChange: (isHovered: boolean) => void;
onDropdownVisibleChange: (isVisible: boolean) => void;
}) {
const { llmProviders } = useChatContext();
const llmOverrideManager = useLlmOverride(llmProviders);
const llmManager = useLlmManager(llmProviders);
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
@@ -148,7 +148,7 @@ export default function RegenerateOption({
);
const currentModelName =
llmOverrideManager?.llmOverride.modelName ||
llmManager?.currentLlm.modelName ||
(selectedAssistant
? selectedAssistant.llm_model_version_override || llmName
: llmName);

View File

@@ -6,7 +6,7 @@ import { Persona } from "@/app/admin/assistants/interfaces";
import LLMPopover from "./LLMPopover";
import { InputPrompt } from "@/app/chat/interfaces";
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
import { FilterManager, LlmManager } from "@/lib/hooks";
import { useChatContext } from "@/components/context/ChatContext";
import { ChatFileType, FileDescriptor } from "../interfaces";
import {
@@ -180,7 +180,7 @@ interface ChatInputBarProps {
setMessage: (message: string) => void;
stopGenerating: () => void;
onSubmit: () => void;
llmOverrideManager: LlmOverrideManager;
llmManager: LlmManager;
chatState: ChatState;
alternativeAssistant: Persona | null;
// assistants
@@ -225,7 +225,7 @@ export function ChatInputBar({
availableSources,
availableDocumentSets,
availableTags,
llmOverrideManager,
llmManager,
proSearchEnabled,
setProSearchEnabled,
}: ChatInputBarProps) {
@@ -781,7 +781,7 @@ export function ChatInputBar({
<LLMPopover
llmProviders={llmProviders}
llmOverrideManager={llmOverrideManager}
llmManager={llmManager}
requiresImageGeneration={false}
currentAssistant={selectedAssistant}
/>

View File

@@ -16,7 +16,7 @@ import {
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { Persona } from "@/app/admin/assistants/interfaces";
import { LlmOverrideManager } from "@/lib/hooks";
import { LlmManager } from "@/lib/hooks";
import {
Tooltip,
@@ -31,21 +31,19 @@ import { useUser } from "@/components/user/UserProvider";
interface LLMPopoverProps {
llmProviders: LLMProviderDescriptor[];
llmOverrideManager: LlmOverrideManager;
llmManager: LlmManager;
requiresImageGeneration?: boolean;
currentAssistant?: Persona;
}
export default function LLMPopover({
llmProviders,
llmOverrideManager,
llmManager,
requiresImageGeneration,
currentAssistant,
}: LLMPopoverProps) {
const [isOpen, setIsOpen] = useState(false);
const { user } = useUser();
const { llmOverride, updateLLMOverride } = llmOverrideManager;
const currentLlm = llmOverride.modelName;
const llmOptionsByProvider: {
[provider: string]: {
@@ -93,19 +91,19 @@ export default function LLMPopover({
: null;
const [localTemperature, setLocalTemperature] = useState(
llmOverrideManager.temperature ?? 0.5
llmManager.temperature ?? 0.5
);
useEffect(() => {
setLocalTemperature(llmOverrideManager.temperature ?? 0.5);
}, [llmOverrideManager.temperature]);
setLocalTemperature(llmManager.temperature ?? 0.5);
}, [llmManager.temperature]);
const handleTemperatureChange = (value: number[]) => {
setLocalTemperature(value[0]);
};
const handleTemperatureChangeComplete = (value: number[]) => {
llmOverrideManager.updateTemperature(value[0]);
llmManager.updateTemperature(value[0]);
};
return (
@@ -120,15 +118,15 @@ export default function LLMPopover({
toggle
flexPriority="stiff"
name={getDisplayNameForModel(
llmOverrideManager?.llmOverride.modelName ||
llmManager?.currentLlm.modelName ||
defaultModelDisplayName ||
"Models"
)}
Icon={getProviderIcon(
llmOverrideManager?.llmOverride.provider ||
llmManager?.currentLlm.provider ||
defaultProvider?.provider ||
"anthropic",
llmOverrideManager?.llmOverride.modelName ||
llmManager?.currentLlm.modelName ||
defaultProvider?.default_model_name ||
"claude-3-5-sonnet-20240620"
)}
@@ -147,12 +145,12 @@ export default function LLMPopover({
<button
key={index}
className={`w-full flex items-center gap-x-2 px-3 py-2 text-sm text-left hover:bg-background-100 dark:hover:bg-neutral-800 transition-colors duration-150 ${
currentLlm === name
llmManager.currentLlm.modelName === name
? "bg-background-100 dark:bg-neutral-900 text-text"
: "text-text-darker"
}`}
onClick={() => {
updateLLMOverride(destructureValue(value));
llmManager.updateCurrentLlm(destructureValue(value));
setIsOpen(false);
}}
>
@@ -172,7 +170,7 @@ export default function LLMPopover({
);
}
})()}
{llmOverrideManager.imageFilesPresent &&
{llmManager.imageFilesPresent &&
!checkLLMSupportsImageInput(name) && (
<TooltipProvider>
<Tooltip delayDuration={0}>
@@ -199,7 +197,7 @@ export default function LLMPopover({
<div className="w-full px-3 py-2">
<Slider
value={[localTemperature]}
max={llmOverrideManager.maxTemperature}
max={llmManager.maxTemperature}
min={0}
step={0.01}
onValueChange={handleTemperatureChange}

View File

@@ -65,7 +65,7 @@ export function getChatRetentionInfo(
};
}
export async function updateModelOverrideForChatSession(
export async function updateLlmOverrideForChatSession(
chatSessionId: string,
newAlternateModel: string
) {
@@ -236,7 +236,7 @@ export async function* sendMessage({
}
: null,
use_existing_user_message: useExistingUserMessage,
use_agentic_search: useLanggraph,
use_agentic_search: useLanggraph ?? false,
});
const response = await fetch(`/api/chat/send-message`, {

View File

@@ -44,7 +44,7 @@ import { ValidSources } from "@/lib/types";
import { useMouseTracking } from "./hooks";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import { LlmDescriptor } from "@/lib/hooks";
import { ContinueGenerating } from "./ContinueMessage";
import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents";
import { extractCodeText, preprocessLaTeX } from "./codeUtils";
@@ -117,7 +117,7 @@ export const AgenticMessage = ({
isComplete?: boolean;
handleFeedback?: (feedbackType: FeedbackType) => void;
overriddenModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
regenerate?: (modelOverRide: LlmDescriptor) => Promise<void>;
setPresentingDocument?: (document: OnyxDocument) => void;
toggleDocDisplay?: (agentic: boolean) => void;
error?: string | null;

View File

@@ -58,7 +58,7 @@ import { useMouseTracking } from "./hooks";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import { LlmDescriptor } from "@/lib/hooks";
import { ContinueGenerating } from "./ContinueMessage";
import { MemoizedAnchor, MemoizedParagraph } from "./MemoizedTextComponents";
import { extractCodeText, preprocessLaTeX } from "./codeUtils";
@@ -213,7 +213,7 @@ export const AIMessage = ({
handleForceSearch?: () => void;
retrievalDisabled?: boolean;
overriddenModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
regenerate?: (modelOverRide: LlmDescriptor) => Promise<void>;
setPresentingDocument: (document: OnyxDocument) => void;
removePadding?: boolean;
}) => {

View File

@@ -11,7 +11,7 @@ import { CopyButton } from "@/components/CopyButton";
import { SEARCH_PARAM_NAMES } from "../searchParams";
import { usePopup } from "@/components/admin/connectors/Popup";
import { structureValue } from "@/lib/llm/utils";
import { LlmOverride } from "@/lib/hooks";
import { LlmDescriptor } from "@/lib/hooks";
import { Separator } from "@/components/ui/separator";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
@@ -38,7 +38,7 @@ async function generateShareLink(chatSessionId: string) {
async function generateSeedLink(
message?: string,
assistantId?: number,
modelOverride?: LlmOverride
modelOverride?: LlmDescriptor
) {
const baseUrl = `${window.location.protocol}//${window.location.host}`;
const model = modelOverride
@@ -92,7 +92,7 @@ export function ShareChatSessionModal({
onClose: () => void;
message?: string;
assistantId?: number;
modelOverride?: LlmOverride;
modelOverride?: LlmDescriptor;
}) {
const [shareLink, setShareLink] = useState<string>(
existingSharedStatus === ChatSessionSharedStatus.Public

View File

@@ -1,6 +1,6 @@
import { useContext, useEffect, useRef, useState } from "react";
import { Modal } from "@/components/Modal";
import { getDisplayNameForModel, LlmOverride } from "@/lib/hooks";
import { getDisplayNameForModel, LlmDescriptor } from "@/lib/hooks";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { destructureValue, structureValue } from "@/lib/llm/utils";
@@ -31,12 +31,12 @@ export function UserSettingsModal({
setPopup,
llmProviders,
onClose,
setLlmOverride,
setCurrentLlm,
defaultModel,
}: {
setPopup: (popupSpec: PopupSpec | null) => void;
llmProviders: LLMProviderDescriptor[];
setLlmOverride?: (newOverride: LlmOverride) => void;
setCurrentLlm?: (newLlm: LlmDescriptor) => void;
onClose: () => void;
defaultModel: string | null;
}) {
@@ -127,18 +127,14 @@ export function UserSettingsModal({
);
});
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
([provider, options]) => [...options]
);
const router = useRouter();
const handleChangedefaultModel = async (defaultModel: string | null) => {
try {
const response = await setUserDefaultModel(defaultModel);
if (response.ok) {
if (defaultModel && setLlmOverride) {
setLlmOverride(destructureValue(defaultModel));
if (defaultModel && setCurrentLlm) {
setCurrentLlm(destructureValue(defaultModel));
}
setPopup({
message: "Default model updated successfully",

View File

@@ -95,7 +95,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
}
}
if (enterpriseSettings && settings.pro_search_enabled == null) {
if (settings.pro_search_enabled == null) {
settings.pro_search_enabled = true;
}

View File

@@ -360,18 +360,18 @@ export const useUsers = ({ includeApiKeys }: UseUsersParams) => {
};
};
export interface LlmOverride {
export interface LlmDescriptor {
name: string;
provider: string;
modelName: string;
}
export interface LlmOverrideManager {
llmOverride: LlmOverride;
updateLLMOverride: (newOverride: LlmOverride) => void;
export interface LlmManager {
currentLlm: LlmDescriptor;
updateCurrentLlm: (newOverride: LlmDescriptor) => void;
temperature: number;
updateTemperature: (temperature: number) => void;
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
updateModelOverrideBasedOnChatSession: (chatSession?: ChatSession) => void;
imageFilesPresent: boolean;
updateImageFilesPresent: (present: boolean) => void;
liveAssistant: Persona | null;
@@ -400,7 +400,7 @@ Thus, the input should be
Changes take place as
- liveAssistant or currentChatSession changes (and the associated model override is set)
- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
- (updateCurrentLlm) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
If we have a live assistant, we should use that model override
@@ -419,55 +419,78 @@ This approach ensures that user preferences are maintained for existing chats wh
providing appropriate defaults for new conversations based on the available tools.
*/
export function useLlmOverride(
export function useLlmManager(
llmProviders: LLMProviderDescriptor[],
currentChatSession?: ChatSession,
liveAssistant?: Persona
): LlmOverrideManager {
): LlmManager {
const { user } = useUser();
const [userHasManuallyOverriddenLLM, setUserHasManuallyOverriddenLLM] =
useState(false);
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
const [currentLlm, setCurrentLlm] = useState<LlmDescriptor>({
name: "",
provider: "",
modelName: "",
});
const llmOverrideUpdate = () => {
if (liveAssistant?.llm_model_version_override) {
setLlmOverride(
getValidLlmOverride(liveAssistant.llm_model_version_override)
);
} else if (currentChatSession?.current_alternate_model) {
setLlmOverride(
getValidLlmOverride(currentChatSession.current_alternate_model)
);
} else if (user?.preferences?.default_model) {
setLlmOverride(getValidLlmOverride(user.preferences.default_model));
return;
} else {
const defaultProvider = llmProviders.find(
(provider) => provider.is_default_provider
);
const llmUpdate = () => {
/* Should be called when the live assistant or current chat session changes */
if (defaultProvider) {
setLlmOverride({
name: defaultProvider.name,
provider: defaultProvider.provider,
modelName: defaultProvider.default_model_name,
});
// separate function so we can `return` to break out
const _llmUpdate = () => {
// if the user has overridden in this session and just switched to a brand
// new session, use their manually specified model
if (userHasManuallyOverriddenLLM && !currentChatSession) {
return;
}
}
if (currentChatSession?.current_alternate_model) {
setCurrentLlm(
getValidLlmDescriptor(currentChatSession.current_alternate_model)
);
} else if (liveAssistant?.llm_model_version_override) {
setCurrentLlm(
getValidLlmDescriptor(liveAssistant.llm_model_version_override)
);
} else if (userHasManuallyOverriddenLLM) {
// if the user has an override and there's nothing special about the
// current chat session, use the override
return;
} else if (user?.preferences?.default_model) {
setCurrentLlm(getValidLlmDescriptor(user.preferences.default_model));
} else {
const defaultProvider = llmProviders.find(
(provider) => provider.is_default_provider
);
if (defaultProvider) {
setCurrentLlm({
name: defaultProvider.name,
provider: defaultProvider.provider,
modelName: defaultProvider.default_model_name,
});
}
}
};
_llmUpdate();
setChatSession(currentChatSession || null);
};
const getValidLlmOverride = (
overrideModel: string | null | undefined
): LlmOverride => {
if (overrideModel) {
const model = destructureValue(overrideModel);
const getValidLlmDescriptor = (
modelName: string | null | undefined
): LlmDescriptor => {
if (modelName) {
const model = destructureValue(modelName);
if (!(model.modelName && model.modelName.length > 0)) {
const provider = llmProviders.find((p) =>
p.model_names.includes(overrideModel)
p.model_names.includes(modelName)
);
if (provider) {
return {
modelName: overrideModel,
modelName: modelName,
name: provider.name,
provider: provider.provider,
};
@@ -491,38 +514,32 @@ export function useLlmOverride(
setImageFilesPresent(present);
};
const [llmOverride, setLlmOverride] = useState<LlmOverride>({
name: "",
provider: "",
modelName: "",
});
// Manually set the override
const updateLLMOverride = (newOverride: LlmOverride) => {
// Manually set the LLM
const updateCurrentLlm = (newLlm: LlmDescriptor) => {
const provider =
newOverride.provider ||
findProviderForModel(llmProviders, newOverride.modelName);
newLlm.provider || findProviderForModel(llmProviders, newLlm.modelName);
const structuredValue = structureValue(
newOverride.name,
newLlm.name,
provider,
newOverride.modelName
newLlm.modelName
);
setLlmOverride(getValidLlmOverride(structuredValue));
setCurrentLlm(getValidLlmDescriptor(structuredValue));
setUserHasManuallyOverriddenLLM(true);
};
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
const updateModelOverrideBasedOnChatSession = (chatSession?: ChatSession) => {
if (chatSession && chatSession.current_alternate_model?.length > 0) {
setLlmOverride(getValidLlmOverride(chatSession.current_alternate_model));
setCurrentLlm(getValidLlmDescriptor(chatSession.current_alternate_model));
}
};
const [temperature, setTemperature] = useState<number>(() => {
llmOverrideUpdate();
llmUpdate();
if (currentChatSession?.current_temperature_override != null) {
return Math.min(
currentChatSession.current_temperature_override,
isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0
isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0
);
} else if (
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
@@ -533,22 +550,23 @@ export function useLlmOverride(
});
const maxTemperature = useMemo(() => {
return isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0;
}, [llmOverride]);
return isAnthropic(currentLlm.provider, currentLlm.modelName) ? 1.0 : 2.0;
}, [currentLlm]);
useEffect(() => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
const newTemperature = Math.min(temperature, 1.0);
setTemperature(newTemperature);
if (chatSession?.id) {
updateTemperatureOverrideForChatSession(chatSession.id, newTemperature);
}
}
}, [llmOverride]);
}, [currentLlm]);
useEffect(() => {
llmUpdate();
if (!chatSession && currentChatSession) {
setChatSession(currentChatSession || null);
if (temperature) {
updateTemperatureOverrideForChatSession(
currentChatSession.id,
@@ -570,7 +588,7 @@ export function useLlmOverride(
}, [liveAssistant, currentChatSession]);
const updateTemperature = (temperature: number) => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
if (isAnthropic(currentLlm.provider, currentLlm.modelName)) {
setTemperature((prevTemp) => Math.min(temperature, 1.0));
} else {
setTemperature(temperature);
@@ -581,9 +599,9 @@ export function useLlmOverride(
};
return {
updateModelOverrideForChatSession,
llmOverride,
updateLLMOverride,
updateModelOverrideBasedOnChatSession,
currentLlm,
updateCurrentLlm,
temperature,
updateTemperature,
imageFilesPresent,

View File

@@ -1,11 +1,11 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { LlmOverride } from "@/lib/hooks";
import { LlmDescriptor } from "@/lib/hooks";
export function getFinalLLM(
llmProviders: LLMProviderDescriptor[],
persona: Persona | null,
llmOverride: LlmOverride | null
currentLlm: LlmDescriptor | null
): [string, string] {
const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
@@ -26,9 +26,9 @@ export function getFinalLLM(
model = persona.llm_model_version_override || model;
}
if (llmOverride) {
provider = llmOverride.provider || provider;
model = llmOverride.modelName || model;
if (currentLlm) {
provider = currentLlm.provider || provider;
model = currentLlm.modelName || model;
}
return [provider, model];
@@ -37,7 +37,7 @@ export function getFinalLLM(
export function getLLMProviderOverrideForPersona(
liveAssistant: Persona,
llmProviders: LLMProviderDescriptor[]
): LlmOverride | null {
): LlmDescriptor | null {
const overrideProvider = liveAssistant.llm_model_provider_override;
const overrideModel = liveAssistant.llm_model_version_override;
@@ -135,7 +135,7 @@ export const structureValue = (
return `${name}__${provider}__${modelName}`;
};
export const destructureValue = (value: string): LlmOverride => {
export const destructureValue = (value: string): LlmDescriptor => {
const [displayName, provider, modelName] = value.split("__");
return {
name: displayName,

View File

@@ -1,5 +1,3 @@
import { LlmOverride } from "../hooks";
export async function setUserDefaultModel(
model: string | null
): Promise<Response> {