5b2dc800d11d19889b61d4659189ddf71f1f71c3
6 from fractions
import Fraction
, gcd
9 from pypol
.isl
import libisl
13 'Expression', 'Constant', 'Symbol', 'symbols',
14 'eq', 'le', 'lt', 'ge', 'gt',
20 def _polymorphic_method(func
):
21 @functools.wraps(func
)
23 if isinstance(b
, Expression
):
25 if isinstance(b
, numbers
.Rational
):
31 def _polymorphic_operator(func
):
32 # A polymorphic operator should call a polymorphic method, hence we just
33 # have to test the left operand.
34 @functools.wraps(func
)
36 if isinstance(a
, numbers
.Rational
):
39 elif isinstance(a
, Expression
):
41 raise TypeError('arguments must be linear expressions')
45 _main_ctx
= isl
.Context()
50 This class implements linear expressions.
60 def __new__(cls
, coefficients
=None, constant
=0):
61 if isinstance(coefficients
, str):
63 raise TypeError('too many arguments')
64 return cls
.fromstring(coefficients
)
65 if isinstance(coefficients
, dict):
66 coefficients
= coefficients
.items()
67 if coefficients
is None:
68 return Constant(constant
)
69 coefficients
= [(symbol
, coefficient
)
70 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
71 if len(coefficients
) == 0:
72 return Constant(constant
)
73 elif len(coefficients
) == 1 and constant
== 0:
74 symbol
, coefficient
= coefficients
[0]
77 self
= object().__new
__(cls
)
78 self
._coefficients
= {}
79 for symbol
, coefficient
in coefficients
:
80 if isinstance(symbol
, Symbol
):
82 elif not isinstance(symbol
, str):
83 raise TypeError('symbols must be strings or Symbol instances')
84 if isinstance(coefficient
, Constant
):
85 coefficient
= coefficient
.constant
86 if not isinstance(coefficient
, numbers
.Rational
):
87 raise TypeError('coefficients must be rational numbers or Constant instances')
88 self
._coefficients
[symbol
] = coefficient
89 if isinstance(constant
, Constant
):
90 constant
= constant
.constant
91 if not isinstance(constant
, numbers
.Rational
):
92 raise TypeError('constant must be a rational number or a Constant instance')
93 self
._constant
= constant
94 self
._symbols
= tuple(sorted(self
._coefficients
))
95 self
._dimension
= len(self
._symbols
)
99 def _fromast(cls
, node
):
100 if isinstance(node
, ast
.Module
):
101 assert len(node
.body
) == 1
102 return cls
._fromast
(node
.body
[0])
103 elif isinstance(node
, ast
.Expr
):
104 return cls
._fromast
(node
.value
)
105 elif isinstance(node
, ast
.Name
):
106 return Symbol(node
.id)
107 elif isinstance(node
, ast
.Num
):
108 return Constant(node
.n
)
109 elif isinstance(node
, ast
.UnaryOp
):
110 if isinstance(node
.op
, ast
.USub
):
111 return -cls
._fromast
(node
.operand
)
112 elif isinstance(node
, ast
.BinOp
):
113 left
= cls
._fromast
(node
.left
)
114 right
= cls
._fromast
(node
.right
)
115 if isinstance(node
.op
, ast
.Add
):
117 elif isinstance(node
.op
, ast
.Sub
):
119 elif isinstance(node
.op
, ast
.Mult
):
121 elif isinstance(node
.op
, ast
.Div
):
123 raise SyntaxError('invalid syntax')
126 def fromstring(cls
, string
):
127 string
= re
.sub(r
'(\d+|\))\s*([^\W\d_]\w*|\()', r
'\1*\2', string
)
128 tree
= ast
.parse(string
, 'eval')
129 return cls
._fromast
(tree
)
137 return self
._dimension
139 def coefficient(self
, symbol
):
140 if isinstance(symbol
, Symbol
):
142 elif not isinstance(symbol
, str):
143 raise TypeError('symbol must be a string or a Symbol instance')
145 return self
._coefficients
[symbol
]
149 __getitem__
= coefficient
151 def coefficients(self
):
152 for symbol
in self
.symbols
:
153 yield symbol
, self
.coefficient(symbol
)
157 return self
._constant
159 def isconstant(self
):
163 for symbol
in self
.symbols
:
164 yield self
.coefficient(symbol
)
180 def __add__(self
, other
):
181 coefficients
= dict(self
.coefficients())
182 for symbol
, coefficient
in other
.coefficients():
183 if symbol
in coefficients
:
184 coefficients
[symbol
] += coefficient
186 coefficients
[symbol
] = coefficient
187 constant
= self
.constant
+ other
.constant
188 return Expression(coefficients
, constant
)
193 def __sub__(self
, other
):
194 coefficients
= dict(self
.coefficients())
195 for symbol
, coefficient
in other
.coefficients():
196 if symbol
in coefficients
:
197 coefficients
[symbol
] -= coefficient
199 coefficients
[symbol
] = -coefficient
200 constant
= self
.constant
- other
.constant
201 return Expression(coefficients
, constant
)
203 def __rsub__(self
, other
):
204 return -(self
- other
)
207 def __mul__(self
, other
):
208 if other
.isconstant():
209 coefficients
= dict(self
.coefficients())
210 for symbol
in coefficients
:
211 coefficients
[symbol
] *= other
.constant
212 constant
= self
.constant
* other
.constant
213 return Expression(coefficients
, constant
)
214 if isinstance(other
, Expression
) and not self
.isconstant():
215 raise ValueError('non-linear expression: '
216 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
217 return NotImplemented
222 def __truediv__(self
, other
):
223 if other
.isconstant():
224 coefficients
= dict(self
.coefficients())
225 for symbol
in coefficients
:
226 coefficients
[symbol
] = \
227 Fraction(coefficients
[symbol
], other
.constant
)
228 constant
= Fraction(self
.constant
, other
.constant
)
229 return Expression(coefficients
, constant
)
230 if isinstance(other
, Expression
):
231 raise ValueError('non-linear expression: '
232 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
233 return NotImplemented
235 def __rtruediv__(self
, other
):
236 if isinstance(other
, self
):
237 if self
.isconstant():
238 constant
= Fraction(other
, self
.constant
)
239 return Expression(constant
=constant
)
241 raise ValueError('non-linear expression: '
242 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
243 return NotImplemented
248 for symbol
in self
.symbols
:
249 coefficient
= self
.coefficient(symbol
)
254 string
+= ' + {}'.format(symbol
)
255 elif coefficient
== -1:
257 string
+= '-{}'.format(symbol
)
259 string
+= ' - {}'.format(symbol
)
262 string
+= '{}*{}'.format(coefficient
, symbol
)
263 elif coefficient
> 0:
264 string
+= ' + {}*{}'.format(coefficient
, symbol
)
266 assert coefficient
< 0
268 string
+= ' - {}*{}'.format(coefficient
, symbol
)
270 constant
= self
.constant
271 if constant
!= 0 and i
== 0:
272 string
+= '{}'.format(constant
)
274 string
+= ' + {}'.format(constant
)
277 string
+= ' - {}'.format(constant
)
282 def _parenstr(self
, always
=False):
284 if not always
and (self
.isconstant() or self
.issymbol()):
287 return '({})'.format(string
)
290 string
= '{}({{'.format(self
.__class
__.__name
__)
291 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
294 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
295 string
+= '}}, {!r})'.format(self
.constant
)
299 def __eq__(self
, other
):
301 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
302 return isinstance(other
, Expression
) and \
303 self
._coefficients
== other
._coefficients
and \
304 self
.constant
== other
.constant
307 return hash((tuple(sorted(self
._coefficients
.items())), self
._constant
))
310 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
311 [value
.denominator
for value
in self
.values()])
315 def _eq(self
, other
):
316 return Polyhedron(equalities
=[(self
- other
)._toint
()])
319 def __le__(self
, other
):
320 return Polyhedron(inequalities
=[(other
- self
)._toint
()])
323 def __lt__(self
, other
):
324 return Polyhedron(inequalities
=[(other
- self
)._toint
() - 1])
327 def __ge__(self
, other
):
328 return Polyhedron(inequalities
=[(self
- other
)._toint
()])
331 def __gt__(self
, other
):
332 return Polyhedron(inequalities
=[(self
- other
)._toint
() - 1])
335 class Constant(Expression
):
337 def __new__(cls
, numerator
=0, denominator
=None):
338 self
= object().__new
__(cls
)
339 if denominator
is None:
340 if isinstance(numerator
, numbers
.Rational
):
341 self
._constant
= numerator
342 elif isinstance(numerator
, Constant
):
343 self
._constant
= numerator
.constant
345 raise TypeError('constant must be a rational number or a Constant instance')
347 self
._constant
= Fraction(numerator
, denominator
)
348 self
._coefficients
= {}
353 def isconstant(self
):
357 return bool(self
.constant
)
360 return '{}({!r})'.format(self
.__class
__.__name
__, self
._constant
)
363 class Symbol(Expression
):
365 __slots__
= Expression
.__slots
__ + (
369 def __new__(cls
, name
):
370 if isinstance(name
, Symbol
):
372 elif not isinstance(name
, str):
373 raise TypeError('name must be a string or a Symbol instance')
374 self
= object().__new
__(cls
)
375 self
._coefficients
= {name
: 1}
377 self
._symbols
= tuple(name
)
390 return '{}({!r})'.format(self
.__class
__.__name
__, self
._name
)
393 if isinstance(names
, str):
394 names
= names
.replace(',', ' ').split()
395 return (Symbol(name
) for name
in names
)
398 @_polymorphic_operator
402 @_polymorphic_operator
406 @_polymorphic_operator
410 @_polymorphic_operator
414 @_polymorphic_operator
421 This class implements polyhedrons.
431 def __new__(cls
, equalities
=None, inequalities
=None):
432 if isinstance(equalities
, str):
433 if inequalities
is not None:
434 raise TypeError('too many arguments')
435 return cls
.fromstring(equalities
)
436 self
= super().__new
__(cls
)
437 self
._equalities
= []
438 if equalities
is not None:
439 for constraint
in equalities
:
440 for value
in constraint
.values():
441 if value
.denominator
!= 1:
442 raise TypeError('non-integer constraint: '
443 '{} == 0'.format(constraint
))
444 self
._equalities
.append(constraint
)
445 self
._equalities
= tuple(self
._equalities
)
446 self
._inequalities
= []
447 if inequalities
is not None:
448 for constraint
in inequalities
:
449 for value
in constraint
.values():
450 if value
.denominator
!= 1:
451 raise TypeError('non-integer constraint: '
452 '{} <= 0'.format(constraint
))
453 self
._inequalities
.append(constraint
)
454 self
._inequalities
= tuple(self
._inequalities
)
455 self
._constraints
= self
._equalities
+ self
._inequalities
456 self
._symbols
= set()
457 for constraint
in self
._constraints
:
458 self
.symbols
.update(constraint
.symbols
)
459 self
._symbols
= tuple(sorted(self
._symbols
))
463 def fromstring(cls
, string
):
464 string
= string
.strip()
465 string
= re
.sub(r
'^\{\s*|\s*\}$', '', string
)
466 string
= re
.sub(r
'([^<=>])=([^<=>])', r
'\1==\2', string
)
467 string
= re
.sub(r
'(\d+|\))\s*([^\W\d_]\w*|\()', r
'\1*\2', string
)
470 for cstr
in re
.split(r
',|;|and|&&|/\\|∧', string
, flags
=re
.I
):
471 tree
= ast
.parse(cstr
.strip(), 'eval')
472 if not isinstance(tree
, ast
.Module
) or len(tree
.body
) != 1:
473 raise SyntaxError('invalid syntax')
475 if not isinstance(node
, ast
.Expr
):
476 raise SyntaxError('invalid syntax')
478 if not isinstance(node
, ast
.Compare
):
479 raise SyntaxError('invalid syntax')
480 left
= Expression
._fromast
(node
.left
)
481 for i
in range(len(node
.ops
)):
483 right
= Expression
._fromast
(node
.comparators
[i
])
484 if isinstance(op
, ast
.Lt
):
485 inequalities
.append(right
- left
- 1)
486 elif isinstance(op
, ast
.LtE
):
487 inequalities
.append(right
- left
)
488 elif isinstance(op
, ast
.Eq
):
489 equalities
.append(left
- right
)
490 elif isinstance(op
, ast
.GtE
):
491 inequalities
.append(left
- right
)
492 elif isinstance(op
, ast
.Gt
):
493 inequalities
.append(left
- right
- 1)
495 raise SyntaxError('invalid syntax')
497 return cls(equalities
, inequalities
)
500 def equalities(self
):
501 return self
._equalities
504 def inequalities(self
):
505 return self
._inequalities
508 def constraints(self
):
509 return self
._constraints
517 return len(self
.symbols
)
520 return not self
.is_empty()
522 def __contains__(self
, value
):
523 # is the value in the polyhedron?
524 raise NotImplementedError
526 def __eq__(self
, other
):
527 # works correctly when symbols is not passed
528 # should be equal if values are the same even if symbols are different
530 other
= other
._toisl
()
531 return bool(libisl
.isl_basic_set_plain_is_equal(bset
, other
))
535 return bool(libisl
.isl_basic_set_is_empty(bset
))
537 def isuniverse(self
):
539 return bool(libisl
.isl_basic_set_is_universe(bset
))
541 def isdisjoint(self
, other
):
542 # return true if the polyhedron has no elements in common with other
543 #symbols = self._symbolunion(other)
545 other
= other
._toisl
()
546 return bool(libisl
.isl_set_is_disjoint(bset
, other
))
548 def issubset(self
, other
):
549 # check if self(bset) is a subset of other
550 symbols
= self
._symbolunion
(other
)
551 bset
= self
._toisl
(symbols
)
552 other
= other
._toisl
(symbols
)
553 return bool(libisl
.isl_set_is_strict_subset(other
, bset
))
555 def __le__(self
, other
):
556 return self
.issubset(other
)
558 def __lt__(self
, other
):
559 symbols
= self
._symbolunion
(other
)
560 bset
= self
._toisl
(symbols
)
561 other
= other
._toisl
(symbols
)
562 return bool(libisl
.isl_set_is_strict_subset(other
, bset
))
564 def issuperset(self
, other
):
565 # test whether every element in other is in the polyhedron
566 raise NotImplementedError
568 def __ge__(self
, other
):
569 return self
.issuperset(other
)
571 def __gt__(self
, other
):
572 symbols
= self
._symbolunion
(other
)
573 bset
= self
._toisl
(symbols
)
574 other
= other
._toisl
(symbols
)
575 bool(libisl
.isl_set_is_strict_subset(other
, bset
))
576 raise NotImplementedError
578 def union(self
, *others
):
579 # return a new polyhedron with elements from the polyhedron and all
580 # others (convex union)
581 raise NotImplementedError
583 def __or__(self
, other
):
584 return self
.union(other
)
586 def intersection(self
, *others
):
587 # return a new polyhedron with elements common to the polyhedron and all
589 # a poor man's implementation could be:
590 # equalities = list(self.equalities)
591 # inequalities = list(self.inequalities)
592 # for other in others:
593 # equalities.extend(other.equalities)
594 # inequalities.extend(other.inequalities)
595 # return self.__class__(equalities, inequalities)
596 raise NotImplementedError
598 def __and__(self
, other
):
599 return self
.intersection(other
)
601 def difference(self
, other
):
602 # return a new polyhedron with elements in the polyhedron that are not in the other
603 symbols
= self
._symbolunion
(other
)
604 bset
= self
._toisl
(symbols
)
605 other
= other
._toisl
(symbols
)
606 difference
= libisl
.isl_set_subtract(bset
, other
)
609 def __sub__(self
, other
):
610 return self
.difference(other
)
614 for constraint
in self
.equalities
:
615 constraints
.append('{} == 0'.format(constraint
))
616 for constraint
in self
.inequalities
:
617 constraints
.append('{} >= 0'.format(constraint
))
618 return '{{{}}}'.format(', '.join(constraints
))
623 elif self
.isuniverse():
626 equalities
= list(self
.equalities
)
627 inequalities
= list(self
.inequalities
)
628 return '{}(equalities={!r}, inequalities={!r})' \
629 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
631 def _symbolunion(self
, *others
):
632 symbols
= set(self
.symbols
)
634 symbols
.update(other
.symbols
)
635 return sorted(symbols
)
637 def _toisl(self
, symbols
=None):
639 symbols
= self
.symbols
640 dimension
= len(symbols
)
641 space
= libisl
.isl_space_set_alloc(_main_ctx
, 0, dimension
)
642 bset
= libisl
.isl_basic_set_universe(libisl
.isl_space_copy(space
))
643 ls
= libisl
.isl_local_space_from_space(space
)
644 for equality
in self
.equalities
:
645 ceq
= libisl
.isl_equality_alloc(libisl
.isl_local_space_copy(ls
))
646 for symbol
, coefficient
in equality
.coefficients():
647 val
= str(coefficient
).encode()
648 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
649 dim
= symbols
.index(symbol
)
650 ceq
= libisl
.isl_constraint_set_coefficient_val(ceq
, libisl
.isl_dim_set
, dim
, val
)
651 if equality
.constant
!= 0:
652 val
= str(equality
.constant
).encode()
653 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
654 ceq
= libisl
.isl_constraint_set_constant_val(ceq
, val
)
655 bset
= libisl
.isl_basic_set_add_constraint(bset
, ceq
)
656 for inequality
in self
.inequalities
:
657 cin
= libisl
.isl_inequality_alloc(libisl
.isl_local_space_copy(ls
))
658 for symbol
, coefficient
in inequality
.coefficients():
659 val
= str(coefficient
).encode()
660 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
661 dim
= symbols
.index(symbol
)
662 cin
= libisl
.isl_constraint_set_coefficient_val(cin
, libisl
.isl_dim_set
, dim
, val
)
663 if inequality
.constant
!= 0:
664 val
= str(inequality
.constant
).encode()
665 val
= libisl
.isl_val_read_from_str(_main_ctx
, val
)
666 cin
= libisl
.isl_constraint_set_constant_val(cin
, val
)
667 bset
= libisl
.isl_basic_set_add_constraint(bset
, cin
)
668 bset
= isl
.BasicSet(bset
)
672 def _fromisl(cls
, bset
, symbols
):
673 raise NotImplementedError
676 return cls(equalities
, inequalities
)
677 '''takes basic set in isl form and puts back into python version of polyhedron
678 isl example code gives isl form as:
679 "{[i] : exists (a : i = 2a and i >= 10 and i <= 42)}")
680 our printer is giving form as:
681 { [i0, i1] : 2i1 >= -2 - i0 } '''
684 Universe
= Polyhedron()
686 if __name__
== '__main__':
687 p1
= Polyhedron('2a + 2b + 1 == 0') # empty
689 p2
= Polyhedron('3x + 2y + 3 == 0') # not empty