Source code for metabeeai.metabeeai_llm.llm_pipeline

import argparse
import asyncio
import copy
import json
import os
import sys
import time

import yaml

from metabeeai.metabeeai_llm.json_multistage_qa import ask_json as ask_json_async
from metabeeai.metabeeai_llm.json_multistage_qa import format_to_list as format_to_list_async


def ask_json(question_text, json_path):
    """
    Asks a question to the JSON file at the specified path and returns the answer.
    """
    return asyncio.run(ask_json_async(question_text, json_path))


def format_to_list(question, text, model="gpt-4o-mini"):
    """
    Formats the JSON file at the specified path to a list.
    """
    return asyncio.run(format_to_list_async(question, text, model))


# ------------------------------------------------------------------------------
# Hierarchical Questions Dictionary
# ------------------------------------------------------------------------------
# Use {placeholder} format syntax in any question that should be parameterized.

# Lazy load questions to avoid import-time errors
_QUESTIONS = None


def _get_questions():
    """
    Lazy loads the questions.yml file when first accessed.
    Returns the questions dictionary.
    """
    global _QUESTIONS
    if _QUESTIONS is None:
        # Get the directory where this script is located
        script_dir = os.path.dirname(os.path.abspath(__file__))
        questions_path = os.path.join(script_dir, "questions.yml")

        with open(questions_path, "r") as file:
            # Extract the questions ('QUESTIONS') dictionary to prevent double-wrapping if it exists
            # Otherwise use the whole config as the questions dictionary
            config = yaml.safe_load(file)
            _QUESTIONS = config.get("QUESTIONS", config)

    return _QUESTIONS


# ------------------------------------------------------------------------------
# Helper Function: get_answer
# ------------------------------------------------------------------------------
async def get_answer(question_text, json_path, relevance_model=None, answer_model=None):
    """
    Retrieves the answer for a given question by calling ask_json.
    Returns a dictionary with the required structure: answer, reason, and chunk_ids.

    Args:
        question_text: The question to ask
        json_path: Path to the JSON file containing text chunks
        relevance_model: Model to use for chunk selection (defaults to config)
        answer_model: Model to use for answer generation and reflection (defaults to config)
    """
    result = await ask_json_async(question_text, json_path, relevance_model=relevance_model, answer_model=answer_model)

    # Ensure the result has the required structure
    if isinstance(result, dict):
        # Extract the required fields from the enhanced result
        return {
            "answer": result.get("answer", ""),
            "reason": result.get("reason", ""),
            "chunk_ids": result.get("chunk_ids", []),
            # Pass total question cost to the pipeline for logging (will be popped before saving to answers.json)
            "cost": result.get("cost", 0.0),
            # Pass aggregated token metrics to the pipeline for logging (will be popped before saving to answers.json)
            "metrics": result.get("metrics", {}),
        }
    else:
        # Fallback if result is not a dict
        return {
            "answer": str(result) if result else "",
            "reason": "Answer generated from available information",
            "chunk_ids": [],
            # Safety fallback - pass zero cost so the logging loop doesn't crash
            # (still popped before saving to answers.json)
            "cost": 0.0,
            # Safety fallback - pass empty metrics dictionary so the logging loop doesn't crash
            # (still popped before saving to answers.json)
            "metrics": {},
        }


# ------------------------------------------------------------------------------
# Generic Recursive Function to Process a Hierarchical Question Tree
# ------------------------------------------------------------------------------
async def process_question_tree(tree, json_path, context=None, relevance_model=None, answer_model=None):
    """
    Recursively traverses the question tree (a nested dictionary) and obtains answers using get_answer.

    Args:
        tree: The question tree structure
        json_path: Path to the JSON file containing text chunks
        context: Context for formatting questions with placeholders
        relevance_model: Model to use for chunk selection (defaults to config)
        answer_model: Model to use for answer generation and reflection (defaults to config)

    - If a node contains a "question" key, it is treated as a leaf node.
    - The "for_each" key indicates that the associated value should be processed for
      each item in a list provided via the context.
    - The context is used to format questions with placeholders.
    """
    if context is None:
        context = {}

    # If the tree is a dictionary
    if isinstance(tree, dict):
        # If this dictionary has a "question" key, treat it as a leaf.
        if "question" in tree:
            question_text = tree["question"].format(**context)
            answer = await get_answer(question_text, json_path, relevance_model=relevance_model, answer_model=answer_model)
            # Process conditional branch if available.
            return answer
        else:
            result = {}
            for key, value in tree.items():
                if key == "list":
                    # If the key is "list", return the list as is.
                    question_of_the_list = value["question"].format(**context)
                    endpoint_name = value["endpoint_name"]
                    answer = await get_answer(
                        question_of_the_list, json_path, relevance_model=relevance_model, answer_model=answer_model
                    )
                    list_result = await format_to_list_async(question_of_the_list, answer["answer"])
                    list_items = list_result["answer"]
                    result[key] = {}

                    # Preserve the parent list costs so extract_metrics can see them
                    # Sum the costs of both finding the list and formatting it into an array
                    total_list_cost = answer.get("cost", 0.0) + list_result.get("cost", 0.0)

                    # Copy the original list token metrics into the combined run metrics tracker
                    combined_metrics = copy.deepcopy(answer.get("metrics", {}))
                    if list_result and "usage_details" in list_result:
                        usage = list_result["usage_details"]
                        m_name = usage.get("model", "unknown")
                        metric_key = f"Answering|{m_name}"

                        # Initialise counters for the model if it's the first time seeing it in this block
                        if metric_key not in combined_metrics:
                            combined_metrics[metric_key] = {"input": 0, "cached": 0, "output": 0, "cost": 0.0}
                        # Add the list formatting token costs to the combined run tracking data
                        combined_metrics[metric_key]["input"] += usage.get("input_tokens", 0)
                        combined_metrics[metric_key]["cached"] += usage.get("cached_tokens", 0)
                        combined_metrics[metric_key]["output"] += usage.get("output_tokens", 0)
                        combined_metrics[metric_key]["cost"] += usage.get("cost", 0.0)

                    # Store the totals in a dictionary key so they can be processed by the logging function
                    result[key]["_list_metadata"] = {
                        "answer": "List discovery metadata",
                        "cost": total_list_cost,
                        "metrics": combined_metrics,
                    }

                    for item in list_items:
                        new_context = context.copy()
                        new_context[endpoint_name] = item
                        result[key][item] = await process_question_tree(
                            value["for_each"],
                            json_path,
                            new_context,
                            relevance_model=relevance_model,
                            answer_model=answer_model,
                        )
                else:
                    result[key] = await process_question_tree(
                        value, json_path, context, relevance_model=relevance_model, answer_model=answer_model
                    )
            return result
    elif isinstance(tree, list):
        return [
            await process_question_tree(item, json_path, context, relevance_model=relevance_model, answer_model=answer_model)
            for item in tree
        ]
    elif isinstance(tree, str):
        # If the tree itself is a string, treat it as a question.
        question_text = tree.format(**context)
        return await get_answer(question_text, json_path, relevance_model=relevance_model, answer_model=answer_model)
    else:
        return tree


# ------------------------------------------------------------------------------
# Main Function: Retrieve All Answers Based on the Questions Dictionary
# ------------------------------------------------------------------------------
[docs] async def get_literature_answers(json_path, relevance_model=None, answer_model=None): """ Processes the entire hierarchical question tree defined in QUESTIONS and returns the collected answers. Args: json_path: Path to the JSON file containing text chunks relevance_model: Model to use for chunk selection (defaults to config) answer_model: Model to use for answer generation and reflection (defaults to config) """ questions = _get_questions() answers = await process_question_tree(questions, json_path, relevance_model=relevance_model, answer_model=answer_model) return answers
# ------------------------------------------------------------------------------ # Main Execution # ------------------------------------------------------------------------------
[docs] def merge_json_in_the_folder(folder_path, overwrite=False): """ Merges all JSON files in the specified folder into a single dictionary. """ if not overwrite: if os.path.exists(folder_path + "merged.json"): print("The file already exists. Set 'overwrite=True' to overwrite.") return chunks_kept = [] for file in os.listdir(folder_path): if file.endswith(".json"): json_path = os.path.join(folder_path, file) with open(json_path, "r") as f: json_obj = json.load(f) chunks = json_obj["data"]["chunks"] for chunk in chunks: if chunk["chunk_type"] in ["figure", "marginalia"]: continue chunks_kept.append(chunk) json_obj = {"data": {"chunks": chunks_kept}} with open(folder_path + "merged.json", "w") as f: json.dump(json_obj, f, indent=2)
def extract_metrics(answers_dict, prefix=""): """ Iterates through the answers dictionary to extract costs and model metrics, removing them from the dictionary. Deletes any temporary tracking keys so they are not in the final JSON. """ costs = {} metrics = {} total = 0.0 # Handle lists if the top-level question yml is an array if isinstance(answers_dict, list): for i, item in enumerate(answers_dict): if isinstance(item, dict): sub_costs, sub_metrics, sub_total = extract_metrics(item, prefix=f"{prefix}[{i}].") costs.update(sub_costs) metrics.update(sub_metrics) total += sub_total return costs, metrics, total # Use list() to modify the dictionary while looping for key in list(answers_dict.keys()): value = answers_dict[key] if isinstance(value, dict): if "metrics" in value and "answer" in value: # Use pop() to extract the data and delete the keys from the dictionary costs[prefix + key] = value.pop("cost", 0.0) metrics[prefix + key] = value.pop("metrics", {}) total += costs[prefix + key] else: # If it's a nested category dictionary, go one layer deeper sub_costs, sub_metrics, sub_total = extract_metrics(value, prefix=f"{prefix}{key}.") # Merge the numbers found in the sub-category into the main tracking dictionaries costs.update(sub_costs) metrics.update(sub_metrics) total += sub_total # If it's a hidden overhead key (starts with _), remove the shell from the dictionary if key.startswith("_"): answers_dict.pop(key) elif isinstance(value, list): # If the value is a list of sub-questions, loop through the list for i, item in enumerate(value): if isinstance(item, dict): # Go one layer deeper into the list items # Adds [i] to the prefix to track which list item the cost belongs to sub_costs, sub_metrics, sub_total = extract_metrics(item, prefix=f"{prefix}{key}[{i}].") # Merge the numbers found in the sub-category into the main tracking dictionaries costs.update(sub_costs) metrics.update(sub_metrics) total += sub_total return costs, metrics, total async def process_papers( base_dir=None, paper_folders=None, overwrite_merged=False, relevance_model=None, answer_model=None, start_folder=None, end_folder=None, ): """ Processes papers in the specified directory. Args: base_dir: Base directory containing paper folders (defaults to config) paper_folders: List of specific paper folder names to process (defaults to all folders) start_folder: Optional start folder (inclusive, alphanumeric) end_folder: Optional end folder (inclusive, alphanumeric) overwrite_merged: Whether to overwrite existing merged.json files relevance_model: Model to use for chunk selection (defaults to config) answer_model: Model to use for answer generation and reflection (defaults to config) """ # Import centralized configuration if base_dir not provided if base_dir is None: from metabeeai.config import get_config_param base_dir = get_config_param("papers_dir") # Validate base directory if not os.path.exists(base_dir): print(f"Error: Base directory '{base_dir}' not found") return # Add trailing slash if missing if not base_dir.endswith("/"): base_dir += "/" # If no specific folders provided, get all subdirectories if paper_folders is None: paper_folders = [] for item in os.listdir(base_dir): item_path = os.path.join(base_dir, item) # Only include directories (not files) if os.path.isdir(item_path) and not item.startswith("."): paper_folders.append(item) paper_folders.sort() # Sort for consistent processing order if start_folder or end_folder: filtered = [] for folder in paper_folders: if start_folder and folder < start_folder: continue if end_folder and folder > end_folder: continue filtered.append(folder) paper_folders = filtered total_papers = len(paper_folders) completed_papers = 0 failed_papers = [] # Create progress log file log_file = os.path.join(base_dir, "processing_log.txt") print(f"🚀 Starting pipeline: {total_papers} papers to process") print(f"📁 Papers directory: {base_dir}") print(f"📝 Progress log: {log_file}") print("=" * 60) # Initialise the global metrics dictionary to hold the total token and cost counts across all files in the run global_metrics = {} for paper_folder in paper_folders: paper_path = os.path.join(base_dir, paper_folder) # Show overall progress remaining = total_papers - completed_papers print(f"\n📊 Progress: {completed_papers}/{total_papers} completed, {remaining} remaining") print(f"🔄 Processing paper {paper_folder}...") # Skip if the paper directory doesn't exist if not os.path.exists(paper_path): print(f"⏭️ Skipping {paper_folder} - directory not found") continue try: pages_path = os.path.join(paper_path, "pages/") if not os.path.exists(pages_path): print(f"⏭️ Skipping {paper_folder} - pages directory not found") continue # Check if merged_v2.json exists json_path = os.path.join(pages_path, "merged_v2.json") if not os.path.exists(json_path): print(f"⏭️ Skipping {paper_folder} - merged_v2.json not found") continue # Process the paper with progress tracking questions = _get_questions() print(f" 📖 Processing {len(questions)} questions...") # Temporarily reduce logging verbosity and suppress all output during processing import logging import sys from io import StringIO # Capture and suppress all output during processing original_stdout = sys.stdout original_stderr = sys.stderr original_log_level = logging.getLogger().level # Suppress all output sys.stdout = StringIO() sys.stderr = StringIO() logging.getLogger().setLevel(logging.ERROR) try: literature_answers = await get_literature_answers( json_path, relevance_model=relevance_model, answer_model=answer_model ) finally: # Restore all output sys.stdout = original_stdout sys.stderr = original_stderr logging.getLogger().setLevel(original_log_level) # Merge with existing answers.json if it exists answers_path = os.path.join(paper_path, "answers.json") # Load existing answers if the file exists existing_answers = {} if os.path.exists(answers_path): try: with open(answers_path, "r") as f: existing_data = json.load(f) # Handle both old format (direct dict) and new format (with QUESTIONS key) if "QUESTIONS" in existing_data: existing_answers = existing_data["QUESTIONS"] else: existing_answers = existing_data print(f" 📝 Found existing answers with {len(existing_answers)} question(s)") except Exception as e: print(f" ⚠️ Could not read existing answers: {e}") # Merge new answers with existing ones # New answers will update existing keys, but won't delete old keys if existing_answers: # Preserve existing answers that aren't in the new results for key in existing_answers: if key not in literature_answers: literature_answers[key] = existing_answers[key] print(f" 🔄 Merged answers: {len(literature_answers)} total question(s)") # Extract the metrics, removing the tracking data from the dictionary # Happens before saving so the data doesn't get written into answers.json costs, question_metrics, total_cost = extract_metrics(literature_answers) # Save the merged results in QUESTIONS format output_data = {"QUESTIONS": literature_answers} with open(answers_path, "w") as f: json.dump(output_data, f, indent=2) completed_papers += 1 print(f" ✅ Paper {paper_folder} completed successfully") # Print and log the metrics log_lines = [] for q_key, q_cost in costs.items(): # Format individual cost metrics for clear reading in the log file log_lines.append(f"- {q_key}: ${q_cost:.4f}") # Collect token usage totals into global_metrics for the final summary q_m = question_metrics.get(q_key, {}) for model_name, tokens in q_m.items(): if model_name not in global_metrics: global_metrics[model_name] = {"input": 0, "cached": 0, "output": 0, "cost": 0.0} global_metrics[model_name]["input"] += tokens["input"] global_metrics[model_name]["cached"] += tokens["cached"] global_metrics[model_name]["output"] += tokens["output"] global_metrics[model_name]["cost"] += tokens["cost"] # Time stamp with the total single-paper cost, saved to tracking log timestamp = time.strftime("%Y-%m-%d %H:%M:%S") log_lines.append(f"{paper_folder}: COMPLETED at {timestamp} | Total Cost: ${total_cost:.4f}\n") # Log completion with open(log_file, "a") as f: f.write("\n".join(log_lines)) except Exception as e: print(f" ❌ Error processing paper {paper_folder}: {str(e)}") failed_papers.append(paper_folder) # Log failure with open(log_file, "a") as f: f.write(f"{paper_folder}: FAILED at {time.strftime('%Y-%m-%d %H:%M:%S')} - {str(e)}\n") continue # Final summary print("\n" + "=" * 60) print("🎉 PIPELINE COMPLETED!") print(f"✅ Successfully processed: {completed_papers}/{total_papers} papers") if failed_papers: print(f"❌ Failed papers: {', '.join(failed_papers)}") print(f"📝 Detailed log: {log_file}") # Full run information for the end of the processing log file # Executes once the whole stack of folders has finished processing if global_metrics: # Calculate the overall total costs and total tokens used across all files processed overall_total_cost = sum(m["cost"] for m in global_metrics.values()) overall_total_tokens = sum(m["input"] + m["cached"] + m["output"] for m in global_metrics.values()) # Title for readability summary_lines = ["\nFull Run Totals:"] # Iterate through global_metrics to generate the final text lines for the log file for model_key, m in global_metrics.items(): # Separate the combined key to identify the phase and model name strings if "|" in model_key: phase, model_name = model_key.split("|", 1) model_display = f"{phase} Model: {model_name}" else: # Use the full key as the display name if the data format doesn't have the '|' splitting it model_display = f"Model: {model_key}" # Format the metrics into an easy to read text block with all the usage information summary_lines.append( f"{model_display}\n" f"New Input Tokens: {m['input']}\n" f"Cached Input Tokens: {m['cached']}\n" f"Output Tokens: {m['output']}\n" f"Overall Cost: ${m['cost']:.4f}\n" ) # Add the combined run totals to the end of the summary summary_lines.append(f"Total Tokens Combined: {overall_total_tokens}") summary_lines.append(f"Total Cost Combined: ${overall_total_cost:.4f}\n") # Write the summary into the log file at the end of the run with open(log_file, "a") as f: f.write("\n".join(summary_lines)) def main(argv=None): """Main entry point.""" if argv is None: argv = sys.argv[1:] parser = argparse.ArgumentParser(description="Process paper folders to extract literature answers") # YAML config file path (sets METABEEAI_CONFIG_FILE for downstream lookups) parser.add_argument( "--config", "--config-file", dest="config", type=str, default=None, help="Path to config YAML file (overrides METABEEAI_CONFIG_FILE and defaults)", ) # Base directory and selection of folders parser.add_argument( "--dir", type=str, default=None, help="Base directory containing paper folders (default: auto-detect from config)" ) parser.add_argument( "--papers", type=str, nargs="+", default=None, help=("Specific paper IDs to process (e.g., 283C6B42 3ZHNVADM). If not specified, all folders will be processed."), ) parser.add_argument( "--start", type=str, default=None, help="Start processing from this paper ID (alphanumeric, optional; only applies when --papers is not set)", ) parser.add_argument( "--end", type=str, default=None, help="End processing at this paper ID (alphanumeric, optional; only applies when --papers is not set)", ) parser.add_argument("--overwrite", action="store_true", help="Overwrite existing merged.json files") # Models parser.add_argument( "--relevance-model", type=str, default=None, help="Model to use for chunk selection (e.g., 'openai/gpt-4o-mini', 'openai/gpt-4o'). Default: from config", ) parser.add_argument( "--answer-model", type=str, default=None, help=( "Model to use for answer generation and reflection (e.g., 'openai/gpt-4o-mini', 'openai/gpt-4o'). " "Default: from config" ), ) # Preset selector (fast/balanced/quality) parser.add_argument( "--preset", type=str, choices=["fast", "balanced", "quality"], default=None, help="Use predefined configuration preset: 'fast', 'balanced', or 'quality'", ) args = parser.parse_args(argv) # Respect provided config file for downstream lookups if args.config: os.environ["METABEEAI_CONFIG_FILE"] = args.config # Handle predefined configurations if args.preset: from metabeeai.metabeeai_llm.pipeline_config import BALANCED_CONFIG, FAST_CONFIG, QUALITY_CONFIG config_map = {"fast": FAST_CONFIG, "balanced": BALANCED_CONFIG, "quality": QUALITY_CONFIG} selected_config = config_map[args.preset] # Override model arguments with config values if not explicitly provided if args.relevance_model is None: args.relevance_model = selected_config["relevance_model"] if args.answer_model is None: args.answer_model = selected_config["answer_model"] print(f"🔧 Using {args.preset.upper()} configuration:") print(f" Relevance Model: {args.relevance_model}") print(f" Answer Model: {args.answer_model}") print(f" Description: {selected_config['description']}") import asyncio asyncio.run( process_papers( base_dir=args.dir, paper_folders=args.papers, overwrite_merged=args.overwrite, relevance_model=args.relevance_model, answer_model=args.answer_model, start_folder=args.start, end_folder=args.end, ) ) if __name__ == "__main__": main(sys.argv[1:])