about summary refs log tree commit diff
diff options
context:
space:
mode:
authorsterni <sternenseemann@systemli.org>2021-07-03 15:20:45 +0200
committerSören Tempel <soeren+git@soeren-tempel.net>2021-07-15 13:28:33 +0200
commit21827d91b3a20480f8c1328e943b060a2c227feb (patch)
treeabbe47a4951351271818d26faa0876da3f4c7443
parentc7a0620eb2cca0595a57489d68aa8a535733b96d (diff)
pty.Parser: track pos in iterator, support postponing parsing
The next step for pty.Parser will be to support parsing ANSI CSI
escape sequences which consist of multiple codepoints, but need to be
treated as an unit.

Thus we need to be able to push some unparsed input to the next
invocation of parse(). The way I've chosen to support this is by
supporting backtracking in an iterator which wraps the input string.

This way we'll be able to keep consuming input from the iterator in
the main for loop, but can backtrack in case we hit the end of input
or encounter a parse failure.

I opted to implement this iterator in the form of an object that
tracks an index into a string. For one this feature also helps us in
the parser: We need to track the current index anyways and it is also
helpful for checking if the end of input has been reached.

Additionally this seems like the only feasible way to implement such a
backtracking iterator wrapper: The other alternative would've been to
copy.deepcopy() the iterator at a point we may want to return to —
however this is not practical as deepcopy will hit a recursion depth
error with the moderately sized chunks of input we handle in parse().

This commit introduces the new iterator and adapts our parser code for
it and lays the groundwork for resuming from unparsed input of a
previous invocation of parse(). Additionally it also ships a set of
unit tests for the iterator, many methods of which are still unused in
the main parser code.
-rw-r--r--saneterm/pty.py160
-rw-r--r--tests.py96
2 files changed, 241 insertions, 15 deletions
diff --git a/saneterm/pty.py b/saneterm/pty.py
index 90bcf34..5396d82 100644
--- a/saneterm/pty.py
+++ b/saneterm/pty.py
@@ -43,6 +43,137 @@ class EventType(Enum):
     TEXT = auto()
     BELL = auto()
 
+class PositionedIterator(object):
+    """
+    Wrapper class which implements the iterator interface
+    for a string. In contrast to the default implementation
+    it works by tracking an index in the string internally.
+
+    This allows the following additional features:
+
+    * Checking whether the iterator has any elements left
+      using empty()
+    * Jumping back to a previous point via backtrack()
+
+    The object exposes the following attributes:
+
+    * pos: the index of the last element received via __next__()
+    * wrapped: the string used for construction
+    """
+    def __init__(self, s):
+        # always points to the position of the element
+        # just received via __next__()
+        self.pos = -1
+        self.wrapped = s
+
+        self.waypoints = []
+
+    def waypoint(self):
+        """
+        Mark the index backtrack() should jump to when called.
+        Calling this will make the character received by __next__()
+        after calling backtrack() at any point in the future be
+        the same which was last received via __next__() before
+        calling waypoint().
+
+        Counterintutively, this means that pos immediately after
+        calling waypoint() will be greater than right after
+        calling backtrack() subsequently.
+
+        This allows you to decide whether or not to set a waypoint
+        after inspecting an element which is useful when writing
+        parsers:
+
+        def example(s):
+          it = PositionedIterator(s)
+
+          ignore_colon = False
+
+          for x in it:
+            if ignore_colon:
+              ignore_colon = False
+              # do nothing
+            elif x == ':':
+              it.waypoint()
+
+              if x.next() == ' ':
+                # do stuff …
+              else:
+                it.backtrack()
+                ignore_colon = True
+        """
+        # TODO: maybe don't support calling waypoint if pos == -1
+        self.waypoints.append(max(self.pos - 1, -1))
+
+    def backtrack(self):
+        """See documentation of waypoint()"""
+        self.pos = self.waypoints.pop()
+
+    def next(self):
+        """Shortcut for __next__()"""
+        return self.__next__()
+
+    def take(self, n):
+        """
+        Consume n elements of the iterator and return them as a string slice.
+        """
+        start = self.pos + 1
+
+        for _ in range(n):
+            _ = self.__next__()
+
+        end = self.pos + 1
+
+        return self.wrapped[start:end]
+
+    def takewhile_greedy(self, f):
+        """
+        Consume elements while a given predicate returns True and
+        return them as a string slice. takewhile_greedy() expects
+        the predicate to return False at least once before the end
+        of input and will otherwise raise a StopIteration condition.
+
+        Thus using takewhile_greedy() only makes sense if whatever
+        your parsing is terminated in some way:
+
+        def example(s):
+          foo = takewhile_greedy(lambda x: x != ';')
+
+        example("foo")  # fails
+        example("foo;") # succeeds, but doesn't consume ';'
+
+        (In a real example you'd also consume the semicolon)
+        """
+        x = self.__next__()
+        start = self.pos
+
+        while f(x):
+            x = self.__next__()
+
+        end = self.pos
+        self.pos -= 1
+
+        return self.wrapped[start:end]
+
+    def empty(self):
+        """
+        Check if the iterator has no elements left
+        without consuming the next item (if any).
+        """
+        return self.pos + 1 == len(self.wrapped)
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        self.pos += 1
+
+        try:
+            return self.wrapped[self.pos]
+        except IndexError:
+            self.pos -= 1
+            raise StopIteration
+
 class Parser(object):
     """
     Parses a subset of special control sequences read from
@@ -52,8 +183,8 @@ class Parser(object):
     is tracked in the Parser object.
     """
     def __init__(self):
-        # no state, yet
-        pass
+        # unparsed output left from the last call to parse
+        self.__leftover = ''
 
     def parse(self, input):
         """
@@ -70,17 +201,18 @@ class Parser(object):
           This usually should trigger the machine to beep
           and/or the window to set the urgent flag.
         """
-        # keep track of the start and potential end position
-        # of the slice we want to emit as a TEXT event
+
+        it = PositionedIterator(self.__leftover + input)
+        self.__leftover = ''
+
+        # keep track of the start position of the slice
+        # we want to emit as a TEXT event
         start = 0
-        pos = 0
-        # TODO: can we check for the last element more efficiently?
-        size = len(input)
 
         # we expect a decoded string as input,
         # so we don't need to handle incremental
         # decoding here as well
-        for code in input:
+        for code in it:
             # if flush_until is set, a slice of the buffer
             # from start to flush_until will be emitted as
             # a TEXT event
@@ -94,22 +226,20 @@ class Parser(object):
             # want to handle them ourselves instead of
             # relying of gtk's default behavior.
             if code == '\a':
-                flush_until = pos
+                flush_until = it.pos
                 special_ev = (EventType.BELL, None)
 
-            pos += 1
-
             # at the end of input, flush if we aren't already
-            if flush_until == None and pos >= size:
-                flush_until = pos
+            if flush_until == None and it.empty():
+                flush_until = it.pos + 1
 
             # only generate text event if it is non empty, …
             if flush_until != None and flush_until > start:
-                yield (EventType.TEXT, input[start:flush_until])
+                yield (EventType.TEXT, it.wrapped[start:flush_until])
 
             # … but advance as if we had flushed
             if flush_until != None:
-                start = pos
+                start = it.pos + 1
 
             if special_ev != None:
                 yield special_ev
diff --git a/tests.py b/tests.py
new file mode 100644
index 0000000..34db339
--- /dev/null
+++ b/tests.py
@@ -0,0 +1,96 @@
+import copy
+import unittest
+
+from saneterm.pty import PositionedIterator
+
+TEST_STRING = 'foo;bar'
+
+class TestPositionedIterator(unittest.TestCase):
+    def test_lossless(self):
+        it = PositionedIterator(TEST_STRING)
+
+        self.assertEqual([x for x in it], list(TEST_STRING))
+        self.assertEqual(it.wrapped, TEST_STRING)
+
+    def test_indices(self):
+        it = PositionedIterator(TEST_STRING)
+
+        self.assertEqual(it.pos, -1)
+
+        for x in it:
+            assert x == TEST_STRING[it.pos]
+
+            if x == ';':
+                break
+
+        self.assertEqual(it.pos, 3)
+
+        for x in it:
+            self.assertEqual(x, it.wrapped[it.pos])
+
+        self.assertTrue(it.empty())
+
+    def test_backtracking(self):
+        it = PositionedIterator(TEST_STRING)
+
+        semicolon_index = None
+
+        for x in it:
+            if x == ';':
+                it.waypoint()
+                semicolon_index = it.pos
+
+        self.assertEqual(semicolon_index, TEST_STRING.index(';'))
+
+        self.assertTrue(it.empty())
+
+        with self.assertRaises(StopIteration):
+            _ = it.next()
+
+        it.backtrack()
+
+        self.assertEqual(it.next(), ';')
+        self.assertEqual(it.pos, semicolon_index)
+
+    def test_takewhile(self):
+        it = PositionedIterator(TEST_STRING)
+
+        s = it.takewhile_greedy(lambda x: x != ';')
+
+        self.assertEqual(s, TEST_STRING.split(';')[0])
+        self.assertEqual(it.pos, len(s) - 1)
+        self.assertEqual(it.next(), ';')
+
+    def test_empty(self):
+        it = PositionedIterator(TEST_STRING)
+
+        for x in it:
+            if it.pos + 1 == len(TEST_STRING):
+                self.assertTrue(it.empty())
+
+        self.assertTrue(it.empty())
+
+        with self.assertRaises(StopIteration):
+            _ = it.next()
+
+    def test_take(self):
+        length = 3
+        it1 = PositionedIterator(TEST_STRING)
+        it2 = PositionedIterator(TEST_STRING)
+
+        s1 = it1.take(length)
+        s2 = ''
+        for x in it2:
+            if it2.pos >= length:
+                break
+            else:
+                s2 += x
+
+        self.assertEqual(s1, s2)
+        self.assertEqual(s1, TEST_STRING[0:length])
+
+        # using take does not consume the next element!
+        self.assertEqual(it1.pos, length - 1)
+
+if __name__ == '__main__':
+    unittest.main()