about summary refs log tree commit diff
path: root/tests.py
diff options
context:
space:
mode:
authorsterni <sternenseemann@systemli.org>2021-07-03 15:20:45 +0200
committerSören Tempel <soeren+git@soeren-tempel.net>2021-07-15 13:28:33 +0200
commit21827d91b3a20480f8c1328e943b060a2c227feb (patch)
treeabbe47a4951351271818d26faa0876da3f4c7443 /tests.py
parentc7a0620eb2cca0595a57489d68aa8a535733b96d (diff)
pty.Parser: track pos in iterator, support postponing parsing
The next step for pty.Parser will be to support parsing ANSI CSI
escape sequences which consist of multiple codepoints, but need to be
treated as an unit.

Thus we need to be able to push some unparsed input to the next
invocation of parse(). The way I've chosen to support this is by
supporting backtracking in an iterator which wraps the input string.

This way we'll be able to keep consuming input from the iterator in
the main for loop, but can backtrack in case we hit the end of input
or encounter a parse failure.

I opted to implement this iterator in the form of an object that
tracks an index into a string. For one this feature also helps us in
the parser: We need to track the current index anyways and it is also
helpful for checking if the end of input has been reached.

Additionally this seems like the only feasible way to implement such a
backtracking iterator wrapper: The other alternative would've been to
copy.deepcopy() the iterator at a point we may want to return to —
however this is not practical as deepcopy will hit a recursion depth
error with the moderately sized chunks of input we handle in parse().

This commit introduces the new iterator and adapts our parser code for
it and lays the groundwork for resuming from unparsed input of a
previous invocation of parse(). Additionally it also ships a set of
unit tests for the iterator, many methods of which are still unused in
the main parser code.
Diffstat (limited to 'tests.py')
-rw-r--r--tests.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/tests.py b/tests.py
new file mode 100644
index 0000000..34db339
--- /dev/null
+++ b/tests.py
@@ -0,0 +1,96 @@
+import copy
+import unittest
+
+from saneterm.pty import PositionedIterator
+
+TEST_STRING = 'foo;bar'
+
+class TestPositionedIterator(unittest.TestCase):
+    def test_lossless(self):
+        it = PositionedIterator(TEST_STRING)
+
+        self.assertEqual([x for x in it], list(TEST_STRING))
+        self.assertEqual(it.wrapped, TEST_STRING)
+
+    def test_indices(self):
+        it = PositionedIterator(TEST_STRING)
+
+        self.assertEqual(it.pos, -1)
+
+        for x in it:
+            assert x == TEST_STRING[it.pos]
+
+            if x == ';':
+                break
+
+        self.assertEqual(it.pos, 3)
+
+        for x in it:
+            self.assertEqual(x, it.wrapped[it.pos])
+
+        self.assertTrue(it.empty())
+
+    def test_backtracking(self):
+        it = PositionedIterator(TEST_STRING)
+
+        semicolon_index = None
+
+        for x in it:
+            if x == ';':
+                it.waypoint()
+                semicolon_index = it.pos
+
+        self.assertEqual(semicolon_index, TEST_STRING.index(';'))
+
+        self.assertTrue(it.empty())
+
+        with self.assertRaises(StopIteration):
+            _ = it.next()
+
+        it.backtrack()
+
+        self.assertEqual(it.next(), ';')
+        self.assertEqual(it.pos, semicolon_index)
+
+    def test_takewhile(self):
+        it = PositionedIterator(TEST_STRING)
+
+        s = it.takewhile_greedy(lambda x: x != ';')
+
+        self.assertEqual(s, TEST_STRING.split(';')[0])
+        self.assertEqual(it.pos, len(s) - 1)
+        self.assertEqual(it.next(), ';')
+
+    def test_empty(self):
+        it = PositionedIterator(TEST_STRING)
+
+        for x in it:
+            if it.pos + 1 == len(TEST_STRING):
+                self.assertTrue(it.empty())
+
+        self.assertTrue(it.empty())
+
+        with self.assertRaises(StopIteration):
+            _ = it.next()
+
+    def test_take(self):
+        length = 3
+        it1 = PositionedIterator(TEST_STRING)
+        it2 = PositionedIterator(TEST_STRING)
+
+        s1 = it1.take(length)
+        s2 = ''
+        for x in it2:
+            if it2.pos >= length:
+                break
+            else:
+                s2 += x
+
+        self.assertEqual(s1, s2)
+        self.assertEqual(s1, TEST_STRING[0:length])
+
+        # using take does not consume the next element!
+        self.assertEqual(it1.pos, length - 1)
+
+if __name__ == '__main__':
+    unittest.main()