about summary refs log tree commit diff
path: root/nixos/lib/test-driver
diff options
context:
space:
mode:
authorAnders Kaseorg <andersk@mit.edu>2020-08-30 15:42:06 -0700
committerAnders Kaseorg <andersk@mit.edu>2020-08-30 15:46:39 -0700
commit59b6664f15c4fefad044ac7e7b1b7d52198adae6 (patch)
treef4b09ada614fbf7700234f2a99e19ef6de071195 /nixos/lib/test-driver
parenta0a421bf5e8beaebc83e8df98694d3a92e42efa8 (diff)
Revert "Merge pull request #96254 from Mic92/logging"
This reverts commit 4fc708567f6d9cf28f9ba426702069aa5a0b89c2, reversing
changes made to 0e54f3a6d8393c31cfae43316904375dcfc77a46.

Fixes #96699.

Signed-off-by: Anders Kaseorg <andersk@mit.edu>
Diffstat (limited to 'nixos/lib/test-driver')
-rw-r--r--nixos/lib/test-driver/test-driver.py365
1 files changed, 221 insertions, 144 deletions
diff --git a/nixos/lib/test-driver/test-driver.py b/nixos/lib/test-driver/test-driver.py
index 65a44be64687d..93f94587c0a50 100644
--- a/nixos/lib/test-driver/test-driver.py
+++ b/nixos/lib/test-driver/test-driver.py
@@ -1,13 +1,19 @@
 #! /somewhere/python3
+from contextlib import contextmanager, _GeneratorContextManager
+from queue import Queue, Empty
+from typing import Tuple, Any, Callable, Dict, Iterator, Optional, List
+from xml.sax.saxutils import XMLGenerator
+import queue
+import io
+import _thread
 import argparse
 import atexit
 import base64
-import io
-import logging
+import codecs
 import os
 import pathlib
+import ptpython.repl
 import pty
-import queue
 import re
 import shlex
 import shutil
@@ -15,12 +21,9 @@ import socket
 import subprocess
 import sys
 import tempfile
-import _thread
 import time
-from contextlib import contextmanager
-from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
-
-import ptpython.repl
+import traceback
+import unicodedata
 
 CHAR_TO_KEY = {
     "A": "shift-a",
@@ -85,17 +88,13 @@ CHAR_TO_KEY = {
     ")": "shift-0x0B",
 }
 
-# Forward reference
+# Forward references
+log: "Logger"
 machines: "List[Machine]"
 
-logging.basicConfig(format="%(message)s")
-logger = logging.getLogger("test-driver")
-logger.setLevel(logging.INFO)
 
-
-class MachineLogAdapter(logging.LoggerAdapter):
-    def process(self, msg: str, kwargs: Any) -> Tuple[str, Any]:
-        return f"{self.extra['machine']}: {msg}", kwargs
+def eprint(*args: object, **kwargs: Any) -> None:
+    print(*args, file=sys.stderr, **kwargs)
 
 
 def make_command(args: list) -> str:
@@ -103,7 +102,8 @@ def make_command(args: list) -> str:
 
 
 def create_vlan(vlan_nr: str) -> Tuple[str, str, "subprocess.Popen[bytes]", Any]:
-    logger.info(f"starting VDE switch for network {vlan_nr}")
+    global log
+    log.log("starting VDE switch for network {}".format(vlan_nr))
     vde_socket = tempfile.mkdtemp(
         prefix="nixos-test-vde-", suffix="-vde{}.ctl".format(vlan_nr)
     )
@@ -142,6 +142,70 @@ def retry(fn: Callable) -> None:
         raise Exception("action timed out")
 
 
+class Logger:
+    def __init__(self) -> None:
+        self.logfile = os.environ.get("LOGFILE", "/dev/null")
+        self.logfile_handle = codecs.open(self.logfile, "wb")
+        self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
+        self.queue: "Queue[Dict[str, str]]" = Queue()
+
+        self.xml.startDocument()
+        self.xml.startElement("logfile", attrs={})
+
+    def close(self) -> None:
+        self.xml.endElement("logfile")
+        self.xml.endDocument()
+        self.logfile_handle.close()
+
+    def sanitise(self, message: str) -> str:
+        return "".join(ch for ch in message if unicodedata.category(ch)[0] != "C")
+
+    def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
+        if "machine" in attributes:
+            return "{}: {}".format(attributes["machine"], message)
+        return message
+
+    def log_line(self, message: str, attributes: Dict[str, str]) -> None:
+        self.xml.startElement("line", attributes)
+        self.xml.characters(message)
+        self.xml.endElement("line")
+
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
+        eprint(self.maybe_prefix(message, attributes))
+        self.drain_log_queue()
+        self.log_line(message, attributes)
+
+    def enqueue(self, message: Dict[str, str]) -> None:
+        self.queue.put(message)
+
+    def drain_log_queue(self) -> None:
+        try:
+            while True:
+                item = self.queue.get_nowait()
+                attributes = {"machine": item["machine"], "type": "serial"}
+                self.log_line(self.sanitise(item["msg"]), attributes)
+        except Empty:
+            pass
+
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        eprint(self.maybe_prefix(message, attributes))
+
+        self.xml.startElement("nest", attrs={})
+        self.xml.startElement("head", attributes)
+        self.xml.characters(message)
+        self.xml.endElement("head")
+
+        tic = time.time()
+        self.drain_log_queue()
+        yield
+        self.drain_log_queue()
+        toc = time.time()
+        self.log("({:.2f} seconds)".format(toc - tic))
+
+        self.xml.endElement("nest")
+
+
 class Machine:
     def __init__(self, args: Dict[str, Any]) -> None:
         if "name" in args:
@@ -171,8 +235,8 @@ class Machine:
         self.pid: Optional[int] = None
         self.socket = None
         self.monitor: Optional[socket.socket] = None
+        self.logger: Logger = args["log"]
         self.allow_reboot = args.get("allowReboot", False)
-        self.logger = MachineLogAdapter(logger, extra=dict(machine=self.name))
 
     @staticmethod
     def create_startcommand(args: Dict[str, str]) -> str:
@@ -228,6 +292,14 @@ class Machine:
     def is_up(self) -> bool:
         return self.booted and self.connected
 
+    def log(self, msg: str) -> None:
+        self.logger.log(msg, {"machine": self.name})
+
+    def nested(self, msg: str, attrs: Dict[str, str] = {}) -> _GeneratorContextManager:
+        my_attrs = {"machine": self.name}
+        my_attrs.update(attrs)
+        return self.logger.nested(msg, my_attrs)
+
     def wait_for_monitor_prompt(self) -> str:
         assert self.monitor is not None
         answer = ""
@@ -242,7 +314,7 @@ class Machine:
 
     def send_monitor_command(self, command: str) -> str:
         message = ("{}\n".format(command)).encode()
-        self.logger.info(f"sending monitor command: {command}")
+        self.log("sending monitor command: {}".format(command))
         assert self.monitor is not None
         self.monitor.send(message)
         return self.wait_for_monitor_prompt()
@@ -309,19 +381,16 @@ class Machine:
         return self.execute("systemctl {}".format(q))
 
     def require_unit_state(self, unit: str, require_state: str = "active") -> None:
-        self.logger.info(
-            f"checking if unit ‘{unit}’ has reached state '{require_state}'"
-        )
-        info = self.get_unit_info(unit)
-        state = info["ActiveState"]
-        if state != require_state:
-            raise Exception(
-                "Expected unit ‘{}’ to to be in state ".format(unit)
-                + "'{}' but it is in state ‘{}’".format(require_state, state)
-            )
-
-    def log(self, message: str) -> None:
-        self.logger.info(message)
+        with self.nested(
+            "checking if unit ‘{}’ has reached state '{}'".format(unit, require_state)
+        ):
+            info = self.get_unit_info(unit)
+            state = info["ActiveState"]
+            if state != require_state:
+                raise Exception(
+                    "Expected unit ‘{}’ to to be in state ".format(unit)
+                    + "'{}' but it is in state ‘{}’".format(require_state, state)
+                )
 
     def execute(self, command: str) -> Tuple[int, str]:
         self.connect()
@@ -345,25 +414,27 @@ class Machine:
         """Execute each command and check that it succeeds."""
         output = ""
         for command in commands:
-            self.logger.info(f"must succeed: {command}")
-            (status, out) = self.execute(command)
-            if status != 0:
-                self.logger.info(f"output: {out}")
-                raise Exception(
-                    "command `{}` failed (exit code {})".format(command, status)
-                )
-            output += out
+            with self.nested("must succeed: {}".format(command)):
+                (status, out) = self.execute(command)
+                if status != 0:
+                    self.log("output: {}".format(out))
+                    raise Exception(
+                        "command `{}` failed (exit code {})".format(command, status)
+                    )
+                output += out
         return output
 
     def fail(self, *commands: str) -> str:
         """Execute each command and check that it fails."""
         output = ""
         for command in commands:
-            self.logger.info(f"must fail: {command}")
-            (status, out) = self.execute(command)
-            if status == 0:
-                raise Exception("command `{}` unexpectedly succeeded".format(command))
-            output += out
+            with self.nested("must fail: {}".format(command)):
+                (status, out) = self.execute(command)
+                if status == 0:
+                    raise Exception(
+                        "command `{}` unexpectedly succeeded".format(command)
+                    )
+                output += out
         return output
 
     def wait_until_succeeds(self, command: str) -> str:
@@ -377,9 +448,9 @@ class Machine:
             status, output = self.execute(command)
             return status == 0
 
-        self.logger.info(f"waiting for success: {command}")
-        retry(check_success)
-        return output
+        with self.nested("waiting for success: {}".format(command)):
+            retry(check_success)
+            return output
 
     def wait_until_fails(self, command: str) -> str:
         """Wait until a command returns failure.
@@ -392,21 +463,21 @@ class Machine:
             status, output = self.execute(command)
             return status != 0
 
-        self.logger.info(f"waiting for failure: {command}")
-        retry(check_failure)
-        return output
+        with self.nested("waiting for failure: {}".format(command)):
+            retry(check_failure)
+            return output
 
     def wait_for_shutdown(self) -> None:
         if not self.booted:
             return
 
-        self.logger.info("waiting for the VM to power off")
-        sys.stdout.flush()
-        self.process.wait()
+        with self.nested("waiting for the VM to power off"):
+            sys.stdout.flush()
+            self.process.wait()
 
-        self.pid = None
-        self.booted = False
-        self.connected = False
+            self.pid = None
+            self.booted = False
+            self.connected = False
 
     def get_tty_text(self, tty: str) -> str:
         status, output = self.execute(
@@ -424,19 +495,19 @@ class Machine:
         def tty_matches(last: bool) -> bool:
             text = self.get_tty_text(tty)
             if last:
-                self.logger.info(
+                self.log(
                     f"Last chance to match /{regexp}/ on TTY{tty}, "
                     f"which currently contains: {text}"
                 )
             return len(matcher.findall(text)) > 0
 
-        self.logger.info(f"waiting for {regexp} to appear on tty {tty}")
-        retry(tty_matches)
+        with self.nested("waiting for {} to appear on tty {}".format(regexp, tty)):
+            retry(tty_matches)
 
     def send_chars(self, chars: List[str]) -> None:
-        self.logger.info(f"sending keys ‘{chars}‘")
-        for char in chars:
-            self.send_key(char)
+        with self.nested("sending keys ‘{}‘".format(chars)):
+            for char in chars:
+                self.send_key(char)
 
     def wait_for_file(self, filename: str) -> None:
         """Waits until the file exists in machine's file system."""
@@ -445,16 +516,16 @@ class Machine:
             status, _ = self.execute("test -e {}".format(filename))
             return status == 0
 
-        self.logger.info(f"waiting for file ‘{filename}‘")
-        retry(check_file)
+        with self.nested("waiting for file ‘{}‘".format(filename)):
+            retry(check_file)
 
     def wait_for_open_port(self, port: int) -> None:
         def port_is_open(_: Any) -> bool:
             status, _ = self.execute("nc -z localhost {}".format(port))
             return status == 0
 
-        self.logger.info(f"waiting for TCP port {port}")
-        retry(port_is_open)
+        with self.nested("waiting for TCP port {}".format(port)):
+            retry(port_is_open)
 
     def wait_for_closed_port(self, port: int) -> None:
         def port_is_closed(_: Any) -> bool:
@@ -476,17 +547,17 @@ class Machine:
         if self.connected:
             return
 
-        self.logger.info("waiting for the VM to finish booting")
-        self.start()
+        with self.nested("waiting for the VM to finish booting"):
+            self.start()
 
-        tic = time.time()
-        self.shell.recv(1024)
-        # TODO: Timeout
-        toc = time.time()
+            tic = time.time()
+            self.shell.recv(1024)
+            # TODO: Timeout
+            toc = time.time()
 
-        self.logger.info("connected to guest root shell")
-        self.logger.info(f"(connecting took {toc - tic:.2f} seconds)")
-        self.connected = True
+            self.log("connected to guest root shell")
+            self.log("(connecting took {:.2f} seconds)".format(toc - tic))
+            self.connected = True
 
     def screenshot(self, filename: str) -> None:
         out_dir = os.environ.get("out", os.getcwd())
@@ -495,12 +566,15 @@ class Machine:
             filename = os.path.join(out_dir, "{}.png".format(filename))
         tmp = "{}.ppm".format(filename)
 
-        self.logger.info(f"making screenshot {filename}")
-        self.send_monitor_command("screendump {}".format(tmp))
-        ret = subprocess.run("pnmtopng {} > {}".format(tmp, filename), shell=True)
-        os.unlink(tmp)
-        if ret.returncode != 0:
-            raise Exception("Cannot convert screenshot")
+        with self.nested(
+            "making screenshot {}".format(filename),
+            {"image": os.path.basename(filename)},
+        ):
+            self.send_monitor_command("screendump {}".format(tmp))
+            ret = subprocess.run("pnmtopng {} > {}".format(tmp, filename), shell=True)
+            os.unlink(tmp)
+            if ret.returncode != 0:
+                raise Exception("Cannot convert screenshot")
 
     def copy_from_host_via_shell(self, source: str, target: str) -> None:
         """Copy a file from the host into the guest by piping it over the
@@ -576,18 +650,20 @@ class Machine:
 
         tess_args = "-c debug_file=/dev/null --psm 11 --oem 2"
 
-        self.logger.info("performing optical character recognition")
-        with tempfile.NamedTemporaryFile() as tmpin:
-            self.send_monitor_command("screendump {}".format(tmpin.name))
+        with self.nested("performing optical character recognition"):
+            with tempfile.NamedTemporaryFile() as tmpin:
+                self.send_monitor_command("screendump {}".format(tmpin.name))
 
-            cmd = "convert {} {} tiff:- | tesseract - - {}".format(
-                magick_args, tmpin.name, tess_args
-            )
-            ret = subprocess.run(cmd, shell=True, capture_output=True)
-            if ret.returncode != 0:
-                raise Exception("OCR failed with exit code {}".format(ret.returncode))
+                cmd = "convert {} {} tiff:- | tesseract - - {}".format(
+                    magick_args, tmpin.name, tess_args
+                )
+                ret = subprocess.run(cmd, shell=True, capture_output=True)
+                if ret.returncode != 0:
+                    raise Exception(
+                        "OCR failed with exit code {}".format(ret.returncode)
+                    )
 
-            return ret.stdout.decode("utf-8")
+                return ret.stdout.decode("utf-8")
 
     def wait_for_text(self, regex: str) -> None:
         def screen_matches(last: bool) -> bool:
@@ -595,15 +671,15 @@ class Machine:
             matches = re.search(regex, text) is not None
 
             if last and not matches:
-                self.logger.info(f"Last OCR attempt failed. Text was: {text}")
+                self.log("Last OCR attempt failed. Text was: {}".format(text))
 
             return matches
 
-        self.logger.info(f"waiting for {regex} to appear on screen")
-        retry(screen_matches)
+        with self.nested("waiting for {} to appear on screen".format(regex)):
+            retry(screen_matches)
 
     def wait_for_console_text(self, regex: str) -> None:
-        self.logger.info(f"waiting for {regex} to appear on console")
+        self.log("waiting for {} to appear on console".format(regex))
         # Buffer the console output, this is needed
         # to match multiline regexes.
         console = io.StringIO()
@@ -626,7 +702,7 @@ class Machine:
         if self.booted:
             return
 
-        self.logger.info("starting vm")
+        self.log("starting vm")
 
         def create_socket(path: str) -> socket.socket:
             if os.path.exists(path):
@@ -683,7 +759,7 @@ class Machine:
 
         # Store last serial console lines for use
         # of wait_for_console_text
-        self.last_lines: queue.Queue = queue.Queue()
+        self.last_lines: Queue = Queue()
 
         def process_serial_output() -> None:
             assert self.process.stdout is not None
@@ -691,7 +767,8 @@ class Machine:
                 # Ignore undecodable bytes that may occur in boot menus
                 line = _line.decode(errors="ignore").replace("\r", "").rstrip()
                 self.last_lines.put(line)
-                self.logger.info(line)
+                eprint("{} # {}".format(self.name, line))
+                self.logger.enqueue({"msg": line, "machine": self.name})
 
         _thread.start_new_thread(process_serial_output, ())
 
@@ -700,10 +777,10 @@ class Machine:
         self.pid = self.process.pid
         self.booted = True
 
-        self.logger.info(f"QEMU running (pid {self.pid})")
+        self.log("QEMU running (pid {})".format(self.pid))
 
     def cleanup_statedir(self) -> None:
-        self.logger.info("delete the VM state directory")
+        self.log("delete the VM state directory")
         if os.path.isfile(self.state_dir):
             shutil.rmtree(self.state_dir)
 
@@ -718,7 +795,7 @@ class Machine:
         if not self.booted:
             return
 
-        self.logger.info("forced crash")
+        self.log("forced crash")
         self.send_monitor_command("quit")
         self.wait_for_shutdown()
 
@@ -738,8 +815,8 @@ class Machine:
             status, _ = self.execute("[ -e /tmp/.X11-unix/X0 ]")
             return status == 0
 
-        self.logger.info("waiting for the X11 server")
-        retry(check_x)
+        with self.nested("waiting for the X11 server"):
+            retry(check_x)
 
     def get_window_names(self) -> List[str]:
         return self.succeed(
@@ -752,14 +829,15 @@ class Machine:
         def window_is_visible(last_try: bool) -> bool:
             names = self.get_window_names()
             if last_try:
-                self.logger.info(
-                    f"Last chance to match {regexp} on the window list, "
-                    + f"which currently contains: {', '.join(names)}"
+                self.log(
+                    "Last chance to match {} on the window list,".format(regexp)
+                    + " which currently contains: "
+                    + ", ".join(names)
                 )
             return any(pattern.search(name) for name in names)
 
-        self.logger.info("Waiting for a window to appear")
-        retry(window_is_visible)
+        with self.nested("Waiting for a window to appear"):
+            retry(window_is_visible)
 
     def sleep(self, secs: int) -> None:
         # We want to sleep in *guest* time, not *host* time.
@@ -788,22 +866,23 @@ class Machine:
 
 def create_machine(args: Dict[str, Any]) -> Machine:
     global log
+    args["log"] = log
     args["redirectSerial"] = os.environ.get("USE_SERIAL", "0") == "1"
     return Machine(args)
 
 
 def start_all() -> None:
     global machines
-    logger.info("starting all VMs")
-    for machine in machines:
-        machine.start()
+    with log.nested("starting all VMs"):
+        for machine in machines:
+            machine.start()
 
 
 def join_all() -> None:
     global machines
-    logger.info("waiting for all VMs to finish")
-    for machine in machines:
-        machine.wait_for_shutdown()
+    with log.nested("waiting for all VMs to finish"):
+        for machine in machines:
+            machine.wait_for_shutdown()
 
 
 def test_script() -> None:
@@ -814,12 +893,13 @@ def run_tests() -> None:
     global machines
     tests = os.environ.get("tests", None)
     if tests is not None:
-        logger.info("running the VM test script")
-        try:
-            exec(tests, globals())
-        except Exception:
-            logging.exception("error:")
-            sys.exit(1)
+        with log.nested("running the VM test script"):
+            try:
+                exec(tests, globals())
+            except Exception as e:
+                eprint("error: ")
+                traceback.print_exc()
+                sys.exit(1)
     else:
         ptpython.repl.embed(locals(), globals())
 
@@ -832,19 +912,18 @@ def run_tests() -> None:
 
 @contextmanager
 def subtest(name: str) -> Iterator[None]:
-    logger.info(name)
-    try:
-        yield
-        return True
-    except Exception as e:
-        logger.info(f'Test "{name}" failed with error: "{e}"')
-        raise e
+    with log.nested(name):
+        try:
+            yield
+            return True
+        except Exception as e:
+            log.log(f'Test "{name}" failed with error: "{e}"')
+            raise e
 
     return False
 
 
-def main() -> None:
-    global machines
+if __name__ == "__main__":
     arg_parser = argparse.ArgumentParser()
     arg_parser.add_argument(
         "-K",
@@ -854,6 +933,8 @@ def main() -> None:
     )
     (cli_args, vm_scripts) = arg_parser.parse_known_args()
 
+    log = Logger()
+
     vlan_nrs = list(dict.fromkeys(os.environ.get("VLANS", "").split()))
     vde_sockets = [create_vlan(v) for v in vlan_nrs]
     for nr, vde_socket, _, _ in vde_sockets:
@@ -864,27 +945,23 @@ def main() -> None:
         if not cli_args.keep_vm_state:
             machine.cleanup_statedir()
     machine_eval = [
-        "global {0}; {0} = machines[{1}]".format(m.name, idx)
-        for idx, m in enumerate(machines)
+        "{0} = machines[{1}]".format(m.name, idx) for idx, m in enumerate(machines)
     ]
     exec("\n".join(machine_eval))
 
     @atexit.register
     def clean_up() -> None:
-        logger.info("cleaning up")
-        for machine in machines:
-            if machine.pid is None:
-                continue
-            logger.info(f"killing {machine.name} (pid {machine.pid})")
-            machine.process.kill()
-        for _, _, process, _ in vde_sockets:
-            process.terminate()
+        with log.nested("cleaning up"):
+            for machine in machines:
+                if machine.pid is None:
+                    continue
+                log.log("killing {} (pid {})".format(machine.name, machine.pid))
+                machine.process.kill()
+            for _, _, process, _ in vde_sockets:
+                process.terminate()
+        log.close()
 
     tic = time.time()
     run_tests()
     toc = time.time()
     print("test script finished in {:.2f}s".format(toc - tic))
-
-
-if __name__ == "__main__":
-    main()