Source code for jobflow_remote.remote.host.remote

from __future__ import annotations

import getpass
import io
import logging
import shlex
import traceback
import warnings
from pathlib import Path

import fabric
from fabric import Config
from fabric.auth import OpenSSHAuthStrategy
from paramiko.auth_strategy import AuthSource
from paramiko.ssh_exception import SSHException

from jobflow_remote.remote.host.base import BaseHost

logger = logging.getLogger(__name__)


[docs] class RemoteHost(BaseHost): """ Execute commands on a remote host. For some commands assumes the remote can run unix. """ def __init__( self, host, user=None, port=None, config=None, gateway=None, forward_agent=None, connect_timeout=None, connect_kwargs=None, inline_ssh_env=None, timeout_execute=None, keepalive=60, shell_cmd="bash", login_shell=True, retry_on_closed_connection=True, interactive_login=False, ) -> None: self.host = host self.user = user self.port = port self.config = config self.gateway = gateway self.forward_agent = forward_agent self.connect_timeout = connect_timeout self.connect_kwargs = connect_kwargs self.inline_ssh_env = inline_ssh_env self.timeout_execute = timeout_execute self.keepalive = keepalive self.shell_cmd = shell_cmd self.login_shell = login_shell self.retry_on_closed_connection = retry_on_closed_connection self._interactive_login = interactive_login self._create_connection() def _create_connection(self) -> None: if self.interactive_login: # if auth_timeout is not explicitly set, use a larger value than # the default to avoid the timeout while user get access to the process connect_kwargs = dict(self.connect_kwargs) if self.connect_kwargs else {} if "auth_timeout" not in connect_kwargs: connect_kwargs["auth_timeout"] = 120 config = self.config self._connection = self._get_single_connection( host=self.host, user=self.user, port=self.port, config=config, gateway=self.gateway, connect_kwargs=connect_kwargs, ) # if the authentication is ssh-key + OTP paramiko already # handles it. Don't use the alternative strategy. if not self._connection.connect_kwargs.get("key_filename"): if not config: config = Config() config.authentication.strategy_class = InteractiveAuthStrategy self._connection = self._get_single_connection( host=self.host, user=self.user, port=self.port, config=config, gateway=self.gateway, connect_kwargs=connect_kwargs, ) else: self._connection = self._get_single_connection( host=self.host, user=self.user, port=self.port, config=self.config, gateway=self.gateway, connect_kwargs=self.connect_kwargs, ) def _get_single_connection( self, host, user, port, config, gateway, connect_kwargs, ): """Helper method to generate a fabric Connection given standard parameters.""" from jobflow_remote.config.base import ConnectionData if isinstance(gateway, ConnectionData): gateway = self._get_single_connection( host=gateway.host, user=gateway.user, port=gateway.port, config=None, gateway=gateway.gateway, connect_kwargs=gateway.get_connect_kwargs(), ) return fabric.Connection( host=host, user=user, port=port, config=config, gateway=gateway, forward_agent=self.forward_agent, connect_timeout=self.connect_timeout, connect_kwargs=connect_kwargs, inline_ssh_env=self.inline_ssh_env, ) def __eq__(self, other): if not isinstance(other, RemoteHost): return False return self.as_dict() == other.as_dict() @property def connection(self): return self._connection
[docs] def execute( self, command: str | list[str], workdir: str | Path | None = None, timeout: int | None = None, ): """Execute the given command on the host. Parameters ---------- command: str or list of str Command to execute, as a str or list of str. workdir: str or None path where the command will be executed. Returns ------- stdout : str Standard output of the command stderr : str Standard error of the command exit_code : int Exit code of the command. """ self._check_connected() if isinstance(command, (list, tuple)): command = " ".join(command) # TODO: check if this works: if not workdir: workdir = "." workdir = Path(workdir) timeout = timeout or self.timeout_execute if self.shell_cmd: shell_cmd = self.shell_cmd if self.login_shell: shell_cmd += " -l " shell_cmd += " -c " remote_command = shell_cmd + shlex.quote(command) else: remote_command = command with self.connection.cd(workdir): out = self._execute_remote_func( self.connection.run, remote_command, hide=True, warn=True, timeout=timeout, ) return out.stdout, out.stderr, out.exited
[docs] def mkdir( self, directory: str | Path, recursive: bool = True, exist_ok: bool = True ) -> bool: """Create directory on the host.""" directory = Path(directory) command = f"mkdir {'-p ' if recursive else ''}{str(directory)!r}" try: stdout, stderr, returncode = self.execute(command) if returncode != 0: logger.warning( f"Error creating folder {directory}. stdout: {stdout}, stderr: {stderr}" ) else: return returncode == 0 except Exception: logger.warning(f"Error creating folder {directory}", exc_info=True) return False
[docs] def write_text_file(self, filepath: str | Path, content: str) -> None: """Write content to a file on the host.""" self._check_connected() f = io.StringIO(content) self._execute_remote_func(self.connection.put, f, str(filepath))
[docs] def connect(self) -> None: self.connection.open() if self.keepalive: # create all the nested connections for all the gateways. connection = self.connection while connection: if isinstance(connection, fabric.Connection): connection.transport.set_keepalive(self.keepalive) connection = connection.gateway
[docs] def close(self) -> bool: connection = self.connection all_closed = True while connection: try: if isinstance(connection, fabric.Connection): connection.close() except Exception: all_closed = False connection = connection.gateway return all_closed
@property def is_connected(self) -> bool: return self.connection.is_connected
[docs] def put(self, src, dst) -> None: self._check_connected() self._execute_remote_func(self.connection.put, src, dst)
[docs] def get(self, src, dst) -> None: self._check_connected() self._execute_remote_func(self.connection.get, src, dst)
[docs] def copy(self, src, dst) -> None: cmd = ["cp", str(src), str(dst)] self.execute(cmd)
def _execute_remote_func(self, remote_cmd, *args, **kwargs): if self.retry_on_closed_connection: try: return remote_cmd(*args, **kwargs) except OSError as e: msg = getattr(e, "message", str(e)) error = e if "Socket is closed" not in msg: raise except SSHException as e: error = e msg = getattr(e, "message", str(e)) if "Server connection dropped" not in msg: raise except EOFError as e: error = e else: return remote_cmd(*args, **kwargs) # if the code gets here one of the errors that could be due to drop of the # connection occurred. Try to close and reopen the connection and retry # one more time logger.warning( f"Error while trying to execute a command on host {self.host}:\n" f"{''.join(traceback.format_exception(error))}" "Probably due to the connection dropping. " "Will reopen the connection and retry." ) try: self.connection.close() except Exception: logger.warning( "Error while closing the connection during a retry. " "Proceeding with the retry.", exc_info=True, ) self._create_connection() self.connect() return remote_cmd(*args, **kwargs)
[docs] def listdir(self, path: str | Path): self._check_connected() try: return self._execute_remote_func(self.connection.sftp().listdir, str(path)) except FileNotFoundError: return []
[docs] def remove(self, path: str | Path) -> None: self._check_connected() self._execute_remote_func(self.connection.sftp().remove, str(path))
[docs] def rmtree(self, path: str | Path, raise_on_error: bool = False) -> bool: """Recursively delete a directory tree on a remote host. It is intended to remove an entire directory tree, including all files and subdirectories, on this remote host. Parameters ---------- path : str or Path The path to the directory tree to be removed. raise_on_error : bool If set to `False` (default), errors will be ignored, and the method will attempt to continue removing remaining files and directories. Otherwise, any errors encountered during the removal process will raise an exception. Returns ------- bool True if the directory tree was successfully removed, False otherwise. """ stdout, stderr, exit_code = self.execute(f"rm -r {path}") if exit_code != 0: msg = f"Error while deleting folder {path}. stdout: {stdout}, stderr: {stderr}" if raise_on_error: raise RuntimeError(msg) warnings.warn(msg, stacklevel=2) return False return True
def _check_connected(self) -> bool: """ Helper method to determine if fabric consider the connection open and open it otherwise. Since many operations requiring connections happen in the runner, if the connection drops there are cases where the host may not be reconnected. To avoid this issue always try to reconnect automatically if the connection is not open. Returns ------- True if the connection is open. """ if not self.is_connected: # Note: raising here instead of reconnecting demonstrated to be a # problem for how the queue managers are handled in the Runner. self.connect() return True @property def interactive_login(self) -> bool: return self._interactive_login
[docs] def inter_handler(title, instructions, prompt_list): """ Handler function for interactive prompts from the server. Used by Interactive AuthSource. """ if title: print(title.strip()) if instructions: print(instructions.strip()) resp = [] # Initialize the response container # Walk the list of prompts that the server sent that we need to answer for pr in prompt_list: in_value = input(pr[0]) if pr[1] else getpass.getpass(pr[0]) resp.append(in_value) return tuple(resp) # Convert the response list to a tuple and return it
[docs] class Interactive(AuthSource): """ Interactive AuthSource. Prompts the user for all the requests coming from the server. """ def __init__(self, username) -> None: super().__init__(username=username) def __repr__(self) -> str: return super()._repr(user=self.username) # type: ignore[misc]
[docs] def authenticate(self, transport): return transport.auth_interactive(self.username, inter_handler)
[docs] class InteractiveAuthStrategy(OpenSSHAuthStrategy): """ AuthStrategy based on OpenSSHAuthStrategy that tries to use public keys and then switches to an interactive approach forwarding the requests from the server. """
[docs] def get_sources(self): # get_pubkeys from OpenSSHAuthStrategy # With the current implementation exceptions may be raised in case if a key # in ~/.ssh cannot be parsed. # In addition other error ("Oops, unhandled type 3 ('unimplemented')") # can lead to the procedure being stuck. Don't try the keys at the moment # InteractiveAuthStrategy works for password only # try: # yield from self.get_pubkeys() # except Exception as e: # logger.warning( # "Error while trying the authentication with all the public keys " # f"available: {getattr(e, 'message', str(e))}. This may be due to the " # "format of one of the keys. Authentication will proceed with " # "interactive prompts" # ) yield Interactive(username=self.username)