mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-22 18:25:45 +00:00
Compare commits
15 Commits
batch_proc
...
v0.10.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f5989477b | ||
|
|
d0e018a3fa | ||
|
|
ee5e085c81 | ||
|
|
decd31c65a | ||
|
|
f0adc03de3 | ||
|
|
c07a10440b | ||
|
|
ca4b230f67 | ||
|
|
4479bdceab | ||
|
|
e23a37e3f1 | ||
|
|
eadd87363d | ||
|
|
91d9414ada | ||
|
|
39a485b777 | ||
|
|
35e4ba7f99 | ||
|
|
302d28f1e8 | ||
|
|
7e8e89359d |
@@ -31,6 +31,12 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# First, update any null values to a default value
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL"
|
||||
)
|
||||
|
||||
# Then, make the column non-nullable
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
|
||||
@@ -96,6 +96,7 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
@@ -21,6 +21,8 @@ celery_app.config_from_object("danswer.background.celery.configs.beat")
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
|
||||
# celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
@@ -58,7 +58,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
@@ -166,19 +166,6 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
r.delete(key)
|
||||
|
||||
|
||||
# @worker_process_init.connect
|
||||
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
|
||||
# """This only runs inside child processes when the worker is in pool=prefork mode.
|
||||
# This may be technically unnecessary since we're finding prefork pools to be
|
||||
# unstable and currently aren't planning on using them."""
|
||||
# logger.info("worker_process_init signal received.")
|
||||
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
|
||||
|
||||
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
|
||||
# SqlEngine.get_engine().dispose(close=False)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
@@ -11,7 +11,8 @@ from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -37,7 +38,9 @@ def _initializer(
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
get_sqlalchemy_engine().dispose(close=False)
|
||||
logger.info("Initializing spawned worker child process.")
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12, pool_recycle=60)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -91,12 +91,13 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
cql_page_query += f" and id='{page_id}'"
|
||||
|
||||
self.cql_page_query = cql_page_query
|
||||
self.cql_label_filter = ""
|
||||
self.cql_time_filter = ""
|
||||
|
||||
self.cql_label_filter = ""
|
||||
if labels_to_skip:
|
||||
labels_to_skip = list(set(labels_to_skip))
|
||||
comma_separated_labels = ",".join(labels_to_skip)
|
||||
self.cql_label_filter = f"&label not in ({comma_separated_labels})"
|
||||
comma_separated_labels = ",".join(f"'{label}'" for label in labels_to_skip)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
@@ -125,7 +126,8 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
for comment in comments:
|
||||
comment_string += "\nComment:\n"
|
||||
comment_string += extract_text_from_confluence_html(
|
||||
confluence_client=self.confluence_client, confluence_object=comment
|
||||
confluence_client=self.confluence_client,
|
||||
confluence_object=comment,
|
||||
)
|
||||
|
||||
return comment_string
|
||||
|
||||
@@ -24,6 +24,10 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# List of directories/Files to exclude
|
||||
exclude_patterns = [
|
||||
"logs",
|
||||
@@ -31,7 +35,6 @@ exclude_patterns = [
|
||||
".gitlab/",
|
||||
".pre-commit-config.yaml",
|
||||
]
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _batch_gitlab_objects(
|
||||
|
||||
@@ -22,6 +22,7 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -230,5 +231,7 @@ if __name__ == "__main__":
|
||||
print("All docs", all_docs)
|
||||
current = datetime.datetime.now().timestamp()
|
||||
one_day_ago = current - 30 * 24 * 60 * 60 # 30 days
|
||||
|
||||
latest_docs = list(test_connector.poll_source(one_day_ago, current))
|
||||
|
||||
print("Latest docs", latest_docs)
|
||||
|
||||
@@ -20,10 +20,13 @@ from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# Fairly generous retry because it's not understood why occasionally GraphQL requests fail even with timeout > 1 min
|
||||
SLAB_GRAPHQL_MAX_TRIES = 10
|
||||
SLAB_API_URL = "https://api.slab.com/v1/graphql"
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_graphql_request(
|
||||
|
||||
@@ -441,6 +441,7 @@ if __name__ == "__main__":
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -237,6 +237,7 @@ class Answer:
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
@@ -331,7 +332,10 @@ class Answer:
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
) -> Iterator[str | StreamStopInfo]:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt, tools=tools, tool_choice=tool_choice
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk):
|
||||
if message.content:
|
||||
|
||||
@@ -102,6 +102,9 @@ class TenantRedis(redis.Redis):
|
||||
"reacquire",
|
||||
"create_lock",
|
||||
"startswith",
|
||||
"sadd",
|
||||
"srem",
|
||||
"scard",
|
||||
] # Regular methods that need simple prefixing
|
||||
|
||||
if item == "scan_iter":
|
||||
|
||||
@@ -154,42 +154,38 @@ def test_send_message_simple_with_history_strict_json(
|
||||
new_admin_user: DATestUser | None,
|
||||
) -> None:
|
||||
# create connectors
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=new_admin_user,
|
||||
)
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=new_admin_user,
|
||||
)
|
||||
LLMProviderManager.create(user_performing_action=new_admin_user)
|
||||
cc_pair_1.documents = DocumentManager.seed_dummy_docs(
|
||||
cc_pair=cc_pair_1,
|
||||
num_docs=NUM_DOCS,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
# intentionally not relevant prompt to ensure that the
|
||||
# structured response format is actually used
|
||||
"messages": [
|
||||
{
|
||||
"message": "List the names of the first three US presidents in JSON format",
|
||||
"message": "What is green?",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
"prompt_id": 0,
|
||||
"structured_response_format": {
|
||||
"type": "json_object",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"presidents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of the first three US presidents",
|
||||
}
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "presidents",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"presidents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of the first three US presidents",
|
||||
}
|
||||
},
|
||||
"required": ["presidents"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"required": ["presidents"],
|
||||
"strict": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -211,14 +207,17 @@ def test_send_message_simple_with_history_strict_json(
|
||||
try:
|
||||
clean_answer = clean_json_string(response_json["answer"])
|
||||
parsed_answer = json.loads(clean_answer)
|
||||
|
||||
# NOTE: do not check content, just the structure
|
||||
assert isinstance(parsed_answer, dict)
|
||||
assert "presidents" in parsed_answer
|
||||
assert isinstance(parsed_answer["presidents"], list)
|
||||
assert len(parsed_answer["presidents"]) == 3
|
||||
for president in parsed_answer["presidents"]:
|
||||
assert isinstance(president, str)
|
||||
except json.JSONDecodeError:
|
||||
assert False, "The answer is not a valid JSON object"
|
||||
assert (
|
||||
False
|
||||
), f"The answer is not a valid JSON object - '{response_json['answer']}'"
|
||||
|
||||
# Check that the answer_citationless is also valid JSON
|
||||
assert "answer_citationless" in response_json
|
||||
|
||||
@@ -9,12 +9,15 @@ from pytest_mock import MockFixture
|
||||
|
||||
from danswer.connectors.mediawiki import wiki
|
||||
|
||||
# These tests are disabled for now
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def site() -> pywikibot.Site:
|
||||
return pywikibot.Site("en", "wikipedia")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Test disabled")
|
||||
def test_pywikibot_timestamp_to_utc_datetime() -> None:
|
||||
timestamp_without_tzinfo = pywikibot.Timestamp(2023, 12, 27, 15, 38, 49)
|
||||
timestamp_min_timezone = timestamp_without_tzinfo.astimezone(datetime.timezone.min)
|
||||
@@ -80,6 +83,7 @@ class MockPage(pywikibot.Page):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Test disabled")
|
||||
def test_get_doc_from_page(site: pywikibot.Site) -> None:
|
||||
test_page = MockPage(site, "Test Page", _has_categories=True)
|
||||
doc = wiki.get_doc_from_page(test_page, site, wiki.DocumentSource.MEDIAWIKI)
|
||||
@@ -103,6 +107,7 @@ def test_get_doc_from_page(site: pywikibot.Site) -> None:
|
||||
assert doc.id == f"MEDIAWIKI_{test_page.pageid}_{test_page.full_url()}"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Test disabled")
|
||||
def test_mediawiki_connector_recurse_depth() -> None:
|
||||
"""Test that the recurse_depth parameter is parsed correctly.
|
||||
|
||||
@@ -132,6 +137,7 @@ def test_mediawiki_connector_recurse_depth() -> None:
|
||||
assert connector.recurse_depth == recurse_depth
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Test disabled")
|
||||
def test_load_from_state_calls_poll_source_with_nones(mocker: MockFixture) -> None:
|
||||
connector = wiki.MediaWikiConnector("wikipedia.org", [], [], 0, "test")
|
||||
poll_source = mocker.patch.object(connector, "poll_source")
|
||||
|
||||
Reference in New Issue
Block a user