# ruff: noqa: PLW0602, PLW0603
from __future__ import annotations
import contextlib
import functools
import json
import logging
import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta
from enum import Enum
from typing import TYPE_CHECKING, Callable, NoReturn
import typer
from click import ClickException
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.prompt import Confirm
from rich.text import Text, TextType
from rich.tree import Tree
from typer.core import TyperCommand, TyperGroup
from jobflow_remote import ConfigManager, JobController
from jobflow_remote.config.base import ProjectUndefinedError
from jobflow_remote.jobs.daemon import DaemonError, DaemonManager, DaemonStatus
if TYPE_CHECKING:
from cProfile import Profile
from jobflow_remote.jobs.state import JobState
logger = logging.getLogger(__name__)
err_console = Console(stderr=True)
out_console = Console()
fmt_datetime = "%Y-%m-%d %H:%M"
# shared instances of the config manager and job controller, to avoid parsing
# the files multiple times. Needs to be initialized with the
# initialize_config_manager function.
_shared_config_manager: ConfigManager | None = None
_shared_job_controller: JobController | None = None
_profiler: Profile | None = None
[docs]
def initialize_config_manager(*args, **kwargs) -> None:
global _shared_config_manager
_shared_config_manager = ConfigManager(*args, **kwargs)
[docs]
def get_config_manager() -> ConfigManager:
global _shared_config_manager
if not _shared_config_manager:
raise RuntimeError("The shared config manager needs to be initialized")
return _shared_config_manager
[docs]
def get_job_controller():
global _shared_job_controller
if _shared_job_controller is None:
cm = get_config_manager()
jc = JobController.from_project(cm.get_project())
_shared_job_controller = jc
return _shared_job_controller
[docs]
def cleanup_job_controller() -> None:
global _shared_job_controller
if _shared_job_controller is not None:
_shared_job_controller.close()
# set to None again, in case it needs to be used again in the same
# execution (e.g., in tests)
_shared_job_controller = None
[docs]
def start_profiling() -> None:
global _profiler
from cProfile import Profile
_profiler = Profile()
_profiler.enable()
[docs]
def complete_profiling() -> None:
global _profiler
_profiler.disable()
import pstats
stats = pstats.Stats(_profiler).sort_stats("cumtime")
stats.print_stats()
[docs]
class SortOption(str, Enum):
CREATED_ON = "created_on"
UPDATED_ON = "updated_on"
DB_ID = "db_id"
[docs]
class IndexDirection(str, Enum):
ASC = "asc"
DESC = "desc"
@property
def as_pymongo(self):
import pymongo
return {
IndexDirection.ASC: pymongo.ASCENDING,
IndexDirection.DESC: pymongo.DESCENDING,
}[self]
[docs]
class ReportInterval(str, Enum):
HOURS = "hours"
DAYS = "days"
WEEKS = "weeks"
MONTHS = "months"
YEARS = "years"
[docs]
class ReprStr(str):
r"""
Helper class that overrides the standard __repr__ to return the string itself
and not its repr().
Used mainly to allow printing of strings with newlines instead of '\n' when
repr is used in rich.
"""
__slots__ = ()
def __repr__(self) -> str:
return self
[docs]
def exit_with_error_msg(message: TextType, code: int = 1, **kwargs) -> NoReturn:
kwargs.setdefault("style", "red")
err_console.print(message, **kwargs)
raise typer.Exit(code)
[docs]
def exit_with_warning_msg(message: TextType, code: int = 0, **kwargs) -> NoReturn:
kwargs.setdefault("style", "gold1")
err_console.print(message, **kwargs)
raise typer.Exit(code)
[docs]
def print_success_msg(message: TextType = "operation completed", **kwargs) -> None:
kwargs.setdefault("style", "green")
out_console.print(message, **kwargs)
[docs]
def check_incompatible_opt(d: dict) -> None:
not_none = []
for k, v in d.items():
if v:
not_none.append(k)
if len(not_none) > 1:
options_list = ", ".join(not_none)
exit_with_error_msg(f"Options {options_list} are incompatible")
[docs]
def check_query_incompatibility(query, incompatible_options):
if query and any(opt is not None for opt in incompatible_options):
exit_with_error_msg(
"The --query option is incompatible with all the other filtering options"
)
[docs]
def check_at_least_one_opt(d: dict) -> None:
not_none = []
for k, v in d.items():
if v:
not_none.append(k)
if len(not_none) > 1:
options_list = ", ".join(d)
exit_with_error_msg(
f"At least one of the options {options_list} should be defined"
)
[docs]
def check_only_one_opt(d: dict) -> None:
not_none = []
for k, v in d.items():
if v:
not_none.append(k)
if len(not_none) != 1:
options_list = ", ".join(d)
exit_with_error_msg(
f"One and only one of the options {options_list} should be defined"
)
[docs]
@contextmanager
def loading_spinner(processing: bool = True):
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=True,
) as progress:
if processing:
progress.add_task(description="Processing...", total=None)
yield progress
[docs]
@contextmanager
def hide_progress(progress: Progress):
"""
Hide the progress bar or spinning icon if an input is required from the user.
Adapted from a related github issue for rich:
https://github.com/Textualize/rich/issues/1535#issuecomment-1745297594
Parameters
----------
progress
The Progress object in use
"""
transient = progress.live.transient # save the old value
progress.live.transient = True
progress.stop()
progress.live.transient = transient # restore the old value
try:
yield
finally:
# make space for the progress to use so it doesn't overwrite any previous lines
progress.start()
[docs]
def get_job_db_ids(job_db_id: str, job_index: int | None):
if check_valid_uuid(job_db_id, raise_on_error=False):
db_id = None
job_id = job_db_id
else:
db_id = job_db_id
job_id = None
if job_index and db_id is not None:
out_console.print(
"The index is defined while a db_id is passed as an ID. Will be ignored",
style="yellow",
)
return db_id, job_id
[docs]
def get_job_ids_indexes(job_ids: list[str] | None) -> list[tuple[str, int]] | None:
if not job_ids:
return None
job_ids_indexes = []
for j in job_ids:
split = j.split(":")
if len(split) != 2 or not split[1].isnumeric():
raise typer.BadParameter(
"The job id should be in the format UUID:INDEX "
"(e.g. e1d66c4f-81db-4fff-bda2-2bf1d79d5961:2). "
f"Wrong format for {j}"
)
check_valid_uuid(split[0])
job_ids_indexes.append((split[0], int(split[1])))
return job_ids_indexes
[docs]
def cli_error_handler(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except (typer.Exit, typer.Abort, ClickException):
raise # Do not capture click or typer exceptions
except ProjectUndefinedError:
exit_with_error_msg(
"The active project could not be determined and it is required to execute this command"
)
except Exception as e:
from jobflow_remote import SETTINGS
if SETTINGS.cli_full_exc:
raise # Reraise exceptions to print the full stacktrace
exit_with_error_msg(
f"An Error occurred during the command execution: {type(e).__name__} {getattr(e, 'message', str(e))}"
)
return wrapper
[docs]
def check_valid_uuid(uuid_str, raise_on_error: bool = True) -> bool:
try:
uuid_obj = uuid.UUID(uuid_str)
if str(uuid_obj) == uuid_str:
return True
except ValueError:
pass
if raise_on_error:
raise typer.BadParameter(f"UUID {uuid_str} is in the wrong format.")
return False
[docs]
def str_to_dict(string: str | None) -> dict | None:
if not string:
return None
try:
dictionary = json.loads(string)
except json.JSONDecodeError as exc:
dictionary = {}
for chunk in string.split(","):
split = chunk.split("=")
if len(split) != 2:
raise typer.BadParameter(
f"Wrong format for dictionary-like field {string}"
) from exc
dictionary[split[0]] = split[1]
return dictionary
[docs]
def get_start_date(start_date: datetime | None, days: int | None, hours: int | None):
if start_date and (start_date.year, start_date.month, start_date.day) == (
1900,
1,
1,
):
now = datetime.now()
start_date = start_date.replace(year=now.year, month=now.month, day=now.day)
if start_date > now:
start_date = start_date - timedelta(days=1)
elif days:
start_date = datetime.now() - timedelta(days=days)
elif hours:
start_date = datetime.now() - timedelta(hours=hours)
return start_date
[docs]
def execute_multi_jobs_cmd(
single_cmd: Callable,
multi_cmd: Callable,
job_db_id: str | None = None,
job_index: int | None = None,
job_ids: list[str] | None = None,
db_ids: str | list[str] | None = None,
flow_ids: str | list[str] | None = None,
states: JobState | list[JobState] | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
name: str | None = None,
metadata: dict | None = None,
days: int | None = None,
hours: int | None = None,
workers: list[str] | None = None,
custom_query: dict | None = None,
verbosity: int = 0,
raise_on_error: bool = False,
interactive: bool = True,
**kwargs,
) -> None:
"""
Utility function to execute a command on a single job or on a set of jobs.
Checks the query options, determine whether the command is single or multi
and call the corresponding function.
Parameters
----------
single_cmd
A function that takes the job_id and job_index as arguments and performs an
action on a single job.
multi_cmd
A function that takes a set of standard arguments and performs an
action on multiple jobs.
job_db_id
The job id or db_id of a job to be considered. If specified, all the other
query options should be disabled.
job_index
The index of the job in the DB.
job_ids
A list of job ids to be considered.
db_ids
A list of db ids to be considered.
flow_ids
A list of flow ids to be considered.
states
The states of the jobs to be considered.
start_date
The start date of the jobs to be considered.
end_date
The end date of the jobs to be considered.
name
The name of the jobs to be considered.
metadata
The metadata of the jobs to be considered.
days
Set the start_date based on the number of past days.
hours
Set the start_date based on the number of past hours.
workers
Workers associated with the jobs to be considered.
custom_query
A custom query to be used to filter the jobs.
verbosity
The verbosity of the output.
raise_on_error
If True, an error will be raised if the operation fails.
interactive
If True, a spinner will be shown even if the operation is not interactive while
the operation is being performed.
**kwargs : dict
Additional arguments to be passed to the single_cmd and multi_cmd functions.
"""
query_values = [
job_ids,
db_ids,
flow_ids,
states,
start_date,
end_date,
name,
metadata,
days,
hours,
workers,
custom_query,
]
# if potentially interactive do not start the spinner.
cm = get_config_manager()
spinner_cm: contextlib.AbstractContextManager
if interactive and cm.get_project().has_interactive_workers:
spinner_cm = contextlib.nullcontext()
out_console.print("Processing...")
else:
spinner_cm = loading_spinner()
try:
if job_db_id is not None:
if any(query_values):
msg = "If job_db_id is defined all the other query options should be disabled"
exit_with_error_msg(msg)
db_id, job_id = get_job_db_ids(job_db_id, job_index)
with spinner_cm:
modified_ids = single_cmd(
job_id=job_id, job_index=job_index, db_id=db_id, **kwargs
)
if not isinstance(modified_ids, (list, tuple)):
modified_ids = [] if modified_ids is None else [modified_ids]
if not modified_ids:
exit_with_error_msg("Could not perform the requested operation")
else:
check_incompatible_opt(
{"start_date": start_date, "days": days, "hours": hours}
)
check_incompatible_opt({"end_date": end_date, "days": days, "hours": hours})
check_query_incompatibility(
custom_query,
[
job_ids,
db_ids,
flow_ids,
states,
start_date,
end_date,
name,
metadata,
days,
hours,
workers,
],
)
job_ids_indexes = get_job_ids_indexes(job_ids)
start_date = get_start_date(start_date, days, hours)
if not any(
(
job_ids_indexes,
db_ids,
flow_ids,
states,
start_date,
end_date,
name,
metadata,
workers,
custom_query,
)
):
text = Text.from_markup(
"[yellow]No filter has been set. This will apply the change to all "
"the jobs in the DB. Proceed anyway?[/yellow]"
)
confirmed = Confirm.ask(text, default=False)
if not confirmed:
raise typer.Exit(0) # noqa: TRY301
with spinner_cm:
modified_ids = multi_cmd(
job_ids=job_ids_indexes,
db_ids=db_ids,
flow_ids=flow_ids,
states=states,
start_date=start_date,
end_date=end_date,
name=name,
metadata=metadata,
workers=workers,
custom_query=custom_query,
raise_on_error=raise_on_error,
**kwargs,
)
if verbosity:
print_success_msg(f"Operation completed. Modified jobs: {modified_ids}")
else:
print_success_msg(f"Operation completed: {len(modified_ids)} jobs modified")
except Exception:
logger.exception("Error executing the operation")
[docs]
def check_stopped_runner(error: bool = True) -> None:
cm = get_config_manager()
dm = DaemonManager.from_project(cm.get_project())
try:
with loading_spinner(processing=False) as progress:
progress.add_task(description="Checking the Daemon status...", total=None)
current_status = dm.check_status()
except DaemonError as e:
exit_with_error_msg(
f"Error while checking the status of the daemon: {getattr(e, 'message', str(e))}"
)
# the current status is only local. Also check that the DB does not contain
# a running_runner document that does not match the current machine.
runner_stopped = False
if current_status in (DaemonStatus.STOPPED, DaemonStatus.SHUT_DOWN):
matching_error = dm.check_matching_runner(allow_missing=True)
if matching_error:
logger.warning(matching_error)
else:
runner_stopped = True
if not runner_stopped:
if error:
exit_with_error_msg(
f"The status of the daemon is {current_status.value}. "
"The daemon should not be running while performing this operation"
)
else:
text = Text.from_markup(
"[red]The Runner is active. This operation may lead to "
"inconsistencies in this case. Proceed anyway?[/red]"
)
confirmed = Confirm.ask(text, default=False)
if not confirmed:
raise typer.Exit(0)
[docs]
def get_command_tree(
app: TyperGroup,
tree: Tree,
show_options: bool,
show_docs: bool,
show_hidden: bool,
max_depth: int | None,
current_depth: int = 0,
max_doc_lines: int | None = None,
) -> Tree:
"""
Recursively build a tree representation of the command structure.
Parameters
----------
app
The Typer app or command group to process.
tree
The Rich Tree object to add nodes to.
show_options
Whether to display command options.
show_docs
Whether to display command and option documentation.
show_hidden
Whether to display hidden commands.
max_depth
Maximum depth of the tree to display.
current_depth
Current depth in the tree (used for recursion).
max_doc_lines
If show_docs is True, the maximum number of lines showed from the
documentation of each command/option.
Returns
-------
Tree
The populated Rich Tree object.
"""
if max_depth is not None and current_depth >= max_depth:
return tree
for command_name, command in sorted(app.commands.items()):
if not show_hidden and getattr(command, "hidden", False):
continue
command_str = f"[bold green]{command_name}[/bold green]"
if (
show_docs
and isinstance(command, (TyperCommand, TyperGroup))
and command.help
):
command_str += f": {command.help}"
if max_doc_lines:
lines = command_str.splitlines()
command_str = "\n".join(lines[:max_doc_lines])
if len(lines) > max_doc_lines:
command_str += " ..."
branch = tree.add(command_str)
if isinstance(command, TyperGroup):
get_command_tree(
command,
branch,
show_options,
show_docs,
show_hidden,
max_depth,
current_depth + 1,
max_doc_lines,
)
elif show_options:
for param in command.params:
param_str = f"[blue]{param.name}[/blue]"
if param.default is not None:
default_value = param.default if param.default != "" else '""'
param_str += f" (default: {default_value})"
if param.required:
param_str += " [red](required)[/red]"
if show_docs and param.help:
param_str += f": {param.help}"
if max_doc_lines:
lines = param_str.splitlines()
param_str = "\n".join(lines[:max_doc_lines])
if len(lines) > max_doc_lines:
param_str += " ..."
branch.add(param_str)
return tree
[docs]
def tree_callback(
ctx: typer.Context,
value: bool,
):
if value:
main_app = ctx.command
tree_title = f"[bold red]{ctx.command.name}[/bold red]"
tree = Tree(tree_title)
command_tree = get_command_tree(
main_app,
tree,
show_options=False,
show_docs=True,
show_hidden=False,
max_depth=None,
max_doc_lines=1,
)
out_console.print(command_tree)
raise typer.Exit(0)
[docs]
def confirm_project_name(style: str | None = "red", exit_on_error: bool = True) -> bool:
"""
Ask the user to insert the name of the current project to confirm that a specific operation
should be performed.
Parameters
----------
style
The style to use for the message. If None, no style is used.
exit_on_error
If True, exit the program with an error message if the user does not insert the correct
project name. Otherwise, return False.
Returns
-------
bool
True if the user inserted the correct project name, False otherwise.
"""
cm = get_config_manager()
project = cm.get_project()
msg = f"Insert the name of the project ({project.name}) to confirm that you want to proceed "
if style:
msg = f"[{style}]{msg}[/]"
user_project_name = out_console.input(msg)
if user_project_name.strip() != project.name:
if exit_on_error:
exit_with_error_msg("The input does not match the current project name")
return False
return True