about summary refs log tree commit diff
path: root/nixos
diff options
context:
space:
mode:
authorJacek Galowicz <jacek@galowicz.de>2022-01-26 10:13:39 +0100
committerGitHub <noreply@github.com>2022-01-26 10:13:39 +0100
commitab4e5af009fca2143b602f7f677b430440788701 (patch)
tree3777a62334f84d7953c9e4cc155a9c39050d5497 /nixos
parent3b55df970b9068dc8c4b94ab0d0989d3652a7413 (diff)
parent446c21fdc75c6a045bf5b715a315bdf87b5e71cb (diff)
Merge pull request #153854 from marijanp/test-driver-out-dir
nixos/test-driver: use an argument instead of the out env-var
Diffstat (limited to 'nixos')
-rwxr-xr-xnixos/lib/test-driver/test_driver/__init__.py32
-rw-r--r--nixos/lib/test-driver/test_driver/driver.py33
-rw-r--r--nixos/lib/test-driver/test_driver/machine.py9
3 files changed, 64 insertions, 10 deletions
diff --git a/nixos/lib/test-driver/test_driver/__init__.py b/nixos/lib/test-driver/test_driver/__init__.py
index 5477ab5cd038e..498a4f56c55bb 100755
--- a/nixos/lib/test-driver/test_driver/__init__.py
+++ b/nixos/lib/test-driver/test_driver/__init__.py
@@ -33,6 +33,22 @@ class EnvDefault(argparse.Action):
         setattr(namespace, self.dest, values)
 
 
+def writeable_dir(arg: str) -> Path:
+    """Raises an ArgumentTypeError if the given argument isn't a writeable directory
+    Note: We want to fail as early as possible if a directory isn't writeable,
+    since an executed nixos-test could fail (very late) because of the test-driver
+    writing in a directory without proper permissions.
+    """
+    path = Path(arg)
+    if not path.is_dir():
+        raise argparse.ArgumentTypeError("{0} is not a directory".format(path))
+    if not os.access(path, os.W_OK):
+        raise argparse.ArgumentTypeError(
+            "{0} is not a writeable directory".format(path)
+        )
+    return path
+
+
 def main() -> None:
     arg_parser = argparse.ArgumentParser(prog="nixos-test-driver")
     arg_parser.add_argument(
@@ -64,6 +80,14 @@ def main() -> None:
         help="vlans to span by the driver",
     )
     arg_parser.add_argument(
+        "-o",
+        "--output_directory",
+        help="""The path to the directory where outputs copied from the VM will be placed.
+                By e.g. Machine.copy_from_vm or Machine.screenshot""",
+        default=Path.cwd(),
+        type=writeable_dir,
+    )
+    arg_parser.add_argument(
         "testscript",
         action=EnvDefault,
         envvar="testScript",
@@ -77,7 +101,11 @@ def main() -> None:
         rootlog.info("Machine state will be reset. To keep it, pass --keep-vm-state")
 
     with Driver(
-        args.start_scripts, args.vlans, args.testscript.read_text(), args.keep_vm_state
+        args.start_scripts,
+        args.vlans,
+        args.testscript.read_text(),
+        args.output_directory.resolve(),
+        args.keep_vm_state,
     ) as driver:
         if args.interactive:
             ptpython.repl.embed(driver.test_symbols(), {})
@@ -94,7 +122,7 @@ def generate_driver_symbols() -> None:
     in user's test scripts. That list is then used by pyflakes to lint those
     scripts.
     """
-    d = Driver([], [], "")
+    d = Driver([], [], "", Path())
     test_symbols = d.test_symbols()
     with open("driver-symbols", "w") as fp:
         fp.write(",".join(test_symbols.keys()))
diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py
index 49a42fe5fb4ef..880b1c5fdec0d 100644
--- a/nixos/lib/test-driver/test_driver/driver.py
+++ b/nixos/lib/test-driver/test_driver/driver.py
@@ -10,6 +10,28 @@ from test_driver.vlan import VLan
 from test_driver.polling_condition import PollingCondition
 
 
+def get_tmp_dir() -> Path:
+    """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
+    Raises an exception in case the retrieved temporary directory is not writeable
+    See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
+    """
+    tmp_dir = Path(tempfile.gettempdir())
+    tmp_dir.mkdir(mode=0o700, exist_ok=True)
+    if not tmp_dir.is_dir():
+        raise NotADirectoryError(
+            "The directory defined by TMPDIR, TEMP, TMP or CWD: {0} is not a directory".format(
+                tmp_dir
+            )
+        )
+    if not os.access(tmp_dir, os.W_OK):
+        raise PermissionError(
+            "The directory defined by TMPDIR, TEMP, TMP, or CWD: {0} is not writeable".format(
+                tmp_dir
+            )
+        )
+    return tmp_dir
+
+
 class Driver:
     """A handle to the driver that sets up the environment
     and runs the tests"""
@@ -24,12 +46,13 @@ class Driver:
         start_scripts: List[str],
         vlans: List[int],
         tests: str,
+        out_dir: Path,
         keep_vm_state: bool = False,
     ):
         self.tests = tests
+        self.out_dir = out_dir
 
-        tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
-        tmp_dir.mkdir(mode=0o700, exist_ok=True)
+        tmp_dir = get_tmp_dir()
 
         with rootlog.nested("start all VLans"):
             self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
@@ -47,6 +70,7 @@ class Driver:
                 name=cmd.machine_name,
                 tmp_dir=tmp_dir,
                 callbacks=[self.check_polling_conditions],
+                out_dir=self.out_dir,
             )
             for cmd in cmd(start_scripts)
         ]
@@ -141,8 +165,8 @@ class Driver:
             "Using legacy create_machine(), please instantiate the"
             "Machine class directly, instead"
         )
-        tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
-        tmp_dir.mkdir(mode=0o700, exist_ok=True)
+
+        tmp_dir = get_tmp_dir()
 
         if args.get("startCommand"):
             start_command: str = args.get("startCommand", "")
@@ -154,6 +178,7 @@ class Driver:
 
         return Machine(
             tmp_dir=tmp_dir,
+            out_dir=self.out_dir,
             start_command=cmd,
             name=name,
             keep_vm_state=args.get("keep_vm_state", False),
diff --git a/nixos/lib/test-driver/test_driver/machine.py b/nixos/lib/test-driver/test_driver/machine.py
index e050cbd7d990c..a41c419ebe6a4 100644
--- a/nixos/lib/test-driver/test_driver/machine.py
+++ b/nixos/lib/test-driver/test_driver/machine.py
@@ -297,6 +297,7 @@ class Machine:
     the machine lifecycle with the help of a start script / command."""
 
     name: str
+    out_dir: Path
     tmp_dir: Path
     shared_dir: Path
     state_dir: Path
@@ -325,6 +326,7 @@ class Machine:
 
     def __init__(
         self,
+        out_dir: Path,
         tmp_dir: Path,
         start_command: StartCommand,
         name: str = "machine",
@@ -332,6 +334,7 @@ class Machine:
         allow_reboot: bool = False,
         callbacks: Optional[List[Callable]] = None,
     ) -> None:
+        self.out_dir = out_dir
         self.tmp_dir = tmp_dir
         self.keep_vm_state = keep_vm_state
         self.allow_reboot = allow_reboot
@@ -702,10 +705,9 @@ class Machine:
             self.connected = True
 
     def screenshot(self, filename: str) -> None:
-        out_dir = os.environ.get("out", os.getcwd())
         word_pattern = re.compile(r"^\w+$")
         if word_pattern.match(filename):
-            filename = os.path.join(out_dir, "{}.png".format(filename))
+            filename = os.path.join(self.out_dir, "{}.png".format(filename))
         tmp = "{}.ppm".format(filename)
 
         with self.nested(
@@ -756,7 +758,6 @@ class Machine:
         all the VMs (using a temporary directory).
         """
         # Compute the source, target, and intermediate shared file names
-        out_dir = Path(os.environ.get("out", os.getcwd()))
         vm_src = Path(source)
         with tempfile.TemporaryDirectory(dir=self.shared_dir) as shared_td:
             shared_temp = Path(shared_td)
@@ -766,7 +767,7 @@ class Machine:
             # Copy the file to the shared directory inside VM
             self.succeed(make_command(["mkdir", "-p", vm_shared_temp]))
             self.succeed(make_command(["cp", "-r", vm_src, vm_intermediate]))
-            abs_target = out_dir / target_dir / vm_src.name
+            abs_target = self.out_dir / target_dir / vm_src.name
             abs_target.parent.mkdir(exist_ok=True, parents=True)
             # Copy the file from the shared directory outside VM
             if intermediate.is_dir():