Compare commits

...

3 Commits

Author SHA1 Message Date
Evan Lohn
6734ab0ac4 potential integration test fix 2025-02-26 18:53:18 -08:00
Evan Lohn
7736c351e4 addressed CW comments 2025-02-26 18:53:18 -08:00
Evan Lohn
a5831ae375 slight improvements to user group endpoints 2025-02-26 18:53:18 -08:00
5 changed files with 77 additions and 7 deletions

View File

@@ -8,6 +8,8 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -16,12 +18,15 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -175,13 +180,55 @@ def validate_object_creation_for_user(
)
def eager_usergroup_options(stmt: Select[tuple[UserGroup]]) -> Select[tuple[UserGroup]]:
return stmt.options(
# Which users are in this group
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
# Which CC pairs this group has access to
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.joinedload(ConnectorCredentialPair.credential),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.joinedload(ConnectorCredentialPair.connector)
.contains_eager(Connector.credentials),
# Which document sets this group has access to
selectinload(UserGroup.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.credential),
selectinload(UserGroup.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.joinedload(ConnectorCredentialPair.connector)
.contains_eager(Connector.credentials),
# Which personas this group has access to. Each persona has
# its own set of associated data similar to the above per-user-group
# associations; TODO: do we really need to load all of this?
selectinload(UserGroup.personas).selectinload(Persona.user),
selectinload(UserGroup.personas).selectinload(Persona.prompts),
selectinload(UserGroup.personas).selectinload(Persona.tools),
selectinload(UserGroup.personas)
.selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.credential),
selectinload(UserGroup.personas)
.selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.joinedload(ConnectorCredentialPair.connector)
.contains_eager(Connector.credentials),
selectinload(UserGroup.personas).selectinload(Persona.users),
selectinload(UserGroup.personas).selectinload(Persona.groups),
selectinload(UserGroup.personas).selectinload(Persona.labels),
)
def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None:
stmt = select(UserGroup).where(UserGroup.id == user_group_id)
return db_session.scalar(stmt)
def fetch_user_groups(
db_session: Session, only_up_to_date: bool = True
db_session: Session, only_up_to_date: bool = True, eager_load_all: bool = False
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -201,11 +248,18 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if eager_load_all:
stmt = eager_usergroup_options(stmt)
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session, user_id: UUID, only_curator_groups: bool = False
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_all: bool = False,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -215,6 +269,11 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if eager_load_all:
stmt = eager_usergroup_options(stmt)
stmt = stmt.options(contains_eager(UserGroup.users))
return db_session.scalars(stmt).all()

View File

@@ -32,12 +32,15 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user is None or user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_all=True
)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_all=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -23,6 +23,8 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.configs.constants import NotificationType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -332,10 +334,16 @@ def get_personas_for_user(
stmt = stmt.options(
selectinload(Persona.prompts),
selectinload(Persona.tools),
selectinload(Persona.document_sets),
selectinload(Persona.groups),
selectinload(Persona.users),
selectinload(Persona.labels),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.credential),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.joinedload(ConnectorCredentialPair.connector)
.contains_eager(Connector.credentials),
)
results = db_session.execute(stmt).scalars().all()

View File

@@ -221,7 +221,7 @@ def index_doc_batch_prepare(
else documents
)
if len(updatable_docs) != len(documents):
updatable_doc_ids = [doc.id for doc in updatable_docs]
updatable_doc_ids = {doc.id for doc in updatable_docs}
skipped_doc_ids = [
doc.id for doc in documents if doc.id not in updatable_doc_ids
]

View File

@@ -78,7 +78,7 @@ USERS_PAGE_SIZE = 10
@router.patch("/manage/set-user-role")
def set_user_role(
user_role_update_request: UserRoleUpdateRequest,
current_user: User = Depends(current_admin_user),
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
user_to_update = get_user_by_email(
@@ -98,7 +98,7 @@ def set_user_role(
current_role=current_role,
)
if user_to_update.id == current_user.id:
if current_user and user_to_update.id == current_user.id:
raise HTTPException(
status_code=400,
detail="An admin cannot demote themselves from admin role!",