add read me
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
308
venv/lib/python3.12/site-packages/sympy/logic/algorithms/dpll.py
Normal file
308
venv/lib/python3.12/site-packages/sympy/logic/algorithms/dpll.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Implementation of DPLL algorithm
|
||||
|
||||
Further improvements: eliminate calls to pl_true, implement branching rules,
|
||||
efficient unit propagation.
|
||||
|
||||
References:
|
||||
- https://en.wikipedia.org/wiki/DPLL_algorithm
|
||||
- https://www.researchgate.net/publication/242384772_Implementations_of_the_DPLL_Algorithm
|
||||
"""
|
||||
|
||||
from sympy.core.sorting import default_sort_key
|
||||
from sympy.logic.boolalg import Or, Not, conjuncts, disjuncts, to_cnf, \
|
||||
to_int_repr, _find_predicates
|
||||
from sympy.assumptions.cnf import CNF
|
||||
from sympy.logic.inference import pl_true, literal_symbol
|
||||
|
||||
|
||||
def dpll_satisfiable(expr):
|
||||
"""
|
||||
Check satisfiability of a propositional sentence.
|
||||
It returns a model rather than True when it succeeds
|
||||
|
||||
>>> from sympy.abc import A, B
|
||||
>>> from sympy.logic.algorithms.dpll import dpll_satisfiable
|
||||
>>> dpll_satisfiable(A & ~B)
|
||||
{A: True, B: False}
|
||||
>>> dpll_satisfiable(A & ~A)
|
||||
False
|
||||
|
||||
"""
|
||||
if not isinstance(expr, CNF):
|
||||
clauses = conjuncts(to_cnf(expr))
|
||||
else:
|
||||
clauses = expr.clauses
|
||||
if False in clauses:
|
||||
return False
|
||||
symbols = sorted(_find_predicates(expr), key=default_sort_key)
|
||||
symbols_int_repr = set(range(1, len(symbols) + 1))
|
||||
clauses_int_repr = to_int_repr(clauses, symbols)
|
||||
result = dpll_int_repr(clauses_int_repr, symbols_int_repr, {})
|
||||
if not result:
|
||||
return result
|
||||
output = {}
|
||||
for key in result:
|
||||
output.update({symbols[key - 1]: result[key]})
|
||||
return output
|
||||
|
||||
|
||||
def dpll(clauses, symbols, model):
|
||||
"""
|
||||
Compute satisfiability in a partial model.
|
||||
Clauses is an array of conjuncts.
|
||||
|
||||
>>> from sympy.abc import A, B, D
|
||||
>>> from sympy.logic.algorithms.dpll import dpll
|
||||
>>> dpll([A, B, D], [A, B], {D: False})
|
||||
False
|
||||
|
||||
"""
|
||||
# compute DP kernel
|
||||
P, value = find_unit_clause(clauses, model)
|
||||
while P:
|
||||
model.update({P: value})
|
||||
symbols.remove(P)
|
||||
if not value:
|
||||
P = ~P
|
||||
clauses = unit_propagate(clauses, P)
|
||||
P, value = find_unit_clause(clauses, model)
|
||||
P, value = find_pure_symbol(symbols, clauses)
|
||||
while P:
|
||||
model.update({P: value})
|
||||
symbols.remove(P)
|
||||
if not value:
|
||||
P = ~P
|
||||
clauses = unit_propagate(clauses, P)
|
||||
P, value = find_pure_symbol(symbols, clauses)
|
||||
# end DP kernel
|
||||
unknown_clauses = []
|
||||
for c in clauses:
|
||||
val = pl_true(c, model)
|
||||
if val is False:
|
||||
return False
|
||||
if val is not True:
|
||||
unknown_clauses.append(c)
|
||||
if not unknown_clauses:
|
||||
return model
|
||||
if not clauses:
|
||||
return model
|
||||
P = symbols.pop()
|
||||
model_copy = model.copy()
|
||||
model.update({P: True})
|
||||
model_copy.update({P: False})
|
||||
symbols_copy = symbols[:]
|
||||
return (dpll(unit_propagate(unknown_clauses, P), symbols, model) or
|
||||
dpll(unit_propagate(unknown_clauses, Not(P)), symbols_copy, model_copy))
|
||||
|
||||
|
||||
def dpll_int_repr(clauses, symbols, model):
|
||||
"""
|
||||
Compute satisfiability in a partial model.
|
||||
Arguments are expected to be in integer representation
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll import dpll_int_repr
|
||||
>>> dpll_int_repr([{1}, {2}, {3}], {1, 2}, {3: False})
|
||||
False
|
||||
|
||||
"""
|
||||
# compute DP kernel
|
||||
P, value = find_unit_clause_int_repr(clauses, model)
|
||||
while P:
|
||||
model.update({P: value})
|
||||
symbols.remove(P)
|
||||
if not value:
|
||||
P = -P
|
||||
clauses = unit_propagate_int_repr(clauses, P)
|
||||
P, value = find_unit_clause_int_repr(clauses, model)
|
||||
P, value = find_pure_symbol_int_repr(symbols, clauses)
|
||||
while P:
|
||||
model.update({P: value})
|
||||
symbols.remove(P)
|
||||
if not value:
|
||||
P = -P
|
||||
clauses = unit_propagate_int_repr(clauses, P)
|
||||
P, value = find_pure_symbol_int_repr(symbols, clauses)
|
||||
# end DP kernel
|
||||
unknown_clauses = []
|
||||
for c in clauses:
|
||||
val = pl_true_int_repr(c, model)
|
||||
if val is False:
|
||||
return False
|
||||
if val is not True:
|
||||
unknown_clauses.append(c)
|
||||
if not unknown_clauses:
|
||||
return model
|
||||
P = symbols.pop()
|
||||
model_copy = model.copy()
|
||||
model.update({P: True})
|
||||
model_copy.update({P: False})
|
||||
symbols_copy = symbols.copy()
|
||||
return (dpll_int_repr(unit_propagate_int_repr(unknown_clauses, P), symbols, model) or
|
||||
dpll_int_repr(unit_propagate_int_repr(unknown_clauses, -P), symbols_copy, model_copy))
|
||||
|
||||
### helper methods for DPLL
|
||||
|
||||
|
||||
def pl_true_int_repr(clause, model={}):
|
||||
"""
|
||||
Lightweight version of pl_true.
|
||||
Argument clause represents the set of args of an Or clause. This is used
|
||||
inside dpll_int_repr, it is not meant to be used directly.
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll import pl_true_int_repr
|
||||
>>> pl_true_int_repr({1, 2}, {1: False})
|
||||
>>> pl_true_int_repr({1, 2}, {1: False, 2: False})
|
||||
False
|
||||
|
||||
"""
|
||||
result = False
|
||||
for lit in clause:
|
||||
if lit < 0:
|
||||
p = model.get(-lit)
|
||||
if p is not None:
|
||||
p = not p
|
||||
else:
|
||||
p = model.get(lit)
|
||||
if p is True:
|
||||
return True
|
||||
elif p is None:
|
||||
result = None
|
||||
return result
|
||||
|
||||
|
||||
def unit_propagate(clauses, symbol):
|
||||
"""
|
||||
Returns an equivalent set of clauses
|
||||
If a set of clauses contains the unit clause l, the other clauses are
|
||||
simplified by the application of the two following rules:
|
||||
|
||||
1. every clause containing l is removed
|
||||
2. in every clause that contains ~l this literal is deleted
|
||||
|
||||
Arguments are expected to be in CNF.
|
||||
|
||||
>>> from sympy.abc import A, B, D
|
||||
>>> from sympy.logic.algorithms.dpll import unit_propagate
|
||||
>>> unit_propagate([A | B, D | ~B, B], B)
|
||||
[D, B]
|
||||
|
||||
"""
|
||||
output = []
|
||||
for c in clauses:
|
||||
if c.func != Or:
|
||||
output.append(c)
|
||||
continue
|
||||
for arg in c.args:
|
||||
if arg == ~symbol:
|
||||
output.append(Or(*[x for x in c.args if x != ~symbol]))
|
||||
break
|
||||
if arg == symbol:
|
||||
break
|
||||
else:
|
||||
output.append(c)
|
||||
return output
|
||||
|
||||
|
||||
def unit_propagate_int_repr(clauses, s):
|
||||
"""
|
||||
Same as unit_propagate, but arguments are expected to be in integer
|
||||
representation
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll import unit_propagate_int_repr
|
||||
>>> unit_propagate_int_repr([{1, 2}, {3, -2}, {2}], 2)
|
||||
[{3}]
|
||||
|
||||
"""
|
||||
negated = {-s}
|
||||
return [clause - negated for clause in clauses if s not in clause]
|
||||
|
||||
|
||||
def find_pure_symbol(symbols, unknown_clauses):
|
||||
"""
|
||||
Find a symbol and its value if it appears only as a positive literal
|
||||
(or only as a negative) in clauses.
|
||||
|
||||
>>> from sympy.abc import A, B, D
|
||||
>>> from sympy.logic.algorithms.dpll import find_pure_symbol
|
||||
>>> find_pure_symbol([A, B, D], [A|~B,~B|~D,D|A])
|
||||
(A, True)
|
||||
|
||||
"""
|
||||
for sym in symbols:
|
||||
found_pos, found_neg = False, False
|
||||
for c in unknown_clauses:
|
||||
if not found_pos and sym in disjuncts(c):
|
||||
found_pos = True
|
||||
if not found_neg and Not(sym) in disjuncts(c):
|
||||
found_neg = True
|
||||
if found_pos != found_neg:
|
||||
return sym, found_pos
|
||||
return None, None
|
||||
|
||||
|
||||
def find_pure_symbol_int_repr(symbols, unknown_clauses):
|
||||
"""
|
||||
Same as find_pure_symbol, but arguments are expected
|
||||
to be in integer representation
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll import find_pure_symbol_int_repr
|
||||
>>> find_pure_symbol_int_repr({1,2,3},
|
||||
... [{1, -2}, {-2, -3}, {3, 1}])
|
||||
(1, True)
|
||||
|
||||
"""
|
||||
all_symbols = set().union(*unknown_clauses)
|
||||
found_pos = all_symbols.intersection(symbols)
|
||||
found_neg = all_symbols.intersection([-s for s in symbols])
|
||||
for p in found_pos:
|
||||
if -p not in found_neg:
|
||||
return p, True
|
||||
for p in found_neg:
|
||||
if -p not in found_pos:
|
||||
return -p, False
|
||||
return None, None
|
||||
|
||||
|
||||
def find_unit_clause(clauses, model):
|
||||
"""
|
||||
A unit clause has only 1 variable that is not bound in the model.
|
||||
|
||||
>>> from sympy.abc import A, B, D
|
||||
>>> from sympy.logic.algorithms.dpll import find_unit_clause
|
||||
>>> find_unit_clause([A | B | D, B | ~D, A | ~B], {A:True})
|
||||
(B, False)
|
||||
|
||||
"""
|
||||
for clause in clauses:
|
||||
num_not_in_model = 0
|
||||
for literal in disjuncts(clause):
|
||||
sym = literal_symbol(literal)
|
||||
if sym not in model:
|
||||
num_not_in_model += 1
|
||||
P, value = sym, not isinstance(literal, Not)
|
||||
if num_not_in_model == 1:
|
||||
return P, value
|
||||
return None, None
|
||||
|
||||
|
||||
def find_unit_clause_int_repr(clauses, model):
|
||||
"""
|
||||
Same as find_unit_clause, but arguments are expected to be in
|
||||
integer representation.
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll import find_unit_clause_int_repr
|
||||
>>> find_unit_clause_int_repr([{1, 2, 3},
|
||||
... {2, -3}, {1, -2}], {1: True})
|
||||
(2, False)
|
||||
|
||||
"""
|
||||
bound = set(model) | {-sym for sym in model}
|
||||
for clause in clauses:
|
||||
unbound = clause - bound
|
||||
if len(unbound) == 1:
|
||||
p = unbound.pop()
|
||||
if p < 0:
|
||||
return -p, False
|
||||
else:
|
||||
return p, True
|
||||
return None, None
|
||||
@@ -0,0 +1,688 @@
|
||||
"""Implementation of DPLL algorithm
|
||||
|
||||
Features:
|
||||
- Clause learning
|
||||
- Watch literal scheme
|
||||
- VSIDS heuristic
|
||||
|
||||
References:
|
||||
- https://en.wikipedia.org/wiki/DPLL_algorithm
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from heapq import heappush, heappop
|
||||
|
||||
from sympy.core.sorting import ordered
|
||||
from sympy.assumptions.cnf import EncodedCNF
|
||||
|
||||
from sympy.logic.algorithms.lra_theory import LRASolver
|
||||
|
||||
|
||||
def dpll_satisfiable(expr, all_models=False, use_lra_theory=False):
|
||||
"""
|
||||
Check satisfiability of a propositional sentence.
|
||||
It returns a model rather than True when it succeeds.
|
||||
Returns a generator of all models if all_models is True.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import A, B
|
||||
>>> from sympy.logic.algorithms.dpll2 import dpll_satisfiable
|
||||
>>> dpll_satisfiable(A & ~B)
|
||||
{A: True, B: False}
|
||||
>>> dpll_satisfiable(A & ~A)
|
||||
False
|
||||
|
||||
"""
|
||||
if not isinstance(expr, EncodedCNF):
|
||||
exprs = EncodedCNF()
|
||||
exprs.add_prop(expr)
|
||||
expr = exprs
|
||||
|
||||
# Return UNSAT when False (encoded as 0) is present in the CNF
|
||||
if {0} in expr.data:
|
||||
if all_models:
|
||||
return (f for f in [False])
|
||||
return False
|
||||
|
||||
if use_lra_theory:
|
||||
lra, immediate_conflicts = LRASolver.from_encoded_cnf(expr)
|
||||
else:
|
||||
lra = None
|
||||
immediate_conflicts = []
|
||||
solver = SATSolver(expr.data + immediate_conflicts, expr.variables, set(), expr.symbols, lra_theory=lra)
|
||||
models = solver._find_model()
|
||||
|
||||
if all_models:
|
||||
return _all_models(models)
|
||||
|
||||
try:
|
||||
return next(models)
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
# Uncomment to confirm the solution is valid (hitting set for the clauses)
|
||||
#else:
|
||||
#for cls in clauses_int_repr:
|
||||
#assert solver.var_settings.intersection(cls)
|
||||
|
||||
|
||||
def _all_models(models):
|
||||
satisfiable = False
|
||||
try:
|
||||
while True:
|
||||
yield next(models)
|
||||
satisfiable = True
|
||||
except StopIteration:
|
||||
if not satisfiable:
|
||||
yield False
|
||||
|
||||
|
||||
class SATSolver:
|
||||
"""
|
||||
Class for representing a SAT solver capable of
|
||||
finding a model to a boolean theory in conjunctive
|
||||
normal form.
|
||||
"""
|
||||
|
||||
def __init__(self, clauses, variables, var_settings, symbols=None,
|
||||
heuristic='vsids', clause_learning='none', INTERVAL=500,
|
||||
lra_theory = None):
|
||||
|
||||
self.var_settings = var_settings
|
||||
self.heuristic = heuristic
|
||||
self.is_unsatisfied = False
|
||||
self._unit_prop_queue = []
|
||||
self.update_functions = []
|
||||
self.INTERVAL = INTERVAL
|
||||
|
||||
if symbols is None:
|
||||
self.symbols = list(ordered(variables))
|
||||
else:
|
||||
self.symbols = symbols
|
||||
|
||||
self._initialize_variables(variables)
|
||||
self._initialize_clauses(clauses)
|
||||
|
||||
if 'vsids' == heuristic:
|
||||
self._vsids_init()
|
||||
self.heur_calculate = self._vsids_calculate
|
||||
self.heur_lit_assigned = self._vsids_lit_assigned
|
||||
self.heur_lit_unset = self._vsids_lit_unset
|
||||
self.heur_clause_added = self._vsids_clause_added
|
||||
|
||||
# Note: Uncomment this if/when clause learning is enabled
|
||||
#self.update_functions.append(self._vsids_decay)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if 'simple' == clause_learning:
|
||||
self.add_learned_clause = self._simple_add_learned_clause
|
||||
self.compute_conflict = self._simple_compute_conflict
|
||||
self.update_functions.append(self._simple_clean_clauses)
|
||||
elif 'none' == clause_learning:
|
||||
self.add_learned_clause = lambda x: None
|
||||
self.compute_conflict = lambda: None
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# Create the base level
|
||||
self.levels = [Level(0)]
|
||||
self._current_level.varsettings = var_settings
|
||||
|
||||
# Keep stats
|
||||
self.num_decisions = 0
|
||||
self.num_learned_clauses = 0
|
||||
self.original_num_clauses = len(self.clauses)
|
||||
|
||||
self.lra = lra_theory
|
||||
|
||||
def _initialize_variables(self, variables):
|
||||
"""Set up the variable data structures needed."""
|
||||
self.sentinels = defaultdict(set)
|
||||
self.occurrence_count = defaultdict(int)
|
||||
self.variable_set = [False] * (len(variables) + 1)
|
||||
|
||||
def _initialize_clauses(self, clauses):
|
||||
"""Set up the clause data structures needed.
|
||||
|
||||
For each clause, the following changes are made:
|
||||
- Unit clauses are queued for propagation right away.
|
||||
- Non-unit clauses have their first and last literals set as sentinels.
|
||||
- The number of clauses a literal appears in is computed.
|
||||
"""
|
||||
self.clauses = [list(clause) for clause in clauses]
|
||||
|
||||
for i, clause in enumerate(self.clauses):
|
||||
|
||||
# Handle the unit clauses
|
||||
if 1 == len(clause):
|
||||
self._unit_prop_queue.append(clause[0])
|
||||
continue
|
||||
|
||||
self.sentinels[clause[0]].add(i)
|
||||
self.sentinels[clause[-1]].add(i)
|
||||
|
||||
for lit in clause:
|
||||
self.occurrence_count[lit] += 1
|
||||
|
||||
def _find_model(self):
|
||||
"""
|
||||
Main DPLL loop. Returns a generator of models.
|
||||
|
||||
Variables are chosen successively, and assigned to be either
|
||||
True or False. If a solution is not found with this setting,
|
||||
the opposite is chosen and the search continues. The solver
|
||||
halts when every variable has a setting.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
>>> list(l._find_model())
|
||||
[{1: True, 2: False, 3: False}, {1: True, 2: True, 3: True}]
|
||||
|
||||
>>> from sympy.abc import A, B, C
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set(), [A, B, C])
|
||||
>>> list(l._find_model())
|
||||
[{A: True, B: False, C: False}, {A: True, B: True, C: True}]
|
||||
|
||||
"""
|
||||
|
||||
# We use this variable to keep track of if we should flip a
|
||||
# variable setting in successive rounds
|
||||
flip_var = False
|
||||
|
||||
# Check if unit prop says the theory is unsat right off the bat
|
||||
self._simplify()
|
||||
if self.is_unsatisfied:
|
||||
return
|
||||
|
||||
# While the theory still has clauses remaining
|
||||
while True:
|
||||
# Perform cleanup / fixup at regular intervals
|
||||
if self.num_decisions % self.INTERVAL == 0:
|
||||
for func in self.update_functions:
|
||||
func()
|
||||
|
||||
if flip_var:
|
||||
# We have just backtracked and we are trying to opposite literal
|
||||
flip_var = False
|
||||
lit = self._current_level.decision
|
||||
|
||||
else:
|
||||
# Pick a literal to set
|
||||
lit = self.heur_calculate()
|
||||
self.num_decisions += 1
|
||||
|
||||
# Stopping condition for a satisfying theory
|
||||
if 0 == lit:
|
||||
|
||||
# check if assignment satisfies lra theory
|
||||
if self.lra:
|
||||
for enc_var in self.var_settings:
|
||||
res = self.lra.assert_lit(enc_var)
|
||||
if res is not None:
|
||||
break
|
||||
res = self.lra.check()
|
||||
self.lra.reset_bounds()
|
||||
else:
|
||||
res = None
|
||||
if res is None or res[0]:
|
||||
yield {self.symbols[abs(lit) - 1]:
|
||||
lit > 0 for lit in self.var_settings}
|
||||
else:
|
||||
self._simple_add_learned_clause(res[1])
|
||||
|
||||
# backtrack until we unassign one of the literals causing the conflict
|
||||
while not any(-lit in res[1] for lit in self._current_level.var_settings):
|
||||
self._undo()
|
||||
|
||||
while self._current_level.flipped:
|
||||
self._undo()
|
||||
if len(self.levels) == 1:
|
||||
return
|
||||
flip_lit = -self._current_level.decision
|
||||
self._undo()
|
||||
self.levels.append(Level(flip_lit, flipped=True))
|
||||
flip_var = True
|
||||
continue
|
||||
|
||||
# Start the new decision level
|
||||
self.levels.append(Level(lit))
|
||||
|
||||
# Assign the literal, updating the clauses it satisfies
|
||||
self._assign_literal(lit)
|
||||
|
||||
# _simplify the theory
|
||||
self._simplify()
|
||||
|
||||
# Check if we've made the theory unsat
|
||||
if self.is_unsatisfied:
|
||||
|
||||
self.is_unsatisfied = False
|
||||
|
||||
# We unroll all of the decisions until we can flip a literal
|
||||
while self._current_level.flipped:
|
||||
self._undo()
|
||||
|
||||
# If we've unrolled all the way, the theory is unsat
|
||||
if 1 == len(self.levels):
|
||||
return
|
||||
|
||||
# Detect and add a learned clause
|
||||
self.add_learned_clause(self.compute_conflict())
|
||||
|
||||
# Try the opposite setting of the most recent decision
|
||||
flip_lit = -self._current_level.decision
|
||||
self._undo()
|
||||
self.levels.append(Level(flip_lit, flipped=True))
|
||||
flip_var = True
|
||||
|
||||
########################
|
||||
# Helper Methods #
|
||||
########################
|
||||
@property
|
||||
def _current_level(self):
|
||||
"""The current decision level data structure
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{1}, {2}], {1, 2}, set())
|
||||
>>> next(l._find_model())
|
||||
{1: True, 2: True}
|
||||
>>> l._current_level.decision
|
||||
0
|
||||
>>> l._current_level.flipped
|
||||
False
|
||||
>>> l._current_level.var_settings
|
||||
{1, 2}
|
||||
|
||||
"""
|
||||
return self.levels[-1]
|
||||
|
||||
def _clause_sat(self, cls):
|
||||
"""Check if a clause is satisfied by the current variable setting.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{1}, {-1}], {1}, set())
|
||||
>>> try:
|
||||
... next(l._find_model())
|
||||
... except StopIteration:
|
||||
... pass
|
||||
>>> l._clause_sat(0)
|
||||
False
|
||||
>>> l._clause_sat(1)
|
||||
True
|
||||
|
||||
"""
|
||||
for lit in self.clauses[cls]:
|
||||
if lit in self.var_settings:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_sentinel(self, lit, cls):
|
||||
"""Check if a literal is a sentinel of a given clause.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
>>> next(l._find_model())
|
||||
{1: True, 2: False, 3: False}
|
||||
>>> l._is_sentinel(2, 3)
|
||||
True
|
||||
>>> l._is_sentinel(-3, 1)
|
||||
False
|
||||
|
||||
"""
|
||||
return cls in self.sentinels[lit]
|
||||
|
||||
def _assign_literal(self, lit):
|
||||
"""Make a literal assignment.
|
||||
|
||||
The literal assignment must be recorded as part of the current
|
||||
decision level. Additionally, if the literal is marked as a
|
||||
sentinel of any clause, then a new sentinel must be chosen. If
|
||||
this is not possible, then unit propagation is triggered and
|
||||
another literal is added to the queue to be set in the future.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
>>> next(l._find_model())
|
||||
{1: True, 2: False, 3: False}
|
||||
>>> l.var_settings
|
||||
{-3, -2, 1}
|
||||
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
>>> l._assign_literal(-1)
|
||||
>>> try:
|
||||
... next(l._find_model())
|
||||
... except StopIteration:
|
||||
... pass
|
||||
>>> l.var_settings
|
||||
{-1}
|
||||
|
||||
"""
|
||||
self.var_settings.add(lit)
|
||||
self._current_level.var_settings.add(lit)
|
||||
self.variable_set[abs(lit)] = True
|
||||
self.heur_lit_assigned(lit)
|
||||
|
||||
sentinel_list = list(self.sentinels[-lit])
|
||||
|
||||
for cls in sentinel_list:
|
||||
if not self._clause_sat(cls):
|
||||
other_sentinel = None
|
||||
for newlit in self.clauses[cls]:
|
||||
if newlit != -lit:
|
||||
if self._is_sentinel(newlit, cls):
|
||||
other_sentinel = newlit
|
||||
elif not self.variable_set[abs(newlit)]:
|
||||
self.sentinels[-lit].remove(cls)
|
||||
self.sentinels[newlit].add(cls)
|
||||
other_sentinel = None
|
||||
break
|
||||
|
||||
# Check if no sentinel update exists
|
||||
if other_sentinel:
|
||||
self._unit_prop_queue.append(other_sentinel)
|
||||
|
||||
def _undo(self):
|
||||
"""
|
||||
_undo the changes of the most recent decision level.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
>>> next(l._find_model())
|
||||
{1: True, 2: False, 3: False}
|
||||
>>> level = l._current_level
|
||||
>>> level.decision, level.var_settings, level.flipped
|
||||
(-3, {-3, -2}, False)
|
||||
>>> l._undo()
|
||||
>>> level = l._current_level
|
||||
>>> level.decision, level.var_settings, level.flipped
|
||||
(0, {1}, False)
|
||||
|
||||
"""
|
||||
# Undo the variable settings
|
||||
for lit in self._current_level.var_settings:
|
||||
self.var_settings.remove(lit)
|
||||
self.heur_lit_unset(lit)
|
||||
self.variable_set[abs(lit)] = False
|
||||
|
||||
# Pop the level off the stack
|
||||
self.levels.pop()
|
||||
|
||||
#########################
|
||||
# Propagation #
|
||||
#########################
|
||||
"""
|
||||
Propagation methods should attempt to soundly simplify the boolean
|
||||
theory, and return True if any simplification occurred and False
|
||||
otherwise.
|
||||
"""
|
||||
def _simplify(self):
|
||||
"""Iterate over the various forms of propagation to simplify the theory.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
>>> l.variable_set
|
||||
[False, False, False, False]
|
||||
>>> l.sentinels
|
||||
{-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}}
|
||||
|
||||
>>> l._simplify()
|
||||
|
||||
>>> l.variable_set
|
||||
[False, True, False, False]
|
||||
>>> l.sentinels
|
||||
{-3: {0, 2}, -2: {3, 4}, -1: set(), 2: {0, 3},
|
||||
...3: {2, 4}}
|
||||
|
||||
"""
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
changed |= self._unit_prop()
|
||||
changed |= self._pure_literal()
|
||||
|
||||
def _unit_prop(self):
|
||||
"""Perform unit propagation on the current theory."""
|
||||
result = len(self._unit_prop_queue) > 0
|
||||
while self._unit_prop_queue:
|
||||
next_lit = self._unit_prop_queue.pop()
|
||||
if -next_lit in self.var_settings:
|
||||
self.is_unsatisfied = True
|
||||
self._unit_prop_queue = []
|
||||
return False
|
||||
else:
|
||||
self._assign_literal(next_lit)
|
||||
|
||||
return result
|
||||
|
||||
def _pure_literal(self):
|
||||
"""Look for pure literals and assign them when found."""
|
||||
return False
|
||||
|
||||
#########################
|
||||
# Heuristics #
|
||||
#########################
|
||||
def _vsids_init(self):
|
||||
"""Initialize the data structures needed for the VSIDS heuristic."""
|
||||
self.lit_heap = []
|
||||
self.lit_scores = {}
|
||||
|
||||
for var in range(1, len(self.variable_set)):
|
||||
self.lit_scores[var] = float(-self.occurrence_count[var])
|
||||
self.lit_scores[-var] = float(-self.occurrence_count[-var])
|
||||
heappush(self.lit_heap, (self.lit_scores[var], var))
|
||||
heappush(self.lit_heap, (self.lit_scores[-var], -var))
|
||||
|
||||
def _vsids_decay(self):
|
||||
"""Decay the VSIDS scores for every literal.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
|
||||
>>> l.lit_scores
|
||||
{-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0}
|
||||
|
||||
>>> l._vsids_decay()
|
||||
|
||||
>>> l.lit_scores
|
||||
{-3: -1.0, -2: -1.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -1.0}
|
||||
|
||||
"""
|
||||
# We divide every literal score by 2 for a decay factor
|
||||
# Note: This doesn't change the heap property
|
||||
for lit in self.lit_scores.keys():
|
||||
self.lit_scores[lit] /= 2.0
|
||||
|
||||
def _vsids_calculate(self):
|
||||
"""
|
||||
VSIDS Heuristic Calculation
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
|
||||
>>> l.lit_heap
|
||||
[(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)]
|
||||
|
||||
>>> l._vsids_calculate()
|
||||
-3
|
||||
|
||||
>>> l.lit_heap
|
||||
[(-2.0, -2), (-2.0, 2), (0.0, -1), (0.0, 1), (-2.0, 3)]
|
||||
|
||||
"""
|
||||
if len(self.lit_heap) == 0:
|
||||
return 0
|
||||
|
||||
# Clean out the front of the heap as long the variables are set
|
||||
while self.variable_set[abs(self.lit_heap[0][1])]:
|
||||
heappop(self.lit_heap)
|
||||
if len(self.lit_heap) == 0:
|
||||
return 0
|
||||
|
||||
return heappop(self.lit_heap)[1]
|
||||
|
||||
def _vsids_lit_assigned(self, lit):
|
||||
"""Handle the assignment of a literal for the VSIDS heuristic."""
|
||||
pass
|
||||
|
||||
def _vsids_lit_unset(self, lit):
|
||||
"""Handle the unsetting of a literal for the VSIDS heuristic.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
>>> l.lit_heap
|
||||
[(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)]
|
||||
|
||||
>>> l._vsids_lit_unset(2)
|
||||
|
||||
>>> l.lit_heap
|
||||
[(-2.0, -3), (-2.0, -2), (-2.0, -2), (-2.0, 2), (-2.0, 3), (0.0, -1),
|
||||
...(-2.0, 2), (0.0, 1)]
|
||||
|
||||
"""
|
||||
var = abs(lit)
|
||||
heappush(self.lit_heap, (self.lit_scores[var], var))
|
||||
heappush(self.lit_heap, (self.lit_scores[-var], -var))
|
||||
|
||||
def _vsids_clause_added(self, cls):
|
||||
"""Handle the addition of a new clause for the VSIDS heuristic.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
|
||||
>>> l.num_learned_clauses
|
||||
0
|
||||
>>> l.lit_scores
|
||||
{-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0}
|
||||
|
||||
>>> l._vsids_clause_added({2, -3})
|
||||
|
||||
>>> l.num_learned_clauses
|
||||
1
|
||||
>>> l.lit_scores
|
||||
{-3: -1.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -2.0}
|
||||
|
||||
"""
|
||||
self.num_learned_clauses += 1
|
||||
for lit in cls:
|
||||
self.lit_scores[lit] += 1
|
||||
|
||||
########################
|
||||
# Clause Learning #
|
||||
########################
|
||||
def _simple_add_learned_clause(self, cls):
|
||||
"""Add a new clause to the theory.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
|
||||
>>> l.num_learned_clauses
|
||||
0
|
||||
>>> l.clauses
|
||||
[[2, -3], [1], [3, -3], [2, -2], [3, -2]]
|
||||
>>> l.sentinels
|
||||
{-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}}
|
||||
|
||||
>>> l._simple_add_learned_clause([3])
|
||||
|
||||
>>> l.clauses
|
||||
[[2, -3], [1], [3, -3], [2, -2], [3, -2], [3]]
|
||||
>>> l.sentinels
|
||||
{-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4, 5}}
|
||||
|
||||
"""
|
||||
cls_num = len(self.clauses)
|
||||
self.clauses.append(cls)
|
||||
|
||||
for lit in cls:
|
||||
self.occurrence_count[lit] += 1
|
||||
|
||||
self.sentinels[cls[0]].add(cls_num)
|
||||
self.sentinels[cls[-1]].add(cls_num)
|
||||
|
||||
self.heur_clause_added(cls)
|
||||
|
||||
def _simple_compute_conflict(self):
|
||||
""" Build a clause representing the fact that at least one decision made
|
||||
so far is wrong.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.logic.algorithms.dpll2 import SATSolver
|
||||
>>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2},
|
||||
... {3, -2}], {1, 2, 3}, set())
|
||||
>>> next(l._find_model())
|
||||
{1: True, 2: False, 3: False}
|
||||
>>> l._simple_compute_conflict()
|
||||
[3]
|
||||
|
||||
"""
|
||||
return [-(level.decision) for level in self.levels[1:]]
|
||||
|
||||
def _simple_clean_clauses(self):
|
||||
"""Clean up learned clauses."""
|
||||
pass
|
||||
|
||||
|
||||
class Level:
|
||||
"""
|
||||
Represents a single level in the DPLL algorithm, and contains
|
||||
enough information for a sound backtracking procedure.
|
||||
"""
|
||||
|
||||
def __init__(self, decision, flipped=False):
|
||||
self.decision = decision
|
||||
self.var_settings = set()
|
||||
self.flipped = flipped
|
||||
@@ -0,0 +1,912 @@
|
||||
"""Implements "A Fast Linear-Arithmetic Solver for DPLL(T)"
|
||||
|
||||
The LRASolver class defined in this file can be used
|
||||
in conjunction with a SAT solver to check the
|
||||
satisfiability of formulas involving inequalities.
|
||||
|
||||
Here's an example of how that would work:
|
||||
|
||||
Suppose you want to check the satisfiability of
|
||||
the following formula:
|
||||
|
||||
>>> from sympy.core.relational import Eq
|
||||
>>> from sympy.abc import x, y
|
||||
>>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & (~Eq(y, 1) | Eq(1, 2))
|
||||
|
||||
First a preprocessing step should be done on f. During preprocessing,
|
||||
f should be checked for any predicates such as `Q.prime` that can't be
|
||||
handled. Also unequality like `~Eq(y, 1)` should be split.
|
||||
|
||||
I should mention that the paper says to split both equalities and
|
||||
unequality, but this implementation only requires that unequality
|
||||
be split.
|
||||
|
||||
>>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & ((y < 1) | (y > 1) | Eq(1, 2))
|
||||
|
||||
Then an LRASolver instance needs to be initialized with this formula.
|
||||
|
||||
>>> from sympy.assumptions.cnf import CNF, EncodedCNF
|
||||
>>> from sympy.assumptions.ask import Q
|
||||
>>> from sympy.logic.algorithms.lra_theory import LRASolver
|
||||
>>> cnf = CNF.from_prop(f)
|
||||
>>> enc = EncodedCNF()
|
||||
>>> enc.add_from_cnf(cnf)
|
||||
>>> lra, conflicts = LRASolver.from_encoded_cnf(enc)
|
||||
|
||||
Any immediate one-lital conflicts clauses will be detected here.
|
||||
In this example, `~Eq(1, 2)` is one such conflict clause. We'll
|
||||
want to add it to `f` so that the SAT solver is forced to
|
||||
assign Eq(1, 2) to False.
|
||||
|
||||
>>> f = f & ~Eq(1, 2)
|
||||
|
||||
Now that the one-literal conflict clauses have been added
|
||||
and an lra object has been initialized, we can pass `f`
|
||||
to a SAT solver. The SAT solver will give us a satisfying
|
||||
assignment such as:
|
||||
|
||||
(1 = 2): False
|
||||
(y = 1): True
|
||||
(y < 1): True
|
||||
(y > 1): True
|
||||
(x = 0): True
|
||||
(x < 0): True
|
||||
(x > 0): True
|
||||
|
||||
Next you would pass this assignment to the LRASolver
|
||||
which will be able to determine that this particular
|
||||
assignment is satisfiable or not.
|
||||
|
||||
Note that since EncodedCNF is inherently non-deterministic,
|
||||
the int each predicate is encoded as is not consistent. As a
|
||||
result, the code below likely does not reflect the assignment
|
||||
given above.
|
||||
|
||||
>>> lra.assert_lit(-1) #doctest: +SKIP
|
||||
>>> lra.assert_lit(2) #doctest: +SKIP
|
||||
>>> lra.assert_lit(3) #doctest: +SKIP
|
||||
>>> lra.assert_lit(4) #doctest: +SKIP
|
||||
>>> lra.assert_lit(5) #doctest: +SKIP
|
||||
>>> lra.assert_lit(6) #doctest: +SKIP
|
||||
>>> lra.assert_lit(7) #doctest: +SKIP
|
||||
>>> is_sat, conflict_or_assignment = lra.check()
|
||||
|
||||
As the particular assignment suggested is not satisfiable,
|
||||
the LRASolver will return unsat and a conflict clause when
|
||||
given that assignment. The conflict clause will always be
|
||||
minimal, but there can be multiple minimal conflict clauses.
|
||||
One possible conflict clause could be `~(x < 0) | ~(x > 0)`.
|
||||
|
||||
We would then add whatever conflict clause is given to
|
||||
`f` to prevent the SAT solver from coming up with an
|
||||
assignment with the same conflicting literals. In this case,
|
||||
the conflict clause `~(x < 0) | ~(x > 0)` would prevent
|
||||
any assignment where both (x < 0) and (x > 0) were both
|
||||
true.
|
||||
|
||||
The SAT solver would then find another assignment
|
||||
and we would check that assignment with the LRASolver
|
||||
and so on. Eventually either a satisfying assignment
|
||||
that the SAT solver and LRASolver agreed on would be found
|
||||
or enough conflict clauses would be added so that the
|
||||
boolean formula was unsatisfiable.
|
||||
|
||||
|
||||
This implementation is based on [1]_, which includes a
|
||||
detailed explanation of the algorithm and pseudocode
|
||||
for the most important functions.
|
||||
|
||||
[1]_ also explains how backtracking and theory propagation
|
||||
could be implemented to speed up the current implementation,
|
||||
but these are not currently implemented.
|
||||
|
||||
TODO:
|
||||
- Handle non-rational real numbers
|
||||
- Handle positive and negative infinity
|
||||
- Implement backtracking and theory proposition
|
||||
- Simplify matrix by removing unused variables using Gaussian elimination
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] Dutertre, B., de Moura, L.:
|
||||
A Fast Linear-Arithmetic Solver for DPLL(T)
|
||||
https://link.springer.com/chapter/10.1007/11817963_11
|
||||
"""
|
||||
from sympy.solvers.solveset import linear_eq_to_matrix
|
||||
from sympy.matrices.dense import eye
|
||||
from sympy.assumptions import Predicate
|
||||
from sympy.assumptions.assume import AppliedPredicate
|
||||
from sympy.assumptions.ask import Q
|
||||
from sympy.core import Dummy
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.core.add import Add
|
||||
from sympy.core.relational import Eq, Ne
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.numbers import Rational, oo
|
||||
from sympy.matrices.dense import Matrix
|
||||
|
||||
class UnhandledInput(Exception):
|
||||
"""
|
||||
Raised while creating an LRASolver if non-linearity
|
||||
or non-rational numbers are present.
|
||||
"""
|
||||
|
||||
# predicates that LRASolver understands and makes use of
|
||||
ALLOWED_PRED = {Q.eq, Q.gt, Q.lt, Q.le, Q.ge}
|
||||
|
||||
# if true ~Q.gt(x, y) implies Q.le(x, y)
|
||||
HANDLE_NEGATION = True
|
||||
|
||||
class LRASolver():
|
||||
"""
|
||||
Linear Arithmetic Solver for DPLL(T) implemented with an algorithm based on
|
||||
the Dual Simplex method. Uses Bland's pivoting rule to avoid cycling.
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] Dutertre, B., de Moura, L.:
|
||||
A Fast Linear-Arithmetic Solver for DPLL(T)
|
||||
https://link.springer.com/chapter/10.1007/11817963_11
|
||||
"""
|
||||
|
||||
def __init__(self, A, slack_variables, nonslack_variables, enc_to_boundary, s_subs, testing_mode):
|
||||
"""
|
||||
Use the "from_encoded_cnf" method to create a new LRASolver.
|
||||
"""
|
||||
self.run_checks = testing_mode
|
||||
self.s_subs = s_subs # used only for test_lra_theory.test_random_problems
|
||||
|
||||
if any(not isinstance(a, Rational) for a in A):
|
||||
raise UnhandledInput("Non-rational numbers are not handled")
|
||||
if any(not isinstance(b.bound, Rational) for b in enc_to_boundary.values()):
|
||||
raise UnhandledInput("Non-rational numbers are not handled")
|
||||
m, n = len(slack_variables), len(slack_variables)+len(nonslack_variables)
|
||||
if m != 0:
|
||||
assert A.shape == (m, n)
|
||||
if self.run_checks:
|
||||
assert A[:, n-m:] == -eye(m)
|
||||
|
||||
self.enc_to_boundary = enc_to_boundary # mapping of int to Boundary objects
|
||||
self.boundary_to_enc = {value: key for key, value in enc_to_boundary.items()}
|
||||
self.A = A
|
||||
self.slack = slack_variables
|
||||
self.nonslack = nonslack_variables
|
||||
self.all_var = nonslack_variables + slack_variables
|
||||
|
||||
self.slack_set = set(slack_variables)
|
||||
|
||||
self.is_sat = True # While True, all constraints asserted so far are satisfiable
|
||||
self.result = None # always one of: (True, assignment), (False, conflict clause), None
|
||||
|
||||
@staticmethod
|
||||
def from_encoded_cnf(encoded_cnf, testing_mode=False):
|
||||
"""
|
||||
Creates an LRASolver from an EncodedCNF object
|
||||
and a list of conflict clauses for propositions
|
||||
that can be simplified to True or False.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
encoded_cnf : EncodedCNF
|
||||
|
||||
testing_mode : bool
|
||||
Setting testing_mode to True enables some slow assert statements
|
||||
and sorting to reduce nonterministic behavior.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
(lra, conflicts)
|
||||
|
||||
lra : LRASolver
|
||||
|
||||
conflicts : list
|
||||
Contains a one-literal conflict clause for each proposition
|
||||
that can be simplified to True or False.
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
>>> from sympy.core.relational import Eq
|
||||
>>> from sympy.assumptions.cnf import CNF, EncodedCNF
|
||||
>>> from sympy.assumptions.ask import Q
|
||||
>>> from sympy.logic.algorithms.lra_theory import LRASolver
|
||||
>>> from sympy.abc import x, y, z
|
||||
>>> phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6))
|
||||
>>> phi = phi & (Eq(x + y, 2) | (x + 2 * y - z > 4))
|
||||
>>> phi = phi & Q.gt(2, 1)
|
||||
>>> cnf = CNF.from_prop(phi)
|
||||
>>> enc = EncodedCNF()
|
||||
>>> enc.from_cnf(cnf)
|
||||
>>> lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True)
|
||||
>>> lra #doctest: +SKIP
|
||||
<sympy.logic.algorithms.lra_theory.LRASolver object at 0x7fdcb0e15b70>
|
||||
>>> conflicts #doctest: +SKIP
|
||||
[[4]]
|
||||
"""
|
||||
# This function has three main jobs:
|
||||
# - raise errors if the input formula is not handled
|
||||
# - preprocesses the formula into a matrix and single variable constraints
|
||||
# - create one-literal conflict clauses from predicates that are always True
|
||||
# or always False such as Q.gt(3, 2)
|
||||
#
|
||||
# See the preprocessing section of "A Fast Linear-Arithmetic Solver for DPLL(T)"
|
||||
# for an explanation of how the formula is converted into a matrix
|
||||
# and a set of single variable constraints.
|
||||
|
||||
encoding = {} # maps int to boundary
|
||||
A = []
|
||||
|
||||
basic = []
|
||||
s_count = 0
|
||||
s_subs = {}
|
||||
nonbasic = []
|
||||
|
||||
if testing_mode:
|
||||
# sort to reduce nondeterminism
|
||||
encoded_cnf_items = sorted(encoded_cnf.encoding.items(), key=lambda x: str(x))
|
||||
else:
|
||||
encoded_cnf_items = encoded_cnf.encoding.items()
|
||||
|
||||
empty_var = Dummy()
|
||||
var_to_lra_var = {}
|
||||
conflicts = []
|
||||
|
||||
for prop, enc in encoded_cnf_items:
|
||||
if isinstance(prop, Predicate):
|
||||
prop = prop(empty_var)
|
||||
if not isinstance(prop, AppliedPredicate):
|
||||
if prop == True:
|
||||
conflicts.append([enc])
|
||||
continue
|
||||
if prop == False:
|
||||
conflicts.append([-enc])
|
||||
continue
|
||||
|
||||
raise ValueError(f"Unhandled Predicate: {prop}")
|
||||
|
||||
assert prop.function in ALLOWED_PRED
|
||||
if prop.lhs == S.NaN or prop.rhs == S.NaN:
|
||||
raise ValueError(f"{prop} contains nan")
|
||||
if prop.lhs.is_imaginary or prop.rhs.is_imaginary:
|
||||
raise UnhandledInput(f"{prop} contains an imaginary component")
|
||||
if prop.lhs == oo or prop.rhs == oo:
|
||||
raise UnhandledInput(f"{prop} contains infinity")
|
||||
|
||||
prop = _eval_binrel(prop) # simplify variable-less quantities to True / False if possible
|
||||
if prop == True:
|
||||
conflicts.append([enc])
|
||||
continue
|
||||
elif prop == False:
|
||||
conflicts.append([-enc])
|
||||
continue
|
||||
elif prop is None:
|
||||
raise UnhandledInput(f"{prop} could not be simplified")
|
||||
|
||||
expr = prop.lhs - prop.rhs
|
||||
if prop.function in [Q.ge, Q.gt]:
|
||||
expr = -expr
|
||||
|
||||
# expr should be less than (or equal to) 0
|
||||
# otherwise prop is False
|
||||
if prop.function in [Q.le, Q.ge]:
|
||||
bool = (expr <= 0)
|
||||
elif prop.function in [Q.lt, Q.gt]:
|
||||
bool = (expr < 0)
|
||||
else:
|
||||
assert prop.function == Q.eq
|
||||
bool = Eq(expr, 0)
|
||||
|
||||
if bool == True:
|
||||
conflicts.append([enc])
|
||||
continue
|
||||
elif bool == False:
|
||||
conflicts.append([-enc])
|
||||
continue
|
||||
|
||||
|
||||
vars, const = _sep_const_terms(expr) # example: (2x + 3y + 2) --> (2x + 3y), (2)
|
||||
vars, var_coeff = _sep_const_coeff(vars) # examples: (2x) --> (x, 2); (2x + 3y) --> (2x + 3y), (1)
|
||||
const = const / var_coeff
|
||||
|
||||
terms = _list_terms(vars) # example: (2x + 3y) --> [2x, 3y]
|
||||
for term in terms:
|
||||
term, _ = _sep_const_coeff(term)
|
||||
assert len(term.free_symbols) > 0
|
||||
if term not in var_to_lra_var:
|
||||
var_to_lra_var[term] = LRAVariable(term)
|
||||
nonbasic.append(term)
|
||||
|
||||
if len(terms) > 1:
|
||||
if vars not in s_subs:
|
||||
s_count += 1
|
||||
d = Dummy(f"s{s_count}")
|
||||
var_to_lra_var[d] = LRAVariable(d)
|
||||
basic.append(d)
|
||||
s_subs[vars] = d
|
||||
A.append(vars - d)
|
||||
var = s_subs[vars]
|
||||
else:
|
||||
var = terms[0]
|
||||
|
||||
assert var_coeff != 0
|
||||
|
||||
equality = prop.function == Q.eq
|
||||
upper = var_coeff > 0 if not equality else None
|
||||
strict = prop.function in [Q.gt, Q.lt]
|
||||
b = Boundary(var_to_lra_var[var], -const, upper, equality, strict)
|
||||
encoding[enc] = b
|
||||
|
||||
fs = [v.free_symbols for v in nonbasic + basic]
|
||||
assert all(len(syms) > 0 for syms in fs)
|
||||
fs_count = sum(len(syms) for syms in fs)
|
||||
if len(fs) > 0 and len(set.union(*fs)) < fs_count:
|
||||
raise UnhandledInput("Nonlinearity is not handled")
|
||||
|
||||
A, _ = linear_eq_to_matrix(A, nonbasic + basic)
|
||||
nonbasic = [var_to_lra_var[nb] for nb in nonbasic]
|
||||
basic = [var_to_lra_var[b] for b in basic]
|
||||
for idx, var in enumerate(nonbasic + basic):
|
||||
var.col_idx = idx
|
||||
|
||||
return LRASolver(A, basic, nonbasic, encoding, s_subs, testing_mode), conflicts
|
||||
|
||||
def reset_bounds(self):
|
||||
"""
|
||||
Resets the state of the LRASolver to before
|
||||
anything was asserted.
|
||||
"""
|
||||
self.result = None
|
||||
for var in self.all_var:
|
||||
var.lower = LRARational(-float("inf"), 0)
|
||||
var.lower_from_eq = False
|
||||
var.lower_from_neg = False
|
||||
var.upper = LRARational(float("inf"), 0)
|
||||
var.upper_from_eq= False
|
||||
var.lower_from_neg = False
|
||||
var.assign = LRARational(0, 0)
|
||||
|
||||
def assert_lit(self, enc_constraint):
|
||||
"""
|
||||
Assert a literal representing a constraint
|
||||
and update the internal state accordingly.
|
||||
|
||||
Note that due to peculiarities of this implementation
|
||||
asserting ~(x > 0) will assert (x <= 0) but asserting
|
||||
~Eq(x, 0) will not do anything.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
enc_constraint : int
|
||||
A mapping of encodings to constraints
|
||||
can be found in `self.enc_to_boundary`.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
None or (False, explanation)
|
||||
|
||||
explanation : set of ints
|
||||
A conflict clause that "explains" why
|
||||
the literals asserted so far are unsatisfiable.
|
||||
"""
|
||||
if abs(enc_constraint) not in self.enc_to_boundary:
|
||||
return None
|
||||
|
||||
if not HANDLE_NEGATION and enc_constraint < 0:
|
||||
return None
|
||||
|
||||
boundary = self.enc_to_boundary[abs(enc_constraint)]
|
||||
sym, c, negated = boundary.var, boundary.bound, enc_constraint < 0
|
||||
|
||||
if boundary.equality and negated:
|
||||
return None # negated equality is not handled and should only appear in conflict clauses
|
||||
|
||||
upper = boundary.upper != negated
|
||||
if boundary.strict != negated:
|
||||
delta = -1 if upper else 1
|
||||
c = LRARational(c, delta)
|
||||
else:
|
||||
c = LRARational(c, 0)
|
||||
|
||||
if boundary.equality:
|
||||
res1 = self._assert_lower(sym, c, from_equality=True, from_neg=negated)
|
||||
if res1 and res1[0] == False:
|
||||
res = res1
|
||||
else:
|
||||
res2 = self._assert_upper(sym, c, from_equality=True, from_neg=negated)
|
||||
res = res2
|
||||
elif upper:
|
||||
res = self._assert_upper(sym, c, from_neg=negated)
|
||||
else:
|
||||
res = self._assert_lower(sym, c, from_neg=negated)
|
||||
|
||||
if self.is_sat and sym not in self.slack_set:
|
||||
self.is_sat = res is None
|
||||
else:
|
||||
self.is_sat = False
|
||||
|
||||
return res
|
||||
|
||||
def _assert_upper(self, xi, ci, from_equality=False, from_neg=False):
|
||||
"""
|
||||
Adjusts the upper bound on variable xi if the new upper bound is
|
||||
more limiting. The assignment of variable xi is adjusted to be
|
||||
within the new bound if needed.
|
||||
|
||||
Also calls `self._update` to update the assignment for slack variables
|
||||
to keep all equalities satisfied.
|
||||
"""
|
||||
if self.result:
|
||||
assert self.result[0] != False
|
||||
self.result = None
|
||||
if ci >= xi.upper:
|
||||
return None
|
||||
if ci < xi.lower:
|
||||
assert (xi.lower[1] >= 0) is True
|
||||
assert (ci[1] <= 0) is True
|
||||
|
||||
lit1, neg1 = Boundary.from_lower(xi)
|
||||
|
||||
lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=True, equality=from_equality)
|
||||
if from_neg:
|
||||
lit2 = lit2.get_negated()
|
||||
neg2 = -1 if from_neg else 1
|
||||
|
||||
conflict = [-neg1*self.boundary_to_enc[lit1], -neg2*self.boundary_to_enc[lit2]]
|
||||
self.result = False, conflict
|
||||
return self.result
|
||||
xi.upper = ci
|
||||
xi.upper_from_eq = from_equality
|
||||
xi.upper_from_neg = from_neg
|
||||
if xi in self.nonslack and xi.assign > ci:
|
||||
self._update(xi, ci)
|
||||
|
||||
if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf")
|
||||
for v in self.all_var):
|
||||
M = self.A
|
||||
X = Matrix([v.assign[0] for v in self.all_var])
|
||||
assert all(abs(val) < 10 ** (-10) for val in M * X)
|
||||
|
||||
return None
|
||||
|
||||
def _assert_lower(self, xi, ci, from_equality=False, from_neg=False):
|
||||
"""
|
||||
Adjusts the lower bound on variable xi if the new lower bound is
|
||||
more limiting. The assignment of variable xi is adjusted to be
|
||||
within the new bound if needed.
|
||||
|
||||
Also calls `self._update` to update the assignment for slack variables
|
||||
to keep all equalities satisfied.
|
||||
"""
|
||||
if self.result:
|
||||
assert self.result[0] != False
|
||||
self.result = None
|
||||
if ci <= xi.lower:
|
||||
return None
|
||||
if ci > xi.upper:
|
||||
assert (xi.upper[1] <= 0) is True
|
||||
assert (ci[1] >= 0) is True
|
||||
|
||||
lit1, neg1 = Boundary.from_upper(xi)
|
||||
|
||||
lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=False, equality=from_equality)
|
||||
if from_neg:
|
||||
lit2 = lit2.get_negated()
|
||||
neg2 = -1 if from_neg else 1
|
||||
|
||||
conflict = [-neg1*self.boundary_to_enc[lit1],-neg2*self.boundary_to_enc[lit2]]
|
||||
self.result = False, conflict
|
||||
return self.result
|
||||
xi.lower = ci
|
||||
xi.lower_from_eq = from_equality
|
||||
xi.lower_from_neg = from_neg
|
||||
if xi in self.nonslack and xi.assign < ci:
|
||||
self._update(xi, ci)
|
||||
|
||||
if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf")
|
||||
for v in self.all_var):
|
||||
M = self.A
|
||||
X = Matrix([v.assign[0] for v in self.all_var])
|
||||
assert all(abs(val) < 10 ** (-10) for val in M * X)
|
||||
|
||||
return None
|
||||
|
||||
def _update(self, xi, v):
|
||||
"""
|
||||
Updates all slack variables that have equations that contain
|
||||
variable xi so that they stay satisfied given xi is equal to v.
|
||||
"""
|
||||
i = xi.col_idx
|
||||
for j, b in enumerate(self.slack):
|
||||
aji = self.A[j, i]
|
||||
b.assign = b.assign + (v - xi.assign)*aji
|
||||
xi.assign = v
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
Searches for an assignment that satisfies all constraints
|
||||
or determines that no such assignment exists and gives
|
||||
a minimal conflict clause that "explains" why the
|
||||
constraints are unsatisfiable.
|
||||
|
||||
Returns
|
||||
=======
|
||||
|
||||
(True, assignment) or (False, explanation)
|
||||
|
||||
assignment : dict of LRAVariables to values
|
||||
Assigned values are tuples that represent a rational number
|
||||
plus some infinatesimal delta.
|
||||
|
||||
explanation : set of ints
|
||||
"""
|
||||
if self.is_sat:
|
||||
return True, {var: var.assign for var in self.all_var}
|
||||
if self.result:
|
||||
return self.result
|
||||
|
||||
from sympy.matrices.dense import Matrix
|
||||
M = self.A.copy()
|
||||
basic = {s: i for i, s in enumerate(self.slack)} # contains the row index associated with each basic variable
|
||||
nonbasic = set(self.nonslack)
|
||||
while True:
|
||||
if self.run_checks:
|
||||
# nonbasic variables must always be within bounds
|
||||
assert all(((nb.assign >= nb.lower) == True) and ((nb.assign <= nb.upper) == True) for nb in nonbasic)
|
||||
|
||||
# assignments for x must always satisfy Ax = 0
|
||||
# probably have to turn this off when dealing with strict ineq
|
||||
if all(v.assign[0] != float("inf") and v.assign[0] != -float("inf")
|
||||
for v in self.all_var):
|
||||
X = Matrix([v.assign[0] for v in self.all_var])
|
||||
assert all(abs(val) < 10**(-10) for val in M*X)
|
||||
|
||||
# check upper and lower match this format:
|
||||
# x <= rat + delta iff x < rat
|
||||
# x >= rat - delta iff x > rat
|
||||
# this wouldn't make sense:
|
||||
# x <= rat - delta
|
||||
# x >= rat + delta
|
||||
assert all(x.upper[1] <= 0 for x in self.all_var)
|
||||
assert all(x.lower[1] >= 0 for x in self.all_var)
|
||||
|
||||
cand = [b for b in basic if b.assign < b.lower or b.assign > b.upper]
|
||||
|
||||
if len(cand) == 0:
|
||||
return True, {var: var.assign for var in self.all_var}
|
||||
|
||||
xi = min(cand, key=lambda v: v.col_idx) # Bland's rule
|
||||
i = basic[xi]
|
||||
|
||||
if xi.assign < xi.lower:
|
||||
cand = [nb for nb in nonbasic
|
||||
if (M[i, nb.col_idx] > 0 and nb.assign < nb.upper)
|
||||
or (M[i, nb.col_idx] < 0 and nb.assign > nb.lower)]
|
||||
if len(cand) == 0:
|
||||
N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0]
|
||||
N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0]
|
||||
|
||||
conflict = []
|
||||
conflict += [Boundary.from_upper(nb) for nb in N_plus]
|
||||
conflict += [Boundary.from_lower(nb) for nb in N_minus]
|
||||
conflict.append(Boundary.from_lower(xi))
|
||||
conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict]
|
||||
return False, conflict
|
||||
xj = min(cand, key=str)
|
||||
M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.lower)
|
||||
|
||||
if xi.assign > xi.upper:
|
||||
cand = [nb for nb in nonbasic
|
||||
if (M[i, nb.col_idx] < 0 and nb.assign < nb.upper)
|
||||
or (M[i, nb.col_idx] > 0 and nb.assign > nb.lower)]
|
||||
|
||||
if len(cand) == 0:
|
||||
N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0]
|
||||
N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0]
|
||||
|
||||
conflict = []
|
||||
conflict += [Boundary.from_upper(nb) for nb in N_minus]
|
||||
conflict += [Boundary.from_lower(nb) for nb in N_plus]
|
||||
conflict.append(Boundary.from_upper(xi))
|
||||
|
||||
conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict]
|
||||
return False, conflict
|
||||
xj = min(cand, key=lambda v: v.col_idx)
|
||||
M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.upper)
|
||||
|
||||
def _pivot_and_update(self, M, basic, nonbasic, xi, xj, v):
|
||||
"""
|
||||
Pivots basic variable xi with nonbasic variable xj,
|
||||
and sets value of xi to v and adjusts the values of all basic variables
|
||||
to keep equations satisfied.
|
||||
"""
|
||||
i, j = basic[xi], xj.col_idx
|
||||
assert M[i, j] != 0
|
||||
theta = (v - xi.assign)*(1/M[i, j])
|
||||
xi.assign = v
|
||||
xj.assign = xj.assign + theta
|
||||
for xk in basic:
|
||||
if xk != xi:
|
||||
k = basic[xk]
|
||||
akj = M[k, j]
|
||||
xk.assign = xk.assign + theta*akj
|
||||
# pivot
|
||||
basic[xj] = basic[xi]
|
||||
del basic[xi]
|
||||
nonbasic.add(xi)
|
||||
nonbasic.remove(xj)
|
||||
return self._pivot(M, i, j)
|
||||
|
||||
@staticmethod
|
||||
def _pivot(M, i, j):
|
||||
"""
|
||||
Performs a pivot operation about entry i, j of M by performing
|
||||
a series of row operations on a copy of M and returning the result.
|
||||
The original M is left unmodified.
|
||||
|
||||
Conceptually, M represents a system of equations and pivoting
|
||||
can be thought of as rearranging equation i to be in terms of
|
||||
variable j and then substituting in the rest of the equations
|
||||
to get rid of other occurances of variable j.
|
||||
|
||||
Example
|
||||
=======
|
||||
|
||||
>>> from sympy.matrices.dense import Matrix
|
||||
>>> from sympy.logic.algorithms.lra_theory import LRASolver
|
||||
>>> from sympy import var
|
||||
>>> Matrix(3, 3, var('a:i'))
|
||||
Matrix([
|
||||
[a, b, c],
|
||||
[d, e, f],
|
||||
[g, h, i]])
|
||||
|
||||
This matrix is equivalent to:
|
||||
0 = a*x + b*y + c*z
|
||||
0 = d*x + e*y + f*z
|
||||
0 = g*x + h*y + i*z
|
||||
|
||||
>>> LRASolver._pivot(_, 1, 0)
|
||||
Matrix([
|
||||
[ 0, -a*e/d + b, -a*f/d + c],
|
||||
[-1, -e/d, -f/d],
|
||||
[ 0, h - e*g/d, i - f*g/d]])
|
||||
|
||||
We rearrange equation 1 in terms of variable 0 (x)
|
||||
and substitute to remove x from the other equations.
|
||||
|
||||
0 = 0 + (-a*e/d + b)*y + (-a*f/d + c)*z
|
||||
0 = -x + (-e/d)*y + (-f/d)*z
|
||||
0 = 0 + (h - e*g/d)*y + (i - f*g/d)*z
|
||||
"""
|
||||
_, _, Mij = M[i, :], M[:, j], M[i, j]
|
||||
if Mij == 0:
|
||||
raise ZeroDivisionError("Tried to pivot about zero-valued entry.")
|
||||
A = M.copy()
|
||||
A[i, :] = -A[i, :]/Mij
|
||||
for row in range(M.shape[0]):
|
||||
if row != i:
|
||||
A[row, :] = A[row, :] + A[row, j] * A[i, :]
|
||||
|
||||
return A
|
||||
|
||||
|
||||
def _sep_const_coeff(expr):
|
||||
"""
|
||||
Example
|
||||
=======
|
||||
|
||||
>>> from sympy.logic.algorithms.lra_theory import _sep_const_coeff
|
||||
>>> from sympy.abc import x, y
|
||||
>>> _sep_const_coeff(2*x)
|
||||
(x, 2)
|
||||
>>> _sep_const_coeff(2*x + 3*y)
|
||||
(2*x + 3*y, 1)
|
||||
"""
|
||||
if isinstance(expr, Add):
|
||||
return expr, sympify(1)
|
||||
|
||||
if isinstance(expr, Mul):
|
||||
coeffs = expr.args
|
||||
else:
|
||||
coeffs = [expr]
|
||||
|
||||
var, const = [], []
|
||||
for c in coeffs:
|
||||
c = sympify(c)
|
||||
if len(c.free_symbols)==0:
|
||||
const.append(c)
|
||||
else:
|
||||
var.append(c)
|
||||
return Mul(*var), Mul(*const)
|
||||
|
||||
|
||||
def _list_terms(expr):
|
||||
if not isinstance(expr, Add):
|
||||
return [expr]
|
||||
|
||||
return expr.args
|
||||
|
||||
|
||||
def _sep_const_terms(expr):
|
||||
"""
|
||||
Example
|
||||
=======
|
||||
|
||||
>>> from sympy.logic.algorithms.lra_theory import _sep_const_terms
|
||||
>>> from sympy.abc import x, y
|
||||
>>> _sep_const_terms(2*x + 3*y + 2)
|
||||
(2*x + 3*y, 2)
|
||||
"""
|
||||
if isinstance(expr, Add):
|
||||
terms = expr.args
|
||||
else:
|
||||
terms = [expr]
|
||||
|
||||
var, const = [], []
|
||||
for t in terms:
|
||||
if len(t.free_symbols) == 0:
|
||||
const.append(t)
|
||||
else:
|
||||
var.append(t)
|
||||
return sum(var), sum(const)
|
||||
|
||||
|
||||
def _eval_binrel(binrel):
|
||||
"""
|
||||
Simplify binary relation to True / False if possible.
|
||||
"""
|
||||
if not (len(binrel.lhs.free_symbols) == 0 and len(binrel.rhs.free_symbols) == 0):
|
||||
return binrel
|
||||
if binrel.function == Q.lt:
|
||||
res = binrel.lhs < binrel.rhs
|
||||
elif binrel.function == Q.gt:
|
||||
res = binrel.lhs > binrel.rhs
|
||||
elif binrel.function == Q.le:
|
||||
res = binrel.lhs <= binrel.rhs
|
||||
elif binrel.function == Q.ge:
|
||||
res = binrel.lhs >= binrel.rhs
|
||||
elif binrel.function == Q.eq:
|
||||
res = Eq(binrel.lhs, binrel.rhs)
|
||||
elif binrel.function == Q.ne:
|
||||
res = Ne(binrel.lhs, binrel.rhs)
|
||||
|
||||
if res == True or res == False:
|
||||
return res
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class Boundary:
|
||||
"""
|
||||
Represents an upper or lower bound or an equality between a symbol
|
||||
and some constant.
|
||||
"""
|
||||
def __init__(self, var, const, upper, equality, strict=None):
|
||||
if not equality in [True, False]:
|
||||
assert equality in [True, False]
|
||||
|
||||
|
||||
self.var = var
|
||||
if isinstance(const, tuple):
|
||||
s = const[1] != 0
|
||||
if strict:
|
||||
assert s == strict
|
||||
self.bound = const[0]
|
||||
self.strict = s
|
||||
else:
|
||||
self.bound = const
|
||||
self.strict = strict
|
||||
self.upper = upper if not equality else None
|
||||
self.equality = equality
|
||||
self.strict = strict
|
||||
assert self.strict is not None
|
||||
|
||||
@staticmethod
|
||||
def from_upper(var):
|
||||
neg = -1 if var.upper_from_neg else 1
|
||||
b = Boundary(var, var.upper[0], True, var.upper_from_eq, var.upper[1] != 0)
|
||||
if neg < 0:
|
||||
b = b.get_negated()
|
||||
return b, neg
|
||||
|
||||
@staticmethod
|
||||
def from_lower(var):
|
||||
neg = -1 if var.lower_from_neg else 1
|
||||
b = Boundary(var, var.lower[0], False, var.lower_from_eq, var.lower[1] != 0)
|
||||
if neg < 0:
|
||||
b = b.get_negated()
|
||||
return b, neg
|
||||
|
||||
def get_negated(self):
|
||||
return Boundary(self.var, self.bound, not self.upper, self.equality, not self.strict)
|
||||
|
||||
def get_inequality(self):
|
||||
if self.equality:
|
||||
return Eq(self.var.var, self.bound)
|
||||
elif self.upper and self.strict:
|
||||
return self.var.var < self.bound
|
||||
elif not self.upper and self.strict:
|
||||
return self.var.var > self.bound
|
||||
elif self.upper:
|
||||
return self.var.var <= self.bound
|
||||
else:
|
||||
return self.var.var >= self.bound
|
||||
|
||||
def __repr__(self):
|
||||
return repr("Boundary(" + repr(self.get_inequality()) + ")")
|
||||
|
||||
def __eq__(self, other):
|
||||
other = (other.var, other.bound, other.strict, other.upper, other.equality)
|
||||
return (self.var, self.bound, self.strict, self.upper, self.equality) == other
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.var, self.bound, self.strict, self.upper, self.equality))
|
||||
|
||||
|
||||
class LRARational():
|
||||
"""
|
||||
Represents a rational plus or minus some amount
|
||||
of arbitrary small deltas.
|
||||
"""
|
||||
def __init__(self, rational, delta):
|
||||
self.value = (rational, delta)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.value < other.value
|
||||
|
||||
def __le__(self, other):
|
||||
return self.value <= other.value
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.value == other.value
|
||||
|
||||
def __add__(self, other):
|
||||
return LRARational(self.value[0] + other.value[0], self.value[1] + other.value[1])
|
||||
|
||||
def __sub__(self, other):
|
||||
return LRARational(self.value[0] - other.value[0], self.value[1] - other.value[1])
|
||||
|
||||
def __mul__(self, other):
|
||||
assert not isinstance(other, LRARational)
|
||||
return LRARational(self.value[0] * other, self.value[1] * other)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.value[index]
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.value)
|
||||
|
||||
|
||||
class LRAVariable():
|
||||
"""
|
||||
Object to keep track of upper and lower bounds
|
||||
on `self.var`.
|
||||
"""
|
||||
def __init__(self, var):
|
||||
self.upper = LRARational(float("inf"), 0)
|
||||
self.upper_from_eq = False
|
||||
self.upper_from_neg = False
|
||||
self.lower = LRARational(-float("inf"), 0)
|
||||
self.lower_from_eq = False
|
||||
self.lower_from_neg = False
|
||||
self.assign = LRARational(0,0)
|
||||
self.var = var
|
||||
self.col_idx = None
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.var)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, LRAVariable):
|
||||
return False
|
||||
return other.var == self.var
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.var)
|
||||
@@ -0,0 +1,46 @@
|
||||
from sympy.assumptions.cnf import EncodedCNF
|
||||
|
||||
def minisat22_satisfiable(expr, all_models=False, minimal=False):
|
||||
|
||||
if not isinstance(expr, EncodedCNF):
|
||||
exprs = EncodedCNF()
|
||||
exprs.add_prop(expr)
|
||||
expr = exprs
|
||||
|
||||
from pysat.solvers import Minisat22
|
||||
|
||||
# Return UNSAT when False (encoded as 0) is present in the CNF
|
||||
if {0} in expr.data:
|
||||
if all_models:
|
||||
return (f for f in [False])
|
||||
return False
|
||||
|
||||
r = Minisat22(expr.data)
|
||||
|
||||
if minimal:
|
||||
r.set_phases([-(i+1) for i in range(r.nof_vars())])
|
||||
|
||||
if not r.solve():
|
||||
return False
|
||||
|
||||
if not all_models:
|
||||
return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r.get_model()}
|
||||
|
||||
else:
|
||||
# Make solutions SymPy compatible by creating a generator
|
||||
def _gen(results):
|
||||
satisfiable = False
|
||||
while results.solve():
|
||||
sol = results.get_model()
|
||||
yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol}
|
||||
if minimal:
|
||||
results.add_clause([-i for i in sol if i>0])
|
||||
else:
|
||||
results.add_clause([-i for i in sol])
|
||||
satisfiable = True
|
||||
if not satisfiable:
|
||||
yield False
|
||||
raise StopIteration
|
||||
|
||||
|
||||
return _gen(r)
|
||||
@@ -0,0 +1,41 @@
|
||||
from sympy.assumptions.cnf import EncodedCNF
|
||||
|
||||
|
||||
def pycosat_satisfiable(expr, all_models=False):
|
||||
import pycosat
|
||||
if not isinstance(expr, EncodedCNF):
|
||||
exprs = EncodedCNF()
|
||||
exprs.add_prop(expr)
|
||||
expr = exprs
|
||||
|
||||
# Return UNSAT when False (encoded as 0) is present in the CNF
|
||||
if {0} in expr.data:
|
||||
if all_models:
|
||||
return (f for f in [False])
|
||||
return False
|
||||
|
||||
if not all_models:
|
||||
r = pycosat.solve(expr.data)
|
||||
result = (r != "UNSAT")
|
||||
if not result:
|
||||
return result
|
||||
return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r}
|
||||
else:
|
||||
r = pycosat.itersolve(expr.data)
|
||||
result = (r != "UNSAT")
|
||||
if not result:
|
||||
return result
|
||||
|
||||
# Make solutions SymPy compatible by creating a generator
|
||||
def _gen(results):
|
||||
satisfiable = False
|
||||
try:
|
||||
while True:
|
||||
sol = next(results)
|
||||
yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol}
|
||||
satisfiable = True
|
||||
except StopIteration:
|
||||
if not satisfiable:
|
||||
yield False
|
||||
|
||||
return _gen(r)
|
||||
@@ -0,0 +1,115 @@
|
||||
from sympy.printing.smtlib import smtlib_code
|
||||
from sympy.assumptions.assume import AppliedPredicate
|
||||
from sympy.assumptions.cnf import EncodedCNF
|
||||
from sympy.assumptions.ask import Q
|
||||
|
||||
from sympy.core import Add, Mul
|
||||
from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.exponential import Pow
|
||||
from sympy.functions.elementary.miscellaneous import Min, Max
|
||||
from sympy.logic.boolalg import And, Or, Xor, Implies
|
||||
from sympy.logic.boolalg import Not, ITE
|
||||
from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate
|
||||
from sympy.external import import_module
|
||||
|
||||
def z3_satisfiable(expr, all_models=False):
|
||||
if not isinstance(expr, EncodedCNF):
|
||||
exprs = EncodedCNF()
|
||||
exprs.add_prop(expr)
|
||||
expr = exprs
|
||||
|
||||
z3 = import_module("z3")
|
||||
if z3 is None:
|
||||
raise ImportError("z3 is not installed")
|
||||
|
||||
s = encoded_cnf_to_z3_solver(expr, z3)
|
||||
|
||||
res = str(s.check())
|
||||
if res == "unsat":
|
||||
return False
|
||||
elif res == "sat":
|
||||
return z3_model_to_sympy_model(s.model(), expr)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def z3_model_to_sympy_model(z3_model, enc_cnf):
|
||||
rev_enc = {value : key for key, value in enc_cnf.encoding.items()}
|
||||
return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model}
|
||||
|
||||
|
||||
def clause_to_assertion(clause):
|
||||
clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause]
|
||||
return "(assert (or " + " ".join(clause_strings) + "))"
|
||||
|
||||
|
||||
def encoded_cnf_to_z3_solver(enc_cnf, z3):
|
||||
def dummify_bool(pred):
|
||||
return False
|
||||
assert isinstance(pred, AppliedPredicate)
|
||||
|
||||
if pred.function in [Q.positive, Q.negative, Q.zero]:
|
||||
return pred
|
||||
else:
|
||||
return False
|
||||
|
||||
s = z3.Solver()
|
||||
|
||||
declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables]
|
||||
assertions = [clause_to_assertion(clause) for clause in enc_cnf.data]
|
||||
|
||||
symbols = set()
|
||||
for pred, enc in enc_cnf.encoding.items():
|
||||
if not isinstance(pred, AppliedPredicate):
|
||||
continue
|
||||
if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive):
|
||||
continue
|
||||
|
||||
pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions)
|
||||
|
||||
symbols |= pred.free_symbols
|
||||
pred = pred_str
|
||||
clause = f"(implies d{enc} {pred})"
|
||||
assertion = "(assert " + clause + ")"
|
||||
assertions.append(assertion)
|
||||
|
||||
for sym in symbols:
|
||||
declarations.append(f"(declare-const {sym} Real)")
|
||||
|
||||
declarations = "\n".join(declarations)
|
||||
assertions = "\n".join(assertions)
|
||||
s.from_string(declarations)
|
||||
s.from_string(assertions)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
known_functions = {
|
||||
Add: '+',
|
||||
Mul: '*',
|
||||
|
||||
Equality: '=',
|
||||
LessThan: '<=',
|
||||
GreaterThan: '>=',
|
||||
StrictLessThan: '<',
|
||||
StrictGreaterThan: '>',
|
||||
|
||||
EqualityPredicate(): '=',
|
||||
LessThanPredicate(): '<=',
|
||||
GreaterThanPredicate(): '>=',
|
||||
StrictLessThanPredicate(): '<',
|
||||
StrictGreaterThanPredicate(): '>',
|
||||
|
||||
Abs: 'abs',
|
||||
Min: 'min',
|
||||
Max: 'max',
|
||||
Pow: '^',
|
||||
|
||||
And: 'and',
|
||||
Or: 'or',
|
||||
Xor: 'xor',
|
||||
Not: 'not',
|
||||
ITE: 'ite',
|
||||
Implies: '=>',
|
||||
}
|
||||
Reference in New Issue
Block a user