#
# Copyright (c) 2022 TUM Department of Electrical and Computer Engineering.
#
# This file is part of MLonMCU.
# See https://github.com/tum-ei-eda/mlonmcu.git for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from .config import (
DefaultsConfig,
PathConfig,
)
from .loader import load_environment_from_file
from .writer import write_environment_to_file
from mlonmcu.feature.type import FeatureType
def _feature_helper(obj, name):
if not obj.enabled:
return []
features = obj.features
if name:
return [feature for feature in features if feature.name == name]
return features
def _extract_names(objs):
return [obj.name for obj in objs]
def _filter_enabled(objs):
return [obj for obj in objs if obj.enabled]
[docs]
class Environment:
def __init__(self):
self._home = None
self.alias = None
self.defaults = DefaultsConfig()
self.paths = {}
self.repos = {}
self.frameworks = []
self.frontends = []
self.platforms = []
self.toolchains = []
self.targets = []
self.vars = {}
self.flags = {}
def __str__(self):
return self.__class__.__name__ + "(" + str(vars(self)) + ")"
@property
def home(self):
"""Home directory of mlonmcu environment."""
return self._home
[docs]
@classmethod
def from_file(cls, filename):
return load_environment_from_file(filename, base=cls)
[docs]
def to_file(self, filename):
write_environment_to_file(self, filename)
[docs]
def lookup_path(self, name):
assert name in self.paths, f"Unable to find '{name}' path in environment config"
return self.paths[name]
[docs]
def lookup_var(self, name, default=None):
return self.vars.get(name, default)
[docs]
def lookup_frontend_feature_configs(self, name=None, frontend=None):
configs = []
if frontend:
names = [frontend.name for frontend in self.frontends]
index = names.index(frontend)
assert index is not None, f"Frontend {frontend} not found in environment config"
configs.extend(_feature_helper(self.frontends[index], name))
else:
for frontend in self.frontends:
configs.extend(_feature_helper(frontend, name))
return configs
[docs]
def lookup_framework_feature_configs(self, name=None, framework=None):
configs = []
if framework:
names = [framework.name for framework in self.frameworks]
index = names.index(framework)
assert index is not None, f"Framework {framework} not found in environment config"
configs.extend(_feature_helper(self.frameworks[index], name))
else:
for framework in self.frameworks:
configs.extend(_feature_helper(framework, name))
return configs
[docs]
def lookup_backend_feature_configs(self, name=None, framework=None, backend=None):
def helper(framework, backend, name):
backend_features = self.framework[framework].backends[backend].features
if name:
return [backend_features[name]]
else:
return backend_features.values()
configs = []
if framework:
names = [framework.name for framework in self.frameworks]
index = names.index(framework)
assert index is not None, f"Framework {framework} not found in environment config"
if backend:
names_ = [backend.name for backend in self.frameworks[index].backends]
index_ = names_.index(backend)
assert index_ is not None, f"Backend {backend} not found in environment config"
configs.extend(_feature_helper(self.frameworks[index].backends[index], name))
else:
for backend in self.frameworks[index].backends:
configs.extend(_feature_helper(backend, name))
else:
for framework in self.frameworks:
if backend:
names_ = [backend.name for backend in framework.backends]
index_ = names_.index(backend)
assert index_ is not None, f"Backend {backend} not found in environment config"
configs.extend(_feature_helper(self.frameworks[index].backends[index], name))
else:
for backend in framework.backends:
configs.extend(_feature_helper(backend, name))
backend = None
return configs
[docs]
def lookup_target_feature_configs(self, name=None, target=None):
configs = []
if target:
names = [target.name for target in self.targets]
index = names.index(target)
assert (
index is not None
), f"Target {target} not found in environment config" # TODO: do not fail, just return empty list
configs.extend(_feature_helper(self.targets[index], name))
else:
for target in self.targets:
configs.extend(_feature_helper(target, name))
return configs
[docs]
def lookup_feature_configs(
self,
name=None,
kind=None,
frontend=None,
framework=None,
backend=None,
platform=None,
target=None,
):
configs = []
if kind == FeatureType.FRONTEND or kind is None:
configs.extend(self.lookup_frontend_feature_configs(name=name, frontend=frontend))
if kind == FeatureType.FRAMEWORK or kind is None:
configs.extend(self.lookup_framework_feature_configs(name=name, framework=framework))
if kind == FeatureType.BACKEND or kind is None:
configs.extend(self.lookup_backend_feature_configs(name=name, framework=framework, backend=backend))
if kind == FeatureType.PLATFORM or kind is None:
configs.extend(self.lookup_platform_feature_configs(name=name, platform=platform))
if kind == FeatureType.TARGET or kind is None:
configs.extend(self.lookup_target_feature_configs(name=name, target=target))
return configs
[docs]
def supports_feature(self, name):
configs = self.lookup_feature_configs(name=name)
supported = [feature.supported for feature in configs]
return any(supported)
[docs]
def has_feature(self, name):
"""An alias for supports_feature."""
return self.supports_feature(name)
[docs]
def lookup_backend_configs(self, backend=None, framework=None, names_only=False):
enabled_frameworks = _filter_enabled(self.frameworks)
configs = []
for framework_config in enabled_frameworks:
if framework is not None and framework_config.name != framework:
continue
enabled_backends = _filter_enabled(framework_config.backends)
if backend is None:
configs.extend(enabled_backends)
else:
for backend_config in enabled_backends:
if backend_config.name == backend:
return [backend_config.name if names_only else backend_config]
return _extract_names(configs) if names_only else configs
[docs]
def lookup_framework_configs(self, framework=None, names_only=False):
enabled_frameworks = _filter_enabled(self.frameworks)
if framework is None:
return _extract_names(enabled_frameworks) if names_only else enabled_frameworks
for framework_config in enabled_frameworks:
if framework_config.name == framework:
return [framework_config.name if names_only else framework_config]
return []
[docs]
def lookup_frontend_configs(self, frontend=None, names_only=False):
enabled_frontends = _filter_enabled(self.frontends)
if frontend is None:
return _extract_names(enabled_frontends) if names_only else enabled_frontends
for frontend_config in enabled_frontends:
if frontend_config.name == frontend:
return [frontend_config.name if names_only else frontend_config]
return []
[docs]
def lookup_target_configs(self, target=None, names_only=False):
enabled_targets = _filter_enabled(self.targets)
if target is None:
return _extract_names(enabled_targets) if names_only else enabled_targets
for target_config in enabled_targets:
if target_config.name == target:
return [target_config.name if names_only else target_config]
return []
[docs]
def has_frontend(self, name):
configs = self.lookup_frontend_configs(frontend=name)
return len(configs) > 0
[docs]
def has_backend(self, name):
if name == "none":
return True
configs = self.lookup_backend_configs(backend=name)
return len(configs) > 0
[docs]
def has_framework(self, name):
configs = self.lookup_framework_configs(framework=name)
return len(configs) > 0
[docs]
def has_target(self, name):
configs = self.lookup_target_configs(target=name)
return len(configs) > 0
# TODO: actually we do not need to explicitly enable those? environment.yml list the default enabled ones instead
# of the supported ones in the environment
# def has_postprocess(self, name):
# configs = self.lookup_postprocess_configs(postprocess=name)
# return len(configs) > 0
[docs]
def get_default_backends(self, framework):
if framework is None or framework not in self.defaults.default_backends:
return []
default = self.defaults.default_backends[framework]
# framework_names = [framework_config.name for framework_config in self.frameworks]
# framework_config = self.frameworks[framework_names.index(framework)]
if default is None:
return []
if isinstance(default, str):
if default == "*": # Wildcard all enabled frameworks
default = self.get_enabled_backends()
else:
default = [default]
else:
assert isinstance(default, list), "TODO"
return default
[docs]
def get_default_frameworks(self):
default = self.defaults.default_framework
if default is None:
return []
if isinstance(default, str):
if default == "*": # Wildcard all enabled frameworks
default = self.get_enabled_frameworks()
else:
default = [default]
else:
assert isinstance(default, list), "TODO"
return default
[docs]
def get_default_targets(self):
default = self.defaults.default_target
if default is not None:
if isinstance(default, str):
if default == "*": # Wildcard all enabled targets
default = self.get_enabled_targets()
else:
default = [default]
else:
assert isinstance(default, list)
return default
[docs]
class DefaultEnvironment(Environment):
def __init__(self):
super().__init__()
self.defaults = DefaultsConfig(
log_level=logging.DEBUG,
log_to_file=False,
default_framework=None,
default_backends={},
default_target=None,
cleanup_auto=False,
cleanup_keep=100,
)
self.paths = {
"deps": PathConfig("./deps"),
"logs": PathConfig("./logs"),
"results": PathConfig("./results"),
"plugins": PathConfig("./plugins"),
"temp": PathConfig("out"),
"models": [
PathConfig("./models"),
],
}
self.repos = {}
self.frameworks = []
self.frontends = []
self.vars = {}
self.flags = {}
self.platforms = []
self.toolchains = {}
self.targets = []
[docs]
class UserEnvironment(DefaultEnvironment):
def __init__(
self,
home,
merge=False,
alias=None,
defaults=None,
paths=None,
repos=None,
frameworks=None,
frontends=None,
platforms=None,
toolchains=None,
targets=None,
variables=None,
default_flags=None,
):
super().__init__()
self._home = home
if merge:
raise NotImplementedError
if alias:
self.alias = alias
if defaults:
self.defaults = defaults
if paths:
self.paths = paths
if repos:
self.repos = repos
if frameworks:
self.frameworks = frameworks
if frontends:
self.frontends = frontends
if platforms:
self.platforms = platforms
if toolchains:
self.toolchains = toolchains
if targets:
self.targets = targets
if variables:
self.vars = variables
if default_flags:
self.flags = default_flags