diff options
Diffstat (limited to 'nixos/lib/test-driver/test_driver/logger.py')
-rw-r--r-- | nixos/lib/test-driver/test_driver/logger.py | 249 |
1 files changed, 226 insertions, 23 deletions
diff --git a/nixos/lib/test-driver/test_driver/logger.py b/nixos/lib/test-driver/test_driver/logger.py index 0b0623bddfa1e..484829254b812 100644 --- a/nixos/lib/test-driver/test_driver/logger.py +++ b/nixos/lib/test-driver/test_driver/logger.py @@ -1,33 +1,238 @@ +import atexit import codecs import os import sys import time import unicodedata -from contextlib import contextmanager +from abc import ABC, abstractmethod +from contextlib import ExitStack, contextmanager +from pathlib import Path from queue import Empty, Queue -from typing import Any, Dict, Iterator +from typing import Any, Dict, Iterator, List from xml.sax.saxutils import XMLGenerator from xml.sax.xmlreader import AttributesImpl from colorama import Fore, Style +from junit_xml import TestCase, TestSuite -class Logger: - def __init__(self) -> None: - self.logfile = os.environ.get("LOGFILE", "/dev/null") - self.logfile_handle = codecs.open(self.logfile, "wb") - self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8") - self.queue: "Queue[Dict[str, str]]" = Queue() +class AbstractLogger(ABC): + @abstractmethod + def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + pass - self.xml.startDocument() - self.xml.startElement("logfile", attrs=AttributesImpl({})) + @abstractmethod + @contextmanager + def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + pass + + @abstractmethod + @contextmanager + def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + pass + + @abstractmethod + def info(self, *args, **kwargs) -> None: # type: ignore + pass + + @abstractmethod + def warning(self, *args, **kwargs) -> None: # type: ignore + pass + + @abstractmethod + def error(self, *args, **kwargs) -> None: # type: ignore + pass + + @abstractmethod + def log_serial(self, message: str, machine: str) -> None: + pass + + @abstractmethod + def print_serial_logs(self, enable: bool) -> None: + pass + + +class JunitXMLLogger(AbstractLogger): + class TestCaseState: + def __init__(self) -> None: + self.stdout = "" + self.stderr = "" + self.failure = False + + def __init__(self, outfile: Path) -> None: + self.tests: dict[str, JunitXMLLogger.TestCaseState] = { + "main": self.TestCaseState() + } + self.currentSubtest = "main" + self.outfile: Path = outfile + self._print_serial_logs = True + atexit.register(self.close) + + def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + self.tests[self.currentSubtest].stdout += message + os.linesep + + @contextmanager + def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + old_test = self.currentSubtest + self.tests.setdefault(name, self.TestCaseState()) + self.currentSubtest = name + + yield + + self.currentSubtest = old_test + + @contextmanager + def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + self.log(message) + yield + + def info(self, *args, **kwargs) -> None: # type: ignore + self.tests[self.currentSubtest].stdout += args[0] + os.linesep + + def warning(self, *args, **kwargs) -> None: # type: ignore + self.tests[self.currentSubtest].stdout += args[0] + os.linesep + + def error(self, *args, **kwargs) -> None: # type: ignore + self.tests[self.currentSubtest].stderr += args[0] + os.linesep + self.tests[self.currentSubtest].failure = True + + def log_serial(self, message: str, machine: str) -> None: + if not self._print_serial_logs: + return + + self.log(f"{machine} # {message}") + + def print_serial_logs(self, enable: bool) -> None: + self._print_serial_logs = enable + + def close(self) -> None: + with open(self.outfile, "w") as f: + test_cases = [] + for name, test_case_state in self.tests.items(): + tc = TestCase( + name, + stdout=test_case_state.stdout, + stderr=test_case_state.stderr, + ) + if test_case_state.failure: + tc.add_failure_info("test case failed") + + test_cases.append(tc) + ts = TestSuite("NixOS integration test", test_cases) + f.write(TestSuite.to_xml_string([ts])) + + +class CompositeLogger(AbstractLogger): + def __init__(self, logger_list: List[AbstractLogger]) -> None: + self.logger_list = logger_list + + def add_logger(self, logger: AbstractLogger) -> None: + self.logger_list.append(logger) + + def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + for logger in self.logger_list: + logger.log(message, attributes) + + @contextmanager + def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + with ExitStack() as stack: + for logger in self.logger_list: + stack.enter_context(logger.subtest(name, attributes)) + yield + + @contextmanager + def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + with ExitStack() as stack: + for logger in self.logger_list: + stack.enter_context(logger.nested(message, attributes)) + yield + + def info(self, *args, **kwargs) -> None: # type: ignore + for logger in self.logger_list: + logger.info(*args, **kwargs) + + def warning(self, *args, **kwargs) -> None: # type: ignore + for logger in self.logger_list: + logger.warning(*args, **kwargs) + + def error(self, *args, **kwargs) -> None: # type: ignore + for logger in self.logger_list: + logger.error(*args, **kwargs) + sys.exit(1) + def print_serial_logs(self, enable: bool) -> None: + for logger in self.logger_list: + logger.print_serial_logs(enable) + + def log_serial(self, message: str, machine: str) -> None: + for logger in self.logger_list: + logger.log_serial(message, machine) + + +class TerminalLogger(AbstractLogger): + def __init__(self) -> None: self._print_serial_logs = True + def maybe_prefix(self, message: str, attributes: Dict[str, str]) -> str: + if "machine" in attributes: + return f"{attributes['machine']}: {message}" + return message + @staticmethod def _eprint(*args: object, **kwargs: Any) -> None: print(*args, file=sys.stderr, **kwargs) + def log(self, message: str, attributes: Dict[str, str] = {}) -> None: + self._eprint(self.maybe_prefix(message, attributes)) + + @contextmanager + def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + with self.nested("subtest: " + name, attributes): + yield + + @contextmanager + def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + self._eprint( + self.maybe_prefix( + Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes + ) + ) + + tic = time.time() + yield + toc = time.time() + self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)") + + def info(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def warning(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def error(self, *args, **kwargs) -> None: # type: ignore + self.log(*args, **kwargs) + + def print_serial_logs(self, enable: bool) -> None: + self._print_serial_logs = enable + + def log_serial(self, message: str, machine: str) -> None: + if not self._print_serial_logs: + return + + self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL) + + +class XMLLogger(AbstractLogger): + def __init__(self, outfile: str) -> None: + self.logfile_handle = codecs.open(outfile, "wb") + self.xml = XMLGenerator(self.logfile_handle, encoding="utf-8") + self.queue: Queue[dict[str, str]] = Queue() + + self._print_serial_logs = True + + self.xml.startDocument() + self.xml.startElement("logfile", attrs=AttributesImpl({})) + def close(self) -> None: self.xml.endElement("logfile") self.xml.endDocument() @@ -54,17 +259,19 @@ class Logger: def error(self, *args, **kwargs) -> None: # type: ignore self.log(*args, **kwargs) - sys.exit(1) def log(self, message: str, attributes: Dict[str, str] = {}) -> None: - self._eprint(self.maybe_prefix(message, attributes)) self.drain_log_queue() self.log_line(message, attributes) + def print_serial_logs(self, enable: bool) -> None: + self._print_serial_logs = enable + def log_serial(self, message: str, machine: str) -> None: + if not self._print_serial_logs: + return + self.enqueue({"msg": message, "machine": machine, "type": "serial"}) - if self._print_serial_logs: - self._eprint(Style.DIM + f"{machine} # {message}" + Style.RESET_ALL) def enqueue(self, item: Dict[str, str]) -> None: self.queue.put(item) @@ -80,13 +287,12 @@ class Logger: pass @contextmanager - def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: - self._eprint( - self.maybe_prefix( - Style.BRIGHT + Fore.GREEN + message + Style.RESET_ALL, attributes - ) - ) + def subtest(self, name: str, attributes: Dict[str, str] = {}) -> Iterator[None]: + with self.nested("subtest: " + name, attributes): + yield + @contextmanager + def nested(self, message: str, attributes: Dict[str, str] = {}) -> Iterator[None]: self.xml.startElement("nest", attrs=AttributesImpl({})) self.xml.startElement("head", attrs=AttributesImpl(attributes)) self.xml.characters(message) @@ -100,6 +306,3 @@ class Logger: self.log(f"(finished: {message}, in {toc - tic:.2f} seconds)") self.xml.endElement("nest") - - -rootlog = Logger() |