Source code for mlonmcu.setup.task

#
# Copyright (c) 2022 TUM Department of Electrical and Computer Engineering.
#
# This file is part of MLonMCU.
# See https://github.com/tum-ei-eda/mlonmcu.git for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Definitions of a task registry used to automatically install dependencies."""

from functools import wraps
import itertools
from enum import Enum
import time
from typing import List, Tuple
import networkx as nx
from networkx.drawing.nx_agraph import write_dot
from tqdm import tqdm

from mlonmcu.logging import get_logger

logger = get_logger()


[docs] def apply_params(key, params): ret = key for param_name, param_val in params.items(): ret = ret.replace(f"{{{param_name}}}", str(param_val)) # if key != ret: # print(f"{key} -> {ret}") # input("a") return ret
[docs] def get_combs(data) -> List[dict]: """Utility which returns combinations of the input data. Parameters ---------- data : dict Input dictionary Returns ------- combs : list All combinations of the input data. Examples -------- >>> get_combs({"foo": [False, True], "bar": [5, 10]}) [{"foo": False, "bar": 5}, {"foo": False, "bar": 10}, {"foo": True, "bar": 5}, {"foo": True, "bar": 10}] """ keys = list(data.keys()) values = list(data.values()) prod = list(itertools.product(*values)) if len(prod) == 1: if len(prod[0]) == 0: prod = [] combs = [dict(zip(keys, p)) for p in prod] return combs
[docs] def get_combs_new(data) -> List[dict]: """Utility which returns combinations of the input data. Parameters ---------- data : dict Input dictionary Returns ------- combs : list All combinations of the input data. Examples -------- >>> get_combs_new({"foo": [False, True], "bar,baz": [(5, 123), (10, 456)]}) [{"foo": False, "bar": 5, "baz": 123}, {"foo": False, "bar": 10, "baz": 456}, {"foo": True, "bar": 5, "baz": 123}, {"foo": True, "bar": 10, "baz": 456}] """ keys = [k.split(",") for k in data] values = list(data.values()) # Assertions for grouped parameters for ks, vs in zip(keys, values): if len(ks) > 1: for v in vs: assert isinstance(v, tuple), f"Grouped key {ks} requires tuple values, got {type(v)}" assert len(v) == len(ks), f"Grouped key {ks} expects {len(ks)} elements, got {len(v)}" prod = list(itertools.product(*values)) if len(prod) == 1 and len(prod[0]) == 0: prod = [] return [ {kk: vv for ks, val in zip(keys, p) for kk, vv in zip(ks, val if isinstance(val, tuple) else (val,))} for p in prod ]
[docs] class TaskType(Enum): """Enumeration for the task type.""" MISC = 0 FRAMEWORK = 1 BACKEND = 2 TOOLCHAIN = 3 TARGET = 4 FRONTEND = 5 OPT = 6 FEATURE = 7 PLATFORM = 8
[docs] class TaskGraph: """Task graph object. Attributes ---------- names : list list of task names in the graph dependencies : dict Dependencies between task artifacts providers : dict Providers for all the artifacts Examples ------- TODO """ def __init__(self, names: List[str], dependencies: dict, providers: dict): self.names = names self.dependencies = dependencies self.providers = providers
[docs] def get_graph(self) -> Tuple[list, list]: """Get nodes and edges of the task graph. Returns ------- nodes : list List of edges edges : list List of edge tuples. """ nodes = list(self.names) edges = [] for dest, deps in self.dependencies.items(): for dep in deps: if dep not in self.providers.keys(): raise RuntimeError(f"Unable to resolve dependency '{dep}' for task {dest}") src = self.providers[dep] edge = (src, dest) edges.append(edge) # Remove duplicates edges = list(dict.fromkeys(edges)) return nodes, edges
[docs] def get_order(self) -> list: """Get execution order of tasks via topological sorting.""" nodes, edges = self.get_graph() graph = nx.DiGraph(edges) graph.add_nodes_from(nodes) order = list(nx.topological_sort(graph)) return order
[docs] def export_dot(self, path): """Visualize the task dependency graph.""" nodes, edges = self.get_graph() graph = nx.DiGraph(edges) graph.add_nodes_from(nodes) # order = list(nx.topological_sort(graph)) # TODO: annotate with order # TODO: also export order as extra graph write_dot(graph, path)
[docs] class TaskFactory: """Class which is used to register all available tasks and their annotations. Attributes ---------- registry : dict Mapping of task names and their actual function dependencies : dict Mapping of task dependencies providers : dict Mapping of which task provides which artifacts types : dict Mapping of task types validates : dict Mapping of validation functions for the tasks changed : list List of tasks?artifacts which have changed recently """ def __init__(self): self.registry = {} self.dependencies = {} self.providers = {} self.types = {} self.params = {} self.validates = {} self.changed = [] # Main problem: per
[docs] def reset_changes(self): """Reset all pending changes.""" self.changed = []
[docs] def needs(self, keys, force=True): """Decorator which registers the artifacts a task needs to be processed.""" def real_decorator(function): name = function.__name__ if name in self.dependencies: self.dependencies[name].extend(keys) else: self.dependencies[name] = keys @wraps(function) def wrapper(*args, **kwargs): # logger.debug("Checking inputs...") if force: context = args[0] variables = context.cache._vars for key in keys: if key not in variables.keys() or variables[key] is None: raise RuntimeError(f"Task '{name}' needs the value of '{key}' which is not set") retval = function(*args, **kwargs) return retval return wrapper return real_decorator
[docs] def optional(self, keys): """Decorator for optional task requirements.""" return self.needs(keys, force=False)
[docs] def removes(self, keys): """Decorator for cleanup tasks.""" # TODO: implementation def real_decorator(function): # @wraps(function) # def wrapper(*args, **kwargs): # retval = function(*args, **kwargs) # return retval # return wrapper return function return real_decorator
# def optional(self, keys): # def real_decorator(function): # name = function.__name__ # if name in self.dependencies: # self.dependencies[name].extend(keys) # else: # self.dependencies[name] = keys # @wraps(function) # def wrapper(*args, **kwargs): # retval = function(*args, **kwargs) # return retval # return wrapper # return real_decorator
[docs] def provides(self, keys): """Decorator which registers what a task provides.""" def real_decorator(function): name = function.__name__ for key in keys: self.providers[key] = name # @wraps(function) # def wrapper(*args, **kwargs): # # print("args", args) # # print("kwargs", kwargs) # context = args[0] # # params = kwargs.get("params", {}) # # print("params", params) # # input("a") # for key in keys: # if key in context.cache._vars: # del context.cache._vars[key] # Unset the value before calling function # retval = function(*args, **kwargs) # if retval is not False: # # logger.debug("Checking outputs...") # variables = context.cache._vars # for key in keys: # if key not in variables.keys() or variables[key] is None: # raise RuntimeError(f"Task '{name}' did not set the value of '{key}'") # return retval # self.registry[name] = wrapper # return wrapper return function return real_decorator
[docs] def param(self, flag, options): """Decorator which registers available task parameters.""" if not isinstance(options, list): options = [options] def real_decorator(function): name = function.__name__ if name in self.params: self.params[name][flag] = options else: self.params[name] = {flag: options} # @wraps(function) # def wrapper(*args, **kwargs): # retval = function(*args, **kwargs) # return retval # return wrapper return function return real_decorator
[docs] def validate(self, func): """Decorator which registers validation functions for a task.""" def real_decorator(function): name = function.__name__ self.validates[name] = func return function return real_decorator
[docs] def register(self, category=TaskType.MISC): """Decorator which actually registers a task in the registry.""" def real_decorator(function): name = function.__name__ @wraps(function) def wrapper(*args, rebuild=False, progress=False, **kwargs): # combs = get_combs(self.params[name]) combs = get_combs_new(self.params[name]) print("combs", combs) def get_valid_combs(combs): ret = [] for comb in combs: if name in self.validates: check = self.validates[name](args[0], params=comb) if not check: continue ret.append(comb) return ret combs_ = get_valid_combs(combs) context = args[0] def process(name_, params=None, rebuild=False): if not params: params = {} rebuild = rebuild if name in self.dependencies: for dep in self.dependencies[name]: if dep in self.changed: rebuild = True break keys = [key for key, provider in self.providers.items() if provider == name] for key in keys: if key in context.cache._vars: del context.cache._vars[key] # Unset the value before calling function retval = function(*args, params=params, rebuild=rebuild, **kwargs) variables = context.cache._vars for key in keys: key = apply_params(key, params) if key not in variables.keys() or variables[key] is None: raise RuntimeError(f"Task '{name}' did not set the value of '{key}'") if retval: for key in keys: if key not in self.changed: self.changed.append(key) # logger.debug("Processed task:", function.__name__) return retval if progress: pbar = tqdm( total=max(len(combs_), 1), desc="Processing", ncols=100, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}s]", leave=None, ) else: pbar = None if len(combs_) == 0: if pbar: pbar.set_description(f"Processing: {name}") else: logger.info("Processing task: %s", name) time.sleep(0.1) check = True if len(combs) > 0: check = False else: if name in self.validates: check = self.validates[name](args[0], params={}) if check: start = time.time() retval = process(name, rebuild=rebuild) end = time.time() diff = end - start minutes = int(diff // 60) seconds = int(diff % 60) duration_str = f"{seconds}s" if minutes == 0 else f"{minutes}m{seconds}s" if not pbar: logger.debug("-> Done (%s)", duration_str) # TODO: move this to helper func else: logger.debug("-> Skipped") retval = False if pbar: pbar.update(1) else: for comb in combs_: # TODO process in parallel? extended_name = name + str(comb) if pbar: pbar.set_description(f"Processing - {extended_name}") else: logger.info("Processing task: %s", extended_name) time.sleep(0.1) start = time.time() retval = process(extended_name, params=comb, rebuild=rebuild) end = time.time() diff = end - start minutes = int(diff // 60) seconds = int(diff % 60) duration_str = f"{seconds}s" if minutes == 0 else f"{minutes}m{seconds}s" if not pbar: logger.debug("-> Done (%s)", duration_str) # TODO: move this to helper func else: pbar.update(1) if pbar: pbar.close() return retval self.registry[name] = wrapper self.types[name] = category self.params[name] = {} return wrapper return real_decorator