diff --git a/csp.py b/csp.py index 247171a3a..b0645f368 100644 --- a/csp.py +++ b/csp.py @@ -83,6 +83,14 @@ def conflict(var2): return count(conflict(v) for v in self.neighbors[var]) + def count_lost_values(self, var, val, assignment): + """Return how many values would be ruled out in the domains of the + unassigned neighbours of var if var were assigned val (the count used + by the least-constraining-value heuristic).""" + return count(not self.constraints(var, val, neighbor, dval) + for neighbor in self.neighbors[var] if neighbor not in assignment + for dval in self.domains[neighbor]) + def display(self, assignment): """Show a human-readable representation of the CSP.""" # Subclasses can print in a prettier way, or display with a GUI @@ -371,7 +379,7 @@ def unordered_domain_values(var, assignment, csp): def lcv(var, assignment, csp): """Least-constraining-values heuristic.""" - return sorted(csp.choices(var), key=lambda val: csp.nconflicts(var, val, assignment)) + return sorted(csp.choices(var), key=lambda val: csp.count_lost_values(var, val, assignment)) # Inference diff --git a/tests/test_csp.py b/tests/test_csp.py index ccc271cd5..e78bb05ae 100644 --- a/tests/test_csp.py +++ b/tests/test_csp.py @@ -35,6 +35,12 @@ def test_csp_nconflicts(): val = 'B' assert map_coloring_test.nconflicts(var, val, assignment) == 0 +def test_csp_count_lost_values(): + map_coloring_test = MapColoringCSP(list('RGB'), 'A: B C; B: C; C: ') + assignment = {'A': 'G'} + var = 'C' + val = 'R' + assert map_coloring_test.count_lost_values(var, val, assignment) == 1 def test_csp_actions(): map_coloring_test = MapColoringCSP(list('123'), 'A: B C; B: C; C: ') @@ -332,13 +338,13 @@ def test_lcv(): var = 'B' - assert lcv(var, assignment, csp) == [4, 0, 1, 2, 3, 5] - assignment = {'A': 1, 'C': 3} + assert lcv(var, assignment, csp) == [0, 2, 4, 1, 3, 5] + assignment = {'A': 1} constraints = lambda X, x, Y, y: (x + y) % 2 == 0 and (x + y) < 5 csp = CSP(variables=None, domains=domains, neighbors=neighbors, constraints=constraints) - assert lcv(var, assignment, csp) == [1, 3, 0, 2, 4, 5] + assert lcv(var, assignment, csp) == [0, 1, 2, 3, 4, 5] def test_forward_checking():