M2_SETI/IA/seti_master-master/code/formula_lang.py
2023-01-29 16:56:40 +01:00

275 lines
8.5 KiB
Python

"""A small module to express formulas and write them in SMTLIB2,
Marabou input format, PyRAT input format. Conjunction of linear constraints,
with logical operators if necessary"""
from fractions import Fraction
import re
DIST = 'x0'
SPEED = 'x1'
ANGLE = 'x2'
OUTPUT = 'y0'
GEQ = '>='
GQ = '>'
LT = '<'
AND = 'and'
OR = 'or'
class Atom():
"""A basic atom for a formula"""
def __init__(self):
self.var = None
self.real = None
def is_real(self):
return self.real is not None
def is_var(self):
return self.var is not None
class Real(Atom):
"""A real number"""
def __init__(self, x):
super(Real, self).__init__()
if isinstance(x, (float, int)):
self.real = x
else:
print('Error, Real can only be of int or float')
raise ValueError
def __str__(self):
return str(self.real)
def to_smtlib(self):
frac = Fraction(self.real)
return '(/ {} {})'.format(frac.numerator, frac.denominator)
def to_marabou(self):
return ' {} '.format(self.real)
def to_pyrat(self):
return ' {} '.format(self.real)
class Var(Atom):
"""A variable with a name"""
def __init__(self, x):
super(Var, self).__init__()
if isinstance(x, str):
if re.search("^[xy][0-9]$", x):
self.var = x
else:
raise ValueError("Input should be of the form x[0-9] or y[0-9]")
else:
raise ValueError("Input should be of the form x[0-9] or y[0-9]")
def __str__(self):
return str(self.var)
def to_smtlib(self):
if self.var == DIST:
return '|CELL_actual_input_0_0_0|'
elif self.var == SPEED:
return'|CELL_actual_input_0_0_1|'
elif self.var == ANGLE:
return'|CELL_actual_input_0_0_2|'
else: # only remaining is Output
return '|CELL_actual_output_0_0_0|'
def to_marabou(self):
return self.var
def to_pyrat(self):
return self.var
class Constr():
"""A linear constraint"""
def __init__(self, a, bop, b):
"""a and b are atoms, bop is either '>=' or '<'"""
if not (isinstance(a, Atom) and isinstance(b, Atom)):
print('Error, either a or b are not atoms. \
Create atoms using Var or Real constructors.')
raise ValueError
else:
if bop == GEQ or bop == LT:
self.constr = (a, bop, b)
else:
print('Error, second argument should be either >= or <')
raise ValueError
def __str__(self):
a, bop, b = self.constr
return '{} {} {}'.format(str(a), bop, str(b))
def to_smtlib(self):
"""write a constraint to SMTLIB format"""
a, bop, b = self.constr
if "y" in str(a) or "y" in str(b):
bop = GEQ if bop == LT else LT
return '({} {} {})'.format(bop, a.to_smtlib(), b.to_smtlib())
def to_marabou(self):
"""write a constraint to marabou format"""
a, bop, b = self.constr
if "y" in str(a) or "y" in str(b):
bop = GEQ if bop == LT else LT
if bop == GEQ:
return '{} {} {}'.format(a.to_marabou(), bop, b.to_marabou())
else:
# print('Warning: Marabou does not support strict inequality, \
# using Lower or Equal ( <= ) instead')
return '{} {} {}'.format(a.to_marabou(), '<=', b.to_marabou())
def to_pyrat(self):
a, bop, b = self.constr
return '{} {} {}'.format(a.to_pyrat(), bop, b.to_pyrat())
def well_formed(c):
"""Check if c is a well formed clause"""
if len(c) == 3:
cx = c[0]
lop = c[1]
cy = c[2]
if lop in (AND, OR):
return isinstance(cx, Constr) and isinstance(cy, Constr)
else:
print('Error, second argument should be AND | OR')
return False
class Formula():
"""A formula composed of a conjunction of constraints"""
def __init__(self):
"""A formula is empty at the beginning"""
self.storage = []
def __str__(self):
to_display = ''
for clause in self.storage:
if isinstance(clause, Constr):
to_display += '{}; '.format(str(clause))
else:
cx, lop, cy = clause
to_display += '({} {} {}); '.format(str(cx), str(lop),
str(cy))
return to_display
def add(self, c):
"""Add a disjunction to the formula with the constr"""
if isinstance(c, Constr):
self.storage.append(c)
elif well_formed(c):
self.storage.append((c[0], c[1], c[2]))
else:
print('Error, input must either be an instance of Constr or \
a tuple (formula, AND|OR, constr)')
def to_smtlib(self):
"""Write down the formula to SMTLIB format"""
to_write = ''
for clause in self.storage:
if isinstance(clause, Constr):
tmp = clause.to_smtlib()
else:
cx, lop, cy = clause
tmp = '({} {} {})'.format(
str(lop), cx.to_smtlib(), cy.to_smtlib())
to_write += '(assert {})\n'.format(tmp)
return to_write
def to_marabou(self):
"""Write down the formula to Marabou format"""
to_write = ''
for clause in self.storage:
if isinstance(clause, Constr):
to_write += '{}\n'.format(clause.to_marabou())
else:
# Marabou only deal with AND logical operators
cx, lop, cy = clause
if lop == AND:
to_write += '{}\n'.format(cx.to_marabou())
to_write += '{}\n'.format(cy.to_marabou())
else:
print('Error, Marabou output format does not support \
OR operator. Aborting conversion.')
raise ValueError
return to_write
def to_pyrat(self):
"""Write down the formula to PyRAT format"""
to_write = ''
for clause in self.storage:
if isinstance(clause, Constr):
to_write += '{}\n'.format(clause.to_pyrat())
else:
# PyRAT only deal with AND logical operators
cx, lop, cy = clause
if lop == AND:
to_write += '{}\n'.format(cx.to_pyrat())
to_write += '{}\n'.format(cy.to_pyrat())
else:
print('Error, PyRAT output format does not support \
OR operator. Aborting conversion.')
raise ValueError
return to_write
def write_smtlib(self, fpath='formula.smt2'):
with open(fpath, 'w') as f:
to_write = self.to_smtlib()
to_write += '(check-sat)\n'
to_write += '(get-value (|CELL_actual_input_0_0_0|))\n'
to_write += '(get-value (|CELL_actual_input_0_0_1|))\n'
to_write += '(get-value (|CELL_actual_input_0_0_2|))\n'
to_write += '(get-value (|CELL_actual_output_0_0_0|))\n'
f.write(to_write)
print('Wrote SMT formula in file {}'.format(fpath))
def write_marabou(self, fpath='formula.marabou'):
with open(fpath, 'w') as f:
to_write = self.to_marabou()
f.write(to_write)
print('Wrote marabou formula in file {}'.format(fpath))
def write_pyrat(self, fpath='formula.txt'):
with open(fpath, 'w') as f:
to_write = self.to_pyrat()
f.write(to_write)
print('Wrote pyrat formula in file {}'.format(fpath))
def formula_p1():
distance = Var('x0')
angle1 = Var('x1')
angle2 = Var('x2')
speed1 = Var('x3')
speed2 = Var('x4')
output = Var('y0')
constrs = []
constrs.append(Constr(distance, '>=', Real(0.6)))
constrs.append(Constr(distance, '<', Real(0.6798577687)))
constrs.append(Constr(angle1, '>=', Real(-0.5)))
constrs.append(Constr(angle1, '<', Real(0.5)))
constrs.append(Constr(angle2, '>=', Real(-0.5)))
constrs.append(Constr(angle2, '<', Real(0.5)))
constrs.append(Constr(speed1, '>=', Real(0.45)))
constrs.append(Constr(speed1, '<', Real(0.5)))
constrs.append(Constr(speed2, '>=', Real(-0.5)))
constrs.append(Constr(speed2, '<', Real(-0.45)))
constrs.append(Constr(output, '<', Real(3.9911256459)))
formula = Formula()
for c in constrs:
formula.add(c)
return formula