Compare commits

...

4 Commits

Author SHA1 Message Date
Wenxi Onyx
acf21736b4 . 2026-03-06 20:32:31 -08:00
Wenxi Onyx
3c9ec205e8 . 2026-03-06 20:28:52 -08:00
Wenxi Onyx
4114dd94b0 remove api 2026-03-06 18:05:48 -08:00
Wenxi Onyx
175258dc7d feat: rotate encryption key utility 2026-03-06 18:05:47 -08:00
6 changed files with 778 additions and 35 deletions

View File

@@ -14,67 +14,82 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
@lru_cache(maxsize=1)
@lru_cache(maxsize=2)
def _get_trimmed_key(key: str) -> bytes:
encoded_key = key.encode()
key_length = len(encoded_key)
if key_length < 16:
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
elif key_length > 32:
key = key[:32]
elif key_length not in (16, 24, 32):
valid_lengths = [16, 24, 32]
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
return encoded_key
# Trim to the largest valid AES key size that fits
valid_lengths = [32, 24, 16]
for size in valid_lengths:
if key_length >= size:
return encoded_key[:size]
def _encrypt_string(input_str: str) -> bytes:
if not ENCRYPTION_KEY_SECRET:
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
if not effective_key:
return input_str.encode()
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
trimmed = _get_trimmed_key(effective_key)
iv = urandom(16)
padder = padding.PKCS7(algorithms.AES.block_size).padder()
padded_data = padder.update(input_str.encode()) + padder.finalize()
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
return iv + encrypted_data
def _decrypt_bytes(input_bytes: bytes) -> str:
if not ENCRYPTION_KEY_SECRET:
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
if not effective_key:
return input_bytes.decode()
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
iv = input_bytes[:16]
encrypted_data = input_bytes[16:]
trimmed = _get_trimmed_key(effective_key)
try:
iv = input_bytes[:16]
encrypted_data = input_bytes[16:]
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
cipher = Cipher(
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
)
decryptor = cipher.decryptor()
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
return decrypted_data.decode()
return decrypted_data.decode()
except (ValueError, UnicodeDecodeError):
if key is not None:
# Explicit key was provided — don't fall back silently
raise
# Read path: fall back to raw decode for key rotation compatibility.
# Run the re-encrypt-secrets script to rotate to the current key.
logger.warning(
"AES decryption failed — falling back to raw decode. "
"Run the re-encrypt secrets script to rotate to the current key."
)
return input_bytes.decode()
def encrypt_string_to_bytes(input_str: str) -> bytes:
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
versioned_encryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_encrypt_string"
)
return versioned_encryption_fn(input_str)
return versioned_encryption_fn(input_str, key=key)
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
versioned_decryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_decrypt_bytes"
)
return versioned_decryption_fn(input_bytes)
return versioned_decryption_fn(input_bytes, key=key)
def test_encryption() -> None:

View File

@@ -0,0 +1,157 @@
"""Rotate encryption key for all encrypted columns.
Dynamically discovers all columns using EncryptedString / EncryptedJson,
decrypts each value with the old key, and re-encrypts with the current
ENCRYPTION_KEY_SECRET.
The operation is idempotent: rows already encrypted with the current key
are skipped. Commits are made in batches so a crash mid-rotation can be
safely resumed by re-running.
"""
import json
from typing import Any
from sqlalchemy import LargeBinary
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
from onyx.db.models import Base
from onyx.db.models import EncryptedJson
from onyx.db.models import EncryptedString
from onyx.utils.encryption import decrypt_bytes_to_string
from onyx.utils.logger import setup_logger
logger = setup_logger()
_BATCH_SIZE = 500
def _can_decrypt_with_current_key(data: bytes) -> bool:
"""Check if data is already encrypted with the current key.
Passes the key explicitly so the fallback-to-raw-decode path in
_decrypt_bytes is NOT triggered — a clean success/failure signal.
"""
try:
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
return True
except Exception:
return False
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
"""
results: list[tuple[type, str, list[str], bool]] = []
for mapper in Base.registry.mappers:
model_cls = mapper.class_
pk_names = [col.key for col in mapper.primary_key]
for prop in mapper.column_attrs:
for col in prop.columns:
if isinstance(col.type, EncryptedJson):
results.append((model_cls, prop.key, pk_names, True))
elif isinstance(col.type, EncryptedString):
results.append((model_cls, prop.key, pk_names, False))
return results
def rotate_encryption_key(
db_session: Session,
old_key: str | None,
dry_run: bool = False,
) -> dict[str, int]:
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
Args:
db_session: Active database session.
old_key: The previous encryption key. Pass None or "" if values were
not previously encrypted with a key.
dry_run: If True, count rows that need rotation without modifying data.
Returns:
Dict of "table.column" -> number of rows re-encrypted (or would be).
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
is preserved on crash. Already-rotated rows are detected and skipped,
making the operation safe to re-run.
"""
if not ENCRYPTION_KEY_SECRET:
raise RuntimeError(
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
"Set the target encryption key in the environment before running."
)
encrypted_columns = _discover_encrypted_columns()
totals: dict[str, int] = {}
for model_cls, col_name, pk_names, is_json in encrypted_columns:
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
col_attr = getattr(model_cls, col_name)
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
# Read raw bytes directly, bypassing the TypeDecorator
raw_col = col_attr.property.columns[0]
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
rows = db_session.execute(stmt).all()
reencrypted = 0
batch_pending = 0
for row in rows:
raw_bytes: bytes | None = row[-1]
if raw_bytes is None:
continue
if _can_decrypt_with_current_key(raw_bytes):
continue
try:
if not old_key:
decrypted_str = raw_bytes.decode("utf-8")
else:
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
except (ValueError, UnicodeDecodeError) as e:
pk_vals = [row[i] for i in range(len(pk_names))]
logger.warning(
f"Could not decrypt {table_name}.{col_name} "
f"row {pk_vals} — skipping: {e}"
)
continue
if not dry_run:
# For EncryptedJson, parse back to dict so the TypeDecorator
# can json.dumps() it cleanly (avoids double-encoding).
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
update_stmt = (
update(model_cls).where(*pk_filters).values({col_name: value})
)
db_session.execute(update_stmt)
batch_pending += 1
if batch_pending >= _BATCH_SIZE:
db_session.commit()
batch_pending = 0
reencrypted += 1
# Flush remaining rows in this column
if batch_pending > 0:
db_session.commit()
if reencrypted > 0:
totals[f"{table_name}.{col_name}"] = reencrypted
logger.info(
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
f"{reencrypted} value(s) in {table_name}.{col_name}"
)
return totals

View File

@@ -11,16 +11,18 @@ logger = setup_logger()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _encrypt_string(input_str: str) -> bytes:
if ENCRYPTION_KEY_SECRET:
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
if ENCRYPTION_KEY_SECRET or key:
logger.warning("MIT version of Onyx does not support encryption of secrets.")
return input_str.encode()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _decrypt_bytes(input_bytes: bytes) -> str:
# No need to double warn. If you wish to learn more about encryption features
# refer to the Onyx EE code
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: # noqa: ARG001
if key is not None:
logger.warning(
"MIT version of Onyx does not support decryption with a provided key."
)
return input_bytes.decode()
@@ -86,15 +88,15 @@ def _mask_list(items: list[Any]) -> list[Any]:
return masked
def encrypt_string_to_bytes(intput_str: str) -> bytes:
def encrypt_string_to_bytes(intput_str: str, key: str | None = None) -> bytes:
versioned_encryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_encrypt_string"
)
return versioned_encryption_fn(intput_str)
return versioned_encryption_fn(intput_str, key=key)
def decrypt_bytes_to_string(intput_bytes: bytes) -> str:
def decrypt_bytes_to_string(intput_bytes: bytes, key: str | None = None) -> str:
versioned_decryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_decrypt_bytes"
)
return versioned_decryption_fn(intput_bytes)
return versioned_decryption_fn(intput_bytes, key=key)

View File

@@ -0,0 +1,99 @@
"""Re-encrypt secrets under the current ENCRYPTION_KEY_SECRET.
Decrypts all encrypted columns using the old key (or raw decode if the old key
is empty), then re-encrypts them with the current ENCRYPTION_KEY_SECRET.
Usage (docker):
docker exec -it onyx-api_server-1 \
python -m scripts.reencrypt_secrets --old-key "previous-key"
Usage (kubernetes):
kubectl exec -it <pod> -- \
python -m scripts.reencrypt_secrets --old-key "previous-key"
Omit --old-key (or pass "") if secrets were not previously encrypted.
For multi-tenant deployments, pass --tenant-id to target a specific tenant,
or --all-tenants to iterate every tenant.
"""
import argparse
import os
import sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from onyx.db.rotate_encryption_key import rotate_encryption_key # noqa: E402
from onyx.db.engine.sql_engine import get_session_with_tenant # noqa: E402
from onyx.db.engine.sql_engine import SqlEngine # noqa: E402
from onyx.db.engine.tenant_utils import get_all_tenant_ids # noqa: E402
from onyx.utils.variable_functionality import global_version # noqa: E402
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # noqa: E402
def _run_for_tenant(tenant_id: str, old_key: str | None, dry_run: bool = False) -> None:
print(f"Re-encrypting secrets for tenant: {tenant_id}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
results = rotate_encryption_key(db_session, old_key=old_key, dry_run=dry_run)
if results:
for col, count in results.items():
print(
f" {col}: {count} row(s) {'would be ' if dry_run else ''}re-encrypted"
)
else:
print("No rows needed re-encryption.")
def main() -> None:
parser = argparse.ArgumentParser(
description="Re-encrypt secrets under the current encryption key."
)
parser.add_argument(
"--old-key",
default=None,
help="Previous encryption key. Omit or pass empty string if not applicable.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be re-encrypted without making changes.",
)
tenant_group = parser.add_mutually_exclusive_group()
tenant_group.add_argument(
"--tenant-id",
default=None,
help="Target a specific tenant schema.",
)
tenant_group.add_argument(
"--all-tenants",
action="store_true",
help="Iterate all tenants.",
)
args = parser.parse_args()
old_key = args.old_key if args.old_key else None
global_version.set_ee()
SqlEngine.init_engine(pool_size=5, max_overflow=2)
if args.dry_run:
print("DRY RUN — no changes will be made")
if args.all_tenants:
tenant_ids = get_all_tenant_ids()
print(f"Found {len(tenant_ids)} tenant(s)")
for tid in tenant_ids:
_run_for_tenant(tid, old_key, dry_run=args.dry_run)
else:
tenant_id = args.tenant_id or POSTGRES_DEFAULT_SCHEMA
_run_for_tenant(tenant_id, old_key, dry_run=args.dry_run)
print("Done.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,305 @@
"""Tests for rotate_encryption_key against real Postgres.
Uses real ORM models (Credential, InternetSearchProvider) and the actual
Postgres database. Discovery is mocked in rotation tests to scope mutations
to only the test rows — the real _discover_encrypted_columns walk is tested
separately in TestDiscoverEncryptedColumns.
Requires a running Postgres instance. Run with::
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/test_rotate_encryption_key.py
"""
import json
from collections.abc import Generator
from unittest.mock import patch
import pytest
from sqlalchemy import LargeBinary
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import Session
from ee.onyx.utils.encryption import _decrypt_bytes
from ee.onyx.utils.encryption import _encrypt_string
from ee.onyx.utils.encryption import _get_trimmed_key
from onyx.configs.constants import DocumentSource
from onyx.db.models import Credential
from onyx.db.models import EncryptedJson
from onyx.db.models import EncryptedString
from onyx.db.models import InternetSearchProvider
from onyx.db.rotate_encryption_key import _discover_encrypted_columns
from onyx.db.rotate_encryption_key import rotate_encryption_key
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
EE_MODULE = "ee.onyx.utils.encryption"
ROTATE_MODULE = "onyx.db.rotate_encryption_key"
OLD_KEY = "o" * 16
NEW_KEY = "n" * 16
@pytest.fixture(autouse=True)
def _enable_ee() -> Generator[None, None, None]:
prev = global_version._is_ee
global_version.set_ee()
fetch_versioned_implementation.cache_clear()
yield
global_version._is_ee = prev
fetch_versioned_implementation.cache_clear()
@pytest.fixture(autouse=True)
def _clear_key_cache() -> None:
_get_trimmed_key.cache_clear()
def _raw_credential_bytes(db_session: Session, credential_id: int) -> bytes | None:
"""Read raw bytes from credential_json, bypassing the TypeDecorator."""
col = Credential.__table__.c.credential_json
stmt = select(col.cast(LargeBinary)).where(
Credential.__table__.c.id == credential_id
)
return db_session.execute(stmt).scalar()
def _raw_isp_bytes(db_session: Session, isp_id: int) -> bytes | None:
"""Read raw bytes from InternetSearchProvider.api_key."""
col = InternetSearchProvider.__table__.c.api_key
stmt = select(col.cast(LargeBinary)).where(
InternetSearchProvider.__table__.c.id == isp_id
)
return db_session.execute(stmt).scalar()
class TestDiscoverEncryptedColumns:
"""Verify _discover_encrypted_columns finds real production models."""
def test_discovers_credential_json(self) -> None:
results = _discover_encrypted_columns()
found = {
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
for model_cls, col_name, _, is_json in results
}
assert ("credential", "credential_json", True) in found
def test_discovers_internet_search_provider_api_key(self) -> None:
results = _discover_encrypted_columns()
found = {
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
for model_cls, col_name, _, is_json in results
}
assert ("internet_search_provider", "api_key", False) in found
def test_all_encrypted_string_columns_are_not_json(self) -> None:
results = _discover_encrypted_columns()
for model_cls, col_name, _, is_json in results:
col = getattr(model_cls, col_name).property.columns[0]
if isinstance(col.type, EncryptedString):
assert not is_json, (
f"{model_cls.__tablename__}.{col_name} is EncryptedString " # type: ignore[attr-defined]
f"but is_json={is_json}"
)
def test_all_encrypted_json_columns_are_json(self) -> None:
results = _discover_encrypted_columns()
for model_cls, col_name, _, is_json in results:
col = getattr(model_cls, col_name).property.columns[0]
if isinstance(col.type, EncryptedJson):
assert is_json, (
f"{model_cls.__tablename__}.{col_name} is EncryptedJson " # type: ignore[attr-defined]
f"but is_json={is_json}"
)
class TestRotateCredential:
"""Test rotation against the real Credential table (EncryptedJson).
Discovery is scoped to only the Credential model to avoid mutating
other tables in the test database.
"""
@pytest.fixture(autouse=True)
def _limit_discovery(self) -> Generator[None, None, None]:
with patch(
f"{ROTATE_MODULE}._discover_encrypted_columns",
return_value=[(Credential, "credential_json", ["id"], True)],
):
yield
@pytest.fixture()
def credential_id(
self, db_session: Session, tenant_context: None # noqa: ARG002
) -> Generator[int, None, None]:
"""Insert a Credential row with raw encrypted bytes, clean up after."""
config = {"api_key": "sk-test-1234", "endpoint": "https://example.com"}
encrypted = _encrypt_string(json.dumps(config), key=OLD_KEY)
result = db_session.execute(
text(
"INSERT INTO credential "
"(source, credential_json, admin_public, curator_public) "
"VALUES (:source, :cred_json, true, false) "
"RETURNING id"
),
{"source": DocumentSource.INGESTION_API.value, "cred_json": encrypted},
)
cred_id = result.scalar_one()
db_session.commit()
yield cred_id
db_session.execute(
text("DELETE FROM credential WHERE id = :id"), {"id": cred_id}
)
db_session.commit()
def test_rotates_credential_json(
self, db_session: Session, credential_id: int
) -> None:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
assert totals.get("credential.credential_json", 0) >= 1
raw = _raw_credential_bytes(db_session, credential_id)
assert raw is not None
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
assert decrypted["api_key"] == "sk-test-1234"
assert decrypted["endpoint"] == "https://example.com"
def test_skips_already_rotated(
self, db_session: Session, credential_id: int
) -> None:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
rotate_encryption_key(db_session, old_key=OLD_KEY)
_ = rotate_encryption_key(db_session, old_key=OLD_KEY)
raw = _raw_credential_bytes(db_session, credential_id)
assert raw is not None
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
assert decrypted["api_key"] == "sk-test-1234"
def test_dry_run_does_not_modify(
self, db_session: Session, credential_id: int
) -> None:
original = _raw_credential_bytes(db_session, credential_id)
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=OLD_KEY, dry_run=True)
assert totals.get("credential.credential_json", 0) >= 1
raw_after = _raw_credential_bytes(db_session, credential_id)
assert raw_after == original
class TestRotateInternetSearchProvider:
"""Test rotation against the real InternetSearchProvider table (EncryptedString).
Discovery is scoped to only the InternetSearchProvider model to avoid
mutating other tables in the test database.
"""
@pytest.fixture(autouse=True)
def _limit_discovery(self) -> Generator[None, None, None]:
with patch(
f"{ROTATE_MODULE}._discover_encrypted_columns",
return_value=[
(InternetSearchProvider, "api_key", ["id"], False),
],
):
yield
@pytest.fixture()
def isp_id(
self, db_session: Session, tenant_context: None # noqa: ARG002
) -> Generator[int, None, None]:
"""Insert an InternetSearchProvider row with raw encrypted bytes."""
encrypted = _encrypt_string("sk-secret-api-key", key=OLD_KEY)
result = db_session.execute(
text(
"INSERT INTO internet_search_provider "
"(name, provider_type, api_key, is_active) "
"VALUES (:name, :ptype, :api_key, false) "
"RETURNING id"
),
{
"name": f"test-rotation-{id(self)}",
"ptype": "test",
"api_key": encrypted,
},
)
isp_id = result.scalar_one()
db_session.commit()
yield isp_id
db_session.execute(
text("DELETE FROM internet_search_provider WHERE id = :id"),
{"id": isp_id},
)
db_session.commit()
def test_rotates_api_key(self, db_session: Session, isp_id: int) -> None:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
assert totals.get("internet_search_provider.api_key", 0) >= 1
raw = _raw_isp_bytes(db_session, isp_id)
assert raw is not None
assert _decrypt_bytes(raw, key=NEW_KEY) == "sk-secret-api-key"
def test_rotates_from_unencrypted(
self, db_session: Session, tenant_context: None # noqa: ARG002
) -> None:
"""Test rotating data that was stored without any encryption key."""
result = db_session.execute(
text(
"INSERT INTO internet_search_provider "
"(name, provider_type, api_key, is_active) "
"VALUES (:name, :ptype, :api_key, false) "
"RETURNING id"
),
{
"name": f"test-raw-{id(self)}",
"ptype": "test",
"api_key": b"raw-api-key",
},
)
isp_id = result.scalar_one()
db_session.commit()
try:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=None)
assert totals.get("internet_search_provider.api_key", 0) >= 1
raw = _raw_isp_bytes(db_session, isp_id)
assert raw is not None
assert _decrypt_bytes(raw, key=NEW_KEY) == "raw-api-key"
finally:
db_session.execute(
text("DELETE FROM internet_search_provider WHERE id = :id"),
{"id": isp_id},
)
db_session.commit()

View File

@@ -0,0 +1,165 @@
"""Tests for EE AES-CBC encryption/decryption with explicit key support.
With EE mode enabled (via conftest), fetch_versioned_implementation resolves
to the EE implementations, so no patching of the MIT layer is needed.
"""
from unittest.mock import patch
import pytest
from ee.onyx.utils.encryption import _decrypt_bytes
from ee.onyx.utils.encryption import _encrypt_string
from ee.onyx.utils.encryption import _get_trimmed_key
from ee.onyx.utils.encryption import decrypt_bytes_to_string
from ee.onyx.utils.encryption import encrypt_string_to_bytes
EE_MODULE = "ee.onyx.utils.encryption"
# Keys must be exactly 16, 24, or 32 bytes for AES
KEY_16 = "a" * 16
KEY_16_ALT = "b" * 16
KEY_24 = "d" * 24
KEY_32 = "c" * 32
@pytest.fixture(autouse=True)
def _clear_key_cache() -> None:
_get_trimmed_key.cache_clear()
class TestEncryptDecryptRoundTrip:
def test_roundtrip_with_env_key(self) -> None:
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
encrypted = _encrypt_string("hello world")
assert encrypted != b"hello world"
assert _decrypt_bytes(encrypted) == "hello world"
def test_roundtrip_with_explicit_key(self) -> None:
encrypted = _encrypt_string("secret data", key=KEY_32)
assert encrypted != b"secret data"
assert _decrypt_bytes(encrypted, key=KEY_32) == "secret data"
def test_roundtrip_no_key(self) -> None:
"""Without any key, data is raw-encoded (no encryption)."""
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
encrypted = _encrypt_string("plain text")
assert encrypted == b"plain text"
assert _decrypt_bytes(encrypted) == "plain text"
def test_explicit_key_overrides_env(self) -> None:
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
encrypted = _encrypt_string("data", key=KEY_16_ALT)
with pytest.raises(ValueError):
_decrypt_bytes(encrypted, key=KEY_16)
assert _decrypt_bytes(encrypted, key=KEY_16_ALT) == "data"
def test_different_encryptions_produce_different_bytes(self) -> None:
"""Each encryption uses a random IV, so results differ."""
a = _encrypt_string("same", key=KEY_16)
b = _encrypt_string("same", key=KEY_16)
assert a != b
def test_roundtrip_empty_string(self) -> None:
encrypted = _encrypt_string("", key=KEY_16)
assert encrypted != b""
assert _decrypt_bytes(encrypted, key=KEY_16) == ""
def test_roundtrip_unicode(self) -> None:
text = "日本語テスト 🔐 émojis"
encrypted = _encrypt_string(text, key=KEY_16)
assert _decrypt_bytes(encrypted, key=KEY_16) == text
class TestDecryptFallbackBehavior:
def test_wrong_env_key_falls_back_to_raw_decode(self) -> None:
"""Default key path: AES fails on non-AES data → fallback to raw decode."""
raw = "readable text".encode()
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
assert _decrypt_bytes(raw) == "readable text"
def test_explicit_wrong_key_raises(self) -> None:
"""Explicit key path: AES fails → raises, no fallback."""
encrypted = _encrypt_string("secret", key=KEY_16)
with pytest.raises(ValueError):
_decrypt_bytes(encrypted, key=KEY_16_ALT)
def test_explicit_none_key_with_no_env(self) -> None:
"""key=None with empty env → raw decode."""
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
assert _decrypt_bytes(b"hello", key=None) == "hello"
def test_explicit_empty_string_key(self) -> None:
"""key='' means no encryption."""
encrypted = _encrypt_string("test", key="")
assert encrypted == b"test"
assert _decrypt_bytes(encrypted, key="") == "test"
class TestKeyValidation:
def test_key_too_short_raises(self) -> None:
with pytest.raises(RuntimeError, match="too short"):
_encrypt_string("data", key="short")
def test_16_byte_key(self) -> None:
encrypted = _encrypt_string("data", key=KEY_16)
assert _decrypt_bytes(encrypted, key=KEY_16) == "data"
def test_24_byte_key(self) -> None:
encrypted = _encrypt_string("data", key=KEY_24)
assert _decrypt_bytes(encrypted, key=KEY_24) == "data"
def test_32_byte_key(self) -> None:
encrypted = _encrypt_string("data", key=KEY_32)
assert _decrypt_bytes(encrypted, key=KEY_32) == "data"
def test_long_key_truncated_to_32(self) -> None:
"""Keys longer than 32 bytes are truncated to 32."""
long_key = "e" * 64
encrypted = _encrypt_string("data", key=long_key)
assert _decrypt_bytes(encrypted, key=long_key) == "data"
def test_20_byte_key_trimmed_to_16(self) -> None:
"""A 20-byte key is trimmed to the largest valid AES size that fits (16)."""
key_20 = "f" * 20
encrypted = _encrypt_string("data", key=key_20)
assert _decrypt_bytes(encrypted, key=key_20) == "data"
# Verify it was trimmed to 16 by checking that the first 16 bytes
# of the key can also decrypt it
key_16_same_prefix = "f" * 16
assert _decrypt_bytes(encrypted, key=key_16_same_prefix) == "data"
def test_25_byte_key_trimmed_to_24(self) -> None:
"""A 25-byte key is trimmed to the largest valid AES size that fits (24)."""
key_25 = "g" * 25
encrypted = _encrypt_string("data", key=key_25)
assert _decrypt_bytes(encrypted, key=key_25) == "data"
key_24_same_prefix = "g" * 24
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
def test_30_byte_key_trimmed_to_24(self) -> None:
"""A 30-byte key is trimmed to the largest valid AES size that fits (24)."""
key_30 = "h" * 30
encrypted = _encrypt_string("data", key=key_30)
assert _decrypt_bytes(encrypted, key=key_30) == "data"
key_24_same_prefix = "h" * 24
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
class TestWrapperFunctions:
"""Test encrypt_string_to_bytes / decrypt_bytes_to_string pass key through.
With EE mode enabled, the wrappers resolve to EE implementations automatically.
"""
def test_wrapper_passes_key(self) -> None:
encrypted = encrypt_string_to_bytes("payload", key=KEY_16)
assert decrypt_bytes_to_string(encrypted, key=KEY_16) == "payload"
def test_wrapper_no_key_uses_env(self) -> None:
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_32):
encrypted = encrypt_string_to_bytes("payload")
assert decrypt_bytes_to_string(encrypted) == "payload"