# ruff: noqa: PLW0602, PLW0603
from __future__ import annotations
import functools
import json
import logging
import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta
from enum import Enum
from typing import TYPE_CHECKING, Callable, NoReturn
import typer
from click import ClickException
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.prompt import Confirm
from rich.text import Text
from jobflow_remote import ConfigManager, JobController
from jobflow_remote.config.base import ProjectUndefinedError
from jobflow_remote.jobs.daemon import DaemonError, DaemonManager, DaemonStatus
if TYPE_CHECKING:
from cProfile import Profile
from jobflow_remote.jobs.state import JobState
logger = logging.getLogger(__name__)
err_console = Console(stderr=True)
out_console = Console()
fmt_datetime = "%Y-%m-%d %H:%M"
# shared instances of the config manager and job controller, to avoid parsing
# the files multiple times. Needs to be initialized with the
# initialize_config_manager function.
_shared_config_manager: ConfigManager | None = None
_shared_job_controller: JobController | None = None
_profiler: Profile | None = None
[docs]
def initialize_config_manager(*args, **kwargs) -> None:
global _shared_config_manager
_shared_config_manager = ConfigManager(*args, **kwargs)
[docs]
def get_config_manager() -> ConfigManager:
global _shared_config_manager
if not _shared_config_manager:
raise RuntimeError("The shared config manager needs to be initialized")
return _shared_config_manager
[docs]
def get_job_controller():
global _shared_job_controller
if _shared_job_controller is None:
cm = get_config_manager()
jc = JobController.from_project(cm.get_project())
_shared_job_controller = jc
return _shared_job_controller
[docs]
def cleanup_job_controller() -> None:
global _shared_job_controller
if _shared_job_controller is not None:
_shared_job_controller.close()
# set to None again, in case it needs to be used again in the same
# execution (e.g., in tests)
_shared_job_controller = None
[docs]
def start_profiling() -> None:
global _profiler
from cProfile import Profile
_profiler = Profile()
_profiler.enable()
[docs]
def complete_profiling() -> None:
global _profiler
_profiler.disable()
import pstats
stats = pstats.Stats(_profiler).sort_stats("cumtime")
stats.print_stats()
[docs]
class SortOption(str, Enum):
CREATED_ON = "created_on"
UPDATED_ON = "updated_on"
DB_ID = "db_id"
[docs]
class IndexDirection(str, Enum):
ASC = "asc"
DESC = "desc"
@property
def as_pymongo(self):
import pymongo
return {
IndexDirection.ASC: pymongo.ASCENDING,
IndexDirection.DESC: pymongo.DESCENDING,
}[self]
[docs]
class ReprStr(str):
r"""
Helper class that overrides the standard __repr__ to return the string itself
and not its repr().
Used mainly to allow printing of strings with newlines instead of '\n' when
repr is used in rich.
"""
__slots__ = ()
def __repr__(self) -> str:
return self
[docs]
def exit_with_error_msg(message: str, code: int = 1, **kwargs) -> NoReturn:
kwargs.setdefault("style", "red")
err_console.print(message, **kwargs)
raise typer.Exit(code)
[docs]
def exit_with_warning_msg(message: str, code: int = 0, **kwargs) -> NoReturn:
kwargs.setdefault("style", "gold1")
err_console.print(message, **kwargs)
raise typer.Exit(code)
[docs]
def print_success_msg(message: str = "operation completed", **kwargs) -> None:
kwargs.setdefault("style", "green")
out_console.print(message, **kwargs)
[docs]
def check_incompatible_opt(d: dict) -> None:
not_none = []
for k, v in d.items():
if v:
not_none.append(k)
if len(not_none) > 1:
options_list = ", ".join(not_none)
exit_with_error_msg(f"Options {options_list} are incompatible")
[docs]
def check_query_incompatibility(query, incompatible_options):
if query and any(opt is not None for opt in incompatible_options):
exit_with_error_msg(
"The --query option is incompatible with all the other filtering options"
)
[docs]
def check_at_least_one_opt(d: dict) -> None:
not_none = []
for k, v in d.items():
if v:
not_none.append(k)
if len(not_none) > 1:
options_list = ", ".join(d)
exit_with_error_msg(
f"At least one of the options {options_list} should be defined"
)
[docs]
def check_only_one_opt(d: dict) -> None:
not_none = []
for k, v in d.items():
if v:
not_none.append(k)
if len(not_none) != 1:
options_list = ", ".join(d)
exit_with_error_msg(
f"One and only one of the options {options_list} should be defined"
)
[docs]
@contextmanager
def loading_spinner(processing: bool = True):
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=True,
) as progress:
if processing:
progress.add_task(description="Processing...", total=None)
yield progress
[docs]
@contextmanager
def hide_progress(progress: Progress):
"""
Hide the progress bar or spinning icon if an input is required from the user.
Adapted from a related github issue for rich:
https://github.com/Textualize/rich/issues/1535#issuecomment-1745297594
Parameters
----------
progress
The Progress object in use
"""
transient = progress.live.transient # save the old value
progress.live.transient = True
progress.stop()
progress.live.transient = transient # restore the old value
try:
yield
finally:
# make space for the progress to use so it doesn't overwrite any previous lines
progress.start()
[docs]
def get_job_db_ids(job_db_id: str, job_index: int | None):
if check_valid_uuid(job_db_id, raise_on_error=False):
db_id = None
job_id = job_db_id
else:
db_id = job_db_id
job_id = None
if job_index and db_id is not None:
out_console.print(
"The index is defined while a db_id is passed as an ID. Will be ignored",
style="yellow",
)
return db_id, job_id
[docs]
def get_job_ids_indexes(job_ids: list[str] | None) -> list[tuple[str, int]] | None:
if not job_ids:
return None
job_ids_indexes = []
for j in job_ids:
split = j.split(":")
if len(split) != 2 or not split[1].isnumeric():
raise typer.BadParameter(
"The job id should be in the format UUID:INDEX "
"(e.g. e1d66c4f-81db-4fff-bda2-2bf1d79d5961:2). "
f"Wrong format for {j}"
)
check_valid_uuid(split[0])
job_ids_indexes.append((split[0], int(split[1])))
return job_ids_indexes
[docs]
def cli_error_handler(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except (typer.Exit, typer.Abort, ClickException):
raise # Do not capture click or typer exceptions
except ProjectUndefinedError:
exit_with_error_msg(
"The active project could not be determined and it is required to execute this command"
)
except Exception as e:
from jobflow_remote import SETTINGS
if SETTINGS.cli_full_exc:
raise # Reraise exceptions to print the full stacktrace
exit_with_error_msg(
f"An Error occurred during the command execution: {type(e).__name__} {getattr(e, 'message', str(e))}"
)
return wrapper
[docs]
def check_valid_uuid(uuid_str, raise_on_error: bool = True) -> bool:
try:
uuid_obj = uuid.UUID(uuid_str)
if str(uuid_obj) == uuid_str:
return True
except ValueError:
pass
if raise_on_error:
raise typer.BadParameter(f"UUID {uuid_str} is in the wrong format.")
return False
[docs]
def str_to_dict(string: str | None) -> dict | None:
if not string:
return None
try:
dictionary = json.loads(string)
except json.JSONDecodeError as exc:
dictionary = {}
for chunk in string.split(","):
split = chunk.split("=")
if len(split) != 2:
raise typer.BadParameter(
f"Wrong format for dictionary-like field {string}"
) from exc
dictionary[split[0]] = split[1]
return dictionary
[docs]
def get_start_date(start_date: datetime | None, days: int | None, hours: int | None):
if start_date and (start_date.year, start_date.month, start_date.day) == (
1900,
1,
1,
):
now = datetime.now()
start_date = start_date.replace(year=now.year, month=now.month, day=now.day)
if start_date > now:
start_date = start_date - timedelta(days=1)
elif days:
start_date = datetime.now() - timedelta(days=days)
elif hours:
start_date = datetime.now() - timedelta(hours=hours)
return start_date
[docs]
def execute_multi_jobs_cmd(
single_cmd: Callable,
multi_cmd: Callable,
job_db_id: str | None = None,
job_index: int | None = None,
job_ids: list[str] | None = None,
db_ids: str | list[str] | None = None,
flow_ids: str | list[str] | None = None,
states: JobState | list[JobState] | None = None,
start_date: datetime | None = None,
end_date: datetime | None = None,
name: str | None = None,
metadata: dict | None = None,
days: int | None = None,
hours: int | None = None,
workers: list[str] | None = None,
custom_query: dict | None = None,
verbosity: int = 0,
raise_on_error: bool = False,
**kwargs,
) -> None:
query_values = [
job_ids,
db_ids,
flow_ids,
states,
start_date,
end_date,
name,
metadata,
days,
hours,
workers,
custom_query,
]
try:
if job_db_id is not None:
if any(query_values):
msg = "If job_db_id is defined all the other query options should be disabled"
exit_with_error_msg(msg)
db_id, job_id = get_job_db_ids(job_db_id, job_index)
with loading_spinner():
modified_ids = single_cmd(
job_id=job_id, job_index=job_index, db_id=db_id, **kwargs
)
if not isinstance(modified_ids, (list, tuple)):
modified_ids = [] if modified_ids is None else [modified_ids]
if not modified_ids:
exit_with_error_msg("Could not perform the requested operation")
else:
check_incompatible_opt(
{"start_date": start_date, "days": days, "hours": hours}
)
check_incompatible_opt({"end_date": end_date, "days": days, "hours": hours})
check_query_incompatibility(
custom_query,
[
job_ids,
db_ids,
flow_ids,
states,
start_date,
end_date,
name,
metadata,
days,
hours,
workers,
],
)
job_ids_indexes = get_job_ids_indexes(job_ids)
start_date = get_start_date(start_date, days, hours)
if not any(
(
job_ids_indexes,
db_ids,
flow_ids,
states,
start_date,
end_date,
name,
metadata,
workers,
custom_query,
)
):
text = Text.from_markup(
"[yellow]No filter has been set. This will apply the change to all "
"the jobs in the DB. Proceed anyway?[/yellow]"
)
confirmed = Confirm.ask(text, default=False)
if not confirmed:
raise typer.Exit(0) # noqa: TRY301
with loading_spinner():
modified_ids = multi_cmd(
job_ids=job_ids_indexes,
db_ids=db_ids,
flow_ids=flow_ids,
states=states,
start_date=start_date,
end_date=end_date,
name=name,
metadata=metadata,
workers=workers,
custom_query=custom_query,
raise_on_error=raise_on_error,
**kwargs,
)
if verbosity:
print_success_msg(f"Operation completed. Modified jobs: {modified_ids}")
else:
print_success_msg(f"Operation completed: {len(modified_ids)} jobs modified")
except Exception:
logger.exception("Error executing the operation")
[docs]
def check_stopped_runner(error: bool = True) -> None:
cm = get_config_manager()
dm = DaemonManager.from_project(cm.get_project())
try:
with loading_spinner(processing=False) as progress:
progress.add_task(description="Checking the Daemon status...", total=None)
current_status = dm.check_status()
except DaemonError as e:
exit_with_error_msg(
f"Error while checking the status of the daemon: {getattr(e, 'message', str(e))}"
)
if current_status not in (DaemonStatus.STOPPED, DaemonStatus.SHUT_DOWN):
if error:
exit_with_error_msg(
f"The status of the daemon is {current_status.value}. "
"The daemon should not be running while resetting the database"
)
else:
text = Text.from_markup(
"[red]The Runner is active. This operation may lead to "
"inconsistencies in this case. Proceed anyway?[/red]"
)
confirmed = Confirm.ask(text, default=False)
if not confirmed:
raise typer.Exit(0)