Source code for jobflow_remote.cli.batch

from typing import Annotated

import typer
from pydantic import ValidationError
from rich.prompt import Confirm
from rich.text import Text

from jobflow_remote.cli.formatting import format_batch_info, get_batch_processes_table
from jobflow_remote.cli.jf import app
from jobflow_remote.cli.jfr_typer import JFRTyper
from jobflow_remote.cli.types import (
    batch_state_opt,
    db_ids_opt,
    job_ids_indexes_opt,
    max_results_opt,
    verbosity_opt,
    worker_name_opt,
    yes_opt,
)
from jobflow_remote.cli.utils import (
    check_valid_uuid,
    exit_with_error_msg,
    exit_with_warning_msg,
    get_config_manager,
    get_job_controller,
    get_job_ids_indexes,
    loading_spinner,
    out_console,
    print_success_msg,
)
from jobflow_remote.jobs.state import BatchState

app_batch = JFRTyper(
    name="batch", help="Helper utils handling batch jobs", no_args_is_help=True
)
app.add_typer(app_batch)


def _check_exception_dev_version(exc: ValidationError):
    """
    Helper function to handle the change in the "jobs" type in the BatchDoc object.

    Initially defined as dict was switched to a list. Users may have this in their DB
    if using development version.

    Parameters
    ----------
    exc
        The pydantic ValidationError exception to be verified.
    """
    # TODO consider removing this function in the future
    try:
        # capture all possible to avoid that if some keys are not present it
        # fails with a confusing message.
        val_errors = exc.errors()
        if (
            exc.error_count() == 1
            and val_errors[0]["loc"] == ("jobs",)
            and isinstance(val_errors[0]["input"], dict)
        ):
            exit_with_error_msg(
                "It seems that you have used a development version of jobflow-remote. The internal format"
                "of the batch jobs have changed before the release. Please run 'jf batch fix-batch-doc-jobs-dict'"
                " to upgrade the database content (note that this command is hidden from the help as it is "
                "only needed for this error)."
            )
    except Exception:
        pass


[docs] @app_batch.command(name="list") def processes_list( worker_name: worker_name_opt = None, max_results: max_results_opt = 20, batch_state: batch_state_opt = None, job_id: job_ids_indexes_opt = None, db_id: db_ids_opt = None, verbosity: verbosity_opt = 0, ) -> None: """ Show the list of processes being executed on the batch workers. """ jc = get_job_controller() cm = get_config_manager() project = cm.get_project() workers = project.workers job_ids_indexes = get_job_ids_indexes(job_id) try: with loading_spinner(): batch_processes = jc.get_batches( worker=worker_name, batch_state=batch_state, job_ids=job_ids_indexes, db_ids=db_id, limit=max_results, ) except ValidationError as exc: _check_exception_dev_version(exc=exc) raise if not batch_processes: exit_with_warning_msg("No batch processes") table = get_batch_processes_table( batch_processes=batch_processes, workers=workers, verbosity=verbosity, status=True, title="Batches info", ) out_console.print(table)
[docs] @app_batch.command(name="info") def process_info( selected_id: Annotated[ str, typer.Argument( help="The ID of the batch process. Can be the Process id (i.e. the one coming from the worker) or batch UID", metavar="ID", ), ], ): """Detailed information on a specific batch process.""" process_id = batch_uid = None if check_valid_uuid(selected_id, raise_on_error=False): batch_uid = selected_id else: process_id = selected_id jc = get_job_controller() cm = get_config_manager() project = cm.get_project() workers = project.workers with loading_spinner(): batch_processes = jc.get_batches( process_id=process_id, batch_uid=batch_uid, limit=2, ) if not batch_processes: exit_with_warning_msg("No batch process matching the request") if len(batch_processes) > 1: exit_with_error_msg( "More than one document matches the selection criteria. User 'jf batch list' to identify a" "unique criteria for the batch process that you want to visualize" ) worker = workers[batch_processes[0].worker] out_console.print( format_batch_info(batch_processes[0], worker=worker), overflow="crop" )
[docs] @app_batch.command() def delete( # batch_state defined like this to avoid typing issues batch_state: Annotated[ list[BatchState], typer.Option( "--state", "-s", help="One or more of the batch states", default_factory=lambda: [BatchState.FINISHED.value], show_default=f"{BatchState.FINISHED.value}", ), ], process_id: Annotated[ str | None, typer.Option( "--process-id", "-pid", help="One or more process ids", ), ] = None, batch_uid: Annotated[ str | None, typer.Option( "--batch-uid", "-bid", help="One or more batch unique ids", ), ] = None, worker_name: worker_name_opt = None, yes_all: yes_opt = False, ): """Remove one or more batch processes from the database. No effect on the processes running on the worker.""" jc = get_job_controller() try: with loading_spinner(): to_delete = jc.get_batches( process_id=process_id, batch_uid=batch_uid, worker=worker_name, batch_state=batch_state, ) except ValidationError as exc: _check_exception_dev_version(exc) raise if not to_delete: exit_with_warning_msg("No batch process matching the request") if not yes_all: if len(set(batch_state).difference([BatchState.FINISHED])) > 0 and any( b.batch_state != BatchState.FINISHED for b in to_delete ): text = Text.from_markup( "[red]This operation may remove batch processes in states other than 'FINISHED'. " "This could lead to [bold]inconsistencies or data loss[/bold]. " "Overall it will [bold]delete {len(to_delete)} batch processes[/bold]. Proceed anyway?[/red]" ) else: text = Text.from_markup( f"[red]This operation will [bold]delete {len(to_delete)} batch processes[/bold]. Proceed anyway?[/red]" ) confirmed = Confirm.ask(text, default=False) if not confirmed: raise typer.Exit(0) with loading_spinner(): n_deleted = jc.delete_batches( process_id=[td.process_id for td in to_delete], ) print_success_msg(f"Operation completed. {n_deleted} batch processes deleted")
[docs] @app_batch.command(hidden=True) def fix_batch_doc_jobs_dict(): jc = get_job_controller() n_fixed = jc.fix_batch_doc_jobs() out_console.print(f"{n_fixed} batch documents modified")