mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
3 Commits
loading_or
...
remove_see
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
50c186c8c6 | ||
|
|
e965a602ad | ||
|
|
7fa3c9c82d |
@@ -135,6 +135,7 @@ class SearchPipeline:
|
||||
|
||||
"""Retrieval and Postprocessing"""
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_chunks(self) -> list[InferenceChunk]:
|
||||
if self._retrieved_chunks is not None:
|
||||
return self._retrieved_chunks
|
||||
@@ -306,6 +307,7 @@ class SearchPipeline:
|
||||
return expanded_inference_sections
|
||||
|
||||
@property
|
||||
@log_function_time(print_only=True)
|
||||
def reranked_sections(self) -> list[InferenceSection]:
|
||||
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
||||
long sections which could be:
|
||||
@@ -331,6 +333,7 @@ class SearchPipeline:
|
||||
return self._reranked_sections
|
||||
|
||||
@property
|
||||
@log_function_time(print_only=True)
|
||||
def final_context_sections(self) -> list[InferenceSection]:
|
||||
if self._final_context_sections is not None:
|
||||
return self._final_context_sections
|
||||
@@ -339,6 +342,7 @@ class SearchPipeline:
|
||||
return self._final_context_sections
|
||||
|
||||
@property
|
||||
@log_function_time(print_only=True)
|
||||
def section_relevance(self) -> list[SectionRelevancePiece] | None:
|
||||
if self._section_relevance is not None:
|
||||
return self._section_relevance
|
||||
@@ -393,6 +397,7 @@ class SearchPipeline:
|
||||
return self._section_relevance
|
||||
|
||||
@property
|
||||
@log_function_time(print_only=True)
|
||||
def section_relevance_list(self) -> list[bool]:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=self.section_relevance,
|
||||
|
||||
@@ -42,6 +42,7 @@ def _log_top_section_links(search_flow: str, sections: list[InferenceSection]) -
|
||||
logger.debug(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]:
|
||||
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
|
||||
if not chunk.title or not chunk.content:
|
||||
@@ -244,6 +245,7 @@ def filter_sections(
|
||||
]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def search_postprocessing(
|
||||
search_query: SearchQuery,
|
||||
retrieved_sections: list[InferenceSection],
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import string
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
import nltk # type:ignore
|
||||
@@ -85,6 +86,7 @@ def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]:
|
||||
return keywords
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def combine_retrieval_results(
|
||||
chunk_sets: list[list[InferenceChunk]],
|
||||
) -> list[InferenceChunk]:
|
||||
@@ -256,7 +258,13 @@ def retrieve_chunks(
|
||||
(q_copy, document_index, db_session),
|
||||
)
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
end_time = time.time()
|
||||
logger.info(
|
||||
f"Parallel search execution took {end_time - start_time:.2f} seconds"
|
||||
)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
|
||||
if not top_chunks:
|
||||
|
||||
@@ -8,6 +8,7 @@ from danswer.context.search.models import SavedSearchDoc
|
||||
from danswer.context.search.models import SavedSearchDocWithContent
|
||||
from danswer.context.search.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
T = TypeVar(
|
||||
@@ -88,6 +89,7 @@ def drop_llm_indices(
|
||||
return [i for i, val in enumerate(llm_bools) if val]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def inference_section_from_chunks(
|
||||
center_chunk: InferenceChunk,
|
||||
chunks: list[InferenceChunk],
|
||||
|
||||
@@ -48,6 +48,7 @@ from danswer.document_index.vespa_constants import TITLE
|
||||
from danswer.document_index.vespa_constants import YQL_BASE
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -146,6 +147,7 @@ def _vespa_hit_to_inference_chunk(
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_chunks_via_visit_api(
|
||||
chunk_request: VespaChunkRequest,
|
||||
index_name: str,
|
||||
@@ -232,6 +234,7 @@ def _get_chunks_via_visit_api(
|
||||
|
||||
|
||||
@retry(tries=10, delay=1, backoff=2)
|
||||
@log_function_time(print_only=True)
|
||||
def get_all_vespa_ids_for_document_id(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
@@ -248,6 +251,7 @@ def get_all_vespa_ids_for_document_id(
|
||||
return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def parallel_visit_api_retrieval(
|
||||
index_name: str,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
@@ -262,9 +266,12 @@ def parallel_visit_api_retrieval(
|
||||
for chunk_request in chunk_requests
|
||||
]
|
||||
|
||||
start_time = datetime.now()
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
duration = datetime.now() - start_time
|
||||
print(f"Parallel visit API retrieval took {duration.total_seconds():.2f} seconds")
|
||||
|
||||
# Any failures to retrieve would give a None, drop the Nones and empty lists
|
||||
vespa_chunk_sets = [res for res in parallel_results if res]
|
||||
@@ -282,9 +289,11 @@ def parallel_visit_api_retrieval(
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
@log_function_time(print_only=True)
|
||||
def query_vespa(
|
||||
query_params: Mapping[str, str | int | float]
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
print(f"query_params: {query_params}")
|
||||
if "query" in query_params and not cast(str, query_params["query"]).strip():
|
||||
raise ValueError("No/empty query received")
|
||||
|
||||
@@ -340,6 +349,7 @@ def query_vespa(
|
||||
return inference_chunks
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_chunks_via_batch_search(
|
||||
index_name: str,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
@@ -374,6 +384,7 @@ def _get_chunks_via_batch_search(
|
||||
return inference_chunks
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def batch_search_api_retrieval(
|
||||
index_name: str,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
|
||||
@@ -72,6 +72,7 @@ from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.key_value_store.factory import get_kv_store
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
@@ -660,6 +661,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
return total_chunks_deleted
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def id_based_retrieval(
|
||||
self,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
@@ -681,6 +683,7 @@ class VespaIndex(DocumentIndex):
|
||||
get_large_chunks=get_large_chunks,
|
||||
)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
|
||||
@@ -21,6 +21,7 @@ from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.tools.tool_implementations.search.search_utils import section_to_dict
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -43,6 +44,7 @@ class ChunkRange(BaseModel):
|
||||
end: int
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
|
||||
"""
|
||||
This acts on a single document to merge the overlapping ranges of chunks
|
||||
@@ -300,6 +302,7 @@ def prune_sections(
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
# Assuming there are no duplicates by this point
|
||||
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
|
||||
@@ -327,6 +330,7 @@ def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
|
||||
doc_order: dict[str, int] = {}
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import count_punctuation
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -48,6 +49,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
||||
return model_output
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def multilingual_query_expansion(
|
||||
query: str,
|
||||
expansion_languages: list[str],
|
||||
|
||||
@@ -81,7 +81,6 @@ def load_personas_from_yaml(
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
|
||||
@@ -38,7 +38,6 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
from danswer.seeding.load_docs import seed_initial_documents
|
||||
from danswer.seeding.load_yamls import load_chat_yamls
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.settings.store import load_settings
|
||||
@@ -150,7 +149,7 @@ def setup_danswer(
|
||||
# update multipass indexing setting based on GPU availability
|
||||
update_default_multipass_indexing(db_session)
|
||||
|
||||
seed_initial_documents(db_session, tenant_id, cohere_enabled)
|
||||
# seed_initial_documents(db_session, tenant_id, cohere_enabled)
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
@@ -254,14 +253,13 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls(db_session)
|
||||
|
||||
logger.notice("Loading built-in tools")
|
||||
load_builtin_tools(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
load_chat_yamls(db_session)
|
||||
|
||||
refresh_built_in_tools_cache(db_session)
|
||||
auto_add_search_tool_to_personas(db_session)
|
||||
|
||||
|
||||
@@ -259,7 +259,7 @@ export function ChatPage({
|
||||
refreshRecentAssistants,
|
||||
} = useAssistants();
|
||||
|
||||
const liveAssistant: Persona | undefined =
|
||||
const liveAssistant =
|
||||
alternativeAssistant ||
|
||||
selectedAssistant ||
|
||||
recentAssistants[0] ||
|
||||
@@ -269,7 +269,6 @@ export function ChatPage({
|
||||
const noAssistants = liveAssistant == null || liveAssistant == undefined;
|
||||
// always set the model override for the chat session, when an assistant, llm provider, or user preference exists
|
||||
useEffect(() => {
|
||||
if (noAssistants) return;
|
||||
const personaDefault = getLLMProviderOverrideForPersona(
|
||||
liveAssistant,
|
||||
llmProviders
|
||||
@@ -754,7 +753,7 @@ export function ChatPage({
|
||||
useEffect(() => {
|
||||
async function fetchMaxTokens() {
|
||||
const response = await fetch(
|
||||
`/api/chat/max-selected-document-tokens?persona_id=${liveAssistant?.id}`
|
||||
`/api/chat/max-selected-document-tokens?persona_id=${liveAssistant.id}`
|
||||
);
|
||||
if (response.ok) {
|
||||
const maxTokens = (await response.json()).max_tokens as number;
|
||||
@@ -1810,23 +1809,18 @@ export function ChatPage({
|
||||
});
|
||||
};
|
||||
}
|
||||
if (noAssistants)
|
||||
return (
|
||||
<>
|
||||
<HealthCheckBanner />
|
||||
<NoAssistantModal isAdmin={isAdmin} />
|
||||
</>
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<HealthCheckBanner />
|
||||
|
||||
{showApiKeyModal && !shouldShowWelcomeModal && (
|
||||
{showApiKeyModal && !shouldShowWelcomeModal ? (
|
||||
<ApiKeyModal
|
||||
hide={() => setShowApiKeyModal(false)}
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
) : (
|
||||
noAssistants && <NoAssistantModal isAdmin={isAdmin} />
|
||||
)}
|
||||
|
||||
{/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit.
|
||||
|
||||
@@ -10,8 +10,6 @@ import {
|
||||
import { ChevronDownIcon } from "./icons/icons";
|
||||
import { FiCheck, FiChevronDown } from "react-icons/fi";
|
||||
import { Popover } from "./popover/Popover";
|
||||
import { createPortal } from "react-dom";
|
||||
import { useDropdownPosition } from "@/lib/dropdown";
|
||||
|
||||
export interface Option<T> {
|
||||
name: string;
|
||||
@@ -62,7 +60,6 @@ export function SearchMultiSelectDropdown({
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
const dropdownMenuRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const handleSelect = (option: StringOrNumberOption) => {
|
||||
onSelect(option);
|
||||
@@ -78,9 +75,7 @@ export function SearchMultiSelectDropdown({
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
dropdownRef.current &&
|
||||
!dropdownRef.current.contains(event.target as Node) &&
|
||||
dropdownMenuRef.current &&
|
||||
!dropdownMenuRef.current.contains(event.target as Node)
|
||||
!dropdownRef.current.contains(event.target as Node)
|
||||
) {
|
||||
setIsOpen(false);
|
||||
}
|
||||
@@ -92,103 +87,105 @@ export function SearchMultiSelectDropdown({
|
||||
};
|
||||
}, []);
|
||||
|
||||
useDropdownPosition({ isOpen, dropdownRef, dropdownMenuRef });
|
||||
|
||||
return (
|
||||
<div className="relative text-left w-full" ref={dropdownRef}>
|
||||
<div className="relative inline-block text-left w-full" ref={dropdownRef}>
|
||||
<div>
|
||||
<input
|
||||
type="text"
|
||||
placeholder="Search..."
|
||||
value={searchTerm}
|
||||
onChange={(e: ChangeEvent<HTMLInputElement>) => {
|
||||
setSearchTerm(e.target.value);
|
||||
if (e.target.value) {
|
||||
if (!searchTerm) {
|
||||
setIsOpen(true);
|
||||
} else {
|
||||
}
|
||||
if (!e.target.value) {
|
||||
setIsOpen(false);
|
||||
}
|
||||
setSearchTerm(e.target.value);
|
||||
}}
|
||||
onFocus={() => setIsOpen(true)}
|
||||
className={`inline-flex
|
||||
justify-between
|
||||
w-full
|
||||
px-4
|
||||
py-2
|
||||
text-sm
|
||||
bg-background
|
||||
border
|
||||
border-border
|
||||
rounded-md
|
||||
shadow-sm
|
||||
`}
|
||||
justify-between
|
||||
w-full
|
||||
px-4
|
||||
py-2
|
||||
text-sm
|
||||
bg-background
|
||||
border
|
||||
border-border
|
||||
rounded-md
|
||||
shadow-sm
|
||||
`}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
className={`absolute top-0 right-0
|
||||
text-sm
|
||||
h-full px-2 border-l border-border`}
|
||||
aria-expanded={isOpen}
|
||||
text-sm
|
||||
h-full px-2 border-l border-border`}
|
||||
aria-expanded="true"
|
||||
aria-haspopup="true"
|
||||
onClick={() => setIsOpen(!isOpen)}
|
||||
>
|
||||
<ChevronDownIcon className="my-auto w-4 h-4" />
|
||||
<ChevronDownIcon className="my-auto" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{isOpen &&
|
||||
createPortal(
|
||||
{isOpen && (
|
||||
<div
|
||||
className={`origin-top-right
|
||||
absolute
|
||||
left-0
|
||||
mt-3
|
||||
w-full
|
||||
rounded-md
|
||||
shadow-lg
|
||||
bg-background
|
||||
border
|
||||
border-border
|
||||
max-h-80
|
||||
overflow-y-auto
|
||||
overscroll-contain`}
|
||||
>
|
||||
<div
|
||||
ref={dropdownMenuRef}
|
||||
className={`origin-top-right
|
||||
rounded-md
|
||||
shadow-lg
|
||||
bg-background
|
||||
border
|
||||
border-border
|
||||
max-h-80
|
||||
overflow-y-auto
|
||||
overscroll-contain`}
|
||||
role="menu"
|
||||
aria-orientation="vertical"
|
||||
aria-labelledby="options-menu"
|
||||
>
|
||||
<div
|
||||
role="menu"
|
||||
aria-orientation="vertical"
|
||||
aria-labelledby="options-menu"
|
||||
>
|
||||
{filteredOptions.length ? (
|
||||
filteredOptions.map((option, index) =>
|
||||
itemComponent ? (
|
||||
<div
|
||||
key={option.name}
|
||||
onClick={() => {
|
||||
handleSelect(option);
|
||||
}}
|
||||
>
|
||||
{itemComponent({ option })}
|
||||
</div>
|
||||
) : (
|
||||
<StandardDropdownOption
|
||||
key={index}
|
||||
option={option}
|
||||
index={index}
|
||||
handleSelect={handleSelect}
|
||||
/>
|
||||
)
|
||||
{filteredOptions.length ? (
|
||||
filteredOptions.map((option, index) =>
|
||||
itemComponent ? (
|
||||
<div
|
||||
key={option.name}
|
||||
onClick={() => {
|
||||
setIsOpen(false);
|
||||
handleSelect(option);
|
||||
}}
|
||||
>
|
||||
{itemComponent({ option })}
|
||||
</div>
|
||||
) : (
|
||||
<StandardDropdownOption
|
||||
key={index}
|
||||
option={option}
|
||||
index={index}
|
||||
handleSelect={handleSelect}
|
||||
/>
|
||||
)
|
||||
) : (
|
||||
<button
|
||||
key={0}
|
||||
className={`w-full text-left block px-4 py-2.5 text-sm hover:bg-hover`}
|
||||
role="menuitem"
|
||||
onClick={() => setIsOpen(false)}
|
||||
>
|
||||
No matches found...
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>,
|
||||
document.body
|
||||
)}
|
||||
)
|
||||
) : (
|
||||
<button
|
||||
key={0}
|
||||
className={`w-full text-left block px-4 py-2.5 text-sm hover:bg-hover`}
|
||||
role="menuitem"
|
||||
onClick={() => setIsOpen(false)}
|
||||
>
|
||||
No matches found...
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -66,21 +66,11 @@ export function Modal({
|
||||
e.stopPropagation();
|
||||
}
|
||||
}}
|
||||
className={`
|
||||
bg-background
|
||||
text-emphasis
|
||||
rounded
|
||||
shadow-2xl
|
||||
transform
|
||||
transition-all
|
||||
duration-300
|
||||
ease-in-out
|
||||
relative
|
||||
overflow-visible
|
||||
className={`bg-background text-emphasis rounded shadow-2xl
|
||||
transform transition-all duration-300 ease-in-out
|
||||
${width ?? "w-11/12 max-w-4xl"}
|
||||
${noPadding ? "" : "p-10"}
|
||||
${className || ""}
|
||||
`}
|
||||
${className || ""}`}
|
||||
>
|
||||
{onOutsideClick && !hideCloseButton && (
|
||||
<div className="absolute top-2 right-2">
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
import { RefObject, useCallback, useEffect } from "react";
|
||||
|
||||
interface DropdownPositionProps {
|
||||
isOpen: boolean;
|
||||
dropdownRef: RefObject<HTMLElement>;
|
||||
dropdownMenuRef: RefObject<HTMLElement>;
|
||||
}
|
||||
|
||||
// This hook manages the positioning of a dropdown menu relative to its trigger element.
|
||||
// It ensures the menu is positioned correctly, adjusting for viewport boundaries and scroll position.
|
||||
// Also adds event listeners for window resize and scroll to update the position dynamically.
|
||||
export const useDropdownPosition = ({
|
||||
isOpen,
|
||||
dropdownRef,
|
||||
dropdownMenuRef,
|
||||
}: DropdownPositionProps) => {
|
||||
const updateMenuPosition = useCallback(() => {
|
||||
if (isOpen && dropdownRef.current && dropdownMenuRef.current) {
|
||||
const rect = dropdownRef.current.getBoundingClientRect();
|
||||
const menuRect = dropdownMenuRef.current.getBoundingClientRect();
|
||||
const viewportHeight = window.innerHeight;
|
||||
|
||||
let top = rect.bottom + window.scrollY;
|
||||
|
||||
if (top + menuRect.height > viewportHeight) {
|
||||
top = rect.top + window.scrollY - menuRect.height;
|
||||
}
|
||||
|
||||
dropdownMenuRef.current.style.position = "absolute";
|
||||
dropdownMenuRef.current.style.top = `${top}px`;
|
||||
dropdownMenuRef.current.style.left = `${rect.left + window.scrollX}px`;
|
||||
dropdownMenuRef.current.style.width = `${rect.width}px`;
|
||||
dropdownMenuRef.current.style.zIndex = "10000";
|
||||
}
|
||||
}, [isOpen, dropdownRef, dropdownMenuRef]);
|
||||
|
||||
useEffect(() => {
|
||||
updateMenuPosition();
|
||||
window.addEventListener("resize", updateMenuPosition);
|
||||
window.addEventListener("scroll", updateMenuPosition);
|
||||
|
||||
return () => {
|
||||
window.removeEventListener("resize", updateMenuPosition);
|
||||
window.removeEventListener("scroll", updateMenuPosition);
|
||||
};
|
||||
}, [isOpen, updateMenuPosition]);
|
||||
|
||||
return updateMenuPosition;
|
||||
};
|
||||
Reference in New Issue
Block a user