Source code for mlonmcu.setup.utils

#
# Copyright (c) 2022 TUM Department of Electrical and Computer Engineering.
#
# This file is part of MLonMCU.
# See https://github.com/tum-ei-eda/mlonmcu.git for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import signal
import sys
import multiprocessing
import subprocess
import tarfile
import zipfile
import shutil
import tempfile
import hashlib
import urllib.request
from pathlib import Path
from packaging.version import Version
from typing import Union, List, Callable, Optional
from git import Repo
from tqdm import tqdm

from mlonmcu import logging

# from mlonmcu.context.context import MlonMcuContext
from mlonmcu.environment.config import RepoConfig

logger = logging.get_logger()


[docs] def makeFlags(*args): """Resolve tuple-like arguments to a list of string. Parameters ---------- args List of tuples of the form: [(True, "foo"), (False, "bar")] Returns ------- dirname : str The generated directory name Examples -------- >>> makeFlags((True, "foo"), (False, "bar")) ["foo"] """ return [name for check, name in args if check]
[docs] def makeDirName(base: str, *args, flags: list = None) -> str: """Creates a directory name based on configuration values. Using snake_case style. Parameters ---------- base : str Prefix of the filename to be generated. args List of tuples of the form: [(True, "foo"), (False, "bar")] flags : list Optional list of additional flags to be added. Returns ------- dirname : str The generated directory name Examples -------- >>> makeDirName("base", (True, "foo"), (False, "bar"), flags=["flag"]) "base_foo_flag" """ names = [base] + makeFlags(*args) if flags: names = names + flags return "_".join(names)
[docs] def exec(*args, **kwargs): """Execute a process with the given args and using the given kwards as Popen arguments. Parameters ---------- args The command to be executed. kwargs Parameters to be passed to subprocess """ logger.warning("DEPRECATED: Please use utils.execute(..., ignore_output=True) instead of utils.exec(...)") # Original implementation # logger.debug("- Executing: " + str(args)) # if "cwd" in kwargs: # logger.debug("- CWD: " + str(kwargs["cwd"])) # subprocess.run([i for i in args], **kwargs, check=True) # Call new implementation _ = execute(*args, ignore_output=True, live=False, print_func=None, **kwargs)
[docs] def exec_getout(*args, live: bool = False, print_output: bool = False, handle_exit=None, prefix="", **kwargs) -> str: """Execute a process with the given args and using the given kwards as Popen arguments and return the output. Parameters ---------- args The command to be executed. live : bool If the stdout should be updated in real time. print_output : bool Print the output at the end on non-live mode. Returns ------- output The text printed to the command line. """ logger.warning("DEPRECATED: Please use utils.execute(...) instead of utils.exec_getout(...)") # Original implementation: # logger.debug("- Executing: " + str(args)) # outStr = "" # if live: # process = subprocess.Popen([i for i in args], **kwargs, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) # try: # for line in process.stdout: # new_line = prefix + line.decode(errors="replace") # outStr = outStr + new_line # print(new_line.replace("\n", "")) # exit_code = None # while exit_code is None: # exit_code = process.poll() # if handle_exit is not None: # exit_code = handle_exit(exit_code) # assert exit_code == 0, "The process returned an non-zero exit code {}! (CMD: `{}`)".format( # exit_code, " ".join(list(map(str, args))) # ) # except KeyboardInterrupt: # logger.debug("Interrupted subprocess. Sending SIGINT signal...") # pid = process.pid # os.kill(pid, signal.SIGINT) # else: # try: # p = subprocess.Popen([i for i in args], **kwargs, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) # outStr = p.communicate()[0].decode(errors="replace") # exit_code = p.poll() # # outStr = p.stdout.decode(errors="replace") # if print_output: # logger.debug(prefix + outStr) # if handle_exit is not None: # exit_code = handle_exit(exit_code) # if exit_code != 0: # logger.error(outStr) # assert exit_code == 0, "The process returned an non-zero exit code {}! (CMD: `{}`)".format( # exit_code, " ".join(list(map(str, args))) # ) # except KeyboardInterrupt: # logger.debug("Interrupted subprocess. Sending SIGINT signal...") # pid = p.pid # os.kill(pid, signal.SIGINT) # except subprocess.CalledProcessError as e: # outStr = e.output.decode(errors="replace") # logger.error(outStr) # raise e # return outStr return execute( *args, ignore_output=False, live=live, # print_func=print, handle_exit=handle_exit, err_func=logger.error, prefix=prefix, **kwargs, )
[docs] def execute( *args: List[str], ignore_output: bool = False, live: bool = False, print_func: Callable = print, handle_exit: Optional[Callable] = None, err_func: Callable = logger.error, encoding: Optional[str] = "utf-8", stdin_data: Optional[bytes] = None, prefix: str = "", **kwargs, ) -> str: """Wrapper for running a program in a subprocess. Parameters ---------- args : list The actual command. ignore_output : bool Do not get the stdout and stderr or the subprocess. live : bool Print the output line by line instead of only at the end. print_func : Callable Function which should be used to print sysout messages. handle_exit: Callable Handler for exit code. err_func : Callable Function which should be used to print errors. encoding: str, optional Used encoding for the stdout. stdin_data: bytes, optional Send this to the stdin of the process. kwargs: dict Arbitrary keyword arguments passed through to the subprocess. Returns ------- out : str The command line output of the command """ # TODO: catch keyboardinterrupt logger.debug("- Executing: %s", str(args)) if "cwd" in kwargs: logger.debug("- CWD: %s", str(kwargs["cwd"])) # if "env" in kwargs: # logger.debug("- ENV: %s", str(kwargs["env"])) if ignore_output: assert not live subprocess.run(args, **kwargs, check=True) return None def args_helper(x): x = str(x) if "[" in x or "]" in x or " " in x: x = f'"{x}"' return x out_str = "" if live: with subprocess.Popen( args, **kwargs, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, ) as process: try: if stdin_data: raise RuntimeError("stdin_data only supported if live=False") # not working... # process.stdin.write(stdin_data) for line in process.stdout: if encoding: line = line.decode(encoding, errors="replace") new_line = prefix + line else: new_line = line out_str = out_str + new_line print_func(new_line.replace("\n", "")) exit_code = None while exit_code is None: exit_code = process.poll() if handle_exit is not None: out_str_ = out_str if encoding is None: out_str_ = out_str_.decode("utf-8", errors="ignore") exit_code = handle_exit(exit_code, out=out_str_) assert exit_code == 0, "The process returned an non-zero exit code {}! (CMD: `{}`)".format( exit_code, " ".join(list(map(args_helper, args))) ) except KeyboardInterrupt: logger.debug("Interrupted subprocess. Sending SIGINT signal...") pid = process.pid os.kill(pid, signal.SIGINT) else: try: p = subprocess.Popen( [i for i in args], **kwargs, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT ) if stdin_data: out_str = p.communicate(input=stdin_data)[0] else: out_str = p.communicate()[0] if encoding: out_str = out_str.decode(encoding, errors="replace") out_str = prefix + out_str exit_code = p.poll() # print_func(out_str) if handle_exit is not None: out_str_ = out_str if encoding is None: out_str_ = out_str_.decode("utf-8", errors="ignore") exit_code = handle_exit(exit_code, out=out_str_) if exit_code != 0: err_func(out_str) assert exit_code == 0, "The process returned an non-zero exit code {}! (CMD: `{}`)".format( exit_code, " ".join(list(map(args_helper, args))) ) except KeyboardInterrupt: logger.debug("Interrupted subprocess. Sending SIGINT signal...") pid = p.pid os.kill(pid, signal.SIGINT) except subprocess.CalledProcessError as e: out_str = e.output.decode(errors="replace") err_func(out_str) raise e return out_str
[docs] def python(*args, **kwargs): """Run a python script with the current interpreter.""" return execute(sys.executable, *args, **kwargs)
# Makes sure all directories at the given path are created.
[docs] def mkdirs(path: Union[str, bytes, os.PathLike]): """Wrapper for os.makedirs which handels the special case where the path already exits.""" if not os.path.exists(path): os.makedirs(path)
# Clones a git repository at given url into given dir and switches to the given branch.
[docs] def clone( url: str, dest: Union[str, bytes, os.PathLike], branch: str = "", submodules: list = [], recursive: bool = False, refresh: bool = False, ): """Helper function for cloning a repository. Parameters ---------- url : str Clone URL of the repository. dest : Path Destination directory path. branch : str Optional branch name or commit reference/tag. submodules : list of strings Only affects when recursive is true. Submodules to be updated. If empty, all submodules will be updated. recursive : bool If the clone should be done recursively. refesh : bool Enables switching the url/branch if the repo already exists """ mkdirs(dest) def update_submodules(): if recursive: if submodules: for submodule in submodules: assert isinstance(submodule, str), f"Submodules should be a list of str. {submodule} is not str." repo.git.submodule("update", "--init", "--recursive", "--", *submodules) else: repo.git.submodule("update", "--init", "--recursive") else: # TODO: share code if submodules: for submodule in submodules: assert isinstance(submodule, str), f"Submodules should be a list of str. {submodule} is not str." repo.git.submodule("update", "--init", "--", *submodules) else: repo.git.submodule("update", "--init") if is_populated(dest): if refresh: repo = Repo(dest) # TODO: backup old remote? repo.remotes.origin.set_url(url) repo.remotes.origin.fetch() repo.git.checkout(branch) repo.git.pull("origin", branch) # This should also work for specific commits update_submodules() else: if branch: repo = Repo.clone_from(url, dest, recursive=recursive, no_checkout=True) repo.git.checkout(branch) update_submodules() else: Repo.clone_from(url, dest, recursive=recursive)
[docs] def clone_wrapper(cfg: RepoConfig, dest: Union[str, bytes, os.PathLike], refresh: bool = False): clone(cfg.url, dest, branch=cfg.ref, submodules=cfg.submodules, recursive=cfg.recursive, refresh=refresh)
[docs] def apply( repo_dir: Path, patch_file: Path, ): """Helper function for applying a patch to a repository. Parameters ---------- repo_dir : Path Clone directory of repository. patch_file : Path Path to patch file. """ repo = Repo(repo_dir) repo.git.clean("-xdf") # Undo all changes repo.git.apply(patch_file)
[docs] def make(*args, threads=multiprocessing.cpu_count(), use_ninja=False, cwd=None, verbose=False, **kwargs): if cwd is None: raise RuntimeError("Please always pass a cwd to make()") if isinstance(cwd, Path): cwd = str(cwd.resolve()) # TODO: make sure that ninja is installed? extraArgs = [] tool = "ninja" if use_ninja else "make" extraArgs.append("-j" + str(threads)) cmd = [tool] + extraArgs + list(args) return execute(*cmd, cwd=cwd, **kwargs)
[docs] def cmake(src, *args, debug=False, use_ninja=False, cwd=None, cmake_exe: Optional[Union[str, Path]] = None, **kwargs): if cwd is None: raise RuntimeError("Please always pass a cwd to cmake()") if isinstance(cwd, Path): cwd = str(cwd.resolve()) buildType = "Debug" if debug else "Release" extraArgs = [] extraArgs.append("-DCMAKE_BUILD_TYPE=" + buildType) if use_ninja: extraArgs.append("-GNinja") if cmake_exe is None: cmake_exe = "cmake" cmd = [cmake_exe, str(src)] + extraArgs + list(args) return execute(*cmd, cwd=cwd, **kwargs)
# def move(a, b): # TODO: make every utility compatible with Paths! # # This can not handle cross file-system renames! # if not isinstance(a, Path): # a = Path(a) # if not isinstance(b, Path): # b = Path(b) # a.replace(b)
[docs] def validate_checksum(path: Path, checksum: str, mode: str = "auto", allow_missmatch: bool = False): if isinstance(path, str): path = Path(path) assert path.is_file(), "File does not exists: {path}" if ":" in checksum: mode_, checksum = checksum.split(":", 1) else: mode_ = None if mode == "auto": assert mode_ is not None, "Could not infer mode from checksum" mode = mode_ else: if mode_ is not None: assert mode == mode_, "Checksum mode missmatch" mode_lookup = { "sha256": hashlib.sha256, "md5": hashlib.md5, } mode_lib = mode_lookup.get(mode) assert mode_lib is not None, f"Unhandled checksum mode: {mode}" mode_lib = mode_lib() # See: https://gist.github.com/airtower-luna/a5df5d6143c8e9ffe7eb5deb5797a0e0 with open(path, "rb") as fh: # Read and hash the file in 4K chunks. Reading the whole # file at once might consume a lot of memory if it is # large. while True: data = fh.read(4096) if len(data) == 0: break else: mode_lib.update(data) checksum_ = mode_lib.hexdigest() checksum_matches = checksum == checksum_ if not checksum_matches: msg = f"Checksum missmatch for {path.name}: {checksum} vs. {checksum_}" if allow_missmatch: logger.warning(msg) else: raise RuntimeError(msg) return checksum_matches
[docs] def download(url, dest, checksum: str = None, progress=False): logger.debug("- Downloading: %s", url) def hook(t): """Wraps tqdm instance.""" last_b = [0] def update_to(b=1, bsize=1, tsize=None): """ b : int, optional Number of blocks transferred so far [default: 1]. bsize : int, optional Size of each block (in tqdm units) [default: 1]. tsize : int, optional Total size (in tqdm units). If [default: None] remains unchanged. """ if tsize is not None: t.total = tsize t.update((b - last_b[0]) * bsize) last_b[0] = b return update_to if progress: with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc="Downloading File") as t: urllib.request.urlretrieve(url, dest, reporthook=hook(t)) else: urllib.request.urlretrieve(url, dest) if checksum: _ = validate_checksum(dest, checksum, allow_missmatch=False)
[docs] def extract(archive, dest, progress=False): ext = Path(archive).suffix[1:] def handle(f): if progress: members = f.getmembers() for m in tqdm(iterable=members, total=len(members), desc="Extracting..."): f.extract(m, dest) else: f.extractall(dest) if ext == "zip": with zipfile.ZipFile(archive) as zip_file: handle(zip_file) elif ext in ["tar", "gz", "xz", "tgz", "bz2"]: with tarfile.open(archive) as tar_file: handle(tar_file) else: raise RuntimeError("Unable to detect the archive type")
[docs] def remove(path): os.remove(path)
[docs] def move(src, dest): shutil.move(src, dest)
[docs] def copy(src, dest): shutil.copy(src, dest)
[docs] def is_populated(path): if not isinstance(path, Path): path = Path(path) return path.is_dir() and os.listdir(path.resolve())
[docs] def download_and_extract(url, archive, dest, checksum: str = None, progress=False, force=True): if isinstance(dest, str): dest = Path(dest) assert isinstance(dest, Path) with tempfile.TemporaryDirectory() as tmp_dir: tmp_archive = os.path.join(tmp_dir, archive) base_name = Path(archive).stem if ".tar" in base_name: base_name = Path(base_name).stem if url[-1] != "/": url += "/" download(url + archive, tmp_archive, checksum=checksum, progress=progress) extract(tmp_archive, tmp_dir, progress=progress) remove(os.path.join(tmp_dir, tmp_archive)) mkdirs(dest.parent) if (Path(tmp_dir) / base_name).is_dir(): # Archive contains a subdirectory with the same name move(os.path.join(tmp_dir, base_name), dest) else: contents = list(Path(tmp_dir).glob("*")) if len(contents) == 1: tmp_dir_new = Path(tmp_dir) / contents[0] if tmp_dir_new.is_dir(): # Archive contains a single subdirectory with a different name tmp_dir = tmp_dir_new if dest.is_dir(): assert force, f"Set force=True to replace destination {dest}" shutil.rmtree(dest) move(tmp_dir, dest)
[docs] def patch(path, cwd=None): raise NotImplementedError
[docs] def check_version(version: str, min_version: Optional[str] = None, max_version: Optional[str] = None): version = Version(version) if min_version is not None: min_version = Version(min_version) if version < min_version: return False if max_version is not None: max_version = Version(max_version) if version > max_version: return False return True
[docs] def check_program(name: str, allow_none: bool = False): path = shutil.which(name) if path is None: assert allow_none, f"Program {name} not found in path" return None # path = Path(os.readlink(path)) path = Path(path).resolve() return path
[docs] def detect_system_llvm(major_version_hint: Optional[int] = None, allow_none: bool = False): if major_version_hint is not None: name = f"clang-{major_version_hint}" else: # This will not always be the latest version but the default one... name = "clang" clang_path = check_program(name, allow_none=allow_none) llvm_config_path = clang_path.parent / "llvm-config" assert llvm_config_path.is_file(), f"llvm-config not found in: {llvm_config_path}" llvm_prefix = execute(llvm_config_path, "--prefix", live=False) if llvm_prefix is None: assert allow_none, "Could not get llvm install dir" return None return llvm_prefix.strip()
[docs] def detect_llvm_version(llvm_dir: Union[str, Path], full: bool = True): llvm_dir = Path(llvm_dir) assert llvm_dir.is_dir(), f"Not a directory: {llvm_dir}" llvm_config = llvm_dir / "bin" / "llvm-config" assert llvm_config.is_file() llvm_version_full = execute(llvm_config, "--version", live=False).strip() llvm_version_full = llvm_version_full.replace("git", "") if not full: from packaging.version import Version llvm_version_major = Version(llvm_version_full).major return llvm_version_major return llvm_version_full
[docs] def resolve_llvm( use_system_llvm: bool = False, llvm_version: Optional[str] = None, user_llvm_dir: Optional[Path] = None, mlonmcu_llvm_dir: Optional[Path] = None, allow_none: bool = False, ): if user_llvm_dir is not None: assert Path(user_llvm_dir).is_dir(), "Could not find user LLVM install" llvm_dir = user_llvm_dir elif use_system_llvm: llvm_version_major = Version(llvm_version).major if llvm_version is not None else None system_llvm_dir = None if llvm_version_major is not None: system_llvm_dir = detect_system_llvm(major_version_hint=llvm_version_major, allow_none=True) if system_llvm_dir is None and llvm_version_major is not None: llvm_version_major = None logger.warning("Falling back to default clang version") system_llvm_dir = detect_system_llvm(major_version_hint=None, allow_none=True) if system_llvm_dir is None: assert allow_none, "Could not find system LLVM install" return None, None llvm_dir = system_llvm_dir else: if mlonmcu_llvm_dir is None: assert allow_none, "Could not find MLonMCU LLVM install" return None, None llvm_dir = mlonmcu_llvm_dir llvm_dir = Path(llvm_dir) llvm_version = detect_llvm_version(llvm_dir, full=True) assert llvm_version is not None, "Unable to get LLVM version" return llvm_dir, llvm_version
[docs] def resolve_llvm_wrapper(context, allow_none: bool = False): user_vars = context.environment.vars use_system_llvm = user_vars.get("llvm.use_system", False) llvm_version = user_vars.get("llvm.version", None) user_llvm_dir = user_vars.get("llvm.install_dir", None) mlonmcu_llvm_dir = context.cache.get("llvm.install_dir") return resolve_llvm(use_system_llvm, llvm_version, user_llvm_dir, mlonmcu_llvm_dir, allow_none=allow_none)
[docs] def detect_system_cmake(allow_none: bool = False): name = "cmake" cmake_exe = check_program(name, allow_none=allow_none) return cmake_exe
[docs] def detect_cmake_version(cmake_exe: Union[str, Path]): cmake_version = execute(cmake_exe, "--version", live=False).splitlines()[0].split(" ")[-1] return cmake_version.strip()
[docs] def resolve_cmake( use_system_cmake: bool = False, user_cmake_exe: Optional[Path] = None, mlonmcu_cmake_exe: Optional[Path] = None, allow_none: bool = False, ): if user_cmake_exe is not None: assert Path(user_cmake_exe).is_file(), "Could not find user CMake" cmake_exe = user_cmake_exe elif use_system_cmake: system_cmake_exe = detect_system_cmake() if system_cmake_exe is None: assert allow_none, "Could not find system CMake" return None, None cmake_exe = system_cmake_exe else: if mlonmcu_cmake_exe is None: assert allow_none, "Could not find MLonMCU CMake install" return None, None cmake_exe = mlonmcu_cmake_exe cmake_exe = Path(cmake_exe) cmake_version = detect_cmake_version(cmake_exe) assert cmake_version is not None, "Unable to get CMake version" return cmake_exe, cmake_version
[docs] def resolve_cmake_wrapper(context, allow_none: bool = False): user_vars = context.environment.vars use_system_cmake = user_vars.get("cmake.use_system", False) mlonmcu_cmake_exe = context.cache.get("cmake.exe") return resolve_cmake(use_system_cmake, mlonmcu_cmake_exe, allow_none=allow_none)