Compare commits

...

1 Commits

Author SHA1 Message Date
hagen-danswer
089cfe7478 User Filter Polish 2025-01-16 11:47:11 -08:00
5 changed files with 105 additions and 39 deletions

View File

@@ -12,6 +12,8 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.enums import AccessType
@@ -33,8 +35,13 @@ logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -94,7 +101,7 @@ def _add_user_filters(
def get_connector_credential_pairs_for_user(
db_session: Session,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
@@ -105,6 +112,7 @@ def get_connector_credential_pairs_for_user(
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
@@ -115,12 +123,11 @@ def get_connector_credential_pairs(
db_session: Session,
ids: list[int] | None = None,
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
return get_connector_credential_pairs_for_user(
db_session=db_session,
user=SYSTEM_USER,
ids=ids,
)
def add_deletion_failure_message(
@@ -155,7 +162,7 @@ def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
credential_id: int,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
@@ -171,17 +178,18 @@ def get_connector_credential_pair(
connector_id: int,
credential_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair)
stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id)
stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
return get_connector_credential_pair_for_user(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
user=SYSTEM_USER,
)
def get_connector_credential_pair_from_id_for_user(
cc_pair_id: int,
db_session: Session,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
@@ -195,10 +203,11 @@ def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
result = db_session.execute(stmt)
return result.scalar_one_or_none()
return get_connector_credential_pair_from_id_for_user(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=SYSTEM_USER,
)
def get_last_successful_attempt_time(

View File

@@ -1 +1,11 @@
from typing import Final
SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"
class SystemUser:
"""Represents the system user for internal operations"""
SYSTEM_USER: Final = SystemUser()

View File

@@ -14,6 +14,8 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
@@ -42,11 +44,17 @@ PUBLIC_CREDENTIAL_ID = 0
def _add_user_filters(
stmt: Select,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> Select:
"""Attaches filters to the statement to ensure that the user can only
access the appropriate credentials"""
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")
if user is None:
if not DISABLE_AUTH:
raise ValueError("Anonymous users are not allowed to access credentials")
@@ -151,7 +159,7 @@ def fetch_credentials_for_user(
def fetch_credential_by_id_for_user(
credential_id: int,
user: User | None,
user: User | None | SystemUser,
db_session: Session,
get_editable: bool = True,
) -> Credential | None:
@@ -171,16 +179,16 @@ def fetch_credential_by_id(
db_session: Session,
credential_id: int,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential
return fetch_credential_by_id_for_user(
credential_id=credential_id,
user=SYSTEM_USER,
db_session=db_session,
)
def fetch_credentials_by_source_for_user(
db_session: Session,
user: User | None,
user: User | None | SystemUser,
document_source: DocumentSource | None = None,
get_editable: bool = True,
) -> list[Credential]:
@@ -194,9 +202,11 @@ def fetch_credentials_by_source(
db_session: Session,
document_source: DocumentSource | None = None,
) -> list[Credential]:
base_query = select(Credential).where(Credential.source == document_source)
credentials = db_session.execute(base_query).scalars().all()
return list(credentials)
return fetch_credentials_by_source_for_user(
db_session=db_session,
user=SYSTEM_USER,
document_source=document_source,
)
def swap_credentials_connector(

View File

@@ -15,6 +15,8 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
@@ -35,8 +37,13 @@ logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -487,7 +494,7 @@ def fetch_document_sets(
def fetch_all_document_sets_for_user(
db_session: Session,
user: User | None,
user: User | None | SystemUser,
get_editable: bool = True,
) -> Sequence[DocumentSetDBModel]:
stmt = select(DocumentSetDBModel).distinct()
@@ -495,6 +502,15 @@ def fetch_all_document_sets_for_user(
return db_session.scalars(stmt).all()
def fetch_all_document_sets(
db_session: Session,
) -> Sequence[DocumentSetDBModel]:
return fetch_all_document_sets_for_user(
db_session=db_session,
user=SYSTEM_USER,
)
def fetch_documents_for_document_set_paginated(
document_set_id: int,
db_session: Session,

View File

@@ -23,6 +23,8 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.constants import SYSTEM_USER
from onyx.db.constants import SystemUser
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import DocumentSet
from onyx.db.models import Persona
@@ -44,8 +46,13 @@ logger = setup_logger()
def _add_user_filters(
stmt: Select, user: User | None, get_editable: bool = True
stmt: Select, user: User | None | SystemUser, get_editable: bool = True
) -> Select:
if isinstance(user, SystemUser):
if user is SYSTEM_USER:
return stmt
raise ValueError("Bad SystemUser object")
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -111,7 +118,10 @@ def _add_user_filters(
def fetch_persona_by_id_for_user(
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
db_session: Session,
persona_id: int,
user: User | None | SystemUser,
get_editable: bool = True,
) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id).distinct()
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
@@ -124,6 +134,17 @@ def fetch_persona_by_id_for_user(
return persona
def fetch_persona_by_id(
db_session: Session,
persona_id: int,
) -> Persona:
return fetch_persona_by_id_for_user(
db_session=db_session,
persona_id=persona_id,
user=SYSTEM_USER,
)
def get_best_persona_id_for_user(
db_session: Session, user: User | None, persona_id: int | None = None
) -> int | None:
@@ -285,7 +306,7 @@ def get_prompts(
def get_personas_for_user(
# if user is `None` assume the user is an admin or auth is disabled
user: User | None,
user: User | None | SystemUser,
db_session: Session,
get_editable: bool = True,
include_default: bool = True,
@@ -315,10 +336,10 @@ def get_personas_for_user(
def get_personas(db_session: Session) -> Sequence[Persona]:
stmt = select(Persona).distinct()
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
stmt = stmt.where(Persona.deleted.is_(False))
return db_session.execute(stmt).unique().scalars().all()
return get_personas_for_user(
user=SYSTEM_USER,
db_session=db_session,
)
def mark_persona_as_deleted(