mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 06:05:46 +00:00
Compare commits
28 Commits
test-tests
...
v2.11.0-cl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bdc7f6c100 | ||
|
|
90f8656afa | ||
|
|
3c7d35a6e8 | ||
|
|
40d58a37e3 | ||
|
|
be3ecd9640 | ||
|
|
a6da511490 | ||
|
|
c7577ebe58 | ||
|
|
b87078a4f5 | ||
|
|
8a408e7023 | ||
|
|
4c7b73a355 | ||
|
|
8e9cb94d4f | ||
|
|
a21af4b906 | ||
|
|
7f0ce0531f | ||
|
|
b631bfa656 | ||
|
|
eca6b6bef2 | ||
|
|
51ef28305d | ||
|
|
144030c5ca | ||
|
|
a557d76041 | ||
|
|
605e808158 | ||
|
|
8fec88c90d | ||
|
|
e54969a693 | ||
|
|
1da2b2f28f | ||
|
|
eb7b91e08e | ||
|
|
3339000968 | ||
|
|
d9db849e94 | ||
|
|
046408359c | ||
|
|
4b8cca190f | ||
|
|
52a312a63b |
@@ -66,7 +66,8 @@ repos:
|
||||
- id: uv-run
|
||||
name: Check lazy imports
|
||||
args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"]
|
||||
files: ^backend/(?!\.venv/).*\.py$
|
||||
pass_filenames: true
|
||||
files: ^backend/(?!\.venv/|scripts/).*\.py$
|
||||
# NOTE: This takes ~6s on a single, large module which is prohibitively slow.
|
||||
# - id: uv-run
|
||||
# name: mypy
|
||||
|
||||
58
.vscode/launch.json
vendored
58
.vscode/launch.json
vendored
@@ -149,6 +149,24 @@
|
||||
},
|
||||
"consoleTitle": "Slack Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "Discord Bot",
|
||||
"consoleName": "Discord Bot",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "onyx/onyxbot/discord/client.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Discord Bot Console"
|
||||
},
|
||||
{
|
||||
"name": "MCP Server",
|
||||
"consoleName": "MCP Server",
|
||||
@@ -397,7 +415,6 @@
|
||||
"onyx.background.celery.versioned_apps.docfetching",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docfetching@%n",
|
||||
@@ -428,7 +445,6 @@
|
||||
"onyx.background.celery.versioned_apps.docprocessing",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=6",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=docprocessing@%n",
|
||||
@@ -577,6 +593,23 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Build Sandbox Templates",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "onyx.server.features.build.sandbox.build_templates",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
},
|
||||
"consoleTitle": "Build Sandbox Templates"
|
||||
},
|
||||
{
|
||||
// Dummy entry used to label the group
|
||||
"name": "--- Database ---",
|
||||
@@ -587,6 +620,27 @@
|
||||
"order": 0
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Restore seeded database dump",
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "uv",
|
||||
"runtimeArgs": [
|
||||
"run",
|
||||
"--with",
|
||||
"onyx-devtools",
|
||||
"ods",
|
||||
"db",
|
||||
"restore",
|
||||
"--fetch-seeded",
|
||||
"--yes"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "4"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clean restore seeded database dump (destructive)",
|
||||
"type": "node",
|
||||
|
||||
@@ -16,3 +16,8 @@ dist/
|
||||
.coverage
|
||||
htmlcov/
|
||||
model_server/legacy/
|
||||
|
||||
# Craft: demo_data directory should be unzipped at container startup, not copied
|
||||
**/demo_data/
|
||||
# Craft: templates/outputs/venv is created at container startup
|
||||
**/templates/outputs/venv
|
||||
|
||||
@@ -7,6 +7,10 @@ have a contract or agreement with DanswerAI, you are not permitted to use the En
|
||||
Edition features outside of personal development or testing purposes. Please reach out to \
|
||||
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
|
||||
|
||||
# Build argument for Craft support (disabled by default)
|
||||
# Use --build-arg ENABLE_CRAFT=true to include Node.js and opencode CLI
|
||||
ARG ENABLE_CRAFT=false
|
||||
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true" \
|
||||
@@ -46,7 +50,23 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
# Conditionally install Node.js 20 for Craft (required for Next.js)
|
||||
# Only installed when ENABLE_CRAFT=true
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Installing Node.js 20 for Craft support..." && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
rm -rf /var/lib/apt/lists/*; \
|
||||
fi
|
||||
|
||||
# Conditionally install opencode CLI for Craft agent functionality
|
||||
# Only installed when ENABLE_CRAFT=true
|
||||
# TODO: download a specific, versioned release of the opencode CLI
|
||||
RUN if [ "$ENABLE_CRAFT" = "true" ]; then \
|
||||
echo "Installing opencode CLI for Craft support..." && \
|
||||
curl -fsSL https://opencode.ai/install | bash; \
|
||||
fi
|
||||
ENV PATH="/root/.opencode/bin:${PATH}"
|
||||
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
@@ -89,6 +109,12 @@ RUN uv pip install --system --no-cache-dir --upgrade \
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
nltk.download('stopwords', quiet=True); \
|
||||
nltk.download('punkt_tab', quiet=True);"
|
||||
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
|
||||
|
||||
# Pre-downloading tiktoken for setups with limited egress
|
||||
RUN python -c "import tiktoken; \
|
||||
tiktoken.get_encoding('cl100k_base')"
|
||||
@@ -113,7 +139,8 @@ COPY --chown=onyx:onyx ./static /app/static
|
||||
COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Put logo in assets
|
||||
COPY --chown=onyx:onyx ./assets /app/assets
|
||||
|
||||
@@ -0,0 +1,351 @@
|
||||
"""single onyx craft migration
|
||||
|
||||
Consolidates all buildmode/onyx craft tables into a single migration.
|
||||
|
||||
Tables created:
|
||||
- build_session: User build sessions with status tracking
|
||||
- sandbox: User-owned containerized environments (one per user)
|
||||
- artifact: Build output files (web apps, documents, images)
|
||||
- snapshot: Sandbox filesystem snapshots
|
||||
- build_message: Conversation messages for build sessions
|
||||
|
||||
Existing table modified:
|
||||
- connector_credential_pair: Added processing_mode column
|
||||
|
||||
Revision ID: 2020d417ec84
|
||||
Revises: 41fa44bef321
|
||||
Create Date: 2026-01-26 14:43:54.641405
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2020d417ec84"
|
||||
down_revision = "41fa44bef321"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ==========================================================================
|
||||
# ENUMS
|
||||
# ==========================================================================
|
||||
|
||||
# Build session status enum
|
||||
build_session_status_enum = sa.Enum(
|
||||
"active",
|
||||
"idle",
|
||||
name="buildsessionstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Sandbox status enum
|
||||
sandbox_status_enum = sa.Enum(
|
||||
"provisioning",
|
||||
"running",
|
||||
"idle",
|
||||
"sleeping",
|
||||
"terminated",
|
||||
"failed",
|
||||
name="sandboxstatus",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# Artifact type enum
|
||||
artifact_type_enum = sa.Enum(
|
||||
"web_app",
|
||||
"pptx",
|
||||
"docx",
|
||||
"markdown",
|
||||
"excel",
|
||||
"image",
|
||||
name="artifacttype",
|
||||
native_enum=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_SESSION TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"build_session",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
build_session_status_enum,
|
||||
nullable=False,
|
||||
server_default="active",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"last_activity_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("nextjs_port", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_build_session_user_created",
|
||||
"build_session",
|
||||
["user_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_build_session_status",
|
||||
"build_session",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# SANDBOX TABLE (user-owned, one per user)
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"sandbox",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("container_id", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sandbox_status_enum,
|
||||
nullable=False,
|
||||
server_default="provisioning",
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_heartbeat", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("user_id", name="sandbox_user_id_key"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_sandbox_status",
|
||||
"sandbox",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_sandbox_container_id",
|
||||
"sandbox",
|
||||
["container_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# ARTIFACT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"artifact",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("type", artifact_type_enum, nullable=False),
|
||||
sa.Column("path", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_artifact_session_created",
|
||||
"artifact",
|
||||
["session_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_artifact_type",
|
||||
"artifact",
|
||||
["type"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# SNAPSHOT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"snapshot",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("storage_path", sa.String(), nullable=False),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_snapshot_session_created",
|
||||
"snapshot",
|
||||
["session_id", sa.text("created_at DESC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_MESSAGE TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.create_table(
|
||||
"build_message",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True),
|
||||
sa.Column(
|
||||
"session_id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
sa.ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"turn_index",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"type",
|
||||
sa.Enum(
|
||||
"SYSTEM",
|
||||
"USER",
|
||||
"ASSISTANT",
|
||||
"DANSWER",
|
||||
name="messagetype",
|
||||
create_type=False,
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"message_metadata",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_build_message_session_turn",
|
||||
"build_message",
|
||||
["session_id", "turn_index", sa.text("created_at ASC")],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
# ==========================================================================
|
||||
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
|
||||
# ==========================================================================
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"processing_mode",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="regular",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ==========================================================================
|
||||
# CONNECTOR_CREDENTIAL_PAIR MODIFICATION
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_column("connector_credential_pair", "processing_mode")
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_MESSAGE TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_build_message_session_turn", table_name="build_message")
|
||||
op.drop_table("build_message")
|
||||
|
||||
# ==========================================================================
|
||||
# SNAPSHOT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_snapshot_session_created", table_name="snapshot")
|
||||
op.drop_table("snapshot")
|
||||
|
||||
# ==========================================================================
|
||||
# ARTIFACT TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_artifact_type", table_name="artifact")
|
||||
op.drop_index("ix_artifact_session_created", table_name="artifact")
|
||||
op.drop_table("artifact")
|
||||
sa.Enum(name="artifacttype").drop(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ==========================================================================
|
||||
# SANDBOX TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_sandbox_container_id", table_name="sandbox")
|
||||
op.drop_index("ix_sandbox_status", table_name="sandbox")
|
||||
op.drop_table("sandbox")
|
||||
sa.Enum(name="sandboxstatus").drop(op.get_bind(), checkfirst=True)
|
||||
|
||||
# ==========================================================================
|
||||
# BUILD_SESSION TABLE
|
||||
# ==========================================================================
|
||||
|
||||
op.drop_index("ix_build_session_status", table_name="build_session")
|
||||
op.drop_index("ix_build_session_user_created", table_name="build_session")
|
||||
op.drop_table("build_session")
|
||||
sa.Enum(name="buildsessionstatus").drop(op.get_bind(), checkfirst=True)
|
||||
@@ -0,0 +1,45 @@
|
||||
"""make processing mode default all caps
|
||||
|
||||
Revision ID: 72aa7de2e5cf
|
||||
Revises: 2020d417ec84
|
||||
Create Date: 2026-01-26 18:58:47.705253
|
||||
|
||||
This migration fixes the ProcessingMode enum value mismatch:
|
||||
- SQLAlchemy's Enum with native_enum=False uses enum member NAMES as valid values
|
||||
- The original migration stored lowercase VALUES ('regular', 'file_system')
|
||||
- This converts existing data to uppercase NAMES ('REGULAR', 'FILE_SYSTEM')
|
||||
- Also drops any spurious native PostgreSQL enum type that may have been auto-created
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "72aa7de2e5cf"
|
||||
down_revision = "2020d417ec84"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Convert existing lowercase values to uppercase to match enum member names
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'REGULAR' "
|
||||
"WHERE processing_mode = 'regular'"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET processing_mode = 'FILE_SYSTEM' "
|
||||
"WHERE processing_mode = 'file_system'"
|
||||
)
|
||||
|
||||
# Update the server default to use uppercase
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"processing_mode",
|
||||
server_default="REGULAR",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# State prior to this was broken, so we don't want to revert back to it
|
||||
pass
|
||||
@@ -122,6 +122,9 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app
|
||||
POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") or "FooBar"
|
||||
POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
POSTHOG_DEBUG_LOGS_ENABLED = (
|
||||
os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY")
|
||||
|
||||
@@ -133,3 +136,9 @@ GATED_TENANTS_KEY = "gated_tenants"
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints
|
||||
# Used when MULTI_TENANT=false (self-hosted mode)
|
||||
CLOUD_DATA_PLANE_URL = os.environ.get(
|
||||
"CLOUD_DATA_PLANE_URL", "https://cloud.onyx.app/api"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Database and cache operations for the license table."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import NamedTuple
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
@@ -9,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -23,6 +25,13 @@ LICENSE_METADATA_KEY = "license:metadata"
|
||||
LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours
|
||||
|
||||
|
||||
class SeatAvailabilityResult(NamedTuple):
|
||||
"""Result of a seat availability check."""
|
||||
|
||||
available: bool
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Database CRUD Operations
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -95,23 +104,30 @@ def delete_license(db_session: Session) -> bool:
|
||||
|
||||
def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
"""
|
||||
Get current seat usage.
|
||||
Get current seat usage directly from database.
|
||||
|
||||
For multi-tenant: counts users in UserTenantMapping for this tenant.
|
||||
For self-hosted: counts all active users (includes both Onyx UI users
|
||||
and Slack users who have been converted to Onyx users).
|
||||
For self-hosted: counts all active users (excludes EXT_PERM_USER role).
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_count
|
||||
|
||||
return get_tenant_count(tenant_id or get_current_tenant_id())
|
||||
else:
|
||||
# Self-hosted: count all active users (Onyx + converted Slack users)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
result = db_session.execute(
|
||||
select(func.count()).select_from(User).where(User.is_active) # type: ignore
|
||||
select(func.count())
|
||||
.select_from(User)
|
||||
.where(
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
User.role != UserRole.EXT_PERM_USER,
|
||||
)
|
||||
)
|
||||
return result.scalar() or 0
|
||||
|
||||
@@ -276,3 +292,43 @@ def get_license_metadata(
|
||||
|
||||
# Refresh from database
|
||||
return refresh_license_cache(db_session, tenant_id)
|
||||
|
||||
|
||||
def check_seat_availability(
|
||||
db_session: Session,
|
||||
seats_needed: int = 1,
|
||||
tenant_id: str | None = None,
|
||||
) -> SeatAvailabilityResult:
|
||||
"""
|
||||
Check if there are enough seats available to add users.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
seats_needed: Number of seats needed (default 1)
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
|
||||
Returns:
|
||||
SeatAvailabilityResult with available=True if seats are available,
|
||||
or available=False with error_message if limit would be exceeded.
|
||||
Returns available=True if no license exists (self-hosted = unlimited).
|
||||
"""
|
||||
metadata = get_license_metadata(db_session, tenant_id)
|
||||
|
||||
# No license = no enforcement (self-hosted without license)
|
||||
if metadata is None:
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
# Calculate current usage directly from DB (not cache) for accuracy
|
||||
current_used = get_used_seats(tenant_id)
|
||||
total_seats = metadata.seats
|
||||
|
||||
# Use > (not >=) to allow filling to exactly 100% capacity
|
||||
would_exceed_limit = current_used + seats_needed > total_seats
|
||||
if would_exceed_limit:
|
||||
return SeatAvailabilityResult(
|
||||
available=False,
|
||||
error_message=f"Seat limit would be exceeded: {current_used} of {total_seats} seats used, "
|
||||
f"cannot add {seats_needed} more user(s).",
|
||||
)
|
||||
|
||||
return SeatAvailabilityResult(available=True)
|
||||
|
||||
@@ -12,6 +12,12 @@ EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
("/enterprise-settings/custom-analytics-script", {"GET"}),
|
||||
# Stripe publishable key is safe to expose publicly
|
||||
("/tenants/stripe-publishable-key", {"GET"}),
|
||||
# Proxy endpoints use license-based auth, not user auth
|
||||
("/proxy/create-checkout-session", {"POST"}),
|
||||
("/proxy/claim-license", {"POST"}),
|
||||
("/proxy/create-customer-portal-session", {"POST"}),
|
||||
("/proxy/billing-information", {"GET"}),
|
||||
("/proxy/license/{tenant_id}", {"GET"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
# /me - Basic user info needed for UI rendering
|
||||
# /settings, /enterprise-settings - View app status and branding
|
||||
# /tenants/billing-* - Manage subscription to resolve gating
|
||||
# /proxy - Self-hosted proxy endpoints (have own license-based auth)
|
||||
ALLOWED_PATH_PREFIXES = {
|
||||
"/auth",
|
||||
"/license",
|
||||
@@ -35,6 +36,7 @@ ALLOWED_PATH_PREFIXES = {
|
||||
"/tenants/billing-information",
|
||||
"/tenants/create-customer-portal-session",
|
||||
"/tenants/create-subscription-session",
|
||||
"/proxy",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from fastapi import APIRouter
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.proxy import router as proxy_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
@@ -22,3 +23,4 @@ router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
router.include_router(proxy_router)
|
||||
|
||||
450
backend/ee/onyx/server/tenants/proxy.py
Normal file
450
backend/ee/onyx/server/tenants/proxy.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""Proxy endpoints for billing operations.
|
||||
|
||||
These endpoints run on the CLOUD DATA PLANE (cloud.onyx.app) and serve as a proxy
|
||||
for self-hosted instances to reach the control plane.
|
||||
|
||||
Flow:
|
||||
Self-hosted backend → Cloud DP /proxy/* (license auth) → Control plane (JWT auth)
|
||||
|
||||
Self-hosted instances call these endpoints with their license in the Authorization
|
||||
header. The cloud data plane validates the license signature and forwards the
|
||||
request to the control plane using JWT authentication.
|
||||
|
||||
Auth levels by endpoint:
|
||||
- /create-checkout-session: No auth (new customer) or expired license OK (renewal)
|
||||
- /claim-license: Session ID based (one-time after Stripe payment)
|
||||
- /create-customer-portal-session: Expired license OK (need portal to fix payment)
|
||||
- /billing-information: Valid license required
|
||||
- /license/{tenant_id}: Valid license required
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import Header
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import update_license_cache
|
||||
from ee.onyx.db.license import upsert_license
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.utils.license import is_license_valid
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/proxy")
|
||||
|
||||
|
||||
def _check_license_enforcement_enabled() -> None:
|
||||
"""Ensure LICENSE_ENFORCEMENT_ENABLED is true (proxy endpoints only work on cloud DP)."""
|
||||
if not LICENSE_ENFORCEMENT_ENABLED:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail="Proxy endpoints are only available on cloud data plane",
|
||||
)
|
||||
|
||||
|
||||
def _extract_license_from_header(
|
||||
authorization: str | None,
|
||||
required: bool = True,
|
||||
) -> str | None:
|
||||
"""Extract license data from Authorization header.
|
||||
|
||||
Self-hosted instances authenticate to these proxy endpoints by sending their
|
||||
license as a Bearer token: `Authorization: Bearer <base64-encoded-license>`.
|
||||
|
||||
We use the Bearer scheme (RFC 6750) because:
|
||||
1. It's the standard HTTP auth scheme for token-based authentication
|
||||
2. The license blob is cryptographically signed (RSA), so it's self-validating
|
||||
3. No other auth schemes (Basic, Digest, etc.) are supported for license auth
|
||||
|
||||
The license data is the base64-encoded signed blob that contains tenant_id,
|
||||
seats, expiration, etc. We verify the signature to authenticate the caller.
|
||||
|
||||
Args:
|
||||
authorization: The Authorization header value (e.g., "Bearer <license>")
|
||||
required: If True, raise 401 when header is missing/invalid
|
||||
|
||||
Returns:
|
||||
License data string (base64-encoded), or None if not required and missing
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if required and header is missing/invalid
|
||||
"""
|
||||
if not authorization or not authorization.startswith("Bearer "):
|
||||
if required:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Missing or invalid authorization header"
|
||||
)
|
||||
return None
|
||||
|
||||
return authorization.split(" ", 1)[1]
|
||||
|
||||
|
||||
def verify_license_auth(
|
||||
license_data: str,
|
||||
allow_expired: bool = False,
|
||||
) -> LicensePayload:
|
||||
"""Verify license signature and optionally check expiry.
|
||||
|
||||
Args:
|
||||
license_data: Base64-encoded signed license blob
|
||||
allow_expired: If True, accept expired licenses (for renewal flows)
|
||||
|
||||
Returns:
|
||||
LicensePayload if valid
|
||||
|
||||
Raises:
|
||||
HTTPException: If license is invalid or expired (when not allowed)
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=401, detail=f"Invalid license: {e}")
|
||||
|
||||
if not allow_expired and not is_license_valid(payload):
|
||||
raise HTTPException(status_code=401, detail="License has expired")
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
async def get_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require valid (non-expired) license.
|
||||
|
||||
Used for endpoints that require an active subscription.
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=False)
|
||||
|
||||
|
||||
async def get_license_payload_allow_expired(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload:
|
||||
"""Dependency: Require license with valid signature, expired OK.
|
||||
|
||||
Used for endpoints needed to fix payment issues (portal, renewal checkout).
|
||||
"""
|
||||
license_data = _extract_license_from_header(authorization, required=True)
|
||||
# license_data is guaranteed non-None when required=True
|
||||
assert license_data is not None
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def get_optional_license_payload(
|
||||
authorization: str | None = Header(None, alias="Authorization"),
|
||||
) -> LicensePayload | None:
|
||||
"""Dependency: Optional license auth (for checkout - new customers have none).
|
||||
|
||||
Returns None if no license provided, otherwise validates and returns payload.
|
||||
Expired licenses are allowed for renewal flows.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
license_data = _extract_license_from_header(authorization, required=False)
|
||||
if license_data is None:
|
||||
return None
|
||||
|
||||
return verify_license_auth(license_data, allow_expired=True)
|
||||
|
||||
|
||||
async def forward_to_control_plane(
|
||||
method: str,
|
||||
path: str,
|
||||
body: dict | None = None,
|
||||
params: dict | None = None,
|
||||
) -> dict:
|
||||
"""Forward a request to the control plane with proper authentication."""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}{path}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
if method == "GET":
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
elif method == "POST":
|
||||
response = await client.post(url, headers=headers, json=body)
|
||||
else:
|
||||
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
status_code = e.response.status_code
|
||||
detail = "Control plane request failed"
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"Control plane returned {status_code}: {detail}")
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to control plane")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to control plane"
|
||||
)
|
||||
|
||||
|
||||
def fetch_and_store_license(tenant_id: str, license_data: str) -> None:
|
||||
"""Store license in database and update Redis cache.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
license_data: Base64-encoded signed license blob
|
||||
"""
|
||||
try:
|
||||
# Verify before storing
|
||||
payload = verify_license_signature(license_data)
|
||||
|
||||
# Store in database using the specific tenant's schema
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
upsert_license(db_session, license_data)
|
||||
|
||||
# Update Redis cache
|
||||
update_license_cache(
|
||||
payload,
|
||||
source=LicenseSource.AUTO_FETCH,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to verify license: {e}")
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to store license")
|
||||
raise
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
billing_period: Literal["monthly", "annual"] = "monthly"
|
||||
email: str | None = None
|
||||
# Redirect URL after successful checkout - self-hosted passes their instance URL
|
||||
redirect_url: str | None = None
|
||||
# Cancel URL when user exits checkout - returns to upgrade page
|
||||
cancel_url: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def proxy_create_checkout_session(
|
||||
request_body: CreateCheckoutSessionRequest,
|
||||
license_payload: LicensePayload | None = Depends(get_optional_license_payload),
|
||||
) -> CreateCheckoutSessionResponse:
|
||||
"""Proxy checkout session creation to control plane.
|
||||
|
||||
Auth: Optional license (new customers don't have one yet).
|
||||
If license provided, expired is OK (for renewals).
|
||||
"""
|
||||
# license_payload is None for new customers who don't have a license yet.
|
||||
# In that case, tenant_id is omitted from the request body and the control
|
||||
# plane will create a new tenant during checkout completion.
|
||||
tenant_id = license_payload.tenant_id if license_payload else None
|
||||
|
||||
body: dict = {
|
||||
"billing_period": request_body.billing_period,
|
||||
}
|
||||
if tenant_id:
|
||||
body["tenant_id"] = tenant_id
|
||||
if request_body.email:
|
||||
body["email"] = request_body.email
|
||||
if request_body.redirect_url:
|
||||
body["redirect_url"] = request_body.redirect_url
|
||||
if request_body.cancel_url:
|
||||
body["cancel_url"] = request_body.cancel_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-checkout-session", body=body
|
||||
)
|
||||
return CreateCheckoutSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class ClaimLicenseRequest(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
class ClaimLicenseResponse(BaseModel):
|
||||
tenant_id: str
|
||||
license: str
|
||||
message: str | None = None
|
||||
|
||||
|
||||
@router.post("/claim-license")
|
||||
async def proxy_claim_license(
|
||||
request_body: ClaimLicenseRequest,
|
||||
) -> ClaimLicenseResponse:
|
||||
"""Claim a license after successful Stripe checkout.
|
||||
|
||||
Auth: Session ID based (one-time use after payment).
|
||||
The control plane verifies the session_id is valid and unclaimed.
|
||||
|
||||
Returns the license to the caller. For self-hosted instances, they will
|
||||
store the license locally. The cloud DP doesn't need to store it.
|
||||
"""
|
||||
_check_license_enforcement_enabled()
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST",
|
||||
"/claim-license",
|
||||
body={"session_id": request_body.session_id},
|
||||
)
|
||||
|
||||
tenant_id = result.get("tenant_id")
|
||||
license_data = result.get("license")
|
||||
|
||||
if not tenant_id or not license_data:
|
||||
logger.error(f"Control plane returned incomplete claim response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
return ClaimLicenseResponse(
|
||||
tenant_id=tenant_id,
|
||||
license=license_data,
|
||||
message="License claimed successfully",
|
||||
)
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionRequest(BaseModel):
|
||||
return_url: str | None = None
|
||||
|
||||
|
||||
class CreateCustomerPortalSessionResponse(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def proxy_create_customer_portal_session(
|
||||
request_body: CreateCustomerPortalSessionRequest | None = None,
|
||||
license_payload: LicensePayload = Depends(get_license_payload_allow_expired),
|
||||
) -> CreateCustomerPortalSessionResponse:
|
||||
"""Proxy customer portal session creation to control plane.
|
||||
|
||||
Auth: License required, expired OK (need portal to fix payment issues).
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
body: dict = {"tenant_id": tenant_id}
|
||||
if request_body and request_body.return_url:
|
||||
body["return_url"] = request_body.return_url
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"POST", "/create-customer-portal-session", body=body
|
||||
)
|
||||
return CreateCustomerPortalSessionResponse(url=result["url"])
|
||||
|
||||
|
||||
class BillingInformationResponse(BaseModel):
|
||||
tenant_id: str
|
||||
status: str | None = None
|
||||
plan_type: str | None = None
|
||||
seats: int | None = None
|
||||
billing_period: str | None = None
|
||||
current_period_start: str | None = None
|
||||
current_period_end: str | None = None
|
||||
cancel_at_period_end: bool = False
|
||||
canceled_at: str | None = None
|
||||
trial_start: str | None = None
|
||||
trial_end: str | None = None
|
||||
payment_method_enabled: bool = False
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def proxy_billing_information(
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> BillingInformationResponse:
|
||||
"""Proxy billing information request to control plane.
|
||||
|
||||
Auth: Valid (non-expired) license required.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
tenant_id = license_payload.tenant_id
|
||||
|
||||
result = await forward_to_control_plane(
|
||||
"GET", "/billing-information", params={"tenant_id": tenant_id}
|
||||
)
|
||||
# Add tenant_id from license if not in response (control plane may not include it)
|
||||
if "tenant_id" not in result:
|
||||
result["tenant_id"] = tenant_id
|
||||
return BillingInformationResponse(**result)
|
||||
|
||||
|
||||
class LicenseFetchResponse(BaseModel):
|
||||
license: str
|
||||
tenant_id: str
|
||||
|
||||
|
||||
@router.get("/license/{tenant_id}")
|
||||
async def proxy_license_fetch(
|
||||
tenant_id: str,
|
||||
license_payload: LicensePayload = Depends(get_license_payload),
|
||||
) -> LicenseFetchResponse:
|
||||
"""Proxy license fetch to control plane.
|
||||
|
||||
Auth: Valid license required.
|
||||
The tenant_id in path must match the authenticated tenant.
|
||||
"""
|
||||
# tenant_id is a required field in LicensePayload (Pydantic validates this),
|
||||
# but we check explicitly for defense in depth
|
||||
if not license_payload.tenant_id:
|
||||
raise HTTPException(status_code=401, detail="License missing tenant_id")
|
||||
|
||||
if tenant_id != license_payload.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Cannot fetch license for a different tenant",
|
||||
)
|
||||
|
||||
result = await forward_to_control_plane("GET", f"/license/{tenant_id}")
|
||||
|
||||
# Auto-store the refreshed license
|
||||
license_data = result.get("license")
|
||||
if not license_data:
|
||||
logger.error(f"Control plane returned incomplete license response: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Control plane returned incomplete license data",
|
||||
)
|
||||
|
||||
fetch_and_store_license(tenant_id, license_data)
|
||||
|
||||
return LicenseFetchResponse(license=license_data, tenant_id=tenant_id)
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
|
||||
from ee.onyx.db.license import invalidate_license_cache
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
@@ -47,6 +48,8 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
@@ -70,49 +73,104 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
|
||||
|
||||
If a user already has an active mapping to a different tenant, they receive
|
||||
an inactive mapping (invitation) to this tenant. They can accept the
|
||||
invitation later to switch tenants.
|
||||
|
||||
Raises:
|
||||
HTTPException: 402 if adding active users would exceed seat limit
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant as get_tenant_session
|
||||
|
||||
unique_emails = set(emails)
|
||||
if not unique_emails:
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
for email in emails:
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
# Batch query 1: Get all existing mappings for these emails to this tenant
|
||||
# Lock rows to prevent concurrent modifications
|
||||
existing_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.all()
|
||||
)
|
||||
emails_with_mapping = {m.email for m in existing_mappings}
|
||||
|
||||
# If user already has an active mapping, add this one as inactive
|
||||
if not existing_mapping:
|
||||
# Check if the user already has an active mapping to any tenant
|
||||
has_active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
# Batch query 2: Get all active mappings for these emails (any tenant)
|
||||
active_mappings = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(unique_emails),
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails_with_active_mapping = {m.email for m in active_mappings}
|
||||
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=False if has_active_mapping else True,
|
||||
)
|
||||
# Determine which users will consume a new seat.
|
||||
# Users with active mappings elsewhere get INACTIVE mappings (invitations)
|
||||
# and don't consume seats until they accept. Only users without any active
|
||||
# mapping will get an ACTIVE mapping and consume a seat immediately.
|
||||
emails_consuming_seats = {
|
||||
email
|
||||
for email in unique_emails
|
||||
if email not in emails_with_mapping
|
||||
and email not in emails_with_active_mapping
|
||||
}
|
||||
|
||||
# Check seat availability inside the transaction to prevent race conditions.
|
||||
# Note: ALL users in unique_emails still get added below - this check only
|
||||
# validates we have capacity for users who will consume seats immediately.
|
||||
if emails_consuming_seats:
|
||||
with get_tenant_session(tenant_id=tenant_id) as tenant_session:
|
||||
result = check_seat_availability(
|
||||
tenant_session,
|
||||
seats_needed=len(emails_consuming_seats),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
if not result.available:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=result.error_message or "Seat limit exceeded",
|
||||
)
|
||||
|
||||
# Add mappings for emails that don't already have one to this tenant
|
||||
for email in unique_emails:
|
||||
if email in emails_with_mapping:
|
||||
continue
|
||||
|
||||
# Create mapping: inactive if user belongs to another tenant (invitation),
|
||||
# active otherwise
|
||||
db_session.add(
|
||||
UserTenantMapping(
|
||||
email=email,
|
||||
tenant_id=tenant_id,
|
||||
active=email not in emails_with_active_mapping,
|
||||
)
|
||||
)
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
except HTTPException:
|
||||
db_session.rollback()
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.rollback()
|
||||
@@ -135,6 +193,9 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
@@ -149,6 +210,9 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
|
||||
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -177,6 +241,9 @@ def approve_user_invite(email: str, tenant_id: str) -> None:
|
||||
db_session.add(new_mapping)
|
||||
db_session.commit()
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
|
||||
# Also remove the user from pending users list
|
||||
# Remove from pending users
|
||||
pending_users = get_pending_users()
|
||||
@@ -195,19 +262,42 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
This activates the user's mapping to the tenant.
|
||||
|
||||
Raises:
|
||||
HTTPException: 402 if accepting would exceed seat limit
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.db.license import check_seat_availability
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
# First check if there's an active mapping for this user and tenant
|
||||
# Lock the user's mappings first to prevent race conditions.
|
||||
# This ensures no concurrent request can modify this user's mappings
|
||||
# while we check seats and activate.
|
||||
active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check seat availability within the same logical operation.
|
||||
# Note: This queries fresh data from DB, not cache.
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
result = check_seat_availability(
|
||||
tenant_session, seats_needed=1, tenant_id=tenant_id
|
||||
)
|
||||
if not result.available:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail=result.error_message or "Seat limit exceeded",
|
||||
)
|
||||
|
||||
# If an active mapping exists, delete it
|
||||
if active_mapping:
|
||||
db_session.delete(active_mapping)
|
||||
@@ -237,6 +327,9 @@ def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
|
||||
|
||||
# Invalidate license cache so used_seats reflects the new count
|
||||
invalidate_license_cache(tenant_id)
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
@@ -297,16 +390,41 @@ def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
|
||||
def get_tenant_count(tenant_id: str) -> int:
|
||||
"""
|
||||
Get the number of active users for this tenant
|
||||
Get the number of active users for this tenant.
|
||||
|
||||
A user counts toward the seat count if:
|
||||
1. They have an active mapping to this tenant (UserTenantMapping.active == True)
|
||||
2. AND the User is active (User.is_active == True)
|
||||
|
||||
TODO: Exclude API key dummy users from seat counting. API keys create
|
||||
users with emails like `__DANSWER_API_KEY_*` that should not count toward
|
||||
seat limits. See: https://linear.app/onyx-app/issue/ENG-3518
|
||||
"""
|
||||
from onyx.db.models import User
|
||||
|
||||
# First get all emails with active mappings to this tenant
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Count the number of active users for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
active_mapping_emails = (
|
||||
db_session.query(UserTenantMapping.email)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.all()
|
||||
)
|
||||
emails = [email for (email,) in active_mapping_emails]
|
||||
|
||||
if not emails:
|
||||
return 0
|
||||
|
||||
# Now count how many of those users are actually active in the tenant's User table
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
user_count = (
|
||||
db_session.query(User)
|
||||
.filter(
|
||||
User.email.in_(emails), # type: ignore
|
||||
User.is_active == True, # type: ignore # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
@@ -19,21 +20,27 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# RSA-4096 Public Key for license verification
|
||||
# Load from environment variable - key is generated on the control plane
|
||||
# In production, inject via Kubernetes secrets or secrets manager
|
||||
LICENSE_PUBLIC_KEY_PEM = os.environ.get("LICENSE_PUBLIC_KEY_PEM", "")
|
||||
# Path to the license public key file
|
||||
_LICENSE_PUBLIC_KEY_PATH = (
|
||||
Path(__file__).parent.parent.parent.parent / "keys" / "license_public_key.pem"
|
||||
)
|
||||
|
||||
|
||||
def _get_public_key() -> RSAPublicKey:
|
||||
"""Load the public key from environment variable."""
|
||||
if not LICENSE_PUBLIC_KEY_PEM:
|
||||
raise ValueError(
|
||||
"LICENSE_PUBLIC_KEY_PEM environment variable not set. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key = serialization.load_pem_public_key(LICENSE_PUBLIC_KEY_PEM.encode())
|
||||
"""Load the public key from file, with env var override."""
|
||||
# Allow env var override for flexibility
|
||||
key_pem = os.environ.get("LICENSE_PUBLIC_KEY_PEM")
|
||||
|
||||
if not key_pem:
|
||||
# Read from file
|
||||
if not _LICENSE_PUBLIC_KEY_PATH.exists():
|
||||
raise ValueError(
|
||||
f"License public key not found at {_LICENSE_PUBLIC_KEY_PATH}. "
|
||||
"License verification requires the control plane public key."
|
||||
)
|
||||
key_pem = _LICENSE_PUBLIC_KEY_PATH.read_text()
|
||||
|
||||
key = serialization.load_pem_public_key(key_pem.encode())
|
||||
if not isinstance(key, RSAPublicKey):
|
||||
raise ValueError("Expected RSA public key")
|
||||
return key
|
||||
@@ -53,17 +60,21 @@ def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
ValueError: If license data is invalid or signature verification fails
|
||||
"""
|
||||
try:
|
||||
# Decode the license data
|
||||
decoded = json.loads(base64.b64decode(license_data))
|
||||
|
||||
# Parse into LicenseData to validate structure
|
||||
license_obj = LicenseData(**decoded)
|
||||
|
||||
payload_json = json.dumps(
|
||||
license_obj.payload.model_dump(mode="json"), sort_keys=True
|
||||
)
|
||||
# IMPORTANT: Use the ORIGINAL payload JSON for signature verification,
|
||||
# not re-serialized through Pydantic. Pydantic may format fields differently
|
||||
# (e.g., datetime "+00:00" vs "Z") which would break signature verification.
|
||||
original_payload = decoded.get("payload", {})
|
||||
payload_json = json.dumps(original_payload, sort_keys=True)
|
||||
signature_bytes = base64.b64decode(license_obj.signature)
|
||||
|
||||
# Verify signature using PSS padding (modern standard)
|
||||
public_key = _get_public_key()
|
||||
|
||||
public_key.verify(
|
||||
signature_bytes,
|
||||
payload_json.encode(),
|
||||
@@ -77,16 +88,18 @@ def verify_license_signature(license_data: str) -> LicensePayload:
|
||||
return license_obj.payload
|
||||
|
||||
except InvalidSignature:
|
||||
logger.error("License signature verification failed")
|
||||
logger.error("[verify_license] FAILED: Signature verification failed")
|
||||
raise ValueError("Invalid license signature")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Failed to decode license JSON")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"[verify_license] FAILED: JSON decode error: {e}")
|
||||
raise ValueError("Invalid license format: not valid JSON")
|
||||
except (ValueError, KeyError, TypeError) as e:
|
||||
logger.error(f"License data validation error: {type(e).__name__}")
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}")
|
||||
logger.error(
|
||||
f"[verify_license] FAILED: Validation error: {type(e).__name__}: {e}"
|
||||
)
|
||||
raise ValueError(f"Invalid license format: {type(e).__name__}: {e}")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during license verification")
|
||||
logger.exception("[verify_license] FAILED: Unexpected error")
|
||||
raise ValueError("License verification failed: unexpected error")
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from posthog import Posthog
|
||||
|
||||
from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_API_KEY
|
||||
from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED
|
||||
from ee.onyx.configs.app_configs import POSTHOG_HOST
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -20,7 +21,7 @@ def posthog_on_error(error: Any, items: Any) -> None:
|
||||
posthog = Posthog(
|
||||
project_api_key=POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
@@ -33,7 +34,7 @@ if MARKETING_POSTHOG_API_KEY:
|
||||
marketing_posthog = Posthog(
|
||||
project_api_key=MARKETING_POSTHOG_API_KEY,
|
||||
host=POSTHOG_HOST,
|
||||
debug=True,
|
||||
debug=POSTHOG_DEBUG_LOGS_ENABLED,
|
||||
on_error=posthog_on_error,
|
||||
)
|
||||
|
||||
|
||||
14
backend/keys/license_public_key.pem
Normal file
14
backend/keys/license_public_key.pem
Normal file
@@ -0,0 +1,14 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA5DpchQujdxjCwpc4/RQP
|
||||
Hej6rc3SS/5ENCXL0I8NAfMogel0fqG6PKRhonyEh/Bt3P4q18y8vYzAShwf4b6Q
|
||||
aS0WwshbvnkjyWlsK0BY4HLBKPkTpes7kaz8MwmPZDeelvGJ7SNv3FvyJR4QsoSQ
|
||||
GSoB5iTH7hi63TjzdxtckkXoNG+GdVd/koxVDUv2uWcAoWIFTTcbKWyuq2SS/5Sf
|
||||
xdVaIArqfAhLpnNbnM9OS7lZ1xP+29ZXpHxDoeluz35tJLMNBYn9u0y+puo1kW1E
|
||||
TOGizlAq5kmEMsTJ55e9ZuyIV3gZAUaUKe8CxYJPkOGt0Gj6e1jHoHZCBJmaq97Y
|
||||
stKj//84HNBzajaryEZuEfRecJ94ANEjkD8u9cGmW+9VxRe5544zWguP5WMT/nv1
|
||||
0Q+jkOBW2hkY5SS0Rug4cblxiB7bDymWkaX6+sC0VWd5g6WXp36EuP2T0v3mYuHU
|
||||
GDEiWbD44ToREPVwE/M07ny8qhLo/HYk2l8DKFt83hXe7ePBnyQdcsrVbQWOO1na
|
||||
j43OkoU5gOFyOkrk2RmmtCjA8jSnw+tGCTpRaRcshqoWC1MjZyU+8/kDteXNkmv9
|
||||
/B5VxzYSyX+abl7yAu5wLiUPW8l+mOazzWu0nPkmiA160ArxnRyxbGnmp4dUIrt5
|
||||
azYku4tQYLSsSabfhcpeiCsCAwEAAQ==
|
||||
-----END PUBLIC KEY-----
|
||||
@@ -97,10 +97,14 @@ def get_access_for_documents(
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
user should have access to a document if at least one entry in the document's ACL
|
||||
matches one entry in the returned set.
|
||||
"""Returns a list of ACL entries that the user has access to.
|
||||
|
||||
This is meant to be used downstream to filter out documents that the user
|
||||
does not have access to. The user should have access to a document if at
|
||||
least one entry in the document's ACL matches one entry in the returned set.
|
||||
|
||||
NOTE: These strings must be formatted in the same way as the output of
|
||||
DocumentAccess::to_acl.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
|
||||
@@ -125,9 +125,11 @@ class DocumentAccess(ExternalAccess):
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
"""Converts the access state to a set of formatted ACL strings.
|
||||
|
||||
NOTE: When querying for documents, the supplied ACL filter strings must
|
||||
be formatted in the same way as this function.
|
||||
"""
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
|
||||
@@ -1468,7 +1468,7 @@ class OAuth2AuthorizeResponse(BaseModel):
|
||||
|
||||
def generate_state_token(
|
||||
data: Dict[str, str],
|
||||
secret: SecretType,
|
||||
secret: SecretType, # type: ignore[valid-type]
|
||||
lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS,
|
||||
) -> str:
|
||||
data["aud"] = STATE_TOKEN_AUDIENCE
|
||||
@@ -1484,7 +1484,7 @@ def generate_csrf_token() -> str:
|
||||
def create_onyx_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
state_secret: SecretType,
|
||||
state_secret: SecretType, # type: ignore[valid-type]
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
@@ -1504,7 +1504,7 @@ def get_oauth_router(
|
||||
oauth_client: BaseOAuth2,
|
||||
backend: AuthenticationBackend,
|
||||
get_user_manager: UserManagerDependency[models.UP, models.ID],
|
||||
state_secret: SecretType,
|
||||
state_secret: SecretType, # type: ignore[valid-type]
|
||||
redirect_url: Optional[str] = None,
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
|
||||
@@ -134,5 +134,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -98,5 +98,7 @@ for bootstep in base_bootsteps:
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
# Sandbox tasks (file sync, cleanup)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -116,5 +116,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -139,6 +139,27 @@ beat_task_templates: list[dict] = [
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
# Sandbox cleanup tasks
|
||||
{
|
||||
"name": "cleanup-idle-sandboxes",
|
||||
"task": OnyxCeleryTask.CLEANUP_IDLE_SANDBOXES,
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.SANDBOX,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "cleanup-old-snapshots",
|
||||
"task": OnyxCeleryTask.CLEANUP_OLD_SNAPSHOTS,
|
||||
"schedule": timedelta(hours=24),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.SANDBOX,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
|
||||
@@ -31,17 +31,20 @@ from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
|
||||
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
|
||||
@@ -53,7 +56,12 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.file_store.document_batch_storage import DocumentBatchStorage
|
||||
from onyx.file_store.document_batch_storage import get_document_batch_storage
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.server.features.build.indexing.persistent_document_writer import (
|
||||
get_persistent_document_writer,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
@@ -367,6 +375,7 @@ def connector_document_extraction(
|
||||
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
processing_mode = index_attempt.connector_credential_pair.processing_mode
|
||||
is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT
|
||||
|
||||
from_beginning = index_attempt.from_beginning
|
||||
@@ -600,34 +609,103 @@ def connector_document_extraction(
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
memory_tracer.increment_and_maybe_trace()
|
||||
|
||||
# Store documents in storage
|
||||
batch_storage.store_batch(batch_num, doc_batch_cleaned)
|
||||
# cc4a
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
# File system only - write directly to persistent storage,
|
||||
# skip chunking/embedding/Vespa but still track documents in DB
|
||||
|
||||
# Create processing task data
|
||||
processing_batch_data = {
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": batch_num, # 0-indexed
|
||||
}
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Create metadata for the batch
|
||||
index_attempt_metadata = IndexAttemptMetadata(
|
||||
attempt_id=index_attempt_id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
request_id=make_randomized_onyx_request_id("FSI"),
|
||||
structured_id=f"{tenant_id}:{cc_pair_id}:{index_attempt_id}:{batch_num}",
|
||||
batch_num=batch_num,
|
||||
)
|
||||
|
||||
# Queue document processing task
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=docprocessing_priority,
|
||||
)
|
||||
# Upsert documents to PostgreSQL (document table + cc_pair relationship)
|
||||
# This is a subset of what docprocessing does - just DB tracking, no chunking/embedding
|
||||
index_doc_batch_prepare(
|
||||
documents=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=True, # Documents already filtered during extraction
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
# Mark documents as indexed for the CC pair
|
||||
mark_document_as_indexed_for_cc_pair__no_commit(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
document_ids=[doc.id for doc in doc_batch_cleaned],
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Queued document processing batch: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(doc_batch_cleaned)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
# Write documents to persistent file system
|
||||
# Use creator_id for user-segregated storage paths (sandbox isolation)
|
||||
creator_id = index_attempt.connector_credential_pair.creator_id
|
||||
if creator_id is None:
|
||||
raise ValueError(
|
||||
f"ConnectorCredentialPair {index_attempt.connector_credential_pair.id} "
|
||||
"must have a creator_id for persistent document storage"
|
||||
)
|
||||
user_id_str: str = str(creator_id)
|
||||
writer = get_persistent_document_writer(
|
||||
user_id=user_id_str,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
written_paths = writer.write_documents(doc_batch_cleaned)
|
||||
|
||||
# Update coordination directly (no docprocessing task)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
IndexingCoordination.update_batch_completion_and_docs(
|
||||
db_session=db_session,
|
||||
index_attempt_id=index_attempt_id,
|
||||
total_docs_indexed=len(doc_batch_cleaned),
|
||||
new_docs_indexed=len(doc_batch_cleaned),
|
||||
total_chunks=0, # No chunks for file system mode
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
|
||||
logger.info(
|
||||
f"Wrote documents to file system: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(written_paths)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
else:
|
||||
# REGULAR mode (default): Full pipeline - store and queue docprocessing
|
||||
batch_storage.store_batch(batch_num, doc_batch_cleaned)
|
||||
|
||||
# Create processing task data
|
||||
processing_batch_data = {
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"tenant_id": tenant_id,
|
||||
"batch_num": batch_num, # 0-indexed
|
||||
}
|
||||
|
||||
# Queue document processing task
|
||||
app.send_task(
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
kwargs=processing_batch_data,
|
||||
queue=OnyxCeleryQueues.DOCPROCESSING,
|
||||
priority=docprocessing_priority,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
total_doc_batches_queued += 1
|
||||
|
||||
logger.info(
|
||||
f"Queued document processing batch: "
|
||||
f"batch_num={batch_num} "
|
||||
f"docs={len(doc_batch_cleaned)} "
|
||||
f"attempt={index_attempt_id}"
|
||||
)
|
||||
|
||||
# Check checkpoint size periodically
|
||||
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
|
||||
@@ -663,6 +741,24 @@ def connector_document_extraction(
|
||||
total_batches=batch_num,
|
||||
)
|
||||
|
||||
# Trigger file sync to user's sandbox (if running) - only for FILE_SYSTEM mode
|
||||
# This syncs the newly written documents from S3 to any running sandbox pod
|
||||
if processing_mode == ProcessingMode.FILE_SYSTEM:
|
||||
creator_id = index_attempt.connector_credential_pair.creator_id
|
||||
if creator_id:
|
||||
app.send_task(
|
||||
OnyxCeleryTask.SANDBOX_FILE_SYNC,
|
||||
kwargs={
|
||||
"user_id": str(creator_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.SANDBOX,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered sandbox file sync for user {creator_id} "
|
||||
f"after indexing complete"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Document extraction failed: "
|
||||
|
||||
@@ -207,6 +207,9 @@ OPENSEARCH_HOST = os.environ.get("OPENSEARCH_HOST") or "localhost"
|
||||
OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 9200)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
@@ -1042,3 +1045,14 @@ STRIPE_PUBLISHABLE_KEY_URL = (
|
||||
)
|
||||
# Override for local testing with Stripe test keys (pk_test_*)
|
||||
STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY")
|
||||
# Persistent Document Storage Configuration
|
||||
# When enabled, indexed documents are written to local filesystem with hierarchical structure
|
||||
PERSISTENT_DOCUMENT_STORAGE_ENABLED = (
|
||||
os.environ.get("PERSISTENT_DOCUMENT_STORAGE_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Base directory path for persistent document storage (local filesystem)
|
||||
# Example: /var/onyx/indexed-docs or /app/indexed-docs
|
||||
PERSISTENT_DOCUMENT_STORAGE_PATH = os.environ.get(
|
||||
"PERSISTENT_DOCUMENT_STORAGE_PATH", "/app/indexed-docs"
|
||||
)
|
||||
|
||||
@@ -241,6 +241,7 @@ class NotificationType(str, Enum):
|
||||
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
|
||||
RELEASE_NOTES = "release_notes"
|
||||
ASSISTANT_FILES_READY = "assistant_files_ready"
|
||||
FEATURE_ANNOUNCEMENT = "feature_announcement"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
@@ -327,6 +328,7 @@ class FileOrigin(str, Enum):
|
||||
PLAINTEXT_CACHE = "plaintext_cache"
|
||||
OTHER = "other"
|
||||
QUERY_HISTORY_CSV = "query_history_csv"
|
||||
SANDBOX_SNAPSHOT = "sandbox_snapshot"
|
||||
USER_FILE = "user_file"
|
||||
|
||||
|
||||
@@ -344,6 +346,7 @@ class MilestoneRecordType(str, Enum):
|
||||
MULTIPLE_ASSISTANTS = "multiple_assistants"
|
||||
CREATED_ASSISTANT = "created_assistant"
|
||||
CREATED_ONYX_BOT = "created_onyx_bot"
|
||||
REQUESTED_CONNECTOR = "requested_connector"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
@@ -383,6 +386,9 @@ class OnyxCeleryQueues:
|
||||
# KG processing queue
|
||||
KG_PROCESSING = "kg_processing"
|
||||
|
||||
# Sandbox processing queue
|
||||
SANDBOX = "sandbox"
|
||||
|
||||
|
||||
class OnyxRedisLocks:
|
||||
PRIMARY_WORKER = "da_lock:primary_worker"
|
||||
@@ -431,6 +437,10 @@ class OnyxRedisLocks:
|
||||
# Release notes
|
||||
RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch"
|
||||
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES_BEAT_LOCK = "da_lock:cleanup_idle_sandboxes_beat"
|
||||
CLEANUP_OLD_SNAPSHOTS_BEAT_LOCK = "da_lock:cleanup_old_snapshots_beat"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences"
|
||||
@@ -556,6 +566,13 @@ class OnyxCeleryTask:
|
||||
CHECK_KG_PROCESSING_CLUSTERING_ONLY = "check_kg_processing_clustering_only"
|
||||
KG_RESET_SOURCE_INDEX = "kg_reset_source_index"
|
||||
|
||||
# Sandbox cleanup
|
||||
CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes"
|
||||
CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots"
|
||||
|
||||
# Sandbox file sync
|
||||
SANDBOX_FILE_SYNC = "sandbox_file_sync"
|
||||
|
||||
|
||||
# this needs to correspond to the matching entry in supervisord
|
||||
ONYX_CELERY_BEAT_HEARTBEAT_KEY = "onyx:celery:beat:heartbeat"
|
||||
|
||||
@@ -89,6 +89,9 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
meeting_date_unix = transcript["date"]
|
||||
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
|
||||
|
||||
# Build hierarchy based on meeting date (year-month)
|
||||
year_month = meeting_date.strftime("%Y-%m")
|
||||
|
||||
meeting_organizer_email = transcript["organizer_email"]
|
||||
organizer_email_user_info = [BasicExpertInfo(email=meeting_organizer_email)]
|
||||
|
||||
@@ -102,6 +105,14 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.FIREFLIES,
|
||||
semantic_identifier=meeting_title,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": [year_month],
|
||||
"year_month": year_month,
|
||||
"meeting_title": meeting_title,
|
||||
"organizer_email": meeting_organizer_email,
|
||||
}
|
||||
},
|
||||
metadata={
|
||||
k: str(v)
|
||||
for k, v in {
|
||||
|
||||
@@ -240,8 +240,21 @@ def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
def _convert_pr_to_document(
|
||||
pull_request: PullRequest, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
repo_full_name = pull_request.base.repo.full_name if pull_request.base else ""
|
||||
# Split full_name (e.g., "owner/repo") into owner and repo
|
||||
parts = repo_full_name.split("/", 1)
|
||||
owner_name = parts[0] if parts else ""
|
||||
repo_name = parts[1] if len(parts) > 1 else repo_full_name
|
||||
|
||||
doc_metadata = {
|
||||
"repo": repo_full_name,
|
||||
"hierarchy": {
|
||||
"source_path": [owner_name, repo_name, "pull_requests"],
|
||||
"owner": owner_name,
|
||||
"repo": repo_name,
|
||||
"object_type": "pull_request",
|
||||
},
|
||||
}
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
sections=[
|
||||
@@ -259,7 +272,7 @@ def _convert_pr_to_document(
|
||||
else None
|
||||
),
|
||||
# this metadata is used in perm sync
|
||||
doc_metadata=doc_metadata.model_dump(),
|
||||
doc_metadata=doc_metadata,
|
||||
metadata={
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
@@ -316,8 +329,21 @@ def _fetch_issue_comments(issue: Issue) -> str:
|
||||
def _convert_issue_to_document(
|
||||
issue: Issue, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
repo_name = issue.repository.full_name if issue.repository else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
repo_full_name = issue.repository.full_name if issue.repository else ""
|
||||
# Split full_name (e.g., "owner/repo") into owner and repo
|
||||
parts = repo_full_name.split("/", 1)
|
||||
owner_name = parts[0] if parts else ""
|
||||
repo_name = parts[1] if len(parts) > 1 else repo_full_name
|
||||
|
||||
doc_metadata = {
|
||||
"repo": repo_full_name,
|
||||
"hierarchy": {
|
||||
"source_path": [owner_name, repo_name, "issues"],
|
||||
"owner": owner_name,
|
||||
"repo": repo_name,
|
||||
"object_type": "issue",
|
||||
},
|
||||
}
|
||||
return Document(
|
||||
id=issue.html_url,
|
||||
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
|
||||
@@ -327,7 +353,7 @@ def _convert_issue_to_document(
|
||||
# updated_at is UTC time but is timezone unaware
|
||||
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
|
||||
# this metadata is used in perm sync
|
||||
doc_metadata=doc_metadata.model_dump(),
|
||||
doc_metadata=doc_metadata,
|
||||
metadata={
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
|
||||
@@ -390,7 +390,9 @@ class GmailConnector(
|
||||
"""
|
||||
List all user emails if we are on a Google Workspace domain.
|
||||
If the domain is gmail.com, or if we attempt to call the Admin SDK and
|
||||
get a 404, fall back to using the single user.
|
||||
get a 404 or 403, fall back to using the single user.
|
||||
A 404 indicates a personal Gmail account with no Workspace domain.
|
||||
A 403 indicates insufficient permissions (e.g., OAuth user without admin privileges).
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -413,6 +415,13 @@ class GmailConnector(
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
elif e.resp.status == 403:
|
||||
logger.warning(
|
||||
"Received 403 from Admin SDK; this may indicate insufficient permissions "
|
||||
"(e.g., OAuth user without admin privileges or service account without "
|
||||
"domain-wide delegation). Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
|
||||
def _fetch_threads_impl(
|
||||
|
||||
@@ -46,6 +46,138 @@ from onyx.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Cache for folder path lookups to avoid redundant API calls
|
||||
# Maps folder_id -> (folder_name, parent_id)
|
||||
_folder_cache: dict[str, tuple[str, str | None]] = {}
|
||||
|
||||
|
||||
def _get_folder_info(
|
||||
service: GoogleDriveService, folder_id: str
|
||||
) -> tuple[str, str | None]:
|
||||
"""Fetch folder name and parent ID, with caching."""
|
||||
if folder_id in _folder_cache:
|
||||
return _folder_cache[folder_id]
|
||||
|
||||
try:
|
||||
folder = (
|
||||
service.files()
|
||||
.get(
|
||||
fileId=folder_id,
|
||||
fields="name, parents",
|
||||
supportsAllDrives=True,
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
folder_name = folder.get("name", "Unknown")
|
||||
parents = folder.get("parents", [])
|
||||
parent_id = parents[0] if parents else None
|
||||
_folder_cache[folder_id] = (folder_name, parent_id)
|
||||
return folder_name, parent_id
|
||||
except HttpError as e:
|
||||
logger.warning(f"Failed to get folder info for {folder_id}: {e}")
|
||||
_folder_cache[folder_id] = ("Unknown", None)
|
||||
return "Unknown", None
|
||||
|
||||
|
||||
def _get_drive_name(service: GoogleDriveService, drive_id: str) -> str:
|
||||
"""Fetch shared drive name."""
|
||||
cache_key = f"drive_{drive_id}"
|
||||
if cache_key in _folder_cache:
|
||||
return _folder_cache[cache_key][0]
|
||||
|
||||
try:
|
||||
drive = service.drives().get(driveId=drive_id).execute()
|
||||
drive_name = drive.get("name", f"Shared Drive {drive_id}")
|
||||
_folder_cache[cache_key] = (drive_name, None)
|
||||
return drive_name
|
||||
except HttpError as e:
|
||||
logger.warning(f"Failed to get drive name for {drive_id}: {e}")
|
||||
_folder_cache[cache_key] = (f"Shared Drive {drive_id}", None)
|
||||
return f"Shared Drive {drive_id}"
|
||||
|
||||
|
||||
def build_folder_path(
|
||||
file: GoogleDriveFileType,
|
||||
service: GoogleDriveService,
|
||||
drive_id: str | None = None,
|
||||
user_email: str | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Build the full folder path for a file by walking up the parent chain.
|
||||
Returns a list of folder names from root to immediate parent.
|
||||
|
||||
Args:
|
||||
file: The Google Drive file object
|
||||
service: Google Drive service instance
|
||||
drive_id: Optional drive ID (will be extracted from file if not provided)
|
||||
user_email: Optional user email to check ownership for "My Drive" vs "Shared with me"
|
||||
"""
|
||||
path_parts: list[str] = []
|
||||
|
||||
# Get drive_id from file if not provided
|
||||
if drive_id is None:
|
||||
drive_id = file.get("driveId")
|
||||
|
||||
# Check if file is owned by the user (for distinguishing "My Drive" vs "Shared with me")
|
||||
is_owned_by_user = False
|
||||
if user_email:
|
||||
owners = file.get("owners", [])
|
||||
is_owned_by_user = any(
|
||||
owner.get("emailAddress", "").lower() == user_email.lower()
|
||||
for owner in owners
|
||||
)
|
||||
|
||||
# Get the file's parent folder ID
|
||||
parents = file.get("parents", [])
|
||||
if not parents:
|
||||
# File is at root level
|
||||
if drive_id:
|
||||
return [_get_drive_name(service, drive_id)]
|
||||
# If not in a shared drive, check if it's owned by the user
|
||||
if is_owned_by_user:
|
||||
return ["My Drive"]
|
||||
else:
|
||||
return ["Shared with me"]
|
||||
|
||||
parent_id: str | None = parents[0]
|
||||
|
||||
# Walk up the folder hierarchy (limit to 50 levels to prevent infinite loops)
|
||||
visited: set[str] = set()
|
||||
for _ in range(50):
|
||||
if not parent_id or parent_id in visited:
|
||||
break
|
||||
visited.add(parent_id)
|
||||
|
||||
folder_name, next_parent = _get_folder_info(service, parent_id)
|
||||
|
||||
# Check if we've reached the root (parent is the drive itself or no parent)
|
||||
if next_parent is None:
|
||||
# This folder's name is either the drive root, My Drive, or Shared with me
|
||||
if drive_id:
|
||||
path_parts.insert(0, _get_drive_name(service, drive_id))
|
||||
else:
|
||||
# Not in a shared drive - determine if it's "My Drive" or "Shared with me"
|
||||
if is_owned_by_user:
|
||||
path_parts.insert(0, "My Drive")
|
||||
else:
|
||||
path_parts.insert(0, "Shared with me")
|
||||
break
|
||||
else:
|
||||
path_parts.insert(0, folder_name)
|
||||
parent_id = next_parent
|
||||
|
||||
# If we didn't find a root, determine the root based on ownership and drive
|
||||
if not path_parts:
|
||||
if drive_id:
|
||||
return [_get_drive_name(service, drive_id)]
|
||||
elif is_owned_by_user:
|
||||
return ["My Drive"]
|
||||
else:
|
||||
return ["Shared with me"]
|
||||
|
||||
return path_parts
|
||||
|
||||
|
||||
# This is not a standard valid unicode char, it is used by the docs advanced API to
|
||||
# represent smart chips (elements like dates and doc links).
|
||||
SMART_CHIP_CHAR = "\ue907"
|
||||
@@ -526,12 +658,33 @@ def _convert_drive_item_to_document(
|
||||
else None
|
||||
)
|
||||
|
||||
# Build doc_metadata with hierarchy information
|
||||
file_name = file.get("name", "")
|
||||
mime_type = file.get("mimeType", "")
|
||||
drive_id = file.get("driveId")
|
||||
|
||||
# Build full folder path by walking up the parent chain
|
||||
# Pass retriever_email to determine if file is in "My Drive" vs "Shared with me"
|
||||
source_path = build_folder_path(
|
||||
file, _get_drive_service(), drive_id, retriever_email
|
||||
)
|
||||
|
||||
doc_metadata = {
|
||||
"hierarchy": {
|
||||
"source_path": source_path,
|
||||
"drive_id": drive_id,
|
||||
"file_name": file_name,
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
}
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file.get("name", ""),
|
||||
semantic_identifier=file_name,
|
||||
doc_metadata=doc_metadata,
|
||||
metadata={
|
||||
"owner_names": ", ".join(
|
||||
owner.get("displayName", "") for owner in file.get("owners", [])
|
||||
|
||||
@@ -39,11 +39,11 @@ PERMISSION_FULL_DESCRIPTION = (
|
||||
"permissions(id, emailAddress, type, domain, allowFileDiscovery, permissionDetails)"
|
||||
)
|
||||
FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, "
|
||||
"nextPageToken, files(mimeType, id, name, driveId, parents, "
|
||||
"modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
)
|
||||
FILE_FIELDS_WITH_PERMISSIONS = (
|
||||
f"nextPageToken, files(mimeType, id, name, {PERMISSION_FULL_DESCRIPTION}, permissionIds, "
|
||||
f"nextPageToken, files(mimeType, id, name, driveId, parents, {PERMISSION_FULL_DESCRIPTION}, permissionIds, "
|
||||
"modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)"
|
||||
)
|
||||
SLIM_FILE_FIELDS = (
|
||||
|
||||
@@ -490,6 +490,13 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=ticket.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": ["Tickets"],
|
||||
"object_type": "ticket",
|
||||
"object_id": ticket.id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -615,6 +622,13 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=company.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": ["Companies"],
|
||||
"object_type": "company",
|
||||
"object_id": company.id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -738,6 +752,13 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=deal.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": ["Deals"],
|
||||
"object_type": "deal",
|
||||
"object_id": deal.id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -881,6 +902,13 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=contact.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata=metadata,
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": ["Contacts"],
|
||||
"object_type": "contact",
|
||||
"object_id": contact.id,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -274,6 +274,10 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
# Cast the sections list to the expected type
|
||||
typed_sections = cast(list[TextSection | ImageSection], sections)
|
||||
|
||||
# Extract team name for hierarchy
|
||||
team_name = (node.get("team") or {}).get("name") or "Unknown Team"
|
||||
identifier = node.get("identifier", node["id"])
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
id=node["id"],
|
||||
@@ -282,6 +286,13 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
semantic_identifier=f"[{node['identifier']}] {node['title']}",
|
||||
title=node["title"],
|
||||
doc_updated_at=time_str_to_utc(node["updatedAt"]),
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": [team_name],
|
||||
"team_name": team_name,
|
||||
"identifier": identifier,
|
||||
}
|
||||
},
|
||||
metadata={
|
||||
k: str(v)
|
||||
for k, v in {
|
||||
|
||||
@@ -244,6 +244,9 @@ def convert_metadata_dict_to_list_of_strings(
|
||||
Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key
|
||||
points to a list of values, each value generates a unique pair.
|
||||
|
||||
NOTE: Whatever formatting strategy is used here to generate a key-value
|
||||
string must be replicated when constructing query filters.
|
||||
|
||||
Args:
|
||||
metadata: The metadata dict to convert where values can be either a
|
||||
string or a list of strings.
|
||||
|
||||
@@ -234,6 +234,8 @@ def thread_to_doc(
|
||||
"\n", " "
|
||||
)
|
||||
|
||||
channel_name = channel["name"]
|
||||
|
||||
return Document(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
|
||||
sections=[
|
||||
@@ -247,7 +249,14 @@ def thread_to_doc(
|
||||
semantic_identifier=doc_sem_id,
|
||||
doc_updated_at=get_latest_message_time(thread),
|
||||
primary_owners=valid_experts,
|
||||
metadata={"Channel": channel["name"]},
|
||||
doc_metadata={
|
||||
"hierarchy": {
|
||||
"source_path": [channel_name],
|
||||
"channel_name": channel_name,
|
||||
"channel_id": channel_id,
|
||||
}
|
||||
},
|
||||
metadata={"Channel": channel_name},
|
||||
external_access=channel_access,
|
||||
)
|
||||
|
||||
|
||||
@@ -116,6 +116,8 @@ class UserFileFilters(BaseModel):
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters, UserFileFilters):
|
||||
# NOTE: These strings must be formatted in the same way as the output of
|
||||
# DocumentAccess::to_acl.
|
||||
access_control_list: list[str] | None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
@@ -116,7 +117,14 @@ def get_connector_credential_pairs_for_user(
|
||||
eager_load_user: bool = False,
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""Get connector credential pairs for a user.
|
||||
|
||||
Args:
|
||||
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
|
||||
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
|
||||
"""
|
||||
if eager_load_user:
|
||||
assert (
|
||||
eager_load_credential
|
||||
@@ -142,6 +150,9 @@ def get_connector_credential_pairs_for_user(
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
if processing_mode is not None:
|
||||
stmt = stmt.where(ConnectorCredentialPair.processing_mode == processing_mode)
|
||||
|
||||
if order_by_desc:
|
||||
stmt = stmt.order_by(desc(ConnectorCredentialPair.id))
|
||||
|
||||
@@ -160,6 +171,7 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
eager_load_user: bool = False,
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_connector_credential_pairs_for_user(
|
||||
@@ -172,6 +184,7 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
eager_load_user=eager_load_user,
|
||||
order_by_desc=order_by_desc,
|
||||
source=source,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
|
||||
|
||||
@@ -501,6 +514,7 @@ def add_credential_to_connector(
|
||||
initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.SCHEDULED,
|
||||
last_successful_index_time: datetime | None = None,
|
||||
seeding_flow: bool = False,
|
||||
processing_mode: ProcessingMode = ProcessingMode.REGULAR,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
|
||||
@@ -566,6 +580,7 @@ def add_credential_to_connector(
|
||||
access_type=access_type,
|
||||
auto_sync_options=auto_sync_options,
|
||||
last_successful_index_time=last_successful_index_time,
|
||||
processing_mode=processing_mode,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
|
||||
@@ -56,6 +56,13 @@ class IndexingMode(str, PyEnum):
|
||||
REINDEX = "reindex"
|
||||
|
||||
|
||||
class ProcessingMode(str, PyEnum):
|
||||
"""Determines how documents are processed after fetching."""
|
||||
|
||||
REGULAR = "REGULAR" # Full pipeline: chunk → embed → Vespa
|
||||
FILE_SYSTEM = "FILE_SYSTEM" # Write to file system only
|
||||
|
||||
|
||||
class SyncType(str, PyEnum):
|
||||
DOCUMENT_SET = "document_set"
|
||||
USER_GROUP = "user_group"
|
||||
@@ -194,3 +201,39 @@ class SwitchoverType(str, PyEnum):
|
||||
REINDEX = "reindex"
|
||||
ACTIVE_ONLY = "active_only"
|
||||
INSTANT = "instant"
|
||||
|
||||
|
||||
# Onyx Build Mode Enums
|
||||
class BuildSessionStatus(str, PyEnum):
|
||||
ACTIVE = "active"
|
||||
IDLE = "idle"
|
||||
|
||||
|
||||
class SandboxStatus(str, PyEnum):
|
||||
PROVISIONING = "provisioning"
|
||||
RUNNING = "running"
|
||||
IDLE = "idle"
|
||||
SLEEPING = "sleeping" # Pod terminated, snapshots saved to S3
|
||||
TERMINATED = "terminated"
|
||||
FAILED = "failed"
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Check if sandbox is in an active state (running or idle)."""
|
||||
return self in (SandboxStatus.RUNNING, SandboxStatus.IDLE)
|
||||
|
||||
def is_terminal(self) -> bool:
|
||||
"""Check if sandbox is in a terminal state."""
|
||||
return self in (SandboxStatus.TERMINATED, SandboxStatus.FAILED)
|
||||
|
||||
def is_sleeping(self) -> bool:
|
||||
"""Check if sandbox is sleeping (pod terminated but can be restored)."""
|
||||
return self == SandboxStatus.SLEEPING
|
||||
|
||||
|
||||
class ArtifactType(str, PyEnum):
|
||||
WEB_APP = "web_app"
|
||||
PPTX = "pptx"
|
||||
DOCX = "docx"
|
||||
IMAGE = "image"
|
||||
MARKDOWN = "markdown"
|
||||
EXCEL = "excel"
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
from pydantic import ValidationError
|
||||
|
||||
from sqlalchemy.dialects.postgresql import JSONB as PGJSONB
|
||||
from sqlalchemy.dialects.postgresql import UUID as PGUUID
|
||||
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
@@ -55,8 +56,12 @@ from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import (
|
||||
AccessType,
|
||||
ArtifactType,
|
||||
BuildSessionStatus,
|
||||
EmbeddingPrecision,
|
||||
IndexingMode,
|
||||
ProcessingMode,
|
||||
SandboxStatus,
|
||||
SyncType,
|
||||
SyncStatus,
|
||||
MCPAuthenticationType,
|
||||
@@ -609,6 +614,16 @@ class ConnectorCredentialPair(Base):
|
||||
Enum(IndexingMode, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
# Determines how documents are processed after fetching:
|
||||
# REGULAR: Full pipeline (chunk → embed → Vespa)
|
||||
# FILE_SYSTEM: Write to file system only (for CLI agent sandbox)
|
||||
processing_mode: Mapped[ProcessingMode] = mapped_column(
|
||||
Enum(ProcessingMode, native_enum=False),
|
||||
nullable=False,
|
||||
default=ProcessingMode.REGULAR,
|
||||
server_default="REGULAR",
|
||||
)
|
||||
|
||||
connector: Mapped["Connector"] = relationship(
|
||||
"Connector", back_populates="credentials"
|
||||
)
|
||||
@@ -4142,3 +4157,202 @@ class TenantUsage(Base):
|
||||
# Ensure only one row per window start (tenant_id is in the schema name)
|
||||
UniqueConstraint("window_start", name="uq_tenant_usage_window"),
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Build Mode (CLI Agent Platform)"""
|
||||
|
||||
|
||||
class BuildSession(Base):
|
||||
"""Stores metadata about CLI agent build sessions."""
|
||||
|
||||
__tablename__ = "build_session"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
status: Mapped[BuildSessionStatus] = mapped_column(
|
||||
Enum(BuildSessionStatus, native_enum=False, name="buildsessionstatus"),
|
||||
nullable=False,
|
||||
default=BuildSessionStatus.ACTIVE,
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
last_activity_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
nextjs_port: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
user: Mapped[User | None] = relationship("User", foreign_keys=[user_id])
|
||||
artifacts: Mapped[list["Artifact"]] = relationship(
|
||||
"Artifact", back_populates="session", cascade="all, delete-orphan"
|
||||
)
|
||||
messages: Mapped[list["BuildMessage"]] = relationship(
|
||||
"BuildMessage", back_populates="session", cascade="all, delete-orphan"
|
||||
)
|
||||
snapshots: Mapped[list["Snapshot"]] = relationship(
|
||||
"Snapshot", back_populates="session", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_build_session_user_created", "user_id", desc("created_at")),
|
||||
Index("ix_build_session_status", "status"),
|
||||
)
|
||||
|
||||
|
||||
class Sandbox(Base):
|
||||
"""Stores sandbox container metadata for users (one sandbox per user)."""
|
||||
|
||||
__tablename__ = "sandbox"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("user.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
unique=True,
|
||||
)
|
||||
container_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
status: Mapped[SandboxStatus] = mapped_column(
|
||||
Enum(SandboxStatus, native_enum=False, name="sandboxstatus"),
|
||||
nullable=False,
|
||||
default=SandboxStatus.PROVISIONING,
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
last_heartbeat: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped[User] = relationship("User")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_sandbox_status", "status"),
|
||||
Index("ix_sandbox_container_id", "container_id"),
|
||||
)
|
||||
|
||||
|
||||
class Artifact(Base):
|
||||
"""Stores metadata about artifacts generated by CLI agents."""
|
||||
|
||||
__tablename__ = "artifact"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
session_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
type: Mapped[ArtifactType] = mapped_column(
|
||||
Enum(ArtifactType, native_enum=False, name="artifacttype"), nullable=False
|
||||
)
|
||||
# path of artifact in sandbox relative to outputs/
|
||||
path: Mapped[str] = mapped_column(String, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
session: Mapped[BuildSession] = relationship(
|
||||
"BuildSession", back_populates="artifacts"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_artifact_session_created", "session_id", desc("created_at")),
|
||||
Index("ix_artifact_type", "type"),
|
||||
)
|
||||
|
||||
|
||||
class Snapshot(Base):
|
||||
"""Stores metadata about session output snapshots."""
|
||||
|
||||
__tablename__ = "snapshot"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
session_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
storage_path: Mapped[str] = mapped_column(String, nullable=False)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
session: Mapped[BuildSession] = relationship(
|
||||
"BuildSession", back_populates="snapshots"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_snapshot_session_created", "session_id", desc("created_at")),
|
||||
)
|
||||
|
||||
|
||||
class BuildMessage(Base):
|
||||
"""Stores messages exchanged in build sessions.
|
||||
|
||||
All message data is stored in message_metadata as JSON (the raw ACP packet).
|
||||
The turn_index groups all assistant responses under the user prompt they respond to.
|
||||
|
||||
Packet types stored in message_metadata:
|
||||
- user_message: {type: "user_message", content: {...}}
|
||||
- agent_message: {type: "agent_message", content: {...}} (accumulated from chunks)
|
||||
- agent_thought: {type: "agent_thought", content: {...}} (accumulated from chunks)
|
||||
- tool_call_progress: {type: "tool_call_progress", status: "completed", ...} (only completed)
|
||||
- agent_plan_update: {type: "agent_plan_update", entries: [...]} (upserted, latest only)
|
||||
"""
|
||||
|
||||
__tablename__ = "build_message"
|
||||
|
||||
id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), primary_key=True, default=uuid4
|
||||
)
|
||||
session_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True),
|
||||
ForeignKey("build_session.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
)
|
||||
turn_index: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
type: Mapped[MessageType] = mapped_column(
|
||||
Enum(MessageType, native_enum=False, name="messagetype"), nullable=False
|
||||
)
|
||||
message_metadata: Mapped[dict[str, Any]] = mapped_column(PGJSONB, nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
session: Mapped[BuildSession] = relationship(
|
||||
"BuildSession", back_populates="messages"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_build_message_session_turn", "session_id", "turn_index", "created_at"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -28,8 +28,8 @@ of "minimum value clipping".
|
||||
## On time decay and boosting
|
||||
Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also
|
||||
varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any
|
||||
additive or multiplicative boost to it. Ie. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50 percentile, it brings its under the 0.6 and is now the worst match.
|
||||
additive or multiplicative boost to it. i.e. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score,
|
||||
it doesn't bring a result from the top of the range to 50th percentile, it brings it under the 0.6 and is now the worst match.
|
||||
Same logic applies to additive boosting.
|
||||
|
||||
So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last
|
||||
@@ -40,7 +40,7 @@ and vector would make the docs which only came because of time filter very low s
|
||||
scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before
|
||||
being fetched and returned to the user. But there are other issues of including these:
|
||||
- There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the
|
||||
contents. If there are lots of updates, this may miss
|
||||
contents. If there are lots of updates, this may miss.
|
||||
- There is not a good way to normalize this field, the best is to clip it on the bottom.
|
||||
- This would require using min-max norm but z-score norm is better for the other functions due to things like it being less
|
||||
sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing
|
||||
|
||||
@@ -559,6 +559,36 @@ class OpenSearchClient:
|
||||
"""
|
||||
self._client.indices.refresh(index=self._index_name)
|
||||
|
||||
def set_cluster_auto_create_index_setting(self, enabled: bool) -> bool:
|
||||
"""Sets the cluster auto create index setting.
|
||||
|
||||
By default, when you index a document to a non-existent index,
|
||||
OpenSearch will automatically create the index. This behavior is
|
||||
undesirable so this function exposes the ability to disable it.
|
||||
|
||||
See
|
||||
https://docs.opensearch.org/latest/install-and-configure/configuring-opensearch/index/#updating-cluster-settings-using-the-api
|
||||
|
||||
Args:
|
||||
enabled: Whether to enable the auto create index setting.
|
||||
|
||||
Returns:
|
||||
True if the setting was updated successfully, False otherwise. Does
|
||||
not raise.
|
||||
"""
|
||||
try:
|
||||
body = {"persistent": {"action.auto_create_index": enabled}}
|
||||
response = self._client.cluster.put_settings(body=body)
|
||||
if response.get("acknowledged", False):
|
||||
logger.info(f"Successfully set action.auto_create_index to {enabled}.")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Failed to update setting: {response}.")
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception("Error setting auto_create_index.")
|
||||
return False
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Pings the OpenSearch cluster.
|
||||
|
||||
|
||||
@@ -3,7 +3,10 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
@@ -68,6 +71,18 @@ from shared_configs.model_server_models import Embedding
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def generate_opensearch_filtered_access_control_list(
|
||||
access: DocumentAccess,
|
||||
) -> list[str]:
|
||||
"""Generates an access control list with PUBLIC_DOC_PAT removed.
|
||||
|
||||
In the OpenSearch schema this is represented by PUBLIC_FIELD_NAME.
|
||||
"""
|
||||
access_control_list = access.to_acl()
|
||||
access_control_list.discard(PUBLIC_DOC_PAT)
|
||||
return list(access_control_list)
|
||||
|
||||
|
||||
def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
chunk: DocumentChunk,
|
||||
score: float | None,
|
||||
@@ -152,10 +167,9 @@ def _convert_onyx_chunk_to_opensearch_document(
|
||||
metadata_suffix=chunk.metadata_suffix_keyword,
|
||||
last_updated=chunk.source_document.doc_updated_at,
|
||||
public=chunk.access.is_public,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
access_control_list=list(chunk.access.to_acl()),
|
||||
access_control_list=generate_opensearch_filtered_access_control_list(
|
||||
chunk.access
|
||||
),
|
||||
global_boost=chunk.boost,
|
||||
semantic_identifier=chunk.source_document.semantic_identifier,
|
||||
image_file_id=chunk.image_file_id,
|
||||
@@ -440,15 +454,28 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
search pipelines.
|
||||
"""
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary."
|
||||
f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if necessary, "
|
||||
f"with embedding dimension {embedding_dim}."
|
||||
)
|
||||
expected_mappings = DocumentSchema.get_document_schema(
|
||||
embedding_dim, self._tenant_state.multitenant
|
||||
)
|
||||
if not self._os_client.index_exists():
|
||||
if not self._os_client.set_cluster_auto_create_index_setting(enabled=False):
|
||||
logger.error(
|
||||
f"Failed to disable the auto create index setting for index {self._index_name}. "
|
||||
"This may cause unexpected index creation when indexing documents into an index that does not exist. "
|
||||
"Not taking any further action..."
|
||||
)
|
||||
if USING_AWS_MANAGED_OPENSEARCH:
|
||||
index_settings = (
|
||||
DocumentSchema.get_index_settings_for_aws_managed_opensearch()
|
||||
)
|
||||
else:
|
||||
index_settings = DocumentSchema.get_index_settings()
|
||||
self._os_client.create_index(
|
||||
mappings=expected_mappings,
|
||||
settings=DocumentSchema.get_index_settings(),
|
||||
settings=index_settings,
|
||||
)
|
||||
if not self._os_client.validate_index(
|
||||
expected_mappings=expected_mappings,
|
||||
@@ -578,8 +605,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# here so we don't have to think about passing in the
|
||||
# appropriate types into this dict.
|
||||
if update_request.access is not None:
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = list(
|
||||
update_request.access.to_acl()
|
||||
properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = (
|
||||
generate_opensearch_filtered_access_control_list(
|
||||
update_request.access
|
||||
)
|
||||
)
|
||||
if update_request.document_sets is not None:
|
||||
properties_to_update[DOCUMENT_SETS_FIELD_NAME] = list(
|
||||
@@ -625,13 +654,11 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[DocumentSectionRequest],
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
# TODO(andrei): Remove this from the new interface at some point; we
|
||||
# should not be exposing this.
|
||||
batch_retrieval: bool = False,
|
||||
# TODO(andrei): Add a param for whether to retrieve hidden docs.
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
TODO(andrei): Consider implementing this method to retrieve on document
|
||||
@@ -646,6 +673,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_body = DocumentQuery.get_from_document_id_query(
|
||||
document_id=chunk_request.document_id,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
max_chunk_size=chunk_request.max_chunk_size,
|
||||
min_chunk_index=chunk_request.min_chunk_ind,
|
||||
max_chunk_index=chunk_request.max_chunk_ind,
|
||||
@@ -672,9 +701,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
query_type: QueryType,
|
||||
# TODO(andrei): When going over ACL look very carefully at
|
||||
# access_control_list. Notice DocumentAccess::to_acl prepends every
|
||||
# string with a type.
|
||||
filters: IndexFilters,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
@@ -688,6 +714,8 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
num_candidates=1000, # TODO(andrei): Magic number.
|
||||
num_hits=num_to_retrieve,
|
||||
tenant_state=self._tenant_state,
|
||||
index_filters=filters,
|
||||
include_hidden=False,
|
||||
)
|
||||
search_hits: list[SearchHit[DocumentChunk]] = self._os_client.search(
|
||||
body=query_body,
|
||||
|
||||
@@ -172,24 +172,23 @@ class DocumentChunk(BaseModel):
|
||||
return serialized_exclude_none
|
||||
|
||||
@field_serializer("last_updated", mode="wrap")
|
||||
def serialize_datetime_fields_to_epoch_millis(
|
||||
def serialize_datetime_fields_to_epoch_seconds(
|
||||
self, value: datetime | None, handler: SerializerFunctionWrapHandler
|
||||
) -> int | None:
|
||||
"""
|
||||
Serializes datetime fields to milliseconds since the Unix epoch.
|
||||
Serializes datetime fields to seconds since the Unix epoch.
|
||||
|
||||
If there is no datetime, returns None.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
value = set_or_convert_timezone_to_utc(value)
|
||||
# timestamp returns a float in seconds so convert to millis.
|
||||
return int(value.timestamp() * 1000)
|
||||
return int(value.timestamp())
|
||||
|
||||
@field_validator("last_updated", mode="before")
|
||||
@classmethod
|
||||
def parse_epoch_millis_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses milliseconds since the Unix epoch to a datetime object.
|
||||
def parse_epoch_seconds_to_datetime(cls, value: Any) -> datetime | None:
|
||||
"""Parses seconds since the Unix epoch to a datetime object.
|
||||
|
||||
If the input is None, returns None.
|
||||
|
||||
@@ -204,7 +203,7 @@ class DocumentChunk(BaseModel):
|
||||
raise ValueError(
|
||||
f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead."
|
||||
)
|
||||
return datetime.fromtimestamp(value / 1000, tz=timezone.utc)
|
||||
return datetime.fromtimestamp(value, tz=timezone.utc)
|
||||
|
||||
@field_serializer("tenant_id", mode="wrap")
|
||||
def serialize_tenant_state(
|
||||
@@ -354,11 +353,9 @@ class DocumentSchema:
|
||||
},
|
||||
SOURCE_TYPE_FIELD_NAME: {"type": "keyword"},
|
||||
METADATA_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# TODO(andrei): Check if Vespa stores seconds, we may wanna do
|
||||
# seconds here not millis.
|
||||
LAST_UPDATED_FIELD_NAME: {
|
||||
"type": "date",
|
||||
"format": "epoch_millis",
|
||||
"format": "epoch_second",
|
||||
# For some reason date defaults to False, even though it
|
||||
# would make sense to sort by date.
|
||||
"doc_values": True,
|
||||
@@ -366,14 +363,21 @@ class DocumentSchema:
|
||||
# Access control fields.
|
||||
# Whether the doc is public. Could have fallen under access
|
||||
# control list but is such a broad and critical filter that it
|
||||
# is its own field.
|
||||
# is its own field. If true, ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
# should have no effect on queries.
|
||||
PUBLIC_FIELD_NAME: {"type": "boolean"},
|
||||
# Access control list for the doc, excluding public access,
|
||||
# which is covered above.
|
||||
# If a user's access set contains at least one entry from this
|
||||
# set, the user should be able to retrieve this document. This
|
||||
# only applies if public is set to false; public non-hidden
|
||||
# documents are always visible to anyone in a given tenancy
|
||||
# regardless of this field.
|
||||
ACCESS_CONTROL_LIST_FIELD_NAME: {"type": "keyword"},
|
||||
# Whether the doc is hidden from search results. Should clobber
|
||||
# all other search filters; up to search implementations to
|
||||
# guarantee this.
|
||||
# Whether the doc is hidden from search results.
|
||||
# Should clobber all other access search filters, namely
|
||||
# PUBLIC_FIELD_NAME and ACCESS_CONTROL_LIST_FIELD_NAME; up to
|
||||
# search implementations to guarantee this.
|
||||
HIDDEN_FIELD_NAME: {"type": "boolean"},
|
||||
GLOBAL_BOOST_FIELD_NAME: {"type": "integer"},
|
||||
# This field is only used for displaying a useful name for the
|
||||
@@ -447,7 +451,6 @@ class DocumentSchema:
|
||||
DOCUMENT_ID_FIELD_NAME: {"type": "keyword"},
|
||||
CHUNK_INDEX_FIELD_NAME: {"type": "integer"},
|
||||
# The maximum number of tokens this chunk's content can hold.
|
||||
# TODO(andrei): Can we generalize this to embedding type?
|
||||
MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"},
|
||||
},
|
||||
}
|
||||
@@ -473,16 +476,22 @@ class DocumentSchema:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_bulk_index_settings() -> dict[str, Any]:
|
||||
def get_index_settings_for_aws_managed_opensearch() -> dict[str, Any]:
|
||||
"""
|
||||
Optimized settings for bulk indexing: disable refresh and replicas.
|
||||
Settings for AWS-managed OpenSearch.
|
||||
|
||||
Our AWS-managed OpenSearch cluster has 3 data nodes in 3 availability
|
||||
zones.
|
||||
- We use 3 shards to distribute load across all data nodes.
|
||||
- We use 2 replicas to ensure each shard has a copy in each
|
||||
availability zone. This is a hard requirement from AWS. The number
|
||||
of data copies, including the primary (not a replica) copy, must be
|
||||
divisible by the number of AZs.
|
||||
"""
|
||||
return {
|
||||
"index": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0, # No replication during bulk load.
|
||||
# Disables auto-refresh, improves performance in pure indexing (no searching) scenarios.
|
||||
"refresh_interval": "-1",
|
||||
"number_of_shards": 3,
|
||||
"number_of_replicas": 2,
|
||||
# Required for vector search.
|
||||
"knn": True,
|
||||
"knn.algo_param.ef_search": EF_SEARCH,
|
||||
|
||||
@@ -1,21 +1,36 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_PHRASE_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_CONTENT_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_KEYWORD_WEIGHT
|
||||
from onyx.document_index.opensearch.constants import SEARCH_TITLE_VECTOR_WEIGHT
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import CONTENT_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc
|
||||
from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TENANT_ID_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME
|
||||
|
||||
# Normalization pipelines combine document scores from multiple query clauses.
|
||||
# The number and ordering of weights should match the query clauses. The values
|
||||
@@ -91,6 +106,11 @@ assert (
|
||||
# given search. This value is configurable in the index settings.
|
||||
DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000
|
||||
|
||||
# For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume
|
||||
# that the document was last updated this many days ago for the purpose of time
|
||||
# cutoff filtering during retrieval.
|
||||
ASSUMED_DOCUMENT_AGE_DAYS = 90
|
||||
|
||||
|
||||
class DocumentQuery:
|
||||
"""
|
||||
@@ -103,6 +123,8 @@ class DocumentQuery:
|
||||
def get_from_document_id_query(
|
||||
document_id: str,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
max_chunk_size: int,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
@@ -120,6 +142,8 @@ class DocumentQuery:
|
||||
document_id: Onyx document ID. Notably not an OpenSearch document
|
||||
ID, which points to what Onyx would refer to as a chunk.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the document retrieval query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
max_chunk_size: Document chunks are categorized by the maximum
|
||||
number of tokens they can hold. This parameter specifies the
|
||||
maximum size category of document chunks to retrieve.
|
||||
@@ -136,28 +160,21 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final ID search query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
# TODO(andrei): Fix tenant stuff.
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
filter_clauses.append(range_clause)
|
||||
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=min_chunk_index,
|
||||
max_chunk_index=max_chunk_index,
|
||||
max_chunk_size=max_chunk_size,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
final_get_ids_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
# We include this to make sure OpenSearch does not revert to
|
||||
@@ -195,15 +212,22 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final delete query.
|
||||
"""
|
||||
filter_clauses: list[dict[str, Any]] = [
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
]
|
||||
|
||||
if tenant_state.multitenant:
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
|
||||
filter_clauses = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
# Delete hidden docs too.
|
||||
include_hidden=True,
|
||||
access_control_list=None,
|
||||
source_types=[],
|
||||
tags=[],
|
||||
document_sets=[],
|
||||
user_file_ids=[],
|
||||
project_id=None,
|
||||
time_cutoff=None,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
max_chunk_size=None,
|
||||
document_id=document_id,
|
||||
)
|
||||
final_delete_query: dict[str, Any] = {
|
||||
"query": {"bool": {"filter": filter_clauses}},
|
||||
}
|
||||
@@ -217,19 +241,25 @@ class DocumentQuery:
|
||||
num_candidates: int,
|
||||
num_hits: int,
|
||||
tenant_state: TenantState,
|
||||
index_filters: IndexFilters,
|
||||
include_hidden: bool,
|
||||
) -> dict[str, Any]:
|
||||
"""Returns a final hybrid search query.
|
||||
|
||||
This query can be directly supplied to the OpenSearch client.
|
||||
NOTE: This query can be directly supplied to the OpenSearch client, but
|
||||
it MUST be supplied in addition to a search pipeline. The results from
|
||||
hybrid search are not meaningful without that step.
|
||||
|
||||
Args:
|
||||
query_text: The text to query for.
|
||||
query_vector: The vector embedding of the text to query for.
|
||||
num_candidates: The number of candidates to consider for vector
|
||||
num_candidates: The number of neighbors to consider for vector
|
||||
similarity search. Generally more candidates improves search
|
||||
quality at the cost of performance.
|
||||
num_hits: The final number of hits to return.
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
index_filters: Filters for the hybrid search query.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the final hybrid search query.
|
||||
@@ -243,31 +273,47 @@ class DocumentQuery:
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector, num_candidates
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_hybrid_search_filters(tenant_state)
|
||||
hybrid_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
include_hidden=include_hidden,
|
||||
# TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to
|
||||
# now. This should not cause any issues but it can introduce
|
||||
# redundant filters in queries that may affect performance.
|
||||
access_control_list=index_filters.access_control_list,
|
||||
source_types=index_filters.source_type or [],
|
||||
tags=index_filters.tags or [],
|
||||
document_sets=index_filters.document_set or [],
|
||||
user_file_ids=index_filters.user_file_ids or [],
|
||||
project_id=index_filters.project_id,
|
||||
time_cutoff=index_filters.time_cutoff,
|
||||
min_chunk_index=None,
|
||||
max_chunk_index=None,
|
||||
)
|
||||
match_highlights_configuration = (
|
||||
DocumentQuery._get_match_highlights_configuration()
|
||||
)
|
||||
|
||||
# See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
}
|
||||
}
|
||||
],
|
||||
# TODO(andrei): When revisiting our hybrid query logic see if
|
||||
# this needs to be nested one level down.
|
||||
"filter": hybrid_search_filters,
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
# Applied to all the sub-queries. Source:
|
||||
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
# Does AND for each filter in the list.
|
||||
"filter": {"bool": {"filter": hybrid_search_filters}},
|
||||
}
|
||||
}
|
||||
|
||||
# NOTE: By default, hybrid search retrieves "size"-many results from
|
||||
# each OpenSearch shard before aggregation. Source:
|
||||
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
|
||||
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
"highlight": match_highlights_configuration,
|
||||
}
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@staticmethod
|
||||
@@ -294,7 +340,8 @@ class DocumentQuery:
|
||||
pipeline.
|
||||
|
||||
NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed
|
||||
in a single hybrid query.
|
||||
in a single hybrid query. Source:
|
||||
https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
|
||||
Args:
|
||||
query_text: The text of the query to search for.
|
||||
@@ -305,6 +352,7 @@ class DocumentQuery:
|
||||
hybrid_search_queries: list[dict[str, Any]] = [
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the title.
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -313,6 +361,7 @@ class DocumentQuery:
|
||||
},
|
||||
{
|
||||
"knn": {
|
||||
# Match on semantic similarity of the content.
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
@@ -322,36 +371,273 @@ class DocumentQuery:
|
||||
{
|
||||
"multi_match": {
|
||||
"query": query_text,
|
||||
# TODO(andrei): Ask Yuhong do we want this?
|
||||
# Either fuzzy match on the analyzed title (boosted 2x), or
|
||||
# exact match on exact title keywords (no OpenSearch
|
||||
# analysis done on the title). See
|
||||
# https://docs.opensearch.org/latest/mappings/supported-field-types/keyword/
|
||||
"fields": [f"{TITLE_FIELD_NAME}^2", f"{TITLE_FIELD_NAME}.keyword"],
|
||||
# Returns the score of the best match of the fields above.
|
||||
# See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/multi-match/
|
||||
"type": "best_fields",
|
||||
}
|
||||
},
|
||||
# Fuzzy match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match/
|
||||
{"match": {CONTENT_FIELD_NAME: {"query": query_text}}},
|
||||
# Exact match on the OpenSearch-analyzed content. See
|
||||
# https://docs.opensearch.org/latest/query-dsl/full-text/match-phrase/
|
||||
{"match_phrase": {CONTENT_FIELD_NAME: {"query": query_text, "boost": 1.5}}},
|
||||
]
|
||||
|
||||
return hybrid_search_queries
|
||||
|
||||
@staticmethod
|
||||
def _get_hybrid_search_filters(tenant_state: TenantState) -> list[dict[str, Any]]:
|
||||
"""Returns filters for hybrid search.
|
||||
def _get_search_filters(
|
||||
tenant_state: TenantState,
|
||||
include_hidden: bool,
|
||||
access_control_list: list[str] | None,
|
||||
source_types: list[DocumentSource],
|
||||
tags: list[Tag],
|
||||
document_sets: list[str],
|
||||
user_file_ids: list[UUID],
|
||||
project_id: int | None,
|
||||
time_cutoff: datetime | None,
|
||||
min_chunk_index: int | None,
|
||||
max_chunk_index: int | None,
|
||||
max_chunk_size: int | None = None,
|
||||
document_id: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns filters to be passed into the "filter" key of a search query.
|
||||
|
||||
For now only fetches public and not hidden documents.
|
||||
The "filter" key applies a logical AND operator to its elements, so
|
||||
every subfilter must evaluate to true in order for the document to be
|
||||
retrieved. This function returns a list of such subfilters.
|
||||
See https://docs.opensearch.org/latest/query-dsl/compound/bool/
|
||||
|
||||
The return of this function is not sufficient to be directly supplied to
|
||||
the OpenSearch client. See get_hybrid_search_query.
|
||||
Args:
|
||||
tenant_state: Tenant state containing the tenant ID.
|
||||
include_hidden: Whether to include hidden documents.
|
||||
access_control_list: Access control list for the documents to
|
||||
retrieve. If None, there is no restriction on the documents that
|
||||
can be retrieved. If not None, only public documents can be
|
||||
retrieved, or non-public documents where at least one acl
|
||||
provided here is present in the document's acl list.
|
||||
source_types: If supplied, only documents of one of these source
|
||||
types will be retrieved.
|
||||
tags: If supplied, only documents with an entry in their metadata
|
||||
list corresponding to a tag will be retrieved.
|
||||
document_sets: If supplied, only documents with at least one
|
||||
document set ID from this list will be retrieved.
|
||||
user_file_ids: If supplied, only document IDs in this list will be
|
||||
retrieved.
|
||||
project_id: If not None, only documents with this project ID in user
|
||||
projects will be retrieved.
|
||||
time_cutoff: Time cutoff for the documents to retrieve. If not None,
|
||||
Documents which were last updated before this date will not be
|
||||
returned. For documents which do not have a value for their last
|
||||
updated time, we assume some default age of
|
||||
ASSUMED_DOCUMENT_AGE_DAYS for when the document was last
|
||||
updated.
|
||||
min_chunk_index: The minimum chunk index to retrieve, inclusive. If
|
||||
None, no minimum chunk index will be applied.
|
||||
max_chunk_index: The maximum chunk index to retrieve, inclusive. If
|
||||
None, no maximum chunk index will be applied.
|
||||
max_chunk_size: The type of chunk to retrieve, specified by the
|
||||
maximum number of tokens it can hold. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
NOTE: See DocumentChunk.max_chunk_size.
|
||||
document_id: The document ID to retrieve. If None, no filter will be
|
||||
applied for this. Defaults to None.
|
||||
WARNING: This filters on the same property as user_file_ids.
|
||||
Although it would never make sense to supply both, note that if
|
||||
user_file_ids is supplied and does not contain document_id, no
|
||||
matches will be retrieved.
|
||||
|
||||
TODO(andrei): Add ACL filters and stuff.
|
||||
Returns:
|
||||
A list of filters to be passed into the "filter" key of a search
|
||||
query.
|
||||
"""
|
||||
hybrid_search_filters: list[dict[str, Any]] = [
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}},
|
||||
{"term": {HIDDEN_FIELD_NAME: {"value": False}}},
|
||||
]
|
||||
|
||||
def _get_acl_visibility_filter(
|
||||
access_control_list: list[str],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
acl_visibility_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
acl_visibility_filter["bool"]["should"].append(
|
||||
{"term": {PUBLIC_FIELD_NAME: {"value": True}}}
|
||||
)
|
||||
for acl in access_control_list:
|
||||
acl_subclause: dict[str, Any] = {
|
||||
"term": {ACCESS_CONTROL_LIST_FIELD_NAME: {"value": acl}}
|
||||
}
|
||||
acl_visibility_filter["bool"]["should"].append(acl_subclause)
|
||||
return acl_visibility_filter
|
||||
|
||||
def _get_source_type_filter(
|
||||
source_types: list[DocumentSource],
|
||||
) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
source_type_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for source_type in source_types:
|
||||
source_type_filter["bool"]["should"].append(
|
||||
{"term": {SOURCE_TYPE_FIELD_NAME: {"value": source_type.value}}}
|
||||
)
|
||||
return source_type_filter
|
||||
|
||||
def _get_tag_filter(tags: list[Tag]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
tag_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for tag in tags:
|
||||
# Kind of an abstraction leak, see
|
||||
# convert_metadata_dict_to_list_of_strings for why metadata list
|
||||
# entries are expected to look this way.
|
||||
tag_str = f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}"
|
||||
tag_filter["bool"]["should"].append(
|
||||
{"term": {METADATA_LIST_FIELD_NAME: {"value": tag_str}}}
|
||||
)
|
||||
return tag_filter
|
||||
|
||||
def _get_document_set_filter(document_sets: list[str]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
document_set_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for document_set in document_sets:
|
||||
document_set_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_SETS_FIELD_NAME: {"value": document_set}}}
|
||||
)
|
||||
return document_set_filter
|
||||
|
||||
def _get_user_file_id_filter(user_file_ids: list[UUID]) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_file_id_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
for user_file_id in user_file_ids:
|
||||
user_file_id_filter["bool"]["should"].append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": str(user_file_id)}}}
|
||||
)
|
||||
return user_file_id_filter
|
||||
|
||||
def _get_user_project_filter(project_id: int) -> dict[str, Any]:
|
||||
# Logical OR operator on its elements.
|
||||
user_project_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
user_project_filter["bool"]["should"].append(
|
||||
{"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}}
|
||||
)
|
||||
return user_project_filter
|
||||
|
||||
def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]:
|
||||
# Convert to UTC if not already so the cutoff is comparable to the
|
||||
# document data.
|
||||
time_cutoff = set_or_convert_timezone_to_utc(time_cutoff)
|
||||
# Logical OR operator on its elements.
|
||||
time_cutoff_filter: dict[str, Any] = {"bool": {"should": []}}
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"range": {
|
||||
LAST_UPDATED_FIELD_NAME: {"gte": int(time_cutoff.timestamp())}
|
||||
}
|
||||
}
|
||||
)
|
||||
if time_cutoff < datetime.now(timezone.utc) - timedelta(
|
||||
days=ASSUMED_DOCUMENT_AGE_DAYS
|
||||
):
|
||||
# Since the time cutoff is older than ASSUMED_DOCUMENT_AGE_DAYS
|
||||
# ago, we include documents which have no
|
||||
# LAST_UPDATED_FIELD_NAME value.
|
||||
time_cutoff_filter["bool"]["should"].append(
|
||||
{
|
||||
"bool": {
|
||||
"must_not": {"exists": {"field": LAST_UPDATED_FIELD_NAME}}
|
||||
}
|
||||
}
|
||||
)
|
||||
return time_cutoff_filter
|
||||
|
||||
def _get_chunk_index_filter(
|
||||
min_chunk_index: int | None, max_chunk_index: int | None
|
||||
) -> dict[str, Any]:
|
||||
range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}}
|
||||
if min_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index
|
||||
if max_chunk_index is not None:
|
||||
range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index
|
||||
return range_clause
|
||||
|
||||
filter_clauses: list[dict[str, Any]] = []
|
||||
|
||||
if not include_hidden:
|
||||
filter_clauses.append({"term": {HIDDEN_FIELD_NAME: {"value": False}}})
|
||||
|
||||
if access_control_list is not None:
|
||||
# If an access control list is provided, the caller can only
|
||||
# retrieve public documents, and non-public documents where at least
|
||||
# one acl provided here is present in the document's acl list. If
|
||||
# there is explicitly no list provided, we make no restrictions on
|
||||
# the documents that can be retrieved.
|
||||
filter_clauses.append(_get_acl_visibility_filter(access_control_list))
|
||||
|
||||
if source_types:
|
||||
# If at least one source type is provided, the caller will only
|
||||
# retrieve documents whose source type is present in this input
|
||||
# list.
|
||||
filter_clauses.append(_get_source_type_filter(source_types))
|
||||
|
||||
if tags:
|
||||
# If at least one tag is provided, the caller will only retrieve
|
||||
# documents where at least one tag provided here is present in the
|
||||
# document's metadata list.
|
||||
filter_clauses.append(_get_tag_filter(tags))
|
||||
|
||||
if document_sets:
|
||||
# If at least one document set is provided, the caller will only
|
||||
# retrieve documents where at least one document set provided here
|
||||
# is present in the document's document sets list.
|
||||
filter_clauses.append(_get_document_set_filter(document_sets))
|
||||
|
||||
if user_file_ids:
|
||||
# If at least one user file ID is provided, the caller will only
|
||||
# retrieve documents where the document ID is in this input list of
|
||||
# file IDs. Note that these IDs correspond to Onyx documents whereas
|
||||
# the entries retrieved from the document index correspond to Onyx
|
||||
# document chunks.
|
||||
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
|
||||
|
||||
if project_id is not None:
|
||||
# If a project ID is provided, the caller will only retrieve
|
||||
# documents where the project ID provided here is present in the
|
||||
# document's user projects list.
|
||||
filter_clauses.append(_get_user_project_filter(project_id))
|
||||
|
||||
if time_cutoff is not None:
|
||||
# If a time cutoff is provided, the caller will only retrieve
|
||||
# documents where the document was last updated at or after the time
|
||||
# cutoff. For documents which do not have a value for
|
||||
# LAST_UPDATED_FIELD_NAME, we assume some default age for the
|
||||
# purposes of time cutoff.
|
||||
filter_clauses.append(_get_time_cutoff_filter(time_cutoff))
|
||||
|
||||
if min_chunk_index is not None or max_chunk_index is not None:
|
||||
filter_clauses.append(
|
||||
_get_chunk_index_filter(min_chunk_index, max_chunk_index)
|
||||
)
|
||||
|
||||
if document_id is not None:
|
||||
# WARNING: If user_file_ids has elements and if none of them are
|
||||
# document_id, no matches will be retrieved.
|
||||
filter_clauses.append(
|
||||
{"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}}
|
||||
)
|
||||
|
||||
if max_chunk_size is not None:
|
||||
filter_clauses.append(
|
||||
{"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}}
|
||||
)
|
||||
|
||||
if tenant_state.multitenant:
|
||||
hybrid_search_filters.append(
|
||||
filter_clauses.append(
|
||||
{"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}}
|
||||
)
|
||||
return hybrid_search_filters
|
||||
|
||||
return filter_clauses
|
||||
|
||||
@staticmethod
|
||||
def _get_match_highlights_configuration() -> dict[str, Any]:
|
||||
@@ -378,4 +664,5 @@ class DocumentQuery:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return match_highlights_configuration
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.feature_flags.interface import FeatureFlagProvider
|
||||
from onyx.feature_flags.interface import NoOpFeatureFlagProvider
|
||||
from onyx.utils.variable_functionality import (
|
||||
@@ -19,7 +20,7 @@ def get_default_feature_flag_provider() -> FeatureFlagProvider:
|
||||
Returns:
|
||||
FeatureFlagProvider: The configured feature flag provider instance
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
if MULTI_TENANT or DEV_MODE:
|
||||
return fetch_versioned_implementation_with_fallback(
|
||||
module="onyx.feature_flags.factory",
|
||||
attribute="get_posthog_feature_flag_provider",
|
||||
|
||||
@@ -369,6 +369,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
# New output item added
|
||||
output_item = parsed_chunk.get("item", {})
|
||||
if output_item.get("type") == "function_call":
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -394,6 +396,8 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
elif event_type == "response.function_call_arguments.delta":
|
||||
content_part: Optional[str] = parsed_chunk.get("delta", None)
|
||||
if content_part:
|
||||
# Track that we've received tool calls via streaming
|
||||
self._has_streamed_tool_calls = True
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=ChatCompletionToolCallChunk(
|
||||
@@ -491,22 +495,72 @@ def _patch_openai_responses_chunk_parser() -> None:
|
||||
|
||||
elif event_type == "response.completed":
|
||||
# Final event signaling all output items (including parallel tool calls) are done
|
||||
# Check if we already received tool calls via streaming events
|
||||
# There is an issue where OpenAI (not via Azure) will give back the tool calls streamed out as tokens
|
||||
# But on Azure, it's only given out all at once. OpenAI also happens to give back the tool calls in the
|
||||
# response.completed event so we need to throw it out here or there are duplicate tool calls.
|
||||
has_streamed_tool_calls = getattr(self, "_has_streamed_tool_calls", False)
|
||||
|
||||
response_data = parsed_chunk.get("response", {})
|
||||
# Determine finish reason based on response content
|
||||
finish_reason = "stop"
|
||||
if response_data.get("output"):
|
||||
for item in response_data["output"]:
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
finish_reason = "tool_calls"
|
||||
break
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason=finish_reason,
|
||||
usage=None,
|
||||
output_items = response_data.get("output", [])
|
||||
|
||||
# Check if there are function_call items in the output
|
||||
has_function_calls = any(
|
||||
isinstance(item, dict) and item.get("type") == "function_call"
|
||||
for item in output_items
|
||||
)
|
||||
|
||||
if has_function_calls and not has_streamed_tool_calls:
|
||||
# Azure's Responses API returns all tool calls in response.completed
|
||||
# without streaming them incrementally. Extract them here.
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for idx, item in enumerate(output_items):
|
||||
if isinstance(item, dict) and item.get("type") == "function_call":
|
||||
tool_calls.append(
|
||||
ChatCompletionToolCallChunk(
|
||||
id=item.get("call_id"),
|
||||
index=idx,
|
||||
type="function",
|
||||
function=ChatCompletionToolCallFunctionChunk(
|
||||
name=item.get("name"),
|
||||
arguments=item.get("arguments", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
index=0,
|
||||
delta=Delta(tool_calls=tool_calls),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
elif has_function_calls:
|
||||
# Tool calls were already streamed, just signal completion
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="tool_calls",
|
||||
usage=None,
|
||||
)
|
||||
else:
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
tool_use=None,
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage=None,
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -631,6 +685,40 @@ def _patch_openai_responses_transform_response() -> None:
|
||||
LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign]
|
||||
|
||||
|
||||
def _patch_azure_responses_should_fake_stream() -> None:
|
||||
"""
|
||||
Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False.
|
||||
|
||||
By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models
|
||||
not in its database. This causes Azure custom model deployments to buffer the entire
|
||||
response before yielding, resulting in poor time-to-first-token.
|
||||
|
||||
Azure's Responses API supports native streaming, so we override this to always use
|
||||
real streaming (SyncResponsesAPIStreamingIterator).
|
||||
"""
|
||||
from litellm.llms.azure.responses.transformation import (
|
||||
AzureOpenAIResponsesAPIConfig,
|
||||
)
|
||||
|
||||
if (
|
||||
getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "")
|
||||
== "_patched_should_fake_stream"
|
||||
):
|
||||
return
|
||||
|
||||
def _patched_should_fake_stream(
|
||||
self: Any,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
# Azure Responses API supports native streaming - never fake it
|
||||
return False
|
||||
|
||||
_patched_should_fake_stream.__name__ = "_patched_should_fake_stream"
|
||||
AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign]
|
||||
|
||||
|
||||
def apply_monkey_patches() -> None:
|
||||
"""
|
||||
Apply all necessary monkey patches to LiteLLM for compatibility.
|
||||
@@ -640,12 +728,13 @@ def apply_monkey_patches() -> None:
|
||||
- Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content
|
||||
- Patching OpenAiResponsesToChatCompletionStreamIterator.chunk_parser for OpenAI Responses API
|
||||
- Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses
|
||||
- Patching LiteLLMResponsesTransformationHandler._convert_content_str_to_input_text for tool content types
|
||||
- Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming
|
||||
"""
|
||||
_patch_ollama_transform_request()
|
||||
_patch_ollama_chunk_parser()
|
||||
_patch_openai_responses_chunk_parser()
|
||||
_patch_openai_responses_transform_response()
|
||||
_patch_azure_responses_should_fake_stream()
|
||||
|
||||
|
||||
def _extract_reasoning_content(message: dict) -> Tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -738,7 +738,7 @@ def model_is_reasoning_model(model_name: str, model_provider: str) -> bool:
|
||||
|
||||
# Fallback: try using litellm.supports_reasoning() for newer models
|
||||
try:
|
||||
logger.debug("Falling back to `litellm.supports_reasoning`")
|
||||
# logger.debug("Falling back to `litellm.supports_reasoning`")
|
||||
full_model_name = (
|
||||
f"{model_provider}/{model_name}"
|
||||
if model_provider not in model_name
|
||||
|
||||
@@ -63,6 +63,8 @@ from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.standard_oauth import router as standard_oauth_router
|
||||
from onyx.server.features.build.api.api import nextjs_assets_router
|
||||
from onyx.server.features.build.api.api import router as build_router
|
||||
from onyx.server.features.default_assistant.api import (
|
||||
router as default_assistant_router,
|
||||
)
|
||||
@@ -376,6 +378,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, projects_router)
|
||||
include_router_with_global_prefix_prepended(application, build_router)
|
||||
include_router_with_global_prefix_prepended(application, nextjs_assets_router)
|
||||
include_router_with_global_prefix_prepended(application, document_set_router)
|
||||
include_router_with_global_prefix_prepended(application, search_settings_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
287
backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md
Normal file
@@ -0,0 +1,287 @@
|
||||
# Discord Bot Multitenant Architecture
|
||||
|
||||
This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client.
|
||||
|
||||
## Overview
|
||||
|
||||
The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through:
|
||||
|
||||
- **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys
|
||||
- **API Client**: Stateless HTTP client that accepts dynamic API keys per request
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ OnyxDiscordClient │
|
||||
│ │
|
||||
│ ┌─────────────────────────┐ ┌─────────────────────────────┐ │
|
||||
│ │ DiscordCacheManager │ │ OnyxAPIClient │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ guild_id → tenant_id │───▶│ send_chat_message( │ │
|
||||
│ │ tenant_id → api_key │ │ message, │ │
|
||||
│ │ │ │ api_key=<per-tenant>, │ │
|
||||
│ └─────────────────────────┘ │ persona_id=... │ │
|
||||
│ │ ) │ │
|
||||
│ └─────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Component Details
|
||||
|
||||
### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`)
|
||||
|
||||
The `DiscordCacheManager` maintains two critical in-memory mappings:
|
||||
|
||||
```python
|
||||
class DiscordCacheManager:
|
||||
_guild_tenants: dict[int, str] # guild_id → tenant_id
|
||||
_api_keys: dict[str, str] # tenant_id → api_key
|
||||
_lock: asyncio.Lock # Concurrency control
|
||||
```
|
||||
|
||||
#### Key Responsibilities
|
||||
|
||||
| Function | Purpose |
|
||||
|----------|---------|
|
||||
| `get_tenant(guild_id)` | O(1) lookup: guild → tenant |
|
||||
| `get_api_key(tenant_id)` | O(1) lookup: tenant → API key |
|
||||
| `refresh_all()` | Full cache rebuild from database |
|
||||
| `refresh_guild()` | Incremental update for single guild |
|
||||
|
||||
#### API Key Provisioning Strategy
|
||||
|
||||
API keys are **lazily provisioned** - only created when first needed:
|
||||
|
||||
```python
|
||||
async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]:
|
||||
needs_key = tenant_id not in self._api_keys
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# Load guild configs
|
||||
configs = get_discord_bot_configs(db)
|
||||
guild_ids = [c.guild_id for c in configs if c.enabled]
|
||||
|
||||
# Only provision API key if not already cached
|
||||
api_key = None
|
||||
if needs_key:
|
||||
api_key = get_or_create_discord_service_api_key(db, tenant_id)
|
||||
|
||||
return guild_ids, api_key
|
||||
```
|
||||
|
||||
This optimization avoids repeated database calls for API key generation.
|
||||
|
||||
#### Concurrency Control
|
||||
|
||||
All write operations acquire an async lock to prevent race conditions:
|
||||
|
||||
```python
|
||||
async def refresh_all(self) -> None:
|
||||
async with self._lock:
|
||||
# Safe to modify _guild_tenants and _api_keys
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
guild_ids, api_key = await self._load_tenant_data(tenant_id)
|
||||
# Update mappings...
|
||||
```
|
||||
|
||||
Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic.
|
||||
|
||||
---
|
||||
|
||||
### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`)
|
||||
|
||||
The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods.
|
||||
|
||||
#### Key Design: Per-Request API Key Injection
|
||||
|
||||
```python
|
||||
class OnyxAPIClient:
|
||||
async def send_chat_message(
|
||||
self,
|
||||
message: str,
|
||||
api_key: str, # Injected per-request
|
||||
persona_id: int | None,
|
||||
...
|
||||
) -> ChatFullResponse:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}", # Tenant-specific auth
|
||||
}
|
||||
# Make request...
|
||||
```
|
||||
|
||||
The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants:
|
||||
|
||||
```python
|
||||
# Same client, different tenants
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...)
|
||||
await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Coordination Flow
|
||||
|
||||
### Message Processing Pipeline
|
||||
|
||||
When a Discord message arrives, the client coordinates cache and API client:
|
||||
|
||||
```python
|
||||
async def on_message(self, message: Message) -> None:
|
||||
guild_id = message.guild.id
|
||||
|
||||
# Step 1: Cache lookup - guild → tenant
|
||||
tenant_id = self.cache.get_tenant(guild_id)
|
||||
if not tenant_id:
|
||||
return # Guild not registered
|
||||
|
||||
# Step 2: Cache lookup - tenant → API key
|
||||
api_key = self.cache.get_api_key(tenant_id)
|
||||
if not api_key:
|
||||
logger.warning(f"No API key for tenant {tenant_id}")
|
||||
return
|
||||
|
||||
# Step 3: API call with tenant-specific credentials
|
||||
await process_chat_message(
|
||||
message=message,
|
||||
api_key=api_key, # Tenant-specific
|
||||
persona_id=persona_id, # Tenant-specific
|
||||
api_client=self.api_client,
|
||||
)
|
||||
```
|
||||
|
||||
### Startup Sequence
|
||||
|
||||
```python
|
||||
async def setup_hook(self) -> None:
|
||||
# 1. Initialize API client (create aiohttp session)
|
||||
await self.api_client.initialize()
|
||||
|
||||
# 2. Populate cache with all tenants
|
||||
await self.cache.refresh_all()
|
||||
|
||||
# 3. Start background refresh task
|
||||
self._cache_refresh_task = self.loop.create_task(
|
||||
self._periodic_cache_refresh() # Every 60 seconds
|
||||
)
|
||||
```
|
||||
|
||||
### Shutdown Sequence
|
||||
|
||||
```python
|
||||
async def close(self) -> None:
|
||||
# 1. Cancel background refresh
|
||||
if self._cache_refresh_task:
|
||||
self._cache_refresh_task.cancel()
|
||||
|
||||
# 2. Close Discord connection
|
||||
await super().close()
|
||||
|
||||
# 3. Close API client session
|
||||
await self.api_client.close()
|
||||
|
||||
# 4. Clear cache
|
||||
self.cache.clear()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Tenant Isolation Mechanisms
|
||||
|
||||
### 1. Per-Tenant API Keys
|
||||
|
||||
Each tenant has a dedicated service API key:
|
||||
|
||||
```python
|
||||
# backend/onyx/db/discord_bot.py
|
||||
def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str:
|
||||
existing = get_discord_service_api_key(db_session)
|
||||
if existing:
|
||||
return regenerate_key(existing)
|
||||
|
||||
# Create LIMITED role key (chat-only permissions)
|
||||
return insert_api_key(
|
||||
db_session=db_session,
|
||||
api_key_args=APIKeyArgs(
|
||||
name=DISCORD_SERVICE_API_KEY_NAME,
|
||||
role=UserRole.LIMITED, # Minimal permissions
|
||||
),
|
||||
user_id=None, # Service account (system-owned)
|
||||
).api_key
|
||||
```
|
||||
|
||||
### 2. Database Context Variables
|
||||
|
||||
The cache uses context variables for proper tenant-scoped DB sessions:
|
||||
|
||||
```python
|
||||
context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db:
|
||||
# All DB operations scoped to this tenant
|
||||
...
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token)
|
||||
```
|
||||
|
||||
### 3. Enterprise Gating Support
|
||||
|
||||
Gated tenants are filtered during cache refresh:
|
||||
|
||||
```python
|
||||
gated_tenants = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.product_gating",
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
|
||||
for tenant_id in get_all_tenant_ids():
|
||||
if tenant_id in gated_tenants:
|
||||
continue # Skip gated tenants
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Cache Refresh Strategy
|
||||
|
||||
| Trigger | Method | Scope |
|
||||
|---------|--------|-------|
|
||||
| Startup | `refresh_all()` | All tenants |
|
||||
| Periodic (60s) | `refresh_all()` | All tenants |
|
||||
| Guild registration | `refresh_guild()` | Single tenant |
|
||||
|
||||
### Error Handling
|
||||
|
||||
- **Tenant-level errors**: Logged and skipped (doesn't stop other tenants)
|
||||
- **Missing API key**: Bot silently ignores messages from that guild
|
||||
- **Network errors**: Logged, cache continues with stale data until next refresh
|
||||
|
||||
---
|
||||
|
||||
## Key Design Insights
|
||||
|
||||
1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection.
|
||||
|
||||
2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls.
|
||||
|
||||
3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation).
|
||||
|
||||
4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads.
|
||||
|
||||
5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance.
|
||||
|
||||
6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter.
|
||||
|
||||
---
|
||||
|
||||
## File References
|
||||
|
||||
| Component | Path |
|
||||
|-----------|------|
|
||||
| Cache Manager | `backend/onyx/onyxbot/discord/cache.py` |
|
||||
| API Client | `backend/onyx/onyxbot/discord/api_client.py` |
|
||||
| Discord Client | `backend/onyx/onyxbot/discord/client.py` |
|
||||
| API Key DB Operations | `backend/onyx/db/discord_bot.py` |
|
||||
| Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` |
|
||||
| API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` |
|
||||
@@ -564,6 +564,7 @@ def associate_credential_to_connector(
|
||||
access_type=metadata.access_type,
|
||||
auto_sync_options=metadata.auto_sync_options,
|
||||
groups=metadata.groups,
|
||||
processing_mode=metadata.processing_mode,
|
||||
)
|
||||
|
||||
# trigger indexing immediately
|
||||
|
||||
@@ -20,6 +20,7 @@ from google.oauth2.credentials import Credentials
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.email_utils import send_email
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
@@ -29,6 +30,7 @@ from onyx.background.celery.tasks.pruning.tasks import (
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -125,6 +127,7 @@ from onyx.server.documents.models import ConnectorFileInfo
|
||||
from onyx.server.documents.models import ConnectorFilesResponse
|
||||
from onyx.server.documents.models import ConnectorIndexingStatusLite
|
||||
from onyx.server.documents.models import ConnectorIndexingStatusLiteResponse
|
||||
from onyx.server.documents.models import ConnectorRequestSubmission
|
||||
from onyx.server.documents.models import ConnectorSnapshot
|
||||
from onyx.server.documents.models import ConnectorStatus
|
||||
from onyx.server.documents.models import ConnectorUpdateRequest
|
||||
@@ -1759,6 +1762,86 @@ def get_connector_by_id(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/connector-request")
|
||||
def submit_connector_request(
|
||||
request_data: ConnectorRequestSubmission,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> StatusResponse:
|
||||
"""
|
||||
Submit a connector request for Cloud deployments.
|
||||
Tracks via PostHog telemetry and sends email to hello@onyx.app.
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
connector_name = request_data.connector_name.strip()
|
||||
|
||||
if not connector_name:
|
||||
raise HTTPException(status_code=400, detail="Connector name cannot be empty")
|
||||
|
||||
# Get user identifier for telemetry
|
||||
user_email = user.email if user else None
|
||||
distinct_id = user_email or tenant_id
|
||||
|
||||
# Track connector request via PostHog telemetry (Cloud only)
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
if MULTI_TENANT:
|
||||
mt_cloud_telemetry(
|
||||
tenant_id=tenant_id,
|
||||
distinct_id=distinct_id,
|
||||
event=MilestoneRecordType.REQUESTED_CONNECTOR,
|
||||
properties={
|
||||
"connector_name": connector_name,
|
||||
"user_email": user_email,
|
||||
},
|
||||
)
|
||||
|
||||
# Send email notification (if email is configured)
|
||||
if EMAIL_CONFIGURED:
|
||||
try:
|
||||
subject = "Onyx Craft Connector Request"
|
||||
email_body_text = f"""A new connector request has been submitted:
|
||||
|
||||
Connector Name: {connector_name}
|
||||
User Email: {user_email or 'Not provided (anonymous user)'}
|
||||
Tenant ID: {tenant_id}
|
||||
"""
|
||||
email_body_html = f"""<html>
|
||||
<body>
|
||||
<p>A new connector request has been submitted:</p>
|
||||
<ul>
|
||||
<li><strong>Connector Name:</strong> {connector_name}</li>
|
||||
<li><strong>User Email:</strong> {user_email or 'Not provided (anonymous user)'}</li>
|
||||
<li><strong>Tenant ID:</strong> {tenant_id}</li>
|
||||
</ul>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
send_email(
|
||||
user_email="hello@onyx.app",
|
||||
subject=subject,
|
||||
html_body=email_body_html,
|
||||
text_body=email_body_text,
|
||||
)
|
||||
logger.info(
|
||||
f"Connector request email sent to hello@onyx.app for connector: {connector_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but don't fail the request if email fails
|
||||
logger.error(
|
||||
f"Failed to send connector request email for {connector_name}: {e}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Connector request submitted: {connector_name} by user {user_email or 'anonymous'} "
|
||||
f"(tenant: {tenant_id})"
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Connector request submitted successfully. We'll prioritize popular requests!",
|
||||
)
|
||||
|
||||
|
||||
class BasicCCPairInfo(BaseModel):
|
||||
has_successful_run: bool
|
||||
source: DocumentSource
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import PermissionSyncStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
@@ -483,6 +484,7 @@ class ConnectorCredentialPairMetadata(BaseModel):
|
||||
access_type: AccessType
|
||||
auto_sync_options: dict[str, Any] | None = None
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
processing_mode: ProcessingMode = ProcessingMode.REGULAR
|
||||
|
||||
|
||||
class CCStatusUpdateRequest(BaseModel):
|
||||
@@ -523,6 +525,10 @@ class RunConnectorRequest(BaseModel):
|
||||
from_beginning: bool = False
|
||||
|
||||
|
||||
class ConnectorRequestSubmission(BaseModel):
|
||||
connector_name: str
|
||||
|
||||
|
||||
class CCPropertyUpdateRequest(BaseModel):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
2
backend/onyx/server/features/build/.gitignore
vendored
Normal file
2
backend/onyx/server/features/build/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
sandbox/kubernetes/docker/templates/venv/**
|
||||
sandbox/kubernetes/docker/demo_data/**
|
||||
257
backend/onyx/server/features/build/AGENTS.template.md
Normal file
257
backend/onyx/server/features/build/AGENTS.template.md
Normal file
@@ -0,0 +1,257 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance for AI agents when working in this sandbox.
|
||||
|
||||
## Introduction
|
||||
|
||||
You are Steve, an AI agent powering **Onyx Craft**, a feature that allows users to create interactive web applications and dashboards from their company knowledge. You are running in a secure sandbox with access to the user's knowledge sources and the ability to create Next.js applications.
|
||||
|
||||
## Purpose
|
||||
|
||||
Your primary purpose is to assist users in accomplishing their goals by providing information, executing tasks, and offering guidance. I aim to be a reliable partner in problem-solving and task completion.
|
||||
|
||||
## How I Approach Tasks
|
||||
|
||||
When presented with a task, I typically:
|
||||
|
||||
1. Analyze the request to understand what's being asked
|
||||
2. Break down complex problems into manageable steps
|
||||
3. Use appropriate tools and methods to address each step
|
||||
4. Provide clear communication throughout the process
|
||||
5. Deliver results in a helpful and organized manner
|
||||
|
||||
## My Personality Traits
|
||||
|
||||
- Helpful and service-oriented
|
||||
- Detail-focused and thorough
|
||||
- Adaptable to different user needs
|
||||
- Patient when working through complex problems
|
||||
- Honest about my capabilities and limitations
|
||||
|
||||
## Areas I Can Help With
|
||||
|
||||
- Information gathering and research
|
||||
- Knowledge Synthesis
|
||||
- Data processing and analysis
|
||||
- File management and organization
|
||||
- Dashboard creation
|
||||
- Repetitive administrative tasks
|
||||
|
||||
{{USER_CONTEXT}}
|
||||
|
||||
## Your Configuration
|
||||
|
||||
**LLM Provider**: {{LLM_PROVIDER_NAME}}
|
||||
**Model**: {{LLM_MODEL_NAME}}
|
||||
**Next.js Development Server**: Running on port {{NEXTJS_PORT}}
|
||||
{{DISABLED_TOOLS_SECTION}}
|
||||
|
||||
## Your Environment
|
||||
|
||||
You are in an ephemeral virtual machine.
|
||||
|
||||
You currently have Python 3.11.13 and Node v22.21.1.
|
||||
|
||||
**Python Virtual Environment**: A Python virtual environment is pre-configured at `.venv/` with common data science and visualization packages already installed (numpy, pandas, matplotlib, scipy, PIL, etc.). The environment should be automatically activated, but if you run into issues with missing packages, you can explicitly use `.venv/bin/python` or `.venv/bin/pip`.
|
||||
|
||||
If you need additional packages, install them with `pip install <package>` (or `.venv/bin/pip install <package>` if the venv isn't active). For javascript packages, use `npm install <package>` from within the `outputs/web` directory.
|
||||
|
||||
## Organization Info
|
||||
|
||||
The `org_info/` directory contains information about the organization and user context:
|
||||
|
||||
- `AGENTS.md`: Description of available organizational information files
|
||||
- `user_identity_profile.txt`: Contains the current user's name, email, and organization they work for. Use this information when personalizing outputs or when the user asks about their identity.
|
||||
- `organization_structure.json`: Contains a JSON representation of the organization's groups, managers, and their direct reports. Use this to understand reporting relationships and team structures.
|
||||
|
||||
## Available Skills
|
||||
|
||||
{{AVAILABLE_SKILLS_SECTION}}
|
||||
|
||||
Skills contain best practices and guidelines for specific tasks. Always read the relevant skill's SKILL.md file BEFORE starting work that the skill covers.
|
||||
|
||||
## General Capabilities
|
||||
|
||||
### Information Processing
|
||||
|
||||
- Answering questions on diverse topics using available information
|
||||
- Conducting research through web searches and data analysis
|
||||
- Fact-checking and information verification from multiple sources
|
||||
- Summarizing complex information into digestible formats
|
||||
- Processing and analyzing structured and unstructured data
|
||||
|
||||
### Problem Solving
|
||||
|
||||
- Breaking down complex problems into manageable steps
|
||||
- Providing step-by-step solutions to technical challenges
|
||||
- Troubleshooting errors in code or processes
|
||||
- Suggesting alternative approaches when initial attempts fail
|
||||
- Adapting to changing requirements during task execution
|
||||
|
||||
### File System Operations
|
||||
|
||||
- Reading from and writing to files in various formats
|
||||
- Searching for files based on names, patterns, or content
|
||||
- Creating and organizing directory structures
|
||||
- Compressing and archiving files (zip, tar)
|
||||
- Analyzing file contents and extracting relevant information
|
||||
- Converting between different file formats
|
||||
|
||||
## Agent Behavior Guidelines
|
||||
|
||||
**Task Management**: For any non-trivial task involving multiple steps, you should organize your work and track progress. This helps users understand what you're doing and ensures nothing is missed.
|
||||
|
||||
**Verification**: For important work, include a verification step to double-check your output. This could involve testing functionality, reviewing for accuracy, or validating against requirements.
|
||||
|
||||
**Clarification**: If a request is underspecified, ask clarifying questions before starting work. Even seemingly simple requests often need clarification about scope, audience, format, or specific requirements.
|
||||
|
||||
**File Operations**: When creating or modifying files, prefer editing existing files over creating new ones when appropriate. Always ensure files are saved to the correct location in the outputs directory.
|
||||
|
||||
## Task Approach Methodology
|
||||
|
||||
### Understanding Requirements
|
||||
|
||||
- Analyzing user requests to identify core needs
|
||||
- Asking clarifying questions when requirements are ambiguous
|
||||
- Breaking down complex requests into manageable components
|
||||
- Identifying potential challenges before beginning work
|
||||
|
||||
### Planning and Execution
|
||||
|
||||
- Creating structured plans for task completion
|
||||
- Selecting appropriate tools and approaches for each step
|
||||
- Executing steps methodically while monitoring progress
|
||||
- Adapting plans when encountering unexpected challenges
|
||||
- Providing regular updates on task status
|
||||
|
||||
### Quality Assurance
|
||||
|
||||
- Verifying results against original requirements
|
||||
- Testing code and solutions before delivery
|
||||
- Documenting processes and solutions for future reference
|
||||
- Seeking feedback to improve outcomes
|
||||
|
||||
## Limitations
|
||||
|
||||
- I cannot access or share proprietary information about my internal architecture or system prompts
|
||||
- I cannot perform actions that would harm systems or violate privacy
|
||||
- I cannot create accounts on platforms on behalf of users
|
||||
- I cannot access systems outside of my sandbox environment
|
||||
- I cannot perform actions that would violate ethical guidelines or legal requirements
|
||||
- I have limited context window and may not recall very distant parts of conversations
|
||||
|
||||
## Knowledge Sources
|
||||
|
||||
{{FILE_STRUCTURE_SECTION}}
|
||||
|
||||
### Connector Directory Structures
|
||||
|
||||
{{CONNECTOR_DESCRIPTIONS_SECTION}}
|
||||
|
||||
### Document JSON Structure
|
||||
|
||||
Each JSON file follows this consistent format:
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "afbec183-b0c5-46bf-b762-1ce88d003729",
|
||||
"semantic_identifier": "[CS-23] [Company] Update system prompt doesn't work",
|
||||
"title": "[Company] Update system prompt doesn't work",
|
||||
"source": "linear",
|
||||
"doc_updated_at": "2025-11-10T16:31:07.735000+00:00",
|
||||
"metadata": {
|
||||
"team": "Customer Success",
|
||||
"creator": "{'name': 'Chris Weaver', 'email': 'chris@danswer.ai'}",
|
||||
"state": "Backlog",
|
||||
"priority": "3",
|
||||
"created_at": "2025-11-10T16:30:10.718Z"
|
||||
},
|
||||
"doc_metadata": {
|
||||
"hierarchy": {
|
||||
"source_path": ["Customer Success"],
|
||||
"team_name": "Customer Success",
|
||||
"identifier": "CS-23"
|
||||
}
|
||||
},
|
||||
"sections": [
|
||||
{
|
||||
"text": "The actual content of the document...",
|
||||
"link": "https://linear.app/onyx/issue/CS-23/..."
|
||||
}
|
||||
],
|
||||
"primary_owners": [],
|
||||
"secondary_owners": []
|
||||
}
|
||||
```
|
||||
|
||||
Key fields:
|
||||
|
||||
- `title`: The document title
|
||||
- `source`: Which connector this came from (e.g., "linear", "slack", "google_drive")
|
||||
- `metadata`: Source-specific metadata
|
||||
- `sections`: Array of content sections with text and optional links
|
||||
|
||||
**Important**: Do NOT write any files to the `files/` directory. Do NOT edit any files in the `files/` directory. This is read-only knowledge data.
|
||||
|
||||
## Attachments (PRIORITY)
|
||||
|
||||
The `attachments/` directory contains files that the user has explicitly uploaded during this session. **These files are critically important** and should be treated as high-priority context.
|
||||
|
||||
### Why Attachments Matter
|
||||
|
||||
- The user deliberately chose to upload these files, signaling they are directly relevant to the task
|
||||
- These files often contain the specific data, requirements, or examples the user wants you to work with
|
||||
- They may include spreadsheets, documents, images, or code that should inform your work
|
||||
|
||||
### Required Actions
|
||||
|
||||
**At the start of every task, you MUST:**
|
||||
|
||||
1. **Check for attachments**: List the contents of `attachments/` to see what the user has provided
|
||||
2. **Read and analyze each file**: Thoroughly examine every attachment to understand its contents and relevance
|
||||
3. **Reference attachment content**: Use the information from attachments to inform your responses and outputs
|
||||
|
||||
### File Handling
|
||||
|
||||
- Uploaded files may be in various formats: CSV, JSON, PDF, images, text files, etc.
|
||||
- For spreadsheets and data files, examine the structure, columns, and sample data
|
||||
- For documents, extract key information and requirements
|
||||
- For images, analyze and describe their content
|
||||
- For code files, understand the logic and patterns
|
||||
|
||||
**Do NOT ignore user uploaded files.** They are there for a reason and likely contain exactly what you need to complete the task successfully.
|
||||
|
||||
## Outputs Directory
|
||||
|
||||
There is a special folder called `outputs`. Any and all python scripts, javascript apps, generated documents, slides, etc. should go here.
|
||||
Feel free to write/edit anything you find in here.
|
||||
|
||||
## Outputs
|
||||
|
||||
There should be four main types of outputs:
|
||||
|
||||
1. Web Applications / Dashboards
|
||||
|
||||
Generally, you should use
|
||||
|
||||
### Web Applications / Dashboards
|
||||
|
||||
Web applications and dashboards should be written as a webapp built with Next.js, React, and shadcn/ui.. Within the `outputs` directory,
|
||||
there is a folder called `web` that has the skeleton of a basic Next.js app in it. Use this. We do NOT use a `src` directory.
|
||||
|
||||
Use NextJS 16.1.1, React v19, Tailwindcss, and recharts.
|
||||
|
||||
The Next.js app is already running on port {{NEXTJS_PORT}}. Do not run `npm run dev` yourself.
|
||||
|
||||
If the app needs any pre-computation, then create a bash script called `prepare.sh` at the root of the `web` directory.
|
||||
|
||||
**IMPORTANT: See `outputs/web/AGENTS.md` for detailed technical specifications, architecture patterns, component usage guidelines, and styling rules. It is the ground truth for webapp design**
|
||||
|
||||
### Other Output Formats (Coming Soon)
|
||||
|
||||
Additional output formats such as slides, markdown documents, and standalone graphs are coming soon. If the user requests these formats, let them know they're not yet available and suggest building an interactive web application instead, which can include:
|
||||
|
||||
- Data visualizations and charts using recharts
|
||||
- Multi-page layouts with navigation
|
||||
- Exportable content (print-to-PDF functionality)
|
||||
- Interactive dashboards with real-time filtering and sorting
|
||||
114
backend/onyx/server/features/build/CLAUDE.template.md
Normal file
114
backend/onyx/server/features/build/CLAUDE.template.md
Normal file
@@ -0,0 +1,114 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Structure
|
||||
|
||||
The `files` directory contains all of the knowledge from Chris' company, Onyx. This knowledge comes from Google Drive, Linear, Slack, Github, and Fireflies.
|
||||
|
||||
Each source has it's own directory - `Google_Drive`, `Linear`, `Slack`, `Github`, and `Fireflies`. Within each directory, the structure of the source is built out as a folder structure:
|
||||
|
||||
- Google Drive is copied over directly as is. End files are stored as `FILE_NAME.json`.
|
||||
- Linear has each project as a folder, and then within each project, each individual ticket is stored as a file: `[TICKET_ID]_TICKET_NAME.json`.
|
||||
- Slack has each channel as a folder titled `[CHANNEL_NAME]` in the root directory. Within each channel, each thread is represented as a single file called `[INITIAL_AUTHOR]_in_[CHANNEL]__[FIRST_MESSAGE].json`.
|
||||
- Github has each organization as a folder titled `[ORG_NAME]`. Within each organization, there is
|
||||
a folder for each repository tilted `[REPO_NAME]`. Within each repository there are up to two folders: `pull_requests` and `issues`. Each pull request / issue is then represented as a single file
|
||||
within the appropriate folder. Pull requests are structured as `[PR_ID]__[PR_NAME].json` and issues
|
||||
are structured as `[ISSUE_ID]__[ISSUE_NAME].json`.
|
||||
- Fireflies has all calls in the root, each as a single file titled `CALL_TITLE.json`.
|
||||
- HubSpot has four folders in the root: `Tickets`, `Companies`, `Deals`, and `Contacts`. Each object is stored as a file named after its title/name (e.g., `[TICKET_SUBJECT].json`, `[COMPANY_NAME].json`, `[DEAL_NAME].json`, `[CONTACT_NAME].json`).
|
||||
|
||||
Across all names, spaces are replaced by `_`.
|
||||
|
||||
Each JSON is structured like:
|
||||
|
||||
```
|
||||
{
|
||||
"id": "afbec183-b0c5-46bf-b768-1ce88d003729",
|
||||
"semantic_identifier": "[CS-17] [Betclic] Update system prompt doesn't work",
|
||||
"title": "[Betclic] Update system prompt doesn't work",
|
||||
"source": "linear",
|
||||
"doc_updated_at": "2025-11-10T16:31:07.735000+00:00",
|
||||
"metadata": {
|
||||
"team": "Customer Success",
|
||||
"creator": "{'name': 'Chris Weaver', 'email': 'chris@danswer.ai'}",
|
||||
"state": "Backlog",
|
||||
"priority": "3",
|
||||
"created_at": "2025-11-10T16:30:10.718Z"
|
||||
},
|
||||
"doc_metadata": {
|
||||
"hierarchy": {
|
||||
"source_path": [
|
||||
"Customer Success"
|
||||
],
|
||||
"team_name": "Customer Success",
|
||||
"identifier": "CS-17"
|
||||
}
|
||||
},
|
||||
"sections": [
|
||||
{
|
||||
"text": "Happens \\~15% of the time.",
|
||||
"link": "https://linear.app/onyx-app/issue/CS-17/betclic-update-system-prompt-doesnt-work"
|
||||
}
|
||||
],
|
||||
"primary_owners": [],
|
||||
"secondary_owners": []
|
||||
}
|
||||
```
|
||||
|
||||
Do NOT write any files to these directories. Do NOT edit any files in these directories.
|
||||
|
||||
There is a special folder called `outputs`. Any and all python scripts, javascript apps, generated documents, slides, etc. should go here.
|
||||
Feel free to write/edit anything you find in here.
|
||||
|
||||
|
||||
## Outputs
|
||||
|
||||
There should be four main types of outputs:
|
||||
1. Web Applications / Dashboards
|
||||
2. Slides
|
||||
3. Markdown Documents
|
||||
4. Graphs/Charts
|
||||
|
||||
Generally, you should use
|
||||
|
||||
### Web Applications / Dashboards
|
||||
|
||||
Web applications and dashboards should be written as a Next.js app. Within the `outputs` directory,
|
||||
there is a folder called `web` that has the skeleton of a basic Next.js app in it. Use this.
|
||||
|
||||
Use NextJS 16.1.1, React v19, Tailwindcss, and recharts.
|
||||
|
||||
The Next.js app is already running and accessible at http://localhost:3002. Do not run `npm run dev` yourself.
|
||||
|
||||
If the app needs any pre-computation, then create a bash script called `prepare.sh` at the root of the `web` directory.
|
||||
|
||||
### Slides
|
||||
|
||||
Slides should be created using the nano-banana MCP.
|
||||
|
||||
The outputs should be placed within the `outputs/slides` directory, named `[SLIDE_NUMBER].png`.
|
||||
|
||||
Before creating slides, create a `SLIDE_OUTLINE.md` file describing the overall message as well as the content and structure of each slide.
|
||||
|
||||
### Markdown Documents
|
||||
|
||||
Markdown documents should be placed within the `outputs/document` directory.
|
||||
If you want to have a single "Document" that has multiple distinct pages, then create a folder within
|
||||
the `outputs/document` directory, and name each page `1.MD`, `2.MD`, ...
|
||||
|
||||
### Graphs/Charts
|
||||
|
||||
Graphs and charts should be placed in the `outputs/charts` directory.
|
||||
|
||||
Graphs and charts should be created with a python script. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL.
|
||||
|
||||
## Your Environment
|
||||
|
||||
You are in an ephemeral virtual machine.
|
||||
|
||||
You currently have Python 3.11.13 and Node v22.21.1.
|
||||
|
||||
**Python Virtual Environment**: A Python virtual environment is pre-configured at `.venv/` with common data science and visualization packages already installed (numpy, pandas, matplotlib, scipy, PIL, etc.). The environment should be automatically activated, but if you run into issues with missing packages, you can explicitly use `.venv/bin/python` or `.venv/bin/pip`.
|
||||
|
||||
If you need additional packages, install them with `pip install <package>` (or `.venv/bin/pip install <package>` if the venv isn't active). For javascript packages, use `npm` from within the `outputs/web` directory.
|
||||
1
backend/onyx/server/features/build/__init__.py
Normal file
1
backend/onyx/server/features/build/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Build feature module
|
||||
454
backend/onyx/server/features/build/api/api.py
Normal file
454
backend/onyx/server/features/build/api/api.py
Normal file
@@ -0,0 +1,454 @@
|
||||
from collections.abc import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.build.api.messages_api import router as messages_router
|
||||
from onyx.server.features.build.api.models import BuildConnectorInfo
|
||||
from onyx.server.features.build.api.models import BuildConnectorListResponse
|
||||
from onyx.server.features.build.api.models import BuildConnectorStatus
|
||||
from onyx.server.features.build.api.models import RateLimitResponse
|
||||
from onyx.server.features.build.api.rate_limit import get_user_rate_limit_status
|
||||
from onyx.server.features.build.api.sessions_api import router as sessions_router
|
||||
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
|
||||
from onyx.server.features.build.sandbox import get_sandbox_manager
|
||||
from onyx.server.features.build.session.manager import SessionManager
|
||||
from onyx.server.features.build.utils import is_onyx_craft_enabled
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def require_onyx_craft_enabled(user: User = Depends(current_user)) -> User:
|
||||
"""
|
||||
Dependency that checks if Onyx Craft is enabled for the user.
|
||||
Raises HTTP 403 if Onyx Craft is disabled via feature flag.
|
||||
"""
|
||||
if not is_onyx_craft_enabled(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Onyx Craft is not available",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
router = APIRouter(prefix="/build", dependencies=[Depends(require_onyx_craft_enabled)])
|
||||
|
||||
# Include sub-routers for sessions and messages
|
||||
router.include_router(sessions_router, tags=["build"])
|
||||
router.include_router(messages_router, tags=["build"])
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Rate Limiting
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/limit", response_model=RateLimitResponse)
|
||||
def get_rate_limit(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RateLimitResponse:
|
||||
"""Get rate limit information for the current user."""
|
||||
return get_user_rate_limit_status(user, db_session)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Build Connectors
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/connectors", response_model=BuildConnectorListResponse)
|
||||
def get_build_connectors(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> BuildConnectorListResponse:
|
||||
"""Get all connectors for the build admin panel.
|
||||
|
||||
Returns all connector-credential pairs with simplified status information.
|
||||
"""
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
processing_mode=ProcessingMode.FILE_SYSTEM, # Only show FILE_SYSTEM connectors
|
||||
)
|
||||
|
||||
connectors: list[BuildConnectorInfo] = []
|
||||
for cc_pair in cc_pairs:
|
||||
# Skip ingestion API connectors and default pairs
|
||||
if cc_pair.connector.source == DocumentSource.INGESTION_API:
|
||||
continue
|
||||
if cc_pair.name == "DefaultCCPair":
|
||||
continue
|
||||
|
||||
# Determine status
|
||||
error_message: str | None = None
|
||||
has_ever_succeeded = cc_pair.last_successful_index_time is not None
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
status = BuildConnectorStatus.DELETING
|
||||
elif cc_pair.status == ConnectorCredentialPairStatus.INVALID:
|
||||
# If connector has succeeded before but credentials are now invalid,
|
||||
# show as connected_with_errors so user can still disable demo data
|
||||
if has_ever_succeeded:
|
||||
status = BuildConnectorStatus.CONNECTED_WITH_ERRORS
|
||||
error_message = "Connector credentials are invalid"
|
||||
else:
|
||||
status = BuildConnectorStatus.ERROR
|
||||
error_message = "Connector credentials are invalid"
|
||||
else:
|
||||
# Check latest index attempt for errors
|
||||
latest_attempt = get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session=db_session,
|
||||
connector_credential_pair_id=cc_pair.id,
|
||||
secondary_index=False,
|
||||
only_finished=True,
|
||||
)
|
||||
|
||||
if latest_attempt and latest_attempt.status == IndexingStatus.FAILED:
|
||||
# If connector has succeeded before but latest attempt failed,
|
||||
# show as connected_with_errors
|
||||
if has_ever_succeeded:
|
||||
status = BuildConnectorStatus.CONNECTED_WITH_ERRORS
|
||||
else:
|
||||
status = BuildConnectorStatus.ERROR
|
||||
error_message = latest_attempt.error_msg
|
||||
elif (
|
||||
latest_attempt
|
||||
and latest_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS
|
||||
):
|
||||
# Completed with errors - if it has succeeded before, show as connected_with_errors
|
||||
if has_ever_succeeded:
|
||||
status = BuildConnectorStatus.CONNECTED_WITH_ERRORS
|
||||
else:
|
||||
status = BuildConnectorStatus.ERROR
|
||||
error_message = "Indexing completed with errors"
|
||||
elif cc_pair.status == ConnectorCredentialPairStatus.PAUSED:
|
||||
status = BuildConnectorStatus.CONNECTED
|
||||
elif cc_pair.last_successful_index_time is None:
|
||||
# Never successfully indexed - check if currently indexing
|
||||
# First check cc_pair status for scheduled/initial indexing
|
||||
if cc_pair.status in (
|
||||
ConnectorCredentialPairStatus.SCHEDULED,
|
||||
ConnectorCredentialPairStatus.INITIAL_INDEXING,
|
||||
):
|
||||
status = BuildConnectorStatus.INDEXING
|
||||
else:
|
||||
in_progress_attempt = get_latest_index_attempt_for_cc_pair_id(
|
||||
db_session=db_session,
|
||||
connector_credential_pair_id=cc_pair.id,
|
||||
secondary_index=False,
|
||||
only_finished=False,
|
||||
)
|
||||
if (
|
||||
in_progress_attempt
|
||||
and in_progress_attempt.status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
status = BuildConnectorStatus.INDEXING
|
||||
elif (
|
||||
in_progress_attempt
|
||||
and in_progress_attempt.status == IndexingStatus.NOT_STARTED
|
||||
):
|
||||
status = BuildConnectorStatus.INDEXING
|
||||
else:
|
||||
# Has a finished attempt but never succeeded - likely error
|
||||
status = BuildConnectorStatus.ERROR
|
||||
error_message = (
|
||||
latest_attempt.error_msg
|
||||
if latest_attempt
|
||||
else "Initial indexing failed"
|
||||
)
|
||||
else:
|
||||
status = BuildConnectorStatus.CONNECTED
|
||||
|
||||
connectors.append(
|
||||
BuildConnectorInfo(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
source=cc_pair.connector.source.value,
|
||||
name=cc_pair.name or cc_pair.connector.name or "Unnamed",
|
||||
status=status,
|
||||
docs_indexed=0, # Would need to query for this
|
||||
last_indexed=cc_pair.last_successful_index_time,
|
||||
error_message=error_message,
|
||||
)
|
||||
)
|
||||
|
||||
return BuildConnectorListResponse(connectors=connectors)
|
||||
|
||||
|
||||
# Headers to skip when proxying (hop-by-hop headers)
|
||||
EXCLUDED_HEADERS = {
|
||||
"content-encoding",
|
||||
"content-length",
|
||||
"transfer-encoding",
|
||||
"connection",
|
||||
}
|
||||
|
||||
|
||||
def _stream_response(response: httpx.Response) -> Iterator[bytes]:
|
||||
"""Stream the response content in chunks."""
|
||||
for chunk in response.iter_bytes(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
|
||||
def _rewrite_asset_paths(content: bytes, session_id: str) -> bytes:
|
||||
"""Rewrite Next.js asset paths to go through the proxy."""
|
||||
import re
|
||||
|
||||
# Base path includes session_id for routing
|
||||
webapp_base_path = f"/api/build/sessions/{session_id}/webapp"
|
||||
|
||||
text = content.decode("utf-8")
|
||||
# Rewrite /_next/ paths to go through our proxy
|
||||
text = text.replace("/_next/", f"{webapp_base_path}/_next/")
|
||||
# Rewrite JSON data file fetch paths (e.g., /data.json, /data/tickets.json)
|
||||
# Matches paths like "/filename.json" or "/path/to/file.json"
|
||||
text = re.sub(
|
||||
r'"(/(?:[a-zA-Z0-9_-]+/)*[a-zA-Z0-9_-]+\.json)"',
|
||||
f'"{webapp_base_path}\\1"',
|
||||
text,
|
||||
)
|
||||
text = re.sub(
|
||||
r"'(/(?:[a-zA-Z0-9_-]+/)*[a-zA-Z0-9_-]+\.json)'",
|
||||
f"'{webapp_base_path}\\1'",
|
||||
text,
|
||||
)
|
||||
# Rewrite favicon
|
||||
text = text.replace('"/favicon.ico', f'"{webapp_base_path}/favicon.ico')
|
||||
return text.encode("utf-8")
|
||||
|
||||
|
||||
# Content types that may contain asset path references that need rewriting
|
||||
REWRITABLE_CONTENT_TYPES = {
|
||||
"text/html",
|
||||
"text/css",
|
||||
"application/javascript",
|
||||
"text/javascript",
|
||||
"application/x-javascript",
|
||||
}
|
||||
|
||||
|
||||
def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
|
||||
"""Get the internal URL for a session's Next.js server.
|
||||
|
||||
Uses the sandbox manager to get the correct URL for both local and
|
||||
Kubernetes environments.
|
||||
|
||||
Args:
|
||||
session_id: The build session ID
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
The internal URL to proxy requests to
|
||||
|
||||
Raises:
|
||||
HTTPException: If session not found, port not allocated, or sandbox not found
|
||||
"""
|
||||
|
||||
session = db_session.get(BuildSession, session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
if session.nextjs_port is None:
|
||||
raise HTTPException(status_code=503, detail="Session port not allocated")
|
||||
if session.user_id is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Get the user's sandbox to get the sandbox_id
|
||||
sandbox = get_sandbox_by_user_id(db_session, session.user_id)
|
||||
if sandbox is None:
|
||||
raise HTTPException(status_code=404, detail="Sandbox not found")
|
||||
|
||||
# Use sandbox manager to get the correct internal URL
|
||||
sandbox_manager = get_sandbox_manager()
|
||||
return sandbox_manager.get_webapp_url(sandbox.id, session.nextjs_port)
|
||||
|
||||
|
||||
def _proxy_request(
|
||||
path: str, request: Request, session_id: UUID, db_session: Session
|
||||
) -> StreamingResponse | Response:
|
||||
"""Proxy a request to the sandbox's Next.js server."""
|
||||
base_url = _get_sandbox_url(session_id, db_session)
|
||||
|
||||
# Build the target URL
|
||||
target_url = f"{base_url}/{path.lstrip('/')}"
|
||||
|
||||
# Include query params if present
|
||||
if request.query_params:
|
||||
target_url = f"{target_url}?{request.query_params}"
|
||||
|
||||
logger.debug(f"Proxying request to: {target_url}")
|
||||
|
||||
try:
|
||||
# Make the request to the target URL
|
||||
with httpx.Client(timeout=30.0, follow_redirects=True) as client:
|
||||
response = client.get(
|
||||
target_url,
|
||||
headers={
|
||||
key: value
|
||||
for key, value in request.headers.items()
|
||||
if key.lower() not in ("host", "content-length")
|
||||
},
|
||||
)
|
||||
|
||||
# Build response headers, excluding hop-by-hop headers
|
||||
response_headers = {
|
||||
key: value
|
||||
for key, value in response.headers.items()
|
||||
if key.lower() not in EXCLUDED_HEADERS
|
||||
}
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
|
||||
# For HTML/CSS/JS responses, rewrite asset paths
|
||||
if any(ct in content_type for ct in REWRITABLE_CONTENT_TYPES):
|
||||
content = _rewrite_asset_paths(response.content, str(session_id))
|
||||
return Response(
|
||||
content=content,
|
||||
status_code=response.status_code,
|
||||
headers=response_headers,
|
||||
media_type=content_type,
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
content=_stream_response(response),
|
||||
status_code=response.status_code,
|
||||
headers=response_headers,
|
||||
media_type=content_type or None,
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error(f"Timeout while proxying request to {target_url}")
|
||||
raise HTTPException(status_code=504, detail="Gateway timeout")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Error proxying request to {target_url}: {e}")
|
||||
raise HTTPException(status_code=502, detail="Bad gateway")
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/webapp", response_model=None)
|
||||
def get_webapp_root(
|
||||
session_id: UUID,
|
||||
request: Request,
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | Response:
|
||||
"""Proxy the root path of the webapp for a specific session."""
|
||||
return _proxy_request("", request, session_id, db_session)
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/webapp/{path:path}", response_model=None)
|
||||
def get_webapp_path(
|
||||
session_id: UUID,
|
||||
path: str,
|
||||
request: Request,
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | Response:
|
||||
"""Proxy any subpath of the webapp (static assets, etc.) for a specific session."""
|
||||
return _proxy_request(path, request, session_id, db_session)
|
||||
|
||||
|
||||
# Separate router for Next.js static assets at /_next/*
|
||||
# This is needed because Next.js apps may reference assets with root-relative paths
|
||||
# that don't get rewritten. The session_id is extracted from the Referer header.
|
||||
nextjs_assets_router = APIRouter()
|
||||
|
||||
|
||||
def _extract_session_from_referer(request: Request) -> UUID | None:
|
||||
"""Extract session_id from the Referer header.
|
||||
|
||||
Expects Referer to contain /api/build/sessions/{session_id}/webapp
|
||||
"""
|
||||
import re
|
||||
|
||||
referer = request.headers.get("referer", "")
|
||||
match = re.search(r"/api/build/sessions/([a-f0-9-]+)/webapp", referer)
|
||||
if match:
|
||||
try:
|
||||
return UUID(match.group(1))
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
@nextjs_assets_router.get("/_next/{path:path}", response_model=None)
|
||||
def get_nextjs_assets(
|
||||
path: str,
|
||||
request: Request,
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | Response:
|
||||
"""Proxy Next.js static assets requested at root /_next/ path.
|
||||
|
||||
The session_id is extracted from the Referer header since these requests
|
||||
come from within the iframe context.
|
||||
"""
|
||||
session_id = _extract_session_from_referer(request)
|
||||
if not session_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Could not determine session from request context",
|
||||
)
|
||||
return _proxy_request(f"_next/{path}", request, session_id, db_session)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Sandbox Management Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/sandbox/reset", response_model=None)
|
||||
def reset_sandbox(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Reset the user's sandbox by terminating it and cleaning up all sessions.
|
||||
|
||||
This endpoint terminates the user's shared sandbox container/pod and
|
||||
cleans up all session workspaces. Useful for "start fresh" functionality.
|
||||
|
||||
After calling this endpoint, the next session creation will provision a
|
||||
new sandbox.
|
||||
"""
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
try:
|
||||
success = session_manager.terminate_user_sandbox(user.id)
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="No sandbox found for user",
|
||||
)
|
||||
db_session.commit()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.error(f"Failed to reset sandbox for user {user.id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to reset sandbox: {e}",
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
106
backend/onyx/server/features/build/api/messages_api.py
Normal file
106
backend/onyx/server/features/build/api/messages_api.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""API endpoints for Build Mode message management."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.build.api.models import MessageListResponse
|
||||
from onyx.server.features.build.api.models import MessageRequest
|
||||
from onyx.server.features.build.api.models import MessageResponse
|
||||
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
|
||||
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
|
||||
from onyx.server.features.build.session.manager import RateLimitError
|
||||
from onyx.server.features.build.session.manager import SessionManager
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def check_build_rate_limits(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""
|
||||
Dependency to check build mode rate limits before processing the request.
|
||||
|
||||
Raises HTTPException(429) if rate limit is exceeded.
|
||||
Follows the same pattern as chat's check_token_rate_limits.
|
||||
"""
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
try:
|
||||
session_manager.check_rate_limit(user)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions/{session_id}/messages", tags=PUBLIC_API_TAGS)
|
||||
def list_messages(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageListResponse:
|
||||
"""Get all messages for a build session."""
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
messages = session_manager.list_messages(session_id, user.id)
|
||||
|
||||
if messages is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return MessageListResponse(
|
||||
messages=[MessageResponse.from_model(msg) for msg in messages]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/send-message", tags=PUBLIC_API_TAGS)
|
||||
async def send_message(
|
||||
session_id: UUID,
|
||||
request: MessageRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
_rate_limit_check: None = Depends(check_build_rate_limits),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Send a message to the CLI agent and stream the response.
|
||||
|
||||
Enforces rate limiting before executing the agent (via dependency).
|
||||
Returns a Server-Sent Events (SSE) stream with the agent's response.
|
||||
|
||||
Follows the same pattern as /chat/send-message for consistency.
|
||||
"""
|
||||
# Update sandbox heartbeat - this is the only place we track activity
|
||||
# for determining when a sandbox should be put to sleep
|
||||
sandbox = get_sandbox_by_user_id(db_session, user.id)
|
||||
if sandbox and sandbox.status.is_active():
|
||||
update_sandbox_heartbeat(db_session, sandbox.id)
|
||||
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
# Stream the CLI agent's response
|
||||
return StreamingResponse(
|
||||
session_manager.send_message(session_id, user.id, request.content),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
||||
},
|
||||
)
|
||||
325
backend/onyx/server/features/build/api/models.py
Normal file
325
backend/onyx/server/features/build/api/models.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import ArtifactType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.server.features.build.sandbox.models import (
|
||||
FilesystemEntry as FileSystemEntry,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.db.models import Sandbox
|
||||
from onyx.db.models import BuildSession
|
||||
|
||||
|
||||
# ===== Session Models =====
|
||||
class SessionCreateRequest(BaseModel):
|
||||
"""Request to create a new build session."""
|
||||
|
||||
name: str | None = None # Optional session name
|
||||
demo_data_enabled: bool = True # Whether to enable demo org_info data in sandbox
|
||||
user_work_area: str | None = None # User's work area (e.g., "engineering")
|
||||
user_level: str | None = None # User's level (e.g., "ic", "manager")
|
||||
# LLM selection from user's cookie
|
||||
llm_provider_type: str | None = None # Provider type (e.g., "anthropic", "openai")
|
||||
llm_model_name: str | None = None # Model name (e.g., "claude-opus-4-5")
|
||||
|
||||
|
||||
class SessionUpdateRequest(BaseModel):
|
||||
"""Request to update a build session.
|
||||
|
||||
If name is None, the session name will be auto-generated using LLM.
|
||||
"""
|
||||
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class SessionNameGenerateResponse(BaseModel):
|
||||
"""Response containing a generated session name."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class SandboxResponse(BaseModel):
|
||||
"""Sandbox metadata in session response."""
|
||||
|
||||
id: str
|
||||
status: SandboxStatus
|
||||
container_id: str | None
|
||||
created_at: datetime
|
||||
last_heartbeat: datetime | None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, sandbox: Any) -> "SandboxResponse":
|
||||
"""Convert Sandbox ORM model to response."""
|
||||
return cls(
|
||||
id=str(sandbox.id),
|
||||
status=sandbox.status,
|
||||
container_id=sandbox.container_id,
|
||||
created_at=sandbox.created_at,
|
||||
last_heartbeat=sandbox.last_heartbeat,
|
||||
)
|
||||
|
||||
|
||||
class ArtifactResponse(BaseModel):
|
||||
"""Artifact metadata in session response."""
|
||||
|
||||
id: str
|
||||
session_id: str
|
||||
type: ArtifactType
|
||||
name: str
|
||||
path: str
|
||||
preview_url: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, artifact: Any) -> "ArtifactResponse":
|
||||
"""Convert Artifact ORM model to response."""
|
||||
return cls(
|
||||
id=str(artifact.id),
|
||||
session_id=str(artifact.session_id),
|
||||
type=artifact.type,
|
||||
name=artifact.name,
|
||||
path=artifact.path,
|
||||
preview_url=getattr(artifact, "preview_url", None),
|
||||
created_at=artifact.created_at,
|
||||
updated_at=artifact.updated_at,
|
||||
)
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
"""Response containing session details."""
|
||||
|
||||
id: str
|
||||
user_id: str | None
|
||||
name: str | None
|
||||
status: BuildSessionStatus
|
||||
created_at: datetime
|
||||
last_activity_at: datetime
|
||||
nextjs_port: int | None
|
||||
sandbox: SandboxResponse | None
|
||||
artifacts: list[ArtifactResponse]
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, session: "BuildSession", sandbox: Union["Sandbox", None] = None
|
||||
) -> "SessionResponse":
|
||||
"""Convert BuildSession ORM model to response.
|
||||
|
||||
Args:
|
||||
session: BuildSession ORM model
|
||||
sandbox: Optional Sandbox ORM model. Since sandboxes are now user-owned
|
||||
(not session-owned), the sandbox must be passed separately.
|
||||
"""
|
||||
return cls(
|
||||
id=str(session.id),
|
||||
user_id=str(session.user_id) if session.user_id else None,
|
||||
name=session.name,
|
||||
status=session.status,
|
||||
created_at=session.created_at,
|
||||
last_activity_at=session.last_activity_at,
|
||||
nextjs_port=session.nextjs_port,
|
||||
sandbox=(SandboxResponse.from_model(sandbox) if sandbox else None),
|
||||
artifacts=[ArtifactResponse.from_model(a) for a in session.artifacts],
|
||||
)
|
||||
|
||||
|
||||
class DetailedSessionResponse(SessionResponse):
|
||||
"""Extended session response with sandbox state details.
|
||||
|
||||
Used for single-session endpoints where we compute expensive fields
|
||||
like session_loaded_in_sandbox.
|
||||
"""
|
||||
|
||||
session_loaded_in_sandbox: bool
|
||||
|
||||
@classmethod
|
||||
def from_session_response(
|
||||
cls,
|
||||
base: SessionResponse,
|
||||
session_loaded_in_sandbox: bool,
|
||||
) -> "DetailedSessionResponse":
|
||||
return cls(
|
||||
**base.model_dump(),
|
||||
session_loaded_in_sandbox=session_loaded_in_sandbox,
|
||||
)
|
||||
|
||||
|
||||
class SessionListResponse(BaseModel):
|
||||
"""Response containing list of sessions."""
|
||||
|
||||
sessions: list[SessionResponse]
|
||||
|
||||
|
||||
# ===== Message Models =====
|
||||
class MessageRequest(BaseModel):
|
||||
"""Request to send a message to the CLI agent."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Response containing message details.
|
||||
|
||||
All message data is stored in message_metadata as JSON (the raw ACP packet).
|
||||
The turn_index groups all assistant responses under the user prompt they respond to.
|
||||
|
||||
Packet types in message_metadata:
|
||||
- user_message: {type: "user_message", content: {...}}
|
||||
- agent_message: {type: "agent_message", content: {...}}
|
||||
- agent_thought: {type: "agent_thought", content: {...}}
|
||||
- tool_call_progress: {type: "tool_call_progress", status: "completed", ...}
|
||||
- agent_plan_update: {type: "agent_plan_update", entries: [...]}
|
||||
"""
|
||||
|
||||
id: str
|
||||
session_id: str
|
||||
turn_index: int
|
||||
type: MessageType
|
||||
message_metadata: dict[str, Any]
|
||||
created_at: datetime
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, message: Any) -> "MessageResponse":
|
||||
"""Convert BuildMessage ORM model to response."""
|
||||
return cls(
|
||||
id=str(message.id),
|
||||
session_id=str(message.session_id),
|
||||
turn_index=message.turn_index,
|
||||
type=message.type,
|
||||
message_metadata=message.message_metadata,
|
||||
created_at=message.created_at,
|
||||
)
|
||||
|
||||
|
||||
class MessageListResponse(BaseModel):
|
||||
"""Response containing list of messages."""
|
||||
|
||||
messages: list[MessageResponse]
|
||||
|
||||
|
||||
# ===== Legacy Models (for compatibility with other code) =====
|
||||
class CreateSessionRequest(BaseModel):
|
||||
task: str
|
||||
available_sources: list[str] | None = None
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
session_id: str
|
||||
|
||||
|
||||
class ExecuteRequest(BaseModel):
|
||||
task: str
|
||||
context: str | None = None
|
||||
|
||||
|
||||
class ArtifactInfo(BaseModel):
|
||||
artifact_type: str # "webapp", "file", "markdown", "image"
|
||||
path: str
|
||||
filename: str
|
||||
mime_type: str | None = None
|
||||
|
||||
|
||||
class SessionStatus(BaseModel):
|
||||
session_id: str
|
||||
status: str # "idle", "running", "completed", "failed"
|
||||
webapp_url: str | None = None
|
||||
|
||||
|
||||
class DirectoryListing(BaseModel):
|
||||
path: str # Current directory path
|
||||
entries: list[FileSystemEntry] # Contents
|
||||
|
||||
|
||||
class WebappInfo(BaseModel):
|
||||
has_webapp: bool # Whether a webapp exists in outputs/web
|
||||
webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015)
|
||||
status: str # Sandbox status (running, terminated, etc.)
|
||||
|
||||
|
||||
# ===== File Upload Models =====
|
||||
class UploadResponse(BaseModel):
|
||||
"""Response after successful file upload."""
|
||||
|
||||
filename: str # Sanitized filename
|
||||
path: str # Relative path in sandbox (e.g., "attachments/doc.pdf")
|
||||
size_bytes: int # File size in bytes
|
||||
|
||||
|
||||
# ===== Rate Limit Models =====
|
||||
class RateLimitResponse(BaseModel):
|
||||
"""Rate limit information."""
|
||||
|
||||
is_limited: bool
|
||||
limit_type: str # "weekly" or "total"
|
||||
messages_used: int
|
||||
limit: int
|
||||
reset_timestamp: str | None = None
|
||||
|
||||
|
||||
# ===== Build Connector Models =====
|
||||
class BuildConnectorStatus(str, Enum):
|
||||
"""Status of a build connector."""
|
||||
|
||||
NOT_CONNECTED = "not_connected"
|
||||
CONNECTED = "connected"
|
||||
CONNECTED_WITH_ERRORS = "connected_with_errors"
|
||||
INDEXING = "indexing"
|
||||
ERROR = "error"
|
||||
DELETING = "deleting"
|
||||
|
||||
|
||||
class BuildConnectorInfo(BaseModel):
|
||||
"""Simplified connector info for build admin panel."""
|
||||
|
||||
cc_pair_id: int
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
source: str
|
||||
name: str
|
||||
status: BuildConnectorStatus
|
||||
docs_indexed: int
|
||||
last_indexed: datetime | None
|
||||
error_message: str | None = None
|
||||
|
||||
|
||||
class BuildConnectorListResponse(BaseModel):
|
||||
"""List of build connectors."""
|
||||
|
||||
connectors: list[BuildConnectorInfo]
|
||||
|
||||
|
||||
# ===== Suggestion Bubble Models =====
|
||||
class SuggestionTheme(str, Enum):
|
||||
"""Theme/category of a follow-up suggestion."""
|
||||
|
||||
ADD = "add"
|
||||
QUESTION = "question"
|
||||
|
||||
|
||||
class SuggestionBubble(BaseModel):
|
||||
"""A single follow-up suggestion bubble."""
|
||||
|
||||
theme: SuggestionTheme
|
||||
text: str
|
||||
|
||||
|
||||
class GenerateSuggestionsRequest(BaseModel):
|
||||
"""Request to generate follow-up suggestions."""
|
||||
|
||||
user_message: str # First user message
|
||||
assistant_message: str # First assistant text response (accumulated)
|
||||
|
||||
|
||||
class GenerateSuggestionsResponse(BaseModel):
|
||||
"""Response containing generated suggestions."""
|
||||
|
||||
suggestions: list[SuggestionBubble]
|
||||
101
backend/onyx/server/features/build/api/packet_logger.py
Normal file
101
backend/onyx/server/features/build/api/packet_logger.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Simple packet logger for build mode debugging.
|
||||
|
||||
Logs the raw JSON of every packet emitted during build mode.
|
||||
|
||||
Log output: backend/onyx/server/features/build/packets.log
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class PacketLogger:
|
||||
"""Simple packet logger - outputs raw JSON for each packet."""
|
||||
|
||||
_instance: "PacketLogger | None" = None
|
||||
_initialized: bool
|
||||
|
||||
def __new__(cls) -> "PacketLogger":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._initialized = True
|
||||
self._enabled = os.getenv("LOG_LEVEL", "").upper() == "DEBUG"
|
||||
self._logger: logging.Logger | None = None
|
||||
|
||||
if self._enabled:
|
||||
self._setup_logger()
|
||||
|
||||
def _setup_logger(self) -> None:
|
||||
"""Set up the file handler for packet logging."""
|
||||
# Log to backend/onyx/server/features/build/packets.log
|
||||
build_dir = Path(__file__).parents[1]
|
||||
log_file = build_dir / "packets.log"
|
||||
|
||||
self._logger = logging.getLogger("build.packets")
|
||||
self._logger.setLevel(logging.DEBUG)
|
||||
self._logger.propagate = False
|
||||
|
||||
self._logger.handlers.clear()
|
||||
|
||||
handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
|
||||
self._logger.addHandler(handler)
|
||||
|
||||
def log(self, packet_type: str, payload: dict[str, Any] | None = None) -> None:
|
||||
"""Log a packet as JSON.
|
||||
|
||||
Args:
|
||||
packet_type: The type of packet
|
||||
payload: The packet payload
|
||||
"""
|
||||
if not self._enabled or not self._logger:
|
||||
return
|
||||
|
||||
try:
|
||||
output = json.dumps(payload, indent=2, default=str) if payload else "{}"
|
||||
self._logger.debug(f"\n=== {packet_type} ===\n{output}")
|
||||
except Exception:
|
||||
self._logger.debug(f"\n=== {packet_type} ===\n{payload}")
|
||||
|
||||
def log_raw(self, label: str, data: Any) -> None:
|
||||
"""Log raw data with a label.
|
||||
|
||||
Args:
|
||||
label: A label for this log entry
|
||||
data: Any data to log
|
||||
"""
|
||||
if not self._enabled or not self._logger:
|
||||
return
|
||||
|
||||
try:
|
||||
if isinstance(data, (dict, list)):
|
||||
output = json.dumps(data, indent=2, default=str)
|
||||
else:
|
||||
output = str(data)
|
||||
self._logger.debug(f"\n=== {label} ===\n{output}")
|
||||
except Exception:
|
||||
self._logger.debug(f"\n=== {label} ===\n{data}")
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_packet_logger: PacketLogger | None = None
|
||||
|
||||
|
||||
def get_packet_logger() -> PacketLogger:
|
||||
"""Get the singleton packet logger instance."""
|
||||
global _packet_logger
|
||||
if _packet_logger is None:
|
||||
_packet_logger = PacketLogger()
|
||||
return _packet_logger
|
||||
68
backend/onyx/server/features/build/api/packets.py
Normal file
68
backend/onyx/server/features/build/api/packets.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Build Mode packet types for streaming agent responses.
|
||||
|
||||
This module defines CUSTOM Onyx packet types that extend ACP (Agent Client Protocol).
|
||||
ACP events are passed through directly from the agent - this module only contains
|
||||
Onyx-specific extensions like artifacts and file operations.
|
||||
|
||||
All packets use SSE (Server-Sent Events) format with `event: message` and include
|
||||
a `type` field to distinguish packet types.
|
||||
|
||||
ACP events (passed through directly from acp.schema):
|
||||
- agent_message_chunk: Text/image content from agent
|
||||
- agent_thought_chunk: Agent's internal reasoning
|
||||
- tool_call_start: Tool invocation started
|
||||
- tool_call_progress: Tool execution progress/result
|
||||
- agent_plan_update: Agent's execution plan
|
||||
- current_mode_update: Agent mode change
|
||||
- prompt_response: Agent finished processing
|
||||
- error: An error occurred
|
||||
|
||||
Custom Onyx packets (defined here):
|
||||
- error: Onyx-specific errors (e.g., session not found)
|
||||
|
||||
Based on:
|
||||
- Agent Client Protocol (ACP): https://agentclientprotocol.com
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base Packet Type
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class BasePacket(BaseModel):
|
||||
"""Base packet with common fields for all custom Onyx packet types."""
|
||||
|
||||
type: str
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.now(tz=timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Custom Onyx Packets
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ErrorPacket(BasePacket):
|
||||
"""An Onyx-specific error occurred (e.g., session not found, sandbox not running)."""
|
||||
|
||||
type: Literal["error"] = "error"
|
||||
message: str
|
||||
code: int | None = None
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Union Type for Custom Onyx Packets
|
||||
# =============================================================================
|
||||
|
||||
BuildPacket = ErrorPacket
|
||||
90
backend/onyx/server/features/build/api/rate_limit.py
Normal file
90
backend/onyx/server/features/build/api/rate_limit.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Rate limiting logic for Build Mode."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.build.api.models import RateLimitResponse
|
||||
from onyx.server.features.build.api.subscription_check import is_user_subscribed
|
||||
from onyx.server.features.build.db.rate_limit import count_user_messages_in_window
|
||||
from onyx.server.features.build.db.rate_limit import count_user_messages_total
|
||||
from onyx.server.features.build.db.rate_limit import get_oldest_message_timestamp
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def get_user_rate_limit_status(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> RateLimitResponse:
|
||||
"""
|
||||
Get the rate limit status for a user.
|
||||
|
||||
Rate limits:
|
||||
- Cloud (MULTI_TENANT=true):
|
||||
- Subscribed users: 50 messages per week (rolling 7-day window)
|
||||
- Non-subscribed users: 5 messages (lifetime total)
|
||||
- Self-hosted (MULTI_TENANT=false):
|
||||
- Unlimited (no rate limiting)
|
||||
|
||||
Args:
|
||||
user: The user object (None for unauthenticated users)
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
RateLimitResponse with current limit status
|
||||
"""
|
||||
# Self-hosted deployments have no rate limits
|
||||
if not MULTI_TENANT:
|
||||
return RateLimitResponse(
|
||||
is_limited=False,
|
||||
limit_type="weekly",
|
||||
messages_used=0,
|
||||
limit=0, # 0 indicates unlimited
|
||||
reset_timestamp=None,
|
||||
)
|
||||
|
||||
# Determine subscription status
|
||||
is_subscribed = is_user_subscribed(user, db_session)
|
||||
|
||||
# Set limits based on subscription
|
||||
limit = 50 if is_subscribed else 5
|
||||
limit_type: Literal["weekly", "total"] = "weekly" if is_subscribed else "total"
|
||||
|
||||
# Count messages
|
||||
user_id = user.id if user else None
|
||||
if user_id is None:
|
||||
# Unauthenticated users have no usage
|
||||
messages_used = 0
|
||||
reset_timestamp = None
|
||||
elif limit_type == "weekly":
|
||||
# Subscribed: rolling 7-day window
|
||||
cutoff_time = datetime.now(tz=timezone.utc) - timedelta(days=7)
|
||||
messages_used = count_user_messages_in_window(user_id, cutoff_time, db_session)
|
||||
|
||||
# Calculate reset timestamp (when oldest message ages out)
|
||||
# Only show reset time if user is at or over the limit
|
||||
if messages_used >= limit:
|
||||
oldest_msg = get_oldest_message_timestamp(user_id, cutoff_time, db_session)
|
||||
if oldest_msg:
|
||||
reset_time = oldest_msg + timedelta(days=7)
|
||||
reset_timestamp = reset_time.isoformat()
|
||||
else:
|
||||
reset_timestamp = None
|
||||
else:
|
||||
reset_timestamp = None
|
||||
else:
|
||||
# Non-subscribed: lifetime total
|
||||
messages_used = count_user_messages_total(user_id, db_session)
|
||||
reset_timestamp = None
|
||||
|
||||
return RateLimitResponse(
|
||||
is_limited=messages_used >= limit,
|
||||
limit_type=limit_type,
|
||||
messages_used=messages_used,
|
||||
limit=limit,
|
||||
reset_timestamp=reset_timestamp,
|
||||
)
|
||||
680
backend/onyx/server/features/build/api/sessions_api.py
Normal file
680
backend/onyx/server/features/build/api/sessions_api.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""API endpoints for Build Mode session management."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.build.api.models import ArtifactResponse
|
||||
from onyx.server.features.build.api.models import DetailedSessionResponse
|
||||
from onyx.server.features.build.api.models import DirectoryListing
|
||||
from onyx.server.features.build.api.models import GenerateSuggestionsRequest
|
||||
from onyx.server.features.build.api.models import GenerateSuggestionsResponse
|
||||
from onyx.server.features.build.api.models import SessionCreateRequest
|
||||
from onyx.server.features.build.api.models import SessionListResponse
|
||||
from onyx.server.features.build.api.models import SessionNameGenerateResponse
|
||||
from onyx.server.features.build.api.models import SessionResponse
|
||||
from onyx.server.features.build.api.models import SessionUpdateRequest
|
||||
from onyx.server.features.build.api.models import SuggestionBubble
|
||||
from onyx.server.features.build.api.models import SuggestionTheme
|
||||
from onyx.server.features.build.api.models import UploadResponse
|
||||
from onyx.server.features.build.api.models import WebappInfo
|
||||
from onyx.server.features.build.db.build_session import allocate_nextjs_port
|
||||
from onyx.server.features.build.db.build_session import get_build_session
|
||||
from onyx.server.features.build.db.sandbox import get_latest_snapshot_for_session
|
||||
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
|
||||
from onyx.server.features.build.db.sandbox import update_sandbox_status__no_commit
|
||||
from onyx.server.features.build.sandbox import get_sandbox_manager
|
||||
from onyx.server.features.build.session.manager import SessionManager
|
||||
from onyx.server.features.build.session.manager import UploadLimitExceededError
|
||||
from onyx.server.features.build.utils import sanitize_filename
|
||||
from onyx.server.features.build.utils import validate_file
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/sessions")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Session Management Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("", response_model=SessionListResponse)
|
||||
def list_sessions(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SessionListResponse:
|
||||
"""List all build sessions for the current user."""
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
sessions = session_manager.list_sessions(user.id)
|
||||
|
||||
# Get the user's sandbox (shared across all sessions)
|
||||
sandbox = get_sandbox_by_user_id(db_session, user.id)
|
||||
|
||||
return SessionListResponse(
|
||||
sessions=[SessionResponse.from_model(session, sandbox) for session in sessions]
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=DetailedSessionResponse)
|
||||
def create_session(
|
||||
request: SessionCreateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DetailedSessionResponse:
|
||||
"""
|
||||
Create or get an existing empty build session.
|
||||
|
||||
Creates a sandbox with the necessary file structure and returns a session ID.
|
||||
Uses SessionManager for session and sandbox provisioning.
|
||||
|
||||
This endpoint is atomic - if sandbox provisioning fails, no database
|
||||
records are created (transaction is rolled back).
|
||||
"""
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
try:
|
||||
# Only pass user_work_area and user_level if demo data is enabled
|
||||
# This prevents org_info directory creation when demo data is disabled
|
||||
build_session = session_manager.get_or_create_empty_session(
|
||||
user.id,
|
||||
user_work_area=(
|
||||
request.user_work_area if request.demo_data_enabled else None
|
||||
),
|
||||
user_level=request.user_level if request.demo_data_enabled else None,
|
||||
llm_provider_type=request.llm_provider_type,
|
||||
llm_model_name=request.llm_model_name,
|
||||
)
|
||||
db_session.commit()
|
||||
except ValueError as e:
|
||||
# Max concurrent sandboxes reached or other validation error
|
||||
logger.exception("Sandbox provisioning failed")
|
||||
db_session.rollback()
|
||||
raise HTTPException(status_code=429, detail=str(e))
|
||||
except Exception as e:
|
||||
# Sandbox provisioning failed - rollback to remove any uncommitted records
|
||||
db_session.rollback()
|
||||
logger.error(f"Sandbox provisioning failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Sandbox provisioning failed: {e}",
|
||||
)
|
||||
|
||||
# Get the user's sandbox to include in response
|
||||
sandbox = get_sandbox_by_user_id(db_session, user.id)
|
||||
base_response = SessionResponse.from_model(build_session, sandbox)
|
||||
# Session was just created, so it's loaded in the sandbox
|
||||
return DetailedSessionResponse.from_session_response(
|
||||
base_response, session_loaded_in_sandbox=True
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}", response_model=DetailedSessionResponse)
|
||||
def get_session_details(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DetailedSessionResponse:
|
||||
"""
|
||||
Get details of a specific build session.
|
||||
|
||||
Returns session_loaded_in_sandbox to indicate if the session workspace
|
||||
exists in the running sandbox.
|
||||
"""
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
session = session_manager.get_session(session_id, user.id)
|
||||
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Get the user's sandbox to include in response
|
||||
sandbox = get_sandbox_by_user_id(db_session, user.id)
|
||||
|
||||
# Check if session workspace exists in the sandbox
|
||||
session_loaded = False
|
||||
if sandbox and sandbox.status == SandboxStatus.RUNNING:
|
||||
sandbox_manager = get_sandbox_manager()
|
||||
session_loaded = sandbox_manager.session_workspace_exists(
|
||||
sandbox.id, session_id
|
||||
)
|
||||
|
||||
base_response = SessionResponse.from_model(session, sandbox)
|
||||
return DetailedSessionResponse.from_session_response(
|
||||
base_response, session_loaded_in_sandbox=session_loaded
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/generate-name", response_model=SessionNameGenerateResponse)
|
||||
def generate_session_name(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SessionNameGenerateResponse:
|
||||
"""Generate a session name using LLM based on the first user message."""
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
generated_name = session_manager.generate_session_name(session_id, user.id)
|
||||
|
||||
if generated_name is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return SessionNameGenerateResponse(name=generated_name)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{session_id}/generate-suggestions", response_model=GenerateSuggestionsResponse
|
||||
)
|
||||
def generate_suggestions(
|
||||
session_id: UUID,
|
||||
request: GenerateSuggestionsRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> GenerateSuggestionsResponse:
|
||||
"""Generate follow-up suggestions based on the first exchange in a session."""
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
# Verify session exists and belongs to user
|
||||
session = session_manager.get_session(session_id, user.id)
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Generate suggestions
|
||||
suggestions_data = session_manager.generate_followup_suggestions(
|
||||
user_message=request.user_message,
|
||||
assistant_message=request.assistant_message,
|
||||
)
|
||||
|
||||
# Convert to response model
|
||||
suggestions = [
|
||||
SuggestionBubble(
|
||||
theme=SuggestionTheme(item["theme"]),
|
||||
text=item["text"],
|
||||
)
|
||||
for item in suggestions_data
|
||||
]
|
||||
|
||||
return GenerateSuggestionsResponse(suggestions=suggestions)
|
||||
|
||||
|
||||
@router.put("/{session_id}/name", response_model=SessionResponse)
|
||||
def update_session_name(
|
||||
session_id: UUID,
|
||||
request: SessionUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SessionResponse:
|
||||
"""Update the name of a build session."""
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
session = session_manager.update_session_name(session_id, user.id, request.name)
|
||||
|
||||
if session is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# Get the user's sandbox to include in response
|
||||
sandbox = get_sandbox_by_user_id(db_session, user.id)
|
||||
return SessionResponse.from_model(session, sandbox)
|
||||
|
||||
|
||||
@router.delete("/{session_id}", response_model=None)
|
||||
def delete_session(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Delete a build session and all associated data.
|
||||
|
||||
This endpoint is atomic - if sandbox termination fails, the session
|
||||
is NOT deleted (transaction is rolled back).
|
||||
"""
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
try:
|
||||
success = session_manager.delete_session(session_id, user.id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
db_session.commit()
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions (like 404) without rollback
|
||||
raise
|
||||
except Exception as e:
|
||||
# Sandbox termination failed - rollback to preserve session
|
||||
db_session.rollback()
|
||||
logger.error(f"Failed to delete session {session_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to delete session: {e}",
|
||||
)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
# Lock timeout should be longer than max restore time (5 minutes)
|
||||
RESTORE_LOCK_TIMEOUT_SECONDS = 300
|
||||
|
||||
|
||||
@router.post("/{session_id}/restore", response_model=DetailedSessionResponse)
|
||||
def restore_session(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DetailedSessionResponse:
|
||||
"""Restore sandbox and load session snapshot. Blocks until complete.
|
||||
|
||||
Uses Redis lock to ensure only one restore runs per sandbox at a time.
|
||||
If another restore is in progress, waits for it to complete.
|
||||
|
||||
Handles two cases:
|
||||
1. Sandbox is SLEEPING: Re-provision pod, then load session snapshot
|
||||
2. Sandbox is RUNNING but session not loaded: Just load session snapshot
|
||||
|
||||
Returns immediately if session workspace already exists in pod.
|
||||
Always returns session_loaded_in_sandbox=True on success.
|
||||
"""
|
||||
session = get_build_session(session_id, user.id, db_session)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
sandbox = get_sandbox_by_user_id(db_session, user.id)
|
||||
if not sandbox:
|
||||
raise HTTPException(status_code=404, detail="Sandbox not found")
|
||||
|
||||
# If sandbox is already running, check if session workspace exists
|
||||
sandbox_manager = get_sandbox_manager()
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# Need to do some work - acquire Redis lock
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
lock_key = f"sandbox_restore:{sandbox.id}"
|
||||
lock = redis_client.lock(lock_key, timeout=RESTORE_LOCK_TIMEOUT_SECONDS)
|
||||
|
||||
# blocking=True means wait if another restore is in progress
|
||||
acquired = lock.acquire(
|
||||
blocking=True, blocking_timeout=RESTORE_LOCK_TIMEOUT_SECONDS
|
||||
)
|
||||
if not acquired:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Restore operation timed out waiting for lock",
|
||||
)
|
||||
|
||||
try:
|
||||
# Re-fetch sandbox status (may have changed while waiting for lock)
|
||||
db_session.refresh(sandbox)
|
||||
|
||||
# Also re-check if session workspace exists (another request may have
|
||||
# restored it while we were waiting)
|
||||
if sandbox.status == SandboxStatus.RUNNING:
|
||||
# Verify pod is healthy before proceeding
|
||||
is_healthy = sandbox_manager.health_check(sandbox.id, timeout=10.0)
|
||||
if is_healthy and sandbox_manager.session_workspace_exists(
|
||||
sandbox.id, session_id
|
||||
):
|
||||
logger.info(
|
||||
f"Session {session_id} workspace was restored by another request"
|
||||
)
|
||||
base_response = SessionResponse.from_model(session, sandbox)
|
||||
return DetailedSessionResponse.from_session_response(
|
||||
base_response, session_loaded_in_sandbox=True
|
||||
)
|
||||
|
||||
if not is_healthy:
|
||||
logger.warning(
|
||||
f"Sandbox {sandbox.id} marked as RUNNING but pod is "
|
||||
f"unhealthy/missing. Entering recovery mode."
|
||||
)
|
||||
# Terminate to clean up any lingering K8s resources
|
||||
sandbox_manager.terminate(sandbox.id)
|
||||
|
||||
update_sandbox_status__no_commit(
|
||||
db_session, sandbox.id, SandboxStatus.TERMINATED
|
||||
)
|
||||
db_session.commit()
|
||||
db_session.refresh(sandbox)
|
||||
# Fall through to TERMINATED handling below
|
||||
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
if sandbox.status in (SandboxStatus.SLEEPING, SandboxStatus.TERMINATED):
|
||||
# 1. Re-provision the pod
|
||||
logger.info(f"Re-provisioning {sandbox.status.value} sandbox {sandbox.id}")
|
||||
llm_config = session_manager._get_llm_config(None, None)
|
||||
sandbox_manager.provision(
|
||||
sandbox_id=sandbox.id,
|
||||
user_id=user.id,
|
||||
tenant_id=tenant_id,
|
||||
llm_config=llm_config,
|
||||
)
|
||||
update_sandbox_status__no_commit(
|
||||
db_session, sandbox.id, SandboxStatus.RUNNING
|
||||
)
|
||||
db_session.commit()
|
||||
db_session.refresh(sandbox)
|
||||
|
||||
# 2. Check if session workspace needs to be loaded
|
||||
if sandbox.status == SandboxStatus.RUNNING:
|
||||
if not sandbox_manager.session_workspace_exists(sandbox.id, session_id):
|
||||
# Get latest snapshot and restore it
|
||||
snapshot = get_latest_snapshot_for_session(db_session, session_id)
|
||||
if snapshot:
|
||||
# Allocate a new port for the restored session
|
||||
new_port = allocate_nextjs_port(db_session)
|
||||
session.nextjs_port = new_port
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Restoring snapshot for session {session_id} "
|
||||
f"from {snapshot.storage_path} with port {new_port}"
|
||||
)
|
||||
|
||||
try:
|
||||
sandbox_manager.restore_snapshot(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
snapshot_storage_path=snapshot.storage_path,
|
||||
tenant_id=tenant_id,
|
||||
nextjs_port=new_port,
|
||||
)
|
||||
except Exception as e:
|
||||
# Clear the port allocation on failure so it can be reused
|
||||
logger.error(
|
||||
f"Failed to restore session {session_id}, "
|
||||
f"clearing port {new_port}: {e}"
|
||||
)
|
||||
session.nextjs_port = None
|
||||
db_session.commit()
|
||||
raise
|
||||
else:
|
||||
# No snapshot - set up fresh workspace
|
||||
logger.info(
|
||||
f"No snapshot found for session {session_id}, "
|
||||
f"setting up fresh workspace"
|
||||
)
|
||||
llm_config = session_manager._get_llm_config(None, None)
|
||||
sandbox_manager.setup_session_workspace(
|
||||
sandbox_id=sandbox.id,
|
||||
session_id=session_id,
|
||||
llm_config=llm_config,
|
||||
nextjs_port=session.nextjs_port or 3010,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restore session {session_id}: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to restore session: {e}",
|
||||
)
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
base_response = SessionResponse.from_model(session, sandbox)
|
||||
return DetailedSessionResponse.from_session_response(
|
||||
base_response, session_loaded_in_sandbox=True
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Artifact Endpoints
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{session_id}/artifacts",
|
||||
response_model=list[ArtifactResponse],
|
||||
)
|
||||
def list_artifacts(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[dict]:
|
||||
"""List artifacts generated in the session."""
|
||||
user_id: UUID = user.id
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
artifacts = session_manager.list_artifacts(session_id, user_id)
|
||||
if artifacts is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return artifacts
|
||||
|
||||
|
||||
@router.get("/{session_id}/files", response_model=DirectoryListing)
|
||||
def list_directory(
|
||||
session_id: UUID,
|
||||
path: str = "",
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DirectoryListing:
|
||||
"""
|
||||
List files and directories in the sandbox.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
path: Relative path from sandbox root (empty string for root)
|
||||
|
||||
Returns:
|
||||
DirectoryListing with sorted entries (directories first, then files)
|
||||
"""
|
||||
user_id: UUID = user.id
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
try:
|
||||
listing = session_manager.list_directory(session_id, user_id, path)
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
if "path traversal" in error_message.lower():
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
elif "not found" in error_message.lower():
|
||||
raise HTTPException(status_code=404, detail="Directory not found")
|
||||
elif "not a directory" in error_message.lower():
|
||||
raise HTTPException(status_code=400, detail="Path is not a directory")
|
||||
raise HTTPException(status_code=400, detail=error_message)
|
||||
|
||||
if listing is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return listing
|
||||
|
||||
|
||||
@router.get("/{session_id}/artifacts/{path:path}")
|
||||
def download_artifact(
|
||||
session_id: UUID,
|
||||
path: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Download a specific artifact file."""
|
||||
user_id: UUID = user.id
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
try:
|
||||
result = session_manager.download_artifact(session_id, user_id, path)
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
if (
|
||||
"path traversal" in error_message.lower()
|
||||
or "access denied" in error_message.lower()
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
elif "directory" in error_message.lower():
|
||||
raise HTTPException(status_code=400, detail="Cannot download directory")
|
||||
raise HTTPException(status_code=400, detail=error_message)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
|
||||
content, mime_type, filename = result
|
||||
|
||||
# Handle Unicode filenames in Content-Disposition header
|
||||
# HTTP headers require Latin-1 encoding, so we use RFC 5987 for Unicode
|
||||
try:
|
||||
# Try Latin-1 encoding first (ASCII-compatible filenames)
|
||||
filename.encode("latin-1")
|
||||
content_disposition = f'attachment; filename="{filename}"'
|
||||
except UnicodeEncodeError:
|
||||
# Use RFC 5987 encoding for Unicode filenames
|
||||
from urllib.parse import quote
|
||||
|
||||
encoded_filename = quote(filename, safe="")
|
||||
content_disposition = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=mime_type,
|
||||
headers={
|
||||
"Content-Disposition": content_disposition,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{session_id}/webapp-info", response_model=WebappInfo)
|
||||
def get_webapp_info(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> WebappInfo:
|
||||
"""
|
||||
Get webapp information for a session.
|
||||
|
||||
Returns whether a webapp exists, its URL, and the sandbox status.
|
||||
"""
|
||||
user_id: UUID = user.id
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
webapp_info = session_manager.get_webapp_info(session_id, user_id)
|
||||
|
||||
if webapp_info is None:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return WebappInfo(**webapp_info)
|
||||
|
||||
|
||||
@router.get("/{session_id}/webapp/download")
|
||||
def download_webapp(
|
||||
session_id: UUID,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""
|
||||
Download the webapp directory as a zip file.
|
||||
|
||||
Returns the entire outputs/web directory as a zip archive.
|
||||
"""
|
||||
user_id: UUID = user.id
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
result = session_manager.download_webapp_zip(session_id, user_id)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(status_code=404, detail="Webapp not found")
|
||||
|
||||
zip_bytes, filename = result
|
||||
|
||||
return Response(
|
||||
content=zip_bytes,
|
||||
media_type="application/zip",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{session_id}/upload", response_model=UploadResponse)
|
||||
async def upload_file_endpoint(
|
||||
session_id: UUID,
|
||||
file: UploadFile = File(...),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UploadResponse:
|
||||
"""Upload a file to the session's sandbox.
|
||||
|
||||
The file will be placed in the sandbox's attachments directory.
|
||||
"""
|
||||
user_id: UUID = user.id
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File has no filename")
|
||||
|
||||
# Read file content
|
||||
content = await file.read()
|
||||
|
||||
# Validate file (extension, mime type, size)
|
||||
is_valid, error = validate_file(file.filename, file.content_type, len(content))
|
||||
if not is_valid:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
|
||||
# Sanitize filename
|
||||
safe_filename = sanitize_filename(file.filename)
|
||||
|
||||
try:
|
||||
relative_path, _ = session_manager.upload_file(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
filename=safe_filename,
|
||||
content=content,
|
||||
)
|
||||
except UploadLimitExceededError as e:
|
||||
# Return 429 for limit exceeded errors
|
||||
raise HTTPException(status_code=429, detail=str(e))
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
if "not found" in error_message.lower():
|
||||
raise HTTPException(status_code=404, detail=error_message)
|
||||
raise HTTPException(status_code=400, detail=error_message)
|
||||
|
||||
return UploadResponse(
|
||||
filename=safe_filename,
|
||||
path=relative_path,
|
||||
size_bytes=len(content),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{session_id}/files/{path:path}", response_model=None)
|
||||
def delete_file_endpoint(
|
||||
session_id: UUID,
|
||||
path: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
"""Delete a file from the session's sandbox.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
path: Relative path to the file (e.g., "attachments/doc.pdf")
|
||||
"""
|
||||
user_id: UUID = user.id
|
||||
session_manager = SessionManager(db_session)
|
||||
|
||||
try:
|
||||
deleted = session_manager.delete_file(session_id, user_id, path)
|
||||
except ValueError as e:
|
||||
error_message = str(e)
|
||||
if "path traversal" in error_message.lower():
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
elif "not found" in error_message.lower():
|
||||
raise HTTPException(status_code=404, detail=error_message)
|
||||
elif "directory" in error_message.lower():
|
||||
raise HTTPException(status_code=400, detail="Cannot delete directory")
|
||||
raise HTTPException(status_code=400, detail=error_message)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
return Response(status_code=204)
|
||||
52
backend/onyx/server/features/build/api/subscription_check.py
Normal file
52
backend/onyx/server/features/build/api/subscription_check.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Subscription detection for Build Mode rate limiting."""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.db.models import User
|
||||
from onyx.server.usage_limits import is_tenant_on_trial_fn
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_user_subscribed(user: User, db_session: Session) -> bool:
|
||||
"""
|
||||
Check if a user has an active subscription.
|
||||
|
||||
For cloud (MULTI_TENANT=true):
|
||||
- Checks Stripe billing via control plane
|
||||
- Returns True if tenant is NOT on trial (subscribed = NOT on trial)
|
||||
|
||||
For self-hosted (MULTI_TENANT=false):
|
||||
- Checks license metadata
|
||||
- Returns True if license status is ACTIVE
|
||||
|
||||
Args:
|
||||
user: The user object (None for unauthenticated users)
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
True if user has active subscription, False otherwise
|
||||
"""
|
||||
if DEV_MODE:
|
||||
return True
|
||||
|
||||
if user is None:
|
||||
return False
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Cloud: check Stripe billing via control plane
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
on_trial = is_tenant_on_trial_fn(tenant_id)
|
||||
# Subscribed = NOT on trial
|
||||
return not on_trial
|
||||
except Exception as e:
|
||||
logger.warning(f"Subscription check failed for tenant {tenant_id}: {e}")
|
||||
# Default to non-subscribed (safer/more restrictive)
|
||||
return False
|
||||
|
||||
return True
|
||||
117
backend/onyx/server/features/build/configs.py
Normal file
117
backend/onyx/server/features/build/configs.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class SandboxBackend(str, Enum):
|
||||
"""Backend mode for sandbox operations.
|
||||
|
||||
LOCAL: Development mode - no snapshots, no automatic cleanup
|
||||
KUBERNETES: Production mode - full snapshots and cleanup
|
||||
"""
|
||||
|
||||
LOCAL = "local"
|
||||
KUBERNETES = "kubernetes"
|
||||
|
||||
|
||||
# Sandbox backend mode (controls snapshot and cleanup behavior)
|
||||
# "local" = no snapshots, no cleanup (for development)
|
||||
# "kubernetes" = full snapshots and cleanup (for production)
|
||||
SANDBOX_BACKEND = SandboxBackend(os.environ.get("SANDBOX_BACKEND", "local"))
|
||||
|
||||
|
||||
# Persistent Document Storage Configuration
|
||||
# When enabled, indexed documents are written to local filesystem with hierarchical structure
|
||||
PERSISTENT_DOCUMENT_STORAGE_ENABLED = (
|
||||
os.environ.get("PERSISTENT_DOCUMENT_STORAGE_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Base directory path for persistent document storage (local filesystem)
|
||||
# Example: /var/onyx/indexed-docs or /app/indexed-docs
|
||||
PERSISTENT_DOCUMENT_STORAGE_PATH = os.environ.get(
|
||||
"PERSISTENT_DOCUMENT_STORAGE_PATH", ""
|
||||
)
|
||||
|
||||
# Demo Data Path
|
||||
# Local: Source tree path (relative to this file)
|
||||
# Kubernetes: Baked into container image at /workspace/demo-data
|
||||
_THIS_FILE = Path(__file__)
|
||||
DEMO_DATA_PATH = str(
|
||||
_THIS_FILE.parent / "sandbox" / "kubernetes" / "docker" / "demo_data"
|
||||
)
|
||||
|
||||
# Sandbox filesystem paths
|
||||
SANDBOX_BASE_PATH = os.environ.get("SANDBOX_BASE_PATH", "/tmp/onyx-sandboxes")
|
||||
OUTPUTS_TEMPLATE_PATH = os.environ.get("OUTPUTS_TEMPLATE_PATH", "/templates/outputs")
|
||||
VENV_TEMPLATE_PATH = os.environ.get("VENV_TEMPLATE_PATH", "/templates/venv")
|
||||
|
||||
# Sandbox agent configuration
|
||||
SANDBOX_AGENT_COMMAND = os.environ.get("SANDBOX_AGENT_COMMAND", "opencode").split()
|
||||
|
||||
# OpenCode disabled tools (comma-separated list)
|
||||
# Available tools: bash, edit, write, read, grep, glob, list, lsp, patch,
|
||||
# skill, todowrite, todoread, webfetch, question
|
||||
# Example: "question,webfetch" to disable user questions and web fetching
|
||||
_disabled_tools_str = os.environ.get("OPENCODE_DISABLED_TOOLS", "question")
|
||||
OPENCODE_DISABLED_TOOLS: list[str] = [
|
||||
t.strip() for t in _disabled_tools_str.split(",") if t.strip()
|
||||
]
|
||||
|
||||
# Sandbox lifecycle configuration
|
||||
SANDBOX_IDLE_TIMEOUT_SECONDS = int(
|
||||
os.environ.get("SANDBOX_IDLE_TIMEOUT_SECONDS", "3600")
|
||||
)
|
||||
SANDBOX_MAX_CONCURRENT_PER_ORG = int(
|
||||
os.environ.get("SANDBOX_MAX_CONCURRENT_PER_ORG", "10")
|
||||
)
|
||||
|
||||
# Sandbox snapshot storage
|
||||
SANDBOX_SNAPSHOTS_BUCKET = os.environ.get(
|
||||
"SANDBOX_SNAPSHOTS_BUCKET", "sandbox-snapshots"
|
||||
)
|
||||
|
||||
# Next.js preview server port range
|
||||
SANDBOX_NEXTJS_PORT_START = int(os.environ.get("SANDBOX_NEXTJS_PORT_START", "3010"))
|
||||
SANDBOX_NEXTJS_PORT_END = int(os.environ.get("SANDBOX_NEXTJS_PORT_END", "3100"))
|
||||
|
||||
# File upload configuration
|
||||
MAX_UPLOAD_FILE_SIZE_MB = int(os.environ.get("BUILD_MAX_UPLOAD_FILE_SIZE_MB", "50"))
|
||||
MAX_UPLOAD_FILE_SIZE_BYTES = MAX_UPLOAD_FILE_SIZE_MB * 1024 * 1024
|
||||
MAX_UPLOAD_FILES_PER_SESSION = int(
|
||||
os.environ.get("BUILD_MAX_UPLOAD_FILES_PER_SESSION", "20")
|
||||
)
|
||||
MAX_TOTAL_UPLOAD_SIZE_MB = int(os.environ.get("BUILD_MAX_TOTAL_UPLOAD_SIZE_MB", "200"))
|
||||
MAX_TOTAL_UPLOAD_SIZE_BYTES = MAX_TOTAL_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
ATTACHMENTS_DIRECTORY = "attachments"
|
||||
|
||||
# ============================================================================
|
||||
# Kubernetes Sandbox Configuration
|
||||
# Only used when SANDBOX_BACKEND = "kubernetes"
|
||||
# ============================================================================
|
||||
|
||||
# Namespace where sandbox pods are created
|
||||
SANDBOX_NAMESPACE = os.environ.get("SANDBOX_NAMESPACE", "onyx-sandboxes")
|
||||
|
||||
# Container image for sandbox pods
|
||||
# Should include Next.js template and opencode CLI
|
||||
SANDBOX_CONTAINER_IMAGE = os.environ.get(
|
||||
"SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.0"
|
||||
)
|
||||
|
||||
# S3 bucket for sandbox file storage (snapshots, knowledge files, uploads)
|
||||
# Path structure: s3://{bucket}/{tenant_id}/snapshots/{session_id}/{snapshot_id}.tar.gz
|
||||
# s3://{bucket}/{tenant_id}/knowledge/{user_id}/
|
||||
# s3://{bucket}/{tenant_id}/uploads/{session_id}/
|
||||
SANDBOX_S3_BUCKET = os.environ.get("SANDBOX_S3_BUCKET", "onyx-sandbox-files")
|
||||
|
||||
# Service account for sandbox pods (NO IRSA - no AWS API access)
|
||||
SANDBOX_SERVICE_ACCOUNT_NAME = os.environ.get(
|
||||
"SANDBOX_SERVICE_ACCOUNT_NAME", "sandbox-runner"
|
||||
)
|
||||
|
||||
# Service account for init container (has IRSA for S3 access)
|
||||
SANDBOX_FILE_SYNC_SERVICE_ACCOUNT = os.environ.get(
|
||||
"SANDBOX_FILE_SYNC_SERVICE_ACCOUNT", "sandbox-file-sync"
|
||||
)
|
||||
|
||||
ENABLE_CRAFT = os.environ.get("ENABLE_CRAFT", "false").lower() == "true"
|
||||
1
backend/onyx/server/features/build/db/__init__.py
Normal file
1
backend/onyx/server/features/build/db/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Database operations for the build feature
|
||||
544
backend/onyx/server/features/build/db/build_session.py
Normal file
544
backend/onyx/server/features/build/db/build_session.py
Normal file
@@ -0,0 +1,544 @@
|
||||
"""Database operations for Build Mode sessions."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.db.models import Artifact
|
||||
from onyx.db.models import BuildMessage
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
from onyx.db.models import Sandbox
|
||||
from onyx.db.models import Snapshot
|
||||
from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_END
|
||||
from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_START
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def create_build_session__no_commit(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
name: str | None = None,
|
||||
) -> BuildSession:
|
||||
"""Create a new build session for the given user.
|
||||
|
||||
NOTE: This function uses flush() instead of commit(). The caller is
|
||||
responsible for committing the transaction when ready.
|
||||
"""
|
||||
session = BuildSession(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
status=BuildSessionStatus.ACTIVE,
|
||||
)
|
||||
db_session.add(session)
|
||||
db_session.flush()
|
||||
|
||||
logger.info(f"Created build session {session.id} for user {user_id}")
|
||||
return session
|
||||
|
||||
|
||||
def get_build_session(
|
||||
session_id: UUID,
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
) -> BuildSession | None:
|
||||
"""Get a build session by ID, ensuring it belongs to the user."""
|
||||
return (
|
||||
db_session.query(BuildSession)
|
||||
.filter(
|
||||
BuildSession.id == session_id,
|
||||
BuildSession.user_id == user_id,
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
def get_user_build_sessions(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
limit: int = 100,
|
||||
) -> list[BuildSession]:
|
||||
"""Get all build sessions for a user that have at least 1 message.
|
||||
|
||||
Excludes empty (pre-provisioned) sessions from the listing.
|
||||
"""
|
||||
return (
|
||||
db_session.query(BuildSession)
|
||||
.join(BuildMessage) # Inner join excludes empty sessions
|
||||
.filter(BuildSession.user_id == user_id)
|
||||
.group_by(BuildSession.id)
|
||||
.order_by(desc(BuildSession.created_at))
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def get_empty_session_for_user(
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
max_age_minutes: int = 30,
|
||||
) -> BuildSession | None:
|
||||
"""Get the user's empty session (0 messages) if one exists and is recent."""
|
||||
cutoff = datetime.utcnow() - timedelta(minutes=max_age_minutes)
|
||||
|
||||
return (
|
||||
db_session.query(BuildSession)
|
||||
.filter(
|
||||
BuildSession.user_id == user_id,
|
||||
BuildSession.created_at > cutoff,
|
||||
~exists().where(BuildMessage.session_id == BuildSession.id),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def update_session_activity(
|
||||
session_id: UUID,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update the last activity timestamp for a session."""
|
||||
session = (
|
||||
db_session.query(BuildSession)
|
||||
.filter(BuildSession.id == session_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if session:
|
||||
session.last_activity_at = datetime.utcnow()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_session_status(
|
||||
session_id: UUID,
|
||||
status: BuildSessionStatus,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update the status of a build session."""
|
||||
session = (
|
||||
db_session.query(BuildSession)
|
||||
.filter(BuildSession.id == session_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if session:
|
||||
session.status = status
|
||||
db_session.commit()
|
||||
logger.info(f"Updated build session {session_id} status to {status}")
|
||||
|
||||
|
||||
def delete_build_session__no_commit(
|
||||
session_id: UUID,
|
||||
user_id: UUID,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
"""Delete a build session and all related data.
|
||||
|
||||
NOTE: This function uses flush() instead of commit(). The caller is
|
||||
responsible for committing the transaction when ready.
|
||||
"""
|
||||
session = get_build_session(session_id, user_id, db_session)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
db_session.delete(session)
|
||||
db_session.flush()
|
||||
logger.info(f"Deleted build session {session_id}")
|
||||
return True
|
||||
|
||||
|
||||
# Sandbox operations
|
||||
# NOTE: Most sandbox operations have moved to sandbox.py
|
||||
# These remain here for convenience in session-related workflows
|
||||
|
||||
|
||||
def update_sandbox_status(
|
||||
sandbox_id: UUID,
|
||||
status: SandboxStatus,
|
||||
db_session: Session,
|
||||
container_id: str | None = None,
|
||||
) -> None:
|
||||
"""Update the status of a sandbox."""
|
||||
sandbox = db_session.query(Sandbox).filter(Sandbox.id == sandbox_id).one_or_none()
|
||||
if sandbox:
|
||||
sandbox.status = status
|
||||
if container_id is not None:
|
||||
sandbox.container_id = container_id
|
||||
sandbox.last_heartbeat = datetime.utcnow()
|
||||
db_session.commit()
|
||||
logger.info(f"Updated sandbox {sandbox_id} status to {status}")
|
||||
|
||||
|
||||
def update_sandbox_heartbeat(
|
||||
sandbox_id: UUID,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update the heartbeat timestamp for a sandbox."""
|
||||
sandbox = db_session.query(Sandbox).filter(Sandbox.id == sandbox_id).one_or_none()
|
||||
if sandbox:
|
||||
sandbox.last_heartbeat = datetime.utcnow()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# Artifact operations
|
||||
def create_artifact(
|
||||
session_id: UUID,
|
||||
artifact_type: str,
|
||||
path: str,
|
||||
name: str,
|
||||
db_session: Session,
|
||||
) -> Artifact:
|
||||
"""Create a new artifact record."""
|
||||
artifact = Artifact(
|
||||
session_id=session_id,
|
||||
type=artifact_type,
|
||||
path=path,
|
||||
name=name,
|
||||
)
|
||||
db_session.add(artifact)
|
||||
db_session.commit()
|
||||
db_session.refresh(artifact)
|
||||
|
||||
logger.info(f"Created artifact {artifact.id} for session {session_id}")
|
||||
return artifact
|
||||
|
||||
|
||||
def get_session_artifacts(
|
||||
session_id: UUID,
|
||||
db_session: Session,
|
||||
) -> list[Artifact]:
|
||||
"""Get all artifacts for a session."""
|
||||
return (
|
||||
db_session.query(Artifact)
|
||||
.filter(Artifact.session_id == session_id)
|
||||
.order_by(desc(Artifact.created_at))
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def update_artifact(
|
||||
artifact_id: UUID,
|
||||
db_session: Session,
|
||||
path: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
"""Update artifact metadata."""
|
||||
artifact = (
|
||||
db_session.query(Artifact).filter(Artifact.id == artifact_id).one_or_none()
|
||||
)
|
||||
if artifact:
|
||||
if path is not None:
|
||||
artifact.path = path
|
||||
if name is not None:
|
||||
artifact.name = name
|
||||
artifact.updated_at = datetime.utcnow()
|
||||
db_session.commit()
|
||||
logger.info(f"Updated artifact {artifact_id}")
|
||||
|
||||
|
||||
# Snapshot operations
|
||||
def create_snapshot(
|
||||
session_id: UUID,
|
||||
storage_path: str,
|
||||
size_bytes: int,
|
||||
db_session: Session,
|
||||
) -> Snapshot:
|
||||
"""Create a new snapshot record."""
|
||||
snapshot = Snapshot(
|
||||
session_id=session_id,
|
||||
storage_path=storage_path,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
db_session.add(snapshot)
|
||||
db_session.commit()
|
||||
db_session.refresh(snapshot)
|
||||
|
||||
logger.info(f"Created snapshot {snapshot.id} for session {session_id}")
|
||||
return snapshot
|
||||
|
||||
|
||||
# Message operations
|
||||
def create_message(
|
||||
session_id: UUID,
|
||||
message_type: MessageType,
|
||||
turn_index: int,
|
||||
message_metadata: dict[str, Any],
|
||||
db_session: Session,
|
||||
) -> BuildMessage:
|
||||
"""Create a new message in a build session.
|
||||
|
||||
All message data is stored in message_metadata as JSON.
|
||||
|
||||
Args:
|
||||
session_id: Session UUID
|
||||
message_type: Type of message (USER, ASSISTANT, SYSTEM)
|
||||
turn_index: 0-indexed user message number this message belongs to
|
||||
message_metadata: Required structured data (the raw ACP packet JSON)
|
||||
db_session: Database session
|
||||
"""
|
||||
message = BuildMessage(
|
||||
session_id=session_id,
|
||||
turn_index=turn_index,
|
||||
type=message_type,
|
||||
message_metadata=message_metadata,
|
||||
)
|
||||
db_session.add(message)
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
logger.info(
|
||||
f"Created {message_type.value} message {message.id} for session {session_id} "
|
||||
f"turn={turn_index} type={message_metadata.get('type')}"
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def update_message(
|
||||
message_id: UUID,
|
||||
message_metadata: dict[str, Any],
|
||||
db_session: Session,
|
||||
) -> BuildMessage | None:
|
||||
"""Update an existing message's metadata.
|
||||
|
||||
Used for upserting agent_plan_update messages.
|
||||
|
||||
Args:
|
||||
message_id: The message UUID to update
|
||||
message_metadata: New metadata to set
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Updated BuildMessage or None if not found
|
||||
"""
|
||||
message = (
|
||||
db_session.query(BuildMessage).filter(BuildMessage.id == message_id).first()
|
||||
)
|
||||
if message is None:
|
||||
return None
|
||||
|
||||
message.message_metadata = message_metadata
|
||||
db_session.commit()
|
||||
db_session.refresh(message)
|
||||
|
||||
logger.info(
|
||||
f"Updated message {message_id} metadata type={message_metadata.get('type')}"
|
||||
)
|
||||
return message
|
||||
|
||||
|
||||
def upsert_agent_plan(
|
||||
session_id: UUID,
|
||||
turn_index: int,
|
||||
plan_metadata: dict[str, Any],
|
||||
db_session: Session,
|
||||
existing_plan_id: UUID | None = None,
|
||||
) -> BuildMessage:
|
||||
"""Upsert an agent plan - update if exists, create if not.
|
||||
|
||||
Each session/turn should only have one agent_plan_update message.
|
||||
This function updates the existing plan message or creates a new one.
|
||||
|
||||
Args:
|
||||
session_id: Session UUID
|
||||
turn_index: Current turn index
|
||||
plan_metadata: The agent_plan_update packet data
|
||||
db_session: Database session
|
||||
existing_plan_id: ID of existing plan message to update (if known)
|
||||
|
||||
Returns:
|
||||
The created or updated BuildMessage
|
||||
"""
|
||||
if existing_plan_id:
|
||||
# Fast path: we know the plan ID
|
||||
updated = update_message(existing_plan_id, plan_metadata, db_session)
|
||||
if updated:
|
||||
return updated
|
||||
|
||||
# Check if a plan already exists for this session/turn
|
||||
existing_plan = (
|
||||
db_session.query(BuildMessage)
|
||||
.filter(
|
||||
BuildMessage.session_id == session_id,
|
||||
BuildMessage.turn_index == turn_index,
|
||||
BuildMessage.message_metadata["type"].astext == "agent_plan_update",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_plan:
|
||||
existing_plan.message_metadata = plan_metadata
|
||||
db_session.commit()
|
||||
db_session.refresh(existing_plan)
|
||||
logger.info(
|
||||
f"Updated agent_plan_update message {existing_plan.id} for session {session_id}"
|
||||
)
|
||||
return existing_plan
|
||||
|
||||
# Create new plan message
|
||||
return create_message(
|
||||
session_id=session_id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
turn_index=turn_index,
|
||||
message_metadata=plan_metadata,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def get_session_messages(
|
||||
session_id: UUID,
|
||||
db_session: Session,
|
||||
) -> list[BuildMessage]:
|
||||
"""Get all messages for a session, ordered by turn index and creation time."""
|
||||
return (
|
||||
db_session.query(BuildMessage)
|
||||
.filter(BuildMessage.session_id == session_id)
|
||||
.order_by(BuildMessage.turn_index, BuildMessage.created_at)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _is_port_available(port: int) -> bool:
|
||||
"""Check if a port is available by attempting to bind to it.
|
||||
|
||||
Checks both IPv4 and IPv6 wildcard addresses to properly detect
|
||||
if anything is listening on the port, regardless of address family.
|
||||
"""
|
||||
import socket
|
||||
|
||||
logger.debug(f"Checking if port {port} is available")
|
||||
|
||||
# Check IPv4 wildcard (0.0.0.0) - this will detect any IPv4 listener
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("0.0.0.0", port))
|
||||
logger.debug(f"Port {port} IPv4 wildcard bind successful")
|
||||
except OSError as e:
|
||||
logger.debug(f"Port {port} IPv4 wildcard not available: {e}")
|
||||
return False
|
||||
|
||||
# Check IPv6 wildcard (::) - this will detect any IPv6 listener
|
||||
try:
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
# IPV6_V6ONLY must be False to allow dual-stack behavior
|
||||
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
|
||||
sock.bind(("::", port))
|
||||
logger.debug(f"Port {port} IPv6 wildcard bind successful")
|
||||
except OSError as e:
|
||||
logger.debug(f"Port {port} IPv6 wildcard not available: {e}")
|
||||
return False
|
||||
|
||||
logger.debug(f"Port {port} is available")
|
||||
return True
|
||||
|
||||
|
||||
def allocate_nextjs_port(db_session: Session) -> int:
|
||||
"""Allocate an available port for a new session.
|
||||
|
||||
Finds the first available port in the configured range by checking
|
||||
both database allocations and system-level port availability.
|
||||
|
||||
Args:
|
||||
db_session: Database session for querying allocated ports
|
||||
|
||||
Returns:
|
||||
An available port number
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no ports are available in the configured range
|
||||
"""
|
||||
from onyx.db.models import BuildSession
|
||||
|
||||
# Get all currently allocated ports from active sessions
|
||||
allocated_ports = set(
|
||||
db_session.query(BuildSession.nextjs_port)
|
||||
.filter(BuildSession.nextjs_port.isnot(None))
|
||||
.all()
|
||||
)
|
||||
allocated_ports = {port[0] for port in allocated_ports if port[0] is not None}
|
||||
|
||||
# Find first port that's not in DB and not currently bound
|
||||
for port in range(SANDBOX_NEXTJS_PORT_START, SANDBOX_NEXTJS_PORT_END):
|
||||
if port not in allocated_ports and _is_port_available(port):
|
||||
return port
|
||||
|
||||
raise RuntimeError(
|
||||
f"No available ports in range [{SANDBOX_NEXTJS_PORT_START}, {SANDBOX_NEXTJS_PORT_END})"
|
||||
)
|
||||
|
||||
|
||||
def clear_nextjs_ports_for_user(db_session: Session, user_id: UUID) -> int:
|
||||
"""Clear nextjs_port for all sessions belonging to a user.
|
||||
|
||||
Called when sandbox goes to sleep to release port allocations.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
user_id: The user whose sessions should have ports cleared
|
||||
|
||||
Returns:
|
||||
Number of sessions updated
|
||||
"""
|
||||
result = (
|
||||
db_session.query(BuildSession)
|
||||
.filter(
|
||||
BuildSession.user_id == user_id,
|
||||
BuildSession.nextjs_port.isnot(None),
|
||||
)
|
||||
.update({BuildSession.nextjs_port: None})
|
||||
)
|
||||
db_session.flush()
|
||||
logger.info(f"Cleared {result} nextjs_port allocations for user {user_id}")
|
||||
return result
|
||||
|
||||
|
||||
def fetch_llm_provider_by_type_for_build_mode(
|
||||
db_session: Session, provider_type: str
|
||||
) -> LLMProviderView | None:
|
||||
"""Fetch an LLM provider by its provider type (e.g., "anthropic", "openai").
|
||||
|
||||
Resolution priority:
|
||||
1. First try to find a provider named "build-mode-{type}" (e.g., "build-mode-anthropic")
|
||||
2. If not found, fall back to any provider that matches the type
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
provider_type: The provider type (e.g., "anthropic", "openai", "openrouter")
|
||||
|
||||
Returns:
|
||||
LLMProviderView if found, None otherwise
|
||||
"""
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
|
||||
# First try to find a "build-mode-{type}" provider
|
||||
build_mode_name = f"build-mode-{provider_type}"
|
||||
provider_model = fetch_existing_llm_provider(
|
||||
name=build_mode_name, db_session=db_session
|
||||
)
|
||||
|
||||
# If not found, fall back to any provider that matches the type
|
||||
if not provider_model:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel)
|
||||
.where(LLMProviderModel.provider == provider_type)
|
||||
.options(
|
||||
selectinload(LLMProviderModel.model_configurations),
|
||||
selectinload(LLMProviderModel.groups),
|
||||
selectinload(LLMProviderModel.personas),
|
||||
)
|
||||
)
|
||||
|
||||
if not provider_model:
|
||||
return None
|
||||
return LLMProviderView.from_model(provider_model)
|
||||
96
backend/onyx/server/features/build/db/rate_limit.py
Normal file
96
backend/onyx/server/features/build/db/rate_limit.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Database queries for Build Mode rate limiting."""
|
||||
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import BuildMessage
|
||||
from onyx.db.models import BuildSession
|
||||
|
||||
|
||||
def count_user_messages_in_window(
|
||||
user_id: UUID,
|
||||
cutoff_time: datetime,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
"""
|
||||
Count USER messages for a user since cutoff_time.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID
|
||||
cutoff_time: Only count messages created at or after this time
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Number of USER messages in the time window
|
||||
"""
|
||||
return (
|
||||
db_session.query(func.count(BuildMessage.id))
|
||||
.join(BuildSession, BuildMessage.session_id == BuildSession.id)
|
||||
.filter(
|
||||
BuildSession.user_id == user_id,
|
||||
BuildMessage.type == MessageType.USER,
|
||||
BuildMessage.created_at >= cutoff_time,
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def count_user_messages_total(user_id: UUID, db_session: Session) -> int:
|
||||
"""
|
||||
Count all USER messages for a user (lifetime total).
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Total number of USER messages
|
||||
"""
|
||||
return (
|
||||
db_session.query(func.count(BuildMessage.id))
|
||||
.join(BuildSession, BuildMessage.session_id == BuildSession.id)
|
||||
.filter(
|
||||
BuildSession.user_id == user_id,
|
||||
BuildMessage.type == MessageType.USER,
|
||||
)
|
||||
.scalar()
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def get_oldest_message_timestamp(
|
||||
user_id: UUID,
|
||||
cutoff_time: datetime,
|
||||
db_session: Session,
|
||||
) -> datetime | None:
|
||||
"""
|
||||
Get the timestamp of the oldest USER message in the time window.
|
||||
|
||||
Used to calculate when the rate limit will reset (when the oldest
|
||||
message ages out of the rolling window).
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID
|
||||
cutoff_time: Only consider messages created at or after this time
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Timestamp of oldest message in window, or None if no messages
|
||||
"""
|
||||
return (
|
||||
db_session.query(BuildMessage.created_at)
|
||||
.join(BuildSession, BuildMessage.session_id == BuildSession.id)
|
||||
.filter(
|
||||
BuildSession.user_id == user_id,
|
||||
BuildMessage.type == MessageType.USER,
|
||||
BuildMessage.created_at >= cutoff_time,
|
||||
)
|
||||
.order_by(BuildMessage.created_at.asc())
|
||||
.limit(1)
|
||||
.scalar()
|
||||
)
|
||||
206
backend/onyx/server/features/build/db/sandbox.py
Normal file
206
backend/onyx/server/features/build/db/sandbox.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Database operations for CLI agent sandbox management."""
|
||||
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.db.models import Sandbox
|
||||
from onyx.db.models import Snapshot
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def create_sandbox__no_commit(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
) -> Sandbox:
|
||||
"""Create a new sandbox record for a user.
|
||||
|
||||
NOTE: This function uses flush() instead of commit(). The caller is
|
||||
responsible for committing the transaction when ready.
|
||||
"""
|
||||
sandbox = Sandbox(
|
||||
user_id=user_id,
|
||||
status=SandboxStatus.PROVISIONING,
|
||||
)
|
||||
db_session.add(sandbox)
|
||||
db_session.flush()
|
||||
return sandbox
|
||||
|
||||
|
||||
def get_sandbox_by_user_id(db_session: Session, user_id: UUID) -> Sandbox | None:
|
||||
"""Get sandbox by user ID (primary lookup method)."""
|
||||
stmt = select(Sandbox).where(Sandbox.user_id == user_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_sandbox_by_session_id(db_session: Session, session_id: UUID) -> Sandbox | None:
|
||||
"""Get sandbox by session ID (compatibility function).
|
||||
|
||||
This function provides backwards compatibility during the transition to
|
||||
user-owned sandboxes. It looks up the session's user_id, then finds the
|
||||
user's sandbox.
|
||||
|
||||
NOTE: This will be removed in a future phase when all callers are updated
|
||||
to use get_sandbox_by_user_id() directly.
|
||||
"""
|
||||
from onyx.db.models import BuildSession
|
||||
|
||||
stmt = select(BuildSession.user_id).where(BuildSession.id == session_id)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return get_sandbox_by_user_id(db_session, result)
|
||||
|
||||
|
||||
def get_sandbox_by_id(db_session: Session, sandbox_id: UUID) -> Sandbox | None:
|
||||
"""Get sandbox by its ID."""
|
||||
stmt = select(Sandbox).where(Sandbox.id == sandbox_id)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def update_sandbox_status__no_commit(
|
||||
db_session: Session,
|
||||
sandbox_id: UUID,
|
||||
status: SandboxStatus,
|
||||
) -> Sandbox:
|
||||
"""Update sandbox status.
|
||||
|
||||
NOTE: This function uses flush() instead of commit(). The caller is
|
||||
responsible for committing the transaction when ready.
|
||||
"""
|
||||
sandbox = get_sandbox_by_id(db_session, sandbox_id)
|
||||
if not sandbox:
|
||||
raise ValueError(f"Sandbox {sandbox_id} not found")
|
||||
|
||||
sandbox.status = status
|
||||
db_session.flush()
|
||||
return sandbox
|
||||
|
||||
|
||||
def update_sandbox_heartbeat(db_session: Session, sandbox_id: UUID) -> Sandbox:
|
||||
"""Update sandbox last_heartbeat to now."""
|
||||
sandbox = get_sandbox_by_id(db_session, sandbox_id)
|
||||
if not sandbox:
|
||||
raise ValueError(f"Sandbox {sandbox_id} not found")
|
||||
|
||||
sandbox.last_heartbeat = datetime.datetime.now(datetime.timezone.utc)
|
||||
db_session.commit()
|
||||
return sandbox
|
||||
|
||||
|
||||
def get_idle_sandboxes(
|
||||
db_session: Session, idle_threshold_seconds: int
|
||||
) -> list[Sandbox]:
|
||||
"""Get sandboxes that have been idle longer than threshold."""
|
||||
threshold_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
seconds=idle_threshold_seconds
|
||||
)
|
||||
|
||||
stmt = select(Sandbox).where(
|
||||
Sandbox.status.in_([SandboxStatus.RUNNING, SandboxStatus.IDLE]),
|
||||
Sandbox.last_heartbeat < threshold_time,
|
||||
)
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
def get_running_sandbox_count_by_tenant(db_session: Session, tenant_id: str) -> int:
|
||||
"""Get count of running sandboxes for a tenant (for limit enforcement).
|
||||
|
||||
Note: tenant_id parameter is kept for API compatibility but is not used
|
||||
since Sandbox model no longer has tenant_id. This function returns
|
||||
the count of all running sandboxes.
|
||||
"""
|
||||
stmt = select(func.count(Sandbox.id)).where(
|
||||
Sandbox.status.in_([SandboxStatus.RUNNING, SandboxStatus.IDLE])
|
||||
)
|
||||
result = db_session.execute(stmt).scalar()
|
||||
return result or 0
|
||||
|
||||
|
||||
def create_snapshot(
|
||||
db_session: Session,
|
||||
session_id: UUID,
|
||||
storage_path: str,
|
||||
size_bytes: int,
|
||||
) -> Snapshot:
|
||||
"""Create a snapshot record for a session."""
|
||||
snapshot = Snapshot(
|
||||
session_id=session_id,
|
||||
storage_path=storage_path,
|
||||
size_bytes=size_bytes,
|
||||
)
|
||||
db_session.add(snapshot)
|
||||
db_session.commit()
|
||||
return snapshot
|
||||
|
||||
|
||||
def get_latest_snapshot_for_session(
|
||||
db_session: Session, session_id: UUID
|
||||
) -> Snapshot | None:
|
||||
"""Get most recent snapshot for a session."""
|
||||
stmt = (
|
||||
select(Snapshot)
|
||||
.where(Snapshot.session_id == session_id)
|
||||
.order_by(Snapshot.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_snapshots_for_session(db_session: Session, session_id: UUID) -> list[Snapshot]:
|
||||
"""Get all snapshots for a session, ordered by creation time descending."""
|
||||
stmt = (
|
||||
select(Snapshot)
|
||||
.where(Snapshot.session_id == session_id)
|
||||
.order_by(Snapshot.created_at.desc())
|
||||
)
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
def delete_old_snapshots(
|
||||
db_session: Session, tenant_id: str, retention_days: int
|
||||
) -> int:
|
||||
"""Delete snapshots older than retention period, return count deleted.
|
||||
|
||||
Note: tenant_id parameter is kept for API compatibility but is not used
|
||||
since Snapshot model no longer has tenant_id. This function deletes
|
||||
all snapshots older than the retention period.
|
||||
"""
|
||||
cutoff_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=retention_days
|
||||
)
|
||||
|
||||
stmt = select(Snapshot).where(
|
||||
Snapshot.created_at < cutoff_time,
|
||||
)
|
||||
old_snapshots = db_session.execute(stmt).scalars().all()
|
||||
|
||||
count = 0
|
||||
for snapshot in old_snapshots:
|
||||
db_session.delete(snapshot)
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
db_session.commit()
|
||||
|
||||
return count
|
||||
|
||||
|
||||
def delete_snapshot(db_session: Session, snapshot_id: UUID) -> bool:
|
||||
"""Delete a specific snapshot by ID. Returns True if deleted, False if not found."""
|
||||
stmt = select(Snapshot).where(Snapshot.id == snapshot_id)
|
||||
snapshot = db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not snapshot:
|
||||
return False
|
||||
|
||||
db_session.delete(snapshot)
|
||||
db_session.commit()
|
||||
return True
|
||||
@@ -0,0 +1,397 @@
|
||||
"""
|
||||
Persistent Document Writer for writing indexed documents to local filesystem or S3 with
|
||||
hierarchical directory structure that mirrors the source organization.
|
||||
|
||||
Local mode (SandboxBackend.LOCAL):
|
||||
Writes to local filesystem at {PERSISTENT_DOCUMENT_STORAGE_PATH}/{tenant_id}/knowledge/{user_id}/...
|
||||
|
||||
Kubernetes mode (SandboxBackend.KUBERNETES):
|
||||
Writes to S3 at s3://{SANDBOX_S3_BUCKET}/{tenant_id}/knowledge/{user_id}/...
|
||||
This is the same location that kubernetes_sandbox_manager.py reads from when
|
||||
provisioning sandboxes.
|
||||
|
||||
Both modes use consistent tenant/user-segregated paths for multi-tenant isolation.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
from mypy_boto3_s3.client import S3Client
|
||||
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.server.features.build.configs import PERSISTENT_DOCUMENT_STORAGE_PATH
|
||||
from onyx.server.features.build.configs import SANDBOX_BACKEND
|
||||
from onyx.server.features.build.configs import SANDBOX_S3_BUCKET
|
||||
from onyx.server.features.build.configs import SandboxBackend
|
||||
from onyx.server.features.build.s3.s3_client import build_s3_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shared Utilities for Path Building
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def sanitize_path_component(component: str, replace_slash: bool = True) -> str:
|
||||
"""Sanitize a path component for file system / S3 key safety.
|
||||
|
||||
Args:
|
||||
component: The path component to sanitize
|
||||
replace_slash: If True, replaces forward slashes (needed for local filesystem).
|
||||
Set to False for S3 where `/` is a valid delimiter.
|
||||
|
||||
Returns:
|
||||
Sanitized path component safe for use in file paths or S3 keys
|
||||
"""
|
||||
# Replace spaces with underscores
|
||||
sanitized = component.replace(" ", "_")
|
||||
# Replace problematic characters
|
||||
if replace_slash:
|
||||
sanitized = sanitized.replace("/", "_")
|
||||
sanitized = sanitized.replace("\\", "_").replace(":", "_")
|
||||
sanitized = sanitized.replace("<", "_").replace(">", "_").replace("|", "_")
|
||||
sanitized = sanitized.replace('"', "_").replace("?", "_").replace("*", "_")
|
||||
# Also handle null bytes and other control characters
|
||||
sanitized = "".join(c for c in sanitized if ord(c) >= 32)
|
||||
return sanitized.strip() or "unnamed"
|
||||
|
||||
|
||||
def sanitize_filename(name: str, replace_slash: bool = True) -> str:
|
||||
"""Sanitize name for use as filename.
|
||||
|
||||
Args:
|
||||
name: The filename to sanitize
|
||||
replace_slash: Passed through to sanitize_path_component
|
||||
|
||||
Returns:
|
||||
Sanitized filename, truncated with hash suffix if too long
|
||||
"""
|
||||
sanitized = sanitize_path_component(name, replace_slash=replace_slash)
|
||||
if len(sanitized) > 200:
|
||||
# Keep first 150 chars + hash suffix for uniqueness
|
||||
hash_suffix = hashlib.sha256(name.encode()).hexdigest()[:16]
|
||||
return f"{sanitized[:150]}_{hash_suffix}"
|
||||
return sanitized
|
||||
|
||||
|
||||
def get_base_filename(doc: Document, replace_slash: bool = True) -> str:
|
||||
"""Get base filename from document, preferring semantic identifier.
|
||||
|
||||
Args:
|
||||
doc: The document to get filename for
|
||||
replace_slash: Passed through to sanitize_filename
|
||||
|
||||
Returns:
|
||||
Sanitized base filename (without extension)
|
||||
"""
|
||||
name = doc.semantic_identifier or doc.title or doc.id
|
||||
return sanitize_filename(name, replace_slash=replace_slash)
|
||||
|
||||
|
||||
def build_document_subpath(doc: Document, replace_slash: bool = True) -> list[str]:
|
||||
"""Build the source/hierarchy path components from a document.
|
||||
|
||||
Returns path components like: [source, hierarchy_part1, hierarchy_part2, ...]
|
||||
|
||||
This is the common part of the path that comes after user/tenant segregation.
|
||||
|
||||
Args:
|
||||
doc: The document to build path for
|
||||
replace_slash: Passed through to sanitize_path_component
|
||||
|
||||
Returns:
|
||||
List of sanitized path components
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
# Source type (e.g., "google_drive", "confluence")
|
||||
parts.append(doc.source.value)
|
||||
|
||||
# Get hierarchy from doc_metadata
|
||||
hierarchy = doc.doc_metadata.get("hierarchy", {}) if doc.doc_metadata else {}
|
||||
source_path = hierarchy.get("source_path", [])
|
||||
|
||||
if source_path:
|
||||
parts.extend(
|
||||
[
|
||||
sanitize_path_component(p, replace_slash=replace_slash)
|
||||
for p in source_path
|
||||
]
|
||||
)
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def resolve_duplicate_filename(
|
||||
doc: Document,
|
||||
base_filename: str,
|
||||
has_duplicates: bool,
|
||||
replace_slash: bool = True,
|
||||
) -> str:
|
||||
"""Resolve filename, appending ID suffix if there are duplicates.
|
||||
|
||||
Args:
|
||||
doc: The document (for ID extraction)
|
||||
base_filename: The base filename without extension
|
||||
has_duplicates: Whether there are other docs with the same base filename
|
||||
replace_slash: Passed through to sanitize_path_component
|
||||
|
||||
Returns:
|
||||
Final filename with .json extension
|
||||
"""
|
||||
if has_duplicates:
|
||||
id_suffix = sanitize_path_component(doc.id, replace_slash=replace_slash)
|
||||
if len(id_suffix) > 50:
|
||||
id_suffix = hashlib.sha256(doc.id.encode()).hexdigest()[:16]
|
||||
return f"{base_filename}_{id_suffix}.json"
|
||||
return f"{base_filename}.json"
|
||||
|
||||
|
||||
def serialize_document(doc: Document) -> dict[str, Any]:
|
||||
"""Serialize a document to a dictionary for JSON storage.
|
||||
|
||||
Args:
|
||||
doc: The document to serialize
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the document
|
||||
"""
|
||||
return {
|
||||
"id": doc.id,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"title": doc.title,
|
||||
"source": doc.source.value,
|
||||
"doc_updated_at": (
|
||||
doc.doc_updated_at.isoformat() if doc.doc_updated_at else None
|
||||
),
|
||||
"metadata": doc.metadata,
|
||||
"doc_metadata": doc.doc_metadata,
|
||||
"sections": [
|
||||
{"text": s.text if hasattr(s, "text") else None, "link": s.link}
|
||||
for s in doc.sections
|
||||
],
|
||||
"primary_owners": [o.model_dump() for o in (doc.primary_owners or [])],
|
||||
"secondary_owners": [o.model_dump() for o in (doc.secondary_owners or [])],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class PersistentDocumentWriter:
|
||||
"""Writes indexed documents to local filesystem with hierarchical structure.
|
||||
|
||||
Documents are stored in tenant/user-segregated paths:
|
||||
{base_path}/{tenant_id}/knowledge/{user_id}/{source}/{hierarchy}/document.json
|
||||
|
||||
This enables per-tenant and per-user isolation for sandbox access control.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_path: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
self.base_path = Path(base_path)
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
|
||||
def write_documents(self, documents: list[Document]) -> list[str]:
|
||||
"""Write documents to local filesystem, returns written file paths."""
|
||||
written_paths: list[str] = []
|
||||
|
||||
# Build a map of base filenames to detect duplicates
|
||||
# Key: (directory_path, base_filename) -> list of docs with that name
|
||||
filename_map: dict[tuple[Path, str], list[Document]] = {}
|
||||
|
||||
for doc in documents:
|
||||
dir_path = self._build_directory_path(doc)
|
||||
base_filename = get_base_filename(doc, replace_slash=True)
|
||||
key = (dir_path, base_filename)
|
||||
if key not in filename_map:
|
||||
filename_map[key] = []
|
||||
filename_map[key].append(doc)
|
||||
|
||||
# Now write documents, appending ID if there are duplicates
|
||||
for (dir_path, base_filename), docs in filename_map.items():
|
||||
has_duplicates = len(docs) > 1
|
||||
for doc in docs:
|
||||
filename = resolve_duplicate_filename(
|
||||
doc, base_filename, has_duplicates, replace_slash=True
|
||||
)
|
||||
path = dir_path / filename
|
||||
self._write_document(doc, path)
|
||||
written_paths.append(str(path))
|
||||
|
||||
return written_paths
|
||||
|
||||
def _build_directory_path(self, doc: Document) -> Path:
|
||||
"""Build directory path from document metadata.
|
||||
|
||||
Documents are stored under tenant/user-segregated paths:
|
||||
{base_path}/{tenant_id}/knowledge/{user_id}/{source}/{hierarchy}/
|
||||
|
||||
This enables per-tenant and per-user isolation for sandbox access control.
|
||||
"""
|
||||
# Tenant and user segregation prefix (matches S3 path structure)
|
||||
parts = [self.tenant_id, "knowledge", self.user_id]
|
||||
# Add source and hierarchy from document
|
||||
parts.extend(build_document_subpath(doc, replace_slash=True))
|
||||
|
||||
return self.base_path / "/".join(parts)
|
||||
|
||||
def _write_document(self, doc: Document, path: Path) -> None:
|
||||
"""Serialize and write document to filesystem."""
|
||||
content = serialize_document(doc)
|
||||
|
||||
# Create parent directories if they don't exist
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the JSON file
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(content, f, indent=2, default=str)
|
||||
|
||||
logger.debug(f"Wrote document to {path}")
|
||||
|
||||
|
||||
class S3PersistentDocumentWriter:
|
||||
"""Writes indexed documents to S3 with hierarchical structure.
|
||||
|
||||
Documents are stored in tenant/user-segregated paths:
|
||||
s3://{bucket}/{tenant_id}/knowledge/{user_id}/{source}/{hierarchy}/document.json
|
||||
|
||||
This matches the location that KubernetesSandboxManager reads from when
|
||||
provisioning sandboxes (via the init container's aws s3 sync command).
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_id: str, user_id: str):
|
||||
"""Initialize S3PersistentDocumentWriter.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
user_id: User ID for user-segregated storage paths
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.user_id = user_id
|
||||
self.bucket = SANDBOX_S3_BUCKET
|
||||
self._s3_client: S3Client | None = None
|
||||
|
||||
def _get_s3_client(self) -> S3Client:
|
||||
"""Lazily initialize S3 client.
|
||||
|
||||
Uses the craft-specific boto3 client which only supports IAM roles (IRSA).
|
||||
"""
|
||||
if self._s3_client is None:
|
||||
self._s3_client = build_s3_client()
|
||||
return self._s3_client
|
||||
|
||||
def write_documents(self, documents: list[Document]) -> list[str]:
|
||||
"""Write documents to S3, returns written S3 keys.
|
||||
|
||||
Args:
|
||||
documents: List of documents to write
|
||||
|
||||
Returns:
|
||||
List of S3 keys that were written
|
||||
"""
|
||||
written_keys: list[str] = []
|
||||
|
||||
# Build a map of base keys to detect duplicates
|
||||
# Key: (directory_prefix, base_filename) -> list of docs with that name
|
||||
key_map: dict[tuple[str, str], list[Document]] = {}
|
||||
|
||||
for doc in documents:
|
||||
dir_prefix = self._build_directory_path(doc)
|
||||
base_filename = get_base_filename(doc, replace_slash=False)
|
||||
key = (dir_prefix, base_filename)
|
||||
if key not in key_map:
|
||||
key_map[key] = []
|
||||
key_map[key].append(doc)
|
||||
|
||||
# Now write documents, appending ID if there are duplicates
|
||||
s3_client = self._get_s3_client()
|
||||
|
||||
for (dir_prefix, base_filename), docs in key_map.items():
|
||||
has_duplicates = len(docs) > 1
|
||||
for doc in docs:
|
||||
filename = resolve_duplicate_filename(
|
||||
doc, base_filename, has_duplicates, replace_slash=False
|
||||
)
|
||||
s3_key = f"{dir_prefix}/{filename}"
|
||||
self._write_document(s3_client, doc, s3_key)
|
||||
written_keys.append(s3_key)
|
||||
|
||||
return written_keys
|
||||
|
||||
def _build_directory_path(self, doc: Document) -> str:
|
||||
"""Build S3 key prefix from document metadata.
|
||||
|
||||
Documents are stored under tenant/user-segregated paths:
|
||||
{tenant_id}/knowledge/{user_id}/{source}/{hierarchy}/
|
||||
|
||||
This matches the path that KubernetesSandboxManager syncs from:
|
||||
aws s3 sync "s3://{bucket}/{tenant_id}/knowledge/{user_id}/" /workspace/files/
|
||||
"""
|
||||
# Tenant and user segregation (matches K8s sandbox init container path)
|
||||
parts = [self.tenant_id, "knowledge", self.user_id]
|
||||
# Add source and hierarchy from document
|
||||
parts.extend(build_document_subpath(doc, replace_slash=False))
|
||||
|
||||
return "/".join(parts)
|
||||
|
||||
def _write_document(self, s3_client: S3Client, doc: Document, s3_key: str) -> None:
|
||||
"""Serialize and write document to S3."""
|
||||
content = serialize_document(doc)
|
||||
json_content = json.dumps(content, indent=2, default=str)
|
||||
|
||||
try:
|
||||
s3_client.put_object(
|
||||
Bucket=self.bucket,
|
||||
Key=s3_key,
|
||||
Body=json_content.encode("utf-8"),
|
||||
ContentType="application/json",
|
||||
)
|
||||
logger.debug(f"Wrote document to s3://{self.bucket}/{s3_key}")
|
||||
except ClientError as e:
|
||||
logger.error(f"Failed to write to S3: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def get_persistent_document_writer(
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
) -> PersistentDocumentWriter | S3PersistentDocumentWriter:
|
||||
"""Factory function to create a PersistentDocumentWriter with default configuration.
|
||||
|
||||
Args:
|
||||
user_id: User ID for user-segregated storage paths.
|
||||
tenant_id: Tenant ID for multi-tenant isolation.
|
||||
|
||||
Both local and S3 modes use consistent tenant/user-segregated paths:
|
||||
- Local: {base_path}/{tenant_id}/knowledge/{user_id}/...
|
||||
- S3: s3://{bucket}/{tenant_id}/knowledge/{user_id}/...
|
||||
|
||||
Returns:
|
||||
PersistentDocumentWriter for local mode, S3PersistentDocumentWriter for K8s mode
|
||||
"""
|
||||
if SANDBOX_BACKEND == SandboxBackend.LOCAL:
|
||||
return PersistentDocumentWriter(
|
||||
base_path=PERSISTENT_DOCUMENT_STORAGE_PATH,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
elif SANDBOX_BACKEND == SandboxBackend.KUBERNETES:
|
||||
return S3PersistentDocumentWriter(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown sandbox backend: {SANDBOX_BACKEND}")
|
||||
9
backend/onyx/server/features/build/s3/s3_client.py
Normal file
9
backend/onyx/server/features/build/s3/s3_client.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import boto3
|
||||
from mypy_boto3_s3.client import S3Client
|
||||
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
|
||||
|
||||
def build_s3_client() -> S3Client:
|
||||
"""Build an S3 client using IAM roles (IRSA)"""
|
||||
return boto3.client("s3", region_name=AWS_REGION_NAME)
|
||||
352
backend/onyx/server/features/build/sandbox/README.md
Normal file
352
backend/onyx/server/features/build/sandbox/README.md
Normal file
@@ -0,0 +1,352 @@
|
||||
# Onyx Sandbox System
|
||||
|
||||
This directory contains the implementation of Onyx's sandbox system for running OpenCode agents in isolated environments.
|
||||
|
||||
## Overview
|
||||
|
||||
The sandbox system provides isolated execution environments where OpenCode agents can build web applications, run code, and interact with knowledge files. Each sandbox includes:
|
||||
|
||||
- **Next.js development environment** - Lightweight Next.js scaffold with shadcn/ui and Recharts for building UIs
|
||||
- **Python virtual environment** - Pre-installed packages for data processing
|
||||
- **OpenCode agent** - AI coding agent with access to tools and MCP servers
|
||||
- **Knowledge files** - Access to indexed documents and user uploads
|
||||
|
||||
## Architecture
|
||||
|
||||
### Deployment Modes
|
||||
|
||||
1. **Local Mode** (`SANDBOX_BACKEND=local`)
|
||||
- Sandboxes run as directories on the local filesystem
|
||||
- No automatic cleanup or snapshots
|
||||
- Suitable for development and testing
|
||||
|
||||
2. **Kubernetes Mode** (`SANDBOX_BACKEND=kubernetes`)
|
||||
- Sandboxes run as Kubernetes pods
|
||||
- Automatic snapshots to S3
|
||||
- Auto-cleanup of idle sandboxes
|
||||
- Production-ready with resource isolation
|
||||
|
||||
### Directory Structure
|
||||
|
||||
```
|
||||
/workspace/ # Sandbox root (in container)
|
||||
├── outputs/ # Working directory
|
||||
│ ├── web/ # Lightweight Next.js app (shadcn/ui, Recharts)
|
||||
│ ├── slides/ # Generated presentations
|
||||
│ ├── markdown/ # Generated documents
|
||||
│ └── graphs/ # Generated visualizations
|
||||
├── .venv/ # Python virtual environment
|
||||
├── files/ # Symlink to knowledge files
|
||||
├── attachments/ # User uploads
|
||||
├── AGENTS.md # Agent instructions
|
||||
└── .opencode/
|
||||
└── skills/ # Agent skills
|
||||
```
|
||||
|
||||
## Setup
|
||||
|
||||
### Running via Docker/Kubernetes (Zero Setup!) 🎉
|
||||
|
||||
**No setup required!** Just build and deploy:
|
||||
|
||||
```bash
|
||||
# Build backend image (includes both templates)
|
||||
cd backend
|
||||
docker build -f Dockerfile.sandbox-templates -t onyxdotapp/backend:latest .
|
||||
|
||||
# Build sandbox container (lightweight runner)
|
||||
cd onyx/server/features/build/sandbox/kubernetes/docker
|
||||
docker build -t onyxdotapp/sandbox:latest .
|
||||
|
||||
# Deploy with docker-compose or kubectl - sandboxes work immediately!
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
- **Backend image**: Contains both templates at build time:
|
||||
- Web template at `/templates/outputs/web` (lightweight Next.js scaffold, ~2MB)
|
||||
- Python venv template at `/templates/venv` (pre-installed packages, ~50MB)
|
||||
- **Init container** (Kubernetes only): Syncs knowledge files from S3
|
||||
- **Sandbox startup**: Runs `npm install` (for fresh dependency locks) + `next dev`
|
||||
|
||||
### Running Backend Directly (Without Docker)
|
||||
|
||||
**Only needed if you're running the Onyx backend outside of Docker.** Most developers use Docker and can skip this section.
|
||||
|
||||
If you're running the backend Python process directly on your machine, you need templates at `/templates/`:
|
||||
|
||||
#### Web Template
|
||||
|
||||
The web template is a lightweight Next.js app (Next.js 16, React 19, shadcn/ui, Recharts) checked into the codebase at `backend/onyx/server/features/build/templates/outputs/web/`.
|
||||
|
||||
For local development, create a symlink to this template:
|
||||
|
||||
```bash
|
||||
sudo mkdir -p /templates/outputs
|
||||
sudo ln -s $(pwd)/backend/onyx/server/features/build/templates/outputs/web /templates/outputs/web
|
||||
```
|
||||
|
||||
#### Python Venv Template
|
||||
|
||||
If you don't have a venv template, create it:
|
||||
|
||||
```bash
|
||||
# Use the utility script
|
||||
cd backend
|
||||
python -m onyx.server.features.build.sandbox.util.build_venv_template
|
||||
|
||||
# Or manually
|
||||
python3 -m venv /templates/venv
|
||||
/templates/venv/bin/pip install -r backend/onyx/server/features/build/sandbox/kubernetes/docker/initial-requirements.txt
|
||||
```
|
||||
|
||||
**That's it!** When sandboxes are created:
|
||||
|
||||
1. Web template is copied from `/templates/outputs/web`
|
||||
2. Python venv is copied from `/templates/venv`
|
||||
3. `npm install` runs automatically to install fresh Next.js dependencies
|
||||
|
||||
## OpenCode Configuration
|
||||
|
||||
Each sandbox includes an OpenCode agent configured with:
|
||||
|
||||
- **LLM Provider**: Anthropic, OpenAI, Google, Bedrock, or Azure
|
||||
- **Extended thinking**: High reasoning effort / thinking budgets for complex tasks
|
||||
- **Tool permissions**: File operations, bash commands, web access
|
||||
- **Disabled tools**: Configurable via `OPENCODE_DISABLED_TOOLS` env var
|
||||
|
||||
Configuration is generated dynamically in `templates/opencode_config.py`.
|
||||
|
||||
## Key Components
|
||||
|
||||
### Managers
|
||||
|
||||
- **`base.py`** - Abstract base class defining the sandbox interface
|
||||
- **`local/manager.py`** - Filesystem-based sandbox manager for local development
|
||||
- **`kubernetes/manager.py`** - Kubernetes-based sandbox manager for production
|
||||
|
||||
### Managers (Shared)
|
||||
|
||||
- **`manager/directory_manager.py`** - Creates sandbox directory structure and copies templates
|
||||
- **`manager/snapshot_manager.py`** - Handles snapshot creation and restoration
|
||||
|
||||
### Utilities
|
||||
|
||||
- **`util/opencode_config.py`** - Generates OpenCode configuration with MCP support
|
||||
- **`util/agent_instructions.py`** - Generates agent instructions (AGENTS.md)
|
||||
- **`util/build_venv_template.py`** - Utility to build Python venv template for local development
|
||||
|
||||
### Templates
|
||||
|
||||
- **`../templates/outputs/web/`** - Lightweight Next.js scaffold (shadcn/ui, Recharts) versioned with the backend code
|
||||
|
||||
### Kubernetes Specific
|
||||
|
||||
- **`kubernetes/docker/Dockerfile`** - Sandbox container image (runs Next.js + OpenCode)
|
||||
- **`kubernetes/docker/entrypoint.sh`** - Container startup script
|
||||
|
||||
## Environment Variables
|
||||
|
||||
### Core Settings
|
||||
|
||||
```bash
|
||||
# Sandbox backend mode
|
||||
SANDBOX_BACKEND=local|kubernetes # Default: local
|
||||
|
||||
# Template paths (local mode)
|
||||
OUTPUTS_TEMPLATE_PATH=/templates/outputs # Default: /templates/outputs
|
||||
VENV_TEMPLATE_PATH=/templates/venv # Default: /templates/venv
|
||||
|
||||
# Sandbox base path (local mode)
|
||||
SANDBOX_BASE_PATH=/tmp/onyx-sandboxes # Default: /tmp/onyx-sandboxes
|
||||
|
||||
# OpenCode configuration
|
||||
OPENCODE_DISABLED_TOOLS=question # Comma-separated list, default: question
|
||||
```
|
||||
|
||||
### Kubernetes Settings
|
||||
|
||||
```bash
|
||||
# Kubernetes namespace
|
||||
SANDBOX_NAMESPACE=onyx-sandboxes # Default: onyx-sandboxes
|
||||
|
||||
# Container image
|
||||
SANDBOX_CONTAINER_IMAGE=onyxdotapp/sandbox:latest
|
||||
|
||||
# S3 bucket for snapshots and files
|
||||
SANDBOX_S3_BUCKET=onyx-sandbox-files # Default: onyx-sandbox-files
|
||||
|
||||
# Service accounts
|
||||
SANDBOX_SERVICE_ACCOUNT_NAME=sandbox-runner # No AWS access
|
||||
SANDBOX_FILE_SYNC_SERVICE_ACCOUNT=sandbox-file-sync # Has S3 access via IRSA
|
||||
```
|
||||
|
||||
### Lifecycle Settings
|
||||
|
||||
```bash
|
||||
# Idle timeout before cleanup (seconds)
|
||||
SANDBOX_IDLE_TIMEOUT_SECONDS=900 # Default: 900 (15 minutes)
|
||||
|
||||
# Max concurrent sandboxes per organization
|
||||
SANDBOX_MAX_CONCURRENT_PER_ORG=10 # Default: 10
|
||||
|
||||
# Next.js port range (local mode)
|
||||
SANDBOX_NEXTJS_PORT_START=3010 # Default: 3010
|
||||
SANDBOX_NEXTJS_PORT_END=3100 # Default: 3100
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
### Integration Tests
|
||||
|
||||
```bash
|
||||
# Test local sandbox provisioning
|
||||
uv run pytest backend/tests/integration/sandbox/test_local_sandbox.py
|
||||
|
||||
# Test Kubernetes sandbox provisioning (requires k8s cluster)
|
||||
uv run pytest backend/tests/integration/sandbox/test_kubernetes_sandbox.py
|
||||
```
|
||||
|
||||
### Manual Testing
|
||||
|
||||
```bash
|
||||
# Start a local sandbox session
|
||||
curl -X POST http://localhost:3000/api/build/session \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"user_id": "user-123",
|
||||
"file_system_path": "/path/to/files"
|
||||
}'
|
||||
|
||||
# Send a message to the agent
|
||||
curl -X POST http://localhost:3000/api/build/session/{session_id}/message \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"message": "Create a simple web page"
|
||||
}'
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Sandbox Stuck in PROVISIONING (Kubernetes)
|
||||
|
||||
**Symptoms**: Sandbox status never changes from `PROVISIONING`
|
||||
|
||||
**Solutions**:
|
||||
|
||||
- Check pod logs: `kubectl logs -n onyx-sandboxes sandbox-{sandbox-id}`
|
||||
- Check init container: `kubectl logs -n onyx-sandboxes sandbox-{sandbox-id} -c file-sync`
|
||||
- Verify init container completed: `kubectl describe pod -n onyx-sandboxes sandbox-{sandbox-id}`
|
||||
- Check S3 bucket access: Ensure init container service account has IRSA configured
|
||||
|
||||
### Next.js Server Won't Start
|
||||
|
||||
**Symptoms**: Sandbox provisioned but web preview doesn't load
|
||||
|
||||
**Solutions**:
|
||||
|
||||
- **Local mode**: Check if port is already in use
|
||||
- **Docker/K8s**: Check container logs: `kubectl logs -n onyx-sandboxes sandbox-{sandbox-id}`
|
||||
- Verify npm install succeeded (check entrypoint.sh logs)
|
||||
- Check that web template was copied: `kubectl exec -n onyx-sandboxes sandbox-{sandbox-id} -- ls /workspace/outputs/web`
|
||||
|
||||
### Templates Not Found (Local Mode)
|
||||
|
||||
**Symptoms**: `RuntimeError: Sandbox templates are missing`
|
||||
|
||||
**Solution**: Set up templates as described in the "Local Development" section above:
|
||||
|
||||
```bash
|
||||
# Symlink web template
|
||||
sudo ln -s $(pwd)/backend/onyx/server/features/build/templates/outputs/web /templates/outputs/web
|
||||
|
||||
# Create Python venv
|
||||
python3 -m venv /templates/venv
|
||||
/templates/venv/bin/pip install -r backend/onyx/server/features/build/sandbox/kubernetes/docker/initial-requirements.txt
|
||||
```
|
||||
|
||||
### Permission Denied
|
||||
|
||||
**Symptoms**: `Permission denied` error accessing `/templates/`
|
||||
|
||||
**Solution**: Either use sudo when creating symlinks, or use custom paths:
|
||||
|
||||
```bash
|
||||
export OUTPUTS_TEMPLATE_PATH=$HOME/.onyx/templates/outputs
|
||||
export VENV_TEMPLATE_PATH=$HOME/.onyx/templates/venv
|
||||
|
||||
# Then symlink to your home directory
|
||||
mkdir -p $HOME/.onyx/templates/outputs
|
||||
ln -s $(pwd)/backend/onyx/server/features/build/templates/outputs/web $HOME/.onyx/templates/outputs/web
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Sandbox Isolation
|
||||
|
||||
- **Kubernetes pods** run with restricted security context (non-root, no privilege escalation)
|
||||
- **Init containers** have S3 access for file sync, but main sandbox container does NOT
|
||||
- **Network policies** can restrict sandbox egress traffic
|
||||
- **Resource limits** prevent resource exhaustion
|
||||
|
||||
### Credentials Management
|
||||
|
||||
- LLM API keys are passed as environment variables (not stored in sandbox)
|
||||
- User file access is read-only via symlinks
|
||||
- Snapshots are isolated per tenant in S3
|
||||
|
||||
## Development
|
||||
|
||||
### Adding New MCP Servers
|
||||
|
||||
1. Add MCP configuration to `templates/opencode_config.py`:
|
||||
|
||||
```python
|
||||
config["mcp"] = {
|
||||
"my-mcp": {
|
||||
"type": "local",
|
||||
"command": ["npx", "@my/mcp@latest"],
|
||||
"enabled": True,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. Install required npm packages in web template (if needed)
|
||||
|
||||
3. Rebuild Docker image and templates
|
||||
|
||||
### Modifying Agent Instructions
|
||||
|
||||
Edit `AGENTS.template.md` in the build directory. This is populated with dynamic content by `templates/agent_instructions.py`.
|
||||
|
||||
### Adding New Tools/Permissions
|
||||
|
||||
Update `templates/opencode_config.py` to add/remove tool permissions in the `permission` section.
|
||||
|
||||
## Template Details
|
||||
|
||||
### Web Template
|
||||
|
||||
The lightweight Next.js template (`backend/onyx/server/features/build/templates/outputs/web/`) includes:
|
||||
|
||||
- **Framework**: Next.js 16.1.4 with React 19.2.3
|
||||
- **UI Library**: shadcn/ui components with Radix UI primitives
|
||||
- **Styling**: Tailwind CSS v4 with custom theming support
|
||||
- **Charts**: Recharts for data visualization
|
||||
- **Size**: ~2MB (excluding node_modules, which are installed fresh per sandbox)
|
||||
|
||||
This template provides a modern development environment without the complexity of the full Onyx application, allowing agents to build custom UIs quickly.
|
||||
|
||||
### Python Venv Template
|
||||
|
||||
The Python venv (`/templates/venv/`) includes packages from `initial-requirements.txt`:
|
||||
|
||||
- Data processing: pandas, numpy, polars
|
||||
- HTTP clients: requests, httpx
|
||||
- Utilities: python-dotenv, pydantic
|
||||
|
||||
## References
|
||||
|
||||
- [OpenCode Documentation](https://docs.opencode.ai)
|
||||
- [Next.js Documentation](https://nextjs.org/docs)
|
||||
- [shadcn/ui Components](https://ui.shadcn.com)
|
||||
44
backend/onyx/server/features/build/sandbox/__init__.py
Normal file
44
backend/onyx/server/features/build/sandbox/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Sandbox module for CLI agent filesystem-based isolation.
|
||||
|
||||
This module provides lightweight sandbox management for CLI-based AI agent sessions.
|
||||
Each sandbox is a directory on the local filesystem or a Kubernetes pod.
|
||||
|
||||
Usage:
|
||||
from onyx.server.features.build.sandbox import get_sandbox_manager
|
||||
|
||||
# Get the appropriate sandbox manager based on SANDBOX_BACKEND config
|
||||
sandbox_manager = get_sandbox_manager()
|
||||
|
||||
# Use the sandbox manager
|
||||
sandbox_info = sandbox_manager.provision(...)
|
||||
|
||||
Module structure:
|
||||
- base.py: SandboxManager ABC and get_sandbox_manager() factory
|
||||
- models.py: Shared Pydantic models
|
||||
- local/: Local filesystem-based implementation for development
|
||||
- kubernetes/: Kubernetes pod-based implementation for production
|
||||
- internal/: Shared internal utilities (snapshot manager)
|
||||
"""
|
||||
|
||||
from onyx.server.features.build.sandbox.base import get_sandbox_manager
|
||||
from onyx.server.features.build.sandbox.base import SandboxManager
|
||||
from onyx.server.features.build.sandbox.local.local_sandbox_manager import (
|
||||
LocalSandboxManager,
|
||||
)
|
||||
from onyx.server.features.build.sandbox.models import FilesystemEntry
|
||||
from onyx.server.features.build.sandbox.models import SandboxInfo
|
||||
from onyx.server.features.build.sandbox.models import SnapshotInfo
|
||||
|
||||
__all__ = [
|
||||
# Factory function (preferred)
|
||||
"get_sandbox_manager",
|
||||
# Interface
|
||||
"SandboxManager",
|
||||
# Implementations
|
||||
"LocalSandboxManager",
|
||||
# Models
|
||||
"SandboxInfo",
|
||||
"SnapshotInfo",
|
||||
"FilesystemEntry",
|
||||
]
|
||||
466
backend/onyx/server/features/build/sandbox/base.py
Normal file
466
backend/onyx/server/features/build/sandbox/base.py
Normal file
@@ -0,0 +1,466 @@
|
||||
"""Abstract base class and factory for sandbox operations.
|
||||
|
||||
SandboxManager is the abstract interface for sandbox lifecycle management.
|
||||
Use get_sandbox_manager() to get the appropriate implementation based on SANDBOX_BACKEND.
|
||||
|
||||
IMPORTANT: SandboxManager implementations must NOT interface with the database directly.
|
||||
All database operations should be handled by the caller (SessionManager, Celery tasks, etc.).
|
||||
|
||||
Architecture Note (User-Shared Sandbox Model):
|
||||
- One sandbox (container/pod) is shared across all of a user's sessions
|
||||
- provision() creates the user's sandbox with shared files/ directory
|
||||
- setup_session_workspace() creates per-session workspace within the sandbox
|
||||
- cleanup_session_workspace() removes session workspace on session delete
|
||||
- terminate() destroys the entire sandbox (all sessions)
|
||||
"""
|
||||
|
||||
import threading
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.server.features.build.configs import SANDBOX_BACKEND
|
||||
from onyx.server.features.build.configs import SandboxBackend
|
||||
from onyx.server.features.build.sandbox.models import FilesystemEntry
|
||||
from onyx.server.features.build.sandbox.models import LLMProviderConfig
|
||||
from onyx.server.features.build.sandbox.models import SandboxInfo
|
||||
from onyx.server.features.build.sandbox.models import SnapshotResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ACPEvent is a union type defined in both local and kubernetes modules
|
||||
# Using Any here to avoid circular imports - the actual type checking
|
||||
# happens in the implementation modules
|
||||
ACPEvent = Any
|
||||
|
||||
|
||||
class SandboxManager(ABC):
|
||||
"""Abstract interface for sandbox operations.
|
||||
|
||||
Defines the contract for sandbox lifecycle management including:
|
||||
- Provisioning and termination (user-level)
|
||||
- Session workspace setup and cleanup (session-level)
|
||||
- Snapshot creation (session-level)
|
||||
- Health checks
|
||||
- Agent communication (session-level)
|
||||
- Filesystem operations (session-level)
|
||||
|
||||
Directory Structure:
|
||||
$SANDBOX_ROOT/
|
||||
├── files/ # SHARED - symlink to user's persistent documents
|
||||
└── sessions/
|
||||
├── $session_id_1/ # Per-session workspace
|
||||
│ ├── outputs/ # Agent output for this session
|
||||
│ │ └── web/ # Next.js app
|
||||
│ ├── venv/ # Python virtual environment
|
||||
│ ├── skills/ # Opencode skills
|
||||
│ ├── AGENTS.md # Agent instructions
|
||||
│ ├── opencode.json # LLM config
|
||||
│ └── attachments/
|
||||
└── $session_id_2/
|
||||
└── ...
|
||||
|
||||
IMPORTANT: Implementations must NOT interface with the database directly.
|
||||
All database operations should be handled by the caller.
|
||||
|
||||
Use get_sandbox_manager() to get the appropriate implementation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def provision(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
user_id: UUID,
|
||||
tenant_id: str,
|
||||
llm_config: LLMProviderConfig,
|
||||
) -> SandboxInfo:
|
||||
"""Provision a new sandbox for a user.
|
||||
|
||||
Creates the sandbox container/directory with:
|
||||
- sessions/ directory for per-session workspaces
|
||||
|
||||
NOTE: This does NOT set up session-specific workspaces.
|
||||
Call setup_session_workspace() after provisioning to create a session workspace.
|
||||
|
||||
Args:
|
||||
sandbox_id: Unique identifier for the sandbox
|
||||
user_id: User identifier who owns this sandbox
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
llm_config: LLM provider configuration (for default config)
|
||||
|
||||
Returns:
|
||||
SandboxInfo with the provisioned sandbox details
|
||||
|
||||
Raises:
|
||||
RuntimeError: If provisioning fails
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def terminate(self, sandbox_id: UUID) -> None:
|
||||
"""Terminate a sandbox and clean up all resources.
|
||||
|
||||
Destroys the entire sandbox including all session workspaces.
|
||||
Use cleanup_session_workspace() to remove individual sessions.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID to terminate
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def setup_session_workspace(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
llm_config: LLMProviderConfig,
|
||||
nextjs_port: int,
|
||||
file_system_path: str | None = None,
|
||||
snapshot_path: str | None = None,
|
||||
user_name: str | None = None,
|
||||
user_role: str | None = None,
|
||||
user_work_area: str | None = None,
|
||||
user_level: str | None = None,
|
||||
use_demo_data: bool = False,
|
||||
) -> None:
|
||||
"""Set up a session workspace within an existing sandbox.
|
||||
|
||||
Creates the per-session directory structure:
|
||||
- sessions/$session_id/outputs/ (from snapshot or template)
|
||||
- sessions/$session_id/venv/
|
||||
- sessions/$session_id/skills/
|
||||
- sessions/$session_id/files/ (symlink to demo data or user files)
|
||||
- sessions/$session_id/AGENTS.md
|
||||
- sessions/$session_id/opencode.json
|
||||
- sessions/$session_id/attachments/
|
||||
- sessions/$session_id/org_info/ (if demo data enabled)
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID (must be provisioned)
|
||||
session_id: The session ID for this workspace
|
||||
llm_config: LLM provider configuration for opencode.json
|
||||
file_system_path: Path to user's knowledge/source files
|
||||
snapshot_path: Optional storage path to restore outputs from
|
||||
user_name: User's name for personalization in AGENTS.md
|
||||
user_role: User's role/title for personalization in AGENTS.md
|
||||
user_work_area: User's work area for demo persona (e.g., "engineering")
|
||||
user_level: User's level for demo persona (e.g., "ic", "manager")
|
||||
use_demo_data: If True, symlink files/ to demo data; else to user files
|
||||
|
||||
Raises:
|
||||
RuntimeError: If workspace setup fails
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def cleanup_session_workspace(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
) -> None:
|
||||
"""Clean up a session workspace (on session delete).
|
||||
|
||||
Removes the session directory: sessions/$session_id/
|
||||
Does NOT terminate the sandbox - other sessions may still be using it.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID to clean up
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def create_snapshot(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
tenant_id: str,
|
||||
) -> SnapshotResult | None:
|
||||
"""Create a snapshot of a session's outputs directory.
|
||||
|
||||
Captures only the session-specific outputs:
|
||||
sessions/$session_id/outputs/
|
||||
|
||||
Does NOT include: venv, skills, AGENTS.md, opencode.json, attachments
|
||||
Does NOT include: shared files/ directory
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID to snapshot
|
||||
tenant_id: Tenant identifier for storage path
|
||||
|
||||
Returns:
|
||||
SnapshotResult with storage path and size, or None if
|
||||
snapshots are disabled for this backend
|
||||
|
||||
Raises:
|
||||
RuntimeError: If snapshot creation fails
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def session_workspace_exists(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
) -> bool:
|
||||
"""Check if a session's workspace directory exists in the sandbox.
|
||||
|
||||
Used to determine if we need to restore from snapshot.
|
||||
Checks for sessions/$session_id/outputs/ directory.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID to check
|
||||
|
||||
Returns:
|
||||
True if the session workspace exists, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def restore_snapshot(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
snapshot_storage_path: str,
|
||||
tenant_id: str,
|
||||
nextjs_port: int,
|
||||
) -> None:
|
||||
"""Restore a snapshot into a session's workspace directory.
|
||||
|
||||
Downloads the snapshot from storage, extracts it into
|
||||
sessions/$session_id/outputs/, and starts the NextJS server.
|
||||
|
||||
For Kubernetes backend, this downloads from S3 and streams
|
||||
into the pod via kubectl exec (since the pod has no S3 access).
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID to restore
|
||||
snapshot_storage_path: Path to the snapshot in storage
|
||||
tenant_id: Tenant identifier for storage access
|
||||
nextjs_port: Port number for the NextJS dev server
|
||||
|
||||
Raises:
|
||||
RuntimeError: If snapshot restoration fails
|
||||
FileNotFoundError: If snapshot does not exist
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def health_check(self, sandbox_id: UUID, timeout: float = 60.0) -> bool:
|
||||
"""Check if the sandbox is healthy.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID to check
|
||||
|
||||
Returns:
|
||||
True if sandbox is healthy, False otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def send_message(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
message: str,
|
||||
) -> Generator[ACPEvent, None, None]:
|
||||
"""Send a message to the CLI agent and stream typed ACP events.
|
||||
|
||||
The agent runs in the session-specific workspace:
|
||||
sessions/$session_id/
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID (determines workspace directory)
|
||||
message: The message content to send
|
||||
|
||||
Yields:
|
||||
Typed ACP schema event objects
|
||||
|
||||
Raises:
|
||||
RuntimeError: If agent communication fails
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_directory(
|
||||
self, sandbox_id: UUID, session_id: UUID, path: str
|
||||
) -> list[FilesystemEntry]:
|
||||
"""List contents of a directory in the session's outputs directory.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID
|
||||
path: Relative path within sessions/$session_id/outputs/
|
||||
|
||||
Returns:
|
||||
List of FilesystemEntry objects sorted by directory first, then name
|
||||
|
||||
Raises:
|
||||
ValueError: If path traversal attempted or path is not a directory
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def read_file(self, sandbox_id: UUID, session_id: UUID, path: str) -> bytes:
|
||||
"""Read a file from the session's workspace.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID
|
||||
path: Relative path within sessions/$session_id/
|
||||
|
||||
Returns:
|
||||
File contents as bytes
|
||||
|
||||
Raises:
|
||||
ValueError: If path traversal attempted or path is not a file
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def upload_file(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
) -> str:
|
||||
"""Upload a file to the session's attachments directory.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID
|
||||
filename: Sanitized filename
|
||||
content: File content as bytes
|
||||
|
||||
Returns:
|
||||
Relative path where file was saved (e.g., "attachments/doc.pdf")
|
||||
|
||||
Raises:
|
||||
RuntimeError: If upload fails
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
path: str,
|
||||
) -> bool:
|
||||
"""Delete a file from the session's workspace.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID
|
||||
path: Relative path to the file (e.g., "attachments/doc.pdf")
|
||||
|
||||
Returns:
|
||||
True if file was deleted, False if not found
|
||||
|
||||
Raises:
|
||||
ValueError: If path traversal attempted
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_upload_stats(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
) -> tuple[int, int]:
|
||||
"""Get current file count and total size for a session's attachments.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
Tuple of (file_count, total_size_bytes)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_webapp_url(self, sandbox_id: UUID, port: int) -> str:
|
||||
"""Get the webapp URL for a session's Next.js server.
|
||||
|
||||
Returns the appropriate URL based on the backend:
|
||||
- Local: Returns localhost URL with port
|
||||
- Kubernetes: Returns internal cluster service URL
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
port: The session's allocated Next.js port
|
||||
|
||||
Returns:
|
||||
URL to access the webapp
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def sync_files(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
user_id: UUID,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""Sync files from S3 to the sandbox's /workspace/files directory.
|
||||
|
||||
For Kubernetes backend: Executes `aws s3 sync` in the file-sync sidecar container.
|
||||
For Local backend: No-op since files are directly accessible via symlink.
|
||||
|
||||
This is idempotent - only downloads changed files.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox UUID
|
||||
user_id: The user ID (for S3 path construction)
|
||||
tenant_id: The tenant ID (for S3 path construction)
|
||||
|
||||
Returns:
|
||||
True if sync was successful, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# Singleton instance cache for the factory
|
||||
_sandbox_manager_instance: SandboxManager | None = None
|
||||
_sandbox_manager_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_sandbox_manager() -> SandboxManager:
|
||||
"""Get the appropriate SandboxManager implementation based on SANDBOX_BACKEND.
|
||||
|
||||
Returns:
|
||||
SandboxManager instance:
|
||||
- LocalSandboxManager for local backend (development)
|
||||
- KubernetesSandboxManager for kubernetes backend (production)
|
||||
"""
|
||||
global _sandbox_manager_instance
|
||||
|
||||
if _sandbox_manager_instance is None:
|
||||
with _sandbox_manager_lock:
|
||||
if _sandbox_manager_instance is None:
|
||||
if SANDBOX_BACKEND == SandboxBackend.LOCAL:
|
||||
from onyx.server.features.build.sandbox.local.local_sandbox_manager import (
|
||||
LocalSandboxManager,
|
||||
)
|
||||
|
||||
_sandbox_manager_instance = LocalSandboxManager()
|
||||
elif SANDBOX_BACKEND == SandboxBackend.KUBERNETES:
|
||||
from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import (
|
||||
KubernetesSandboxManager,
|
||||
)
|
||||
|
||||
_sandbox_manager_instance = KubernetesSandboxManager()
|
||||
logger.info("Using KubernetesSandboxManager for sandbox operations")
|
||||
else:
|
||||
raise ValueError(f"Unknown sandbox backend: {SANDBOX_BACKEND}")
|
||||
|
||||
return _sandbox_manager_instance
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Kubernetes-based sandbox implementation.
|
||||
|
||||
This module provides the KubernetesSandboxManager for production deployments
|
||||
that run sandboxes as isolated Kubernetes pods.
|
||||
|
||||
Internal implementation details (acp_http_client) are in the internal/
|
||||
subdirectory and should not be used directly.
|
||||
"""
|
||||
|
||||
from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import (
|
||||
KubernetesSandboxManager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"KubernetesSandboxManager",
|
||||
]
|
||||
@@ -0,0 +1,100 @@
|
||||
# Sandbox Container Image
|
||||
#
|
||||
# User-shared sandbox model:
|
||||
# - One pod per user, shared across all user's sessions
|
||||
# - Session workspaces created via kubectl exec (setup_session_workspace)
|
||||
# - OpenCode agent runs via kubectl exec when needed
|
||||
#
|
||||
# Directory structure (created by init container + session setup):
|
||||
# /workspace/
|
||||
# ├── demo-data/ # Demo data (baked into image, for demo sessions)
|
||||
# ├── files/ # User's knowledge files (synced from S3)
|
||||
# ├── templates/ # Output templates (baked into image)
|
||||
# └── sessions/ # Per-session workspaces (created via exec)
|
||||
# └── $session_id/
|
||||
# ├── files/ # Symlink to /workspace/demo-data or /workspace/files
|
||||
# ├── outputs/
|
||||
# ├── AGENTS.md
|
||||
# └── opencode.json
|
||||
|
||||
FROM node:20-slim
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3 \
|
||||
python3-pip \
|
||||
python3-venv \
|
||||
curl \
|
||||
git \
|
||||
procps \
|
||||
unzip \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create non-root user (matches pod securityContext)
|
||||
# Handle existing user/group with UID/GID 1000 in base image
|
||||
RUN EXISTING_USER=$(id -nu 1000 2>/dev/null || echo ""); \
|
||||
EXISTING_GROUP=$(getent group 1000 | cut -d: -f1 2>/dev/null || echo ""); \
|
||||
if [ -n "$EXISTING_GROUP" ] && [ "$EXISTING_GROUP" != "sandbox" ]; then \
|
||||
groupmod -n sandbox $EXISTING_GROUP; \
|
||||
elif [ -z "$EXISTING_GROUP" ]; then \
|
||||
groupadd -g 1000 sandbox; \
|
||||
fi; \
|
||||
if [ -n "$EXISTING_USER" ] && [ "$EXISTING_USER" != "sandbox" ]; then \
|
||||
usermod -l sandbox -g sandbox $EXISTING_USER; \
|
||||
usermod -d /home/sandbox -m sandbox; \
|
||||
usermod -s /bin/bash sandbox; \
|
||||
elif [ -z "$EXISTING_USER" ]; then \
|
||||
useradd -u 1000 -g sandbox -m -s /bin/bash sandbox; \
|
||||
fi
|
||||
|
||||
# Create workspace directories
|
||||
RUN mkdir -p workspace/sessions /workspace/files /workspace/templates /workspace/demo-data && \
|
||||
chown -R sandbox:sandbox /workspace
|
||||
|
||||
# Copy outputs template (web app scaffold, without node_modules)
|
||||
COPY --exclude=.next --exclude=node_modules templates/outputs /workspace/templates/outputs
|
||||
RUN chown -R sandbox:sandbox /workspace/templates
|
||||
|
||||
# Copy and extract demo data from zip file
|
||||
COPY demo_data.zip /tmp/demo_data.zip
|
||||
RUN unzip -q /tmp/demo_data.zip -d /workspace/demo-data && \
|
||||
rm /tmp/demo_data.zip && \
|
||||
chown -R sandbox:sandbox /workspace/demo-data
|
||||
|
||||
# Copy and install Python requirements into a venv
|
||||
COPY initial-requirements.txt /tmp/initial-requirements.txt
|
||||
RUN python3 -m venv /workspace/.venv && \
|
||||
/workspace/.venv/bin/pip install --upgrade pip && \
|
||||
/workspace/.venv/bin/pip install -r /tmp/initial-requirements.txt && \
|
||||
rm /tmp/initial-requirements.txt && \
|
||||
chown -R sandbox:sandbox /workspace/.venv
|
||||
|
||||
# Add venv to PATH so python/pip use it by default
|
||||
ENV PATH="/workspace/.venv/bin:${PATH}"
|
||||
|
||||
# Install opencode CLI as sandbox user so it goes to their home directory
|
||||
USER sandbox
|
||||
RUN curl -fsSL https://opencode.ai/install | bash
|
||||
USER root
|
||||
|
||||
# Add opencode to PATH (installs to ~/.opencode/bin)
|
||||
ENV PATH="/home/sandbox/.opencode/bin:${PATH}"
|
||||
|
||||
# Set ownership
|
||||
RUN chown -R sandbox:sandbox /workspace
|
||||
|
||||
# Copy scripts
|
||||
COPY generate_agents_md.py /usr/local/bin/generate_agents_md.py
|
||||
RUN chmod +x /usr/local/bin/generate_agents_md.py
|
||||
|
||||
# Switch to non-root user
|
||||
USER sandbox
|
||||
WORKDIR /workspace
|
||||
|
||||
# Expose ports
|
||||
# - 3000: Next.js dev server (started per-session if needed)
|
||||
# - 8081: OpenCode ACP HTTP server (started via exec)
|
||||
EXPOSE 3000 8081
|
||||
|
||||
# Keep container alive - all work done via kubectl exec
|
||||
CMD ["sleep", "infinity"]
|
||||
Binary file not shown.
@@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate AGENTS.md by scanning the files directory and populating the template.
|
||||
|
||||
This script runs at container startup, AFTER the init container has synced files
|
||||
from S3. It scans the /workspace/files directory to discover what knowledge sources
|
||||
are available and generates appropriate documentation.
|
||||
|
||||
Environment variables:
|
||||
- AGENT_INSTRUCTIONS: The template content with placeholders to replace
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Connector descriptions for known connector types
|
||||
# Keep in sync with agent_instructions.py CONNECTOR_DESCRIPTIONS
|
||||
CONNECTOR_DESCRIPTIONS = {
|
||||
"google_drive": (
|
||||
"**Google Drive**: Copied over directly as is. "
|
||||
"End files are stored as `FILE_NAME.json`."
|
||||
),
|
||||
"gmail": (
|
||||
"**Gmail**: Copied over directly as is. "
|
||||
"End files are stored as `FILE_NAME.json`."
|
||||
),
|
||||
"linear": (
|
||||
"**Linear**: Each project is a folder, and within each project, "
|
||||
"individual tickets are stored as `[TICKET_ID]_TICKET_NAME.json`."
|
||||
),
|
||||
"slack": (
|
||||
"**Slack**: Each channel is a folder titled `[CHANNEL_NAME]`. "
|
||||
"Within each channel, each thread is a single file called "
|
||||
"`[INITIAL_AUTHOR]_in_[CHANNEL]__[FIRST_MESSAGE].json`."
|
||||
),
|
||||
"github": (
|
||||
"**Github**: Each organization is a folder titled `[ORG_NAME]`. "
|
||||
"Within each organization, there is a folder for each repository "
|
||||
"titled `[REPO_NAME]`. Within each repository there are up to two "
|
||||
"folders: `pull_requests` and `issues`. Pull requests are structured "
|
||||
"as `[PR_ID]__[PR_NAME].json` and issues as `[ISSUE_ID]__[ISSUE_NAME].json`."
|
||||
),
|
||||
"fireflies": (
|
||||
"**Fireflies**: All calls are in the root, each as a single file "
|
||||
"titled `CALL_TITLE.json`."
|
||||
),
|
||||
"hubspot": (
|
||||
"**HubSpot**: Four folders in the root: `Tickets`, `Companies`, "
|
||||
"`Deals`, and `Contacts`. Each object is stored as a file named "
|
||||
"after its title/name (e.g., `[TICKET_SUBJECT].json`, `[COMPANY_NAME].json`)."
|
||||
),
|
||||
"notion": (
|
||||
"**Notion**: Pages and databases are organized hierarchically. "
|
||||
"Each page is stored as `PAGE_TITLE.json`."
|
||||
),
|
||||
"org_info": (
|
||||
"**Org Info**: Contains organizational data and identity information."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def build_file_structure_section(files_path: Path) -> str:
|
||||
"""Build the file structure section by scanning the files directory."""
|
||||
if not files_path.exists():
|
||||
return "No knowledge sources available."
|
||||
|
||||
sources = []
|
||||
try:
|
||||
for item in sorted(files_path.iterdir()):
|
||||
if not item.is_dir() or item.name.startswith("."):
|
||||
continue
|
||||
|
||||
file_count = sum(1 for f in item.rglob("*") if f.is_file())
|
||||
subdir_count = sum(1 for d in item.rglob("*") if d.is_dir())
|
||||
|
||||
details = []
|
||||
if file_count > 0:
|
||||
details.append(f"{file_count} file{'s' if file_count != 1 else ''}")
|
||||
if subdir_count > 0:
|
||||
details.append(
|
||||
f"{subdir_count} subdirector{'ies' if subdir_count != 1 else 'y'}"
|
||||
)
|
||||
|
||||
source_info = f"- **{item.name}/**"
|
||||
if details:
|
||||
source_info += f" ({', '.join(details)})"
|
||||
sources.append(source_info)
|
||||
except Exception as e:
|
||||
print(f"Warning: Error scanning files directory: {e}", file=sys.stderr)
|
||||
return "Error scanning knowledge sources."
|
||||
|
||||
if not sources:
|
||||
return "No knowledge sources available."
|
||||
|
||||
header = "The `files/` directory contains the following knowledge sources:\n\n"
|
||||
return header + "\n".join(sources)
|
||||
|
||||
|
||||
def build_connector_descriptions(files_path: Path) -> str:
|
||||
"""Build connector-specific descriptions for available data sources."""
|
||||
if not files_path.exists():
|
||||
return ""
|
||||
|
||||
descriptions = []
|
||||
try:
|
||||
for item in sorted(files_path.iterdir()):
|
||||
if not item.is_dir() or item.name.startswith("."):
|
||||
continue
|
||||
|
||||
normalized = item.name.lower().replace(" ", "_").replace("-", "_")
|
||||
if normalized in CONNECTOR_DESCRIPTIONS:
|
||||
descriptions.append(f"- {CONNECTOR_DESCRIPTIONS[normalized]}")
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Error scanning for connector descriptions: {e}", file=sys.stderr
|
||||
)
|
||||
return ""
|
||||
|
||||
if not descriptions:
|
||||
return ""
|
||||
|
||||
header = "Each connector type organizes its data differently:\n\n"
|
||||
footer = "\n\nSpaces in names are replaced by `_`."
|
||||
return header + "\n".join(descriptions) + footer
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Read template from environment variable
|
||||
template = os.environ.get("AGENT_INSTRUCTIONS", "")
|
||||
if not template:
|
||||
print("Warning: No AGENT_INSTRUCTIONS template provided", file=sys.stderr)
|
||||
template = "# Agent Instructions\n\nNo instructions provided."
|
||||
|
||||
# Scan files directory
|
||||
files_path = Path("/workspace/files")
|
||||
file_structure = build_file_structure_section(files_path)
|
||||
connector_descriptions = build_connector_descriptions(files_path)
|
||||
|
||||
# Replace placeholders
|
||||
content = template
|
||||
content = content.replace("{{FILE_STRUCTURE_SECTION}}", file_structure)
|
||||
content = content.replace(
|
||||
"{{CONNECTOR_DESCRIPTIONS_SECTION}}", connector_descriptions
|
||||
)
|
||||
|
||||
# Write AGENTS.md
|
||||
output_path = Path("/workspace/AGENTS.md")
|
||||
output_path.write_text(content)
|
||||
|
||||
# Log result
|
||||
source_count = 0
|
||||
if files_path.exists():
|
||||
source_count = len(
|
||||
[
|
||||
d
|
||||
for d in files_path.iterdir()
|
||||
if d.is_dir() and not d.name.startswith(".")
|
||||
]
|
||||
)
|
||||
print(f"Generated AGENTS.md with {source_count} knowledge sources")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,17 @@
|
||||
google-genai>=1.0.0
|
||||
matplotlib==3.9.1
|
||||
matplotlib-inline>=0.1.7
|
||||
matplotlib-venn>=1.1.2
|
||||
numpy==1.26.4
|
||||
opencv-python>=4.11.0.86
|
||||
openpyxl>=3.1.5
|
||||
pandas==2.2.2
|
||||
pdfplumber>=0.11.7
|
||||
Pillow>=10.0.0
|
||||
pydantic>=2.11.9
|
||||
python-pptx>=1.0.2
|
||||
scikit-image>=0.25.2
|
||||
scikit-learn>=1.7.2
|
||||
scipy>=1.16.2
|
||||
seaborn>=0.13.2
|
||||
xgboost>=3.0.5
|
||||
80
backend/onyx/server/features/build/sandbox/kubernetes/docker/run-test.sh
Executable file
80
backend/onyx/server/features/build/sandbox/kubernetes/docker/run-test.sh
Executable file
@@ -0,0 +1,80 @@
|
||||
#!/bin/bash
|
||||
# Run Kubernetes sandbox integration tests
|
||||
#
|
||||
# This script:
|
||||
# 1. Builds the onyx-backend Docker image
|
||||
# 2. Loads it into the kind cluster
|
||||
# 3. Deletes/recreates the test pod
|
||||
# 4. Waits for the pod to be ready
|
||||
# 5. Runs the pytest command inside the pod
|
||||
#
|
||||
# Usage:
|
||||
# ./run-test.sh [test_name]
|
||||
#
|
||||
# Examples:
|
||||
# ./run-test.sh # Run all tests
|
||||
# ./run-test.sh test_kubernetes_sandbox_provision # Run specific test
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../../../../../../../.." && pwd)"
|
||||
NAMESPACE="onyx-sandboxes"
|
||||
POD_NAME="sandbox-test"
|
||||
IMAGE_NAME="onyxdotapp/onyx-backend:latest"
|
||||
TEST_FILE="onyx/server/features/build/sandbox/kubernetes/test_kubernetes_sandbox.py"
|
||||
ENV_FILE="$PROJECT_ROOT/.vscode/.env"
|
||||
|
||||
ORIGINAL_TEST_FILE="$PROJECT_ROOT/backend/tests/external_dependency_unit/craft/test_kubernetes_sandbox.py"
|
||||
cp "$ORIGINAL_TEST_FILE" "$PROJECT_ROOT/backend/$TEST_FILE"
|
||||
|
||||
# Optional: specific test to run
|
||||
TEST_NAME="${1:-}"
|
||||
|
||||
# Build env var arguments from .vscode/.env file for passing to the container
|
||||
ENV_VARS=()
|
||||
if [ -f "$ENV_FILE" ]; then
|
||||
echo "=== Loading environment variables from .vscode/.env ==="
|
||||
while IFS= read -r line || [ -n "$line" ]; do
|
||||
# Skip empty lines and comments
|
||||
[[ -z "$line" || "$line" =~ ^[[:space:]]*# ]] && continue
|
||||
# Skip lines without =
|
||||
[[ "$line" != *"="* ]] && continue
|
||||
# Add to env vars array
|
||||
ENV_VARS+=("$line")
|
||||
done < "$ENV_FILE"
|
||||
echo "Loaded ${#ENV_VARS[@]} environment variables"
|
||||
else
|
||||
echo "Warning: .vscode/.env not found, running without additional env vars"
|
||||
fi
|
||||
|
||||
echo "=== Building onyx-backend Docker image ==="
|
||||
cd "$PROJECT_ROOT/backend"
|
||||
docker build -t "$IMAGE_NAME" -f Dockerfile .
|
||||
|
||||
rm "$PROJECT_ROOT/backend/$TEST_FILE"
|
||||
|
||||
echo "=== Loading image into kind cluster ==="
|
||||
kind load docker-image "$IMAGE_NAME" --name onyx 2>/dev/null || \
|
||||
kind load docker-image "$IMAGE_NAME" 2>/dev/null || \
|
||||
echo "Warning: Could not load into kind. If using minikube, run: minikube image load $IMAGE_NAME"
|
||||
|
||||
echo "=== Deleting existing test pod (if any) ==="
|
||||
kubectl delete pod "$POD_NAME" -n "$NAMESPACE" --ignore-not-found=true
|
||||
|
||||
echo "=== Creating test pod ==="
|
||||
kubectl apply -f "$SCRIPT_DIR/test-job.yaml"
|
||||
|
||||
echo "=== Waiting for pod to be ready ==="
|
||||
kubectl wait --for=condition=Ready pod/"$POD_NAME" -n "$NAMESPACE" --timeout=120s
|
||||
|
||||
echo "=== Running tests ==="
|
||||
if [ -n "$TEST_NAME" ]; then
|
||||
kubectl exec -it "$POD_NAME" -n "$NAMESPACE" -- \
|
||||
env "${ENV_VARS[@]}" pytest "$TEST_FILE::$TEST_NAME" -v -s
|
||||
else
|
||||
kubectl exec -it "$POD_NAME" -n "$NAMESPACE" -- \
|
||||
env "${ENV_VARS[@]}" pytest "$TEST_FILE" -v -s
|
||||
fi
|
||||
|
||||
echo "=== Tests complete ==="
|
||||
41
backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/.gitignore
vendored
Normal file
41
backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/.gitignore
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
|
||||
|
||||
# dependencies
|
||||
/node_modules
|
||||
/.pnp
|
||||
.pnp.*
|
||||
.yarn/*
|
||||
!.yarn/patches
|
||||
!.yarn/plugins
|
||||
!.yarn/releases
|
||||
!.yarn/versions
|
||||
|
||||
# testing
|
||||
/coverage
|
||||
|
||||
# next.js
|
||||
/.next/
|
||||
/out/
|
||||
|
||||
# production
|
||||
/build
|
||||
|
||||
# misc
|
||||
.DS_Store
|
||||
*.pem
|
||||
|
||||
# debug
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
.pnpm-debug.log*
|
||||
|
||||
# env files (can opt-in for committing if needed)
|
||||
.env*
|
||||
|
||||
# vercel
|
||||
.vercel
|
||||
|
||||
# typescript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
||||
@@ -0,0 +1,803 @@
|
||||
# AGENTS.md
|
||||
|
||||
This file provides guidance to AI agents when working on the web application within this directory.
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **The development server is already running** at a dynamically allocated port. Do NOT run `npm run dev` yourself.
|
||||
- **We do NOT use a `src` directory** - all code lives directly in the root folders (`app/`, `components/`, `lib/`, etc.)
|
||||
- If the app needs pre-computation (data processing, API calls, etc.), create a bash or python script called `prepare.sh`/`prepare.py` at the root of this directory
|
||||
- **CRITICAL: Create small, modular components** - Do NOT write everything in `page.tsx`. Break your UI into small, reusable components in the `components/` directory. Each component should have a single responsibility and be in its own file.
|
||||
|
||||
## Data Preparation Scripts
|
||||
|
||||
**CRITICAL: Always re-run data scripts after modifying them.**
|
||||
|
||||
If a `prepare.sh` or `prepare.py` script exists at the root of this directory, it is responsible for generating/loading data that the frontend consumes.
|
||||
|
||||
### When to Run the Script
|
||||
|
||||
You MUST run the data preparation script:
|
||||
1. **After creating** the script for the first time
|
||||
2. **After modifying** the script logic (new data sources, changed processing, etc.)
|
||||
3. **After updating** any data files the script reads from
|
||||
4. **Before testing** the frontend if you're unsure if data is fresh
|
||||
|
||||
### How to Run
|
||||
|
||||
```bash
|
||||
# For bash scripts
|
||||
bash prepare.sh
|
||||
|
||||
# For python scripts
|
||||
python prepare.py
|
||||
```
|
||||
|
||||
### Common Mistake
|
||||
|
||||
❌ **Updating the script but forgetting to run it** - This leaves stale data in place and the frontend won't reflect your changes. Always run the script immediately after modifying it.
|
||||
|
||||
## Commands
|
||||
|
||||
```bash
|
||||
npm run dev # Start development server (DO NOT RUN - already running)
|
||||
npm run lint # Run ESLint
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
This is a **Next.js 16.1.1** application using the **App Router** with **React 19** and **TypeScript**. It serves as a component showcase/template built on shadcn/ui.
|
||||
|
||||
### File Organization Philosophy
|
||||
|
||||
**Prioritize small, incremental file writes.** Break your application into many small components rather than monolithic page files.
|
||||
|
||||
#### Component Organization
|
||||
|
||||
```
|
||||
components/
|
||||
├── dashboard/ # Feature-specific components
|
||||
│ ├── stats-card.tsx
|
||||
│ ├── activity-feed.tsx
|
||||
│ └── recent-items.tsx
|
||||
├── charts/ # Chart components
|
||||
│ ├── line-chart.tsx
|
||||
│ ├── bar-chart.tsx
|
||||
│ └── pie-chart.tsx
|
||||
├── data/ # Data display components
|
||||
│ ├── data-table.tsx
|
||||
│ ├── filter-bar.tsx
|
||||
│ └── sort-controls.tsx
|
||||
└── layout/ # Layout components
|
||||
├── header.tsx
|
||||
├── sidebar.tsx
|
||||
└── footer.tsx
|
||||
```
|
||||
|
||||
#### Page Structure
|
||||
|
||||
Pages (`app/page.tsx`) should be **thin orchestration layers** that compose components:
|
||||
|
||||
```typescript
|
||||
// ✅ GOOD - page.tsx is just composition
|
||||
import { StatsCard } from "@/components/dashboard/stats-card";
|
||||
import { ActivityFeed } from "@/components/dashboard/activity-feed";
|
||||
import { RecentItems } from "@/components/dashboard/recent-items";
|
||||
|
||||
export default function DashboardPage() {
|
||||
return (
|
||||
<div className="container py-6 space-y-6">
|
||||
<h1 className="text-3xl font-bold">Dashboard</h1>
|
||||
<div className="grid grid-cols-1 md:grid-cols-3 gap-4">
|
||||
<StatsCard title="Total Users" value={1234} />
|
||||
<StatsCard title="Active Sessions" value={56} />
|
||||
<StatsCard title="Revenue" value="$12,345" />
|
||||
</div>
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 gap-6">
|
||||
<ActivityFeed />
|
||||
<RecentItems />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ❌ BAD - Everything in page.tsx (500+ lines of mixed logic)
|
||||
export default function DashboardPage() {
|
||||
// ... 500 lines of component logic, state, handlers, JSX ...
|
||||
}
|
||||
```
|
||||
|
||||
#### Component Granularity
|
||||
|
||||
Create a new component file when:
|
||||
- A UI section has distinct functionality (e.g., `user-profile-card.tsx`)
|
||||
- Logic exceeds ~50-100 lines
|
||||
- A pattern is reused 2+ times
|
||||
- Testing/maintenance would benefit from isolation
|
||||
|
||||
**Example: Dashboard Feature**
|
||||
|
||||
Instead of writing everything in `app/page.tsx`:
|
||||
|
||||
```typescript
|
||||
// components/dashboard/stats-card.tsx
|
||||
export function StatsCard({ title, value, trend }: StatsCardProps) {
|
||||
return (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="text-sm font-medium">{title}</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div className="text-2xl font-bold">{value}</div>
|
||||
{trend && <p className="text-xs text-muted-foreground">{trend}</p>}
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
// components/dashboard/activity-feed.tsx
|
||||
export function ActivityFeed() {
|
||||
// Activity feed logic here
|
||||
}
|
||||
|
||||
// components/dashboard/recent-items.tsx
|
||||
export function RecentItems() {
|
||||
// Recent items logic here
|
||||
}
|
||||
```
|
||||
|
||||
#### Benefits of Small Components
|
||||
|
||||
1. **Incremental Development**: Write one component at a time, test, iterate
|
||||
2. **Better Diffs**: Smaller files = clearer git diffs and easier reviews
|
||||
3. **Reusability**: Components can be imported across pages
|
||||
4. **Maintainability**: Easier to locate and fix issues
|
||||
5. **Hot Reload Efficiency**: Changes to small files reload faster
|
||||
6. **Parallel Development**: Multiple features can be worked on independently
|
||||
|
||||
### Tech Stack
|
||||
|
||||
- **Framework**: Next.js 16.1.1 with App Router
|
||||
- **React**: React 19
|
||||
- **Language**: TypeScript
|
||||
- **Styling**: Tailwind CSS v4 with CSS variables in OKLCH color space
|
||||
- **Charts**: recharts for data visualization
|
||||
- **UI Components**: shadcn/ui (53 components) built on Radix UI primitives
|
||||
- **Variants**: class-variance-authority (CVA) for component variants
|
||||
- **Class Merging**: `cn()` utility in `lib/utils.ts` (clsx + tailwind-merge)
|
||||
- **Theme**: Dark mode enforced (via `dark` class on `<html>`)
|
||||
|
||||
### Key Directories
|
||||
|
||||
- `app/` - Next.js App Router pages and layouts
|
||||
- `components/ui/` - shadcn/ui component library (Button, Card, Dialog, etc.)
|
||||
- `components/` - App-specific components
|
||||
- `hooks/` - Custom React hooks (e.g., `use-mobile.ts`)
|
||||
- `lib/` - Utilities (`cn()` function)
|
||||
|
||||
### Component Patterns
|
||||
|
||||
- **Compound Components**: Components like `DropdownMenu`, `Dialog`, `Select` export multiple sub-components (Trigger, Content, Item)
|
||||
- **Variants via CVA**: Use `variants` prop for size/style variations (e.g., `buttonVariants`)
|
||||
- **Radix UI Primitives**: UI components wrap Radix for accessibility
|
||||
|
||||
### Path Aliases
|
||||
|
||||
All imports use `@/` alias (e.g., `@/components/ui/button`, `@/lib/utils`)
|
||||
|
||||
### shadcn/ui Configuration
|
||||
|
||||
Located in `components.json`:
|
||||
|
||||
- Style: `radix-nova`
|
||||
- RSC enabled
|
||||
- Icons: lucide-react
|
||||
|
||||
### Theme Variables
|
||||
|
||||
Global CSS variables defined in `app/globals.css` control colors, radius, and spacing. **Dark mode is enforced site-wide** via the `dark` class on the `<html>` element in `app/layout.tsx`. All styling should assume dark mode is active.
|
||||
|
||||
### Dark Mode Priority
|
||||
|
||||
- **Dark mode is the default and only theme** - do not design for light mode
|
||||
- The `dark` class is permanently set on `<html>` in `layout.tsx`
|
||||
- Use dark-appropriate colors: `bg-background`, `text-foreground`, etc.
|
||||
- Ensure sufficient contrast for dark backgrounds
|
||||
- Test all components in dark mode only
|
||||
|
||||
## Styling Guidelines
|
||||
|
||||
### CRITICAL: Use Only shadcn/ui Components
|
||||
|
||||
**MINIMIZE freestyling and creating custom components.** This application uses a complete, professionally designed component library (shadcn/ui). You MUST use the existing components from `components/ui/` for most UI needs.
|
||||
|
||||
#### Available shadcn/ui Components
|
||||
|
||||
All components are in `components/ui/`. Import using `@/components/ui/component-name`.
|
||||
|
||||
**Layout & Structure:**
|
||||
|
||||
- `Card` (`card.tsx`) - Content containers with CardHeader, CardTitle, CardDescription, CardContent, CardFooter
|
||||
- `Separator` (`separator.tsx`) - Horizontal/vertical dividers
|
||||
- `Tabs` (`tabs.tsx`) - Tabbed interfaces with Tabs, TabsList, TabsTrigger, TabsContent
|
||||
- `ScrollArea` (`scroll-area.tsx`) - Styled scrollable regions
|
||||
- `Resizable` (`resizable.tsx`) - Resizable panel layouts
|
||||
- `Drawer` (`drawer.tsx`) - Bottom/side drawer overlays
|
||||
- `Sidebar` (`sidebar.tsx`) - Application sidebar layout
|
||||
- `AspectRatio` (`aspect-ratio.tsx`) - Maintain aspect ratios
|
||||
|
||||
**Forms & Inputs:**
|
||||
|
||||
- `Button` (`button.tsx`) - Primary, secondary, destructive, outline, ghost, link variants
|
||||
- `ButtonGroup` (`button-group.tsx`) - Group of related buttons
|
||||
- `Input` (`input.tsx`) - Text inputs with various states
|
||||
- `InputGroup` (`input-group.tsx`) - Input with addons/icons
|
||||
- `Textarea` (`textarea.tsx`) - Multi-line text input
|
||||
- `Checkbox` (`checkbox.tsx`) - Checkboxes with indeterminate state
|
||||
- `RadioGroup` (`radio-group.tsx`) - Radio button groups
|
||||
- `Switch` (`switch.tsx`) - Toggle switches
|
||||
- `Select` (`select.tsx`) - Dropdown select menus
|
||||
- `NativeSelect` (`native-select.tsx`) - Native HTML select
|
||||
- `Combobox` (`combobox.tsx`) - Autocomplete select with search
|
||||
- `Command` (`command.tsx`) - Command palette/search interface
|
||||
- `Field` (`field.tsx`) - Form field wrapper with label and error
|
||||
- `Label` (`label.tsx`) - Form labels with proper accessibility
|
||||
- `Slider` (`slider.tsx`) - Range sliders
|
||||
- `Calendar` (`calendar.tsx`) - Date picker calendar
|
||||
- `Toggle` (`toggle.tsx`) - Toggle button
|
||||
- `ToggleGroup` (`toggle-group.tsx`) - Group of toggle buttons
|
||||
|
||||
**Navigation:**
|
||||
|
||||
- `NavigationMenu` (`navigation-menu.tsx`) - Complex navigation menus
|
||||
- `Menubar` (`menubar.tsx`) - Application menu bar
|
||||
- `Breadcrumb` (`breadcrumb.tsx`) - Breadcrumb navigation
|
||||
- `Pagination` (`pagination.tsx`) - Page navigation controls
|
||||
|
||||
**Feedback & Overlays:**
|
||||
|
||||
- `Dialog` (`dialog.tsx`) - Modal dialogs
|
||||
- `AlertDialog` (`alert-dialog.tsx`) - Confirmation dialogs
|
||||
- `Sheet` (`sheet.tsx`) - Side sheets/panels
|
||||
- `Popover` (`popover.tsx`) - Floating popovers
|
||||
- `HoverCard` (`hover-card.tsx`) - Hover-triggered cards
|
||||
- `Tooltip` (`tooltip.tsx`) - Tooltips on hover
|
||||
- `Sonner` (`sonner.tsx`) - Toast notifications
|
||||
- `Alert` (`alert.tsx`) - Static alert messages
|
||||
- `Progress` (`progress.tsx`) - Progress bars
|
||||
- `Skeleton` (`skeleton.tsx`) - Loading skeletons
|
||||
- `Spinner` (`spinner.tsx`) - Loading spinners
|
||||
- `Empty` (`empty.tsx`) - Empty state placeholder
|
||||
|
||||
**Menus & Dropdowns:**
|
||||
|
||||
- `DropdownMenu` (`dropdown-menu.tsx`) - Dropdown menus with submenus
|
||||
- `ContextMenu` (`context-menu.tsx`) - Right-click context menus
|
||||
|
||||
**Data Display:**
|
||||
|
||||
- `Table` (`table.tsx`) - Data tables with Table, TableHeader, TableBody, TableRow, TableCell, etc.
|
||||
- `Badge` (`badge.tsx`) - Status badges and tags
|
||||
- `Avatar` (`avatar.tsx`) - User avatars with fallbacks
|
||||
- `Accordion` (`accordion.tsx`) - Collapsible content sections
|
||||
- `Collapsible` (`collapsible.tsx`) - Simple collapse/expand
|
||||
- `Carousel` (`carousel.tsx`) - Image/content carousels
|
||||
- `Item` (`item.tsx`) - List item component
|
||||
- `Kbd` (`kbd.tsx`) - Keyboard shortcut display
|
||||
|
||||
**Data Visualization:**
|
||||
|
||||
- `Chart` (`chart.tsx`) - Chart wrapper with ChartContainer, ChartTooltip, ChartTooltipContent, ChartLegend, ChartLegendContent
|
||||
|
||||
### Component Usage Principles
|
||||
|
||||
#### 1. **Never Create Custom Components**
|
||||
|
||||
```typescript
|
||||
// ❌ WRONG - Do not create freestyle components
|
||||
function CustomCard({ title, children }) {
|
||||
return (
|
||||
<div className="rounded-lg border p-4">
|
||||
<h3 className="font-bold">{title}</h3>
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ✅ CORRECT - Use shadcn Card
|
||||
import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card";
|
||||
|
||||
function MyComponent() {
|
||||
return (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Title</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>Content here</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. **Use Component Variants, Don't Style Directly**
|
||||
|
||||
```typescript
|
||||
// ❌ WRONG - Applying custom Tailwind classes
|
||||
<button className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded">
|
||||
Click me
|
||||
</button>
|
||||
|
||||
// ✅ CORRECT - Use Button variants
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
<Button variant="default">Click me</Button>
|
||||
<Button variant="destructive">Delete</Button>
|
||||
<Button variant="outline">Cancel</Button>
|
||||
<Button variant="ghost">Subtle Action</Button>
|
||||
<Button size="sm">Small</Button>
|
||||
<Button size="lg">Large</Button>
|
||||
```
|
||||
|
||||
#### 3. **Compose Compound Components**
|
||||
|
||||
Many shadcn components export multiple sub-components. Use them as designed:
|
||||
|
||||
```typescript
|
||||
// ✅ Dropdown Menu Composition
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuTrigger,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuLabel,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="outline">Options</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
<DropdownMenuLabel>Actions</DropdownMenuLabel>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem>Edit</DropdownMenuItem>
|
||||
<DropdownMenuItem>Delete</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
```
|
||||
|
||||
#### 4. **Use Layout Components for Structure**
|
||||
|
||||
```typescript
|
||||
// ✅ Use Card for content sections
|
||||
import { Card, CardHeader, CardTitle, CardDescription, CardContent, CardFooter } from "@/components/ui/card";
|
||||
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Dashboard</CardTitle>
|
||||
<CardDescription>Overview of your data</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{/* Your content */}
|
||||
</CardContent>
|
||||
<CardFooter>
|
||||
<Button>Action</Button>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
```
|
||||
|
||||
### Styling Rules
|
||||
|
||||
#### 1. **Spacing & Layout**
|
||||
|
||||
Use Tailwind's utility classes for spacing, but stick to the design system:
|
||||
|
||||
- Gap: `gap-2`, `gap-4`, `gap-6`, `gap-8`
|
||||
- Padding: `p-2`, `p-4`, `p-6`, `p-8`
|
||||
- Margins: Prefer `gap` and `space-y-*` over margins
|
||||
|
||||
#### 2. **Colors**
|
||||
|
||||
All colors come from CSS variables in `app/globals.css`. Use semantic color classes:
|
||||
|
||||
- `bg-background`, `bg-foreground`
|
||||
- `bg-card`, `text-card-foreground`
|
||||
- `bg-primary`, `text-primary-foreground`
|
||||
- `bg-secondary`, `text-secondary-foreground`
|
||||
- `bg-muted`, `text-muted-foreground`
|
||||
- `bg-accent`, `text-accent-foreground`
|
||||
- `bg-destructive`, `text-destructive-foreground`
|
||||
- `border-border`, `border-input`
|
||||
- `ring-ring`
|
||||
|
||||
**DO NOT use arbitrary color values** like `bg-blue-500` or `text-red-600`.
|
||||
|
||||
#### **CRITICAL: Color Contrast Pairing Rules**
|
||||
|
||||
**Always pair background colors with their matching foreground colors.** The color system uses paired variables where each background has a corresponding text color designed for proper contrast.
|
||||
|
||||
| Background Class | Text Class to Use | Description |
|
||||
|-----------------|-------------------|-------------|
|
||||
| `bg-background` | `text-foreground` | Main page background |
|
||||
| `bg-card` | `text-card-foreground` | Card containers |
|
||||
| `bg-primary` | `text-primary-foreground` | Primary buttons/accents |
|
||||
| `bg-secondary` | `text-secondary-foreground` | Secondary elements |
|
||||
| `bg-muted` | `text-muted-foreground` | Muted/subtle areas |
|
||||
| `bg-accent` | `text-accent-foreground` | Accent highlights |
|
||||
| `bg-destructive` | `text-destructive-foreground` | Error/delete actions |
|
||||
|
||||
**Examples:**
|
||||
|
||||
```typescript
|
||||
// ✅ CORRECT - Matching background and foreground pairs
|
||||
<div className="bg-card text-card-foreground">Content</div>
|
||||
<Button className="bg-primary text-primary-foreground">Click</Button>
|
||||
<div className="bg-muted text-muted-foreground">Subtle text</div>
|
||||
|
||||
// ❌ WRONG - Mismatched colors causing contrast issues
|
||||
<div className="bg-background text-background">Invisible text!</div>
|
||||
<div className="bg-card text-foreground">May have poor contrast</div>
|
||||
<Button className="bg-primary text-primary">White on white!</Button>
|
||||
```
|
||||
|
||||
**Key Rules:**
|
||||
|
||||
1. **Never use the same color for background and text** (e.g., `bg-foreground text-foreground`)
|
||||
2. **Always use the `-foreground` variant for text** when using a colored background
|
||||
3. **For text on `bg-background`**, use `text-foreground` (primary) or `text-muted-foreground` (secondary)
|
||||
4. **Test visually** - if text is hard to read, you have a contrast problem
|
||||
|
||||
#### 3. **Typography**
|
||||
|
||||
Use Tailwind text utilities (no separate Typography component):
|
||||
|
||||
- Headings: `text-xl font-semibold`, `text-2xl font-bold`, etc.
|
||||
- Body: `text-sm`, `text-base`
|
||||
- Secondary text: `text-muted-foreground`
|
||||
- Use semantic HTML: `<h1>`, `<h2>`, `<p>`, etc.
|
||||
- **Always wrap text** - Use `max-w-prose` or `max-w-xl` for readable line lengths
|
||||
- **Prevent overflow** - Use `break-words` or `truncate` for long text that might overflow containers
|
||||
|
||||
#### 4. **Responsive Design**
|
||||
|
||||
Use Tailwind's responsive prefixes:
|
||||
|
||||
```typescript
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{/* Responsive grid */}
|
||||
</div>
|
||||
```
|
||||
|
||||
#### 5. **Icons**
|
||||
|
||||
Use Lucide React icons (already configured):
|
||||
|
||||
```typescript
|
||||
import { Check, X, ChevronDown, User } from "lucide-react";
|
||||
|
||||
<Button>
|
||||
<Check className="mr-2 h-4 w-4" />
|
||||
Confirm
|
||||
</Button>
|
||||
```
|
||||
|
||||
### Data Visualization
|
||||
|
||||
For charts and data visualization, use the **shadcn/ui Chart components** (`@/components/ui/chart`) which wrap recharts with consistent theming. Charts should be **elegant, informative, and digestible at a glance**.
|
||||
|
||||
#### Chart Design Principles
|
||||
|
||||
1. **Clarity over complexity** - A chart should communicate ONE key insight immediately
|
||||
2. **Minimal visual noise** - Remove anything that doesn't add information
|
||||
3. **Consistent styling** - Use `ChartConfig` for colors, not arbitrary values
|
||||
4. **Responsive** - Always use `ChartContainer` (includes ResponsiveContainer)
|
||||
5. **Accessible** - Use `ChartTooltip` with `ChartTooltipContent` for proper styling
|
||||
|
||||
#### Chart Type Selection
|
||||
|
||||
| Data Type | Recommended Chart | Use Case |
|
||||
|-----------|-------------------|----------|
|
||||
| Trend over time | `LineChart` or `AreaChart` | Stock prices, user growth, metrics over days/months |
|
||||
| Comparing categories | `BarChart` | Revenue by product, users by region |
|
||||
| Part of whole | `PieChart` or `RadialBarChart` | Market share, budget allocation |
|
||||
| Distribution | `BarChart` (horizontal) | Survey responses, rating distribution |
|
||||
| Correlation | `ScatterChart` | Price vs. quality, age vs. income |
|
||||
|
||||
#### shadcn/ui Chart Components
|
||||
|
||||
Always import from the shadcn chart component:
|
||||
|
||||
```typescript
|
||||
import {
|
||||
ChartContainer,
|
||||
ChartTooltip,
|
||||
ChartTooltipContent,
|
||||
ChartLegend,
|
||||
ChartLegendContent,
|
||||
type ChartConfig,
|
||||
} from "@/components/ui/chart";
|
||||
import { LineChart, Line, XAxis, YAxis, CartesianGrid } from "recharts";
|
||||
```
|
||||
|
||||
#### ChartConfig - Define Colors and Labels
|
||||
|
||||
The `ChartConfig` object defines colors and labels for your data series. This ensures consistent theming:
|
||||
|
||||
```typescript
|
||||
const chartConfig = {
|
||||
revenue: {
|
||||
label: "Revenue",
|
||||
color: "var(--chart-1)",
|
||||
},
|
||||
expenses: {
|
||||
label: "Expenses",
|
||||
color: "var(--chart-2)",
|
||||
},
|
||||
} satisfies ChartConfig;
|
||||
```
|
||||
|
||||
#### Basic Line Chart Template
|
||||
|
||||
```typescript
|
||||
import {
|
||||
ChartContainer,
|
||||
ChartTooltip,
|
||||
ChartTooltipContent,
|
||||
type ChartConfig,
|
||||
} from "@/components/ui/chart";
|
||||
import { LineChart, Line, XAxis, YAxis, CartesianGrid } from "recharts";
|
||||
|
||||
const chartConfig = {
|
||||
value: {
|
||||
label: "Value",
|
||||
color: "var(--chart-1)",
|
||||
},
|
||||
} satisfies ChartConfig;
|
||||
|
||||
<ChartContainer config={chartConfig} className="h-[300px] w-full">
|
||||
<LineChart data={data} accessibilityLayer>
|
||||
<CartesianGrid vertical={false} />
|
||||
<XAxis
|
||||
dataKey="month"
|
||||
tickLine={false}
|
||||
axisLine={false}
|
||||
tickMargin={8}
|
||||
/>
|
||||
<YAxis tickLine={false} axisLine={false} tickMargin={8} />
|
||||
<ChartTooltip content={<ChartTooltipContent />} />
|
||||
<Line
|
||||
type="monotone"
|
||||
dataKey="value"
|
||||
stroke="var(--color-value)"
|
||||
strokeWidth={2}
|
||||
dot={false}
|
||||
/>
|
||||
</LineChart>
|
||||
</ChartContainer>
|
||||
```
|
||||
|
||||
#### Bar Chart with Multiple Series
|
||||
|
||||
```typescript
|
||||
const chartConfig = {
|
||||
revenue: {
|
||||
label: "Revenue",
|
||||
color: "var(--chart-1)",
|
||||
},
|
||||
expenses: {
|
||||
label: "Expenses",
|
||||
color: "var(--chart-2)",
|
||||
},
|
||||
} satisfies ChartConfig;
|
||||
|
||||
<ChartContainer config={chartConfig} className="h-[300px] w-full">
|
||||
<BarChart data={data} accessibilityLayer>
|
||||
<CartesianGrid vertical={false} />
|
||||
<XAxis dataKey="month" tickLine={false} axisLine={false} tickMargin={8} />
|
||||
<YAxis tickLine={false} axisLine={false} tickMargin={8} />
|
||||
<ChartTooltip content={<ChartTooltipContent />} />
|
||||
<ChartLegend content={<ChartLegendContent />} />
|
||||
<Bar dataKey="revenue" fill="var(--color-revenue)" radius={4} />
|
||||
<Bar dataKey="expenses" fill="var(--color-expenses)" radius={4} />
|
||||
</BarChart>
|
||||
</ChartContainer>
|
||||
```
|
||||
|
||||
#### Pie/Donut Chart
|
||||
|
||||
```typescript
|
||||
const chartConfig = {
|
||||
desktop: { label: "Desktop", color: "var(--chart-1)" },
|
||||
mobile: { label: "Mobile", color: "var(--chart-2)" },
|
||||
tablet: { label: "Tablet", color: "var(--chart-3)" },
|
||||
} satisfies ChartConfig;
|
||||
|
||||
<ChartContainer config={chartConfig} className="h-[300px] w-full">
|
||||
<PieChart>
|
||||
<ChartTooltip content={<ChartTooltipContent hideLabel />} />
|
||||
<Pie
|
||||
data={data}
|
||||
dataKey="value"
|
||||
nameKey="name"
|
||||
innerRadius={60} // Remove for solid pie, keep for donut
|
||||
strokeWidth={5}
|
||||
/>
|
||||
<ChartLegend content={<ChartLegendContent nameKey="name" />} />
|
||||
</PieChart>
|
||||
</ChartContainer>
|
||||
```
|
||||
|
||||
#### Chart Styling Rules
|
||||
|
||||
**Colors (use CSS variables from globals.css):**
|
||||
- `var(--chart-1)` through `var(--chart-5)` - Primary chart colors
|
||||
- `var(--primary)` - For single-series emphasis
|
||||
- `var(--muted)` - For de-emphasized data
|
||||
|
||||
**Color References in Charts:**
|
||||
- In `ChartConfig`: Use `color: "var(--chart-1)"`
|
||||
- In chart elements: Use `fill="var(--color-keyname)"` or `stroke="var(--color-keyname)"`
|
||||
- The `keyname` matches the key in your `ChartConfig`
|
||||
|
||||
**Visual Cleanup:**
|
||||
- Set `tickLine={false}` and `axisLine={false}` on axes for cleaner look
|
||||
- Use `vertical={false}` on `CartesianGrid` for horizontal-only grid lines
|
||||
- Use `dot={false}` on line charts unless individual points matter
|
||||
- Add `radius={4}` to bars for rounded corners
|
||||
- Limit to 3-5 data series maximum per chart
|
||||
|
||||
**Avoid:**
|
||||
- ❌ 3D effects
|
||||
- ❌ More than 5-6 colors in one chart
|
||||
- ❌ Legends with more than 5 items (simplify the data instead)
|
||||
- ❌ Dual Y-axes (confusing - use two separate charts)
|
||||
- ❌ Pie charts with more than 5-6 slices
|
||||
- ❌ Custom tooltip styling - use `ChartTooltipContent`
|
||||
|
||||
#### Fallback to Raw Recharts
|
||||
|
||||
If shadcn/ui Chart components don't support a specific chart type (e.g., ScatterChart, ComposedChart, RadarChart), you can use recharts directly:
|
||||
|
||||
```typescript
|
||||
import { ScatterChart, Scatter, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from "recharts";
|
||||
|
||||
<ResponsiveContainer width="100%" height={300}>
|
||||
<ScatterChart>
|
||||
<CartesianGrid strokeDasharray="3 3" stroke="var(--border)" />
|
||||
<XAxis dataKey="x" stroke="var(--muted-foreground)" fontSize={12} tickLine={false} axisLine={false} />
|
||||
<YAxis dataKey="y" stroke="var(--muted-foreground)" fontSize={12} tickLine={false} axisLine={false} />
|
||||
<Tooltip
|
||||
contentStyle={{
|
||||
backgroundColor: "var(--card)",
|
||||
border: "1px solid var(--border)",
|
||||
borderRadius: "6px"
|
||||
}}
|
||||
/>
|
||||
<Scatter data={data} fill="var(--chart-1)" />
|
||||
</ScatterChart>
|
||||
</ResponsiveContainer>
|
||||
```
|
||||
|
||||
**When using raw recharts:**
|
||||
- Still use CSS variables for colors (`var(--chart-1)`, etc.)
|
||||
- Match styling to shadcn conventions (tickLine={false}, axisLine={false})
|
||||
- Style tooltips to match the design system
|
||||
|
||||
#### Data Accuracy Checklist
|
||||
|
||||
Before displaying a chart, verify:
|
||||
- [ ] `ChartConfig` keys match your data's `dataKey` values
|
||||
- [ ] Data values are correctly mapped to the right axes
|
||||
- [ ] Axis labels match the data units (%, $, count, etc.)
|
||||
- [ ] Time series data is sorted chronologically
|
||||
- [ ] No missing data points that would break the visualization
|
||||
- [ ] `ChartTooltip` with `ChartTooltipContent` is included
|
||||
- [ ] Chart title/context makes the insight clear
|
||||
|
||||
### Common Patterns
|
||||
|
||||
#### Loading States
|
||||
|
||||
```typescript
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
|
||||
{isLoading ? (
|
||||
<Skeleton className="h-12 w-full" />
|
||||
) : (
|
||||
<Content />
|
||||
)}
|
||||
```
|
||||
|
||||
#### Empty States
|
||||
|
||||
```typescript
|
||||
import { Empty, EmptyHeader, EmptyTitle, EmptyDescription, EmptyMedia } from "@/components/ui/empty";
|
||||
import { Inbox } from "lucide-react";
|
||||
|
||||
<Empty>
|
||||
<EmptyHeader>
|
||||
<EmptyMedia variant="icon">
|
||||
<Inbox />
|
||||
</EmptyMedia>
|
||||
<EmptyTitle>No data available</EmptyTitle>
|
||||
<EmptyDescription>
|
||||
There's nothing to display yet. Add some items to get started.
|
||||
</EmptyDescription>
|
||||
</EmptyHeader>
|
||||
</Empty>
|
||||
```
|
||||
|
||||
#### Interactive Lists
|
||||
|
||||
```typescript
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { ItemGroup, Item, ItemContent, ItemTitle, ItemDescription, ItemMedia } from "@/components/ui/item";
|
||||
import { FileText } from "lucide-react";
|
||||
|
||||
<ScrollArea className="h-[400px]">
|
||||
<ItemGroup>
|
||||
{items.map((item) => (
|
||||
<Item key={item.id} variant="outline">
|
||||
<ItemMedia variant="icon">
|
||||
<FileText />
|
||||
</ItemMedia>
|
||||
<ItemContent>
|
||||
<ItemTitle>{item.name}</ItemTitle>
|
||||
<ItemDescription>{item.description}</ItemDescription>
|
||||
</ItemContent>
|
||||
</Item>
|
||||
))}
|
||||
</ItemGroup>
|
||||
</ScrollArea>
|
||||
```
|
||||
|
||||
#### Form Fields
|
||||
|
||||
```typescript
|
||||
import { Field, FieldLabel, FieldDescription, FieldError, FieldGroup } from "@/components/ui/field";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
<FieldGroup>
|
||||
<Field>
|
||||
<FieldLabel>Email</FieldLabel>
|
||||
<Input type="email" placeholder="you@example.com" />
|
||||
<FieldDescription>We'll never share your email.</FieldDescription>
|
||||
</Field>
|
||||
<Field>
|
||||
<FieldLabel>Password</FieldLabel>
|
||||
<Input type="password" />
|
||||
<FieldError>Password must be at least 8 characters.</FieldError>
|
||||
</Field>
|
||||
<Button type="submit">Sign up</Button>
|
||||
</FieldGroup>
|
||||
```
|
||||
|
||||
### What NOT To Do
|
||||
|
||||
❌ **Don't create custom styled divs when a component exists**
|
||||
❌ **Don't use arbitrary Tailwind colors** (use CSS variables)
|
||||
❌ **Don't import UI libraries** like Material-UI, Ant Design, etc.
|
||||
❌ **Don't use inline styles** except for dynamic values
|
||||
❌ **Don't create custom form inputs** (use Field, Input, Select, etc. from components/ui)
|
||||
❌ **Don't add new dependencies** without checking if shadcn covers it
|
||||
❌ **Don't write everything in page.tsx** - break into separate component files
|
||||
❌ **Don't design for light mode** - this site is dark mode only
|
||||
❌ **Don't use `dark:` variants** - dark mode is always active, use base classes
|
||||
|
||||
### Development Workflow
|
||||
|
||||
1. **Plan the component structure** - Identify logical UI sections before writing code
|
||||
2. **Create components incrementally** - Write one small component file at a time
|
||||
3. **Test each component** - Verify it works before moving to the next
|
||||
4. **Compose in page.tsx** - Import and arrange your components in the page
|
||||
5. **Iterate** - Refine individual components without touching others
|
||||
|
||||
### Summary
|
||||
|
||||
This application has a **complete, production-ready component library**. Your job is to:
|
||||
1. **Compose** shadcn/ui components (from `components/ui/`)
|
||||
2. **Create small, focused component files** (in `components/`)
|
||||
3. **Keep pages thin** - pages should orchestrate components, not contain implementation
|
||||
|
||||
Think of yourself as assembling LEGO blocks—all the UI pieces you need already exist in `components/ui/`, and you create small, organized structures by composing them into feature-specific components.
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 104 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 34 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 893 B |
Binary file not shown.
|
After Width: | Height: | Size: 2.7 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
Binary file not shown.
@@ -0,0 +1,127 @@
|
||||
@import "tailwindcss";
|
||||
@import "tw-animate-css";
|
||||
@import "shadcn/tailwind.css";
|
||||
|
||||
@custom-variant dark (&:is(.dark *));
|
||||
|
||||
@theme inline {
|
||||
--color-background: var(--background);
|
||||
--color-foreground: var(--foreground);
|
||||
--font-sans: var(--font-sans);
|
||||
--font-mono: var(--font-geist-mono);
|
||||
--color-sidebar-ring: var(--sidebar-ring);
|
||||
--color-sidebar-border: var(--sidebar-border);
|
||||
--color-sidebar-accent-foreground: var(--sidebar-accent-foreground);
|
||||
--color-sidebar-accent: var(--sidebar-accent);
|
||||
--color-sidebar-primary-foreground: var(--sidebar-primary-foreground);
|
||||
--color-sidebar-primary: var(--sidebar-primary);
|
||||
--color-sidebar-foreground: var(--sidebar-foreground);
|
||||
--color-sidebar: var(--sidebar);
|
||||
--color-chart-5: var(--chart-5);
|
||||
--color-chart-4: var(--chart-4);
|
||||
--color-chart-3: var(--chart-3);
|
||||
--color-chart-2: var(--chart-2);
|
||||
--color-chart-1: var(--chart-1);
|
||||
--color-ring: var(--ring);
|
||||
--color-input: var(--input);
|
||||
--color-border: var(--border);
|
||||
--color-destructive: var(--destructive);
|
||||
--color-accent-foreground: var(--accent-foreground);
|
||||
--color-accent: var(--accent);
|
||||
--color-muted-foreground: var(--muted-foreground);
|
||||
--color-muted: var(--muted);
|
||||
--color-secondary-foreground: var(--secondary-foreground);
|
||||
--color-secondary: var(--secondary);
|
||||
--color-primary-foreground: var(--primary-foreground);
|
||||
--color-primary: var(--primary);
|
||||
--color-popover-foreground: var(--popover-foreground);
|
||||
--color-popover: var(--popover);
|
||||
--color-card-foreground: var(--card-foreground);
|
||||
--color-card: var(--card);
|
||||
--radius-sm: calc(var(--radius) - 4px);
|
||||
--radius-md: calc(var(--radius) - 2px);
|
||||
--radius-lg: var(--radius);
|
||||
--radius-xl: calc(var(--radius) + 4px);
|
||||
--radius-2xl: calc(var(--radius) + 8px);
|
||||
--radius-3xl: calc(var(--radius) + 12px);
|
||||
--radius-4xl: calc(var(--radius) + 16px);
|
||||
}
|
||||
|
||||
:root {
|
||||
--background: oklch(1 0 0);
|
||||
--foreground: oklch(0.145 0 0);
|
||||
--card: oklch(1 0 0);
|
||||
--card-foreground: oklch(0.145 0 0);
|
||||
--popover: oklch(1 0 0);
|
||||
--popover-foreground: oklch(0.145 0 0);
|
||||
--primary: oklch(0.67 0.16 58);
|
||||
--primary-foreground: oklch(0.99 0.02 95);
|
||||
--secondary: oklch(0.967 0.001 286.375);
|
||||
--secondary-foreground: oklch(0.21 0.006 285.885);
|
||||
--muted: oklch(0.97 0 0);
|
||||
--muted-foreground: oklch(0.556 0 0);
|
||||
--accent: oklch(0.97 0 0);
|
||||
--accent-foreground: oklch(0.205 0 0);
|
||||
--destructive: oklch(0.58 0.22 27);
|
||||
--border: oklch(0.922 0 0);
|
||||
--input: oklch(0.922 0 0);
|
||||
--ring: oklch(0.708 0 0);
|
||||
--chart-1: oklch(0.88 0.15 92);
|
||||
--chart-2: oklch(0.77 0.16 70);
|
||||
--chart-3: oklch(0.67 0.16 58);
|
||||
--chart-4: oklch(0.56 0.15 49);
|
||||
--chart-5: oklch(0.47 0.12 46);
|
||||
--radius: 0.625rem;
|
||||
--sidebar: oklch(0.985 0 0);
|
||||
--sidebar-foreground: oklch(0.145 0 0);
|
||||
--sidebar-primary: oklch(0.67 0.16 58);
|
||||
--sidebar-primary-foreground: oklch(0.99 0.02 95);
|
||||
--sidebar-accent: oklch(0.97 0 0);
|
||||
--sidebar-accent-foreground: oklch(0.205 0 0);
|
||||
--sidebar-border: oklch(0.922 0 0);
|
||||
--sidebar-ring: oklch(0.708 0 0);
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background: oklch(0.145 0 0);
|
||||
--foreground: oklch(0.985 0 0);
|
||||
--card: oklch(0.205 0 0);
|
||||
--card-foreground: oklch(0.985 0 0);
|
||||
--popover: oklch(0.205 0 0);
|
||||
--popover-foreground: oklch(0.985 0 0);
|
||||
--primary: oklch(0.77 0.16 70);
|
||||
--primary-foreground: oklch(0.28 0.07 46);
|
||||
--secondary: oklch(0.274 0.006 286.033);
|
||||
--secondary-foreground: oklch(0.985 0 0);
|
||||
--muted: oklch(0.269 0 0);
|
||||
--muted-foreground: oklch(0.708 0 0);
|
||||
--accent: oklch(0.371 0 0);
|
||||
--accent-foreground: oklch(0.985 0 0);
|
||||
--destructive: oklch(0.704 0.191 22.216);
|
||||
--border: oklch(1 0 0 / 10%);
|
||||
--input: oklch(1 0 0 / 15%);
|
||||
--ring: oklch(0.556 0 0);
|
||||
/* Chart colors optimized for dark backgrounds - brighter and more vibrant */
|
||||
--chart-1: oklch(0.82 0.18 140);
|
||||
--chart-2: oklch(0.75 0.2 200);
|
||||
--chart-3: oklch(0.7 0.22 280);
|
||||
--chart-4: oklch(0.78 0.18 50);
|
||||
--chart-5: oklch(0.72 0.2 330);
|
||||
--sidebar: oklch(0.205 0 0);
|
||||
--sidebar-foreground: oklch(0.985 0 0);
|
||||
--sidebar-primary: oklch(0.77 0.16 70);
|
||||
--sidebar-primary-foreground: oklch(0.28 0.07 46);
|
||||
--sidebar-accent: oklch(0.269 0 0);
|
||||
--sidebar-accent-foreground: oklch(0.985 0 0);
|
||||
--sidebar-border: oklch(1 0 0 / 10%);
|
||||
--sidebar-ring: oklch(0.556 0 0);
|
||||
}
|
||||
|
||||
@layer base {
|
||||
* {
|
||||
@apply border-border outline-ring/50;
|
||||
}
|
||||
body {
|
||||
@apply bg-background text-foreground;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
import type { Metadata } from "next";
|
||||
import { Geist, Geist_Mono, Inter } from "next/font/google";
|
||||
import "./globals.css";
|
||||
|
||||
const inter = Inter({ subsets: ["latin"], variable: "--font-sans" });
|
||||
|
||||
const geistSans = Geist({
|
||||
variable: "--font-geist-sans",
|
||||
subsets: ["latin"],
|
||||
});
|
||||
|
||||
const geistMono = Geist_Mono({
|
||||
variable: "--font-geist-mono",
|
||||
subsets: ["latin"],
|
||||
});
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Onyx Craft",
|
||||
description: "Crafting your next great idea.",
|
||||
};
|
||||
|
||||
export default function RootLayout({
|
||||
children,
|
||||
}: Readonly<{
|
||||
children: React.ReactNode;
|
||||
}>) {
|
||||
return (
|
||||
<html lang="en" className={`${inter.variable} dark`}>
|
||||
<body
|
||||
className={`${geistSans.variable} ${geistMono.variable} antialiased`}
|
||||
>
|
||||
{children}
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,132 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
|
||||
const messages = [
|
||||
"Punching wood...",
|
||||
"Gathering resources...",
|
||||
"Placing blocks...",
|
||||
"Crafting your workspace...",
|
||||
"Mining for dependencies...",
|
||||
"Smelting the code...",
|
||||
"Enchanting with magic...",
|
||||
"World generation complete...",
|
||||
"/gamemode 1",
|
||||
];
|
||||
|
||||
const MESSAGE_COUNT = messages.length;
|
||||
const TYPE_DELAY = 40;
|
||||
const LINE_PAUSE = 800;
|
||||
const RESET_DELAY = 2000;
|
||||
|
||||
export default function CraftingLoader() {
|
||||
const [display, setDisplay] = useState({
|
||||
lines: [] as string[],
|
||||
currentText: "",
|
||||
});
|
||||
|
||||
const lineIndexRef = useRef(0);
|
||||
const charIndexRef = useRef(0);
|
||||
const lastUpdateRef = useRef(0);
|
||||
const timeoutRef = useRef<NodeJS.Timeout | undefined>(undefined);
|
||||
const rafRef = useRef<number | undefined>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
let isActive = true;
|
||||
|
||||
const update = (now: number) => {
|
||||
if (!isActive) return;
|
||||
|
||||
const lineIdx = lineIndexRef.current;
|
||||
const charIdx = charIndexRef.current;
|
||||
|
||||
if (lineIdx >= MESSAGE_COUNT) {
|
||||
timeoutRef.current = setTimeout(() => {
|
||||
if (!isActive) return;
|
||||
lineIndexRef.current = 0;
|
||||
charIndexRef.current = 0;
|
||||
setDisplay({ lines: [], currentText: "" });
|
||||
lastUpdateRef.current = performance.now();
|
||||
rafRef.current = requestAnimationFrame(update);
|
||||
}, RESET_DELAY);
|
||||
return;
|
||||
}
|
||||
|
||||
const msg = messages[lineIdx];
|
||||
if (!msg) return;
|
||||
|
||||
const elapsed = now - lastUpdateRef.current;
|
||||
|
||||
if (charIdx < msg.length) {
|
||||
if (elapsed >= TYPE_DELAY) {
|
||||
charIndexRef.current = charIdx + 1;
|
||||
setDisplay((prev) => ({
|
||||
lines: prev.lines,
|
||||
currentText: msg.substring(0, charIdx + 1),
|
||||
}));
|
||||
lastUpdateRef.current = now;
|
||||
}
|
||||
} else if (elapsed >= LINE_PAUSE) {
|
||||
setDisplay((prev) => ({
|
||||
lines: [...prev.lines, msg],
|
||||
currentText: "",
|
||||
}));
|
||||
lineIndexRef.current = lineIdx + 1;
|
||||
charIndexRef.current = 0;
|
||||
lastUpdateRef.current = now;
|
||||
}
|
||||
|
||||
rafRef.current = requestAnimationFrame(update);
|
||||
};
|
||||
|
||||
lastUpdateRef.current = performance.now();
|
||||
rafRef.current = requestAnimationFrame(update);
|
||||
|
||||
return () => {
|
||||
isActive = false;
|
||||
if (rafRef.current !== undefined) cancelAnimationFrame(rafRef.current);
|
||||
if (timeoutRef.current !== undefined) clearTimeout(timeoutRef.current);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const { lines, currentText } = display;
|
||||
const hasCurrentText = currentText.length > 0;
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-gradient-to-br from-neutral-950 via-neutral-900 to-neutral-950 flex flex-col items-center justify-center p-4">
|
||||
<div className="w-full max-w-md rounded-sm overflow-hidden shadow-2xl border-2 border-neutral-700">
|
||||
<div className="bg-neutral-800 px-4 py-3 flex items-center gap-2 border-b-2 border-neutral-700">
|
||||
<div className="w-3 h-3 rounded-none bg-red-500" />
|
||||
<div className="w-3 h-3 rounded-none bg-yellow-500" />
|
||||
<div className="w-3 h-3 rounded-none bg-green-500" />
|
||||
<span className="ml-4 text-neutral-500 text-sm font-mono">
|
||||
crafting_table
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<div className="bg-neutral-900 p-6 min-h-[250px] font-mono text-sm">
|
||||
{lines.map((line, i) => (
|
||||
<div key={i} className="flex items-center text-neutral-300">
|
||||
<span className="text-emerald-500 mr-2">/></span>
|
||||
<span>{line}</span>
|
||||
</div>
|
||||
))}
|
||||
{hasCurrentText && (
|
||||
<div className="flex items-center text-neutral-300">
|
||||
<span className="text-emerald-500 mr-2">/></span>
|
||||
<span>{currentText}</span>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex items-center text-neutral-300">
|
||||
<span className="text-emerald-500 mr-2">/></span>
|
||||
<span className="w-2 h-5 bg-emerald-500 animate-pulse" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<p className="mt-6 text-neutral-500 text-sm font-mono">
|
||||
Crafting your next great idea...
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
{"name":"","short_name":"","icons":[{"src":"/android-chrome-192x192.png","sizes":"192x192","type":"image/png"},{"src":"/android-chrome-512x512.png","sizes":"512x512","type":"image/png"}],"theme_color":"#ffffff","background_color":"#ffffff","display":"standalone"}
|
||||
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"style": "radix-nova",
|
||||
"rsc": true,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "",
|
||||
"css": "app/globals.css",
|
||||
"baseColor": "neutral",
|
||||
"cssVariables": true,
|
||||
"prefix": ""
|
||||
},
|
||||
"iconLibrary": "lucide",
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils",
|
||||
"ui": "@/components/ui",
|
||||
"lib": "@/lib",
|
||||
"hooks": "@/hooks"
|
||||
},
|
||||
"menuColor": "default",
|
||||
"menuAccent": "subtle",
|
||||
"registries": {}
|
||||
}
|
||||
@@ -0,0 +1,490 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
|
||||
import { Example, ExampleWrapper } from "@/components/example";
|
||||
import {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogMedia,
|
||||
AlertDialogTitle,
|
||||
AlertDialogTrigger,
|
||||
} from "@/components/ui/alert-dialog";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Card,
|
||||
CardAction,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardFooter,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import {
|
||||
Combobox,
|
||||
ComboboxContent,
|
||||
ComboboxEmpty,
|
||||
ComboboxInput,
|
||||
ComboboxItem,
|
||||
ComboboxList,
|
||||
} from "@/components/ui/combobox";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuCheckboxItem,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuGroup,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuLabel,
|
||||
DropdownMenuPortal,
|
||||
DropdownMenuRadioGroup,
|
||||
DropdownMenuRadioItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuShortcut,
|
||||
DropdownMenuSub,
|
||||
DropdownMenuSubContent,
|
||||
DropdownMenuSubTrigger,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Field, FieldGroup, FieldLabel } from "@/components/ui/field";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectGroup,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import {
|
||||
PlusIcon,
|
||||
BluetoothIcon,
|
||||
MoreVerticalIcon,
|
||||
FileIcon,
|
||||
FolderIcon,
|
||||
FolderOpenIcon,
|
||||
FileCodeIcon,
|
||||
MoreHorizontalIcon,
|
||||
FolderSearchIcon,
|
||||
SaveIcon,
|
||||
DownloadIcon,
|
||||
EyeIcon,
|
||||
LayoutIcon,
|
||||
PaletteIcon,
|
||||
SunIcon,
|
||||
MoonIcon,
|
||||
MonitorIcon,
|
||||
UserIcon,
|
||||
CreditCardIcon,
|
||||
SettingsIcon,
|
||||
KeyboardIcon,
|
||||
LanguagesIcon,
|
||||
BellIcon,
|
||||
MailIcon,
|
||||
ShieldIcon,
|
||||
HelpCircleIcon,
|
||||
FileTextIcon,
|
||||
LogOutIcon,
|
||||
} from "lucide-react";
|
||||
|
||||
export function ComponentExample() {
|
||||
return (
|
||||
<ExampleWrapper>
|
||||
<CardExample />
|
||||
<FormExample />
|
||||
</ExampleWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
function CardExample() {
|
||||
return (
|
||||
<Example title="Card" className="items-center justify-center">
|
||||
<Card className="relative w-full max-w-sm overflow-hidden pt-0">
|
||||
<div className="bg-primary absolute inset-0 z-30 aspect-video opacity-50 mix-blend-color" />
|
||||
<img
|
||||
src="https://images.unsplash.com/photo-1604076850742-4c7221f3101b?q=80&w=1887&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
|
||||
alt="Photo by mymind on Unsplash"
|
||||
title="Photo by mymind on Unsplash"
|
||||
className="relative z-20 aspect-video w-full object-cover brightness-60 grayscale"
|
||||
/>
|
||||
<CardHeader>
|
||||
<CardTitle>Observability Plus is replacing Monitoring</CardTitle>
|
||||
<CardDescription>
|
||||
Switch to the improved way to explore your data, with natural
|
||||
language. Monitoring will no longer be available on the Pro plan in
|
||||
November, 2025
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardFooter>
|
||||
<AlertDialog>
|
||||
<AlertDialogTrigger asChild>
|
||||
<Button>
|
||||
<PlusIcon data-icon="inline-start" />
|
||||
Show Dialog
|
||||
</Button>
|
||||
</AlertDialogTrigger>
|
||||
<AlertDialogContent size="sm">
|
||||
<AlertDialogHeader>
|
||||
<AlertDialogMedia>
|
||||
<BluetoothIcon />
|
||||
</AlertDialogMedia>
|
||||
<AlertDialogTitle>Allow accessory to connect?</AlertDialogTitle>
|
||||
<AlertDialogDescription>
|
||||
Do you want to allow the USB accessory to connect to this
|
||||
device?
|
||||
</AlertDialogDescription>
|
||||
</AlertDialogHeader>
|
||||
<AlertDialogFooter>
|
||||
<AlertDialogCancel>Don't allow</AlertDialogCancel>
|
||||
<AlertDialogAction>Allow</AlertDialogAction>
|
||||
</AlertDialogFooter>
|
||||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
<Badge variant="secondary" className="ml-auto">
|
||||
Warning
|
||||
</Badge>
|
||||
</CardFooter>
|
||||
</Card>
|
||||
</Example>
|
||||
);
|
||||
}
|
||||
|
||||
const frameworks = [
|
||||
"Next.js",
|
||||
"SvelteKit",
|
||||
"Nuxt.js",
|
||||
"Remix",
|
||||
"Astro",
|
||||
] as const;
|
||||
|
||||
function FormExample() {
|
||||
const [notifications, setNotifications] = React.useState({
|
||||
email: true,
|
||||
sms: false,
|
||||
push: true,
|
||||
});
|
||||
const [theme, setTheme] = React.useState("light");
|
||||
|
||||
return (
|
||||
<Example title="Form">
|
||||
<Card className="w-full max-w-md">
|
||||
<CardHeader>
|
||||
<CardTitle>User Information</CardTitle>
|
||||
<CardDescription>Please fill in your details below</CardDescription>
|
||||
<CardAction>
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="ghost" size="icon">
|
||||
<MoreVerticalIcon />
|
||||
<span className="sr-only">More options</span>
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end" className="w-56">
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuLabel>File</DropdownMenuLabel>
|
||||
<DropdownMenuItem>
|
||||
<FileIcon />
|
||||
New File
|
||||
<DropdownMenuShortcut>⌘N</DropdownMenuShortcut>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem>
|
||||
<FolderIcon />
|
||||
New Folder
|
||||
<DropdownMenuShortcut>⇧⌘N</DropdownMenuShortcut>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger>
|
||||
<FolderOpenIcon />
|
||||
Open Recent
|
||||
</DropdownMenuSubTrigger>
|
||||
<DropdownMenuPortal>
|
||||
<DropdownMenuSubContent>
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuLabel>Recent Projects</DropdownMenuLabel>
|
||||
<DropdownMenuItem>
|
||||
<FileCodeIcon />
|
||||
Project Alpha
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem>
|
||||
<FileCodeIcon />
|
||||
Project Beta
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger>
|
||||
<MoreHorizontalIcon />
|
||||
More Projects
|
||||
</DropdownMenuSubTrigger>
|
||||
<DropdownMenuPortal>
|
||||
<DropdownMenuSubContent>
|
||||
<DropdownMenuItem>
|
||||
<FileCodeIcon />
|
||||
Project Gamma
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem>
|
||||
<FileCodeIcon />
|
||||
Project Delta
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuSubContent>
|
||||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
</DropdownMenuGroup>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuItem>
|
||||
<FolderSearchIcon />
|
||||
Browse...
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuGroup>
|
||||
</DropdownMenuSubContent>
|
||||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem>
|
||||
<SaveIcon />
|
||||
Save
|
||||
<DropdownMenuShortcut>⌘S</DropdownMenuShortcut>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem>
|
||||
<DownloadIcon />
|
||||
Export
|
||||
<DropdownMenuShortcut>⇧⌘E</DropdownMenuShortcut>
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuGroup>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuLabel>View</DropdownMenuLabel>
|
||||
<DropdownMenuCheckboxItem
|
||||
checked={notifications.email}
|
||||
onCheckedChange={(checked) =>
|
||||
setNotifications({
|
||||
...notifications,
|
||||
email: checked === true,
|
||||
})
|
||||
}
|
||||
>
|
||||
<EyeIcon />
|
||||
Show Sidebar
|
||||
</DropdownMenuCheckboxItem>
|
||||
<DropdownMenuCheckboxItem
|
||||
checked={notifications.sms}
|
||||
onCheckedChange={(checked) =>
|
||||
setNotifications({
|
||||
...notifications,
|
||||
sms: checked === true,
|
||||
})
|
||||
}
|
||||
>
|
||||
<LayoutIcon />
|
||||
Show Status Bar
|
||||
</DropdownMenuCheckboxItem>
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger>
|
||||
<PaletteIcon />
|
||||
Theme
|
||||
</DropdownMenuSubTrigger>
|
||||
<DropdownMenuPortal>
|
||||
<DropdownMenuSubContent>
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuLabel>Appearance</DropdownMenuLabel>
|
||||
<DropdownMenuRadioGroup
|
||||
value={theme}
|
||||
onValueChange={setTheme}
|
||||
>
|
||||
<DropdownMenuRadioItem value="light">
|
||||
<SunIcon />
|
||||
Light
|
||||
</DropdownMenuRadioItem>
|
||||
<DropdownMenuRadioItem value="dark">
|
||||
<MoonIcon />
|
||||
Dark
|
||||
</DropdownMenuRadioItem>
|
||||
<DropdownMenuRadioItem value="system">
|
||||
<MonitorIcon />
|
||||
System
|
||||
</DropdownMenuRadioItem>
|
||||
</DropdownMenuRadioGroup>
|
||||
</DropdownMenuGroup>
|
||||
</DropdownMenuSubContent>
|
||||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
</DropdownMenuGroup>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuLabel>Account</DropdownMenuLabel>
|
||||
<DropdownMenuItem>
|
||||
<UserIcon />
|
||||
Profile
|
||||
<DropdownMenuShortcut>⇧⌘P</DropdownMenuShortcut>
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem>
|
||||
<CreditCardIcon />
|
||||
Billing
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger>
|
||||
<SettingsIcon />
|
||||
Settings
|
||||
</DropdownMenuSubTrigger>
|
||||
<DropdownMenuPortal>
|
||||
<DropdownMenuSubContent>
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuLabel>Preferences</DropdownMenuLabel>
|
||||
<DropdownMenuItem>
|
||||
<KeyboardIcon />
|
||||
Keyboard Shortcuts
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem>
|
||||
<LanguagesIcon />
|
||||
Language
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSub>
|
||||
<DropdownMenuSubTrigger>
|
||||
<BellIcon />
|
||||
Notifications
|
||||
</DropdownMenuSubTrigger>
|
||||
<DropdownMenuPortal>
|
||||
<DropdownMenuSubContent>
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuLabel>
|
||||
Notification Types
|
||||
</DropdownMenuLabel>
|
||||
<DropdownMenuCheckboxItem
|
||||
checked={notifications.push}
|
||||
onCheckedChange={(checked) =>
|
||||
setNotifications({
|
||||
...notifications,
|
||||
push: checked === true,
|
||||
})
|
||||
}
|
||||
>
|
||||
<BellIcon />
|
||||
Push Notifications
|
||||
</DropdownMenuCheckboxItem>
|
||||
<DropdownMenuCheckboxItem
|
||||
checked={notifications.email}
|
||||
onCheckedChange={(checked) =>
|
||||
setNotifications({
|
||||
...notifications,
|
||||
email: checked === true,
|
||||
})
|
||||
}
|
||||
>
|
||||
<MailIcon />
|
||||
Email Notifications
|
||||
</DropdownMenuCheckboxItem>
|
||||
</DropdownMenuGroup>
|
||||
</DropdownMenuSubContent>
|
||||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
</DropdownMenuGroup>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuItem>
|
||||
<ShieldIcon />
|
||||
Privacy & Security
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuGroup>
|
||||
</DropdownMenuSubContent>
|
||||
</DropdownMenuPortal>
|
||||
</DropdownMenuSub>
|
||||
</DropdownMenuGroup>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuItem>
|
||||
<HelpCircleIcon />
|
||||
Help & Support
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem>
|
||||
<FileTextIcon />
|
||||
Documentation
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuGroup>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuGroup>
|
||||
<DropdownMenuItem variant="destructive">
|
||||
<LogOutIcon />
|
||||
Sign Out
|
||||
<DropdownMenuShortcut>⇧⌘Q</DropdownMenuShortcut>
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuGroup>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</CardAction>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<form>
|
||||
<FieldGroup>
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<Field>
|
||||
<FieldLabel htmlFor="small-form-name">Name</FieldLabel>
|
||||
<Input
|
||||
id="small-form-name"
|
||||
placeholder="Enter your name"
|
||||
required
|
||||
/>
|
||||
</Field>
|
||||
<Field>
|
||||
<FieldLabel htmlFor="small-form-role">Role</FieldLabel>
|
||||
<Select defaultValue="">
|
||||
<SelectTrigger id="small-form-role">
|
||||
<SelectValue placeholder="Select a role" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<SelectGroup>
|
||||
<SelectItem value="developer">Developer</SelectItem>
|
||||
<SelectItem value="designer">Designer</SelectItem>
|
||||
<SelectItem value="manager">Manager</SelectItem>
|
||||
<SelectItem value="other">Other</SelectItem>
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</Field>
|
||||
</div>
|
||||
<Field>
|
||||
<FieldLabel htmlFor="small-form-framework">
|
||||
Framework
|
||||
</FieldLabel>
|
||||
<Combobox items={frameworks}>
|
||||
<ComboboxInput
|
||||
id="small-form-framework"
|
||||
placeholder="Select a framework"
|
||||
required
|
||||
/>
|
||||
<ComboboxContent>
|
||||
<ComboboxEmpty>No frameworks found.</ComboboxEmpty>
|
||||
<ComboboxList>
|
||||
{(item) => (
|
||||
<ComboboxItem key={item} value={item}>
|
||||
{item}
|
||||
</ComboboxItem>
|
||||
)}
|
||||
</ComboboxList>
|
||||
</ComboboxContent>
|
||||
</Combobox>
|
||||
</Field>
|
||||
<Field>
|
||||
<FieldLabel htmlFor="small-form-comments">Comments</FieldLabel>
|
||||
<Textarea
|
||||
id="small-form-comments"
|
||||
placeholder="Add any additional comments"
|
||||
/>
|
||||
</Field>
|
||||
<Field orientation="horizontal">
|
||||
<Button type="submit">Submit</Button>
|
||||
<Button variant="outline" type="button">
|
||||
Cancel
|
||||
</Button>
|
||||
</Field>
|
||||
</FieldGroup>
|
||||
</form>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</Example>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
function ExampleWrapper({ className, ...props }: React.ComponentProps<"div">) {
|
||||
return (
|
||||
<div className="bg-background w-full">
|
||||
<div
|
||||
data-slot="example-wrapper"
|
||||
className={cn(
|
||||
"mx-auto grid min-h-screen w-full max-w-5xl min-w-0 content-center items-start gap-8 p-4 pt-2 sm:gap-12 sm:p-6 md:grid-cols-2 md:gap-8 lg:p-12 2xl:max-w-6xl",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function Example({
|
||||
title,
|
||||
children,
|
||||
className,
|
||||
containerClassName,
|
||||
...props
|
||||
}: React.ComponentProps<"div"> & {
|
||||
title?: string;
|
||||
containerClassName?: string;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
data-slot="example"
|
||||
className={cn(
|
||||
"mx-auto flex w-full max-w-lg min-w-0 flex-col gap-1 self-stretch lg:max-w-none",
|
||||
containerClassName,
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{title && (
|
||||
<div className="text-muted-foreground px-1.5 py-2 text-xs font-medium">
|
||||
{title}
|
||||
</div>
|
||||
)}
|
||||
<div
|
||||
data-slot="example-content"
|
||||
className={cn(
|
||||
"bg-background text-foreground flex min-w-0 flex-1 flex-col items-start gap-6 border border-dashed p-4 sm:p-6 *:[div:not([class*='w-'])]:w-full",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export { ExampleWrapper, Example };
|
||||
@@ -0,0 +1,87 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { Accordion as AccordionPrimitive } from "radix-ui";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChevronDownIcon, ChevronUpIcon } from "lucide-react";
|
||||
|
||||
function Accordion({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof AccordionPrimitive.Root>) {
|
||||
return (
|
||||
<AccordionPrimitive.Root
|
||||
data-slot="accordion"
|
||||
className={cn("flex w-full flex-col", className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AccordionItem({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof AccordionPrimitive.Item>) {
|
||||
return (
|
||||
<AccordionPrimitive.Item
|
||||
data-slot="accordion-item"
|
||||
className={cn("not-last:border-b", className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AccordionTrigger({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof AccordionPrimitive.Trigger>) {
|
||||
return (
|
||||
<AccordionPrimitive.Header className="flex">
|
||||
<AccordionPrimitive.Trigger
|
||||
data-slot="accordion-trigger"
|
||||
className={cn(
|
||||
"focus-visible:ring-ring/50 focus-visible:border-ring focus-visible:after:border-ring **:data-[slot=accordion-trigger-icon]:text-muted-foreground rounded-lg py-2.5 text-left text-sm font-medium hover:underline focus-visible:ring-[3px] **:data-[slot=accordion-trigger-icon]:ml-auto **:data-[slot=accordion-trigger-icon]:size-4 group/accordion-trigger relative flex flex-1 items-start justify-between border border-transparent transition-all outline-none disabled:pointer-events-none disabled:opacity-50",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
<ChevronDownIcon
|
||||
data-slot="accordion-trigger-icon"
|
||||
className="pointer-events-none shrink-0 group-aria-expanded/accordion-trigger:hidden"
|
||||
/>
|
||||
<ChevronUpIcon
|
||||
data-slot="accordion-trigger-icon"
|
||||
className="pointer-events-none hidden shrink-0 group-aria-expanded/accordion-trigger:inline"
|
||||
/>
|
||||
</AccordionPrimitive.Trigger>
|
||||
</AccordionPrimitive.Header>
|
||||
);
|
||||
}
|
||||
|
||||
function AccordionContent({
|
||||
className,
|
||||
children,
|
||||
...props
|
||||
}: React.ComponentProps<typeof AccordionPrimitive.Content>) {
|
||||
return (
|
||||
<AccordionPrimitive.Content
|
||||
data-slot="accordion-content"
|
||||
className="data-open:animate-accordion-down data-closed:animate-accordion-up text-sm overflow-hidden"
|
||||
{...props}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"pt-0 pb-2.5 [&_a]:hover:text-foreground h-(--radix-accordion-content-height) [&_a]:underline [&_a]:underline-offset-3 [&_p:not(:last-child)]:mb-4",
|
||||
className,
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
</AccordionPrimitive.Content>
|
||||
);
|
||||
}
|
||||
|
||||
export { Accordion, AccordionItem, AccordionTrigger, AccordionContent };
|
||||
@@ -0,0 +1,199 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import { AlertDialog as AlertDialogPrimitive } from "radix-ui";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
function AlertDialog({
|
||||
...props
|
||||
}: React.ComponentProps<typeof AlertDialogPrimitive.Root>) {
|
||||
return <AlertDialogPrimitive.Root data-slot="alert-dialog" {...props} />;
|
||||
}
|
||||
|
||||
function AlertDialogTrigger({
|
||||
...props
|
||||
}: React.ComponentProps<typeof AlertDialogPrimitive.Trigger>) {
|
||||
return (
|
||||
<AlertDialogPrimitive.Trigger data-slot="alert-dialog-trigger" {...props} />
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogPortal({
|
||||
...props
|
||||
}: React.ComponentProps<typeof AlertDialogPrimitive.Portal>) {
|
||||
return (
|
||||
<AlertDialogPrimitive.Portal data-slot="alert-dialog-portal" {...props} />
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogOverlay({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof AlertDialogPrimitive.Overlay>) {
|
||||
return (
|
||||
<AlertDialogPrimitive.Overlay
|
||||
data-slot="alert-dialog-overlay"
|
||||
className={cn(
|
||||
"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 bg-black/10 duration-100 supports-backdrop-filter:backdrop-blur-xs fixed inset-0 z-50",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogContent({
|
||||
className,
|
||||
size = "default",
|
||||
...props
|
||||
}: React.ComponentProps<typeof AlertDialogPrimitive.Content> & {
|
||||
size?: "default" | "sm";
|
||||
}) {
|
||||
return (
|
||||
<AlertDialogPortal>
|
||||
<AlertDialogOverlay />
|
||||
<AlertDialogPrimitive.Content
|
||||
data-slot="alert-dialog-content"
|
||||
data-size={size}
|
||||
className={cn(
|
||||
"data-open:animate-in data-closed:animate-out data-closed:fade-out-0 data-open:fade-in-0 data-closed:zoom-out-95 data-open:zoom-in-95 bg-background ring-foreground/10 gap-4 rounded-xl p-4 ring-1 duration-100 data-[size=default]:max-w-xs data-[size=sm]:max-w-xs data-[size=default]:sm:max-w-sm group/alert-dialog-content fixed top-1/2 left-1/2 z-50 grid w-full -translate-x-1/2 -translate-y-1/2 outline-none",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
</AlertDialogPortal>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogHeader({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<"div">) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert-dialog-header"
|
||||
className={cn(
|
||||
"grid grid-rows-[auto_1fr] place-items-center gap-1.5 text-center has-data-[slot=alert-dialog-media]:grid-rows-[auto_auto_1fr] has-data-[slot=alert-dialog-media]:gap-x-4 sm:group-data-[size=default]/alert-dialog-content:place-items-start sm:group-data-[size=default]/alert-dialog-content:text-left sm:group-data-[size=default]/alert-dialog-content:has-data-[slot=alert-dialog-media]:grid-rows-[auto_1fr]",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogFooter({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<"div">) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert-dialog-footer"
|
||||
className={cn(
|
||||
"bg-muted/50 -mx-4 -mb-4 rounded-b-xl border-t p-4 flex flex-col-reverse gap-2 group-data-[size=sm]/alert-dialog-content:grid group-data-[size=sm]/alert-dialog-content:grid-cols-2 sm:flex-row sm:justify-end",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogMedia({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<"div">) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert-dialog-media"
|
||||
className={cn(
|
||||
"bg-muted mb-2 inline-flex size-10 items-center justify-center rounded-md sm:group-data-[size=default]/alert-dialog-content:row-span-2 *:[svg:not([class*='size-'])]:size-6",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogTitle({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof AlertDialogPrimitive.Title>) {
|
||||
return (
|
||||
<AlertDialogPrimitive.Title
|
||||
data-slot="alert-dialog-title"
|
||||
className={cn(
|
||||
"text-base font-medium sm:group-data-[size=default]/alert-dialog-content:group-has-data-[slot=alert-dialog-media]/alert-dialog-content:col-start-2",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogDescription({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<typeof AlertDialogPrimitive.Description>) {
|
||||
return (
|
||||
<AlertDialogPrimitive.Description
|
||||
data-slot="alert-dialog-description"
|
||||
className={cn(
|
||||
"text-muted-foreground *:[a]:hover:text-foreground text-sm text-balance md:text-pretty *:[a]:underline *:[a]:underline-offset-3",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogAction({
|
||||
className,
|
||||
variant = "default",
|
||||
size = "default",
|
||||
...props
|
||||
}: React.ComponentProps<typeof AlertDialogPrimitive.Action> &
|
||||
Pick<React.ComponentProps<typeof Button>, "variant" | "size">) {
|
||||
return (
|
||||
<Button variant={variant} size={size} asChild>
|
||||
<AlertDialogPrimitive.Action
|
||||
data-slot="alert-dialog-action"
|
||||
className={cn(className)}
|
||||
{...props}
|
||||
/>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDialogCancel({
|
||||
className,
|
||||
variant = "outline",
|
||||
size = "default",
|
||||
...props
|
||||
}: React.ComponentProps<typeof AlertDialogPrimitive.Cancel> &
|
||||
Pick<React.ComponentProps<typeof Button>, "variant" | "size">) {
|
||||
return (
|
||||
<Button variant={variant} size={size} asChild>
|
||||
<AlertDialogPrimitive.Cancel
|
||||
data-slot="alert-dialog-cancel"
|
||||
className={cn(className)}
|
||||
{...props}
|
||||
/>
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
|
||||
export {
|
||||
AlertDialog,
|
||||
AlertDialogAction,
|
||||
AlertDialogCancel,
|
||||
AlertDialogContent,
|
||||
AlertDialogDescription,
|
||||
AlertDialogFooter,
|
||||
AlertDialogHeader,
|
||||
AlertDialogMedia,
|
||||
AlertDialogOverlay,
|
||||
AlertDialogPortal,
|
||||
AlertDialogTitle,
|
||||
AlertDialogTrigger,
|
||||
};
|
||||
@@ -0,0 +1,76 @@
|
||||
import * as React from "react";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const alertVariants = cva(
|
||||
"grid gap-0.5 rounded-lg border px-2.5 py-2 text-left text-sm has-data-[slot=alert-action]:relative has-data-[slot=alert-action]:pr-18 has-[>svg]:grid-cols-[auto_1fr] has-[>svg]:gap-x-2 *:[svg]:row-span-2 *:[svg]:translate-y-0.5 *:[svg]:text-current *:[svg:not([class*='size-'])]:size-4 w-full relative group/alert",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
default: "bg-card text-card-foreground",
|
||||
destructive:
|
||||
"text-destructive bg-card *:data-[slot=alert-description]:text-destructive/90 *:[svg]:text-current",
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: "default",
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
function Alert({
|
||||
className,
|
||||
variant,
|
||||
...props
|
||||
}: React.ComponentProps<"div"> & VariantProps<typeof alertVariants>) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert"
|
||||
role="alert"
|
||||
className={cn(alertVariants({ variant }), className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertTitle({ className, ...props }: React.ComponentProps<"div">) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert-title"
|
||||
className={cn(
|
||||
"font-medium group-has-[>svg]/alert:col-start-2 [&_a]:hover:text-foreground [&_a]:underline [&_a]:underline-offset-3",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertDescription({
|
||||
className,
|
||||
...props
|
||||
}: React.ComponentProps<"div">) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert-description"
|
||||
className={cn(
|
||||
"text-muted-foreground text-sm text-balance md:text-pretty [&_p:not(:last-child)]:mb-4 [&_a]:hover:text-foreground [&_a]:underline [&_a]:underline-offset-3",
|
||||
className,
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
function AlertAction({ className, ...props }: React.ComponentProps<"div">) {
|
||||
return (
|
||||
<div
|
||||
data-slot="alert-action"
|
||||
className={cn("absolute top-2 right-2", className)}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export { Alert, AlertTitle, AlertDescription, AlertAction };
|
||||
@@ -0,0 +1,11 @@
|
||||
"use client";
|
||||
|
||||
import { AspectRatio as AspectRatioPrimitive } from "radix-ui";
|
||||
|
||||
function AspectRatio({
|
||||
...props
|
||||
}: React.ComponentProps<typeof AspectRatioPrimitive.Root>) {
|
||||
return <AspectRatioPrimitive.Root data-slot="aspect-ratio" {...props} />;
|
||||
}
|
||||
|
||||
export { AspectRatio };
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user