about summary refs log tree commit diff
path: root/nixos/lib/test-driver/test_driver/driver.py
diff options
context:
space:
mode:
Diffstat (limited to 'nixos/lib/test-driver/test_driver/driver.py')
-rw-r--r--nixos/lib/test-driver/test_driver/driver.py33
1 files changed, 29 insertions, 4 deletions
diff --git a/nixos/lib/test-driver/test_driver/driver.py b/nixos/lib/test-driver/test_driver/driver.py
index 49a42fe5fb4ef..880b1c5fdec0d 100644
--- a/nixos/lib/test-driver/test_driver/driver.py
+++ b/nixos/lib/test-driver/test_driver/driver.py
@@ -10,6 +10,28 @@ from test_driver.vlan import VLan
 from test_driver.polling_condition import PollingCondition
 
 
+def get_tmp_dir() -> Path:
+    """Returns a temporary directory that is defined by TMPDIR, TEMP, TMP or CWD
+    Raises an exception in case the retrieved temporary directory is not writeable
+    See https://docs.python.org/3/library/tempfile.html#tempfile.gettempdir
+    """
+    tmp_dir = Path(tempfile.gettempdir())
+    tmp_dir.mkdir(mode=0o700, exist_ok=True)
+    if not tmp_dir.is_dir():
+        raise NotADirectoryError(
+            "The directory defined by TMPDIR, TEMP, TMP or CWD: {0} is not a directory".format(
+                tmp_dir
+            )
+        )
+    if not os.access(tmp_dir, os.W_OK):
+        raise PermissionError(
+            "The directory defined by TMPDIR, TEMP, TMP, or CWD: {0} is not writeable".format(
+                tmp_dir
+            )
+        )
+    return tmp_dir
+
+
 class Driver:
     """A handle to the driver that sets up the environment
     and runs the tests"""
@@ -24,12 +46,13 @@ class Driver:
         start_scripts: List[str],
         vlans: List[int],
         tests: str,
+        out_dir: Path,
         keep_vm_state: bool = False,
     ):
         self.tests = tests
+        self.out_dir = out_dir
 
-        tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
-        tmp_dir.mkdir(mode=0o700, exist_ok=True)
+        tmp_dir = get_tmp_dir()
 
         with rootlog.nested("start all VLans"):
             self.vlans = [VLan(nr, tmp_dir) for nr in vlans]
@@ -47,6 +70,7 @@ class Driver:
                 name=cmd.machine_name,
                 tmp_dir=tmp_dir,
                 callbacks=[self.check_polling_conditions],
+                out_dir=self.out_dir,
             )
             for cmd in cmd(start_scripts)
         ]
@@ -141,8 +165,8 @@ class Driver:
             "Using legacy create_machine(), please instantiate the"
             "Machine class directly, instead"
         )
-        tmp_dir = Path(os.environ.get("TMPDIR", tempfile.gettempdir()))
-        tmp_dir.mkdir(mode=0o700, exist_ok=True)
+
+        tmp_dir = get_tmp_dir()
 
         if args.get("startCommand"):
             start_command: str = args.get("startCommand", "")
@@ -154,6 +178,7 @@ class Driver:
 
         return Machine(
             tmp_dir=tmp_dir,
+            out_dir=self.out_dir,
             start_command=cmd,
             name=name,
             keep_vm_state=args.get("keep_vm_state", False),