Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Create a space-efficient Snapshot Set

I received this interview question that I didn't know how to solve.

Design a snapshot set functionality.

Once the snapshot is taken, the iterator of the class should only return values that were present in the function.

The class should provide add, remove, and contains functionality. The iterator always returns elements that were present in the snapshot even though the element might be removed from set after the snapshot.

The snapshot of the set is taken when the iterator function is called.

interface SnapshotSet {
  void add(int num);
  void remove(num);
  boolean contains(num);
  Iterator<Integer> iterator(); // the first call to this function should trigger a snapshot of the set
}

The interviewer said that the space requirement is that we cannot create a copy (snapshot) of the entire list of keys when calling iterator.

The first step is to handle only one iterator being created and being iterated over at a time. The followup question: how to handle the scenario of multiple iterators?

An example:

SnapshotSet set = new SnapshotSet();
set.add(1);
set.add(2);
set.add(3);
set.add(4);
Iterator<Integer> itr1 = set.iterator(); // iterator should return 1, 2, 3, 4 (in any order) when next() is called.
set.remove(1);
set.contains(1); // returns false; because 1 was removed.
Iterator<Integer> itr2 = set.iterator(); // iterator should return 2, 3, 4 (in any order) when next() is called.

I came up with an O(n) space solution where I created a copy of the entire list of keys when calling iterator. The interviewer said this was not space efficient enough.

I think it is fine to have a solution that focuses on reducing space at the cost of time complexity (but the time complexity should still be as efficient as possible).

like image 369
user21200640 Avatar asked Oct 29 '25 23:10

user21200640


1 Answers

This is a very different but ultimately much better answer than the one I gave at first. The idea is simply to have the data structure be a read-only reasonably well balanced sorted tree. Since it is read-only, it is easy to iterate over it.

But then how do you make modifications? Well, you simply create a new copy of the tree from the modification on up to the root. This will be O(log(n)) new nodes. Better yet the O(log(n)) old nodes that were replaced can be trivially garbage collected if they are not in use.

All operations are O(log(n)) except iteration which is O(n). I also included both an explicit iterator using callbacks, and an implicit one using Python's generators.

And for fun I coded it up in Python.

class TreeNode:
    def __init__ (self, value, left=None, right=None):
        self.value = value
        count = 1
        if left is not None:
            count += left.count
        if right is not None:
            count += right.count
        self.count = count
        self.left = left
        self.right = right

    def left_count (self):
        if self.left is None:
            return 0
        else:
            return self.left.count

    def right_count (self):
        if self.right is None:
            return 0
        else:
            return self.right.count

    def attach_left (self, child):
        # New node for balanced tree with self.left replaced by child.
        if id(child) == id(self.left):
            return self
        elif child is None:
           return TreeNode(self.value).attach_right(self.right)
        elif child.left_count() < child.right_count() + self.right_count():
            return TreeNode(self.value, child, self.right)
        else:
            new_right = TreeNode(self.value, child.right, self.right)
            return TreeNode(child.value, child.left, new_right)

    def attach_right (self, child):
        # New node for balanced tree with self.right replaced by child.
        if id(child) == id(self.right):
            return self
        elif child is None:
            return TreeNode(self.value).attach_left(self.left)
        elif child.right_count() < child.left_count() + self.left_count():
            return TreeNode(self.value, self.left, child)
        else:
            new_left = TreeNode(self.value, self.left, child.left)
            return TreeNode(child.value, new_left, child.right)

    def merge_right (self, other):
        # New node for balanced tree with all of self, then all of other.
        if other is None:
            return self
        elif self.right is None:
            return self.attach_right(other)
        elif other.left is None:
            return other.attach_left(self)
        else:
            child = self.right.merge_right(other.left)
            if self.left_count() < other.right_count():
                child = self.attach_right(child)
                return other.attach_left(child)
            else:
                child = other.attach_left(child)
                return self.attach_right(child)

    def add (self, value):
        if value < self.value:
            if self.left is None:
                child = TreeNode(value)
            else:
                child = self.left.add(value)
            return self.attach_left(child)
        elif self.value < value:
            if self.right is None:
                child = TreeNode(value)
            else:
                child = self.right.add(value)
            return self.attach_right(child)
        else:
            return self

    def remove (self, value):
        if value < self.value:
            if self.left is None:
                return self
            else:
                return self.attach_left(self.left.remove(value))
        elif self.value < value:
            if self.right is None:
                return self
            else:
                return self.attach_right(self.right.remove(value))
        else:
            if self.left is None:
                return self.right
            elif self.right is None:
                return self.left
            else:
                return self.left.merge_right(self.right)

    def __str__ (self):
        if self.left is None:
            left_lines = []
        else:
            left_lines = str(self.left).split("\n")
            left_lines.pop()

            left_lines = ["  " + l for l in left_lines]

        if self.right is None:
            right_lines = []
        else:
            right_lines = str(self.right).split("\n")
            right_lines.pop()

            right_lines = ["  " + l for l in right_lines]

        return "\n".join(left_lines + [str(self.value)] + right_lines) + "\n"

    # Pythonic iterator.
    def __iter__ (self):
        if self.left is not None:
            yield from self.left
        yield self.value
        if self.right is not None:
            yield from self.right



class SnapshottableSet:
    def __init__ (self, root=None):
        self.root = root

    def contains (self, value):
        node = self.root
        while node is not None:
            if value < node.value:
                node = node.left
            elif node.value < value:
                node = node.right
            else:
                return True
        return False

    def add (self, value):
        if self.root is None:
            self.root = TreeNode(value)
        else:
            self.root = self.root.add(value)

    def remove (self, value):
        if self.root is not None:
            self.root = self.root.remove(value)

    # Pythonic built-in approach
    def __iter__ (self):
        if self.root is not None:
            yield from self.root

    # And explicit approach
    def iterator (self):
        nodes = []
        if self.root is not None:
            node = self.root
            while node is not None:
                nodes.append(node)
                node = node.left
        def next_value ():
            if len(nodes):
                node = nodes.pop()
                value = node.value
                node = node.right
                while node is not None:
                    nodes.append(node)
                    node = node.left
                return value
            else:
                raise StopIteration
        return next_value

s = SnapshottableSet()
for i in range(10):
    s.add(i)
it = s.iterator()
for i in range(5):
    s.remove(2*i)

print("Current contents")
for v in s:
    print(v)

print("Original contents")
while True:
    print(it())
like image 144
btilly Avatar answered Nov 01 '25 13:11

btilly