Compare commits

...

1 Commits

Author SHA1 Message Date
Yuhong Sun
9148f54e59 Rag script 2026-03-30 16:01:52 -07:00

View File

@@ -5,6 +5,7 @@ import asyncio
import json
import logging
import sys
import time
from dataclasses import asdict
from dataclasses import dataclass
from pathlib import Path
@@ -27,6 +28,9 @@ INTERNAL_SEARCH_TOOL_NAME = "internal_search"
INTERNAL_SEARCH_IN_CODE_TOOL_ID = "SearchTool"
MAX_REQUEST_ATTEMPTS = 5
RETRIABLE_STATUS_CODES = {429, 500, 502, 503, 504}
QUESTION_TIMEOUT_SECONDS = 300
QUESTION_RETRY_PAUSE_SECONDS = 30
MAX_QUESTION_ATTEMPTS = 3
@dataclass(frozen=True)
@@ -109,6 +113,27 @@ def normalize_api_base(api_base: str) -> str:
return f"{normalized}/api"
def load_completed_question_ids(output_file: Path) -> set[str]:
if not output_file.exists():
return set()
completed_ids: set[str] = set()
with output_file.open("r", encoding="utf-8") as file:
for line in file:
stripped = line.strip()
if not stripped:
continue
try:
record = json.loads(stripped)
except json.JSONDecodeError:
continue
question_id = record.get("question_id")
if isinstance(question_id, str) and question_id:
completed_ids.add(question_id)
return completed_ids
def load_questions(questions_file: Path) -> list[QuestionRecord]:
if not questions_file.exists():
raise FileNotFoundError(f"Questions file not found: {questions_file}")
@@ -348,6 +373,7 @@ async def generate_answers(
api_base: str,
api_key: str,
parallelism: int,
skipped: int,
) -> None:
if parallelism < 1:
raise ValueError("`--parallelism` must be at least 1.")
@@ -382,58 +408,178 @@ async def generate_answers(
write_lock = asyncio.Lock()
completed = 0
successful = 0
stuck_count = 0
failed_questions: list[FailedQuestionRecord] = []
total = len(questions)
remaining_count = len(questions)
overall_total = remaining_count + skipped
question_durations: list[float] = []
run_start_time = time.monotonic()
def print_progress() -> None:
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
elapsed = time.monotonic() - run_start_time
eta = avg_time * (remaining_count - completed) / max(parallelism, 1)
done = skipped + completed
bar_width = 30
filled = (
int(bar_width * done / overall_total)
if overall_total
else bar_width
)
bar = "" * filled + "" * (bar_width - filled)
pct = (done / overall_total * 100) if overall_total else 100.0
parts = (
f"\r{bar} {pct:5.1f}% "
f"[{done}/{overall_total}] "
f"avg {avg_time:.1f}s/q "
f"elapsed {elapsed:.0f}s "
f"ETA {eta:.0f}s "
f"(ok:{successful} fail:{len(failed_questions)}"
)
if stuck_count:
parts += f" stuck:{stuck_count}"
if skipped:
parts += f" skip:{skipped}"
parts += ")"
sys.stderr.write(parts)
sys.stderr.flush()
print_progress()
async def process_question(question_record: QuestionRecord) -> None:
nonlocal completed
nonlocal successful
nonlocal stuck_count
try:
async with semaphore:
result = await submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
last_error: Exception | None = None
for attempt in range(1, MAX_QUESTION_ATTEMPTS + 1):
q_start = time.monotonic()
try:
async with semaphore:
result = await asyncio.wait_for(
submit_question(
session=session,
api_base=api_base,
headers=headers,
internal_search_tool_id=internal_search_tool_id,
question_record=question_record,
),
timeout=QUESTION_TIMEOUT_SECONDS,
)
except asyncio.TimeoutError:
async with progress_lock:
stuck_count += 1
logger.warning(
"Question %s timed out after %ss (attempt %s/%s, "
"total stuck: %s) — retrying in %ss",
question_record.question_id,
QUESTION_TIMEOUT_SECONDS,
attempt,
MAX_QUESTION_ATTEMPTS,
stuck_count,
QUESTION_RETRY_PAUSE_SECONDS,
)
print_progress()
last_error = TimeoutError(
f"Timed out after {QUESTION_TIMEOUT_SECONDS}s "
f"on attempt {attempt}/{MAX_QUESTION_ATTEMPTS}"
)
except Exception as exc:
await asyncio.sleep(QUESTION_RETRY_PAUSE_SECONDS)
continue
except Exception as exc:
duration = time.monotonic() - q_start
async with progress_lock:
completed += 1
question_durations.append(duration)
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
remaining_count,
)
print_progress()
return
duration = time.monotonic() - q_start
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
async with progress_lock:
completed += 1
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(exc),
)
)
logger.exception(
"Failed question %s (%s/%s)",
question_record.question_id,
completed,
total,
)
successful += 1
question_durations.append(duration)
print_progress()
return
async with write_lock:
file.write(json.dumps(asdict(result), ensure_ascii=False))
file.write("\n")
file.flush()
# All attempts exhausted due to timeouts
async with progress_lock:
completed += 1
successful += 1
logger.info("Processed %s/%s questions", completed, total)
failed_questions.append(
FailedQuestionRecord(
question_id=question_record.question_id,
error=str(last_error),
)
)
logger.error(
"Question %s failed after %s timeout attempts (%s/%s)",
question_record.question_id,
MAX_QUESTION_ATTEMPTS,
completed,
remaining_count,
)
print_progress()
await asyncio.gather(
*(process_question(question_record) for question_record in questions)
)
# Final newline after progress bar
sys.stderr.write("\n")
sys.stderr.flush()
total_elapsed = time.monotonic() - run_start_time
avg_time = (
sum(question_durations) / len(question_durations)
if question_durations
else 0.0
)
stuck_suffix = f", {stuck_count} stuck timeouts" if stuck_count else ""
resume_suffix = (
f"{skipped} previously completed, "
f"{skipped + successful}/{overall_total} overall"
if skipped
else ""
)
logger.info(
"Done: %s/%s successful in %.1fs (avg %.1fs/question%s)%s",
successful,
remaining_count,
total_elapsed,
avg_time,
stuck_suffix,
resume_suffix,
)
if failed_questions:
logger.warning(
"Completed with %s failed questions and %s successful questions.",
"%s questions failed:",
len(failed_questions),
successful,
)
for failed_question in failed_questions:
logger.warning(
@@ -453,7 +599,30 @@ def main() -> None:
raise ValueError("`--max-questions` must be at least 1 when provided.")
questions = questions[: args.max_questions]
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
completed_ids = load_completed_question_ids(args.output_file)
logger.info(
"Found %s already-answered question IDs in %s",
len(completed_ids),
args.output_file,
)
total_before_filter = len(questions)
questions = [q for q in questions if q.question_id not in completed_ids]
skipped = total_before_filter - len(questions)
if skipped:
logger.info(
"Resuming: %s/%s already answered, %s remaining",
skipped,
total_before_filter,
len(questions),
)
else:
logger.info("Loaded %s questions from %s", len(questions), args.questions_file)
if not questions:
logger.info("All questions already answered. Nothing to do.")
return
logger.info("Writing answers to %s", args.output_file)
asyncio.run(
@@ -463,6 +632,7 @@ def main() -> None:
api_base=api_base,
api_key=args.api_key,
parallelism=args.parallelism,
skipped=skipped,
)
)