Compare commits

..

10 Commits

Author SHA1 Message Date
pablonyx
0f6fc5b520 nit 2025-02-21 15:55:51 -08:00
pablonyx
65309d8b27 k 2025-02-21 15:07:17 -08:00
pablonyx
03f8f2b7df pass env vars down 2025-02-21 08:55:05 -08:00
pablonyx
1caeea23e2 minor 2025-02-20 16:09:48 -08:00
pablonyx
b3553a7f3b minor update 2025-02-20 15:12:48 -08:00
pablonyx
7e87618d27 update connector creation during tests 2025-02-19 18:50:17 -08:00
pablonyx
1730c13152 u 2025-02-19 17:40:30 -08:00
pablonyx
a1e0f56f98 fix mock 2025-02-19 17:40:11 -08:00
pablonyx
6aa076c382 delete 2025-02-19 17:40:11 -08:00
pablonyx
6a033460e4 update 2025-02-19 17:40:10 -08:00
19 changed files with 103 additions and 141 deletions

View File

@@ -145,7 +145,7 @@ jobs:
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
@@ -157,6 +157,7 @@ jobs:
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
@@ -199,7 +200,7 @@ jobs:
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |

View File

@@ -1,27 +0,0 @@
"""Add composite index for last_modified and last_synced to document
Revision ID: f13db29f3101
Revises: b388730a2899
Create Date: 2025-02-18 22:48:11.511389
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f13db29f3101"
down_revision = "acaab4ef4507"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
op.create_index(
"ix_document_sync_status",
"document",
["last_modified", "last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index("ix_document_sync_status", table_name="document")

View File

@@ -10,7 +10,6 @@ from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
@@ -188,51 +187,23 @@ def send_subscription_cancellation_email(user_email: str) -> None:
send_email(user_email, subject, html_content, text_content)
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
def send_user_email_invite(user_email: str, current_user: User) -> None:
subject = "Invitation to Join Onyx Organization"
heading = "You've Been Invited!"
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
elif auth_type == AuthType.BASIC:
message += (
"<p>To join the organization, please click the button below to set a password "
"and complete your registration.</p>"
)
elif auth_type == AuthType.GOOGLE_OAUTH:
message += (
"<p>To join the organization, please click the button below to login with Google "
"and complete your registration.</p>"
)
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
message += (
"<p>To join the organization, please click the button below to"
" complete your registration.</p>"
)
else:
raise ValueError(f"Invalid auth type: {auth_type}")
message = (
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
"You'll be asked to set a password or login with Google to complete your registration."
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
send_email(user_email, subject, html_content, text_content)

View File

@@ -140,7 +140,7 @@ def on_task_postrun(
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)

View File

@@ -361,6 +361,7 @@ def connector_external_group_sync_generator_task(
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(

View File

@@ -15,6 +15,7 @@ from onyx.background.indexing.memory_tracer import MemoryTracer
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
@@ -89,8 +90,8 @@ def _get_connector_runner(
)
# validate the connector settings
runnable_connector.validate_connector_settings()
if not INTEGRATION_TESTS_MODE:
runnable_connector.validate_connector_settings()
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")

View File

@@ -158,7 +158,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
@@ -626,6 +626,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"

View File

@@ -3,6 +3,7 @@ from typing import Type
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceRequiringTenantContext
from onyx.connectors.airtable.airtable_connector import AirtableConnector
@@ -187,6 +188,9 @@ def validate_ccpair_for_user(
user: User | None,
tenant_id: str | None,
) -> None:
if INTEGRATION_TESTS_MODE:
return
# Validate the connector settings
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id_for_user(
@@ -199,7 +203,10 @@ def validate_ccpair_for_user(
if not connector:
raise ValueError("Connector not found")
if connector.source == DocumentSource.INGESTION_API:
if (
connector.source == DocumentSource.INGESTION_API
or connector.source == DocumentSource.MOCK_CONNECTOR
):
return
if not credential:

View File

@@ -229,20 +229,16 @@ class GitbookConnector(LoadConnector, PollConnector):
try:
content = self.client.get(f"/spaces/{self.space_id}/content")
pages: list[dict[str, Any]] = content.get("pages", [])
pages = content.get("pages", [])
current_batch: list[Document] = []
for page in pages:
updated_at = datetime.fromisoformat(page["updatedAt"])
while pages:
page = pages.pop(0)
updated_at_raw = page.get("updatedAt")
if updated_at_raw is None:
# if updatedAt is not present, that means the page has never been edited
continue
updated_at = datetime.fromisoformat(updated_at_raw)
if start and updated_at < start:
continue
if current_batch:
yield current_batch
return
if end and updated_at > end:
continue
@@ -254,8 +250,6 @@ class GitbookConnector(LoadConnector, PollConnector):
yield current_batch
current_batch = []
pages.extend(page.get("pages", []))
if current_batch:
yield current_batch

View File

@@ -220,7 +220,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
try:
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
except KeyError:
raise ValueError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,

View File

@@ -194,9 +194,14 @@ def get_connector_credential_pair_from_id_for_user(
def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
eager_load_credential: bool = False,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
if eager_load_credential:
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
result = db_session.execute(stmt)
return result.scalar_one_or_none()

View File

@@ -60,8 +60,9 @@ def count_documents_by_needs_sync(session: Session) -> int:
This function executes the query and returns the count of
documents matching the criteria."""
return (
session.query(DbDocument.id)
count = (
session.query(func.count(DbDocument.id.distinct()))
.select_from(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
@@ -72,53 +73,63 @@ def count_documents_by_needs_sync(session: Session) -> int:
DbDocument.last_synced.is_(None),
)
)
.count()
.scalar()
)
return count
def construct_document_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
return (
select(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
),
)
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
select(DbDocument)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
)
.distinct()
)
return stmt
def construct_document_id_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
return (
select(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
),
)
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
select(DbDocument.id)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
)
.distinct()
)
return stmt
def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int

View File

@@ -570,14 +570,6 @@ class Document(Base):
back_populates="documents",
)
__table_args__ = (
Index(
"ix_document_sync_status",
last_modified,
last_synced,
),
)
class Tag(Base):
__tablename__ = "tag"

View File

@@ -311,23 +311,19 @@ def bulk_invite_users(
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
number_of_invited_users = write_invited_users(all_emails)
# send out email invitations if enabled
if ENABLE_EMAIL_INVITES:
try:
for email in new_invited_emails:
send_user_email_invite(email, current_user, AUTH_TYPE)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
if not MULTI_TENANT:
return number_of_invited_users
# for billing purposes, write to the control plane about the number of new users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))
if ENABLE_EMAIL_INVITES:
try:
for email in new_invited_emails:
send_user_email_invite(email, current_user)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
return number_of_invited_users
except Exception as e:

View File

@@ -3,7 +3,7 @@ import os
ADMIN_USER_NAME = "admin_user"
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "127.0.0.1"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
MAX_DELAY = 45

View File

@@ -30,8 +30,10 @@ class ConnectorManager:
name=name,
source=source,
input_type=input_type,
connector_specific_config=connector_specific_config
or {"file_locations": []},
connector_specific_config=(
connector_specific_config
or ({"file_locations": []} if source == DocumentSource.FILE else {})
),
access_type=access_type,
groups=groups or [],
)

View File

@@ -88,8 +88,6 @@ class UserManager:
if not session_cookie:
raise Exception("Failed to login")
print(f"Logged in as {test_user.email}")
# Set cookies in the headers
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
test_user.cookies = {"fastapiusersauth": session_cookie}

View File

@@ -36,6 +36,7 @@ services:
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
- INTEGRATION_TESTS_MODE=${INTEGRATION_TESTS_MODE:-}
# Gen AI Settings
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}

View File

@@ -95,7 +95,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
}
}
if (settings.pro_search_enabled == null) {
if (enterpriseSettings && settings.pro_search_enabled == null) {
settings.pro_search_enabled = true;
}