mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-16 23:16:46 +00:00
Compare commits
18 Commits
v3.2.0-clo
...
edge
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
210d11aa5d | ||
|
|
f9458c86ec | ||
|
|
369306a0f3 | ||
|
|
8af6ee9c9b | ||
|
|
f5f953cc28 | ||
|
|
3f360e462f | ||
|
|
0602353b2b | ||
|
|
78288867b7 | ||
|
|
0e7b99f960 | ||
|
|
3f2d0a0567 | ||
|
|
e0897265e3 | ||
|
|
bc9c03ab76 | ||
|
|
dfc3886683 | ||
|
|
a3cb45e56d | ||
|
|
6fd07f44e1 | ||
|
|
2a3b487fad | ||
|
|
a14dc4e632 | ||
|
|
b6467e8e3e |
@@ -1,6 +1,7 @@
|
||||
FROM ubuntu:26.04@sha256:cc925e589b7543b910fea57a240468940003fbfc0515245a495dd0ad8fe7cef1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
curl \
|
||||
default-jre \
|
||||
fd-find \
|
||||
@@ -61,3 +62,11 @@ RUN chsh -s /bin/zsh root && \
|
||||
echo '[ -f /workspace/.devcontainer/zshrc ] && . /workspace/.devcontainer/zshrc' >> "$rc"; \
|
||||
done && \
|
||||
chown dev:dev /home/dev/.zshrc
|
||||
|
||||
# Pre-seed GitHub's SSH host keys so git-over-SSH never prompts. Keys are
|
||||
# pinned in-repo (verified against the fingerprints GitHub publishes at
|
||||
# https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/githubs-ssh-key-fingerprints)
|
||||
# rather than fetched at build time, so a compromised build-time network can't
|
||||
# inject a rogue key.
|
||||
COPY github_known_hosts /etc/ssh/ssh_known_hosts
|
||||
RUN chmod 644 /etc/ssh/ssh_known_hosts
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
{
|
||||
"name": "Onyx Dev Sandbox",
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:0f02d9299928849c7b15f3b348dcfdcdcb64411ff7a4580cbc026a6ee7aa1554",
|
||||
"runArgs": ["--cap-add=NET_ADMIN", "--cap-add=NET_RAW", "--network=onyx_default"],
|
||||
"image": "onyxdotapp/onyx-devcontainer@sha256:4986c9252289b660ce772b45f0488b938fe425d8114245e96ef64b273b3fcee4",
|
||||
"runArgs": [
|
||||
"--cap-add=NET_ADMIN",
|
||||
"--cap-add=NET_RAW",
|
||||
"--network=onyx_default"
|
||||
],
|
||||
"mounts": [
|
||||
"source=${localEnv:HOME}/.claude,target=/home/dev/.claude,type=bind",
|
||||
"source=${localEnv:HOME}/.claude.json,target=/home/dev/.claude.json,type=bind",
|
||||
|
||||
3
.devcontainer/github_known_hosts
Normal file
3
.devcontainer/github_known_hosts
Normal file
@@ -0,0 +1,3 @@
|
||||
github.com ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQCj7ndNxQowgcQnjshcLrqPEiiphnt+VTTvDP6mHBL9j1aNUkY4Ue1gvwnGLVlOhGeYrnZaMgRK6+PKCUXaDbC7qtbW8gIkhL7aGCsOr/C56SJMy/BCZfxd1nWzAOxSDPgVsmerOBYfNqltV9/hWCqBywINIR+5dIg6JTJ72pcEpEjcYgXkE2YEFXV1JHnsKgbLWNlhScqb2UmyRkQyytRLtL+38TGxkxCflmO+5Z8CSSNY7GidjMIZ7Q4zMjA2n1nGrlTDkzwDCsw+wqFPGQA179cnfGWOWRVruj16z6XyvxvjJwbz0wQZ75XK5tKSb7FNyeIEs4TT4jk+S4dhPeAUC5y+bDYirYgM4GC7uEnztnZyaVWQ7B381AK4Qdrwt51ZqExKbQpTUNn+EjqoTwvqNj4kqx5QUCI0ThS/YkOxJCXmPUWZbhjpCg56i+2aB6CmK2JGhn57K5mj0MNdBXA4/WnwH6XoPWJzK5Nyu2zB3nAZp+S5hpQs+p1vN1/wsjk=
|
||||
github.com ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBEmKSENjQEezOmxkZMy7opKgwFB9nkt5YRrYMjNuG5N87uRgg6CLrbo5wAdT/y6v0mKV0U2w0WZ2YB/++Tpockg=
|
||||
github.com ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIOMqqnkVzrm0SdG6UOoqKLsabgH5C9okWi0dh2l9GKJl
|
||||
@@ -4,6 +4,17 @@ set -euo pipefail
|
||||
|
||||
echo "Setting up firewall..."
|
||||
|
||||
# Reset default policies to ACCEPT before flushing rules. On re-runs the
|
||||
# previous invocation's DROP policies are still in effect; flushing rules while
|
||||
# the default is DROP would block the DNS lookups below. Register a trap so
|
||||
# that if the script exits before the DROP policies are re-applied at the end,
|
||||
# we fail closed instead of leaving the container with an unrestricted
|
||||
# firewall.
|
||||
trap 'iptables -P INPUT DROP; iptables -P OUTPUT DROP; iptables -P FORWARD DROP' EXIT
|
||||
iptables -P INPUT ACCEPT
|
||||
iptables -P OUTPUT ACCEPT
|
||||
iptables -P FORWARD ACCEPT
|
||||
|
||||
# Only flush the filter table. The nat and mangle tables are managed by Docker
|
||||
# (DNS DNAT to 127.0.0.11, container networking, etc.) and must not be touched —
|
||||
# flushing them breaks Docker's embedded DNS resolver.
|
||||
@@ -34,8 +45,16 @@ ALLOWED_DOMAINS=(
|
||||
"pypi.org"
|
||||
"files.pythonhosted.org"
|
||||
"go.dev"
|
||||
"proxy.golang.org"
|
||||
"sum.golang.org"
|
||||
"storage.googleapis.com"
|
||||
"dl.google.com"
|
||||
"static.rust-lang.org"
|
||||
"index.crates.io"
|
||||
"static.crates.io"
|
||||
"archive.ubuntu.com"
|
||||
"security.ubuntu.com"
|
||||
"deb.nodesource.com"
|
||||
)
|
||||
|
||||
for domain in "${ALLOWED_DOMAINS[@]}"; do
|
||||
|
||||
31
.github/workflows/pr-python-checks.yml
vendored
31
.github/workflows/pr-python-checks.yml
vendored
@@ -19,16 +19,16 @@ permissions:
|
||||
jobs:
|
||||
mypy-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
|
||||
# NOTE: This job is named mypy-check for branch protection compatibility,
|
||||
# but it actually runs ty (astral-sh's Rust type checker).
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-x64,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-mypy-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
timeout-minutes: 15
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
@@ -46,26 +46,7 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
backend/requirements/ee.txt
|
||||
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
- name: Run ty
|
||||
env:
|
||||
LICENSE_ENFORCEMENT_ENABLED: "false"
|
||||
run: |
|
||||
ods openapi all
|
||||
|
||||
- name: Cache mypy cache
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
run: ty check --output-format github
|
||||
|
||||
2
.github/workflows/pr-python-model-tests.yml
vendored
2
.github/workflows/pr-python-model-tests.yml
vendored
@@ -17,8 +17,6 @@ env:
|
||||
|
||||
# API keys for testing
|
||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
|
||||
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AZURE_API_URL: ${{ vars.AZURE_API_URL }}
|
||||
|
||||
@@ -67,12 +67,11 @@ repos:
|
||||
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
pass_filenames: true
|
||||
files: ^backend/(?!\.venv/|scripts/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
# args: ["--all-extras", "mypy"]
|
||||
# pass_filenames: true
|
||||
# files: ^backend/.*\.py$
|
||||
- id: uv-run
|
||||
name: ty
|
||||
args: ["ty", "check"]
|
||||
pass_filenames: true
|
||||
types_or: [python]
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0
|
||||
@@ -142,6 +141,7 @@ repos:
|
||||
hooks:
|
||||
- id: ripsecrets
|
||||
args:
|
||||
- --strict-ignore
|
||||
- --additional-pattern
|
||||
- ^sk-[A-Za-z0-9_\-]{20,}$
|
||||
|
||||
|
||||
1
.secretsignore
Normal file
1
.secretsignore
Normal file
@@ -0,0 +1 @@
|
||||
.devcontainer/github_known_hosts
|
||||
@@ -63,11 +63,13 @@ Your features must pass all tests and all comments must be addressed prior to me
|
||||
### Implicit agreements
|
||||
|
||||
If we approve an issue, we are promising you the following:
|
||||
|
||||
- Your work will receive timely attention and we will put aside other important items to ensure you are not blocked.
|
||||
- You will receive necessary coaching on eng quality, system design, etc. to ensure the feature is completed well.
|
||||
- The Onyx team will pull resources and bandwidth from design, PM, and engineering to ensure that you have all the resources to build the feature to the quality required for merging.
|
||||
|
||||
Because this is a large investment from our team, we ask that you:
|
||||
|
||||
- Thoroughly read all the requirements of the design docs, engineering best practices, and try to minimize overhead for the Onyx team.
|
||||
- Complete the feature in a timely manner to reduce context switching and an ongoing resource pull from the Onyx team.
|
||||
|
||||
@@ -149,10 +151,10 @@ Set up pre-commit hooks (black / reorder-python-imports):
|
||||
uv run pre-commit install
|
||||
```
|
||||
|
||||
We also use `mypy` for static type checking. Onyx is fully type-annotated, and we want to keep it that way! To run the mypy checks manually:
|
||||
We also use `ty` for static type checking. Onyx is fully type-annotated, and we want to keep it that way! To run the ty checks manually:
|
||||
|
||||
```bash
|
||||
uv run mypy . # from onyx/backend
|
||||
uv run ty check
|
||||
```
|
||||
|
||||
#### Frontend
|
||||
@@ -192,6 +194,7 @@ Before starting, make sure the Docker Daemon is running.
|
||||
> **Note:** "Clear and Restart External Volumes and Containers" will reset your Postgres and OpenSearch (relational-db and index). Only run this if you are okay with wiping your data.
|
||||
|
||||
**Features:**
|
||||
|
||||
- Hot reload is enabled for the web server and API servers
|
||||
- Python debugging is configured with debugpy
|
||||
- Environment variables are loaded from `.vscode/.env`
|
||||
@@ -344,13 +347,16 @@ sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit
|
||||
### Style and Maintainability
|
||||
|
||||
#### Comments and readability
|
||||
|
||||
Add clear comments:
|
||||
|
||||
- At logical boundaries (e.g., interfaces) so the reader doesn't need to dig 10 layers deeper.
|
||||
- Wherever assumptions are made or something non-obvious/unexpected is done.
|
||||
- For complicated flows/functions.
|
||||
- Wherever it saves time (e.g., nontrivial regex patterns).
|
||||
|
||||
#### Errors and exceptions
|
||||
|
||||
- **Fail loudly** rather than silently skipping work.
|
||||
- Example: raise and let exceptions propagate instead of silently dropping a document.
|
||||
- **Don't overuse `try/except`.**
|
||||
@@ -358,6 +364,7 @@ Add clear comments:
|
||||
- Do not mask exceptions unless it is clearly appropriate.
|
||||
|
||||
#### Typing
|
||||
|
||||
- Everything should be **as strictly typed as possible**.
|
||||
- Use `cast` for annoying/loose-typed interfaces (e.g., results of `run_functions_tuples_in_parallel`).
|
||||
- Only `cast` when the type checker sees `Any` or types are too loose.
|
||||
@@ -368,6 +375,7 @@ Add clear comments:
|
||||
- `dict[EmbeddingModel, list[EmbeddingVector]]`
|
||||
|
||||
#### State, objects, and boundaries
|
||||
|
||||
- Keep **clear logical boundaries** for state containers and objects.
|
||||
- A **config** object should never contain things like a `db_session`.
|
||||
- Avoid state containers that are overly nested, or huge + flat (use judgment).
|
||||
@@ -380,6 +388,7 @@ Add clear comments:
|
||||
- Prefer **hash maps (dicts)** over tree structures unless there's a strong reason.
|
||||
|
||||
#### Naming
|
||||
|
||||
- Name variables carefully and intentionally.
|
||||
- Prefer long, explicit names when undecided.
|
||||
- Avoid single-character variables except for small, self-contained utilities (or not at all).
|
||||
@@ -390,6 +399,7 @@ Add clear comments:
|
||||
- IntelliSense can miss call sites; search works best with unique names.
|
||||
|
||||
#### Correctness by construction
|
||||
|
||||
- Prefer self-contained correctness — don't rely on callers to "use it right" if you can make misuse hard.
|
||||
- Avoid redundancies: if a function takes an arg, it shouldn't also take a state object that contains that same arg.
|
||||
- No dead code (unless there's a very good reason).
|
||||
@@ -417,29 +427,35 @@ Add clear comments:
|
||||
### Repository Conventions
|
||||
|
||||
#### Where code lives
|
||||
|
||||
- Pydantic + data models: `models.py` files.
|
||||
- DB interface functions (excluding lazy loading): `db/` directory.
|
||||
- LLM prompts: `prompts/` directory, roughly mirroring the code layout that uses them.
|
||||
- API routes: `server/` directory.
|
||||
|
||||
#### Pydantic and modeling
|
||||
|
||||
- Prefer **Pydantic** over dataclasses.
|
||||
- If absolutely required, use `allow_arbitrary_types`.
|
||||
|
||||
#### Data conventions
|
||||
|
||||
- Prefer explicit `None` over sentinel empty strings (usually; depends on intent).
|
||||
- Prefer explicit identifiers: use string enums instead of integer codes.
|
||||
- Avoid magic numbers (co-location is good when necessary). **Always avoid magic strings.**
|
||||
|
||||
#### Logging
|
||||
|
||||
- Log messages where they are created.
|
||||
- Don't propagate log messages around just to log them elsewhere.
|
||||
|
||||
#### Encapsulation
|
||||
|
||||
- Don't use private attributes/methods/properties from other classes/modules.
|
||||
- "Private" is private — respect that boundary.
|
||||
|
||||
#### SQLAlchemy guidance
|
||||
|
||||
- Lazy loading is often bad at scale, especially across multiple list relationships.
|
||||
- Be careful when accessing SQLAlchemy object attributes:
|
||||
- It can help avoid redundant DB queries,
|
||||
@@ -448,6 +464,7 @@ Add clear comments:
|
||||
- Reference: https://www.reddit.com/r/SQLAlchemy/comments/138f248/joinedload_vs_selectinload/
|
||||
|
||||
#### Trunk-based development and feature flags
|
||||
|
||||
- **PRs should contain no more than 500 lines of real change.**
|
||||
- **Merge to main frequently.** Avoid long-lived feature branches — they create merge conflicts and integration pain.
|
||||
- **Use feature flags for incremental rollout.**
|
||||
@@ -458,6 +475,7 @@ Add clear comments:
|
||||
- **Test both flag states.** Ensure the codebase works correctly with the flag on and off.
|
||||
|
||||
#### Miscellaneous
|
||||
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username of the owner of that TODO, or an issue number for an issue referencing that piece of work.
|
||||
- Avoid module-level logic that runs on import, which leads to import-time side effects. Essentially every piece of meaningful logic should exist within some function that has to be explicitly invoked. Acceptable exceptions may include loading environment variables or setting up loggers.
|
||||
- If you find yourself needing something like this, you may want that logic to exist in a file dedicated for manual execution (contains `if __name__ == "__main__":`) which should not be imported by anything else.
|
||||
|
||||
@@ -26,7 +26,9 @@ from shared_configs.configs import (
|
||||
TENANT_ID_PREFIX,
|
||||
)
|
||||
from onyx.db.models import Base
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
from celery.backends.database.session import ( # ty: ignore[unresolved-import]
|
||||
ResultModelBase,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
|
||||
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
|
||||
|
||||
@@ -49,7 +49,7 @@ def upgrade() -> None:
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
server_onupdate=sa.text("now()"), # type: ignore
|
||||
server_onupdate=sa.text("now()"), # ty: ignore[invalid-argument-type]
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
|
||||
@@ -68,7 +68,7 @@ def upgrade() -> None:
|
||||
sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"),
|
||||
{"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]},
|
||||
).fetchone()
|
||||
tool_id = result[0] # type: ignore
|
||||
tool_id = result[0] # ty: ignore[not-subscriptable]
|
||||
|
||||
# Associate the tool with all existing personas
|
||||
# Get all persona IDs
|
||||
|
||||
@@ -52,7 +52,7 @@ def upgrade() -> None:
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
default=lambda: datetime.datetime.now(datetime.timezone.utc),
|
||||
),
|
||||
sa.Column(
|
||||
"cc_pair_id",
|
||||
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from typing import cast, Any
|
||||
from typing import cast
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
@@ -255,7 +255,7 @@ def _migrate_files_to_external_storage() -> None:
|
||||
continue
|
||||
|
||||
lobj_id = cast(int, file_record.lobj_oid)
|
||||
file_metadata = cast(Any, file_record.file_metadata)
|
||||
file_metadata = file_record.file_metadata
|
||||
|
||||
# Read file content from PostgreSQL
|
||||
try:
|
||||
|
||||
@@ -112,7 +112,7 @@ def _get_access_for_documents(
|
||||
access_map[document_id] = DocumentAccess.build(
|
||||
user_emails=list(non_ee_access.user_emails),
|
||||
user_groups=user_group_info.get(document_id, []),
|
||||
is_public=is_public_anywhere,
|
||||
is_public=is_public_anywhere, # ty: ignore[invalid-argument-type]
|
||||
external_user_emails=list(ext_u_emails),
|
||||
external_user_group_ids=list(ext_u_groups),
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import jwt
|
||||
from fastapi import Depends
|
||||
@@ -58,7 +59,7 @@ def generate_anonymous_user_jwt_token(tenant_id: str) -> str:
|
||||
payload = {
|
||||
"tenant_id": tenant_id,
|
||||
# Token does not expire
|
||||
"iat": datetime.utcnow(), # Issued at time
|
||||
"iat": datetime.now(timezone.utc), # Issued at time
|
||||
}
|
||||
|
||||
return jwt.encode(payload, USER_AUTH_SECRET, algorithm="HS256")
|
||||
|
||||
@@ -80,6 +80,7 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
@@ -208,6 +209,11 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
# Tenant-work-gating hook: refresh this tenant's active-set membership
|
||||
# whenever doc-permission sync has any due cc_pairs to dispatch.
|
||||
if cc_pair_ids_to_sync:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
|
||||
@@ -69,6 +69,7 @@ from onyx.redis.redis_connector_ext_group_sync import (
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
@@ -202,6 +203,11 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
# Tenant-work-gating hook: refresh this tenant's active-set membership
|
||||
# whenever external-group sync has any due cc_pairs to dispatch.
|
||||
if cc_pair_ids_to_sync:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
|
||||
@@ -53,7 +53,7 @@ def fetch_query_analytics(
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
return db_session.execute(stmt).all() # ty: ignore[invalid-return-type]
|
||||
|
||||
|
||||
def fetch_per_user_query_analytics(
|
||||
@@ -92,7 +92,7 @@ def fetch_per_user_query_analytics(
|
||||
.order_by(cast(ChatMessage.time_sent, Date), ChatSession.user_id)
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
return db_session.execute(stmt).all() # ty: ignore[invalid-return-type]
|
||||
|
||||
|
||||
def fetch_onyxbot_analytics(
|
||||
|
||||
@@ -9,7 +9,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_sources_with_connectors(db_session: Session) -> list[DocumentSource]:
|
||||
sources = db_session.query(distinct(Connector.source)).all() # type: ignore
|
||||
sources = db_session.query(distinct(Connector.source)).all()
|
||||
|
||||
document_sources = [source[0] for source in sources]
|
||||
|
||||
|
||||
@@ -128,9 +128,9 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
select(func.count())
|
||||
.select_from(User)
|
||||
.where(
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.is_active == True, # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
User.email != ANONYMOUS_USER_EMAIL, # type: ignore
|
||||
User.email != ANONYMOUS_USER_EMAIL,
|
||||
User.account_type != AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -121,7 +121,7 @@ class ScimDAL(DAL):
|
||||
"""Update the last_used_at timestamp for a token."""
|
||||
token = self._session.get(ScimToken, token_id)
|
||||
if token:
|
||||
token.last_used_at = func.now() # type: ignore[assignment]
|
||||
token.last_used_at = func.now()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# User mapping operations
|
||||
@@ -229,7 +229,7 @@ class ScimDAL(DAL):
|
||||
def get_user(self, user_id: UUID) -> User | None:
|
||||
"""Fetch a user by ID."""
|
||||
return self._session.scalar(
|
||||
select(User).where(User.id == user_id) # type: ignore[arg-type]
|
||||
select(User).where(User.id == user_id) # ty: ignore[invalid-argument-type]
|
||||
)
|
||||
|
||||
def get_user_by_email(self, email: str) -> User | None:
|
||||
@@ -293,16 +293,22 @@ class ScimDAL(DAL):
|
||||
if attr == "username":
|
||||
# arg-type: fastapi-users types User.email as str, not a column expression
|
||||
# assignment: union return type widens but query is still Select[tuple[User]]
|
||||
query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment]
|
||||
query = _apply_scim_string_op(
|
||||
query, User.email, scim_filter # ty: ignore[invalid-argument-type]
|
||||
)
|
||||
elif attr == "active":
|
||||
query = query.where(
|
||||
User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined]
|
||||
User.is_active.is_( # ty: ignore[unresolved-attribute]
|
||||
scim_filter.value.lower() == "true"
|
||||
)
|
||||
)
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_user_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
return [], 0
|
||||
query = query.where(User.id == mapping.user_id) # type: ignore[arg-type]
|
||||
query = query.where(
|
||||
User.id == mapping.user_id # ty: ignore[invalid-argument-type]
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported filter attribute: {scim_filter.attribute}"
|
||||
@@ -318,7 +324,9 @@ class ScimDAL(DAL):
|
||||
offset = max(start_index - 1, 0)
|
||||
users = list(
|
||||
self._session.scalars(
|
||||
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
|
||||
query.order_by(User.id) # ty: ignore[invalid-argument-type]
|
||||
.offset(offset)
|
||||
.limit(count)
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
@@ -577,7 +585,7 @@ class ScimDAL(DAL):
|
||||
attr = scim_filter.attribute.lower()
|
||||
if attr == "displayname":
|
||||
# assignment: union return type widens but query is still Select[tuple[UserGroup]]
|
||||
query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment]
|
||||
query = _apply_scim_string_op(query, UserGroup.name, scim_filter)
|
||||
elif attr == "externalid":
|
||||
mapping = self.get_group_mapping_by_external_id(scim_filter.value)
|
||||
if not mapping:
|
||||
@@ -615,7 +623,9 @@ class ScimDAL(DAL):
|
||||
|
||||
users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
select(User).where(
|
||||
User.id.in_(user_ids) # ty: ignore[unresolved-attribute]
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
@@ -640,7 +650,9 @@ class ScimDAL(DAL):
|
||||
return []
|
||||
existing_users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
select(User).where(
|
||||
User.id.in_(uuids) # ty: ignore[unresolved-attribute]
|
||||
)
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
|
||||
@@ -300,8 +300,11 @@ def fetch_user_groups_for_user(
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
.join(User__UserGroup, User__UserGroup.user_group_id == UserGroup.id)
|
||||
.join(User, User.id == User__UserGroup.user_id) # type: ignore
|
||||
.where(User.id == user_id) # type: ignore
|
||||
.join(
|
||||
User,
|
||||
User.id == User__UserGroup.user_id, # ty: ignore[invalid-argument-type]
|
||||
)
|
||||
.where(User.id == user_id) # ty: ignore[invalid-argument-type]
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
@@ -430,7 +433,7 @@ def fetch_user_groups_for_documents(
|
||||
.group_by(Document.id)
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
return db_session.execute(stmt).all() # ty: ignore[invalid-return-type]
|
||||
|
||||
|
||||
def _check_user_group_is_modifiable(user_group: UserGroup) -> None:
|
||||
@@ -804,7 +807,9 @@ def update_user_group(
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
removed_users = db_session.scalars(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
select(User).where(
|
||||
User.id.in_(removed_user_ids) # ty: ignore[unresolved-attribute]
|
||||
)
|
||||
).unique()
|
||||
|
||||
# Filter out admin and global curator users before validating curator status
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Iterator
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.discovery import Resource
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
|
||||
@@ -38,7 +38,7 @@ def get_folder_permissions_by_ids(
|
||||
A list of permissions matching the provided permission IDs
|
||||
"""
|
||||
return get_permissions_by_ids(
|
||||
drive_service=service,
|
||||
drive_service=service, # ty: ignore[invalid-argument-type]
|
||||
doc_id=folder_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
@@ -68,7 +68,7 @@ def get_modified_folders(
|
||||
|
||||
# Retrieve and yield folders
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.errors import HttpError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
@@ -183,7 +183,7 @@ def _get_drive_members(
|
||||
)
|
||||
|
||||
admin_user_info = (
|
||||
admin_service.users()
|
||||
admin_service.users() # ty: ignore[unresolved-attribute]
|
||||
.get(userKey=google_drive_connector.primary_admin_email)
|
||||
.execute()
|
||||
)
|
||||
@@ -197,7 +197,7 @@ def _get_drive_members(
|
||||
|
||||
try:
|
||||
for permission in execute_paginated_retrieval(
|
||||
drive_service.permissions().list,
|
||||
drive_service.permissions().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="permissions",
|
||||
fileId=drive_id,
|
||||
fields="permissions(emailAddress, type),nextPageToken",
|
||||
@@ -256,7 +256,7 @@ def _get_all_google_groups(
|
||||
"""
|
||||
group_emails: set[str] = set()
|
||||
for group in execute_paginated_retrieval(
|
||||
admin_service.groups().list,
|
||||
admin_service.groups().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="groups",
|
||||
domain=google_domain,
|
||||
fields="groups(email),nextPageToken",
|
||||
@@ -274,7 +274,7 @@ def _google_group_to_onyx_group(
|
||||
"""
|
||||
group_member_emails: set[str] = set()
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
admin_service.members().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email),nextPageToken",
|
||||
@@ -298,7 +298,7 @@ def _map_group_email_to_member_emails(
|
||||
for group_email in group_emails:
|
||||
group_member_emails: set[str] = set()
|
||||
for member in execute_paginated_retrieval(
|
||||
admin_service.members().list,
|
||||
admin_service.members().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="members",
|
||||
groupKey=group_email,
|
||||
fields="members(email),nextPageToken",
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_permissions_by_ids(
|
||||
|
||||
# Fetch all permissions for the document
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
retrieval_function=drive_service.permissions().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="permissions",
|
||||
fileId=doc_id,
|
||||
fields="permissions(id, emailAddress, type, domain, allowFileDiscovery, permissionDetails),nextPageToken",
|
||||
|
||||
@@ -68,7 +68,7 @@ def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]:
|
||||
logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}")
|
||||
continue
|
||||
|
||||
permission = Permission(**raw_perm.raw)
|
||||
permission = Permission(**raw_perm.raw) # ty: ignore[invalid-argument-type]
|
||||
|
||||
# We only care about ability to browse through projects + issues (not other permissions such as read/write).
|
||||
if permission.permission != "BROWSE_PROJECTS":
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
from office365.sharepoint.client_context import ClientContext
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
|
||||
@@ -7,11 +7,11 @@ from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.runtime.client_request import ClientRequestException # type: ignore
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
|
||||
from office365.graph_client import GraphClient
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem
|
||||
from office365.runtime.client_request import ClientRequestException
|
||||
from office365.sharepoint.client_context import ClientContext
|
||||
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
|
||||
@@ -46,9 +46,10 @@ def get_query_analytics(
|
||||
daily_query_usage_info = fetch_query_analytics(
|
||||
start=start
|
||||
or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
), # default is 30d lookback
|
||||
end=end or datetime.datetime.utcnow(),
|
||||
end=end or datetime.datetime.now(tz=datetime.timezone.utc),
|
||||
db_session=db_session,
|
||||
)
|
||||
return [
|
||||
@@ -77,9 +78,10 @@ def get_user_analytics(
|
||||
daily_query_usage_info_per_user = fetch_per_user_query_analytics(
|
||||
start=start
|
||||
or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
), # default is 30d lookback
|
||||
end=end or datetime.datetime.utcnow(),
|
||||
end=end or datetime.datetime.now(tz=datetime.timezone.utc),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -111,9 +113,10 @@ def get_onyxbot_analytics(
|
||||
daily_onyxbot_info = fetch_onyxbot_analytics(
|
||||
start=start
|
||||
or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
), # default is 30d lookback
|
||||
end=end or datetime.datetime.utcnow(),
|
||||
end=end or datetime.datetime.now(tz=datetime.timezone.utc),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -146,9 +149,10 @@ def get_persona_messages(
|
||||
) -> list[PersonaMessageAnalyticsResponse]:
|
||||
"""Fetch daily message counts for a single persona within the given time range."""
|
||||
start = start or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
)
|
||||
end = end or datetime.datetime.utcnow()
|
||||
end = end or datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
|
||||
persona_message_counts = []
|
||||
for count, date in fetch_persona_message_analytics(
|
||||
@@ -226,9 +230,10 @@ def get_assistant_stats(
|
||||
along with the overall total messages and total distinct users.
|
||||
"""
|
||||
start = start or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
- datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
)
|
||||
end = end or datetime.datetime.utcnow()
|
||||
end = end or datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
|
||||
if not user_can_view_assistant_stats(db_session, user, assistant_id):
|
||||
raise HTTPException(
|
||||
|
||||
@@ -287,8 +287,10 @@ def update_hook(
|
||||
validated_is_reachable: bool | None = None
|
||||
if endpoint_url_changing or api_key_changing or timeout_changing:
|
||||
existing = _get_hook_or_404(db_session, hook_id)
|
||||
effective_url: str = (
|
||||
req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update
|
||||
effective_url: str = ( # ty: ignore[invalid-assignment]
|
||||
req.endpoint_url
|
||||
if endpoint_url_changing
|
||||
else existing.endpoint_url # endpoint_url is required on create and cannot be cleared on update
|
||||
)
|
||||
effective_api_key: str | None = (
|
||||
(api_key if not isinstance(api_key, UnsetType) else None)
|
||||
@@ -299,8 +301,10 @@ def update_hook(
|
||||
else None
|
||||
)
|
||||
)
|
||||
effective_timeout: float = (
|
||||
req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
|
||||
effective_timeout: float = ( # ty: ignore[invalid-assignment]
|
||||
req.timeout_seconds
|
||||
if timeout_changing
|
||||
else existing.timeout_seconds # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest)
|
||||
)
|
||||
validation = _validate_endpoint(
|
||||
endpoint_url=effective_url,
|
||||
|
||||
@@ -97,7 +97,7 @@ def fetch_and_process_chat_session_history(
|
||||
break
|
||||
|
||||
paged_snapshots = parallel_yield(
|
||||
[
|
||||
[ # ty: ignore[invalid-argument-type]
|
||||
yield_snapshot_from_chat_session(
|
||||
db_session=db_session,
|
||||
chat_session=chat_session,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
import jwt
|
||||
from fastapi import HTTPException
|
||||
@@ -19,8 +20,8 @@ def generate_data_plane_token() -> str:
|
||||
|
||||
payload = {
|
||||
"iss": "data_plane",
|
||||
"exp": datetime.utcnow() + timedelta(minutes=5),
|
||||
"iat": datetime.utcnow(),
|
||||
"exp": datetime.now(tz=timezone.utc) + timedelta(minutes=5),
|
||||
"iat": datetime.now(tz=timezone.utc),
|
||||
"scope": "api_access",
|
||||
}
|
||||
|
||||
|
||||
@@ -55,8 +55,10 @@ def run_alembic_migrations(schema_name: str) -> None:
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schemas={schema_name}"] # type: ignore
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # ty: ignore[invalid-assignment]
|
||||
alembic_cfg.cmd_opts.x = [ # ty: ignore[invalid-assignment]
|
||||
f"schemas={schema_name}"
|
||||
]
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
@@ -349,8 +349,9 @@ def get_tenant_count(tenant_id: str) -> int:
|
||||
user_count = (
|
||||
db_session.query(User)
|
||||
.filter(
|
||||
User.email.in_(emails), # type: ignore
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.email.in_(emails), # ty: ignore[unresolved-attribute]
|
||||
User.is_active # noqa: E712 # ty: ignore[invalid-argument-type]
|
||||
== True,
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ def capture_and_sync_with_alternate_posthog(
|
||||
cloud_props.pop("onyx_cloud_user_id", None)
|
||||
|
||||
posthog.identify(
|
||||
distinct_id=cloud_user_id,
|
||||
distinct_id=cloud_user_id, # ty: ignore[possibly-unresolved-reference]
|
||||
properties=cloud_props,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -105,7 +105,7 @@ def get_anon_id_from_request(request: Any) -> str | None:
|
||||
if (cookie_value := request.cookies.get(cookie_name)) and (
|
||||
parsed := parse_posthog_cookie(cookie_value)
|
||||
):
|
||||
return parsed.get("distinct_id")
|
||||
return parsed.get("distinct_id") # ty: ignore[possibly-unresolved-reference]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
# from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from setfit import SetFitModel # type: ignore[import-untyped]
|
||||
# from setfit import SetFitModel
|
||||
# from transformers import PreTrainedTokenizer, BatchEncoding
|
||||
|
||||
|
||||
@@ -423,7 +423,7 @@
|
||||
# def map_keywords(
|
||||
# input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool]
|
||||
# ) -> list[str]:
|
||||
# tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore
|
||||
# tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
||||
|
||||
# if not len(tokens) == len(is_keyword):
|
||||
# raise ValueError("Length of tokens and keyword predictions must match")
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
# super().__init__()
|
||||
# config = DistilBertConfig()
|
||||
# self.distilbert = DistilBertModel(config)
|
||||
# config = self.distilbert.config # type: ignore
|
||||
# config = self.distilbert.config
|
||||
|
||||
# # Keyword tokenwise binary classification layer
|
||||
# self.keyword_classifier = nn.Linear(config.dim, 2)
|
||||
@@ -85,7 +85,7 @@
|
||||
|
||||
# self.config = config
|
||||
# self.distilbert = DistilBertModel(config)
|
||||
# config = self.distilbert.config # type: ignore
|
||||
# config = self.distilbert.config
|
||||
# self.connector_global_classifier = nn.Linear(config.dim, 1)
|
||||
# self.connector_match_classifier = nn.Linear(config.dim, 1)
|
||||
# self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
@@ -7,8 +7,8 @@ from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
from email.utils import make_msgid
|
||||
|
||||
import sendgrid # type: ignore
|
||||
from sendgrid.helpers.mail import Attachment # type: ignore
|
||||
import sendgrid
|
||||
from sendgrid.helpers.mail import Attachment
|
||||
from sendgrid.helpers.mail import Content
|
||||
from sendgrid.helpers.mail import ContentId
|
||||
from sendgrid.helpers.mail import Disposition
|
||||
|
||||
@@ -10,7 +10,7 @@ from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey
|
||||
from jwt import decode as jwt_decode
|
||||
from jwt import InvalidTokenError
|
||||
from jwt import PyJWTError
|
||||
from jwt.algorithms import RSAAlgorithm
|
||||
from jwt.algorithms import RSAAlgorithm # ty: ignore[possibly-missing-import]
|
||||
|
||||
from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -46,8 +46,10 @@ async def _test_expire_oauth_token(
|
||||
|
||||
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
|
||||
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
await user_manager.user_db.update_oauth_account( # ty: ignore[invalid-argument-type]
|
||||
user, # ty: ignore[invalid-argument-type]
|
||||
cast(Any, oauth_account),
|
||||
updated_data,
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -132,8 +134,10 @@ async def refresh_oauth_token(
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
await user_manager.user_db.update_oauth_account( # ty: ignore[invalid-argument-type]
|
||||
user, # ty: ignore[invalid-argument-type]
|
||||
cast(Any, oauth_account),
|
||||
updated_data,
|
||||
)
|
||||
|
||||
logger.info(f"Successfully refreshed OAuth token for {user.email}")
|
||||
|
||||
@@ -191,7 +191,7 @@ class OAuthTokenManager:
|
||||
@staticmethod
|
||||
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
|
||||
if isinstance(value, SensitiveValue):
|
||||
return value.get_value(apply_mask=False)
|
||||
return value.get_value(apply_mask=False) # ty: ignore[invalid-return-type]
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
@@ -199,5 +199,7 @@ class OAuthTokenManager:
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(token_data, SensitiveValue):
|
||||
return token_data.get_value(apply_mask=False)
|
||||
return token_data.get_value( # ty: ignore[invalid-return-type]
|
||||
apply_mask=False
|
||||
)
|
||||
return token_data
|
||||
|
||||
@@ -121,5 +121,7 @@ def require_permission(
|
||||
|
||||
return user
|
||||
|
||||
dependency._is_require_permission = True # type: ignore[attr-defined] # sentinel for auth_check detection
|
||||
dependency._is_require_permission = ( # ty: ignore[unresolved-attribute]
|
||||
True # sentinel for auth_check detection
|
||||
)
|
||||
return dependency
|
||||
|
||||
@@ -45,7 +45,9 @@ from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.authentication import RedisStrategy
|
||||
from fastapi_users.authentication import (
|
||||
RedisStrategy, # ty: ignore[possibly-missing-import]
|
||||
)
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
@@ -462,14 +464,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self.user_db = tenant_user_db
|
||||
|
||||
if hasattr(user_create, "role"):
|
||||
user_create.role = UserRole.BASIC
|
||||
user_create.role = UserRole.BASIC # ty: ignore[invalid-assignment]
|
||||
|
||||
user_count = await get_user_count()
|
||||
if (
|
||||
user_count == 0
|
||||
or user_create.email in get_default_admin_user_emails()
|
||||
):
|
||||
user_create.role = UserRole.ADMIN
|
||||
user_create.role = ( # ty: ignore[invalid-assignment]
|
||||
UserRole.ADMIN
|
||||
)
|
||||
|
||||
# Check seat availability for new users (single-tenant only)
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
@@ -516,7 +520,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
user = await self.user_db.get( # ty: ignore[invalid-assignment]
|
||||
user_id
|
||||
)
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
@@ -544,7 +550,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
user = await self.user_db.get( # ty: ignore[invalid-assignment]
|
||||
user_id
|
||||
)
|
||||
if user_created:
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
@@ -592,7 +600,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
update nor the group assignment is visible without the other.
|
||||
"""
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
|
||||
sync_user = (
|
||||
sync_db.query(User)
|
||||
.filter(User.id == user_id) # ty: ignore[invalid-argument-type]
|
||||
.first()
|
||||
)
|
||||
if sync_user:
|
||||
sync_user.hashed_password = self.password_helper.hash(
|
||||
user_create.password
|
||||
@@ -613,7 +625,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
async def validate_password( # ty: ignore[invalid-method-override]
|
||||
self, password: str, _: schemas.UC | models.UP
|
||||
) -> None:
|
||||
# Validate password according to configurable security policy (defined via environment variables)
|
||||
if len(password) < PASSWORD_MIN_LENGTH:
|
||||
raise exceptions.InvalidPasswordException(
|
||||
@@ -644,7 +658,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
return
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
async def oauth_callback(
|
||||
async def oauth_callback( # ty: ignore[invalid-method-override]
|
||||
self,
|
||||
oauth_name: str,
|
||||
access_token: str,
|
||||
@@ -754,7 +768,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
existing_oauth_account, # ty: ignore[invalid-argument-type]
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
@@ -788,7 +802,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# transaction so neither change is visible without the other.
|
||||
was_inactive = not user.is_active
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
sync_user = (
|
||||
sync_db.query(User)
|
||||
.filter(User.id == user.id) # ty: ignore[invalid-argument-type]
|
||||
.first()
|
||||
)
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
@@ -808,7 +826,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
if user.oidc_expiry is not None and not TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
await self.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
user.oidc_expiry = None # ty: ignore[invalid-assignment]
|
||||
remove_user_from_invited_users(user.email)
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
@@ -925,7 +943,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
and (marketing_cookie_value := request.cookies.get(marketing_cookie_name))
|
||||
and (parsed_cookie := parse_posthog_cookie(marketing_cookie_value))
|
||||
):
|
||||
marketing_anonymous_id = parsed_cookie["distinct_id"]
|
||||
marketing_anonymous_id = (
|
||||
parsed_cookie[ # ty: ignore[possibly-unresolved-reference]
|
||||
"distinct_id"
|
||||
]
|
||||
)
|
||||
|
||||
# Technically, USER_SIGNED_UP is only fired from the cloud site when
|
||||
# it is the first user in a tenant. However, it is semantically correct
|
||||
@@ -942,7 +964,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
}
|
||||
|
||||
# Add all other values from the marketing cookie (featureFlags, etc.)
|
||||
for key, value in parsed_cookie.items():
|
||||
for (
|
||||
key,
|
||||
value,
|
||||
) in parsed_cookie.items(): # ty: ignore[possibly-unresolved-reference]
|
||||
if key != "distinct_id":
|
||||
properties.setdefault(key, value)
|
||||
|
||||
@@ -1504,7 +1529,7 @@ async def _sync_jwt_oidc_expiry(
|
||||
|
||||
if user.oidc_expiry is not None:
|
||||
await user_manager.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
user.oidc_expiry = None # ty: ignore[invalid-assignment]
|
||||
|
||||
|
||||
async def _get_or_create_user_from_jwt(
|
||||
@@ -2232,7 +2257,7 @@ def get_oauth_router(
|
||||
|
||||
# Proceed to authenticate or create the user
|
||||
try:
|
||||
user = await user_manager.oauth_callback(
|
||||
user = await user_manager.oauth_callback( # ty: ignore[invalid-argument-type]
|
||||
oauth_client.name,
|
||||
token["access_token"],
|
||||
account_id,
|
||||
|
||||
@@ -6,16 +6,16 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import bootsteps # ty: ignore[unresolved-import]
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.app import trace # ty: ignore[unresolved-import]
|
||||
from celery.exceptions import WorkerShutdown
|
||||
from celery.signals import before_task_publish
|
||||
from celery.signals import task_postrun
|
||||
from celery.signals import task_prerun
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from celery.worker import strategy # ty: ignore[unresolved-import]
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.beat import PersistentScheduler # ty: ignore[unresolved-import]
|
||||
from celery.signals import beat_init
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
|
||||
@@ -4,4 +4,4 @@ import onyx.background.celery.apps.app_base as app_base
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.client")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
|
||||
|
||||
@@ -29,7 +29,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.docfetching")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
@@ -100,7 +100,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
@@ -30,7 +30,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.docprocessing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
@@ -106,7 +106,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||
# setting pre ping might help even more, but not worrying about that yet
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
@@ -27,7 +27,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.heavy")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
@@ -92,7 +92,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
@@ -29,7 +29,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.light")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
@@ -95,19 +95,26 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
|
||||
logger.info(
|
||||
f"Concurrency: {sender.concurrency}" # ty: ignore[unresolved-attribute]
|
||||
)
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
SqlEngine.init_engine(
|
||||
pool_size=sender.concurrency, # ty: ignore[unresolved-attribute]
|
||||
max_overflow=EXTRA_CONCURRENCY,
|
||||
)
|
||||
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # ty: ignore[unresolved-attribute]
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY # ty: ignore[unresolved-attribute]
|
||||
)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -20,7 +20,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import bootsteps # ty: ignore[unresolved-import]
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
@@ -38,6 +38,12 @@ from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -46,7 +52,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.primary")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
@@ -59,6 +65,7 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -73,6 +80,31 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -85,7 +117,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
|
||||
SqlEngine.init_engine(
|
||||
pool_size=pool_size, max_overflow=CELERY_WORKER_PRIMARY_POOL_OVERFLOW
|
||||
)
|
||||
@@ -145,7 +177,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
raise WorkerShutdown("Primary worker lock could not be acquired!")
|
||||
|
||||
# tacking on our own user data to the sender
|
||||
sender.primary_worker_lock = lock # type: ignore
|
||||
sender.primary_worker_lock = lock # ty: ignore[unresolved-attribute]
|
||||
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
@@ -212,6 +244,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("primary")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.user_file_processing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
celery_app.Task = app_base.TenantAwareTask # ty: ignore[invalid-assignment]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
@@ -66,7 +66,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
# "SSL connection has been closed unexpectedly"
|
||||
# actually setting the spawn method in the cloud fixes 95% of these.
|
||||
# setting pre ping might help even more, but not worrying about that yet
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
pool_size = cast(int, sender.concurrency) # ty: ignore[unresolved-attribute]
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=8)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
@@ -179,7 +179,7 @@ def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
workers: dict[str, Any] = inspect.ping() # ty: ignore[invalid-assignment]
|
||||
if workers:
|
||||
for worker_name in list(workers.keys()):
|
||||
# if the name filter not set, return all worker names
|
||||
@@ -208,7 +208,9 @@ def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
|
||||
reserved_tasks: dict[str, list] | None = ( # ty: ignore[invalid-assignment]
|
||||
inspect.reserved()
|
||||
)
|
||||
if reserved_tasks:
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
@@ -229,7 +231,9 @@ def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
|
||||
active_tasks: dict[str, list] | None = ( # ty: ignore[invalid-assignment]
|
||||
inspect.active()
|
||||
)
|
||||
if active_tasks:
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
|
||||
@@ -6,6 +6,7 @@ from celery.schedules import crontab
|
||||
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import DISABLE_OPENSEARCH_MIGRATION_TASK
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
@@ -226,7 +227,7 @@ if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
)
|
||||
|
||||
# Add OpenSearch migration task if enabled.
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX and not DISABLE_OPENSEARCH_MIGRATION_TASK:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "migrate-chunks-from-vespa-to-opensearch",
|
||||
|
||||
@@ -59,6 +59,7 @@ from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_blocked
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_completed
|
||||
from onyx.server.metrics.deletion_metrics import inc_deletion_fence_reset
|
||||
@@ -172,6 +173,11 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
for cc_pair in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair.id)
|
||||
|
||||
# Tenant-work-gating hook: any cc_pair means deletion could have
|
||||
# cleanup work to do for this tenant on some cycle.
|
||||
if cc_pair_ids:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
@@ -34,6 +34,7 @@ from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
@@ -470,6 +471,15 @@ def docfetching_proxy_task(
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
|
||||
cc_pair = index_attempt.connector_credential_pair
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=result.connector_source,
|
||||
cc_pair_id=cc_pair_id,
|
||||
connector_name=cc_pair.connector.name or f"cc_pair_{cc_pair_id}",
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
|
||||
@@ -108,6 +108,7 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_error_state_change
|
||||
from onyx.server.metrics.connector_health_metrics import on_connector_indexing_success
|
||||
@@ -810,7 +811,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
# redis_client_celery: Redis = self.app.broker_connection().channel().client
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
@@ -896,6 +897,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
|
||||
secondary_cc_pair_ids = standard_cc_pair_ids
|
||||
|
||||
# Tenant-work-gating hook: refresh this tenant's active-set membership
|
||||
# whenever indexing actually has work to dispatch.
|
||||
if primary_cc_pair_ids or secondary_cc_pair_ids:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
# Flag CC pairs in repeated error state for primary/current search settings
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for cc_pair_id in primary_cc_pair_ids:
|
||||
|
||||
@@ -72,6 +72,7 @@ from onyx.redis.redis_hierarchy import get_source_node_id_from_cache
|
||||
from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
|
||||
from onyx.server.metrics.pruning_metrics import observe_pruning_diff_duration
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.server.utils import make_short_id
|
||||
@@ -228,6 +229,11 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
# Tenant-work-gating hook: any cc_pair means pruning could have
|
||||
# work to do for this tenant on some cycle.
|
||||
if cc_pair_ids:
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.db.document import construct_document_id_select_by_needs_sync
|
||||
from onyx.db.document import count_documents_by_needs_sync
|
||||
from onyx.redis.redis_tenant_work_gating import maybe_mark_tenant_active
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# Redis keys for document sync tracking
|
||||
@@ -150,6 +151,10 @@ def try_generate_stale_document_sync_tasks(
|
||||
logger.info("No stale documents found. Skipping sync tasks generation.")
|
||||
return None
|
||||
|
||||
# Tenant-work-gating hook: refresh this tenant's active-set membership
|
||||
# whenever vespa sync actually has stale docs to dispatch.
|
||||
maybe_mark_tenant_active(tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch."
|
||||
)
|
||||
|
||||
@@ -61,7 +61,9 @@ def load_checkpoint(
|
||||
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
|
||||
checkpoint_data = checkpoint_io.read().decode("utf-8")
|
||||
if isinstance(connector, CheckpointedConnector):
|
||||
return connector.validate_checkpoint_json(checkpoint_data)
|
||||
return connector.validate_checkpoint_json( # ty: ignore[invalid-return-type]
|
||||
checkpoint_data
|
||||
)
|
||||
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
|
||||
|
||||
|
||||
|
||||
@@ -69,7 +69,6 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
get_persistent_document_writer,
|
||||
)
|
||||
from onyx.server.metrics.connector_health_metrics import on_index_attempt_status_change
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.postgres_sanitization import sanitize_document_for_postgres
|
||||
@@ -269,14 +268,6 @@ def run_docfetching_entrypoint(
|
||||
)
|
||||
credential_id = attempt.connector_credential_pair.credential_id
|
||||
|
||||
on_index_attempt_status_change(
|
||||
tenant_id=tenant_id,
|
||||
source=attempt.connector_credential_pair.connector.source.value,
|
||||
cc_pair_id=connector_credential_pair_id,
|
||||
connector_name=connector_name or f"cc_pair_{connector_credential_pair_id}",
|
||||
status="in_progress",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Docfetching starting{tenant_str}: "
|
||||
f"connector='{connector_name}' "
|
||||
|
||||
@@ -1164,7 +1164,10 @@ def run_llm_loop(
|
||||
|
||||
emitter.emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=llm_cycle_count + reasoning_cycles),
|
||||
placement=Placement(
|
||||
turn_index=llm_cycle_count # ty: ignore[possibly-unresolved-reference]
|
||||
+ reasoning_cycles
|
||||
),
|
||||
obj=OverallStop(type="stop"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -324,6 +324,9 @@ ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = (
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
DISABLE_OPENSEARCH_MIGRATION_TASK = (
|
||||
os.environ.get("DISABLE_OPENSEARCH_MIGRATION_TASK", "").lower() == "true"
|
||||
)
|
||||
# Whether we should check for and create an index if necessary every time we
|
||||
# instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True.
|
||||
VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = (
|
||||
|
||||
@@ -639,9 +639,11 @@ REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
if platform.system() == "Darwin":
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore[attr-defined,unused-ignore]
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[
|
||||
socket.TCP_KEEPALIVE # ty: ignore[unresolved-attribute]
|
||||
] = 60
|
||||
else:
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore[attr-defined,unused-ignore]
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60
|
||||
|
||||
|
||||
class OnyxCallTypes(str, Enum):
|
||||
|
||||
@@ -547,7 +547,7 @@ class AirtableConnector(LoadConnector):
|
||||
for record in batch_records:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
future_to_record[
|
||||
future_to_record[ # ty: ignore[invalid-assignment]
|
||||
executor.submit(
|
||||
current_context.run,
|
||||
self._process_record,
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
import asana # type: ignore
|
||||
import asana
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -290,8 +290,8 @@ class AxeroConnector(PollConnector):
|
||||
if not self.axero_key or not self.base_url:
|
||||
raise ConnectorMissingCredentialError("Axero")
|
||||
|
||||
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
|
||||
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
entity_types = []
|
||||
if self.include_article:
|
||||
@@ -327,7 +327,7 @@ class AxeroConnector(PollConnector):
|
||||
)
|
||||
|
||||
all_axero_forums = _map_post_to_parent(
|
||||
posts=forums_posts,
|
||||
posts=forums_posts, # ty: ignore[invalid-argument-type]
|
||||
api_key=self.axero_key,
|
||||
axero_base_url=self.base_url,
|
||||
)
|
||||
|
||||
@@ -76,7 +76,9 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
self.bucket_region: Optional[str] = None
|
||||
self.european_residency: bool = european_residency
|
||||
|
||||
def set_allow_images(self, allow_images: bool) -> None:
|
||||
def set_allow_images( # ty: ignore[invalid-method-override]
|
||||
self, allow_images: bool
|
||||
) -> None:
|
||||
"""Set whether to process images in this connector."""
|
||||
logger.info(f"Setting allow_images to {allow_images}.")
|
||||
self._allow_images = allow_images
|
||||
@@ -195,7 +197,9 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
method="sts-assume-role",
|
||||
)
|
||||
botocore_session = get_session()
|
||||
botocore_session._credentials = refreshable # type: ignore[attr-defined]
|
||||
botocore_session._credentials = ( # ty: ignore[unresolved-attribute]
|
||||
refreshable
|
||||
)
|
||||
session = boto3.Session(botocore_session=botocore_session)
|
||||
self.s3_client = session.client("s3")
|
||||
elif authentication_method == "assume_role":
|
||||
|
||||
@@ -2,6 +2,7 @@ import html
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -56,14 +57,14 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
}
|
||||
|
||||
if start:
|
||||
params["filter[updated_at:gte]"] = datetime.utcfromtimestamp(
|
||||
start
|
||||
params["filter[updated_at:gte]"] = datetime.fromtimestamp(
|
||||
start, tz=timezone.utc
|
||||
).strftime("%Y-%m-%d")
|
||||
|
||||
if end:
|
||||
params["filter[updated_at:lte]"] = datetime.utcfromtimestamp(end).strftime(
|
||||
"%Y-%m-%d"
|
||||
)
|
||||
params["filter[updated_at:lte]"] = datetime.fromtimestamp(
|
||||
end, tz=timezone.utc
|
||||
).strftime("%Y-%m-%d")
|
||||
|
||||
batch = bookstack_client.get(endpoint, params=params).get("data", [])
|
||||
doc_batch: list[Document | HierarchyNode] = [
|
||||
|
||||
@@ -95,11 +95,13 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
params["date_updated_lt"] = end
|
||||
|
||||
if self.connector_type == "list":
|
||||
params["list_ids[]"] = self.connector_ids
|
||||
params["list_ids[]"] = self.connector_ids # ty: ignore[invalid-assignment]
|
||||
elif self.connector_type == "folder":
|
||||
params["project_ids[]"] = self.connector_ids
|
||||
params["project_ids[]"] = ( # ty: ignore[invalid-assignment]
|
||||
self.connector_ids
|
||||
)
|
||||
elif self.connector_type == "space":
|
||||
params["space_ids[]"] = self.connector_ids
|
||||
params["space_ids[]"] = self.connector_ids # ty: ignore[invalid-assignment]
|
||||
|
||||
url_endpoint = f"/team/{self.team_id}/task"
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian.errors import ApiError # type: ignore
|
||||
from atlassian.errors import ApiError
|
||||
from requests.exceptions import HTTPError
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from atlassian import Confluence
|
||||
from redis import Redis
|
||||
from requests import HTTPError
|
||||
|
||||
@@ -971,7 +971,7 @@ class OnyxConfluence:
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
from atlassian.errors import ApiPermissionError
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
|
||||
@@ -165,7 +165,7 @@ class ConnectorRunner(Generic[CT]):
|
||||
checkpoint_connector_generator = load_from_checkpoint(
|
||||
start=self.time_range[0].timestamp(),
|
||||
end=self.time_range[1].timestamp(),
|
||||
checkpoint=checkpoint,
|
||||
checkpoint=checkpoint, # ty: ignore[invalid-argument-type]
|
||||
)
|
||||
next_checkpoint: CT | None = None
|
||||
# this is guaranteed to always run at least once with next_checkpoint being non-None
|
||||
@@ -174,7 +174,9 @@ class ConnectorRunner(Generic[CT]):
|
||||
hierarchy_node,
|
||||
failure,
|
||||
next_checkpoint,
|
||||
) in CheckpointOutputWrapper[CT]()(checkpoint_connector_generator):
|
||||
) in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator # ty: ignore[invalid-argument-type]
|
||||
):
|
||||
if document is not None:
|
||||
self.doc_batch.append(document)
|
||||
|
||||
|
||||
@@ -83,7 +83,9 @@ class OnyxDBCredentialsProvider(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
credential.credential_json = ( # ty: ignore[invalid-assignment]
|
||||
credential_json
|
||||
)
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -41,7 +41,7 @@ def tabular_file_to_sections(
|
||||
"""
|
||||
lowered = file_name.lower()
|
||||
|
||||
if lowered.endswith(".xlsx"):
|
||||
if lowered.endswith(tuple(OnyxFileExtensions.SPREADSHEET_EXTENSIONS)):
|
||||
return [
|
||||
TabularSection(
|
||||
link=link or file_name,
|
||||
|
||||
@@ -53,8 +53,10 @@ def _convert_message_to_document(
|
||||
if isinstance(message.channel, TextChannel) and (
|
||||
channel_name := message.channel.name
|
||||
):
|
||||
metadata["Channel"] = channel_name
|
||||
semantic_substring += f" in Channel: #{channel_name}"
|
||||
metadata["Channel"] = channel_name # ty: ignore[possibly-unresolved-reference]
|
||||
semantic_substring += (
|
||||
f" in Channel: #{channel_name}" # ty: ignore[possibly-unresolved-reference]
|
||||
)
|
||||
|
||||
# Single messages dont have a title
|
||||
title = ""
|
||||
|
||||
@@ -221,8 +221,8 @@ class DiscourseConnector(PollConnector):
|
||||
if self.permissions is None:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc)
|
||||
end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc)
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
self._get_categories_map()
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@ from datetime import timezone
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from dropbox import Dropbox # type: ignore[import-untyped]
|
||||
from dropbox.exceptions import ApiError # type: ignore[import-untyped]
|
||||
from dropbox import Dropbox
|
||||
from dropbox.exceptions import ApiError
|
||||
from dropbox.exceptions import AuthError
|
||||
from dropbox.files import FileMetadata # type: ignore[import-untyped]
|
||||
from dropbox.files import FileMetadata
|
||||
from dropbox.files import FolderMetadata
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
|
||||
@@ -189,7 +189,7 @@ def _process_file(
|
||||
if is_tabular_file(file_name):
|
||||
# Produce TabularSections
|
||||
lowered_name = file_name.lower()
|
||||
if lowered_name.endswith(".xlsx"):
|
||||
if lowered_name.endswith(tuple(OnyxFileExtensions.SPREADSHEET_EXTENSIONS)):
|
||||
file.seek(0)
|
||||
tabular_source: IO[bytes] = file
|
||||
else:
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -296,7 +296,9 @@ def _full_thread_from_id(
|
||||
try:
|
||||
thread = next(
|
||||
execute_single_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
retrieval_function=gmail_service.users() # ty: ignore[unresolved-attribute]
|
||||
.threads()
|
||||
.get,
|
||||
list_key=None,
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
@@ -394,7 +396,7 @@ class GmailConnector(
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
retrieval_function=admin_service.users().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
@@ -438,7 +440,9 @@ class GmailConnector(
|
||||
try:
|
||||
for thread in execute_paginated_retrieval_with_max_pages(
|
||||
max_num_pages=PAGES_PER_CHECKPOINT,
|
||||
retrieval_function=gmail_service.users().threads().list,
|
||||
retrieval_function=gmail_service.users() # ty: ignore[unresolved-attribute]
|
||||
.threads()
|
||||
.list,
|
||||
list_key="threads",
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
|
||||
@@ -110,7 +110,7 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
# The batch_ids in the previous method appears to be batches of call_ids to process
|
||||
# In this method, we will retrieve transcripts for them in batches.
|
||||
transcripts: list[dict[str, Any]] = []
|
||||
workspace_list = self.workspaces or [None] # type: ignore
|
||||
workspace_list = self.workspaces or [None]
|
||||
workspace_map = self._get_workspace_id_map() if self.workspaces else {}
|
||||
|
||||
for workspace in workspace_list:
|
||||
|
||||
@@ -18,7 +18,7 @@ from urllib.parse import urlunparse
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.errors import HttpError
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -434,7 +434,7 @@ class GoogleDriveConnector(
|
||||
for is_admin in [True, False]:
|
||||
query = "isAdmin=true" if is_admin else "isAdmin=false"
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
retrieval_function=admin_service.users().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
@@ -719,7 +719,7 @@ class GoogleDriveConnector(
|
||||
)
|
||||
all_drive_ids: set[str] = set()
|
||||
for drive in execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.drives().list,
|
||||
retrieval_function=drive_service.drives().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="drives",
|
||||
useDomainAdminAccess=is_service_account,
|
||||
fields="drives(id),nextPageToken",
|
||||
@@ -907,7 +907,9 @@ class GoogleDriveConnector(
|
||||
# resume from a checkpoint
|
||||
if resuming and (drive_id := curr_stage.current_folder_or_drive_id):
|
||||
resume_start = curr_stage.completed_until
|
||||
for file_or_token in _yield_from_drive(drive_id, resume_start):
|
||||
for file_or_token in _yield_from_drive(
|
||||
drive_id, resume_start # ty: ignore[possibly-unresolved-reference]
|
||||
):
|
||||
if isinstance(file_or_token, str):
|
||||
checkpoint.completion_map[user_email].next_page_token = (
|
||||
file_or_token
|
||||
@@ -1302,7 +1304,9 @@ class GoogleDriveConnector(
|
||||
resume_start = checkpoint.completion_map[
|
||||
self.primary_admin_email
|
||||
].completed_until
|
||||
yield from _yield_from_folder_crawl(folder_id, resume_start)
|
||||
yield from _yield_from_folder_crawl(
|
||||
folder_id, resume_start # ty: ignore[possibly-unresolved-reference]
|
||||
)
|
||||
|
||||
# the times stored in the completion_map aren't used due to the crawling behavior
|
||||
# instead, the traversed_parent_ids are used to determine what we have left to retrieve
|
||||
@@ -1883,7 +1887,9 @@ class GoogleDriveConnector(
|
||||
|
||||
try:
|
||||
drive_service = get_drive_service(self._creds, self._primary_admin_email)
|
||||
drive_service.files().list(pageSize=1, fields="files(id)").execute()
|
||||
drive_service.files().list( # ty: ignore[unresolved-attribute]
|
||||
pageSize=1, fields="files(id)"
|
||||
).execute()
|
||||
|
||||
if isinstance(self._creds, ServiceAccountCredentials):
|
||||
# default is ~17mins of retries, don't do that here since this is called from
|
||||
|
||||
@@ -6,8 +6,8 @@ from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import MediaIoBaseDownload # type: ignore
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import MediaIoBaseDownload
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -65,7 +65,7 @@ def _get_folder_info(
|
||||
|
||||
try:
|
||||
folder = (
|
||||
service.files()
|
||||
service.files() # ty: ignore[unresolved-attribute]
|
||||
.get(
|
||||
fileId=folder_id,
|
||||
fields="name, parents",
|
||||
@@ -91,7 +91,11 @@ def _get_drive_name(service: GoogleDriveService, drive_id: str) -> str:
|
||||
return _folder_cache[cache_key][0]
|
||||
|
||||
try:
|
||||
drive = service.drives().get(driveId=drive_id).execute()
|
||||
drive = (
|
||||
service.drives() # ty: ignore[unresolved-attribute]
|
||||
.get(driveId=drive_id)
|
||||
.execute()
|
||||
)
|
||||
drive_name = drive.get("name", f"Shared Drive {drive_id}")
|
||||
_folder_cache[cache_key] = (drive_name, None)
|
||||
return drive_name
|
||||
@@ -258,7 +262,9 @@ def download_request(
|
||||
"""
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
request = service.files().get_media( # ty: ignore[unresolved-attribute]
|
||||
fileId=file_id
|
||||
)
|
||||
return _download_request(request, file_id, size_threshold)
|
||||
|
||||
|
||||
@@ -331,7 +337,7 @@ def _download_and_extract_sections_basic(
|
||||
# For Google Docs, Sheets, and Slides, export via the Drive API
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
request = service.files().export_media(
|
||||
request = service.files().export_media( # ty: ignore[unresolved-attribute]
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
response = _download_request(request, file_id, size_threshold)
|
||||
@@ -455,7 +461,9 @@ def align_basic_advanced(
|
||||
for adv_ind in range(1, len(adv_sections)):
|
||||
heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0]
|
||||
# retrieve the longest part of the heading that is not a smart chip
|
||||
heading_key = max(heading.split(SMART_CHIP_CHAR), key=len).strip()
|
||||
heading_key = max( # ty: ignore[unresolved-attribute]
|
||||
heading.split(SMART_CHIP_CHAR), key=len
|
||||
).strip()
|
||||
if heading_key == "":
|
||||
logger.warning(
|
||||
f"Cannot match heading: {heading}, its link will come from the following section"
|
||||
|
||||
@@ -7,9 +7,9 @@ from typing import cast
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import BatchHttpRequest # type: ignore
|
||||
from googleapiclient.discovery import Resource
|
||||
from googleapiclient.errors import HttpError
|
||||
from googleapiclient.http import BatchHttpRequest
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
@@ -115,7 +115,7 @@ def _get_folders_in_parent(
|
||||
query += f" and '{parent_id}' in parents"
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
@@ -136,7 +136,7 @@ def get_folder_metadata(
|
||||
fields = _get_hierarchy_fields_for_file_type(field_type)
|
||||
try:
|
||||
return (
|
||||
service.files()
|
||||
service.files() # ty: ignore[unresolved-attribute]
|
||||
.get(
|
||||
fileId=folder_id,
|
||||
fields=fields,
|
||||
@@ -169,7 +169,11 @@ def get_shared_drive_name(
|
||||
folders. Only drives().get() returns the real user-assigned name.
|
||||
"""
|
||||
try:
|
||||
drive = service.drives().get(driveId=drive_id, fields="name").execute()
|
||||
drive = (
|
||||
service.drives() # ty: ignore[unresolved-attribute]
|
||||
.get(driveId=drive_id, fields="name")
|
||||
.execute()
|
||||
)
|
||||
return drive.get("name")
|
||||
except HttpError as e:
|
||||
if e.resp.status in (403, 404):
|
||||
@@ -261,7 +265,7 @@ def _get_files_in_parent(
|
||||
kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}
|
||||
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="allDrives",
|
||||
@@ -371,7 +375,7 @@ def get_files_in_shared_drive(
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
corpora="drive",
|
||||
@@ -389,7 +393,7 @@ def get_files_in_shared_drive(
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
|
||||
for file in execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function=service.files().list,
|
||||
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
|
||||
max_num_pages=max_num_pages,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=True,
|
||||
@@ -436,7 +440,7 @@ def get_all_files_in_my_drive_and_shared(
|
||||
folder_query += " and 'me' in owners"
|
||||
found_folders = False
|
||||
for folder in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="files",
|
||||
corpora="user",
|
||||
fields=_get_fields_for_file_type(field_type),
|
||||
@@ -454,7 +458,7 @@ def get_all_files_in_my_drive_and_shared(
|
||||
file_query += " and 'me' in owners"
|
||||
file_query += generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval_with_max_pages(
|
||||
retrieval_function=service.files().list,
|
||||
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
|
||||
max_num_pages=max_num_pages,
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
@@ -499,7 +503,7 @@ def get_all_files_for_oauth(
|
||||
|
||||
yield from execute_paginated_retrieval_with_max_pages(
|
||||
max_num_pages=max_num_pages,
|
||||
retrieval_function=service.files().list,
|
||||
retrieval_function=service.files().list, # ty: ignore[unresolved-attribute]
|
||||
list_key="files",
|
||||
continue_on_404_or_403=False,
|
||||
corpora=corpora,
|
||||
@@ -516,7 +520,7 @@ def get_root_folder_id(service: Resource) -> str:
|
||||
# we dont paginate here because there is only one root folder per user
|
||||
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
|
||||
return (
|
||||
service.files()
|
||||
service.files() # ty: ignore[unresolved-attribute]
|
||||
.get(fileId="root", fields=GoogleFields.ID.value)
|
||||
.execute()[GoogleFields.ID.value]
|
||||
)
|
||||
@@ -550,7 +554,7 @@ def get_file_by_web_view_link(
|
||||
"""Retrieve a Google Drive file using its webViewLink."""
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
return (
|
||||
service.files()
|
||||
service.files() # ty: ignore[unresolved-attribute]
|
||||
.get(
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
@@ -612,12 +616,17 @@ def _get_files_by_web_view_links_batch(
|
||||
else:
|
||||
result.files[request_id] = response
|
||||
|
||||
batch = cast(BatchHttpRequest, service.new_batch_http_request(callback=callback))
|
||||
batch = cast(
|
||||
BatchHttpRequest,
|
||||
service.new_batch_http_request( # ty: ignore[unresolved-attribute]
|
||||
callback=callback
|
||||
),
|
||||
)
|
||||
|
||||
for web_view_link in web_view_links:
|
||||
try:
|
||||
file_id = _extract_file_id_from_web_view_link(web_view_link)
|
||||
request = service.files().get(
|
||||
request = service.files().get( # ty: ignore[unresolved-attribute]
|
||||
fileId=file_id,
|
||||
supportsAllDrives=True,
|
||||
fields=fields,
|
||||
|
||||
@@ -199,7 +199,7 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
if isinstance(v, set):
|
||||
return ThreadSafeSet(v)
|
||||
if isinstance(v, list):
|
||||
return ThreadSafeSet(set(v))
|
||||
return ThreadSafeSet(set(v)) # ty: ignore[invalid-return-type]
|
||||
return ThreadSafeSet()
|
||||
|
||||
@field_validator("fully_walked_hierarchy_node_raw_ids", mode="before")
|
||||
@@ -209,5 +209,5 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
||||
if isinstance(v, set):
|
||||
return ThreadSafeSet(v)
|
||||
if isinstance(v, list):
|
||||
return ThreadSafeSet(set(v))
|
||||
return ThreadSafeSet(set(v)) # ty: ignore[invalid-return-type]
|
||||
return ThreadSafeSet()
|
||||
|
||||
@@ -85,7 +85,9 @@ def get_document_sections(
|
||||
) -> list[TextSection]:
|
||||
"""Extracts sections from a Google Doc, including their headings and content"""
|
||||
# Fetch the document structure
|
||||
http_request = docs_service.documents().get(documentId=doc_id)
|
||||
http_request = docs_service.documents().get( # ty: ignore[unresolved-attribute]
|
||||
documentId=doc_id
|
||||
)
|
||||
|
||||
# Google has poor support for tabs in the docs api, see
|
||||
# https://cloud.google.com/python/docs/reference/cloudtasks/
|
||||
|
||||
@@ -6,7 +6,7 @@ from urllib.parse import ParseResult
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
@@ -63,7 +63,7 @@ def _load_google_json(raw: object) -> dict[str, Any]:
|
||||
``str`` branch can be removed.
|
||||
"""
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
return raw # ty: ignore[invalid-return-type]
|
||||
if isinstance(raw, str):
|
||||
return json.loads(raw)
|
||||
raise ValueError(f"Unexpected Google credential payload type: {type(raw)!r}")
|
||||
@@ -82,7 +82,7 @@ def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) ->
|
||||
if source == DocumentSource.GOOGLE_DRIVE:
|
||||
drive_service = get_drive_service(creds)
|
||||
user_info = (
|
||||
drive_service.about()
|
||||
drive_service.about() # ty: ignore[unresolved-attribute]
|
||||
.get(
|
||||
fields="user(emailAddress)",
|
||||
)
|
||||
@@ -92,7 +92,7 @@ def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) ->
|
||||
elif source == DocumentSource.GMAIL:
|
||||
gmail_service = get_gmail_service(creds)
|
||||
user_info = (
|
||||
gmail_service.users()
|
||||
gmail_service.users() # ty: ignore[unresolved-attribute]
|
||||
.getProfile(
|
||||
userId="me",
|
||||
fields="emailAddress",
|
||||
@@ -159,7 +159,7 @@ def build_service_account_creds(
|
||||
service_account_key = get_service_account_key(source=source)
|
||||
|
||||
credential_dict = {
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(),
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.model_dump_json(),
|
||||
}
|
||||
if primary_admin_email:
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import timezone
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
from google.auth.exceptions import RefreshError
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
from googleapiclient.discovery import build # type: ignore[import-untyped]
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.discovery import Resource
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -99,25 +99,33 @@ def get_google_docs_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDocsService:
|
||||
return _get_google_service("docs", "v1", creds, user_email)
|
||||
return _get_google_service( # ty: ignore[invalid-return-type]
|
||||
"docs", "v1", creds, user_email
|
||||
)
|
||||
|
||||
|
||||
def get_drive_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GoogleDriveService:
|
||||
return _get_google_service("drive", "v3", creds, user_email)
|
||||
return _get_google_service( # ty: ignore[invalid-return-type]
|
||||
"drive", "v3", creds, user_email
|
||||
)
|
||||
|
||||
|
||||
def get_admin_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> AdminService:
|
||||
return _get_google_service("admin", "directory_v1", creds, user_email)
|
||||
return _get_google_service( # ty: ignore[invalid-return-type]
|
||||
"admin", "directory_v1", creds, user_email
|
||||
)
|
||||
|
||||
|
||||
def get_gmail_service(
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
user_email: str | None = None,
|
||||
) -> GmailService:
|
||||
return _get_google_service("gmail", "v1", creds, user_email)
|
||||
return _get_google_service( # ty: ignore[invalid-return-type]
|
||||
"gmail", "v1", creds, user_email
|
||||
)
|
||||
|
||||
@@ -244,12 +244,20 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync)
|
||||
doc_batch = []
|
||||
|
||||
except HighspotClientError as e:
|
||||
item_id = "ID" if not item_id else item_id
|
||||
item_id = (
|
||||
"ID"
|
||||
if not item_id # ty: ignore[possibly-unresolved-reference]
|
||||
else item_id # ty: ignore[possibly-unresolved-reference]
|
||||
)
|
||||
logger.error(
|
||||
f"Error retrieving item {item_id}: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
item_id = "ID" if not item_id else item_id
|
||||
item_id = (
|
||||
"ID"
|
||||
if not item_id # ty: ignore[possibly-unresolved-reference]
|
||||
else item_id # ty: ignore[possibly-unresolved-reference]
|
||||
)
|
||||
logger.error(
|
||||
f"Unexpected error for item {item_id}: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import requests
|
||||
from hubspot import HubSpot # type: ignore
|
||||
from hubspot import HubSpot
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
@@ -133,7 +133,7 @@ class ImapConnector(
|
||||
checkpoint: ImapCheckpoint,
|
||||
include_perm_sync: bool,
|
||||
) -> CheckpointOutput[ImapCheckpoint]:
|
||||
checkpoint = cast(ImapCheckpoint, copy.deepcopy(checkpoint))
|
||||
checkpoint = copy.deepcopy(checkpoint)
|
||||
checkpoint.has_more = True
|
||||
|
||||
mail_client = self._get_mail_client()
|
||||
@@ -188,7 +188,7 @@ class ImapConnector(
|
||||
for email_id in current_todos:
|
||||
email_msg = _fetch_email(mail_client=mail_client, email_id=email_id)
|
||||
if not email_msg:
|
||||
logger.warn(f"Failed to fetch message {email_id=}; skipping")
|
||||
logger.warning(f"Failed to fetch message {email_id=}; skipping")
|
||||
continue
|
||||
|
||||
email_headers = EmailHeaders.from_email_msg(email_msg=email_msg)
|
||||
@@ -260,7 +260,7 @@ def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> li
|
||||
elif isinstance(mailboxes_raw, str):
|
||||
mailboxes_str = mailboxes_raw
|
||||
else:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Expected the mailbox data to be of type str, instead got {type(mailboxes_raw)=} {mailboxes_raw}; skipping"
|
||||
)
|
||||
continue
|
||||
@@ -274,7 +274,7 @@ def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> li
|
||||
# The below regex matches on that pattern; from there, we select the 3rd match (index 2), which is the mailbox-name.
|
||||
match = re.match(r'\([^)]*\)\s+"([^"]+)"\s+"?(.+?)"?$', mailboxes_str)
|
||||
if not match:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Invalid mailbox-data formatting structure: {mailboxes_str=}; skipping"
|
||||
)
|
||||
continue
|
||||
@@ -391,7 +391,7 @@ def _parse_email_body(
|
||||
try:
|
||||
raw_payload = part.get_payload(decode=True)
|
||||
if not isinstance(raw_payload, bytes):
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
"Payload section from email was expected to be an array of bytes, instead got "
|
||||
f"{type(raw_payload)=}, {raw_payload=}"
|
||||
)
|
||||
@@ -403,7 +403,7 @@ def _parse_email_body(
|
||||
continue
|
||||
|
||||
if not body:
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Email with {email_headers.id=} has an empty body; returning an empty string"
|
||||
)
|
||||
return ""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import email
|
||||
import email.header
|
||||
import email.utils
|
||||
from datetime import datetime
|
||||
from email.message import Message
|
||||
from enum import Enum
|
||||
|
||||
@@ -98,8 +98,7 @@ class BaseConnector(abc.ABC, Generic[CT]):
|
||||
return NormalizationResult(normalized_url=None, use_default=True)
|
||||
|
||||
def build_dummy_checkpoint(self) -> CT:
|
||||
# TODO: find a way to make this work without type: ignore
|
||||
return ConnectorCheckpoint(has_more=True) # type: ignore
|
||||
return ConnectorCheckpoint(has_more=True) # ty: ignore[invalid-return-type]
|
||||
|
||||
|
||||
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
||||
|
||||
@@ -184,13 +184,13 @@ def _handle_jira_search_error(e: Exception, jql: str) -> None:
|
||||
else:
|
||||
error_text = str(raw_text)
|
||||
elif hasattr(e, "response") and e.response is not None:
|
||||
status_code = e.response.status_code
|
||||
status_code = e.response.status_code # ty: ignore[unresolved-attribute]
|
||||
# Try JSON first, fall back to text
|
||||
try:
|
||||
error_json = e.response.json()
|
||||
error_json = e.response.json() # ty: ignore[unresolved-attribute]
|
||||
error_text = _format_error_text(error_json)
|
||||
except Exception:
|
||||
error_text = e.response.text
|
||||
error_text = e.response.text # ty: ignore[unresolved-attribute]
|
||||
|
||||
# Handle specific status codes
|
||||
if status_code == 400:
|
||||
@@ -230,7 +230,9 @@ def enhanced_search_ids(
|
||||
"fields": "id",
|
||||
}
|
||||
try:
|
||||
response = jira_client._session.get(enhanced_search_path, params=params)
|
||||
response = jira_client._session.get( # ty: ignore[unresolved-attribute]
|
||||
enhanced_search_path, params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
except Exception as e:
|
||||
@@ -253,7 +255,9 @@ def _bulk_fetch_request(
|
||||
# to avoid reading unnecessary data
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
resp = jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
resp = jira_client._session.post( # ty: ignore[unresolved-attribute]
|
||||
bulk_fetch_path, json=payload
|
||||
)
|
||||
return resp.json()["issues"]
|
||||
|
||||
|
||||
@@ -298,7 +302,11 @@ def bulk_fetch_issues(
|
||||
raise
|
||||
|
||||
return [
|
||||
Issue(jira_client._options, jira_client._session, raw=issue)
|
||||
Issue(
|
||||
jira_client._options,
|
||||
jira_client._session, # ty: ignore[invalid-argument-type]
|
||||
raw=issue,
|
||||
)
|
||||
for issue in raw_issues
|
||||
]
|
||||
|
||||
@@ -415,18 +423,26 @@ def process_jira_issue(
|
||||
if creator is not None and (
|
||||
basic_expert_info := best_effort_basic_expert_info(creator)
|
||||
):
|
||||
people.add(basic_expert_info)
|
||||
metadata_dict[_FIELD_REPORTER] = basic_expert_info.get_semantic_name()
|
||||
if email := basic_expert_info.get_email():
|
||||
people.add(basic_expert_info) # ty: ignore[possibly-unresolved-reference]
|
||||
metadata_dict[_FIELD_REPORTER] = (
|
||||
basic_expert_info.get_semantic_name() # ty: ignore[possibly-unresolved-reference]
|
||||
)
|
||||
if (
|
||||
email := basic_expert_info.get_email() # ty: ignore[possibly-unresolved-reference]
|
||||
):
|
||||
metadata_dict[_FIELD_REPORTER_EMAIL] = email
|
||||
|
||||
assignee = best_effort_get_field_from_issue(issue, _FIELD_ASSIGNEE)
|
||||
if assignee is not None and (
|
||||
basic_expert_info := best_effort_basic_expert_info(assignee)
|
||||
):
|
||||
people.add(basic_expert_info)
|
||||
metadata_dict[_FIELD_ASSIGNEE] = basic_expert_info.get_semantic_name()
|
||||
if email := basic_expert_info.get_email():
|
||||
people.add(basic_expert_info) # ty: ignore[possibly-unresolved-reference]
|
||||
metadata_dict[_FIELD_ASSIGNEE] = (
|
||||
basic_expert_info.get_semantic_name() # ty: ignore[possibly-unresolved-reference]
|
||||
)
|
||||
if (
|
||||
email := basic_expert_info.get_email() # ty: ignore[possibly-unresolved-reference]
|
||||
):
|
||||
metadata_dict[_FIELD_ASSIGNEE_EMAIL] = email
|
||||
|
||||
metadata_dict[_FIELD_KEY] = issue.key
|
||||
@@ -818,7 +834,7 @@ class JiraConnector(
|
||||
# Add permission information to the document if requested
|
||||
if include_permissions:
|
||||
document.external_access = self._get_project_permissions(
|
||||
project_key,
|
||||
project_key, # ty: ignore[invalid-argument-type]
|
||||
add_prefix=True, # Indexing path - prefix here
|
||||
)
|
||||
yield document
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from oauthlib.oauth2 import BackendApplicationClient
|
||||
from requests_oauthlib import OAuth2Session # type: ignore
|
||||
from requests_oauthlib import OAuth2Session
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
@@ -9,10 +9,10 @@ from unittest import mock
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlunparse
|
||||
|
||||
from pywikibot import family # type: ignore[import-untyped]
|
||||
from pywikibot import family
|
||||
from pywikibot import pagegenerators
|
||||
from pywikibot.scripts import generate_family_file # type: ignore[import-untyped]
|
||||
from pywikibot.scripts.generate_user_files import pywikibot # type: ignore[import-untyped]
|
||||
from pywikibot.scripts import generate_family_file
|
||||
from pywikibot.scripts.generate_user_files import pywikibot
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -134,7 +134,7 @@ def family_class_dispatch(url: str, name: str) -> type[family.Family]:
|
||||
|
||||
"""
|
||||
if "wikipedia" in url:
|
||||
import pywikibot.families.wikipedia_family # type: ignore[import-untyped]
|
||||
import pywikibot.families.wikipedia_family
|
||||
|
||||
return pywikibot.families.wikipedia_family.Family
|
||||
# TODO: Support additional families pre-defined in `pywikibot.families.*_family.py` files
|
||||
@@ -159,7 +159,9 @@ if __name__ == "__main__":
|
||||
all_pages = itertools.chain(
|
||||
pages,
|
||||
*[
|
||||
pagegenerators.CategorizedPageGenerator(category, recurse=recursion_depth)
|
||||
pagegenerators.CategorizedPageGenerator(
|
||||
category, recurse=recursion_depth # ty: ignore[invalid-argument-type]
|
||||
)
|
||||
for category in categories
|
||||
],
|
||||
)
|
||||
|
||||
@@ -8,7 +8,8 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import ClassVar
|
||||
|
||||
import pywikibot.time # type: ignore[import-untyped]
|
||||
import pywikibot.config
|
||||
import pywikibot.time
|
||||
from pywikibot import pagegenerators
|
||||
from pywikibot import textlib
|
||||
|
||||
@@ -46,7 +47,9 @@ def pywikibot_timestamp_to_utc_datetime(
|
||||
|
||||
|
||||
def get_doc_from_page(
|
||||
page: pywikibot.Page, site: pywikibot.Site | None, source_type: DocumentSource
|
||||
page: pywikibot.Page,
|
||||
site: pywikibot.Site | None, # ty: ignore[invalid-type-form]
|
||||
source_type: DocumentSource,
|
||||
) -> Document:
|
||||
"""Generate Onyx Document from a MediaWiki page object.
|
||||
|
||||
@@ -178,7 +181,7 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
# Pywikibot can handle batching for us, including only loading page contents when we finally request them.
|
||||
category_pages = [
|
||||
pagegenerators.PreloadingGenerator(
|
||||
pagegenerators.EdittimeFilterPageGenerator(
|
||||
pagegenerators.EdittimeFilterPageGenerator( # ty: ignore[invalid-argument-type]
|
||||
pagegenerators.CategorizedPageGenerator(
|
||||
category, recurse=self.recurse_depth
|
||||
),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user