Source code for jobflow_remote.cli.utils

# ruff: noqa: PLW0602, PLW0603

from __future__ import annotations

import functools
import json
import logging
import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta
from enum import Enum
from typing import TYPE_CHECKING, Callable, NoReturn

import typer
from click import ClickException
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.prompt import Confirm
from rich.text import Text

from jobflow_remote import ConfigManager, JobController
from jobflow_remote.config.base import ProjectUndefinedError
from jobflow_remote.jobs.daemon import DaemonError, DaemonManager, DaemonStatus

if TYPE_CHECKING:
    from cProfile import Profile

    from jobflow_remote.jobs.state import JobState

logger = logging.getLogger(__name__)


err_console = Console(stderr=True)
out_console = Console()


fmt_datetime = "%Y-%m-%d %H:%M"


# shared instances of the config manager and job controller, to avoid parsing
# the files multiple times. Needs to be initialized with the
# initialize_config_manager function.
_shared_config_manager: ConfigManager | None = None
_shared_job_controller: JobController | None = None

_profiler: Profile | None = None


[docs] def initialize_config_manager(*args, **kwargs) -> None: global _shared_config_manager _shared_config_manager = ConfigManager(*args, **kwargs)
[docs] def get_config_manager() -> ConfigManager: global _shared_config_manager if not _shared_config_manager: raise RuntimeError("The shared config manager needs to be initialized") return _shared_config_manager
[docs] def get_job_controller(): global _shared_job_controller if _shared_job_controller is None: cm = get_config_manager() jc = JobController.from_project(cm.get_project()) _shared_job_controller = jc return _shared_job_controller
[docs] def cleanup_job_controller() -> None: global _shared_job_controller if _shared_job_controller is not None: _shared_job_controller.close() # set to None again, in case it needs to be used again in the same # execution (e.g., in tests) _shared_job_controller = None
[docs] def start_profiling() -> None: global _profiler from cProfile import Profile _profiler = Profile() _profiler.enable()
[docs] def complete_profiling() -> None: global _profiler _profiler.disable() import pstats stats = pstats.Stats(_profiler).sort_stats("cumtime") stats.print_stats()
[docs] class SortOption(str, Enum): CREATED_ON = "created_on" UPDATED_ON = "updated_on" DB_ID = "db_id"
[docs] class SerializeFileFormat(str, Enum): JSON = "json" YAML = "yaml" TOML = "toml"
[docs] class IndexDirection(str, Enum): ASC = "asc" DESC = "desc" @property def as_pymongo(self): import pymongo return { IndexDirection.ASC: pymongo.ASCENDING, IndexDirection.DESC: pymongo.DESCENDING, }[self]
[docs] class ReprStr(str): r""" Helper class that overrides the standard __repr__ to return the string itself and not its repr(). Used mainly to allow printing of strings with newlines instead of '\n' when repr is used in rich. """ __slots__ = () def __repr__(self) -> str: return self
[docs] def exit_with_error_msg(message: str, code: int = 1, **kwargs) -> NoReturn: kwargs.setdefault("style", "red") err_console.print(message, **kwargs) raise typer.Exit(code)
[docs] def exit_with_warning_msg(message: str, code: int = 0, **kwargs) -> NoReturn: kwargs.setdefault("style", "gold1") err_console.print(message, **kwargs) raise typer.Exit(code)
[docs] def check_incompatible_opt(d: dict) -> None: not_none = [] for k, v in d.items(): if v: not_none.append(k) if len(not_none) > 1: options_list = ", ".join(not_none) exit_with_error_msg(f"Options {options_list} are incompatible")
[docs] def check_query_incompatibility(query, incompatible_options): if query and any(opt is not None for opt in incompatible_options): exit_with_error_msg( "The --query option is incompatible with all the other filtering options" )
[docs] def check_at_least_one_opt(d: dict) -> None: not_none = [] for k, v in d.items(): if v: not_none.append(k) if len(not_none) > 1: options_list = ", ".join(d) exit_with_error_msg( f"At least one of the options {options_list} should be defined" )
[docs] def check_only_one_opt(d: dict) -> None: not_none = [] for k, v in d.items(): if v: not_none.append(k) if len(not_none) != 1: options_list = ", ".join(d) exit_with_error_msg( f"One and only one of the options {options_list} should be defined" )
[docs] @contextmanager def loading_spinner(processing: bool = True): with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), transient=True, ) as progress: if processing: progress.add_task(description="Processing...", total=None) yield progress
[docs] @contextmanager def hide_progress(progress: Progress): """ Hide the progress bar or spinning icon if an input is required from the user. Adapted from a related github issue for rich: https://github.com/Textualize/rich/issues/1535#issuecomment-1745297594 Parameters ---------- progress The Progress object in use """ transient = progress.live.transient # save the old value progress.live.transient = True progress.stop() progress.live.transient = transient # restore the old value try: yield finally: # make space for the progress to use so it doesn't overwrite any previous lines progress.start()
[docs] def get_job_db_ids(job_db_id: str, job_index: int | None): if check_valid_uuid(job_db_id, raise_on_error=False): db_id = None job_id = job_db_id else: db_id = job_db_id job_id = None if job_index and db_id is not None: out_console.print( "The index is defined while a db_id is passed as an ID. Will be ignored", style="yellow", ) return db_id, job_id
[docs] def get_job_ids_indexes(job_ids: list[str] | None) -> list[tuple[str, int]] | None: if not job_ids: return None job_ids_indexes = [] for j in job_ids: split = j.split(":") if len(split) != 2 or not split[1].isnumeric(): raise typer.BadParameter( "The job id should be in the format UUID:INDEX " "(e.g. e1d66c4f-81db-4fff-bda2-2bf1d79d5961:2). " f"Wrong format for {j}" ) check_valid_uuid(split[0]) job_ids_indexes.append((split[0], int(split[1]))) return job_ids_indexes
[docs] def cli_error_handler(func): @functools.wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except (typer.Exit, typer.Abort, ClickException): raise # Do not capture click or typer exceptions except ProjectUndefinedError: exit_with_error_msg( "The active project could not be determined and it is required to execute this command" ) except Exception as e: from jobflow_remote import SETTINGS if SETTINGS.cli_full_exc: raise # Reraise exceptions to print the full stacktrace exit_with_error_msg( f"An Error occurred during the command execution: {type(e).__name__} {getattr(e, 'message', str(e))}" ) return wrapper
[docs] def check_valid_uuid(uuid_str, raise_on_error: bool = True) -> bool: try: uuid_obj = uuid.UUID(uuid_str) if str(uuid_obj) == uuid_str: return True except ValueError: pass if raise_on_error: raise typer.BadParameter(f"UUID {uuid_str} is in the wrong format.") return False
[docs] def str_to_dict(string: str | None) -> dict | None: if not string: return None try: dictionary = json.loads(string) except json.JSONDecodeError as exc: dictionary = {} for chunk in string.split(","): split = chunk.split("=") if len(split) != 2: raise typer.BadParameter( f"Wrong format for dictionary-like field {string}" ) from exc dictionary[split[0]] = split[1] return dictionary
[docs] def get_start_date(start_date: datetime | None, days: int | None, hours: int | None): if start_date and (start_date.year, start_date.month, start_date.day) == ( 1900, 1, 1, ): now = datetime.now() start_date = start_date.replace(year=now.year, month=now.month, day=now.day) if start_date > now: start_date = start_date - timedelta(days=1) elif days: start_date = datetime.now() - timedelta(days=days) elif hours: start_date = datetime.now() - timedelta(hours=hours) return start_date
[docs] def execute_multi_jobs_cmd( single_cmd: Callable, multi_cmd: Callable, job_db_id: str | None = None, job_index: int | None = None, job_ids: list[str] | None = None, db_ids: str | list[str] | None = None, flow_ids: str | list[str] | None = None, states: JobState | list[JobState] | None = None, start_date: datetime | None = None, end_date: datetime | None = None, name: str | None = None, metadata: dict | None = None, days: int | None = None, hours: int | None = None, workers: list[str] | None = None, custom_query: dict | None = None, verbosity: int = 0, raise_on_error: bool = False, **kwargs, ) -> None: query_values = [ job_ids, db_ids, flow_ids, states, start_date, end_date, name, metadata, days, hours, workers, custom_query, ] try: if job_db_id is not None: if any(query_values): msg = "If job_db_id is defined all the other query options should be disabled" exit_with_error_msg(msg) db_id, job_id = get_job_db_ids(job_db_id, job_index) with loading_spinner(): modified_ids = single_cmd( job_id=job_id, job_index=job_index, db_id=db_id, **kwargs ) if not isinstance(modified_ids, (list, tuple)): modified_ids = [] if modified_ids is None else [modified_ids] if not modified_ids: exit_with_error_msg("Could not perform the requested operation") else: check_incompatible_opt( {"start_date": start_date, "days": days, "hours": hours} ) check_incompatible_opt({"end_date": end_date, "days": days, "hours": hours}) check_query_incompatibility( custom_query, [ job_ids, db_ids, flow_ids, states, start_date, end_date, name, metadata, days, hours, workers, ], ) job_ids_indexes = get_job_ids_indexes(job_ids) start_date = get_start_date(start_date, days, hours) if not any( ( job_ids_indexes, db_ids, flow_ids, states, start_date, end_date, name, metadata, workers, custom_query, ) ): text = Text.from_markup( "[yellow]No filter has been set. This will apply the change to all " "the jobs in the DB. Proceed anyway?[/yellow]" ) confirmed = Confirm.ask(text, default=False) if not confirmed: raise typer.Exit(0) # noqa: TRY301 with loading_spinner(): modified_ids = multi_cmd( job_ids=job_ids_indexes, db_ids=db_ids, flow_ids=flow_ids, states=states, start_date=start_date, end_date=end_date, name=name, metadata=metadata, workers=workers, custom_query=custom_query, raise_on_error=raise_on_error, **kwargs, ) if verbosity: print_success_msg(f"Operation completed. Modified jobs: {modified_ids}") else: print_success_msg(f"Operation completed: {len(modified_ids)} jobs modified") except Exception: logger.exception("Error executing the operation")
[docs] def check_stopped_runner(error: bool = True) -> None: cm = get_config_manager() dm = DaemonManager.from_project(cm.get_project()) try: with loading_spinner(processing=False) as progress: progress.add_task(description="Checking the Daemon status...", total=None) current_status = dm.check_status() except DaemonError as e: exit_with_error_msg( f"Error while checking the status of the daemon: {getattr(e, 'message', str(e))}" ) if current_status not in (DaemonStatus.STOPPED, DaemonStatus.SHUT_DOWN): if error: exit_with_error_msg( f"The status of the daemon is {current_status.value}. " "The daemon should not be running while resetting the database" ) else: text = Text.from_markup( "[red]The Runner is active. This operation may lead to " "inconsistencies in this case. Proceed anyway?[/red]" ) confirmed = Confirm.ask(text, default=False) if not confirmed: raise typer.Exit(0)