about summary refs log tree commit diff
path: root/nixos/lib
diff options
context:
space:
mode:
authorJörg Thalheim <joerg@thalheim.io>2023-09-30 09:59:17 +0200
committerJörg Thalheim <joerg@thalheim.io>2023-09-30 10:18:06 +0200
commit9ac9e8407fbe084fb6667bdd3bd0c569ea454271 (patch)
tree2dfa8af103a6e35e8ee31e91af292a459c280d17 /nixos/lib
parenta1666863fdf1861188e1378a6dde2a19b1efa1c0 (diff)
nixos/test-driver: fix type errors in extract-docstrings
Diffstat (limited to 'nixos/lib')
-rw-r--r--nixos/lib/test-driver/extract-docstrings.py42
1 files changed, 25 insertions, 17 deletions
diff --git a/nixos/lib/test-driver/extract-docstrings.py b/nixos/lib/test-driver/extract-docstrings.py
index 5aec4c89a9d74..a12e586882a67 100644
--- a/nixos/lib/test-driver/extract-docstrings.py
+++ b/nixos/lib/test-driver/extract-docstrings.py
@@ -1,5 +1,6 @@
 import ast
 import sys
+from pathlib import Path
 
 """
 This program takes all the Machine class methods and prints its methods in
@@ -40,27 +41,34 @@ some_function(param1, param2)
 
 """
 
-assert len(sys.argv) == 2
 
-with open(sys.argv[1], "r") as f:
-    module = ast.parse(f.read())
+def main() -> None:
+    if len(sys.argv) != 2:
+        print(f"Usage: {sys.argv[0]} <path-to-test-driver>")
+        sys.exit(1)
 
-class_definitions = (node for node in module.body if isinstance(node, ast.ClassDef))
+    module = ast.parse(Path(sys.argv[1]).read_text())
 
-machine_class = next(filter(lambda x: x.name == "Machine", class_definitions))
-assert machine_class is not None
+    class_definitions = (node for node in module.body if isinstance(node, ast.ClassDef))
 
-function_definitions = [
-    node for node in machine_class.body if isinstance(node, ast.FunctionDef)
-]
-function_definitions.sort(key=lambda x: x.name)
+    machine_class = next(filter(lambda x: x.name == "Machine", class_definitions))
+    assert machine_class is not None
 
-for f in function_definitions:
-    docstr = ast.get_docstring(f)
-    if docstr is not None:
-        args = ", ".join((a.arg for a in f.args.args[1:]))
-        args = f"({args})"
+    function_definitions = [
+        node for node in machine_class.body if isinstance(node, ast.FunctionDef)
+    ]
+    function_definitions.sort(key=lambda x: x.name)
 
-        docstr = "\n".join((f"    {l}" for l in docstr.strip().splitlines()))
+    for f in function_definitions:
+        docstr = ast.get_docstring(f)
+        if docstr is not None:
+            args = ", ".join(a.arg for a in f.args.args[1:])
+            args = f"({args})"
 
-        print(f"{f.name}{args}\n\n:{docstr[1:]}\n")
+            docstr = "\n".join(f"    {l}" for l in docstr.strip().splitlines())
+
+            print(f"{f.name}{args}\n\n:{docstr[1:]}\n")
+
+
+if __name__ == "__main__":
+    main()