about summary refs log tree commit diff
path: root/nixos/lib/test-driver/test_driver/logger.py
diff options
context:
space:
mode:
Diffstat (limited to 'nixos/lib/test-driver/test_driver/logger.py')
-rw-r--r--nixos/lib/test-driver/test_driver/logger.py249
1 files changed, 226 insertions, 23 deletions
diff --git a/nixos/lib/test-driver/test_driver/logger.py b/nixos/lib/test-driver/test_driver/logger.py
index 0b0623bddfa1e..484829254b812 100644
--- a/nixos/lib/test-driver/test_driver/logger.py
+++ b/nixos/lib/test-driver/test_driver/logger.py
@@ -1,33 +1,238 @@
+import atexit
 import codecs
 import os
 import sys
 import time
 import unicodedata
-from contextlib import contextmanager
+from abc import ABC, abstractmethod
+from contextlib import ExitStack, contextmanager
+from pathlib import Path
 from queue import Empty, Queue
-from typing import Any, Dict, Iterator
+from typing import Any, Dict, Iterator, List
 from xml.sax.saxutils import XMLGenerator
 from xml.sax.xmlreader import AttributesImpl
 
 from colorama import Fore, Style
+from junit_xml import TestCase, TestSuite
 
 
-class Logger:
-    def __init__(self) -> None:
-        self.logfile = os.environ.get("LOGFILE", "/dev/null")
-        self.logfile_handle = codecs.open(self.logfile, "wb")
-        self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
-        self.queue: "Queue[Dict[str, str]]" = Queue()
+class AbstractLogger(ABC):
+    @abstractmethod
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
+        pass
 
-        self.xml.startDocument()
-        self.xml.startElement("logfile", attrs=AttributesImpl({}))
+    @abstractmethod
+    @contextmanager
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        pass
+
+    @abstractmethod
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        pass
+
+    @abstractmethod
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        pass
+
+    @abstractmethod
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        pass
+
+    @abstractmethod
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        pass
+
+    @abstractmethod
+    def log_serial(self, message: str, machine: str) -> None:
+        pass
+
+    @abstractmethod
+    def print_serial_logs(self, enable: bool) -> None:
+        pass
+
+
+class JunitXMLLogger(AbstractLogger):
+    class TestCaseState:
+        def __init__(self) -> None:
+            self.stdout = ""
+            self.stderr = ""
+            self.failure = False
+
+    def __init__(self, outfile: Path) -> None:
+        self.tests: dict[str, JunitXMLLogger.TestCaseState] = {
+            "main": self.TestCaseState()
+        }
+        self.currentSubtest = "main"
+        self.outfile: Path = outfile
+        self._print_serial_logs = True
+        atexit.register(self.close)
+
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
+        self.tests[self.currentSubtest].stdout += message + os.linesep
+
+    @contextmanager
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        old_test = self.currentSubtest
+        self.tests.setdefault(name, self.TestCaseState())
+        self.currentSubtest = name
+
+        yield
+
+        self.currentSubtest = old_test
+
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        self.log(message)
+        yield
+
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        self.tests[self.currentSubtest].stdout += args[0] + os.linesep
+
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        self.tests[self.currentSubtest].stdout += args[0] + os.linesep
+
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        self.tests[self.currentSubtest].stderr += args[0] + os.linesep
+        self.tests[self.currentSubtest].failure = True
+
+    def log_serial(self, message: str, machine: str) -> None:
+        if not self._print_serial_logs:
+            return
+
+        self.log(f"{machine} # {message}")
+
+    def print_serial_logs(self, enable: bool) -> None:
+        self._print_serial_logs = enable
+
+    def close(self) -> None:
+        with open(self.outfile, "w") as f:
+            test_cases = []
+            for name, test_case_state in self.tests.items():
+                tc = TestCase(
+                    name,
+                    stdout=test_case_state.stdout,
+                    stderr=test_case_state.stderr,
+                )
+                if test_case_state.failure:
+                    tc.add_failure_info("test case failed")
+
+                test_cases.append(tc)
+            ts = TestSuite("NixOS integration test", test_cases)
+            f.write(TestSuite.to_xml_string([ts]))
+
+
+class CompositeLogger(AbstractLogger):
+    def __init__(self, logger_list: List[AbstractLogger]) -> None:
+        self.logger_list = logger_list
+
+    def add_logger(self, logger: AbstractLogger) -> None:
+        self.logger_list.append(logger)
+
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
+        for logger in self.logger_list:
+            logger.log(message, attributes)
+
+    @contextmanager
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        with ExitStack() as stack:
+            for logger in self.logger_list:
+                stack.enter_context(logger.subtest(name, attributes))
+            yield
+
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        with ExitStack() as stack:
+            for logger in self.logger_list:
+                stack.enter_context(logger.nested(message, attributes))
+            yield
+
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        for logger in self.logger_list:
+            logger.info(*args, **kwargs)
+
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        for logger in self.logger_list:
+            logger.warning(*args, **kwargs)
+
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        for logger in self.logger_list:
+            logger.error(*args, **kwargs)
+        sys.exit(1)
 
+    def print_serial_logs(self, enable: bool) -> None:
+        for logger in self.logger_list:
+            logger.print_serial_logs(enable)
+
+    def log_serial(self, message: str, machine: str) -> None:
+        for logger in self.logger_list:
+            logger.log_serial(message, machine)
+
+
+class TerminalLogger(AbstractLogger):
+    def __init__(self) -> None:
         self._print_serial_logs = True
 
+    def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str:
+        if "machine" in attributes:
+            return f"{attributes['machine']}: {message}"
+        return message
+
     @staticmethod
     def _eprint(*args: object, **kwargs: Any) -> None:
         print(*args, file=sys.stderr, **kwargs)
 
+    def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
+        self._eprint(self.maybe_prefix(message, attributes))
+
+    @contextmanager
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        with self.nested("subtest: " + name, attributes):
+            yield
+
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        self._eprint(
+            self.maybe_prefix(
+                Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
+            )
+        )
+
+        tic = time.time()
+        yield
+        toc = time.time()
+        self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
+
+    def info(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def warning(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def error(self, *args, **kwargs) -> None:  # type: ignore
+        self.log(*args, **kwargs)
+
+    def print_serial_logs(self, enable: bool) -> None:
+        self._print_serial_logs = enable
+
+    def log_serial(self, message: str, machine: str) -> None:
+        if not self._print_serial_logs:
+            return
+
+        self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
+
+
+class XMLLogger(AbstractLogger):
+    def __init__(self, outfile: str) -> None:
+        self.logfile_handle = codecs.open(outfile, "wb")
+        self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8")
+        self.queue: Queue[dict[str, str]] = Queue()
+
+        self._print_serial_logs = True
+
+        self.xml.startDocument()
+        self.xml.startElement("logfile", attrs=AttributesImpl({}))
+
     def close(self) -> None:
         self.xml.endElement("logfile")
         self.xml.endDocument()
@@ -54,17 +259,19 @@ class Logger:
 
     def error(self, *args, **kwargs) -> None:  # type: ignore
         self.log(*args, **kwargs)
-        sys.exit(1)
 
     def log(self, message: str, attributes: Dict[str, str] = {}) -> None:
-        self._eprint(self.maybe_prefix(message, attributes))
         self.drain_log_queue()
         self.log_line(message, attributes)
 
+    def print_serial_logs(self, enable: bool) -> None:
+        self._print_serial_logs = enable
+
     def log_serial(self, message: str, machine: str) -> None:
+        if not self._print_serial_logs:
+            return
+
         self.enqueue({"msg": message, "machine": machine, "type": "serial"})
-        if self._print_serial_logs:
-            self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL)
 
     def enqueue(self, item: Dict[str, str]) -> None:
         self.queue.put(item)
@@ -80,13 +287,12 @@ class Logger:
             pass
 
     @contextmanager
-    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
-        self._eprint(
-            self.maybe_prefix(
-                Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes
-            )
-        )
+    def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
+        with self.nested("subtest: " + name, attributes):
+            yield
 
+    @contextmanager
+    def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]:
         self.xml.startElement("nest", attrs=AttributesImpl({}))
         self.xml.startElement("head", attrs=AttributesImpl(attributes))
         self.xml.characters(message)
@@ -100,6 +306,3 @@ class Logger:
         self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)")
 
         self.xml.endElement("nest")
-
-
-rootlog = Logger()