diff --git a/tests/test_planning.py b/tests/test_planning.py index 640b6c64c..c395f0c15 100644 --- a/tests/test_planning.py +++ b/tests/test_planning.py @@ -257,9 +257,9 @@ def test_forward_plan(): assert expr('Load(C2, P2, JFK)') in air_cargo_solution assert expr('Fly(P2, JFK, SFO)') in air_cargo_solution assert expr('Unload(C2, P2, SFO)') in air_cargo_solution - assert expr('Load(C1, P2, SFO)') in air_cargo_solution - assert expr('Fly(P2, SFO, JFK)') in air_cargo_solution - assert expr('Unload(C1, P2, JFK)') in air_cargo_solution + assert expr('Load(C1, P1, SFO)') in air_cargo_solution + assert expr('Fly(P1, SFO, JFK)') in air_cargo_solution + assert expr('Unload(C1, P1, JFK)') in air_cargo_solution sussman_anomaly_solution = astar_search(ForwardPlan(three_block_tower())).solution() sussman_anomaly_solution = list(map(lambda action: Expr(action.name, *action.args), sussman_anomaly_solution)) @@ -275,11 +275,12 @@ def test_forward_plan(): shopping_problem_solution = astar_search(ForwardPlan(shopping_problem())).solution() shopping_problem_solution = list(map(lambda action: Expr(action.name, *action.args), shopping_problem_solution)) - assert expr('Go(Home, SM)') in shopping_problem_solution assert expr('Buy(Banana, SM)') in shopping_problem_solution assert expr('Buy(Milk, SM)') in shopping_problem_solution - assert expr('Go(SM, HW)') in shopping_problem_solution assert expr('Buy(Drill, HW)') in shopping_problem_solution + # the plan must reach both stores; the exact route may vary by tie-breaking + assert expr('Go(Home, SM)') in shopping_problem_solution or expr('Go(HW, SM)') in shopping_problem_solution + assert expr('Go(Home, HW)') in shopping_problem_solution or expr('Go(SM, HW)') in shopping_problem_solution def test_backward_plan(): diff --git a/tests/test_utils.py b/tests/test_utils.py index 6c2a50808..21cf9a437 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -299,5 +299,15 @@ def __eq__(self, other): assert len(queue) == 0 +def test_priority_queue_ties_with_non_comparable_items(): + # Items with equal priority must never be compared with each other + # (these dicts are not orderable); ties keep insertion (FIFO) order. + queue = PriorityQueue(f=lambda d: d['cost']) + queue.append({'cost': 1, 'id': 'a'}) + queue.append({'cost': 1, 'id': 'b'}) + queue.append({'cost': 0, 'id': 'c'}) + assert [queue.pop()['id'] for _ in range(3)] == ['c', 'a', 'b'] + + if __name__ == '__main__': pytest.main() diff --git a/utils.py b/utils.py index 3158e3793..4ed9980f4 100644 --- a/utils.py +++ b/utils.py @@ -728,6 +728,9 @@ class PriorityQueue: def __init__(self, order='min', f=lambda x: x): self.heap = [] + # monotonic tie-breaker so items with equal priority are never compared + # (which would fail for non-comparable items); ties keep insertion order + self.counter = 0 if order == 'min': self.f = f elif order == 'max': # now item with max f(x) @@ -737,7 +740,8 @@ def __init__(self, order='min', f=lambda x: x): def append(self, item): """Insert item at its correct position.""" - heapq.heappush(self.heap, (self.f(item), item)) + heapq.heappush(self.heap, (self.f(item), self.counter, item)) + self.counter += 1 def extend(self, items): """Insert each item in items at its correct position.""" @@ -748,7 +752,7 @@ def pop(self): """Pop and return the item (with min or max f(x) value) depending on the order.""" if self.heap: - return heapq.heappop(self.heap)[1] + return heapq.heappop(self.heap)[-1] else: raise Exception('Trying to pop from empty PriorityQueue.') @@ -758,12 +762,12 @@ def __len__(self): def __contains__(self, key): """Return True if the key is in PriorityQueue.""" - return any([item == key for _, item in self.heap]) + return any(item == key for *_, item in self.heap) def __getitem__(self, key): """Returns the first value associated with key in PriorityQueue. Raises KeyError if key is not present.""" - for value, item in self.heap: + for value, _, item in self.heap: if item == key: return value raise KeyError(str(key) + " is not in the priority queue") @@ -771,7 +775,7 @@ def __getitem__(self, key): def __delitem__(self, key): """Delete the first occurrence of key.""" try: - del self.heap[[item == key for _, item in self.heap].index(True)] + del self.heap[[item == key for *_, item in self.heap].index(True)] except ValueError: raise KeyError(str(key) + " is not in the priority queue") heapq.heapify(self.heap)