Source code for jobflow_remote.utils.remote
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, ClassVar
from jobflow_remote import ConfigManager
if TYPE_CHECKING:
from pathlib import Path
from jobflow_remote.config.base import Project
from jobflow_remote.remote.host.base import BaseHost
logger = logging.getLogger(__name__)
[docs]
class SharedHosts:
"""
A singleton context manager to allow sharing the same host objects.
Hosts are stored internally, associated to the worker name
Being a singleton, opening the context manager multiple times allows
to share the hosts across different sections of the code, if needed.
Hosts connections are all closed only when leaving the last context
manager.
Examples
--------
>>> with SharedHosts(project) as shared_hosts:
... host = shared_hosts.get_host("worker_name")
... # Use host as required
"""
_instance: SharedHosts = None
_ref_count: int = 0
_hosts: ClassVar[dict[str, BaseHost]] = {}
_project: Project | None = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, project: Project | None = None):
"""
Parameters
----------
project
The project configuration.
"""
if self._project is None:
if project is None:
config_manager: ConfigManager = ConfigManager()
project = config_manager.get_project(None)
self._project = project
[docs]
def get_host(self, worker: str) -> BaseHost:
"""
Return the shared host, if already defined, otherwise retrieve
the host from the project and connect it.
Parameters
----------
worker
The name of a worker defined in the project
Returns
-------
BaseHost
The shared host.
"""
if worker not in self._project.workers:
raise ValueError(f"Worker {worker} not defined in {self._project.name}")
if worker in self._hosts:
return self._hosts[worker]
host = self._project.workers[worker].get_host()
host.connect()
self._hosts[worker] = host
return host
[docs]
def close_hosts(self) -> None:
"""Close the connection to all the connected hosts"""
for worker in list(self._hosts):
try:
self._hosts[worker].close()
except Exception:
logger.exception(
f"Error while closing the connection to the {worker} worker"
)
finally:
self._hosts.pop(worker)
def __enter__(self):
# Increment reference count
self.__class__._ref_count += 1
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Decrement reference count
self.__class__._ref_count -= 1
# Cleanup only when the last context exits
if self.__class__._ref_count == 0:
self.close_hosts()
[docs]
class UnsafeDeletionError(Exception):
"""
Error to signal that Job files could not be deleted as the safety check
did not pass.
"""
[docs]
def safe_remove_job_files(
host: BaseHost, run_dir: str | Path | None, raise_on_error: bool = False
) -> bool:
if not run_dir:
return False
remote_files = host.listdir(run_dir)
# safety measure to avoid mistakenly deleting other folders
if not remote_files:
return False
if any(fn in remote_files for fn in ("jfremote_in.json", "jfremote_in.json.gz")):
return host.rmtree(path=run_dir, raise_on_error=raise_on_error)
if raise_on_error:
raise UnsafeDeletionError(
f"Could not delete folder {run_dir} "
"since it may not contain a jobflow-remote execution. Some files are present, "
"but jfremote_in.json is missing",
)
logger.warning(
f"Did not delete folder {run_dir} "
"since it may not contain a jobflow-remote executionSome files are present, "
"but jfremote_in.json is missing",
)
return False