about summary refs log tree commit diff
path: root/nixos/lib
diff options
context:
space:
mode:
Diffstat (limited to 'nixos/lib')
-rw-r--r--nixos/lib/make-disk-image.nix9
-rw-r--r--nixos/lib/systemd-lib.nix27
-rw-r--r--nixos/lib/systemd-network-units.nix27
-rw-r--r--nixos/lib/test-driver/default.nix16
-rw-r--r--nixos/lib/test-driver/pyproject.toml8
-rwxr-xr-xnixos/lib/test-driver/test_driver/__init__.py30
-rw-r--r--nixos/lib/test-driver/test_driver/driver.py45
-rw-r--r--nixos/lib/test-driver/test_driver/logger.py249
-rw-r--r--nixos/lib/test-driver/test_driver/machine.py16
-rw-r--r--nixos/lib/test-driver/test_driver/polling_condition.py11
-rw-r--r--nixos/lib/test-driver/test_driver/vlan.py15
-rw-r--r--nixos/lib/test-script-prepend.py4
-rw-r--r--nixos/lib/testing/driver.nix2
13 files changed, 370 insertions, 89 deletions
diff --git a/nixos/lib/make-disk-image.nix b/nixos/lib/make-disk-image.nix
index 9bdbf4e0713de..1220bbfd5ed7c 100644
--- a/nixos/lib/make-disk-image.nix
+++ b/nixos/lib/make-disk-image.nix
@@ -603,10 +603,11 @@ let format' = format; in let
       ${lib.optionalString installBootLoader ''
         # In this throwaway resource, we only have /dev/vda, but the actual VM may refer to another disk for bootloader, e.g. /dev/vdb
         # Use this option to create a symlink from vda to any arbitrary device you want.
-        ${optionalString (config.boot.loader.grub.enable && config.boot.loader.grub.device != "/dev/vda") ''
-            mkdir -p $(dirname ${config.boot.loader.grub.device})
-            ln -s /dev/vda ${config.boot.loader.grub.device}
-        ''}
+        ${optionalString (config.boot.loader.grub.enable) (lib.concatMapStringsSep " " (device:
+          lib.optionalString (device != "/dev/vda") ''
+            mkdir -p "$(dirname ${device})"
+            ln -s /dev/vda ${device}
+          '') config.boot.loader.grub.devices)}
 
         # Set up core system link, bootloader (sd-boot, GRUB, uboot, etc.), etc.
 
diff --git a/nixos/lib/systemd-lib.nix b/nixos/lib/systemd-lib.nix
index eef49f8c4ef38..dac5cc7b700c8 100644
--- a/nixos/lib/systemd-lib.nix
+++ b/nixos/lib/systemd-lib.nix
@@ -18,6 +18,7 @@ let
     flip
     head
     isInt
+    isFloat
     isList
     isPath
     length
@@ -152,7 +153,7 @@ in rec {
       "Systemd ${group} field `${name}' is outside the range [${toString min},${toString max}]";
 
   assertRangeOrOneOf = name: min: max: values: group: attr:
-    optional (attr ? ${name} && !((min <= attr.${name} && max >= attr.${name}) || elem attr.${name} values))
+    optional (attr ? ${name} && !(((isInt attr.${name} || isFloat attr.${name}) && min <= attr.${name} && max >= attr.${name}) || elem attr.${name} values))
       "Systemd ${group} field `${name}' is not a value in range [${toString min},${toString max}], or one of ${toString values}";
 
   assertMinimum = name: min: group: attr:
@@ -181,6 +182,30 @@ in rec {
   in if errors == [] then true
      else trace (concatStringsSep "\n" errors) false;
 
+  checkUnitConfigWithLegacyKey = legacyKey: group: checks: attrs:
+    let
+      dump = lib.generators.toPretty { }
+        (lib.generators.withRecursion { depthLimit = 2; throwOnDepthLimit = false; } attrs);
+      attrs' =
+        if legacyKey == null
+          then attrs
+        else if ! attrs?${legacyKey}
+          then attrs
+        else if removeAttrs attrs [ legacyKey ] == {}
+          then attrs.${legacyKey}
+        else throw ''
+          The declaration
+
+          ${dump}
+
+          must not mix unit options with the legacy key '${legacyKey}'.
+
+          This can be fixed by moving all settings from within ${legacyKey}
+          one level up.
+        '';
+    in
+    checkUnitConfig group checks attrs';
+
   toOption = x:
     if x == true then "true"
     else if x == false then "false"
diff --git a/nixos/lib/systemd-network-units.nix b/nixos/lib/systemd-network-units.nix
index ae581495772a8..c35309a6d2628 100644
--- a/nixos/lib/systemd-network-units.nix
+++ b/nixos/lib/systemd-network-units.nix
@@ -63,13 +63,13 @@ in {
       ${attrsToSection def.l2tpConfig}
     '' + flip concatMapStrings def.l2tpSessions (x: ''
       [L2TPSession]
-      ${attrsToSection x.l2tpSessionConfig}
+      ${attrsToSection x}
     '') + optionalString (def.wireguardConfig != { }) ''
       [WireGuard]
       ${attrsToSection def.wireguardConfig}
     '' + flip concatMapStrings def.wireguardPeers (x: ''
       [WireGuardPeer]
-      ${attrsToSection x.wireguardPeerConfig}
+      ${attrsToSection x}
     '') + optionalString (def.bondConfig != { }) ''
       [Bond]
       ${attrsToSection def.bondConfig}
@@ -122,13 +122,13 @@ in {
       ${concatStringsSep "\n" (map (s: "Xfrm=${s}") def.xfrm)}
     '' + "\n" + flip concatMapStrings def.addresses (x: ''
       [Address]
-      ${attrsToSection x.addressConfig}
+      ${attrsToSection x}
     '') + flip concatMapStrings def.routingPolicyRules (x: ''
       [RoutingPolicyRule]
-      ${attrsToSection x.routingPolicyRuleConfig}
+      ${attrsToSection x}
     '') + flip concatMapStrings def.routes (x: ''
       [Route]
-      ${attrsToSection x.routeConfig}
+      ${attrsToSection x}
     '') + optionalString (def.dhcpV4Config != { }) ''
       [DHCPv4]
       ${attrsToSection def.dhcpV4Config}
@@ -147,24 +147,27 @@ in {
     '' + optionalString (def.ipv6SendRAConfig != { }) ''
       [IPv6SendRA]
       ${attrsToSection def.ipv6SendRAConfig}
-    '' + flip concatMapStrings def.ipv6Prefixes (x: ''
+    '' + flip concatMapStrings def.ipv6PREF64Prefixes (x: ''
+      [IPv6PREF64Prefix]
+      ${attrsToSection x}
+    '') + flip concatMapStrings def.ipv6Prefixes (x: ''
       [IPv6Prefix]
-      ${attrsToSection x.ipv6PrefixConfig}
+      ${attrsToSection x}
     '') + flip concatMapStrings def.ipv6RoutePrefixes (x: ''
       [IPv6RoutePrefix]
-      ${attrsToSection x.ipv6RoutePrefixConfig}
+      ${attrsToSection x}
     '') + flip concatMapStrings def.dhcpServerStaticLeases (x: ''
       [DHCPServerStaticLease]
-      ${attrsToSection x.dhcpServerStaticLeaseConfig}
+      ${attrsToSection x}
     '') + optionalString (def.bridgeConfig != { }) ''
       [Bridge]
       ${attrsToSection def.bridgeConfig}
     '' + flip concatMapStrings def.bridgeFDBs (x: ''
       [BridgeFDB]
-      ${attrsToSection x.bridgeFDBConfig}
+      ${attrsToSection x}
     '') + flip concatMapStrings def.bridgeMDBs (x: ''
       [BridgeMDB]
-      ${attrsToSection x.bridgeMDBConfig}
+      ${attrsToSection x}
     '') + optionalString (def.lldpConfig != { }) ''
       [LLDP]
       ${attrsToSection def.lldpConfig}
@@ -251,7 +254,7 @@ in {
       ${attrsToSection def.quickFairQueueingConfigClass}
     '' + flip concatMapStrings def.bridgeVLANs (x: ''
       [BridgeVLAN]
-      ${attrsToSection x.bridgeVLANConfig}
+      ${attrsToSection x}
     '') + def.extraConfig;
 
 }
diff --git a/nixos/lib/test-driver/default.nix b/nixos/lib/test-driver/default.nix
index 1acdaacc4e658..26652db6016e6 100644
--- a/nixos/lib/test-driver/default.nix
+++ b/nixos/lib/test-driver/default.nix
@@ -13,17 +13,27 @@
 , extraPythonPackages ? (_ : [])
 , nixosTests
 }:
-
+let
+  fs = lib.fileset;
+in
 python3Packages.buildPythonApplication {
   pname = "nixos-test-driver";
   version = "1.1";
-  src = ./.;
+  src = fs.toSource {
+    root = ./.;
+    fileset = fs.unions [
+      ./pyproject.toml
+      ./test_driver
+      ./extract-docstrings.py
+    ];
+  };
   pyproject = true;
 
   propagatedBuildInputs = [
     coreutils
     netpbm
     python3Packages.colorama
+    python3Packages.junit-xml
     python3Packages.ptpython
     qemu_pkg
     socat
@@ -46,7 +56,7 @@ python3Packages.buildPythonApplication {
     echo -e "\x1b[32m## run mypy\x1b[0m"
     mypy test_driver extract-docstrings.py
     echo -e "\x1b[32m## run ruff\x1b[0m"
-    ruff .
+    ruff check .
     echo -e "\x1b[32m## run black\x1b[0m"
     black --check --diff .
   '';
diff --git a/nixos/lib/test-driver/pyproject.toml b/nixos/lib/test-driver/pyproject.toml
index 17b7130a4bad7..714139bc1b25c 100644
--- a/nixos/lib/test-driver/pyproject.toml
+++ b/nixos/lib/test-driver/pyproject.toml
@@ -19,8 +19,8 @@ test_driver = ["py.typed"]
 [tool.ruff]
 line-length = 88
 
-select = ["E", "F", "I", "U", "N"]
-ignore = ["E501"]
+lint.select = ["E", "F", "I", "U", "N"]
+lint.ignore = ["E501"]
 
 # xxx: we can import https://pypi.org/project/types-colorama/ here
 [[tool.mypy.overrides]]
@@ -31,6 +31,10 @@ ignore_missing_imports = true
 module = "ptpython.*"
 ignore_missing_imports = true
 
+[[tool.mypy.overrides]]
+module = "junit_xml.*"
+ignore_missing_imports = true
+
 [tool.black]
 line-length = 88
 target-version = ['py39']
diff --git a/nixos/lib/test-driver/test_driver/__init__.py b/nixos/lib/test-driver/test_driver/__init__.py
index 9daae1e941a65..42b6d29b76714 100755
--- a/nixos/lib/test-driver/test_driver/__init__.py
+++ b/nixos/lib/test-driver/test_driver/__init__.py
@@ -6,7 +6,12 @@ from pathlib import Path
 import ptpython.repl
 
 from test_driver.driver import Driver
-from test_driver.logger import rootlog
+from test_driver.logger import (
+    CompositeLogger,
+    JunitXMLLogger,
+    TerminalLogger,
+    XMLLogger,
+)
 
 
 class EnvDefault(argparse.Action):
@@ -93,6 +98,11 @@ def main() -> None:
         type=writeable_dir,
     )
     arg_parser.add_argument(
+        "--junit-xml",
+        help="Enable JunitXML report generation to the given path",
+        type=Path,
+    )
+    arg_parser.add_argument(
         "testscript",
         action=EnvDefault,
         envvar="testScript",
@@ -102,14 +112,24 @@ def main() -> None:
 
     args = arg_parser.parse_args()
 
+    output_directory = args.output_directory.resolve()
+    logger = CompositeLogger([TerminalLogger()])
+
+    if "LOGFILE" in os.environ.keys():
+        logger.add_logger(XMLLogger(os.environ["LOGFILE"]))
+
+    if args.junit_xml:
+        logger.add_logger(JunitXMLLogger(output_directory / args.junit_xml))
+
     if not args.keep_vm_state:
-        rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
+        logger.info("Machine state will be reset. To keep it, pass --keep-vm-state")
 
     with Driver(
         args.start_scripts,
         args.vlans,
         args.testscript.read_text(),
-        args.output_directory.resolve(),
+        output_directory,
+        logger,
         args.keep_vm_state,
         args.global_timeout,
     ) as driver:
@@ -125,7 +145,7 @@ def main() -> None:
             tic = time.time()
             driver.run_tests()
             toc = time.time()
-            rootlog.info(f"test script finished in {(toc-tic):.2f}s")
+            logger.info(f"test script finished in {(toc-tic):.2f}s")
 
 
 def generate_driver_symbols() -> None:
@@ -134,7 +154,7 @@ def generate_driver_symbols() -> None:
     in user's test scripts. That list is then used by pyflakes to lint those
     scripts.
     """
-    d = Driver([], [], "", Path())
+    d = Driver([], [], "", Path(), CompositeLogger([]))
     test_symbols = d.test_symbols()
     with open("driver-symbols", "w") as fp:
         fp.write(",".join(test_symbols.keys()))
diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py
index f792c04591996..01b64b92e9770 100644
--- a/nixos/lib/test-driver/test_driver/driver.py
+++ b/nixos/lib/test-driver/test_driver/driver.py
@@ -9,7 +9,7 @@ from typing import Any, Callable, ContextManager, Dict, Iterator, List, Optional
 
 from colorama import Fore, Style
 
-from test_driver.logger import rootlog
+from test_driver.logger import AbstractLogger
 from test_driver.machine import Machine, NixStartScript, retry
 from test_driver.polling_condition import PollingCondition
 from test_driver.vlan import VLan
@@ -49,6 +49,7 @@ class Driver:
     polling_conditions: List[PollingCondition]
     global_timeout: int
     race_timer: threading.Timer
+    logger: AbstractLogger
 
     def __init__(
         self,
@@ -56,6 +57,7 @@ class Driver:
         vlans: List[int],
         tests: str,
         out_dir: Path,
+        logger: AbstractLogger,
         keep_vm_state: bool = False,
         global_timeout: int = 24 * 60 * 60 * 7,
     ):
@@ -63,12 +65,13 @@ class Driver:
         self.out_dir = out_dir
         self.global_timeout = global_timeout
         self.race_timer = threading.Timer(global_timeout, self.terminate_test)
+        self.logger = logger
 
         tmp_dir = get_tmp_dir()
 
-        with rootlog.nested("start all VLans"):
+        with self.logger.nested("start all VLans"):
             vlans = list(set(vlans))
-            self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
+            self.vlans = [VLan(nr, tmp_dir, self.logger) for nr in vlans]
 
         def cmd(scripts: List[str]) -> Iterator[NixStartScript]:
             for s in scripts:
@@ -84,6 +87,7 @@ class Driver:
                 tmp_dir=tmp_dir,
                 callbacks=[self.check_polling_conditions],
                 out_dir=self.out_dir,
+                logger=self.logger,
             )
             for cmd in cmd(start_scripts)
         ]
@@ -92,19 +96,18 @@ class Driver:
         return self
 
     def __exit__(self, *_: Any) -> None:
-        with rootlog.nested("cleanup"):
+        with self.logger.nested("cleanup"):
             self.race_timer.cancel()
             for machine in self.machines:
                 machine.release()
 
     def subtest(self, name: str) -> Iterator[None]:
         """Group logs under a given test name"""
-        with rootlog.nested("subtest: " + name):
+        with self.logger.subtest(name):
             try:
                 yield
-                return True
             except Exception as e:
-                rootlog.error(f'Test "{name}" failed with error: "{e}"')
+                self.logger.error(f'Test "{name}" failed with error: "{e}"')
                 raise e
 
     def test_symbols(self) -> Dict[str, Any]:
@@ -118,7 +121,7 @@ class Driver:
             machines=self.machines,
             vlans=self.vlans,
             driver=self,
-            log=rootlog,
+            log=self.logger,
             os=os,
             create_machine=self.create_machine,
             subtest=subtest,
@@ -150,13 +153,13 @@ class Driver:
 
     def test_script(self) -> None:
         """Run the test script"""
-        with rootlog.nested("run the VM test script"):
+        with self.logger.nested("run the VM test script"):
             symbols = self.test_symbols()  # call eagerly
             exec(self.tests, symbols, None)
 
     def run_tests(self) -> None:
         """Run the test script (for non-interactive test runs)"""
-        rootlog.info(
+        self.logger.info(
             f"Test will time out and terminate in {self.global_timeout} seconds"
         )
         self.race_timer.start()
@@ -168,13 +171,13 @@ class Driver:
 
     def start_all(self) -> None:
         """Start all machines"""
-        with rootlog.nested("start all VMs"):
+        with self.logger.nested("start all VMs"):
             for machine in self.machines:
                 machine.start()
 
     def join_all(self) -> None:
         """Wait for all machines to shut down"""
-        with rootlog.nested("wait for all VMs to finish"):
+        with self.logger.nested("wait for all VMs to finish"):
             for machine in self.machines:
                 machine.wait_for_shutdown()
             self.race_timer.cancel()
@@ -182,7 +185,7 @@ class Driver:
     def terminate_test(self) -> None:
         # This will be usually running in another thread than
         # the thread actually executing the test script.
-        with rootlog.nested("timeout reached; test terminating..."):
+        with self.logger.nested("timeout reached; test terminating..."):
             for machine in self.machines:
                 machine.release()
             # As we cannot `sys.exit` from another thread
@@ -227,7 +230,7 @@ class Driver:
                     f"Unsupported arguments passed to create_machine: {args}"
                 )
 
-            rootlog.warning(
+            self.logger.warning(
                 Fore.YELLOW
                 + Style.BRIGHT
                 + "WARNING: Using create_machine with a single dictionary argument is deprecated and will be removed in NixOS 24.11"
@@ -246,13 +249,14 @@ class Driver:
             start_command=cmd,
             name=name,
             keep_vm_state=keep_vm_state,
+            logger=self.logger,
         )
 
     def serial_stdout_on(self) -> None:
-        rootlog._print_serial_logs = True
+        self.logger.print_serial_logs(True)
 
     def serial_stdout_off(self) -> None:
-        rootlog._print_serial_logs = False
+        self.logger.print_serial_logs(False)
 
     def check_polling_conditions(self) -> None:
         for condition in self.polling_conditions:
@@ -271,6 +275,7 @@ class Driver:
             def __init__(self, fun: Callable):
                 self.condition = PollingCondition(
                     fun,
+                    driver.logger,
                     seconds_interval,
                     description,
                 )
@@ -285,15 +290,17 @@ class Driver:
             def wait(self, timeout: int = 900) -> None:
                 def condition(last: bool) -> bool:
                     if last:
-                        rootlog.info(f"Last chance for {self.condition.description}")
+                        driver.logger.info(
+                            f"Last chance for {self.condition.description}"
+                        )
                     ret = self.condition.check(force=True)
                     if not ret and not last:
-                        rootlog.info(
+                        driver.logger.info(
                             f"({self.condition.description} failure not fatal yet)"
                         )
                     return ret
 
-                with rootlog.nested(f"waiting for {self.condition.description}"):
+                with driver.logger.nested(f"waiting for {self.condition.description}"):
                     retry(condition, timeout=timeout)
 
         if fun_ is None:
diff --git a/nixos/lib/test-driver/test_driver/logger.py b/nixos/lib/test-driver/test_driver/logger.py
index 0b0623bddfa1e..484829254b812 100644
--- a/nixos/lib/test-driver/test_driver/logger.py
+++ b/nixos/lib/test-driver/test_driver/logger.py
@@ -1,33 +1,238 @@
+import atexit
 import codecs
 import os
 import sys
 import time
 import unicodedata
-from contextlib import contextmanager
+from abc import ABC, abstractmethod
+from contextlib import ExitStack, contextmanager
+from pathlib import Path
 from queue import Empty, Queue
-from typing import Any, Dict, Iterator
+from typing import Any, Dict, Iterator, List
 from xml.sax.saxutils import XMLGenerator
 from xml.sax.xmlreader import AttributesImpl
 
 from colorama import Fore, Style
+from junit_xml import TestCase, TestSuite
 
 
-class Logger:
-    def __init__(self) -> None:
-        self.logfile = os.environ.get("LOGFILE", "/dev/null")
-        self.logfile_handle = codecs.open(self.logfile, "wb")
-        self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
-        self.queue: "Queue[Dict[str, str]]" = Queue()
+class AbstractLogger(ABC):
+    @abstractmethod
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
+        pass
 
-        self.xml.startDocument()
-        self.xml.startElement("logfile", attrs=AttributesImpl({}))
+    @abstractmethod
+    @contextmanager
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        pass
+
+    @abstractmethod
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        pass
+
+    @abstractmethod
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        pass
+
+    @abstractmethod
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        pass
+
+    @abstractmethod
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        pass
+
+    @abstractmethod
+    def log_serial(self, message: str, machine: str) -> None:
+        pass
+
+    @abstractmethod
+    def print_serial_logs(self, enable: bool) -> None:
+        pass
+
+
+class JunitXMLLogger(AbstractLogger):
+    class TestCaseState:
+        def __init__(self) -> None:
+            self.stdout = ""
+            self.stderr = ""
+            self.failure = False
+
+    def __init__(self, outfile: Path) -> None:
+        self.tests: dict[str, JunitXMLLogger.TestCaseState] = {
+            "main": self.TestCaseState()
+        }
+        self.currentSubtest = "main"
+        self.outfile: Path = outfile
+        self._print_serial_logs = True
+        atexit.register(self.close)
+
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
+        self.tests[self.currentSubtest].stdout += message + os.linesep
+
+    @contextmanager
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        old_test = self.currentSubtest
+        self.tests.setdefault(name, self.TestCaseState())
+        self.currentSubtest = name
+
+        yield
+
+        self.currentSubtest = old_test
+
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        self.log(message)
+        yield
+
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        self.tests[self.currentSubtest].stdout += args[0] + os.linesep
+
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        self.tests[self.currentSubtest].stdout += args[0] + os.linesep
+
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        self.tests[self.currentSubtest].stderr += args[0] + os.linesep
+        self.tests[self.currentSubtest].failure = True
+
+    def log_serial(self, message: str, machine: str) -> None:
+        if not self._print_serial_logs:
+            return
+
+        self.log(f"{machine} # {message}")
+
+    def print_serial_logs(self, enable: bool) -> None:
+        self._print_serial_logs = enable
+
+    def close(self) -> None:
+        with open(self.outfile, "w") as f:
+            test_cases = []
+            for name, test_case_state in self.tests.items():
+                tc = TestCase(
+                    name,
+                    stdout=test_case_state.stdout,
+                    stderr=test_case_state.stderr,
+                )
+                if test_case_state.failure:
+                    tc.add_failure_info("test case failed")
+
+                test_cases.append(tc)
+            ts = TestSuite("NixOS integration test", test_cases)
+            f.write(TestSuite.to_xml_string([ts]))
+
+
+class CompositeLogger(AbstractLogger):
+    def __init__(self, logger_list: List[AbstractLogger]) -> None:
+        self.logger_list = logger_list
+
+    def add_logger(self, logger: AbstractLogger) -> None:
+        self.logger_list.append(logger)
+
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
+        for logger in self.logger_list:
+            logger.log(message, attributes)
+
+    @contextmanager
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        with ExitStack() as stack:
+            for logger in self.logger_list:
+                stack.enter_context(logger.subtest(name, attributes))
+            yield
+
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        with ExitStack() as stack:
+            for logger in self.logger_list:
+                stack.enter_context(logger.nested(message, attributes))
+            yield
+
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        for logger in self.logger_list:
+            logger.info(*args, **kwargs)
+
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        for logger in self.logger_list:
+            logger.warning(*args, **kwargs)
+
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        for logger in self.logger_list:
+            logger.error(*args, **kwargs)
+        sys.exit(1)
 
+    def print_serial_logs(self, enable: bool) -> None:
+        for logger in self.logger_list:
+            logger.print_serial_logs(enable)
+
+    def log_serial(self, message: str, machine: str) -> None:
+        for logger in self.logger_list:
+            logger.log_serial(message, machine)
+
+
+class TerminalLogger(AbstractLogger):
+    def __init__(self) -> None:
         self._print_serial_logs = True
 
+    def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
+        if "machine" in attributes:
+            return f"{attributes['machine']}: {message}"
+        return message
+
     @staticmethod
     def _eprint(*args: object, **kwargs: Any) -> None:
         print(*args, file=sys.stderr, **kwargs)
 
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
+        self._eprint(self.maybe_prefix(message, attributes))
+
+    @contextmanager
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        with self.nested("subtest: " + name, attributes):
+            yield
+
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        self._eprint(
+            self.maybe_prefix(
+                Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
+            )
+        )
+
+        tic = time.time()
+        yield
+        toc = time.time()
+        self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
+
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def print_serial_logs(self, enable: bool) -> None:
+        self._print_serial_logs = enable
+
+    def log_serial(self, message: str, machine: str) -> None:
+        if not self._print_serial_logs:
+            return
+
+        self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
+
+
+class XMLLogger(AbstractLogger):
+    def __init__(self, outfile: str) -> None:
+        self.logfile_handle = codecs.open(outfile, "wb")
+        self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
+        self.queue: Queue[dict[str, str]] = Queue()
+
+        self._print_serial_logs = True
+
+        self.xml.startDocument()
+        self.xml.startElement("logfile", attrs=AttributesImpl({}))
+
     def close(self) -> None:
         self.xml.endElement("logfile")
         self.xml.endDocument()
@@ -54,17 +259,19 @@ class Logger:
 
     def error(self, *args, **kwargs) -> None:  # type: ignore
         self.log(*args, **kwargs)
-        sys.exit(1)
 
     def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
-        self._eprint(self.maybe_prefix(message, attributes))
         self.drain_log_queue()
         self.log_line(message, attributes)
 
+    def print_serial_logs(self, enable: bool) -> None:
+        self._print_serial_logs = enable
+
     def log_serial(self, message: str, machine: str) -> None:
+        if not self._print_serial_logs:
+            return
+
         self.enqueue({"msg": message, "machine": machine, "type": "serial"})
-        if self._print_serial_logs:
-            self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
 
     def enqueue(self, item: Dict[str, str]) -> None:
         self.queue.put(item)
@@ -80,13 +287,12 @@ class Logger:
             pass
 
     @contextmanager
-    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
-        self._eprint(
-            self.maybe_prefix(
-                Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
-            )
-        )
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        with self.nested("subtest: " + name, attributes):
+            yield
 
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
         self.xml.startElement("nest", attrs=AttributesImpl({}))
         self.xml.startElement("head", attrs=AttributesImpl(attributes))
         self.xml.characters(message)
@@ -100,6 +306,3 @@ class Logger:
         self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
 
         self.xml.endElement("nest")
-
-
-rootlog = Logger()
diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py
index 652cc600fad59..3a1d5bc1be764 100644
--- a/nixos/lib/test-driver/test_driver/machine.py
+++ b/nixos/lib/test-driver/test_driver/machine.py
@@ -17,7 +17,7 @@ from pathlib import Path
 from queue import Queue
 from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
 
-from test_driver.logger import rootlog
+from test_driver.logger import AbstractLogger
 
 from .qmp import QMPSession
 
@@ -270,6 +270,7 @@ class Machine:
         out_dir: Path,
         tmp_dir: Path,
         start_command: StartCommand,
+        logger: AbstractLogger,
         name: str = "machine",
         keep_vm_state: bool = False,
         callbacks: Optional[List[Callable]] = None,
@@ -280,6 +281,7 @@ class Machine:
         self.name = name
         self.start_command = start_command
         self.callbacks = callbacks if callbacks is not None else []
+        self.logger = logger
 
         # set up directories
         self.shared_dir = self.tmp_dir / "shared-xchg"
@@ -307,15 +309,15 @@ class Machine:
         return self.booted and self.connected
 
     def log(self, msg: str) -> None:
-        rootlog.log(msg, {"machine": self.name})
+        self.logger.log(msg, {"machine": self.name})
 
     def log_serial(self, msg: str) -> None:
-        rootlog.log_serial(msg, self.name)
+        self.logger.log_serial(msg, self.name)
 
     def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
         my_attrs = {"machine": self.name}
         my_attrs.update(attrs)
-        return rootlog.nested(msg, my_attrs)
+        return self.logger.nested(msg, my_attrs)
 
     def wait_for_monitor_prompt(self) -> str:
         assert self.monitor is not None
@@ -1113,8 +1115,8 @@ class Machine:
 
     def cleanup_statedir(self) -> None:
         shutil.rmtree(self.state_dir)
-        rootlog.log(f"deleting VM state directory {self.state_dir}")
-        rootlog.log("if you want to keep the VM state, pass --keep-vm-state")
+        self.logger.log(f"deleting VM state directory {self.state_dir}")
+        self.logger.log("if you want to keep the VM state, pass --keep-vm-state")
 
     def shutdown(self) -> None:
         """
@@ -1221,7 +1223,7 @@ class Machine:
     def release(self) -> None:
         if self.pid is None:
             return
-        rootlog.info(f"kill machine (pid {self.pid})")
+        self.logger.info(f"kill machine (pid {self.pid})")
         assert self.process
         assert self.shell
         assert self.monitor
diff --git a/nixos/lib/test-driver/test_driver/polling_condition.py b/nixos/lib/test-driver/test_driver/polling_condition.py
index 12cbad69e34e9..1cccaf2c71e74 100644
--- a/nixos/lib/test-driver/test_driver/polling_condition.py
+++ b/nixos/lib/test-driver/test_driver/polling_condition.py
@@ -2,7 +2,7 @@ import time
 from math import isfinite
 from typing import Callable, Optional
 
-from .logger import rootlog
+from test_driver.logger import AbstractLogger
 
 
 class PollingConditionError(Exception):
@@ -13,6 +13,7 @@ class PollingCondition:
     condition: Callable[[], bool]
     seconds_interval: float
     description: Optional[str]
+    logger: AbstractLogger
 
     last_called: float
     entry_count: int
@@ -20,11 +21,13 @@ class PollingCondition:
     def __init__(
         self,
         condition: Callable[[], Optional[bool]],
+        logger: AbstractLogger,
         seconds_interval: float = 2.0,
         description: Optional[str] = None,
     ):
         self.condition = condition  # type: ignore
         self.seconds_interval = seconds_interval
+        self.logger = logger
 
         if description is None:
             if condition.__doc__:
@@ -41,7 +44,7 @@ class PollingCondition:
         if (self.entered or not self.overdue) and not force:
             return True
 
-        with self, rootlog.nested(self.nested_message):
+        with self, self.logger.nested(self.nested_message):
             time_since_last = time.monotonic() - self.last_called
             last_message = (
                 f"Time since last: {time_since_last:.2f}s"
@@ -49,13 +52,13 @@ class PollingCondition:
                 else "(not called yet)"
             )
 
-            rootlog.info(last_message)
+            self.logger.info(last_message)
             try:
                 res = self.condition()  # type: ignore
             except Exception:
                 res = False
             res = res is None or res
-            rootlog.info(self.status_message(res))
+            self.logger.info(self.status_message(res))
             return res
 
     def maybe_raise(self) -> None:
diff --git a/nixos/lib/test-driver/test_driver/vlan.py b/nixos/lib/test-driver/test_driver/vlan.py
index ec9679108e58d..9340fc92ed4c4 100644
--- a/nixos/lib/test-driver/test_driver/vlan.py
+++ b/nixos/lib/test-driver/test_driver/vlan.py
@@ -4,7 +4,7 @@ import pty
 import subprocess
 from pathlib import Path
 
-from test_driver.logger import rootlog
+from test_driver.logger import AbstractLogger
 
 
 class VLan:
@@ -19,17 +19,20 @@ class VLan:
     pid: int
     fd: io.TextIOBase
 
+    logger: AbstractLogger
+
     def __repr__(self) -> str:
         return f"<Vlan Nr. {self.nr}>"
 
-    def __init__(self, nr: int, tmp_dir: Path):
+    def __init__(self, nr: int, tmp_dir: Path, logger: AbstractLogger):
         self.nr = nr
         self.socket_dir = tmp_dir / f"vde{self.nr}.ctl"
+        self.logger = logger
 
         # TODO: don't side-effect environment here
         os.environ[f"QEMU_VDE_SOCKET_{self.nr}"] = str(self.socket_dir)
 
-        rootlog.info("start vlan")
+        self.logger.info("start vlan")
         pty_master, pty_slave = pty.openpty()
 
         # The --hub is required for the scenario determined by
@@ -52,11 +55,11 @@ class VLan:
         assert self.process.stdout is not None
         self.process.stdout.readline()
         if not (self.socket_dir / "ctl").exists():
-            rootlog.error("cannot start vde_switch")
+            self.logger.error("cannot start vde_switch")
 
-        rootlog.info(f"running vlan (pid {self.pid}; ctl {self.socket_dir})")
+        self.logger.info(f"running vlan (pid {self.pid}; ctl {self.socket_dir})")
 
     def __del__(self) -> None:
-        rootlog.info(f"kill vlan (pid {self.pid})")
+        self.logger.info(f"kill vlan (pid {self.pid})")
         self.fd.close()
         self.process.terminate()
diff --git a/nixos/lib/test-script-prepend.py b/nixos/lib/test-script-prepend.py
index 976992ea00158..9d2efdf973031 100644
--- a/nixos/lib/test-script-prepend.py
+++ b/nixos/lib/test-script-prepend.py
@@ -4,7 +4,7 @@
 from test_driver.driver import Driver
 from test_driver.vlan import VLan
 from test_driver.machine import Machine
-from test_driver.logger import Logger
+from test_driver.logger import AbstractLogger
 from typing import Callable, Iterator, ContextManager, Optional, List, Dict, Any, Union
 from typing_extensions import Protocol
 from pathlib import Path
@@ -44,7 +44,7 @@ test_script: Callable[[], None]
 machines: List[Machine]
 vlans: List[VLan]
 driver: Driver
-log: Logger
+log: AbstractLogger
 create_machine: CreateMachineProtocol
 run_tests: Callable[[], None]
 join_all: Callable[[], None]
diff --git a/nixos/lib/testing/driver.nix b/nixos/lib/testing/driver.nix
index 7eb06e023918b..d4f8e0f0c6e38 100644
--- a/nixos/lib/testing/driver.nix
+++ b/nixos/lib/testing/driver.nix
@@ -139,7 +139,7 @@ in
     enableOCR = mkOption {
       description = ''
         Whether to enable Optical Character Recognition functionality for
-        testing graphical programs. See [Machine objects](`ssec-machine-objects`).
+        testing graphical programs. See [`Machine objects`](#ssec-machine-objects).
       '';
       type = types.bool;
       default = false;