diff --git a/doc/conf.py b/doc/conf.py index 8bdc8236..172eb2a3 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -13,6 +13,7 @@ import os import sys from datetime import date +from typing import Any # sys.path.insert(0, os.path.abspath('.')) @@ -68,7 +69,7 @@ master_doc = "index" -def run_apidoc(_): +def run_apidoc(_: Any) -> None: # noqa: ANN401 from sphinx.ext.apidoc import main sys.path.append(os.path.join(os.path.dirname(__file__), "..")) @@ -89,7 +90,7 @@ def run_apidoc(_): ) -def setup(app): +def setup(app: Any) -> None: # noqa: ANN401 app.connect("builder-inited", run_apidoc) diff --git a/dpdispatcher/base_context.py b/dpdispatcher/base_context.py index 4f9096ea..db714ed4 100644 --- a/dpdispatcher/base_context.py +++ b/dpdispatcher/base_context.py @@ -1,10 +1,13 @@ from abc import ABCMeta, abstractmethod -from typing import Any, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Tuple from dargs import Argument from dpdispatcher.dlog import dlog +if TYPE_CHECKING: + from dpdispatcher.submission import Submission + class BaseContext(metaclass=ABCMeta): subclasses_dict = {} @@ -13,7 +16,7 @@ class BaseContext(metaclass=ABCMeta): # notes: this attribute can be inherited alias: Tuple[str, ...] = tuple() - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> "BaseContext": # noqa: ANN401 if cls is BaseContext: subcls = cls.subclasses_dict[kwargs["context_type"]] instance = subcls.__new__(subcls, *args, **kwargs) @@ -21,7 +24,7 @@ def __new__(cls, *args, **kwargs): instance = object.__new__(cls) return instance - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs: Any) -> None: # noqa: ANN401 super().__init_subclass__(**kwargs) alias = [cls.__name__, *cls.alias] for aa in alias: @@ -32,7 +35,7 @@ def __init_subclass__(cls, **kwargs): cls.options.add(cls.__name__) @classmethod - def load_from_dict(cls, context_dict): + def load_from_dict(cls, context_dict: Dict[str, Any]) -> "BaseContext": # noqa: ANN401 context_type = context_dict["context_type"] # print("debug778:context_type", cls.subclasses_dict, context_type) try: @@ -45,35 +48,41 @@ def load_from_dict(cls, context_dict): context = context_class.load_from_dict(context_dict) return context - def bind_submission(self, submission): + def bind_submission(self, submission: "Submission") -> None: self.submission = submission @abstractmethod - def upload(self, submission): + def upload(self, submission: "Submission") -> None: raise NotImplementedError("abstract method") @abstractmethod def download( - self, submission, check_exists=False, mark_failure=True, back_error=False - ): + self, + submission: "Submission", + check_exists: bool = False, + mark_failure: bool = True, + back_error: bool = False, + ) -> None: raise NotImplementedError("abstract method") @abstractmethod - def clean(self): + def clean(self) -> None: raise NotImplementedError("abstract method") @abstractmethod - def write_file(self, fname, write_str): + def write_file(self, fname: str, write_str: str) -> None: raise NotImplementedError("abstract method") @abstractmethod - def read_file(self, fname): + def read_file(self, fname: str) -> str: raise NotImplementedError("abstract method") - def check_finish(self, proc): + def check_finish(self, proc: Any) -> Any: # noqa: ANN401 raise NotImplementedError("abstract method") - def block_checkcall(self, cmd, asynchronously=False) -> Tuple[Any, Any, Any]: + def block_checkcall( + self, cmd: str, asynchronously: bool = False + ) -> Tuple[Any, Any, Any]: # noqa: ANN401 """Run command with arguments. Wait for command to complete. Parameters @@ -112,7 +121,7 @@ def block_checkcall(self, cmd, asynchronously=False) -> Tuple[Any, Any, Any]: return stdin, stdout, stderr @abstractmethod - def block_call(self, cmd) -> Tuple[int, Any, Any, Any]: + def block_call(self, cmd: str) -> Tuple[int, Any, Any, Any]: # noqa: ANN401 """Run command with arguments. Wait for command to complete. Parameters diff --git a/dpdispatcher/contexts/dp_cloud_server_context.py b/dpdispatcher/contexts/dp_cloud_server_context.py index 4d827a80..7ce624e0 100644 --- a/dpdispatcher/contexts/dp_cloud_server_context.py +++ b/dpdispatcher/contexts/dp_cloud_server_context.py @@ -3,7 +3,7 @@ import os import shutil import uuid -from typing import List +from typing import TYPE_CHECKING, Any, List, NoReturn, Optional import tqdm from dargs.dargs import Argument @@ -19,6 +19,9 @@ ALI_STS_ENDPOINT, ) +if TYPE_CHECKING: + from dpdispatcher.submission import Job, Submission + # from zip_file import zip_files DP_CLOUD_SERVER_HOME_DIR = os.path.join( @@ -31,12 +34,12 @@ class BohriumContext(BaseContext): def __init__( self, - local_root, - remote_root=None, - remote_profile={}, - *args, - **kwargs, - ): + local_root: str, + remote_root: Optional[str] = None, + remote_profile: dict[str, Any] = {}, # noqa: ANN401 + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> None: self.init_local_root = local_root self.init_remote_root = remote_root self.temp_local_root = os.path.abspath(local_root) @@ -67,7 +70,7 @@ def __init__( self.api = Client(account, password) @classmethod - def load_from_dict(cls, context_dict): + def load_from_dict(cls, context_dict: dict[str, Any]) -> "BohriumContext": # noqa: ANN401 local_root = context_dict["local_root"] remote_root = context_dict.get("remote_root", None) remote_profile = context_dict.get("remote_profile", {}) @@ -79,7 +82,7 @@ def load_from_dict(cls, context_dict): ) return dp_cloud_server_context - def bind_submission(self, submission): + def bind_submission(self, submission: "Submission") -> None: self.submission = submission self.local_root = os.path.join(self.temp_local_root, submission.work_base) self.remote_root = "." @@ -92,7 +95,7 @@ def bind_submission(self, submission): # file_uuid = uuid.uuid1().hex # oss_task_dir = os.path.join() - def _gen_oss_path(self, job, zip_filename): + def _gen_oss_path(self, job: "Job", zip_filename: str) -> str: if hasattr(job, "upload_path") and job.upload_path: return job.upload_path else: @@ -105,7 +108,7 @@ def _gen_oss_path(self, job, zip_filename): setattr(job, "upload_path", path) return path - def upload_job(self, job, common_files=None): + def upload_job(self, job: "Job", common_files: Optional[list[str]] = None) -> None: MAX_RETRY = 3 if common_files is None: common_files = [] @@ -133,7 +136,7 @@ def upload_job(self, job, common_files=None): retry_count = 0 self._backup(self.local_root, upload_zip) - def upload(self, submission): + def upload(self, submission: "Submission") -> None: # oss_task_dir = os.path.join('%s/%s/%s.zip' % ('indicate', file_uuid, file_uuid)) # zip_filename = submission.submission_hash + '.zip' # oss_task_zip = 'indicate/' + submission.submission_hash + '/' + zip_filename @@ -162,8 +165,12 @@ def upload(self, submission): # api.upload(self.oss_task_dir, zip_task_file) def download( - self, submission, check_exists=False, mark_failure=True, back_error=False - ): + self, + submission: "Submission", + check_exists: bool = False, + mark_failure: bool = True, + back_error: bool = False, + ) -> bool: jobs = submission.belonging_jobs job_hashs = {} job_infos = {} @@ -210,7 +217,9 @@ def download( ) return True - def _check_if_job_has_already_downloaded(self, target, local_root): + def _check_if_job_has_already_downloaded( + self, target: str, local_root: str + ) -> bool: backup_file_location = os.path.join( local_root, "backup", os.path.split(target)[1] ) @@ -219,7 +228,7 @@ def _check_if_job_has_already_downloaded(self, target, local_root): else: return False - def _backup(self, local_root, target): + def _backup(self, local_root: str, target: str) -> None: try: # move to backup directory os.makedirs(os.path.join(local_root, "backup"), exist_ok=True) @@ -229,45 +238,45 @@ def _backup(self, local_root, target): except (OSError, shutil.Error) as e: dlog.exception("unable to backup file, " + str(e)) - def _clean_backup(self, local_root, keep_backup=True): + def _clean_backup(self, local_root: str, keep_backup: bool = True) -> None: if not keep_backup: dir_to_be_removed = os.path.join(local_root, "backup") if os.path.exists(dir_to_be_removed): shutil.rmtree(dir_to_be_removed) - def write_file(self, fname, write_str): + def write_file(self, fname: str, write_str: str) -> bool: result = self.write_home_file(fname, write_str) return result - def write_local_file(self, fname, write_str): + def write_local_file(self, fname: str, write_str: str) -> str: local_filename = os.path.join(self.local_root, fname) with open(local_filename, "w") as f: f.write(write_str) return local_filename - def read_file(self, fname): + def read_file(self, fname: str) -> str: result = self.read_home_file(fname) return result - def write_home_file(self, fname, write_str): + def write_home_file(self, fname: str, write_str: str) -> bool: # os.makedirs(self.remote_root, exist_ok = True) with open(os.path.join(DP_CLOUD_SERVER_HOME_DIR, fname), "w") as fp: fp.write(write_str) return True - def read_home_file(self, fname): + def read_home_file(self, fname: str) -> str: with open(os.path.join(DP_CLOUD_SERVER_HOME_DIR, fname)) as fp: ret = fp.read() return ret - def check_file_exists(self, fname): + def check_file_exists(self, fname: str) -> bool: result = self.check_home_file_exits(fname) return result - def check_home_file_exits(self, fname): + def check_home_file_exits(self, fname: str) -> bool: return os.path.isfile(os.path.join(DP_CLOUD_SERVER_HOME_DIR, fname)) - def clean(self): + def clean(self) -> bool: submission_file_name = f"{self.submission.submission_hash}.json" submission_json = os.path.join(DP_CLOUD_SERVER_HOME_DIR, submission_file_name) os.remove(submission_json) @@ -337,7 +346,7 @@ def machine_subfields(cls) -> List[Argument]: ) ] - def block_call(self, cmd): + def block_call(self, cmd: str) -> NoReturn: raise RuntimeError( "Unsupported method. You may use an unsupported combination of the machine and the context." ) diff --git a/dpdispatcher/contexts/hdfs_context.py b/dpdispatcher/contexts/hdfs_context.py index 4c0a3b72..df63eec4 100644 --- a/dpdispatcher/contexts/hdfs_context.py +++ b/dpdispatcher/contexts/hdfs_context.py @@ -2,21 +2,25 @@ import shutil import tarfile from glob import glob +from typing import TYPE_CHECKING, Any, Dict, List, NoReturn from dpdispatcher.base_context import BaseContext from dpdispatcher.dlog import dlog from dpdispatcher.utils.hdfs_cli import HDFS +if TYPE_CHECKING: + from dpdispatcher.submission import Submission + class HDFSContext(BaseContext): def __init__( self, - local_root, - remote_root, - remote_profile={}, - *args, - **kwargs, - ): + local_root: str, + remote_root: str, + remote_profile: Dict[str, Any] = {}, # noqa: ANN401 + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> None: assert isinstance(local_root, str) self.init_local_root = local_root self.init_remote_root = remote_root @@ -25,7 +29,7 @@ def __init__( self.remote_profile = remote_profile @classmethod - def load_from_dict(cls, context_dict): + def load_from_dict(cls, context_dict: Dict[str, Any]) -> "HDFSContext": # noqa: ANN401 local_root = context_dict["local_root"] remote_root = context_dict["remote_root"] remote_profile = context_dict.get("remote_profile", {}) @@ -36,10 +40,10 @@ def load_from_dict(cls, context_dict): ) return instance - def get_job_root(self): + def get_job_root(self) -> str: return self.remote_root - def bind_submission(self, submission): + def bind_submission(self, submission: "Submission") -> None: self.submission = submission self.local_root = os.path.join(self.temp_local_root, submission.work_base) self.remote_root = os.path.join( @@ -48,7 +52,7 @@ def bind_submission(self, submission): HDFS.mkdir(self.remote_root) - def _put_files(self, files, dereference=True): + def _put_files(self, files: List[str], dereference: bool = True) -> None: of = self.submission.submission_hash + "_upload.tgz" # local tar if os.path.isfile(os.path.join(self.local_root, of)): @@ -67,7 +71,7 @@ def _put_files(self, files, dereference=True): # clean up os.remove(from_f) - def upload(self, submission, dereference=True): + def upload(self, submission: "Submission", dereference: bool = True) -> None: """Upload forward files and forward command files to HDFS root dir. Parameters @@ -111,8 +115,12 @@ def upload(self, submission, dereference=True): self._put_files(file_list, dereference=dereference) def download( - self, submission, check_exists=False, mark_failure=True, back_error=False - ): + self, + submission: "Submission", + check_exists: bool = False, + mark_failure: bool = True, + back_error: bool = False, + ) -> None: """Download backward files from HDFS root dir. Parameters @@ -218,7 +226,7 @@ def download( # remove tmp dir shutil.rmtree(gz_dir, ignore_errors=True) - def check_file_exists(self, fname): + def check_file_exists(self, fname: str) -> bool: """Check whether the given file exists, often used in checking whether the belonging job has finished. Parameters @@ -232,20 +240,20 @@ def check_file_exists(self, fname): """ return HDFS.exists(os.path.join(self.remote_root, fname)) - def clean(self): + def clean(self) -> None: HDFS.remove(self.remote_root) - def write_file(self, fname, write_str): + def write_file(self, fname: str, write_str: str) -> str: local_file = os.path.join("/tmp/", fname) with open(local_file, "w") as fp: fp.write(write_str) HDFS.copy_from_local(local_file, os.path.join(self.remote_root, fname)) return local_file - def read_file(self, fname): + def read_file(self, fname: str) -> str: return HDFS.read_hdfs_file(os.path.join(self.remote_root, fname)) - def block_call(self, cmd): + def block_call(self, cmd: str) -> NoReturn: raise RuntimeError( "Unsupported method. You may use an unsupported combination of the machine and the context." ) diff --git a/dpdispatcher/contexts/lazy_local_context.py b/dpdispatcher/contexts/lazy_local_context.py index 5eaa9264..95756aad 100644 --- a/dpdispatcher/contexts/lazy_local_context.py +++ b/dpdispatcher/contexts/lazy_local_context.py @@ -1,17 +1,21 @@ import os import subprocess as sp +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from dpdispatcher.base_context import BaseContext +if TYPE_CHECKING: + from dpdispatcher.submission import Submission + class SPRetObj: - def __init__(self, ret): + def __init__(self, ret: bytes) -> None: self.data = ret - def read(self): + def read(self) -> bytes: return self.data - def readlines(self): + def readlines(self) -> List[str]: lines = self.data.decode("utf-8").splitlines() ret = [] for aa in lines: @@ -38,12 +42,12 @@ class LazyLocalContext(BaseContext): def __init__( self, - local_root, - remote_root=None, - remote_profile={}, - *args, - **kwargs, - ): + local_root: str, + remote_root: Optional[str] = None, + remote_profile: Dict[str, Any] = {}, # noqa: ANN401 + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> None: assert isinstance(local_root, str) self.init_local_root = local_root self.init_remote_root = remote_root @@ -58,7 +62,7 @@ def __init__( # self.job_uuid = str(uuid.uuid4()) @classmethod - def load_from_dict(cls, context_dict): + def load_from_dict(cls, context_dict: Dict[str, Any]) -> "LazyLocalContext": # noqa: ANN401 local_root = context_dict["local_root"] remote_root = context_dict.get("remote_root", None) remote_profile = context_dict.get("remote_profile", {}) @@ -69,7 +73,7 @@ def load_from_dict(cls, context_dict): ) return instance - def bind_submission(self, submission): + def bind_submission(self, submission: "Submission") -> None: self.submission = submission self.local_root = os.path.join(self.temp_local_root, submission.work_base) self.remote_root = os.path.join(self.temp_local_root, submission.work_base) @@ -78,25 +82,25 @@ def bind_submission(self, submission): # "self.local_root:{self.local_root};" # "self.remote_root:{self.remote_root}") - def get_job_root(self): + def get_job_root(self) -> str: return self.local_root def upload( self, - submission, + submission: "Submission", # local_up_files, - dereference=True, - ): + dereference: bool = True, + ) -> None: pass def download( self, - submission, + submission: "Submission", # remote_down_files, - check_exists=False, - mark_failure=True, - back_error=False, - ): + check_exists: bool = False, + mark_failure: bool = True, + back_error: bool = False, + ) -> None: pass # for ii in job_dirs : @@ -112,7 +116,7 @@ def download( # else: # raise RuntimeError('do not find download file ' + fname) - def block_call(self, cmd): + def block_call(self, cmd: str) -> Tuple[int, None, SPRetObj, SPRetObj]: proc = sp.Popen( cmd, cwd=self.local_root, shell=True, stdout=sp.PIPE, stderr=sp.PIPE ) @@ -122,37 +126,39 @@ def block_call(self, cmd): code = proc.returncode return code, None, stdout, stderr - def clean(self): + def clean(self) -> None: pass - def write_file(self, fname, write_str): + def write_file(self, fname: str, write_str: str) -> None: os.makedirs(self.remote_root, exist_ok=True) with open(os.path.join(self.remote_root, fname), "w") as fp: fp.write(write_str) - def read_file(self, fname): + def read_file(self, fname: str) -> str: with open(os.path.join(self.remote_root, fname)) as fp: ret = fp.read() return ret - def check_file_exists(self, fname): + def check_file_exists(self, fname: str) -> bool: # submission_work_base = os.path.join(self.local_root, self.submission.work_base) # file_to_be_checked = os.path.join(submission_work_base, fname) # print('debug:dpdispatcher.LazyLocalContext().check_file_exists:file_to_be_checked', file_to_be_checked) # return os.path.isfile(file_to_be_checked) return os.path.isfile(os.path.join(self.remote_root, fname)) - def call(self, cmd): + def call(self, cmd: str) -> sp.Popen: # type: ignore[type-arg] cwd = os.getcwd() proc = sp.Popen( cmd, cwd=self.local_root, shell=True, stdout=sp.PIPE, stderr=sp.PIPE ) return proc - def check_finish(self, proc): + def check_finish(self, proc: sp.Popen) -> bool: # type: ignore[type-arg] return proc.poll() is not None - def get_return(self, proc): + def get_return( + self, proc: sp.Popen + ) -> Tuple[Optional[int], Optional[SPRetObj], Optional[SPRetObj]]: # type: ignore[type-arg] ret = proc.poll() if ret is None: return None, None, None diff --git a/dpdispatcher/contexts/local_context.py b/dpdispatcher/contexts/local_context.py index eda0a7d5..9bcdb50e 100644 --- a/dpdispatcher/contexts/local_context.py +++ b/dpdispatcher/contexts/local_context.py @@ -3,22 +3,25 @@ import subprocess as sp from glob import glob from subprocess import TimeoutExpired -from typing import List +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from dargs import Argument from dpdispatcher.base_context import BaseContext from dpdispatcher.dlog import dlog +if TYPE_CHECKING: + from dpdispatcher.submission import Submission + class SPRetObj: - def __init__(self, ret): + def __init__(self, ret: bytes) -> None: self.data = ret - def read(self): + def read(self) -> bytes: return self.data - def readlines(self): + def readlines(self) -> List[str]: lines = self.data.decode("utf-8").splitlines() ret = [] for aa in lines: @@ -26,7 +29,7 @@ def readlines(self): return ret -def _check_file_path(fname): +def _check_file_path(fname: str) -> None: dirname = os.path.dirname(fname) if dirname != "": os.makedirs(dirname, exist_ok=True) @@ -51,12 +54,12 @@ class LocalContext(BaseContext): def __init__( self, - local_root, - remote_root, - remote_profile={}, - *args, - **kwargs, - ): + local_root: str, + remote_root: str, + remote_profile: Dict[str, Any] = {}, # noqa: ANN401 + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> None: assert isinstance(local_root, str) self.init_local_root = local_root self.init_remote_root = remote_root @@ -66,7 +69,7 @@ def __init__( self.symlink = remote_profile.get("symlink", True) @classmethod - def load_from_dict(cls, context_dict): + def load_from_dict(cls, context_dict: Dict[str, Any]) -> "LocalContext": # noqa: ANN401 local_root = context_dict["local_root"] remote_root = context_dict["remote_root"] remote_profile = context_dict.get("remote_profile", {}) @@ -77,17 +80,17 @@ def load_from_dict(cls, context_dict): ) return instance - def get_job_root(self): + def get_job_root(self) -> str: return self.remote_root - def bind_submission(self, submission): + def bind_submission(self, submission: "Submission") -> None: self.submission = submission self.local_root = os.path.join(self.temp_local_root, submission.work_base) self.remote_root = os.path.join( self.temp_remote_root, submission.submission_hash ) - def _copy_from_local_to_remote(self, local_path, remote_path): + def _copy_from_local_to_remote(self, local_path: str, remote_path: str) -> None: if not os.path.exists(local_path): raise FileNotFoundError( f"cannot find uploaded file {os.path.join(local_path)}" @@ -106,7 +109,7 @@ def _copy_from_local_to_remote(self, local_path, remote_path): else: raise ValueError(f"Unknown file type: {local_path}") - def upload(self, submission): + def upload(self, submission: "Submission") -> None: os.makedirs(self.remote_root, exist_ok=True) for ii in submission.belonging_tasks: local_job = os.path.join(self.local_root, ii.task_work_path) @@ -151,8 +154,12 @@ def upload(self, submission): ) def download( - self, submission, check_exists=False, mark_failure=True, back_error=False - ): + self, + submission: "Submission", + check_exists: bool = False, + mark_failure: bool = True, + back_error: bool = False, + ) -> None: for ii in submission.belonging_tasks: local_job = os.path.join(self.local_root, ii.task_work_path) remote_job = os.path.join(self.remote_root, ii.task_work_path) @@ -301,7 +308,7 @@ def download( # no nothing in the case of linked files pass - def block_call(self, cmd): + def block_call(self, cmd: str) -> Tuple[int, None, SPRetObj, SPRetObj]: proc = sp.Popen( cmd, cwd=self.remote_root, shell=True, stdout=sp.PIPE, stderr=sp.PIPE ) @@ -311,32 +318,34 @@ def block_call(self, cmd): code = proc.returncode return code, None, stdout, stderr - def clean(self): + def clean(self) -> None: shutil.rmtree(self.remote_root, ignore_errors=True) - def write_file(self, fname, write_str): + def write_file(self, fname: str, write_str: str) -> None: os.makedirs(self.remote_root, exist_ok=True) with open(os.path.join(self.remote_root, fname), "w") as fp: fp.write(write_str) - def read_file(self, fname): + def read_file(self, fname: str) -> str: with open(os.path.join(self.remote_root, fname)) as fp: ret = fp.read() return ret - def check_file_exists(self, fname): + def check_file_exists(self, fname: str) -> bool: return os.path.isfile(os.path.join(self.remote_root, fname)) - def call(self, cmd): + def call(self, cmd: str) -> sp.Popen: # type: ignore[type-arg] proc = sp.Popen( cmd, cwd=self.remote_root, shell=True, stdout=sp.PIPE, stderr=sp.PIPE ) return proc - def check_finish(self, proc): + def check_finish(self, proc: sp.Popen) -> bool: # type: ignore[type-arg] return proc.poll() is not None - def get_return(self, proc): + def get_return( + self, proc: sp.Popen + ) -> Tuple[Optional[int], Optional[SPRetObj], Optional[SPRetObj]]: # type: ignore[type-arg] ret = proc.poll() if ret is None: return None, None, None diff --git a/dpdispatcher/contexts/openapi_context.py b/dpdispatcher/contexts/openapi_context.py index 8847a8d8..d3c6938d 100644 --- a/dpdispatcher/contexts/openapi_context.py +++ b/dpdispatcher/contexts/openapi_context.py @@ -2,13 +2,15 @@ import os import shutil import uuid +from typing import TYPE_CHECKING, Any, NoReturn, Optional from zipfile import ZipFile import tqdm try: from bohrium import Bohrium - from bohrium.resources import Job, Tiefblue + from bohrium.resources import Job as BohriumJob + from bohrium.resources import Tiefblue except ModuleNotFoundError as e: found_bohriumsdk = False import_bohrium_error = e @@ -20,18 +22,22 @@ from dpdispatcher.dlog import dlog from dpdispatcher.utils.job_status import JobStatus +if TYPE_CHECKING: + from dpdispatcher.submission import Job as DPJob + from dpdispatcher.submission import Submission + DP_CLOUD_SERVER_HOME_DIR = os.path.join( os.path.expanduser("~"), ".dpdispatcher/", "dp_cloud_server/" ) -def unzip_file(zip_file, out_dir="./"): +def unzip_file(zip_file: str, out_dir: str = "./") -> None: obj = ZipFile(zip_file, "r") for item in obj.namelist(): obj.extract(item, out_dir) -def zip_file_list(root_path, zip_filename, file_list=[]): +def zip_file_list(root_path: str, zip_filename: str, file_list: list[str] = []) -> str: out_zip_file = os.path.join(root_path, zip_filename) # print('debug: file_list', file_list) zip_obj = ZipFile(out_zip_file, "w") @@ -58,12 +64,12 @@ def zip_file_list(root_path, zip_filename, file_list=[]): class OpenAPIContext(BaseContext): def __init__( self, - local_root, - remote_root=None, - remote_profile={}, - *args, - **kwargs, - ): + local_root: str, + remote_root: Optional[str] = None, + remote_profile: dict[str, Any] = {}, # noqa: ANN401 + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> None: if not found_bohriumsdk: raise ModuleNotFoundError( "bohriumsdk not installed. Install dpdispatcher with `pip install dpdispatcher[bohrium]`" @@ -99,12 +105,12 @@ def __init__( access_key=access_key, project_id=project_id, app_key=app_key ) self.storage = Tiefblue() - self.job = Job(client=self.client) + self.job = BohriumJob(client=self.client) self.jgid = None os.makedirs(DP_CLOUD_SERVER_HOME_DIR, exist_ok=True) @classmethod - def load_from_dict(cls, context_dict): + def load_from_dict(cls, context_dict: dict[str, Any]) -> "OpenAPIContext": # noqa: ANN401 local_root = context_dict.get("local_root", "./") remote_root = context_dict.get("remote_root", None) remote_profile = context_dict.get("remote_profile", {}) @@ -116,7 +122,7 @@ def load_from_dict(cls, context_dict): ) return bohrium_context - def bind_submission(self, submission): + def bind_submission(self, submission: "Submission") -> None: self.submission = submission self.local_root = os.path.join(self.temp_local_root, submission.work_base) self.remote_root = "." @@ -125,7 +131,7 @@ def bind_submission(self, submission): self.machine = submission.machine - def _gen_object_key(self, job, zip_filename): + def _gen_object_key(self, job: "DPJob", zip_filename: str) -> str: if hasattr(job, "upload_path") and job.upload_path: return job.upload_path else: @@ -136,7 +142,9 @@ def _gen_object_key(self, job, zip_filename): setattr(job, "upload_path", path) return path - def upload_job(self, job, common_files=None): + def upload_job( + self, job: "DPJob", common_files: Optional[list[str]] = None + ) -> None: if common_files is None: common_files = [] self.machine.gen_local_script(job) @@ -177,7 +185,7 @@ def upload_job(self, job, common_files=None): # self._backup(self.local_root, upload_zip) - def upload(self, submission): + def upload(self, submission: "Submission") -> None: # oss_task_dir = os.path.join('%s/%s/%s.zip' % ('indicate', file_uuid, file_uuid)) # zip_filename = submission.submission_hash + '.zip' # oss_task_zip = 'indicate/' + submission.submission_hash + '/' + zip_filename @@ -207,8 +215,12 @@ def upload(self, submission): # api.upload(self.oss_task_dir, zip_task_file) def download( - self, submission, check_exists=False, mark_failure=True, back_error=False - ): + self, + submission: "Submission", + check_exists: bool = False, + mark_failure: bool = True, + back_error: bool = False, + ) -> bool: jobs = submission.belonging_jobs job_hashs = {} job_infos = {} @@ -255,45 +267,47 @@ def download( ) return True - def write_file(self, fname, write_str): + def write_file(self, fname: str, write_str: str) -> bool: result = self.write_home_file(fname, write_str) return result - def write_local_file(self, fname, write_str): + def write_local_file(self, fname: str, write_str: str) -> str: local_filename = os.path.join(self.local_root, fname) with open(local_filename, "w") as f: f.write(write_str) return local_filename - def read_file(self, fname): + def read_file(self, fname: str) -> str: result = self.read_home_file(fname) return result - def write_home_file(self, fname, write_str): + def write_home_file(self, fname: str, write_str: str) -> bool: # os.makedirs(self.remote_root, exist_ok = True) with open(os.path.join(DP_CLOUD_SERVER_HOME_DIR, fname), "w") as fp: fp.write(write_str) return True - def read_home_file(self, fname): + def read_home_file(self, fname: str) -> str: with open(os.path.join(DP_CLOUD_SERVER_HOME_DIR, fname)) as fp: ret = fp.read() return ret - def check_file_exists(self, fname): + def check_file_exists(self, fname: str) -> bool: result = self.check_home_file_exits(fname) return result - def check_home_file_exits(self, fname): + def check_home_file_exits(self, fname: str) -> bool: return os.path.isfile(os.path.join(DP_CLOUD_SERVER_HOME_DIR, fname)) - def clean(self): + def clean(self) -> bool: submission_file_name = f"{self.submission.submission_hash}.json" submission_json = os.path.join(DP_CLOUD_SERVER_HOME_DIR, submission_file_name) os.remove(submission_json) return True - def _check_if_job_has_already_downloaded(self, target, local_root): + def _check_if_job_has_already_downloaded( + self, target: str, local_root: str + ) -> bool: backup_file_location = os.path.join( local_root, "backup", os.path.split(target)[1] ) @@ -302,7 +316,7 @@ def _check_if_job_has_already_downloaded(self, target, local_root): else: return False - def _backup(self, local_root, target): + def _backup(self, local_root: str, target: str) -> None: try: # move to backup directory os.makedirs(os.path.join(local_root, "backup"), exist_ok=True) @@ -312,13 +326,13 @@ def _backup(self, local_root, target): except (OSError, shutil.Error) as e: dlog.exception("unable to backup file, " + str(e)) - def _clean_backup(self, local_root, keep_backup=True): + def _clean_backup(self, local_root: str, keep_backup: bool = True) -> None: if not keep_backup: dir_to_be_removed = os.path.join(local_root, "backup") if os.path.exists(dir_to_be_removed): shutil.rmtree(dir_to_be_removed) - def block_call(self, cmd): + def block_call(self, cmd: str) -> NoReturn: raise RuntimeError( "Unsupported method. You may use an unsupported combination of the machine and the context." ) diff --git a/dpdispatcher/contexts/ssh_context.py b/dpdispatcher/contexts/ssh_context.py index b3240966..69ccfccf 100644 --- a/dpdispatcher/contexts/ssh_context.py +++ b/dpdispatcher/contexts/ssh_context.py @@ -12,7 +12,7 @@ from functools import lru_cache from glob import glob from stat import S_ISDIR, S_ISREG -from typing import List +from typing import TYPE_CHECKING, Any, List, Optional import paramiko import paramiko.ssh_exception @@ -30,23 +30,26 @@ rsync, ) +if TYPE_CHECKING: + from dpdispatcher.submission import Submission + class SSHSession: def __init__( self, - hostname, - username, - password=None, - port=22, - key_filename=None, - passphrase=None, - timeout=10, - totp_secret=None, - tar_compress=True, - look_for_keys=True, - execute_command=None, - proxy_command=None, - ): + hostname: str, + username: str, + password: Optional[str] = None, + port: int = 22, + key_filename: Optional[str] = None, + passphrase: Optional[str] = None, + timeout: int = 10, + totp_secret: Optional[str] = None, + tar_compress: bool = True, + look_for_keys: bool = True, + execute_command: Optional[str] = None, + proxy_command: Optional[str] = None, + ) -> None: self.hostname = hostname self.username = username self.password = password @@ -87,7 +90,7 @@ def __init__( # count += 1 # time.sleep(sleep_time) - def ensure_alive(self, max_check=10, sleep_time=10): + def ensure_alive(self, max_check: int = 10, sleep_time: int = 10) -> None: count = 1 while not self._check_alive(): if count == max_check: @@ -99,7 +102,7 @@ def ensure_alive(self, max_check=10, sleep_time=10): count += 1 time.sleep(sleep_time) - def _check_alive(self): + def _check_alive(self) -> Optional[bool]: if self.ssh is None: return False try: @@ -129,7 +132,7 @@ def _check_alive(self): # transport.set_keepalive(60) @retry(max_retry=6, sleep=1) - def _setup_ssh(self): + def _setup_ssh(self) -> None: # machine = self.machine self.ssh = paramiko.SSHClient() self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy) @@ -246,7 +249,9 @@ def _setup_ssh(self): if self.execute_command is not None: self.exec_command(self.execute_command) - def inter_handler(self, title, instructions, prompt_list): + def inter_handler( + self, title: str, instructions: str, prompt_list: list[tuple[str, bool]] + ) -> list[str]: """inter_handler: the callback for paramiko.transport.auth_interactive. The prototype for this function is defined by Paramiko, so all of the @@ -287,18 +292,18 @@ def inter_handler(self, title, instructions, prompt_list): return resp - def get_ssh_client(self): + def get_ssh_client(self) -> paramiko.SSHClient: return self.ssh # def get_session_root(self): # return self.remote_root - def close(self): + def close(self) -> None: assert self.ssh is not None self.ssh.close() @retry(sleep=1) - def exec_command(self, cmd): + def exec_command(self, cmd: str) -> tuple[Any, Any, Any]: # noqa: ANN401 """Calling self.ssh.exec_command but has an exception check.""" assert self.ssh is not None try: @@ -315,7 +320,7 @@ def exec_command(self, cmd): raise RetrySignal(f"SSH session not active in calling {cmd}") from e @property - def sftp(self): + def sftp(self) -> paramiko.SFTPClient: """Returns sftp. Open a new one if not existing.""" if self._sftp is None: assert self.ssh is not None @@ -324,7 +329,7 @@ def sftp(self): return self._sftp @staticmethod - def arginfo(): + def arginfo() -> list[Argument]: doc_hostname = "hostname or ip of ssh connection." doc_username = "username of target linux system" doc_password = ( @@ -411,7 +416,7 @@ def arginfo(): ) return ssh_remote_profile_format - def put(self, from_f, to_f): + def put(self, from_f: str, to_f: str) -> Optional[paramiko.SFTPAttributes]: if self.rsync_available: # For rsync, we need to use %h:%p placeholders for target host/port proxy_cmd_rsync = None @@ -429,7 +434,7 @@ def put(self, from_f, to_f): ) return self.sftp.put(from_f, to_f) - def get(self, from_f, to_f): + def get(self, from_f: str, to_f: str) -> Optional[paramiko.SFTPAttributes]: if self.rsync_available: # For rsync, we need to use %h:%p placeholders for target host/port proxy_cmd_rsync = None @@ -467,13 +472,13 @@ def remote(self) -> str: class SSHContext(BaseContext): def __init__( self, - local_root, - remote_root, - remote_profile, - clean_asynchronously=False, - *args, - **kwargs, - ): + local_root: str, + remote_root: str, + remote_profile: dict[str, Any], # noqa: ANN401 + clean_asynchronously: bool = False, + *args: Any, # noqa: ANN401 + **kwargs: Any, # noqa: ANN401 + ) -> None: assert isinstance(local_root, str) self.init_local_root = local_root self.init_remote_root = remote_root @@ -501,7 +506,7 @@ def __init__( pass @classmethod - def load_from_dict(cls, context_dict): + def load_from_dict(cls, context_dict: dict[str, Any]) -> "SSHContext": # noqa: ANN401 # instance = cls() # input = dict( # hostname = jdata['hostname'], @@ -535,20 +540,20 @@ def load_from_dict(cls, context_dict): return ssh_context @property - def ssh(self): + def ssh(self) -> paramiko.SSHClient: return self.ssh_session.get_ssh_client() @property - def sftp(self): + def sftp(self) -> paramiko.SFTPClient: return self.ssh_session.sftp - def close(self): + def close(self) -> None: self.ssh_session.close() - def get_job_root(self): + def get_job_root(self) -> str: return self.remote_root - def bind_submission(self, submission): + def bind_submission(self, submission: "Submission") -> None: assert self.ssh_session is not None assert self.ssh_session.ssh is not None self.submission = submission @@ -597,7 +602,13 @@ def bind_submission(self, submission): # except Exception: # pass - def _walk_directory(self, files, work_path, file_list, directory_list): + def _walk_directory( + self, + files: list[str], + work_path: str, + file_list: list[str], + directory_list: list[str], + ) -> None: """Convert input path to list of files and directories.""" for jj in files: file_name = os.path.join(work_path, jj) @@ -628,10 +639,10 @@ def _walk_directory(self, files, work_path, file_list, directory_list): def upload( self, # job_dirs, - submission, + submission: "Submission", # local_up_files, - dereference=True, - ): + dereference: bool = True, + ) -> None: assert self.remote_root is not None dlog.info(f"remote path: {self.remote_root}") # remote_cwd = @@ -705,7 +716,13 @@ def upload( tar_compress=self.remote_profile.get("tar_compress", None), ) - def list_remote_dir(self, sftp, remote_dir, ref_remote_root, result_list): + def list_remote_dir( + self, + sftp: paramiko.SFTPClient, + remote_dir: str, + ref_remote_root: str, + result_list: list[str], + ) -> None: for entry in sftp.listdir_attr(remote_dir): remote_name = pathlib.PurePath( os.path.join(remote_dir, entry.filename) @@ -719,13 +736,13 @@ def list_remote_dir(self, sftp, remote_dir, ref_remote_root, result_list): def download( self, - submission, + submission: "Submission", # job_dirs, # remote_down_files, - check_exists=False, - mark_failure=True, - back_error=False, - ): + check_exists: bool = False, + mark_failure: bool = True, + back_error: bool = False, + ) -> None: assert self.remote_root is not None self.ssh_session.ensure_alive() file_list = [] @@ -797,7 +814,7 @@ def download( tar_compress=self.remote_profile.get("tar_compress", None), ) - def block_call(self, cmd): + def block_call(self, cmd: str) -> int: assert self.remote_root is not None self.ssh_session.ensure_alive() stdin, stdout, stderr = self.ssh_session.exec_command( @@ -806,11 +823,11 @@ def block_call(self, cmd): exit_status = stdout.channel.recv_exit_status() return exit_status, stdin, stdout, stderr - def clean(self): + def clean(self) -> None: self.ssh_session.ensure_alive() self._rmtree(self.remote_root) - def write_file(self, fname, write_str): + def write_file(self, fname: str, write_str: str) -> None: assert self.remote_root is not None self.ssh_session.ensure_alive() fname = pathlib.PurePath(os.path.join(self.remote_root, fname)).as_posix() @@ -827,7 +844,7 @@ def write_file(self, fname, write_str): dlog.exception(f"Error writing to file {fname}") raise e - def read_file(self, fname): + def read_file(self, fname: str) -> str: assert self.remote_root is not None self.ssh_session.ensure_alive() with self.sftp.open( @@ -837,7 +854,7 @@ def read_file(self, fname): ret = fp.read().decode("utf-8") return ret - def check_file_exists(self, fname): + def check_file_exists(self, fname: str) -> bool: assert self.remote_root is not None self.ssh_session.ensure_alive() try: @@ -849,24 +866,24 @@ def check_file_exists(self, fname): ret = False return ret - def call(self, cmd): + def call(self, cmd: str) -> dict[str, Any]: # noqa: ANN401 stdin, stdout, stderr = self.ssh_session.exec_command(cmd) # stdin, stdout, stderr = self.ssh.exec_command('echo $$; exec ' + cmd) # pid = stdout.readline().strip() # print(pid) return {"stdin": stdin, "stdout": stdout, "stderr": stderr} - def check_finish(self, proc): + def check_finish(self, proc: dict[str, Any]) -> bool: # noqa: ANN401 return proc["stdout"].channel.exit_status_ready() - def get_return(self, cmd_pipes): + def get_return(self, cmd_pipes: dict[str, Any]) -> tuple[Optional[int], Any, Any]: # noqa: ANN401 if not self.check_finish(cmd_pipes): return None, None, None else: retcode = cmd_pipes["stdout"].channel.recv_exit_status() return retcode, cmd_pipes["stdout"], cmd_pipes["stderr"] - def _rmtree(self, remotepath, verbose=False): + def _rmtree(self, remotepath: str, verbose: bool = False) -> None: """Remove the remote path.""" # The original implementation method removes files one by one using sftp. # If the latency of the remote server is high, it is very slow. @@ -884,11 +901,11 @@ def _rmtree(self, remotepath, verbose=False): def _put_files( self, - files, - dereference=True, - directories=None, - tar_compress=True, - ): + files: list[str], + dereference: bool = True, + directories: Optional[list[str]] = None, + tar_compress: bool = True, + ) -> None: """Upload files to server. Parameters @@ -957,7 +974,7 @@ def _put_files( os.remove(from_f) self.sftp.remove(to_f) - def _get_files(self, files, tar_compress=True): + def _get_files(self, files: list[str], tar_compress: bool = True) -> None: assert self.remote_root is not None # avoid compressing duplicated files files = list(set(files)) diff --git a/dpdispatcher/dpdisp.py b/dpdispatcher/dpdisp.py index 0a87af00..3f7858ce 100644 --- a/dpdispatcher/dpdisp.py +++ b/dpdispatcher/dpdisp.py @@ -97,7 +97,7 @@ def main_parser() -> argparse.ArgumentParser: return parser -def parse_args(args: Optional[List[str]] = None): +def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace: """Dpdispatcher commandline options argument parsing. Parameters @@ -115,7 +115,7 @@ def parse_args(args: Optional[List[str]] = None): return parsed_args -def main(): +def main() -> None: args = parse_args() if args.command == "submission": handle_submission( diff --git a/dpdispatcher/entrypoints/gui.py b/dpdispatcher/entrypoints/gui.py index 8b6b9e0a..07b19470 100644 --- a/dpdispatcher/entrypoints/gui.py +++ b/dpdispatcher/entrypoints/gui.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later """DP-GUI entrypoint.""" +from typing import Any -def start_dpgui(*, port: int, bind_all: bool, **kwargs): + +def start_dpgui(*, port: int, bind_all: bool, **kwargs: Any) -> None: # noqa: ANN401 """Host DP-GUI server. Parameters diff --git a/dpdispatcher/entrypoints/run.py b/dpdispatcher/entrypoints/run.py index a7bc00df..1013060f 100644 --- a/dpdispatcher/entrypoints/run.py +++ b/dpdispatcher/entrypoints/run.py @@ -3,7 +3,7 @@ from dpdispatcher.run import run_pep723 -def run(*, filename: str): +def run(*, filename: str) -> None: with open(filename) as f: script = f.read() run_pep723(script) diff --git a/dpdispatcher/entrypoints/submission.py b/dpdispatcher/entrypoints/submission.py index 7243dd16..c2eaca97 100644 --- a/dpdispatcher/entrypoints/submission.py +++ b/dpdispatcher/entrypoints/submission.py @@ -13,7 +13,7 @@ def handle_submission( download_finished_task: bool = False, clean: bool = False, reset_fail_count: bool = False, -): +) -> None: """Handle terminated submission. Parameters diff --git a/dpdispatcher/machine.py b/dpdispatcher/machine.py index 02beb7ea..227c7b23 100644 --- a/dpdispatcher/machine.py +++ b/dpdispatcher/machine.py @@ -2,13 +2,17 @@ import pathlib import shlex from abc import ABCMeta, abstractmethod -from typing import List, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import yaml from dargs import Argument, Variant from dpdispatcher.base_context import BaseContext from dpdispatcher.dlog import dlog +from dpdispatcher.utils.job_status import JobStatus + +if TYPE_CHECKING: + from dpdispatcher.submission import Job, Resources, Submission script_template = """\ {script_header} @@ -67,7 +71,7 @@ class Machine(metaclass=ABCMeta): # notes: this attribute can be inherited alias: Tuple[str, ...] = tuple() - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> "Machine": # noqa: ANN401 if cls is Machine: subcls = cls.subclasses_dict[kwargs["batch_type"]] instance = subcls.__new__(subcls, *args, **kwargs) @@ -77,15 +81,15 @@ def __new__(cls, *args, **kwargs): def __init__( self, - batch_type=None, - context_type=None, - local_root=None, - remote_root=None, - remote_profile={}, - retry_count=3, + batch_type: Optional[str] = None, + context_type: Optional[str] = None, + local_root: Optional[str] = None, + remote_root: Optional[str] = None, + remote_profile: Dict[str, Any] = {}, # noqa: ANN401 + retry_count: int = 3, *, - context=None, - ): + context: Optional[BaseContext] = None, + ) -> None: if context is None: assert isinstance(self, self.__class__.subclasses_dict[batch_type]) context = BaseContext( @@ -99,7 +103,7 @@ def __init__( self.bind_context(context=context) self.retry_count = retry_count - def bind_context(self, context): + def bind_context(self, context: BaseContext) -> None: self.context = context # def __init__ (self, @@ -111,7 +115,7 @@ def bind_context(self, context): # self.sub_script_name = '%s.sub' % self.context.job_uuid # self.job_id_name = '%s_job_id' % self.context.job_uuid - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs: Any) -> None: # noqa: ANN401 super().__init_subclass__(**kwargs) alias = [cls.__name__, *cls.alias] for aa in alias: @@ -121,21 +125,21 @@ def __init_subclass__(cls, **kwargs): # cls.subclasses.append(cls) @classmethod - def load_from_json(cls, json_path): + def load_from_json(cls, json_path: str) -> "Machine": with open(json_path) as f: machine_dict = json.load(f) machine = cls.load_from_dict(machine_dict=machine_dict) return machine @classmethod - def load_from_yaml(cls, yaml_path): + def load_from_yaml(cls, yaml_path: str) -> "Machine": with open(yaml_path) as f: machine_dict = yaml.safe_load(f) machine = cls.load_from_dict(machine_dict=machine_dict) return machine @classmethod - def load_from_dict(cls, machine_dict): + def load_from_dict(cls, machine_dict: Dict[str, Any]) -> "Machine": # noqa: ANN401 batch_type = machine_dict["batch_type"] try: machine_class = cls.subclasses_dict[batch_type] @@ -154,7 +158,7 @@ def load_from_dict(cls, machine_dict): machine = machine_class(context=context, retry_count=retry_count) return machine - def serialize(self, if_empty_remote_profile=False): + def serialize(self, if_empty_remote_profile: bool = False) -> Dict[str, Any]: # noqa: ANN401 machine_dict = {} machine_dict["batch_type"] = self.__class__.__name__ machine_dict["context_type"] = self.context.__class__.__name__ @@ -170,46 +174,46 @@ def serialize(self, if_empty_remote_profile=False): machine_dict = base.normalize_value(machine_dict, trim_pattern="_*") return machine_dict - def __eq__(self, other): + def __eq__(self, other: object) -> bool: return self.serialize() == other.serialize() @classmethod - def deserialize(cls, machine_dict): + def deserialize(cls, machine_dict: Dict[str, Any]) -> "Machine": # noqa: ANN401 machine = cls.load_from_dict(machine_dict=machine_dict) return machine @abstractmethod - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: raise NotImplementedError( "abstract method check_status should be implemented by derived class" ) - def default_resources(self, res): + def default_resources(self, res: "Resources") -> "Resources": raise NotImplementedError( "abstract method default_resources should be implemented by derived class" ) - def sub_script_head(self, res): + def sub_script_head(self, res: "Resources") -> str: raise NotImplementedError( "abstract method sub_script_head should be implemented by derived class" ) - def sub_script_cmd(self, res): + def sub_script_cmd(self, res: "Resources") -> str: raise NotImplementedError( "abstract method sub_script_cmd should be implemented by derived class" ) @abstractmethod - def do_submit(self, job): + def do_submit(self, job: "Job") -> str: """Submit a single job, assuming that no job is running there.""" raise NotImplementedError( "abstract method do_submit should be implemented by derived class" ) - def gen_script_run_command(self, job): + def gen_script_run_command(self, job: "Job") -> str: return f"source $REMOTE_ROOT/{job.script_file_name}.run" - def gen_script(self, job): + def gen_script(self, job: "Job") -> str: script_header = self.gen_script_header(job) script_custom_flags = self.gen_script_custom_flags_lines(job) script_env = self.gen_script_env(job) @@ -224,25 +228,25 @@ def gen_script(self, job): ) return script - def check_if_recover(self, submission): + def check_if_recover(self, submission: "Submission") -> bool: submission_hash = submission.submission_hash submission_file_name = f"{submission_hash}.json" if_recover = self.context.check_file_exists(submission_file_name) return if_recover @abstractmethod - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: raise NotImplementedError( "abstract method check_finish_tag should be implemented by derived class" ) @abstractmethod - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: raise NotImplementedError( "abstract method gen_script_header should be implemented by derived class" ) - def gen_script_custom_flags_lines(self, job): + def gen_script_custom_flags_lines(self, job: "Job") -> str: custom_flags_lines = "" custom_flags = job.resources.custom_flags for ii in custom_flags: @@ -250,7 +254,7 @@ def gen_script_custom_flags_lines(self, job): custom_flags_lines += line return custom_flags_lines - def gen_script_env(self, job): + def gen_script_env(self, job: "Job") -> str: source_files_part = "" module_unload_part = "" @@ -304,7 +308,7 @@ def gen_script_env(self, job): ) return script_env - def gen_script_command(self, job): + def gen_script_command(self, job: "Job") -> str: script_command = "" resources = job.resources # in_para_task_num = 0 @@ -339,7 +343,7 @@ def gen_script_command(self, job): script_command += self.gen_script_wait(resources=resources) return script_command - def gen_script_end(self, job): + def gen_script_end(self, job: "Job") -> str: job_tag_finished = job.job_hash + "_job_tag_finished" flag_if_job_task_fail = job.job_hash + "_flag_if_job_task_fail" @@ -353,7 +357,7 @@ def gen_script_end(self, job): ) return script_end - def gen_script_wait(self, resources): + def gen_script_wait(self, resources: "Resources") -> str: # if not resources.strategy.get('if_cuda_multi_devices', None): # return "wait \n" para_deg = resources.para_deg @@ -371,7 +375,7 @@ def gen_script_wait(self, resources): return "wait \n" return "" - def gen_command_env_cuda_devices(self, resources): + def gen_command_env_cuda_devices(self, resources: "Resources") -> str: # task_need_resources = task.task_need_resources # task_need_gpus = task_need_resources.get('task_need_gpus', 1) command_env = "" @@ -388,7 +392,7 @@ def gen_command_env_cuda_devices(self, resources): return command_env @classmethod - def arginfo(cls): + def arginfo(cls) -> Argument: # TODO: change the possible value of batch and context types after we refactor the code doc_batch_type = "The batch job system type. Option: " + ", ".join(cls.options) doc_context_type = ( @@ -473,7 +477,7 @@ def resources_subfields(cls) -> List[Argument]: ) ] - def kill(self, job): + def kill(self, job: "Job") -> None: """Kill the job. If not implemented, pass and let the user manually kill it. @@ -485,7 +489,7 @@ def kill(self, job): """ dlog.warning(f"Job {job.job_id} should be manually killed") - def get_exit_code(self, job): + def get_exit_code(self, job: "Job") -> int: """Get exit code of the job. Parameters diff --git a/dpdispatcher/machines/JH_UniScheduler.py b/dpdispatcher/machines/JH_UniScheduler.py index a9f5a421..0060d3b8 100644 --- a/dpdispatcher/machines/JH_UniScheduler.py +++ b/dpdispatcher/machines/JH_UniScheduler.py @@ -1,5 +1,5 @@ import shlex -from typing import List +from typing import TYPE_CHECKING, List from dargs import Argument @@ -12,6 +12,9 @@ retry, ) +if TYPE_CHECKING: + from dpdispatcher.submission import Job + JH_UniScheduler_script_header_template = """\ #!/bin/bash -l #JSUB -e %J.err @@ -25,11 +28,11 @@ class JH_UniScheduler(Machine): """JH_UniScheduler batch.""" - def gen_script(self, job): + def gen_script(self, job: "Job") -> str: JH_UniScheduler_script = super().gen_script(job) return JH_UniScheduler_script - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: resources = job.resources script_header_dict = { "JH_UniScheduler_nodes_line": f"#JSUB -n {resources.number_node * resources.cpu_per_node}", @@ -59,7 +62,7 @@ def gen_script_header(self, job): return JH_UniScheduler_script_header @retry() - def do_submit(self, job): + def do_submit(self, job: "Job") -> str: script_file_name = job.script_file_name script_str = self.gen_script(job) job_id_name = job.job_hash + "_job_id" @@ -85,7 +88,7 @@ def do_submit(self, job): return job_id @retry() - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: try: job_id = job.job_id except AttributeError: @@ -124,7 +127,7 @@ def check_status(self, job): else: return JobStatus.unknown - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" return self.context.check_file_exists(job_tag_finished) @@ -157,7 +160,7 @@ def resources_subfields(cls) -> List[Argument]: ) ] - def kill(self, job): + def kill(self, job: "Job") -> None: """Kill the job. Parameters diff --git a/dpdispatcher/machines/distributed_shell.py b/dpdispatcher/machines/distributed_shell.py index f5e0962e..c79ed945 100644 --- a/dpdispatcher/machines/distributed_shell.py +++ b/dpdispatcher/machines/distributed_shell.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + from dpdispatcher.dlog import dlog from dpdispatcher.machine import Machine from dpdispatcher.utils.job_status import JobStatus @@ -6,6 +8,9 @@ run_cmd_with_all_output, ) +if TYPE_CHECKING: + from dpdispatcher.submission import Job + shell_script_header_template = """ #!/bin/bash -l set -x @@ -46,7 +51,7 @@ class DistributedShell(Machine): - def gen_script_env(self, job): + def gen_script_env(self, job: "Job") -> str: source_files_part = "" module_unload_part = "" @@ -93,7 +98,7 @@ def gen_script_env(self, job): ) return script_env - def gen_script_end(self, job): + def gen_script_end(self, job: "Job") -> str: all_task_dirs = "" for task in job.job_task_list: all_task_dirs += f"{task.task_work_path} " @@ -114,7 +119,7 @@ def gen_script_end(self, job): ) return script_end - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: resources = job.resources if ( resources["strategy"].get("customized_script_header_template_file") @@ -128,7 +133,7 @@ def gen_script_header(self, job): shell_script_header = shell_script_header_template return shell_script_header - def do_submit(self, job): + def do_submit(self, job: "Job") -> int: """Submit th job to yarn using distributed shell. Parameters @@ -188,7 +193,7 @@ def do_submit(self, job): self.context.write_file(job_id_name, str(job_id)) return job_id - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: job_id = job.job_id if job_id == "": return JobStatus.unsubmitted @@ -212,6 +217,6 @@ def check_status(self, job): else: return JobStatus.terminated - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" return self.context.check_file_exists(job_tag_finished) diff --git a/dpdispatcher/machines/dp_cloud_server.py b/dpdispatcher/machines/dp_cloud_server.py index d919451d..84017da7 100644 --- a/dpdispatcher/machines/dp_cloud_server.py +++ b/dpdispatcher/machines/dp_cloud_server.py @@ -3,6 +3,7 @@ import time import uuid import warnings +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from dpdispatcher.dlog import dlog from dpdispatcher.machine import Machine @@ -11,6 +12,10 @@ from dpdispatcher.utils.job_status import JobStatus from dpdispatcher.utils.utils import customized_script_header_template +if TYPE_CHECKING: + from dpdispatcher.contexts.context import Context + from dpdispatcher.submission import Job, Submission + shell_script_header_template = """ #!/bin/bash -l """ @@ -19,7 +24,7 @@ class Bohrium(Machine): alias = ("Lebesgue", "DpCloudServer") - def __init__(self, context, **kwargs): + def __init__(self, context: "Context", **kwargs: Any) -> None: # noqa: ANN401 super().__init__(context=context, **kwargs) self.context = context self.input_data = context.remote_profile["input_data"].copy() @@ -67,11 +72,11 @@ def __init__(self, context, **kwargs): self.group_id = None - def gen_script(self, job): + def gen_script(self, job: "Job") -> str: shell_script = super(DpCloudServer, self).gen_script(job) return shell_script - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: resources = job.resources if ( resources["strategy"].get("customized_script_header_template_file") @@ -85,7 +90,7 @@ def gen_script_header(self, job): shell_script_header = shell_script_header_template return shell_script_header - def gen_local_script(self, job): + def gen_local_script(self, job: "Job") -> str: script_str = self.gen_script(job) script_file_name = job.script_file_name self.context.write_local_file(fname=script_file_name, write_str=script_str) @@ -96,7 +101,7 @@ def gen_local_script(self, job): ) return script_file_name - def _gen_backward_files_list(self, job): + def _gen_backward_files_list(self, job: "Job") -> List[str]: result_file_list = [] # result_file_list.extend(job.backward_common_files) for task in job.job_task_list: @@ -106,7 +111,7 @@ def _gen_backward_files_list(self, job): result_file_list = list(set(result_file_list)) return result_file_list - def _gen_oss_path(self, job, zip_filename): + def _gen_oss_path(self, job: "Job", zip_filename: str) -> str: if hasattr(job, "upload_path") and job.upload_path: return job.upload_path else: @@ -122,7 +127,7 @@ def _gen_oss_path(self, job, zip_filename): setattr(job, "upload_path", path) return path - def do_submit(self, job): + def do_submit(self, job: "Job") -> str: self.gen_local_script(job) zip_filename = job.job_hash + ".zip" # oss_task_zip = 'indicate/' + job.job_hash + '/' + zip_filename @@ -157,7 +162,7 @@ def do_submit(self, job): job.job_state = JobStatus.waiting return job_id - def _get_job_detail(self, job_id, group_id): + def _get_job_detail(self, job_id: int, group_id: Optional[int]) -> Dict[str, Any]: check_return = self.api.get_job_detail(job_id) assert check_return is not None, ( f"Failed to retrieve tasks information. To resubmit this job, please " @@ -170,7 +175,7 @@ def _get_job_detail(self, job_id, group_id): ) return check_return - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: if job.job_id == "": return JobStatus.unsubmitted job_id = job.job_id @@ -217,7 +222,7 @@ def check_status(self, job): print(job_log, end="") return job_state - def _download_job(self, job): + def _download_job(self, job: "Job") -> None: job_url = self.api.get_job_result_url(job.job_id) if not job_url: return @@ -239,19 +244,23 @@ def _download_job(self, job): except (OSError, shutil.Error) as e: dlog.exception("unable to backup file, " + str(e)) - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" dlog.info("check if job finished: ", job.job_id, job_tag_finished) return self.context.check_file_exists(job_tag_finished) # return # pass - def check_if_recover(self, submission): + def check_if_recover(self, submission: "Submission") -> bool: return False # pass @staticmethod - def map_dp_job_state(status, exit_code, ignore_exit_code=True): + def map_dp_job_state( + status: Union[int, JobStatus], + exit_code: int, + ignore_exit_code: bool = True, + ) -> JobStatus: if isinstance(status, JobStatus): return status map_dict = { @@ -272,7 +281,7 @@ def map_dp_job_state(status, exit_code, ignore_exit_code=True): return JobStatus.finished return map_dict[status] - def kill(self, job): + def kill(self, job: "Job") -> None: """Kill the job. Parameters @@ -283,7 +292,7 @@ def kill(self, job): job_id = job.job_id self.api.kill(job_id) - def get_exit_code(self, job) -> int: + def get_exit_code(self, job: "Job") -> int: job_id = self._parse_job_id(job.job_id) if job_id <= 0: raise RuntimeError(f"cannot parse job id {job.job_id}") diff --git a/dpdispatcher/machines/fugaku.py b/dpdispatcher/machines/fugaku.py index d4a38a4b..71be5c18 100644 --- a/dpdispatcher/machines/fugaku.py +++ b/dpdispatcher/machines/fugaku.py @@ -1,10 +1,14 @@ import shlex +from typing import TYPE_CHECKING from dpdispatcher.dlog import dlog from dpdispatcher.machine import Machine from dpdispatcher.utils.job_status import JobStatus from dpdispatcher.utils.utils import customized_script_header_template +if TYPE_CHECKING: + from dpdispatcher.submission import Job + fugaku_script_header_template = """\ {queue_name_line} {fugaku_node_number_line} @@ -13,11 +17,11 @@ class Fugaku(Machine): - def gen_script(self, job): + def gen_script(self, job: "Job") -> str: fugaku_script = super().gen_script(job) return fugaku_script - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: resources = job.resources fugaku_script_header_dict = {} fugaku_script_header_dict["fugaku_node_number_line"] = ( @@ -43,7 +47,7 @@ def gen_script_header(self, job): ) return fugaku_script_header - def do_submit(self, job): + def do_submit(self, job: "Job") -> str: script_file_name = job.script_file_name script_str = self.gen_script(job) job_id_name = job.job_hash + "_job_id" @@ -67,7 +71,7 @@ def do_submit(self, job): self.context.write_file(job_id_name, job_id) return job_id - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: job_id = job.job_id if job_id == "": return JobStatus.unsubmitted @@ -97,6 +101,6 @@ def check_status(self, job): else: return JobStatus.unknown - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" return self.context.check_file_exists(job_tag_finished) diff --git a/dpdispatcher/machines/lsf.py b/dpdispatcher/machines/lsf.py index 035ceace..4e8032f9 100644 --- a/dpdispatcher/machines/lsf.py +++ b/dpdispatcher/machines/lsf.py @@ -1,5 +1,5 @@ import shlex -from typing import List +from typing import TYPE_CHECKING, List from dargs import Argument @@ -12,6 +12,9 @@ retry, ) +if TYPE_CHECKING: + from dpdispatcher.submission import Job, Resources + lsf_script_header_template = """\ #!/bin/bash -l #BSUB -e %J.err @@ -25,11 +28,11 @@ class LSF(Machine): """LSF batch.""" - def gen_script(self, job): + def gen_script(self, job: "Job") -> str: lsf_script = super().gen_script(job) return lsf_script - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: resources = job.resources script_header_dict = { "lsf_nodes_line": f"#BSUB -n {resources.number_node * resources.cpu_per_node}", @@ -76,7 +79,7 @@ def gen_script_header(self, job): return lsf_script_header @retry() - def do_submit(self, job): + def do_submit(self, job: "Job") -> str: script_file_name = job.script_file_name script_str = self.gen_script(job) job_id_name = job.job_hash + "_job_id" @@ -102,14 +105,14 @@ def do_submit(self, job): return job_id # TODO: derive abstract methods - def sub_script_cmd(self, res): - pass + def sub_script_cmd(self, res: "Resources") -> str: + return "" - def sub_script_head(self, res): - pass + def sub_script_head(self, res: "Resources") -> str: + return "" @retry() - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: try: job_id = job.job_id except AttributeError: @@ -149,7 +152,7 @@ def check_status(self, job): else: return JobStatus.unknown - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" return self.context.check_file_exists(job_tag_finished) @@ -211,7 +214,7 @@ def resources_subfields(cls) -> List[Argument]: ) ] - def kill(self, job): + def kill(self, job: "Job") -> None: """Kill the job. Parameters diff --git a/dpdispatcher/machines/openapi.py b/dpdispatcher/machines/openapi.py index 13821d17..e6565109 100644 --- a/dpdispatcher/machines/openapi.py +++ b/dpdispatcher/machines/openapi.py @@ -1,6 +1,7 @@ import os import shutil import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from zipfile import ZipFile from dpdispatcher.utils.utils import customized_script_header_template @@ -17,19 +18,24 @@ from dpdispatcher.machine import Machine from dpdispatcher.utils.job_status import JobStatus +if TYPE_CHECKING: + from dpdispatcher.contexts.context import Context + from dpdispatcher.submission import Job as SubmissionJob + from dpdispatcher.submission import Submission + shell_script_header_template = """ #!/bin/bash -l """ -def unzip_file(zip_file, out_dir="./"): +def unzip_file(zip_file: str, out_dir: str = "./") -> None: obj = ZipFile(zip_file, "r") for item in obj.namelist(): obj.extract(item, out_dir) class OpenAPI(Machine): - def __init__(self, context, **kwargs): + def __init__(self, context: "Context", **kwargs: Any) -> None: # noqa: ANN401 super().__init__(context=context, **kwargs) if not found_bohriumsdk: raise ModuleNotFoundError( @@ -71,11 +77,11 @@ def __init__(self, context, **kwargs): self.job = Job(client=self.client) self.group_id = None - def gen_script(self, job): + def gen_script(self, job: "SubmissionJob") -> str: shell_script = super().gen_script(job) return shell_script - def gen_script_header(self, job): + def gen_script_header(self, job: "SubmissionJob") -> str: resources = job.resources if ( resources["strategy"].get("customized_script_header_template_file") @@ -89,7 +95,7 @@ def gen_script_header(self, job): shell_script_header = shell_script_header_template return shell_script_header - def gen_local_script(self, job): + def gen_local_script(self, job: "SubmissionJob") -> str: script_str = self.gen_script(job) script_file_name = job.script_file_name self.context.write_local_file(fname=script_file_name, write_str=script_str) @@ -100,7 +106,7 @@ def gen_local_script(self, job): ) return script_file_name - def _gen_backward_files_list(self, job): + def _gen_backward_files_list(self, job: "SubmissionJob") -> List[str]: result_file_list = [] # result_file_list.extend(job.backward_common_files) for task in job.job_task_list: @@ -110,7 +116,7 @@ def _gen_backward_files_list(self, job): result_file_list = list(set(result_file_list)) return result_file_list - def do_submit(self, job): + def do_submit(self, job: "SubmissionJob") -> int: self.gen_local_script(job) project_id = self.remote_profile.get("project_id", 0) @@ -143,7 +149,7 @@ def do_submit(self, job): job.job_state = JobStatus.waiting return job.job_id - def _get_job_detail(self, job_id, group_id): + def _get_job_detail(self, job_id: int, group_id: Optional[int]) -> Dict[str, Any]: check_return = self.job.detail(job_id) assert check_return is not None, ( f"Failed to retrieve tasks information. To resubmit this job, please " @@ -156,7 +162,7 @@ def _get_job_detail(self, job_id, group_id): ) return check_return - def check_status(self, job): + def check_status(self, job: "SubmissionJob") -> JobStatus: if job.job_id == "": return JobStatus.unsubmitted job_id = job.job_id @@ -194,7 +200,7 @@ def check_status(self, job): print(job_log, end="") return job_state - def _download_job(self, job): + def _download_job(self, job: "SubmissionJob") -> None: data = self.job.detail(job.job_id) job_url = data["resultUrl"] if not job_url: @@ -217,19 +223,23 @@ def _download_job(self, job): except (OSError, shutil.Error) as e: dlog.exception("unable to backup file, " + str(e)) - def check_finish_tag(self, job): + def check_finish_tag(self, job: "SubmissionJob") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" dlog.info("check if job finished: ", job.job_id, job_tag_finished) return self.context.check_file_exists(job_tag_finished) # return # pass - def check_if_recover(self, submission): + def check_if_recover(self, submission: "Submission") -> bool: return False # pass @staticmethod - def map_dp_job_state(status, exit_code, ignore_exit_code=True): + def map_dp_job_state( + status: Union[int, JobStatus], + exit_code: int, + ignore_exit_code: bool = True, + ) -> JobStatus: if isinstance(status, JobStatus): return status map_dict = { @@ -250,7 +260,7 @@ def map_dp_job_state(status, exit_code, ignore_exit_code=True): return JobStatus.finished return map_dict[status] - def kill(self, job): + def kill(self, job: "SubmissionJob") -> None: """Kill the job. Parameters @@ -261,7 +271,7 @@ def kill(self, job): job_id = job.job_id self.job.kill(job_id) - def get_exit_code(self, job): + def get_exit_code(self, job: "SubmissionJob") -> int: """Get exit code of the job. Parameters diff --git a/dpdispatcher/machines/pbs.py b/dpdispatcher/machines/pbs.py index 7b81a656..90532235 100644 --- a/dpdispatcher/machines/pbs.py +++ b/dpdispatcher/machines/pbs.py @@ -1,5 +1,5 @@ import shlex -from typing import List +from typing import TYPE_CHECKING, Any, List from dargs import Argument @@ -8,6 +8,9 @@ from dpdispatcher.utils.job_status import JobStatus from dpdispatcher.utils.utils import customized_script_header_template +if TYPE_CHECKING: + from dpdispatcher.submission import Job + pbs_script_header_template = """ #!/bin/bash -l {select_node_line} @@ -20,11 +23,11 @@ class PBS(Machine): # def __init__(self, **kwargs): # super().__init__(**kwargs) - def gen_script(self, job): + def gen_script(self, job: "Job") -> str: pbs_script = super().gen_script(job) return pbs_script - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: resources = job.resources pbs_script_header_dict = {} pbs_script_header_dict["select_node_line"] = ( @@ -49,7 +52,7 @@ def gen_script_header(self, job): ) return pbs_script_header - def do_submit(self, job): + def do_submit(self, job: "Job") -> str: script_file_name = job.script_file_name script_str = self.gen_script(job) job_id_name = job.job_hash + "_job_id" @@ -72,7 +75,7 @@ def do_submit(self, job): self.context.write_file(job_id_name, job_id) return job_id - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: job_id = job.job_id if job_id == "": return JobStatus.unsubmitted @@ -105,11 +108,11 @@ def check_status(self, job): else: return JobStatus.unknown - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" return self.context.check_file_exists(job_tag_finished) - def kill(self, job): + def kill(self, job: "Job") -> None: """Kill the job. Parameters @@ -122,7 +125,7 @@ def kill(self, job): class Torque(PBS): - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: job_id = job.job_id if job_id == "": return JobStatus.unsubmitted @@ -155,7 +158,7 @@ def check_status(self, job): else: return JobStatus.unknown - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: # ref: https://support.adaptivecomputing.com/wp-content/uploads/2021/02/torque/torque.htm#topics/torque/2-jobs/requestingRes.htm resources = job.resources pbs_script_header_dict = {} @@ -191,10 +194,10 @@ def gen_script_header(self, job): class SGE(PBS): - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: # noqa: ANN401 super().__init__(**kwargs) - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: ### Ref:https://softpanorama.org/HPC/PBS_and_derivatives/Reference/pbs_command_vs_sge_commands.shtml # resources.number_node is not used in SGE resources = job.resources @@ -222,7 +225,7 @@ def gen_script_header(self, job): ) return sge_script_header - def do_submit(self, job): + def do_submit(self, job: "Job") -> str: script_file_name = job.script_file_name script_str = self.gen_script(job) job_id_name = job.job_hash + "_job_id" @@ -239,7 +242,7 @@ def do_submit(self, job): self.context.write_file(job_id_name, job_id) return job_id - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: ### https://softpanorama.org/HPC/Grid_engine/Queues/queue_states.shtml job_id = job.job_id status_line = None @@ -283,7 +286,7 @@ def check_status(self, job): else: return JobStatus.unknown - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" return self.context.check_file_exists(job_tag_finished) diff --git a/dpdispatcher/machines/shell.py b/dpdispatcher/machines/shell.py index 2205e333..390e2119 100644 --- a/dpdispatcher/machines/shell.py +++ b/dpdispatcher/machines/shell.py @@ -1,21 +1,25 @@ import shlex +from typing import TYPE_CHECKING from dpdispatcher.dlog import dlog from dpdispatcher.machine import Machine from dpdispatcher.utils.job_status import JobStatus from dpdispatcher.utils.utils import customized_script_header_template +if TYPE_CHECKING: + from dpdispatcher.submission import Job + shell_script_header_template = """ #!/bin/bash -l """ class Shell(Machine): - def gen_script(self, job): + def gen_script(self, job: "Job") -> str: shell_script = super().gen_script(job) return shell_script - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: resources = job.resources if ( resources["strategy"].get("customized_script_header_template_file") @@ -29,7 +33,7 @@ def gen_script_header(self, job): shell_script_header = shell_script_header_template return shell_script_header - def do_submit(self, job): + def do_submit(self, job: "Job") -> int: script_str = self.gen_script(job) script_file_name = job.script_file_name job_id_name = job.job_hash + "_job_id" @@ -60,7 +64,7 @@ def do_submit(self, job): # self.context.write_file(job_id_name, job_id) # return job_id - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: job_id = job.job_id # print('shell.check_status.job_id', job_id) # job_state = JobStatus.unknown @@ -101,12 +105,12 @@ def check_status(self, job): # return True # return False - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" # print('job finished: ',job.job_id, job_tag_finished) return self.context.check_file_exists(job_tag_finished) - def kill(self, job): + def kill(self, job: "Job") -> None: """Kill the job. Parameters diff --git a/dpdispatcher/machines/slurm.py b/dpdispatcher/machines/slurm.py index d4f2b328..1101ab56 100644 --- a/dpdispatcher/machines/slurm.py +++ b/dpdispatcher/machines/slurm.py @@ -1,7 +1,7 @@ import math import pathlib import shlex -from typing import List +from typing import TYPE_CHECKING, List from dargs import Argument @@ -14,6 +14,9 @@ retry, ) +if TYPE_CHECKING: + from dpdispatcher.submission import Job + # from dpdispatcher.submission import Resources slurm_script_header_template = """\ @@ -32,11 +35,11 @@ class Slurm(Machine): - def gen_script(self, job): + def gen_script(self, job: "Job") -> str: slurm_script = super().gen_script(job) return slurm_script - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: resources = job.resources script_header_dict = {} script_header_dict["slurm_nodes_line"] = ( @@ -73,7 +76,7 @@ def gen_script_header(self, job): return slurm_script_header @retry() - def do_submit(self, job): + def do_submit(self, job: "Job") -> str: script_file_name = job.script_file_name script_str = self.gen_script(job) job_id_name = job.job_hash + "_job_id" @@ -119,7 +122,7 @@ def do_submit(self, job): return job_id @retry() - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: job_id = job.job_id if job_id == "": return JobStatus.unsubmitted @@ -182,7 +185,7 @@ def check_status(self, job): else: return JobStatus.unknown - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: job_tag_finished = job.job_hash + "_job_tag_finished" return self.context.check_file_exists(job_tag_finished) @@ -214,7 +217,7 @@ def resources_subfields(cls) -> List[Argument]: ) ] - def kill(self, job): + def kill(self, job: "Job") -> None: """Kill the job. Parameters @@ -233,7 +236,7 @@ def kill(self, job): class SlurmJobArray(Slurm): """Slurm with job array enabled for multiple tasks in a job.""" - def gen_script_header(self, job): + def gen_script_header(self, job: "Job") -> str: slurm_job_size = job.resources.kwargs.get("slurm_job_size", 1) if job.fail_count > 0: # resubmit jobs, check if some of tasks have been finished @@ -252,7 +255,7 @@ def gen_script_header(self, job): math.ceil(len(job.job_task_list) / slurm_job_size) - 1 ) - def gen_script_command(self, job): + def gen_script_command(self, job: "Job") -> str: resources = job.resources slurm_job_size = resources.kwargs.get("slurm_job_size", 1) # SLURM_ARRAY_TASK_ID: 0 ~ n_jobs-1 @@ -296,7 +299,7 @@ def gen_script_command(self, job): script_command += "*)\nexit 1\n;;\nesac\n" return script_command - def gen_script_end(self, job): + def gen_script_end(self, job: "Job") -> str: # We cannot touch tag for job array # we may check task tag instead append_script = job.resources.append_script @@ -306,7 +309,7 @@ def gen_script_end(self, job): ) @retry() - def check_status(self, job): + def check_status(self, job: "Job") -> JobStatus: job_id = job.job_id if job_id == "": return JobStatus.unsubmitted @@ -381,7 +384,7 @@ def check_status(self, job): else: return JobStatus.terminated - def check_finish_tag(self, job): + def check_finish_tag(self, job: "Job") -> bool: results = [] for task in job.job_task_list: task.get_task_state(self.context) diff --git a/dpdispatcher/run.py b/dpdispatcher/run.py index 84b2c4b1..e35213f1 100644 --- a/dpdispatcher/run.py +++ b/dpdispatcher/run.py @@ -153,7 +153,7 @@ def create_submission(metadata: dict, hash: str) -> Submission: ) -def run_pep723(script: str): +def run_pep723(script: str) -> None: """Run a PEP 723 script. Parameters diff --git a/dpdispatcher/submission.py b/dpdispatcher/submission.py index 1a394807..ccbe1b25 100644 --- a/dpdispatcher/submission.py +++ b/dpdispatcher/submission.py @@ -9,7 +9,7 @@ import time import uuid from hashlib import sha1 -from typing import List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import yaml from dargs.dargs import Argument, Variant @@ -19,6 +19,9 @@ from dpdispatcher.utils.job_status import JobStatus from dpdispatcher.utils.record import record +if TYPE_CHECKING: + from dpdispatcher.base_context import BaseContext + # %% default_strategy = dict(if_cuda_multi_devices=False, ratio_unfinished=0.0) @@ -47,14 +50,14 @@ class Submission: def __init__( self, - work_base, - machine=None, - resources=None, - forward_common_files=[], - backward_common_files=[], + work_base: str, + machine: Optional["Machine"] = None, + resources: Optional["Resources"] = None, + forward_common_files: List[str] = [], + backward_common_files: List[str] = [], *, - task_list=[], - ): + task_list: List["Task"] = [], + ) -> None: self.local_root = None self.work_base = work_base self._abs_work_base = os.path.abspath(work_base) @@ -79,22 +82,24 @@ def __init__( self.bind_machine(machine) - def __repr__(self): + def __repr__(self) -> str: return json.dumps(self.serialize(), indent=4) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """When check whether the two submission are equal, we disregard the runtime infomation(job_state, job_id, fail_count) of the submission.belonging_jobs. """ return json.dumps(self.serialize(if_static=True)) == json.dumps( - other.serialize(if_static=True) + other.serialize(if_static=True) # type: ignore[attr-defined] ) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: # noqa: ANN401 return self.serialize()[key] @classmethod - def deserialize(cls, submission_dict, machine=None): + def deserialize( + cls, submission_dict: Dict[str, Any], machine: Optional["Machine"] = None + ) -> "Submission": # noqa: ANN401 """Convert the submission_dict to a Submission class object. Parameters @@ -129,7 +134,7 @@ def deserialize(cls, submission_dict, machine=None): submission.bind_machine(machine) return submission - def serialize(self, if_static=False): + def serialize(self, if_static: bool = False) -> Dict[str, Any]: # noqa: ANN401 """Convert the Submission class instance to a dictionary. Parameters @@ -164,26 +169,26 @@ def serialize(self, if_static=False): ] return submission_dict - def register_task(self, task): + def register_task(self, task: "Task") -> None: if self.belonging_jobs: raise RuntimeError( f"Not allowed to register tasks after generating jobs. submission hash error {self}" ) self.belonging_tasks.append(task) - def register_task_list(self, task_list): + def register_task_list(self, task_list: List["Task"]) -> None: if self.belonging_jobs: raise RuntimeError( f"Not allowed to register tasks after generating jobs. submission hash error {self}" ) self.belonging_tasks.extend(task_list) - def get_hash(self): + def get_hash(self) -> str: return sha1( json.dumps(self.serialize(if_static=True)).encode("utf-8") ).hexdigest() - def bind_machine(self, machine): + def bind_machine(self, machine: Optional["Machine"]) -> "Submission": """Bind this submission to a machine. update the machine's context remote_root and local_root. Parameters @@ -201,8 +206,13 @@ def bind_machine(self, machine): return self def run_submission( - self, *, dry_run=False, exit_on_submit=False, clean=True, check_interval=30 - ): + self, + *, + dry_run: bool = False, + exit_on_submit: bool = False, + clean: bool = True, + check_interval: int = 30, + ) -> None: """Main method to execute the submission. First, check whether old Submission exists on the remote machine, and try to recover from it. Second, upload the local files to the remote machine where the tasks to be executed. @@ -265,7 +275,7 @@ def run_submission( self.clean_jobs() return self.serialize() - def try_download_result(self): + def try_download_result(self) -> None: start_time = time.time() retry_interval = 60 # retry every 1 minute success = False @@ -290,7 +300,7 @@ def try_download_result(self): dlog.info("Maximum retries time reached. Exiting.") break - async def async_run_submission(self, **kwargs): + async def async_run_submission(self, **kwargs: Any) -> None: # noqa: ANN401 """Async interface of run_submission. Examples @@ -327,7 +337,7 @@ async def async_run_submission(self, **kwargs): wrapped_submission = functools.partial(self.run_submission, **kwargs) return await loop.run_in_executor(None, wrapped_submission) - def update_submission_state(self): + def update_submission_state(self) -> None: """Check whether all the jobs in the submission. Notes @@ -343,7 +353,7 @@ def update_submission_state(self): f"update_submission_state: job: {job.job_hash}, {job.job_id}, {job.job_state}" ) - def handle_unexpected_submission_state(self): + def handle_unexpected_submission_state(self) -> None: """Handle unexpected job state of the submission. If the job state is unsubmitted, submit the job. If the job state is terminated (killed unexpectly), resubmit the job. @@ -390,7 +400,7 @@ def check_ratio_unfinished(self, ratio_unfinished: float) -> bool: finished_num = status_list.count(JobStatus.finished) return finished_num / len(self.belonging_tasks) >= (1 - ratio_unfinished) - def remove_unfinished_tasks(self): + def remove_unfinished_tasks(self) -> None: dlog.info("Remove unfinished tasks") # kill all jobs and mark them as finished for job in self.belonging_jobs: @@ -413,7 +423,7 @@ def remove_unfinished_tasks(self): if task.task_state == JobStatus.finished ] - def check_all_finished(self): + def check_all_finished(self) -> bool: """Check whether all the jobs in the submission. Notes @@ -444,7 +454,7 @@ def check_all_finished(self): else: return True - def generate_jobs(self): + def generate_jobs(self) -> None: """After tasks register to the self.belonging_tasks, This method generate the jobs and add these jobs to self.belonging_jobs. The jobs are generated by the tasks randomly, and there are self.resources.group_size tasks in a task. @@ -487,34 +497,36 @@ def generate_jobs(self): self.submission_hash = self.get_hash() - def upload_jobs(self): + def upload_jobs(self) -> None: self.machine.context.upload(self) - def download_jobs(self): + def download_jobs(self) -> None: self.machine.context.download(self) # for job in self.belonging_jobs: # job.tag_finished() # self.machine.context.write_file(self.machine.finish_tag_name, write_str="") - def clean_jobs(self): + def clean_jobs(self) -> None: self.machine.context.clean() assert self.submission_hash is not None record.remove(self.submission_hash) - def submission_to_json(self): + def submission_to_json(self) -> None: # self.update_submission_state() write_str = json.dumps(self.serialize(), indent=4, default=str) submission_file_name = f"{self.submission_hash}.json" self.machine.context.write_file(submission_file_name, write_str=write_str) @classmethod - def submission_from_json(cls, json_file_name="submission.json"): + def submission_from_json( + cls, json_file_name: str = "submission.json" + ) -> "Submission": with open(json_file_name) as f: submission_dict = json.load(f) submission = cls.deserialize(submission_dict=submission_dict, machine=None) return submission - def try_recover_from_json(self): + def try_recover_from_json(self) -> None: submission_file_name = f"{self.submission_hash}.json" if_recover = self.machine.context.check_file_exists(submission_file_name) submission = None @@ -566,13 +578,13 @@ class Task: def __init__( self, - command, - task_work_path, - forward_files=[], - backward_files=[], - outlog="log", - errlog="err", - ): + command: str, + task_work_path: str, + forward_files: List[str] = [], + backward_files: List[str] = [], + outlog: str = "log", + errlog: str = "err", + ) -> None: self.command = command self.task_work_path = task_work_path self.forward_files = forward_files @@ -587,26 +599,26 @@ def __init__( # self.uuid = self.task_state = JobStatus.unsubmitted - def __repr__(self): + def __repr__(self) -> str: return str(self.serialize()) - def __eq__(self, other): - return json.dumps(self.serialize()) == json.dumps(other.serialize()) + def __eq__(self, other: object) -> bool: + return json.dumps(self.serialize()) == json.dumps(other.serialize()) # type: ignore[attr-defined] - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: # noqa: ANN401 return self.serialize()[key] - def get_hash(self): + def get_hash(self) -> str: return sha1(json.dumps(self.serialize()).encode("utf-8")).hexdigest() @classmethod - def load_from_json(cls, json_file): + def load_from_json(cls, json_file: str) -> "Task": with open(json_file) as f: task_dict = json.load(f) return cls.load_from_dict(task_dict) @classmethod - def load_from_yaml(cls, yaml_file): + def load_from_yaml(cls, yaml_file: str) -> "Task": with open(yaml_file) as f: task_dict = yaml.safe_load(f) task = cls.load_from_dict(task_dict=task_dict) @@ -623,7 +635,7 @@ def load_from_dict(cls, task_dict: dict) -> "Task": return task @classmethod - def deserialize(cls, task_dict): + def deserialize(cls, task_dict: Dict[str, Any]) -> "Task": # noqa: ANN401 """Convert the task_dict to a Task class object. Parameters @@ -639,7 +651,7 @@ def deserialize(cls, task_dict): task = cls(**task_dict) return task - def serialize(self): + def serialize(self) -> Dict[str, Any]: # noqa: ANN401 task_dict = {} task_dict["command"] = self.command task_dict["task_work_path"] = self.task_work_path @@ -651,7 +663,7 @@ def serialize(self): return task_dict @staticmethod - def arginfo(): + def arginfo() -> Argument: doc_command = ( "A command to be executed of this task. The expected return code is 0." ) @@ -698,7 +710,7 @@ def arginfo(): task_format = Argument("task", dict, task_args) return task_format - def get_task_state(self, context): + def get_task_state(self, context: "BaseContext") -> None: """Get the task state by checking the tag file. Parameters @@ -737,11 +749,11 @@ class Job: def __init__( self, - job_task_list, + job_task_list: List["Task"], *, - resources, - machine=None, - ): + resources: "Resources", + machine: Optional["Machine"] = None, + ) -> None: self.job_task_list = job_task_list # self.job_work_base = job_work_base self.resources = resources @@ -754,19 +766,21 @@ def __init__( self.job_hash = self.get_hash() self.script_file_name = self.job_hash + ".sub" - def __repr__(self): + def __repr__(self) -> str: return str(self.serialize()) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """When check whether the two jobs are equal, we disregard the runtime infomation(job_state, job_id, fail_count) of the jobs. """ return json.dumps(self.serialize(if_static=True)) == json.dumps( - other.serialize(if_static=True) + other.serialize(if_static=True) # type: ignore[attr-defined] ) @classmethod - def deserialize(cls, job_dict, machine=None): + def deserialize( + cls, job_dict: Dict[str, Any], machine: Optional["Machine"] = None + ) -> "Job": # noqa: ANN401 """Convert the job_dict to a Submission class object. Parameters @@ -808,7 +822,7 @@ def deserialize(cls, job_dict, machine=None): task.task_state = job.job_state return job - def get_job_state(self): + def get_job_state(self) -> None: """Get the jobs. Usually, this method will query the database of slurm or pbs job scheduler system and get the results. Notes @@ -827,7 +841,7 @@ def get_job_state(self): if task.task_state != JobStatus.finished: task.task_state = job_state - def handle_unexpected_job_state(self): + def handle_unexpected_job_state(self) -> None: job_state = self.job_state if job_state == JobStatus.unknown: @@ -876,10 +890,10 @@ def handle_unexpected_job_state(self): time.sleep(self.resources.wait_time) # self.get_job_state() - def get_hash(self): + def get_hash(self) -> str: return str(list(self.serialize(if_static=True).keys())[0]) - def serialize(self, if_static=False): + def serialize(self, if_static: bool = False) -> Dict[str, Any]: # noqa: ANN401 """Convert the Task class instance to a dictionary. Parameters @@ -907,10 +921,10 @@ def serialize(self, if_static=False): # job_content_dict['job_uuid'] = self.job_uuid return {job_hash: job_content_dict} - def register_job_id(self, job_id): + def register_job_id(self, job_id: str) -> None: self.job_id = job_id - def submit_job(self): + def submit_job(self) -> None: assert self.machine is not None job_id = self.machine.do_submit(self) self.register_job_id(job_id) @@ -919,7 +933,7 @@ def submit_job(self): else: self.job_state = JobStatus.unsubmitted - def job_to_json(self): + def job_to_json(self) -> None: write_str = json.dumps(self.serialize(), indent=2, default=str) assert self.machine is not None self.machine.context.write_file( @@ -976,25 +990,25 @@ class Resources: def __init__( self, - number_node, - cpu_per_node, - gpu_per_node, - queue_name, - group_size, + number_node: int, + cpu_per_node: int, + gpu_per_node: int, + queue_name: str, + group_size: int, *, - custom_flags=[], - strategy=default_strategy, - para_deg=1, - module_unload_list=[], - module_purge=False, - module_list=[], - source_list=[], - envs={}, - prepend_script=[], - append_script=[], - wait_time=0, - **kwargs, - ): + custom_flags: List[str] = [], + strategy: Dict[str, Any] = default_strategy, # noqa: ANN401 + para_deg: int = 1, + module_unload_list: List[str] = [], + module_purge: bool = False, + module_list: List[str] = [], + source_list: List[str] = [], + envs: Dict[str, str] = {}, + prepend_script: List[str] = [], + append_script: List[str] = [], + wait_time: int = 0, + **kwargs: Any, # noqa: ANN401 + ) -> None: self.number_node = number_node self.cpu_per_node = cpu_per_node self.gpu_per_node = gpu_per_node @@ -1037,10 +1051,10 @@ def __init__( if self.strategy["ratio_unfinished"] >= 1.0: raise RuntimeError("ratio_unfinished must be smaller than 1.0") - def __eq__(self, other): - return json.dumps(self.serialize()) == json.dumps(other.serialize()) + def __eq__(self, other: object) -> bool: + return json.dumps(self.serialize()) == json.dumps(other.serialize()) # type: ignore[attr-defined] - def serialize(self): + def serialize(self) -> Dict[str, Any]: # noqa: ANN401 resources_dict = {} resources_dict["number_node"] = self.number_node resources_dict["cpu_per_node"] = self.cpu_per_node @@ -1063,7 +1077,7 @@ def serialize(self): return resources_dict @classmethod - def deserialize(cls, resources_dict): + def deserialize(cls, resources_dict: Dict[str, Any]) -> "Resources": # noqa: ANN401 resources = cls( number_node=resources_dict.get("number_node", 1), cpu_per_node=resources_dict.get("cpu_per_node", 1), @@ -1085,25 +1099,25 @@ def deserialize(cls, resources_dict): ) return resources - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: # noqa: ANN401 return self.serialize()[key] @classmethod - def load_from_json(cls, json_file): + def load_from_json(cls, json_file: str) -> "Resources": with open(json_file) as f: resources_dict = json.load(f) resources = cls.load_from_dict(resources_dict=resources_dict) return resources @classmethod - def load_from_yaml(cls, yaml_file): + def load_from_yaml(cls, yaml_file: str) -> "Resources": with open(yaml_file) as f: resources_dict = yaml.safe_load(f) resources = cls.load_from_dict(resources_dict=resources_dict) return resources @classmethod - def load_from_dict(cls, resources_dict): + def load_from_dict(cls, resources_dict: Dict[str, Any]) -> "Resources": # noqa: ANN401 # check dict base = cls.arginfo(detail_kwargs="batch_type" in resources_dict) resources_dict = base.normalize_value(resources_dict, trim_pattern="_*") @@ -1112,7 +1126,7 @@ def load_from_dict(cls, resources_dict): return cls.deserialize(resources_dict=resources_dict) @staticmethod - def arginfo(detail_kwargs=True): + def arginfo(detail_kwargs: bool = True) -> Argument: doc_number_node = "The number of nodes required for each `job`." doc_cpu_per_node = "CPU numbers of each node assigned to each job." doc_gpu_per_node = "GPU numbers of each node assigned to each job." diff --git a/dpdispatcher/utils/dpcloudserver/client.py b/dpdispatcher/utils/dpcloudserver/client.py index cfeed8b6..445a3e15 100644 --- a/dpdispatcher/utils/dpcloudserver/client.py +++ b/dpdispatcher/utils/dpcloudserver/client.py @@ -2,6 +2,7 @@ import re import time import urllib.parse +from typing import Any, Dict, List, Optional, Tuple, Union from urllib.parse import urljoin import requests @@ -26,8 +27,13 @@ class RequestInfoException(Exception): class Client: def __init__( - self, email=None, password=None, debug=False, ticket=None, base_url=API_HOST - ): + self, + email: Optional[str] = None, + password: Optional[str] = None, + debug: bool = False, + ticket: Optional[str] = None, + base_url: str = API_HOST, + ) -> None: self.debug = debug self.debug = os.getenv("LBG_CLI_DEBUG_PRINT", debug) self.config = {} @@ -39,15 +45,36 @@ def __init__( self.last_log_offset = 0 self.ticket = ticket - def post(self, url, data=None, header=None, params=None, retry=5): + def post( + self, + url: str, + data: Optional[Dict[str, Any]] = None, # noqa: ANN401 + header: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Any]] = None, # noqa: ANN401 + retry: int = 5, + ) -> Dict[str, Any]: # noqa: ANN401 return self._req( "POST", url, data=data, header=header, params=params, retry=retry ) - def get(self, url, header=None, params=None, retry=5): + def get( + self, + url: str, + header: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Any]] = None, # noqa: ANN401 + retry: int = 5, + ) -> Dict[str, Any]: # noqa: ANN401 return self._req("GET", url, header=header, params=params, retry=retry) - def _req(self, method, url, data=None, header=None, params=None, retry=5): + def _req( + self, + method: str, + url: str, + data: Optional[Dict[str, Any]] = None, # noqa: ANN401 + header: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Any]] = None, # noqa: ANN401 + retry: int = 5, + ) -> Dict[str, Any]: # noqa: ANN401 short_url = url url = urllib.parse.urljoin(self.base_url, url) if header is None: @@ -94,7 +121,7 @@ def _req(self, method, url, data=None, header=None, params=None, retry=5): err = result.get("message") or result.get("error") raise RequestInfoException(resp_code, short_url, err) - def _login(self): + def _login(self) -> None: if self.config["email"] is None or self.config["password"] is None: raise RequestInfoException( "can not find login information, please check your config" @@ -105,7 +132,7 @@ def _login(self): # print(self.token) self.user_id = resp["user_id"] - def refresh_token(self, retry=3): + def refresh_token(self, retry: int = 3) -> None: self.ticket = os.environ.get("BOHR_TICKET", "") if self.ticket: return @@ -137,7 +164,7 @@ def refresh_token(self, retry=3): return raise RequestInfoException(resp_code, url, err) - def _get_oss_bucket(self, endpoint, bucket_name): + def _get_oss_bucket(self, endpoint: str, bucket_name: str) -> Any: # noqa: ANN401 # res = get("/tools/sts_token", {}) res = self.get("/data/get_sts_token", {}) # print('debug>>>>>>>>>>>>>', res) @@ -147,13 +174,15 @@ def _get_oss_bucket(self, endpoint, bucket_name): ) return oss2.Bucket(auth, endpoint, bucket_name) - def download(self, oss_file, save_file, endpoint, bucket_name): + def download( + self, oss_file: str, save_file: str, endpoint: str, bucket_name: str + ) -> str: bucket = self._get_oss_bucket(endpoint, bucket_name) dlog.debug(f"download: oss_file:{oss_file}; save_file:{save_file}") bucket.get_object_to_file(oss_file, save_file) return save_file - def download_from_url(self, url, save_file): + def download_from_url(self, url: str, save_file: str) -> None: ret = None for retry_count in range(3): try: @@ -178,7 +207,9 @@ def download_from_url(self, url, save_file): f.write(chunk) ret.close() - def upload(self, oss_task_zip, zip_task_file, endpoint, bucket_name): + def upload( + self, oss_task_zip: str, zip_task_file: str, endpoint: str, bucket_name: str + ) -> Any: # noqa: ANN401 dlog.debug( f"upload: oss_task_zip:{oss_task_zip}; zip_task_file:{zip_task_file}" ) @@ -207,8 +238,13 @@ def upload(self, oss_task_zip, zip_task_file, endpoint, bucket_name): return result def job_create( - self, job_type, oss_path, input_data, program_id=None, group_id=None - ): + self, + job_type: str, + oss_path: str, + input_data: Dict[str, Any], # noqa: ANN401 + program_id: Optional[int] = None, + group_id: Optional[int] = None, + ) -> Tuple[int, Optional[int]]: post_data = { "job_type": job_type, "oss_path": oss_path, @@ -241,11 +277,11 @@ def job_create( group_id = ret.get("jobGroupId") return ret["jobId"], group_id - def _camelize(self, str_or_iter): + def _camelize(self, str_or_iter: Any) -> str: # noqa: ANN401 # code reference from https://pypi.org/project/pyhumps/ regex = re.compile(r"(?<=[^\-_\s])[\-_\s]+[^\-_\s]") - def _is_none(_in): + def _is_none(_in: Any) -> Union[str, Any]: # noqa: ANN401 return "" if _in is None else _in s = str(_is_none(str_or_iter)) @@ -256,7 +292,7 @@ def _is_none(_in): s = s[0].lower() + s[1:] return regex.sub(lambda m: m.group(0)[-1].upper(), s) - def get_job_detail(self, job_id): + def get_job_detail(self, job_id: str) -> Optional[Dict[str, Any]]: # noqa: ANN401 try: ret = self.get( f"brm/v1/job/{job_id}", @@ -270,7 +306,7 @@ def get_job_detail(self, job_id): return ret - def get_log(self, job_id): + def get_log(self, job_id: str) -> str: url, size = self._get_job_log(job_id) if not url: return "" @@ -284,7 +320,7 @@ def get_log(self, job_id): dlog.error(f"Error decoding job log: {e}", stack_info=ENABLE_STACK) return "" - def _get_job_log(self, job_id): + def _get_job_log(self, job_id: str) -> Tuple[Optional[str], int]: ret = self.get( f"/brm/v1/job/{job_id}/log", params={ @@ -296,7 +332,7 @@ def _get_job_log(self, job_id): return d[0]["url"], d[0]["size"] return None, 0 - def get_tasks_list(self, group_id, per_page=30): + def get_tasks_list(self, group_id: int, per_page: int = 30) -> List[Dict[str, Any]]: # noqa: ANN401 result = [] page = 0 while True: @@ -315,7 +351,7 @@ def get_tasks_list(self, group_id, per_page=30): page += 1 return result - def get_job_result_url(self, job_id): + def get_job_result_url(self, job_id: str) -> Optional[str]: try: if not job_id: return None @@ -331,7 +367,7 @@ def get_job_result_url(self, job_id): dlog.error(e, stack_info=ENABLE_STACK) return None - def kill(self, job_id): + def kill(self, job_id: str) -> Optional[Dict[str, Any]]: # noqa: ANN401 try: if not job_id: return None diff --git a/dpdispatcher/utils/dpcloudserver/zip_file.py b/dpdispatcher/utils/dpcloudserver/zip_file.py index 41beaaa1..072e0fab 100644 --- a/dpdispatcher/utils/dpcloudserver/zip_file.py +++ b/dpdispatcher/utils/dpcloudserver/zip_file.py @@ -1,5 +1,6 @@ import glob import os +from typing import List from zipfile import ZipFile # def zip_file_list(root_path, zip_filename, file_list=[]): @@ -7,7 +8,7 @@ # root_dir=root_path,) -def zip_file_list(root_path, zip_filename, file_list=[]): +def zip_file_list(root_path: str, zip_filename: str, file_list: List[str] = []) -> str: out_zip_file = os.path.join(root_path, zip_filename) # print('debug: file_list', file_list) zip_obj = ZipFile(out_zip_file, "w") @@ -77,7 +78,7 @@ def zip_file_list(root_path, zip_filename, file_list=[]): # return False -def unzip_file(zip_file, out_dir="./"): +def unzip_file(zip_file: str, out_dir: str = "./") -> None: obj = ZipFile(zip_file, "r") for item in obj.namelist(): obj.extract(item, out_dir) diff --git a/dpdispatcher/utils/hdfs_cli.py b/dpdispatcher/utils/hdfs_cli.py index f0c23949..e37dd7dd 100644 --- a/dpdispatcher/utils/hdfs_cli.py +++ b/dpdispatcher/utils/hdfs_cli.py @@ -1,6 +1,7 @@ # /usr/bin/python import os +from typing import List, Optional, Tuple, Union from dpdispatcher.utils.utils import run_cmd_with_all_output @@ -9,7 +10,7 @@ class HDFS: """Fundamental class for HDFS basic manipulation.""" @staticmethod - def exists(uri): + def exists(uri: str) -> Optional[bool]: """Check existence of hdfs uri Returns: True on exists Raises: RuntimeError. @@ -32,7 +33,7 @@ def exists(uri): ) from e @staticmethod - def remove(uri): + def remove(uri: str) -> Optional[bool]: """Check existence of hdfs uri Returns: True on exists Raises: RuntimeError. @@ -51,7 +52,7 @@ def remove(uri): raise RuntimeError(f"Cannot remove hdfs uri[{uri}] with cmd[{cmd}]") from e @staticmethod - def mkdir(uri): + def mkdir(uri: str) -> Optional[bool]: """Make new hdfs directory Returns: True on success Raises: RuntimeError. @@ -72,7 +73,7 @@ def mkdir(uri): ) from e @staticmethod - def copy_from_local(local_path, to_uri): + def copy_from_local(local_path: str, to_uri: str) -> Tuple[bool, bytes]: """Returns: True on success Raises: on unexpected error. """ @@ -95,7 +96,9 @@ def copy_from_local(local_path, to_uri): ) from e @staticmethod - def copy_to_local(from_uri, local_path): + def copy_to_local( + from_uri: Union[str, List[str], Tuple[str, ...]], local_path: str + ) -> Optional[bool]: remote = "" if isinstance(from_uri, str): remote = from_uri @@ -118,7 +121,7 @@ def copy_to_local(from_uri, local_path): ) from e @staticmethod - def read_hdfs_file(uri): + def read_hdfs_file(uri: str) -> bytes: cmd = f"hadoop fs -text {uri}" try: ret, out, err = run_cmd_with_all_output(cmd) @@ -133,7 +136,7 @@ def read_hdfs_file(uri): raise RuntimeError(f"Cannot read text from uri[{uri}]cmd [{cmd}]") from e @staticmethod - def move(from_uri, to_uri): + def move(from_uri: str, to_uri: str) -> Optional[bool]: cmd = f"hadoop fs -mv {from_uri} {to_uri}" try: ret, out, err = run_cmd_with_all_output(cmd) diff --git a/dpdispatcher/utils/record.py b/dpdispatcher/utils/record.py index 5a8812ad..a7857689 100644 --- a/dpdispatcher/utils/record.py +++ b/dpdispatcher/utils/record.py @@ -1,6 +1,9 @@ import json from pathlib import Path -from typing import List +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + from dpdispatcher.submission import Submission class Record: @@ -24,7 +27,7 @@ def get_submissions(self) -> List[str]: if (f.is_file() and f.suffix == ".json") ] - def write(self, submission) -> Path: + def write(self, submission: "Submission") -> Path: """Write submission data to file. Parameters @@ -59,7 +62,7 @@ def get_submission(self, hash: str, not_exist_ok: bool = False) -> Path: raise FileNotFoundError(f"Submission file not found: {submission_file}") return submission_file - def remove(self, hash: str): + def remove(self, hash: str) -> None: """Remove submission data by hash. Call this method when the remote directory is cleaned. diff --git a/dpdispatcher/utils/utils.py b/dpdispatcher/utils/utils.py index b6c5e392..e4ee2b95 100644 --- a/dpdispatcher/utils/utils.py +++ b/dpdispatcher/utils/utils.py @@ -6,7 +6,7 @@ import struct import subprocess import time -from typing import TYPE_CHECKING, Callable, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union from dpdispatcher.dlog import dlog @@ -14,7 +14,7 @@ from dpdispatcher import Resources -def get_sha256(filename): +def get_sha256(filename: str) -> str: """Get sha256 of a file. Parameters @@ -38,7 +38,7 @@ def get_sha256(filename): return sha256 -def hotp(key: str, period: int, token_length: int = 6, digest="sha1"): +def hotp(key: str, period: int, token_length: int = 6, digest: str = "sha1") -> str: key_ = base64.b32decode(key.upper() + "=" * ((8 - len(key)) % 8)) period_ = struct.pack(">Q", period) mac = hmac.new(key_, period_, digest).digest() @@ -75,7 +75,7 @@ def generate_totp(secret: str, period: int = 30, token_length: int = 6) -> str: return hotp(secret, int(time.time() / period), token_length, digest) -def run_cmd_with_all_output(cmd, shell=True): +def run_cmd_with_all_output(cmd: str, shell: bool = True) -> Tuple[int, bytes, bytes]: with subprocess.Popen( cmd, shell=shell, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) as proc: @@ -91,7 +91,7 @@ def rsync( key_filename: Optional[str] = None, timeout: Union[int, float] = 10, proxy_command: Optional[str] = None, -): +) -> None: """Call rsync to transfer files. Parameters @@ -186,10 +186,10 @@ def retry( ... raise RetrySignal("Failed") """ - def decorator(func): + def decorator(func: Callable) -> Callable: # noqa: ANN401 assert max_retry > 0, "max_retry must be greater than 0" - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401 current_retry = 0 errors = [] while max_retry is None or current_retry < max_retry: diff --git a/pyproject.toml b/pyproject.toml index 743ae74b..8eac9f9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ select = [ "D", # pydocstyle "UP", # pyupgrade "I", # isort + "ANN", # flake8-annotations ] ignore = [ "E501", # line too long diff --git a/tests/context.py b/tests/context.py index 2d5e3a7e..714bdb9b 100644 --- a/tests/context.py +++ b/tests/context.py @@ -31,9 +31,9 @@ from dpdispatcher.utils.utils import RetrySignal, retry # noqa: F401 -def setUpModule(): +def setUpModule() -> None: os.chdir(os.path.abspath(os.path.dirname(__file__))) -def get_file_md5(file_path): +def get_file_md5(file_path: str) -> str: return hashlib.md5(pathlib.Path(file_path).read_bytes()).hexdigest() diff --git a/tests/sample_class.py b/tests/sample_class.py index 7c663094..f0c61aac 100644 --- a/tests/sample_class.py +++ b/tests/sample_class.py @@ -1,5 +1,6 @@ import os import sys +from typing import Any, Dict, List sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) __package__ = "tests" @@ -18,7 +19,7 @@ class SampleClass: @classmethod - def get_sample_resources(cls): + def get_sample_resources(cls) -> Resources: resources = Resources( number_node=1, cpu_per_node=4, @@ -36,7 +37,7 @@ def get_sample_resources(cls): return resources @classmethod - def get_sample_resources_dict(cls): + def get_sample_resources_dict(cls) -> Dict[str, Any]: # noqa: ANN401 resources_dict = { "number_node": 1, "cpu_per_node": 4, @@ -59,7 +60,7 @@ def get_sample_resources_dict(cls): return resources_dict @classmethod - def get_sample_task(cls): + def get_sample_task(cls) -> Task: task = Task( command="lmp -i input.lammps", task_work_path="bct-1/", @@ -71,7 +72,7 @@ def get_sample_task(cls): return task @classmethod - def get_sample_task_dict(cls): + def get_sample_task_dict(cls) -> Dict[str, Any]: # noqa: ANN401 task_dict = { "command": "lmp -i input.lammps", "task_work_path": "bct-1/", @@ -83,7 +84,7 @@ def get_sample_task_dict(cls): return task_dict @classmethod - def get_sample_task_list(cls, backward_wildcard=False): + def get_sample_task_list(cls, backward_wildcard: bool = False) -> List[Task]: task1 = Task( command="lmp -i input.lammps", task_work_path="bct-1/", @@ -122,7 +123,7 @@ def get_sample_task_list(cls, backward_wildcard=False): return task_list @classmethod - def get_sample_empty_submission(cls): + def get_sample_empty_submission(cls) -> Submission: resources = cls.get_sample_resources() # print(task_list) empty_submission = Submission( @@ -137,7 +138,7 @@ def get_sample_empty_submission(cls): return empty_submission @classmethod - def get_sample_submission(cls, backward_wildcard=False): + def get_sample_submission(cls, backward_wildcard: bool = False) -> Submission: submission = cls.get_sample_empty_submission() task_list = cls.get_sample_task_list(backward_wildcard=backward_wildcard) submission.register_task_list(task_list) @@ -145,25 +146,25 @@ def get_sample_submission(cls, backward_wildcard=False): return submission @classmethod - def get_sample_submission_dict(cls): + def get_sample_submission_dict(cls) -> Dict[str, Any]: # noqa: ANN401 submission = cls.get_sample_submission() submission_dict = submission.serialize() return submission_dict @classmethod - def get_sample_job(cls): + def get_sample_job(cls) -> Any: # noqa: ANN401 Submission = cls.get_sample_submission() job = Submission.belonging_jobs[0] return job @classmethod - def get_sample_job_dict(cls): + def get_sample_job_dict(cls) -> Dict[str, Any]: # noqa: ANN401 job = cls.get_sample_job() job_dict = job.serialize() return job_dict @classmethod - def get_sample_pbs_local_context(cls): + def get_sample_pbs_local_context(cls) -> PBS: # local_session = LocalSession({'work_path':'test_work_path/'}) local_context = LocalContext( local_root="test_pbs_dir/", remote_root="tmp_pbs_dir/" @@ -172,7 +173,7 @@ def get_sample_pbs_local_context(cls): return pbs @classmethod - def get_sample_slurm_local_context(cls): + def get_sample_slurm_local_context(cls) -> Slurm: # local_session = LocalSession({'work_path':'test_work_path/'}) local_context = LocalContext( local_root="test_slurm_dir/", remote_root="tmp_slurm_dir/" diff --git a/tests/test_JH_UniScheduler_script_generation.py b/tests/test_JH_UniScheduler_script_generation.py index b7d29a8d..934a977f 100644 --- a/tests/test_JH_UniScheduler_script_generation.py +++ b/tests/test_JH_UniScheduler_script_generation.py @@ -20,10 +20,10 @@ class TestJHUniSchedulerScriptGeneration(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None - def test_shell_trival(self): + def test_shell_trival(self) -> None: with open("jsons/machine_lazy_local_jh_unischeduler.json") as f: machine_dict = json.load(f) @@ -88,7 +88,7 @@ def test_shell_trival(self): self.assertEqual(header_str, benchmark_header) @unittest.skipIf(sys.platform == "win32", "skip for persimission error") - def test_template(self): + def test_template(self) -> None: with open("jsons/machine_lazy_local_jh_unischeduler.json") as f: machine_dict = json.load(f) diff --git a/tests/test_argcheck.py b/tests/test_argcheck.py index 637c5254..c69664ae 100644 --- a/tests/test_argcheck.py +++ b/tests/test_argcheck.py @@ -9,7 +9,7 @@ class TestJob(unittest.TestCase): - def test_machine_argcheck(self): + def test_machine_argcheck(self) -> None: norm_dict = Machine.load_from_dict( { "batch_type": "slurm", @@ -31,7 +31,7 @@ def test_machine_argcheck(self): } self.assertDictEqual(norm_dict, expected_dict) - def test_resources_argcheck(self): + def test_resources_argcheck(self) -> None: norm_dict = Resources.load_from_dict( { "number_node": 1, @@ -70,7 +70,7 @@ def test_resources_argcheck(self): } self.assertDictEqual(norm_dict, expected_dict) - def test_task_argcheck(self): + def test_task_argcheck(self) -> None: norm_dict = Task.load_from_dict( { "command": "ls", diff --git a/tests/test_class_job.py b/tests/test_class_job.py index d863fdd4..85d9be2b 100644 --- a/tests/test_class_job.py +++ b/tests/test_class_job.py @@ -15,23 +15,23 @@ class TestJob(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.job = SampleClass.get_sample_job() self.submission2 = Submission.submission_from_json("jsons/submission.json") self.job2 = self.submission2.belonging_jobs[0] - def test_eq(self): + def test_eq(self) -> None: self.assertTrue(self.job == self.job2) - def test_get_hash(self): + def test_get_hash(self) -> None: self.assertEqual(self.job.get_hash(), self.job2.get_hash()) # self.assertEqual(self.submission, self.submission2) - def test_serialize_deserialize(self): + def test_serialize_deserialize(self) -> None: self.assertEqual(self.job, Job.deserialize(job_dict=self.job.serialize())) - def test_static_serialize(self): + def test_static_serialize(self) -> None: self.assertNotIn( "job_state", list(self.job.serialize(if_static=True).values())[0] ) @@ -40,19 +40,19 @@ def test_static_serialize(self): "fail_count", list(self.job.serialize(if_static=True).values())[0] ) - def test_get_job_state(self): + def test_get_job_state(self) -> None: pass - def test_handle_unexpected_job_state(self): + def test_handle_unexpected_job_state(self) -> None: pass - def test_register_job_id(self): + def test_register_job_id(self) -> None: pass - def test_submit_job(self): + def test_submit_job(self) -> None: pass - def test_job_to_json(self): + def test_job_to_json(self) -> None: pass diff --git a/tests/test_class_machine.py b/tests/test_class_machine.py index 0c8aacc0..97f58320 100644 --- a/tests/test_class_machine.py +++ b/tests/test_class_machine.py @@ -14,13 +14,13 @@ class TestMachineInit(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None - def test_machine_serialize_deserialize(self): + def test_machine_serialize_deserialize(self) -> None: pbs = SampleClass.get_sample_pbs_local_context() self.assertEqual(pbs, Machine.deserialize(pbs.serialize())) - def test_machine_load_from_dict(self): + def test_machine_load_from_dict(self) -> None: pbs = SampleClass.get_sample_pbs_local_context() self.assertEqual(pbs, PBS.load_from_dict(pbs.serialize())) diff --git a/tests/test_class_machine_dispatch.py b/tests/test_class_machine_dispatch.py index db912a31..3d708773 100644 --- a/tests/test_class_machine_dispatch.py +++ b/tests/test_class_machine_dispatch.py @@ -24,10 +24,10 @@ class TestMachineDispatch(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None - def test_init_machine_pbs_lazy_local(self): + def test_init_machine_pbs_lazy_local(self) -> None: machine_dict = { "batch_type": "PBS", "context_type": "LazyLocalContext", @@ -38,7 +38,7 @@ def test_init_machine_pbs_lazy_local(self): self.assertIsInstance(machine, PBS) self.assertIsInstance(machine.context, LazyLocalContext) - def test_init_machine_shell_local(self): + def test_init_machine_shell_local(self) -> None: machine_dict = { "batch_type": "Shell", "context_type": "LocalContext", @@ -50,7 +50,7 @@ def test_init_machine_shell_local(self): self.assertIsInstance(shell, Shell) self.assertIsInstance(shell.context, LocalContext) - def test_init_machine_slurm_ssh(self): + def test_init_machine_slurm_ssh(self) -> None: machine_dict = { "batch_type": "Slurm", "context_type": "SSHContext", @@ -62,7 +62,7 @@ def test_init_machine_slurm_ssh(self): with self.assertRaises(gaierror): Machine(**machine_dict) - def test_lazy_local(self): + def test_lazy_local(self) -> None: machine_dict = { "batch_type": "PBS", "context_type": "LazyLocalContext", @@ -72,7 +72,7 @@ def test_lazy_local(self): # pylint: disable=maybe-no-member self.assertIsInstance(machine.context, LazyLocalContext) - def test_lower_case(self): + def test_lower_case(self) -> None: machine_dict = { "batch_type": "pbs", "context_type": "lazylocalcontext", @@ -81,7 +81,7 @@ def test_lower_case(self): machine = Machine.load_from_dict(machine_dict=machine_dict) self.assertIsInstance(machine.context, LazyLocalContext) - def test_no_ending_context(self): + def test_no_ending_context(self) -> None: machine_dict = { "batch_type": "PBS", "context_type": "lazylocal", @@ -90,7 +90,7 @@ def test_no_ending_context(self): machine = Machine.load_from_dict(machine_dict=machine_dict) self.assertIsInstance(machine.context, LazyLocalContext) - def test_local(self): + def test_local(self) -> None: machine_dict = { "batch_type": "PBS", "context_type": "LocalContext", @@ -101,7 +101,7 @@ def test_local(self): # pylint: disable=maybe-no-member self.assertIsInstance(machine.context, LocalContext) - def test_ssh(self): + def test_ssh(self) -> None: pass # jdata = { # 'batch_type': 'pbs', @@ -116,19 +116,19 @@ def test_ssh(self): # ) # self.assertIsInstance(batch.context, SSHContext) - def test_key_err(self): + def test_key_err(self) -> None: # pass machine_dict = {} with self.assertRaises(KeyError): Machine.load_from_dict(machine_dict=machine_dict) - def test_context_err(self): + def test_context_err(self) -> None: machine_dict = {"batch_type": "PBS", "context_type": "foo"} # with self.assertRaises(KeyError): with self.assertRaises(ArgumentValueError): Machine.load_from_dict(machine_dict=machine_dict) - def test_pbs(self): + def test_pbs(self) -> None: machine_dict = { "batch_type": "PBS", "context_type": "LazyLocalContext", @@ -137,7 +137,7 @@ def test_pbs(self): machine = Machine.load_from_dict(machine_dict=machine_dict) self.assertIsInstance(machine, PBS) - def test_lsf(self): + def test_lsf(self) -> None: machine_dict = { "batch_type": "LSF", "context_type": "LazyLocalContext", @@ -146,7 +146,7 @@ def test_lsf(self): machine = Machine.load_from_dict(machine_dict=machine_dict) self.assertIsInstance(machine, LSF) - def test_slurm(self): + def test_slurm(self) -> None: machine_dict = { "batch_type": "Slurm", "context_type": "LazyLocalContext", @@ -155,7 +155,7 @@ def test_slurm(self): machine = Machine.load_from_dict(machine_dict=machine_dict) self.assertIsInstance(machine, Slurm) - def test_jh_unischeduler(self): + def test_jh_unischeduler(self) -> None: machine_dict = { "batch_type": "JH_UniScheduler", "context_type": "LazyLocalContext", @@ -164,7 +164,7 @@ def test_jh_unischeduler(self): machine = Machine.load_from_dict(machine_dict=machine_dict) self.assertIsInstance(machine, JH_UniScheduler) - def test_shell(self): + def test_shell(self) -> None: machine_dict = { "batch_type": "Shell", "context_type": "LazyLocalContext", @@ -173,7 +173,7 @@ def test_shell(self): machine = Machine.load_from_dict(machine_dict=machine_dict) self.assertIsInstance(machine, Shell) - def test_distributed_shell(self): + def test_distributed_shell(self) -> None: machine_dict = { "batch_type": "DistributedShell", "context_type": "HDFSContext", @@ -183,7 +183,7 @@ def test_distributed_shell(self): machine = Machine.load_from_dict(machine_dict=machine_dict) self.assertIsInstance(machine, DistributedShell) - def test_lebesgue(self): + def test_lebesgue(self) -> None: machine_dict = { "batch_type": "Lebesgue", "context_type": "LebesgueContext", @@ -201,10 +201,10 @@ def test_lebesgue(self): class TestContextDispatch(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None - def test_init_lazy_local(self): + def test_init_lazy_local(self) -> None: context_dict = { "context_type": "LazyLocalContext", "local_root": "./", @@ -213,7 +213,7 @@ def test_init_lazy_local(self): context = BaseContext(**context_dict) self.assertIsInstance(context, LazyLocalContext) - def test_subclass_init_local(self): + def test_subclass_init_local(self) -> None: context_dict = { "context_type": "LocalContext", "local_root": "./", @@ -223,7 +223,7 @@ def test_subclass_init_local(self): context = LocalContext(**context_dict) self.assertIsInstance(context, LocalContext) - def test_init_local(self): + def test_init_local(self) -> None: context_dict = { "context_type": "LocalContext", "local_root": "./", diff --git a/tests/test_class_resources.py b/tests/test_class_resources.py index 002d7056..1c42d23b 100644 --- a/tests/test_class_resources.py +++ b/tests/test_class_resources.py @@ -15,36 +15,36 @@ class TestResources(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None self.resources = SampleClass.get_sample_resources() self.resources_dict = SampleClass.get_sample_resources_dict() - def test_eq(self): + def test_eq(self) -> None: self.assertEqual(self.resources, SampleClass.get_sample_resources()) - def test_serialize(self): + def test_serialize(self) -> None: self.assertEqual(self.resources.serialize(), self.resources_dict) - def test_deserialize(self): + def test_deserialize(self) -> None: resources = Resources.deserialize(resources_dict=self.resources_dict) self.assertEqual(self.resources, resources) - def test_serialize_deserialize(self): + def test_serialize_deserialize(self) -> None: self.assertEqual( self.resources, Resources.deserialize(resources_dict=self.resources.serialize()), ) - def test_resources_json(self): + def test_resources_json(self) -> None: with open("jsons/resources.json") as f: resources_json_dict = json.load(f) self.assertTrue(resources_json_dict, self.resources_dict) self.assertTrue(resources_json_dict, self.resources.serialize()) - def test_arginfo(self): + def test_arginfo(self) -> None: self.resources.arginfo() - def test_load_from_json(self): + def test_load_from_json(self) -> None: resources = Resources.load_from_json("jsons/resources.json") self.assertTrue(resources, self.resources) diff --git a/tests/test_class_submission.py b/tests/test_class_submission.py index 96425777..b1e7bb95 100644 --- a/tests/test_class_submission.py +++ b/tests/test_class_submission.py @@ -6,6 +6,8 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) __package__ = "tests" +from typing import Any + from .context import ( JobStatus, Submission, @@ -15,7 +17,7 @@ class TestSubmission(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None pbs = SampleClass.get_sample_pbs_local_context() self.submission = SampleClass.get_sample_submission() @@ -24,7 +26,7 @@ def setUp(self): # self.submission2 = Submission.submission_from_json('jsons/submission.json') # self.submission2 = Submission.submission_from_json('jsons/submission.json') - def test_serialize_deserialize(self): + def test_serialize_deserialize(self) -> None: self.assertEqual( self.submission.serialize(), Submission.deserialize( @@ -32,37 +34,39 @@ def test_serialize_deserialize(self): ).serialize(), ) - def test_get_hash(self): + def test_get_hash(self) -> None: pass - def test_bind_machine(self): + def test_bind_machine(self) -> None: self.assertIsNotNone(self.submission.machine.context.submission) for job in self.submission.belonging_jobs: self.assertIsNotNone(job.machine) - def test_get_submision_state(self): + def test_get_submision_state(self) -> None: pass - def test_handle_unexpected_submission_state(self): + def test_handle_unexpected_submission_state(self) -> None: pass - def test_submit_submission(self): + def test_submit_submission(self) -> None: pass - def test_upload_jobs(self): + def test_upload_jobs(self) -> None: pass - def test_download_jobs(self): + def test_download_jobs(self) -> None: pass - def test_submission_to_json(self): + def test_submission_to_json(self) -> None: pass @patch("dpdispatcher.Submission.submission_to_json") @patch("dpdispatcher.Submission.update_submission_state") def test_check_all_finished( - self, patch_update_submission_state, patch_submission_to_json - ): + self, + patch_update_submission_state: Any, # noqa: ANN401 + patch_submission_to_json: Any, # noqa: ANN401 + ) -> None: patch_update_submission_state = MagicMock(return_value=None) patch_submission_to_json = MagicMock(return_value=None) @@ -86,25 +90,25 @@ def test_check_all_finished( self.submission.belonging_jobs[1].job_state = JobStatus.finished self.assertTrue(self.submission.check_all_finished()) - def test_submission_from_json(self): + def test_submission_from_json(self) -> None: submission2 = Submission.submission_from_json("jsons/submission.json") # print('<<<<<<<', self.submission) # print('>>>>>>>', submission2) self.assertEqual(self.submission.serialize(), submission2.serialize()) - def test_submission_json(self): + def test_submission_json(self) -> None: with open("jsons/submission.json") as f: submission_json_dict = json.load(f) self.assertTrue(submission_json_dict, self.submission.serialize()) - def test_try_recover_from_json(self): + def test_try_recover_from_json(self) -> None: pass - def test_repr(self): + def test_repr(self) -> None: submission_repr = repr(self.submission) j = json.dumps(self.submission.serialize(), indent=4) self.assertEqual(submission_repr, j) # self.submission_to_json() - def test_clean(self): + def test_clean(self) -> None: pass diff --git a/tests/test_class_submission_init.py b/tests/test_class_submission_init.py index 60e9112c..b5190609 100644 --- a/tests/test_class_submission_init.py +++ b/tests/test_class_submission_init.py @@ -11,12 +11,12 @@ class TestSubmissionInit(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None # self.empty_submission = SampleClass.get_sample_empty_submission() # print('TestSubmissionInit.setUp:self.empty_submission.belonging_tasks', self.empty_submission.belonging_tasks) - def test_reigister_task(self): + def test_reigister_task(self) -> None: empty_submission = SampleClass.get_sample_empty_submission() task = SampleClass.get_sample_task() # print('TestSubmissionInit.test_reigister_task:self.empty_submission.belonging_tasks', empty_submission.belonging_tasks) @@ -24,7 +24,7 @@ def test_reigister_task(self): # print('7890809', SampleClass.get_sample_empty_submission().belonging_tasks) self.assertEqual([task], empty_submission.belonging_tasks) - def test_reigister_task_whether_copy(self): + def test_reigister_task_whether_copy(self) -> None: empty_submission = SampleClass.get_sample_empty_submission() task = SampleClass.get_sample_task() empty_submission.register_task(task=task) diff --git a/tests/test_class_task.py b/tests/test_class_task.py index 033a9b2f..018775e6 100644 --- a/tests/test_class_task.py +++ b/tests/test_class_task.py @@ -16,27 +16,27 @@ class TestTask(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.task = SampleClass.get_sample_task() self.task_dict = SampleClass.get_sample_task_dict() - def test_serialize(self): + def test_serialize(self) -> None: self.assertEqual(self.task.serialize(), self.task_dict) - def test_deserialize(self): + def test_deserialize(self) -> None: task = Task.deserialize(task_dict=self.task_dict) self.assertTrue(task, self.task) - def test_serialize_deserialize(self): + def test_serialize_deserialize(self) -> None: self.assertEqual(Task.deserialize(task_dict=self.task.serialize()), self.task) - def test_task_json(self): + def test_task_json(self) -> None: with open("jsons/task.json") as f: task_json_dict = json.load(f) self.assertTrue(task_json_dict, self.task_dict) self.assertTrue(task_json_dict, self.task.serialize()) - def test_repr(self): + def test_repr(self) -> None: task_repr = repr(self.task) print("debug:", task_repr, self.task_dict) self.assertEqual(task_repr, str(self.task_dict)) diff --git a/tests/test_cli.py b/tests/test_cli.py index 12f5606f..a2c217e2 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,7 @@ class TestCLI(unittest.TestCase): - def test_cli(self): + def test_cli(self) -> None: sp.check_output(["dpdisp", "-h"]) for subcommand in ( "submission", diff --git a/tests/test_examples.py b/tests/test_examples.py index 640e8998..6a0d02bd 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -34,7 +34,7 @@ class TestExamples(unittest.TestCase): - def test_arguments(self): + def test_arguments(self) -> None: for arginfo, fn in input_files: fn = str(fn) with self.subTest(fn=fn): diff --git a/tests/test_group_size.py b/tests/test_group_size.py index cb55e7c6..d370107f 100644 --- a/tests/test_group_size.py +++ b/tests/test_group_size.py @@ -36,7 +36,7 @@ class TestGroupSize(TestCase): - def test_works_as_expected(self): + def test_works_as_expected(self) -> None: for group_size, ntasks in group_ntasks_pairs: with self.subTest(group_size): machine = Machine.load_from_dict(j_machine) diff --git a/tests/test_gui.py b/tests/test_gui.py index 25fd7e66..2b2fb851 100644 --- a/tests/test_gui.py +++ b/tests/test_gui.py @@ -7,5 +7,5 @@ class TestDPGUI(unittest.TestCase): - def test_dpgui_entrypoints(self): + def test_dpgui_entrypoints(self) -> None: self.assertTrue(len(generate_dpgui_templates()) > 0) diff --git a/tests/test_hdfs_context.py b/tests/test_hdfs_context.py index 75e5307c..e9df9cf5 100644 --- a/tests/test_hdfs_context.py +++ b/tests/test_hdfs_context.py @@ -21,7 +21,7 @@ @unittest.skipIf(not shutil.which("hadoop"), "requires hadoop") class TestHDFSContext(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: with open("jsons/machine_yarn.json") as f: mdata = json.load(f) cls.machine = Machine.load_from_dict(mdata["machine"]) @@ -29,16 +29,16 @@ def setUpClass(cls): cls.submission.bind_machine(cls.machine) cls.submission_hash = cls.submission.submission_hash - def setUp(self): + def setUp(self) -> None: self.context = self.__class__.machine.context - def test_0_hdfs_context(self): + def test_0_hdfs_context(self) -> None: self.assertIsInstance(self.context, HDFSContext) - def test_1_upload(self): + def test_1_upload(self) -> None: self.context.upload(self.__class__.submission) - def test_2_fake_run(self): + def test_2_fake_run(self) -> None: rfile_tgz = ( self.context.remote_root + "/" @@ -79,7 +79,7 @@ def test_2_fake_run(self): os.chdir(cwd) shutil.rmtree(tmp_dir) - def test_3_download(self): + def test_3_download(self) -> None: self.context.download(self.__class__.submission) file_list = [ "bct-1/log.lammps", diff --git a/tests/test_import_classes.py b/tests/test_import_classes.py index b6ab68b5..d8f75270 100644 --- a/tests/test_import_classes.py +++ b/tests/test_import_classes.py @@ -11,25 +11,25 @@ class TestImportClasses(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None - def test_import_class_Machine(self): + def test_import_class_Machine(self) -> None: from dpdispatcher import Machine self.assertEqual(dpdispatcher.machine.Machine, Machine) - def test_import_class_Resources(self): + def test_import_class_Resources(self) -> None: from dpdispatcher import Resources self.assertEqual(dpdispatcher.submission.Resources, Resources) - def test_import_class_Submission(self): + def test_import_class_Submission(self) -> None: from dpdispatcher import Submission self.assertEqual(dpdispatcher.submission.Submission, Submission) - def test_import_class_Task(self): + def test_import_class_Task(self) -> None: from dpdispatcher import Task self.assertEqual(dpdispatcher.submission.Task, Task) diff --git a/tests/test_lazy_local_context.py b/tests/test_lazy_local_context.py index 47bffb60..47bc54f1 100644 --- a/tests/test_lazy_local_context.py +++ b/tests/test_lazy_local_context.py @@ -13,7 +13,7 @@ class TestLazyLocalContext(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # os.makedirs('loc', exist_ok = True) # os.makedirs('loc/task0', exist_ok = True) # os.makedirs('loc/task1', exist_ok = True) @@ -25,18 +25,18 @@ def setUp(self): submission = MagicMock(work_base="0_md/") self.lazy_local_context.bind_submission(submission) - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree("tmp_lazy_local_context_dir/") - def test_upload(self): + def test_upload(self) -> None: pass - def test_download(self): + def test_download(self) -> None: pass # TODO: support other platforms @unittest.skipIf(sys.platform != "linux", "not linux") - def test_block_call(self): + def test_block_call(self) -> None: code, stdin, stdout, stderr = self.lazy_local_context.block_call("ls") self.assertEqual( stdout.readlines(), diff --git a/tests/test_local_context.py b/tests/test_local_context.py index 6b506fd9..6920c961 100644 --- a/tests/test_local_context.py +++ b/tests/test_local_context.py @@ -20,7 +20,7 @@ # from .context import dpd -def _identical_files(fname0, fname1): +def _identical_files(fname0: str, fname1: str) -> bool: with open(fname0) as fp: code0 = hashlib.sha1(fp.read().encode("utf-8")).hexdigest() with open(fname1) as fp: @@ -29,7 +29,7 @@ def _identical_files(fname0, fname1): class TestIdFile(unittest.TestCase): - def test_id(self): + def test_id(self) -> None: with open("f0", "w") as fp: fp.write("foo") with open("f1", "w") as fp: @@ -38,7 +38,7 @@ def test_id(self): os.remove("f0") os.remove("f1") - def test_diff(self): + def test_diff(self) -> None: with open("f0", "w") as fp: fp.write("foo") with open("f1", "w") as fp: @@ -49,7 +49,7 @@ def test_diff(self): class TestLocalContext(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.tmp_local_root = "test_context_dir/" self.tmp_remote_root = "tmp_local_context_remote_root/" self.local_context = LocalContext( @@ -57,10 +57,10 @@ def setUp(self): ) @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: shutil.rmtree("tmp_local_context_remote_root/") - def test_upload_non_exist(self): + def test_upload_non_exist(self) -> None: submission_hash = "mock_hash_1" task1 = MagicMock(task_work_path="bct-1/", forward_files=["foo.py"]) submission = MagicMock( @@ -72,7 +72,7 @@ def test_upload_non_exist(self): with self.assertRaises(FileNotFoundError): self.local_context.upload(submission) - def test_upload(self): + def test_upload(self) -> None: submission_hash = "mock_hash_2" task1 = MagicMock( task_work_path="bct-1/", @@ -107,7 +107,7 @@ def test_upload(self): # TODO: support other platforms @unittest.skipIf(sys.platform != "linux", "not linux") - def test_block_call(self): + def test_block_call(self) -> None: submission_hash = "mock_hash_3" task1 = MagicMock( task_work_path="bct-1/", forward_files=["input.lammps", "conf.lmp"] @@ -134,7 +134,7 @@ def test_block_call(self): self.assertTrue("No such file or directory\n" in err_msg) @unittest.skipIf(sys.platform == "win32", "sleep is not supported on Windows") - def test_call(self): + def test_call(self) -> None: submission_hash = "mock_hash_4" submission = MagicMock( work_base="0_md/", belonging_tasks=[], submission_hash=submission_hash @@ -161,7 +161,7 @@ def test_call(self): # self.assertEqual(o, None) # self.assertEqual(e, None) - def test_file(self): + def test_file(self) -> None: submission_hash = "mock_hash_5" submission = MagicMock( work_base="0_md/", belonging_tasks=[], submission_hash=submission_hash @@ -180,7 +180,7 @@ def test_file(self): class TestLocalContextDownload(unittest.TestCase): # @classmethod # def setUpClass(cls): - def setUp(self): + def setUp(self) -> None: shutil.copytree(src="test_context_dir/", dst="tmp_local_context_download_dir/") os.makedirs("tmp_local_context_backfill_dir/0_md/bct-1/") os.makedirs("tmp_local_context_backfill_dir/0_md/bct-2/") @@ -191,11 +191,11 @@ def setUp(self): local_root=self.tmp_local_root, remote_root=self.tmp_remote_root ) - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree("tmp_local_context_download_dir/") shutil.rmtree("tmp_local_context_backfill_dir/") - def test_download_trival(self): + def test_download_trival(self) -> None: # submission_hash = 'mock_hash_2' task1 = MagicMock( task_work_path="bct-1/", backward_files=["input.lammps", "conf.lmp"] @@ -227,7 +227,7 @@ def test_download_trival(self): self.assertTrue(os.path.isfile(f1)) self.assertFalse(os.path.islink(f1)) - def test_download_check_exists(self): + def test_download_check_exists(self) -> None: task1 = MagicMock(task_work_path="bct-1/", backward_files=["foo.py"]) submission = MagicMock( work_base="0_md/", @@ -239,7 +239,7 @@ def test_download_check_exists(self): with self.assertRaises(FileNotFoundError): self.local_context.download(submission, check_exists=False) - def test_download_mark_failure_tag(self): + def test_download_mark_failure_tag(self) -> None: task1 = MagicMock(task_work_path="bct-1/", backward_files=["foo.py"]) submission = MagicMock( work_base="0_md/", @@ -256,7 +256,7 @@ def test_download_mark_failure_tag(self): ) self.assertTrue(os.path.isfile(tag_file)) - def test_download_replace_old_files(self): + def test_download_replace_old_files(self) -> None: task1 = MagicMock(task_work_path="bct-1/", backward_files=["input.lammps"]) submission = MagicMock( work_base="0_md/", @@ -276,7 +276,7 @@ def test_download_replace_old_files(self): md5_new = get_file_md5(target_file) self.assertNotEqual(md5_old, md5_new) - def test_download_symlink(self): + def test_download_symlink(self) -> None: task1 = MagicMock( task_work_path="bct-1/", backward_files=["input.lammps.symlink"] ) diff --git a/tests/test_lsf_script_generation.py b/tests/test_lsf_script_generation.py index aa1cb92c..bb37f65d 100755 --- a/tests/test_lsf_script_generation.py +++ b/tests/test_lsf_script_generation.py @@ -20,10 +20,10 @@ class TestLSFScriptGeneration(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None - def test_shell_trival(self): + def test_shell_trival(self) -> None: with open("jsons/machine_lazy_local_lsf.json") as f: machine_dict = json.load(f) @@ -141,7 +141,7 @@ def test_shell_trival(self): self.assertEqual(footer_str, benchmark_footer) @unittest.skipIf(sys.platform == "win32", "skip for persimission error") - def test_template(self): + def test_template(self) -> None: with open("jsons/machine_lazy_local_lsf.json") as f: machine_dict = json.load(f) diff --git a/tests/test_retry.py b/tests/test_retry.py index a4f7cba3..777b4be2 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -4,6 +4,8 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) __package__ = "tests" +from typing import List, NoReturn + from .context import ( RetrySignal, retry, @@ -12,22 +14,22 @@ class TestRetry(unittest.TestCase): - def test_retry_fail(self): + def test_retry_fail(self) -> None: """Always retry.""" @retry(max_retry=3, sleep=0.05, catch_exception=RetrySignal) - def some_method(): + def some_method() -> NoReturn: raise RetrySignal("Failed to do something") with self.assertRaises(RuntimeError): some_method() - def test_retry_success(self): + def test_retry_success(self) -> None: """Retry less than 3 times.""" retry_times = [0] @retry(max_retry=3, sleep=0.05, catch_exception=RetrySignal) - def some_method(retry_times): + def some_method(retry_times: List[int]) -> None: if retry_times[0] < 2: retry_times[0] += 1 raise RetrySignal("Failed to do something") diff --git a/tests/test_rsync_flags.py b/tests/test_rsync_flags.py index 13ebe6e0..73c8be2e 100644 --- a/tests/test_rsync_flags.py +++ b/tests/test_rsync_flags.py @@ -1,6 +1,7 @@ import os import sys import unittest +from typing import Any from unittest.mock import patch sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -13,7 +14,7 @@ class TestRsyncFlags(unittest.TestCase): """Test rsync function flags to ensure correct options are used.""" @patch("dpdispatcher.utils.utils.run_cmd_with_all_output") - def test_rsync_flags_exclude_owner_group(self, mock_run_cmd): + def test_rsync_flags_exclude_owner_group(self, mock_run_cmd: Any) -> None: # noqa: ANN401 """Test that rsync uses flags that exclude owner and group preservation.""" # Mock successful command execution mock_run_cmd.return_value = (0, "", "") @@ -39,7 +40,7 @@ def test_rsync_flags_exclude_owner_group(self, mock_run_cmd): self.assertIn("-q", called_cmd) @patch("dpdispatcher.utils.utils.run_cmd_with_all_output") - def test_rsync_with_proxy_command_flags(self, mock_run_cmd): + def test_rsync_with_proxy_command_flags(self, mock_run_cmd: Any) -> None: # noqa: ANN401 """Test that rsync uses correct flags even with proxy command.""" # Mock successful command execution mock_run_cmd.return_value = (0, "", "") @@ -63,7 +64,7 @@ def test_rsync_with_proxy_command_flags(self, mock_run_cmd): self.assertNotIn("-az", called_cmd) @patch("dpdispatcher.utils.utils.run_cmd_with_all_output") - def test_rsync_error_handling(self, mock_run_cmd): + def test_rsync_error_handling(self, mock_run_cmd: Any) -> None: # noqa: ANN401 """Test that rsync properly handles errors.""" # Mock failed command execution mock_run_cmd.return_value = ( diff --git a/tests/test_rsync_proxy.py b/tests/test_rsync_proxy.py index 9ac1d9be..1e27eff9 100644 --- a/tests/test_rsync_proxy.py +++ b/tests/test_rsync_proxy.py @@ -16,7 +16,7 @@ class TestRsyncProxyCommand(unittest.TestCase): """Test rsync function with proxy command support.""" - def setUp(self): + def setUp(self) -> None: """Set up test files for rsync operations.""" # Check if rsync is available before running tests if shutil.which("rsync") is None: @@ -35,12 +35,12 @@ def setUp(self): self.remote_file_direct = f"root@server:{self.remote_test_dir}/test_direct.txt" self.remote_file_proxy = f"root@server:{self.remote_test_dir}/test_proxy.txt" - def tearDown(self): + def tearDown(self) -> None: """Clean up test files.""" # Remove local test file os.unlink(self.local_file.name) - def test_rsync_with_proxy_command(self): + def test_rsync_with_proxy_command(self) -> None: """Test rsync with proxy command via jump host.""" # Test rsync through jump host: test -> jumphost -> server rsync( @@ -69,7 +69,7 @@ def test_rsync_with_proxy_command(self): # Clean up os.unlink(download_path) - def test_rsync_direct_connection(self): + def test_rsync_direct_connection(self) -> None: """Test rsync without proxy command (direct connection).""" # Test direct rsync: test -> server rsync( @@ -92,7 +92,7 @@ def test_rsync_direct_connection(self): # Clean up os.unlink(download_path) - def test_rsync_with_additional_options(self): + def test_rsync_with_additional_options(self) -> None: """Test rsync with proxy command and additional SSH options.""" # Test rsync with custom port, timeout, and proxy rsync( diff --git a/tests/test_run.py b/tests/test_run.py index 80bae597..fdc08525 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -11,7 +11,7 @@ class TestRun(unittest.TestCase): - def test_run(self): + def test_run(self) -> None: this_dir = Path(__file__).parent cwd = os.getcwd() with tempfile.TemporaryDirectory() as temp_dir: diff --git a/tests/test_run_submission.py b/tests/test_run_submission.py index 958eba7d..5da65c12 100644 --- a/tests/test_run_submission.py +++ b/tests/test_run_submission.py @@ -5,6 +5,7 @@ import sys import tempfile import traceback +from typing import Any sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) __package__ = "tests" @@ -22,7 +23,7 @@ class RunSubmission: - def setUp(self): + def setUp(self) -> None: self.machine_dict = { "batch_type": "Shell", "context_type": "LocalContext", @@ -58,7 +59,7 @@ def setUp(self): ) as f: f.write("inp space") - def test_run_submission(self): + def test_run_submission(self) -> None: machine = Machine.load_from_dict(self.machine_dict) resources = Resources.load_from_dict(self.resources_dict) @@ -124,7 +125,7 @@ def test_run_submission(self): ) ) - def test_failed_submission(self): + def test_failed_submission(self) -> None: machine = Machine.load_from_dict(self.machine_dict) resources = Resources.load_from_dict(self.resources_dict) @@ -166,12 +167,12 @@ def test_failed_submission(self): clean=True, ) - def test_async_run_submission(self): + def test_async_run_submission(self) -> None: machine = Machine.load_from_dict(self.machine_dict) resources = Resources.load_from_dict(self.resources_dict) ntask = 4 - async def run_jobs(ntask): + async def run_jobs(ntask: int) -> Any: # noqa: ANN401 background_tasks = set() for ii in range(ntask): sleep_time = random.random() * 5 + 2 @@ -210,7 +211,7 @@ async def run_jobs(ntask): ) ) - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(os.path.join(self.machine_dict["local_root"])) @@ -219,13 +220,13 @@ def tearDown(self): "outside the slurm testing environment", ) class TestSlurmRun(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["batch_type"] = "Slurm" self.resources_dict["queue_name"] = "normal" @unittest.skip("Manaually skip") # comment this line to open unittest - def test_async_run_submission(self): + def test_async_run_submission(self) -> None: return super().test_async_run_submission() @@ -234,13 +235,13 @@ def test_async_run_submission(self): "outside the slurm testing environment", ) class TestSlurmJobArrayRun(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["batch_type"] = "SlurmJobArray" self.resources_dict["queue_name"] = "normal" @unittest.skip("Manaually skip") # comment this line to open unittest - def test_async_run_submission(self): + def test_async_run_submission(self) -> None: return super().test_async_run_submission() @@ -249,14 +250,14 @@ def test_async_run_submission(self): "outside the slurm testing environment", ) class TestSlurmJobArrayRun2(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["batch_type"] = "SlurmJobArray" self.resources_dict["queue_name"] = "normal" self.resources_dict["kwargs"] = {"slurm_job_size": 2} @unittest.skip("Manaually skip") # comment this line to open unittest - def test_async_run_submission(self): + def test_async_run_submission(self) -> None: return super().test_async_run_submission() @@ -264,53 +265,53 @@ def test_async_run_submission(self): os.environ.get("DPDISPATCHER_TEST") != "pbs", "outside the pbs testing environment" ) class TestPBSRun(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["batch_type"] = "PBS" self.resources_dict["queue_name"] = "workq" @unittest.skip("Manaually skip") # comment this line to open unittest - def test_async_run_submission(self): + def test_async_run_submission(self) -> None: return super().test_async_run_submission() @unittest.skipIf(sys.platform == "win32", "Shell is not supported on Windows") class TestLocalContext(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.temp_dir = tempfile.TemporaryDirectory() self.machine_dict["context_type"] = "LocalContext" self.machine_dict["remote_root"] = self.temp_dir.name - def tearDown(self): + def tearDown(self) -> None: super().tearDown() self.temp_dir.cleanup() @unittest.skip("It seems the remote file may be deleted") - def test_async_run_submission(self): + def test_async_run_submission(self) -> None: return super().test_async_run_submission() @unittest.skipIf(sys.platform == "win32", "Shell is not supported on Windows") class TestLocalContextCopy(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.temp_dir = tempfile.TemporaryDirectory() self.machine_dict["context_type"] = "LocalContext" self.machine_dict["remote_root"] = self.temp_dir.name self.machine_dict["remote_profile"]["symlink"] = False - def tearDown(self): + def tearDown(self) -> None: super().tearDown() self.temp_dir.cleanup() @unittest.skip("It seems the remote file may be deleted") - def test_async_run_submission(self): + def test_async_run_submission(self) -> None: return super().test_async_run_submission() @unittest.skipIf(sys.platform == "win32", "Shell is not supported on Windows") class TestLazyLocalContext(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["context_type"] = "LazyLocalContext" diff --git a/tests/test_run_submission_bohrium.py b/tests/test_run_submission_bohrium.py index d4d58898..655da45c 100644 --- a/tests/test_run_submission_bohrium.py +++ b/tests/test_run_submission_bohrium.py @@ -3,6 +3,7 @@ import textwrap import unittest from pathlib import Path +from typing import Any sys.path.insert(0, os.path.abspath(os.path.dirname(__file__))) @@ -14,7 +15,7 @@ "outside the Bohrium testing environment", ) class TestBohriumRun(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict.update( batch_type="Bohrium", @@ -37,7 +38,7 @@ def setUp(self): ) @unittest.skip("Manaually skip") # comment this line to open unittest - def test_async_run_submission(self): + def test_async_run_submission(self) -> Any: # noqa: ANN401 return super().test_async_run_submission() @@ -46,7 +47,7 @@ def test_async_run_submission(self): "outside the Bohrium testing environment", ) class TestOpenAPIRun(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() bohrium_config = textwrap.dedent( """\ @@ -68,5 +69,5 @@ def setUp(self): ) @unittest.skip("Manaually skip") # comment this line to open unittest - def test_async_run_submission(self): + def test_async_run_submission(self) -> Any: # noqa: ANN401 return super().test_async_run_submission() diff --git a/tests/test_run_submission_ratio_unfinished.py b/tests/test_run_submission_ratio_unfinished.py index 98942ea6..405531e3 100644 --- a/tests/test_run_submission_ratio_unfinished.py +++ b/tests/test_run_submission_ratio_unfinished.py @@ -18,7 +18,7 @@ class RunSubmission: - def setUp(self): + def setUp(self) -> None: self.machine_dict = { "batch_type": "Shell", "context_type": "LocalContext", @@ -55,7 +55,7 @@ def setUp(self): ) as f: f.write("inp space") - def test_run_submission(self): + def test_run_submission(self) -> None: machine = Machine.load_from_dict(self.machine_dict) resources = Resources.load_from_dict(self.resources_dict) @@ -90,7 +90,7 @@ def test_run_submission(self): ) submission.run_submission() - def tearDown(self): + def tearDown(self) -> None: shutil.rmtree(os.path.join(self.machine_dict["local_root"])) @@ -99,7 +99,7 @@ def tearDown(self): "outside the slurm testing environment", ) class TestSlurmRun(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["batch_type"] = "Slurm" self.resources_dict["queue_name"] = "normal" @@ -110,7 +110,7 @@ def setUp(self): "outside the slurm testing environment", ) class TestSlurmJobArrayRun(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["batch_type"] = "SlurmJobArray" self.resources_dict["queue_name"] = "normal" @@ -121,7 +121,7 @@ def setUp(self): "outside the slurm testing environment", ) class TestSlurmJobArrayRun2(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["batch_type"] = "SlurmJobArray" self.resources_dict["queue_name"] = "normal" @@ -132,7 +132,7 @@ def setUp(self): os.environ.get("DPDISPATCHER_TEST") != "pbs", "outside the pbs testing environment" ) class TestPBSRun(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["batch_type"] = "PBS" self.resources_dict["queue_name"] = "workq" @@ -140,6 +140,6 @@ def setUp(self): @unittest.skipIf(sys.platform == "win32", "Shell is not supported on Windows") class TestLazyLocalContext(RunSubmission, unittest.TestCase): - def setUp(self): + def setUp(self) -> None: super().setUp() self.machine_dict["context_type"] = "LazyLocalContext" diff --git a/tests/test_shell_cuda_multi_devices.py b/tests/test_shell_cuda_multi_devices.py index f55ed7d9..0e020ff9 100755 --- a/tests/test_shell_cuda_multi_devices.py +++ b/tests/test_shell_cuda_multi_devices.py @@ -19,10 +19,10 @@ @unittest.skipIf(sys.platform == "win32", "Shell is not supported on Windows") class TestShellCudaMultiDevices(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None - def test_shell_cuda_multi_devices(self): + def test_shell_cuda_multi_devices(self) -> None: with open("jsons/machine_if_cuda_multi_devices.json") as f: machine_dict = json.load(f) machine = Machine.load_from_dict(machine_dict["machine"]) @@ -59,6 +59,6 @@ def test_shell_cuda_multi_devices(self): self.assertTrue(os.path.isfile("test_if_cuda_multi_devices/test_dir/out.txt")) @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: shutil.rmtree("tmp_if_cuda_multi_devices/") # pass diff --git a/tests/test_shell_trival.py b/tests/test_shell_trival.py index 241fb9bc..ea573e02 100755 --- a/tests/test_shell_trival.py +++ b/tests/test_shell_trival.py @@ -21,14 +21,14 @@ @unittest.skipIf(sys.platform == "win32", "Shell is not supported on Windows") class TestShellTrival(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None # self.local_context_dict = { # 'local_root': './test_shell_trival_dir', # 'remote_root': './tmp_shell_trival_dir' # } - def test_shell_trival(self): + def test_shell_trival(self) -> None: with open("jsons/machine_local_shell.json") as f: machine_dict = json.load(f) @@ -82,7 +82,7 @@ def test_shell_trival(self): f2 = os.path.join("test_shell_trival_dir/", "parent_dir/", dir, "out.txt") self.assertEqual(get_file_md5(f1), get_file_md5(f2)) - def test_shell_fail(self): + def test_shell_fail(self) -> None: with open("jsons/machine_local_shell.json") as f: machine_dict = json.load(f) @@ -112,7 +112,7 @@ def test_shell_fail(self): with self.assertRaises(RuntimeError): submission.run_submission() - def test_shell_recover(self): + def test_shell_recover(self) -> None: with open("jsons/machine_lazylocal_shell.json") as f: machine_dict = json.load(f) @@ -121,7 +121,7 @@ def test_shell_recover(self): pass - def test_dir_with_space(self): + def test_dir_with_space(self) -> None: """Test directory with space.""" with open("jsons/machine_local_shell.json") as f: machine_dict = json.load(f) @@ -156,5 +156,5 @@ def test_dir_with_space(self): self.assertEqual(get_file_md5(f1), get_file_md5(f2)) @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: shutil.rmtree("tmp_shell_trival_dir/") diff --git a/tests/test_slurm_script_generation.py b/tests/test_slurm_script_generation.py index 90ac3211..5eda2b0c 100755 --- a/tests/test_slurm_script_generation.py +++ b/tests/test_slurm_script_generation.py @@ -20,10 +20,10 @@ class TestSlurmScriptGeneration(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.maxDiff = None - def test_shell_trival(self): + def test_shell_trival(self) -> None: with open("jsons/machine_lazy_local_slurm.json") as f: machine_dict = json.load(f) @@ -82,7 +82,7 @@ def test_shell_trival(self): self.assertEqual(str, benchmark_str) @unittest.skipIf(sys.platform == "win32", "skip for persimission error") - def test_template(self): + def test_template(self) -> None: with open("jsons/machine_lazy_local_slurm.json") as f: machine_dict = json.load(f) diff --git a/tests/test_ssh_context.py b/tests/test_ssh_context.py index 05a1e986..476cf019 100644 --- a/tests/test_ssh_context.py +++ b/tests/test_ssh_context.py @@ -23,7 +23,7 @@ ) class TestSSHContext(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: mdata = { "batch_type": "Shell", "context_type": "SSHContext", @@ -60,18 +60,18 @@ def setUpClass(cls): cls.machine.context.write_file(file, "# mock log") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: cls.machine.context.clean() # close the server cls.machine.context.close() - def setUp(self): + def setUp(self) -> None: self.context = self.__class__.machine.context - def test_ssh_session(self): + def test_ssh_session(self) -> None: self.assertIsInstance(self.__class__.machine.context.ssh_session, SSHSession) - def test_upload(self): + def test_upload(self) -> None: self.context.upload(self.__class__.submission) check_file_list = [ "graph.pb", @@ -86,7 +86,7 @@ def test_upload(self): ) ) - def test_empty_transfer(self): + def test_empty_transfer(self) -> None: # Both forward_files and backward_files are empty machine = Machine.load_from_dict(self.machine.serialize()) resources = Resources.load_from_dict( @@ -116,7 +116,7 @@ def test_empty_transfer(self): ) submission.run_submission() - def test_recover(self): + def test_recover(self) -> None: """Test recover from a previous submission.""" machine = Machine.load_from_dict(self.machine.serialize()) resources = Resources.load_from_dict( @@ -163,7 +163,7 @@ def test_recover(self): ) submission.run_submission() - def test_download(self): + def test_download(self) -> None: self.context.download(self.__class__.submission) @@ -172,7 +172,7 @@ def test_download(self): ) class TestSSHContextNoCompress(unittest.TestCase): @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: mdata = { "batch_type": "Shell", "context_type": "SSHContext", @@ -208,18 +208,18 @@ def setUpClass(cls): cls.machine.context.write_file(file, "# mock log") @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: cls.machine.context.clean() # close the server cls.machine.context.close() - def setUp(self): + def setUp(self) -> None: self.context = self.__class__.machine.context - def test_ssh_session(self): + def test_ssh_session(self) -> None: self.assertIsInstance(self.__class__.machine.context.ssh_session, SSHSession) - def test_upload(self): + def test_upload(self) -> None: self.context.upload(self.__class__.submission) check_file_list = [ "graph.pb", @@ -234,5 +234,5 @@ def test_upload(self): ) ) - def test_download(self): + def test_download(self) -> None: self.context.download(self.__class__.submission) diff --git a/tests/test_ssh_jump_host.py b/tests/test_ssh_jump_host.py index 45225ac4..85224ee8 100644 --- a/tests/test_ssh_jump_host.py +++ b/tests/test_ssh_jump_host.py @@ -16,7 +16,7 @@ class TestSSHJumpHost(unittest.TestCase): """Test SSH jump host functionality.""" - def test_proxy_command_connection(self): + def test_proxy_command_connection(self) -> None: """Test SSH connection using proxy_command via jump host.""" # Test connection from test -> server via jumphost ssh_session = SSHSession( @@ -44,7 +44,7 @@ def test_proxy_command_connection(self): ssh_session.close() - def test_direct_connection_no_proxy(self): + def test_direct_connection_no_proxy(self) -> None: """Test direct SSH connection without proxy command.""" # Test direct connection from test -> server (no proxy) ssh_session = SSHSession( @@ -66,7 +66,7 @@ def test_direct_connection_no_proxy(self): ssh_session.close() - def test_jump_host_direct_connection(self): + def test_jump_host_direct_connection(self) -> None: """Test direct connection to jump host itself.""" # Test direct connection from test -> jumphost ssh_session = SSHSession(