6550e92a00269300ccdbbdf90dd509249df01406
5 from fractions
import Fraction
, gcd
10 'constant', 'symbol', 'symbols',
11 'eq', 'le', 'lt', 'ge', 'gt',
19 This class implements linear expressions.
22 def __new__(cls
, coefficients
=None, constant
=0):
23 if isinstance(coefficients
, str):
25 raise TypeError('too many arguments')
26 return cls
.fromstring(coefficients
)
27 self
= super().__new
__(cls
)
28 self
._coefficients
= {}
29 if isinstance(coefficients
, dict):
30 coefficients
= coefficients
.items()
31 if coefficients
is not None:
32 for symbol
, coefficient
in coefficients
:
33 if isinstance(symbol
, Expression
) and symbol
.issymbol():
35 elif not isinstance(symbol
, str):
36 raise TypeError('symbols must be strings')
37 if not isinstance(coefficient
, numbers
.Rational
):
38 raise TypeError('coefficients must be rational numbers')
40 self
._coefficients
[symbol
] = coefficient
41 if not isinstance(constant
, numbers
.Rational
):
42 raise TypeError('constant must be a rational number')
43 self
._constant
= constant
47 yield from sorted(self
._coefficients
)
51 return len(list(self
.symbols()))
53 def coefficient(self
, symbol
):
54 if isinstance(symbol
, Expression
) and symbol
.issymbol():
56 elif not isinstance(symbol
, str):
57 raise TypeError('symbol must be a string')
59 return self
._coefficients
[symbol
]
63 __getitem__
= coefficient
65 def coefficients(self
):
66 for symbol
in self
.symbols():
67 yield symbol
, self
.coefficient(symbol
)
74 return len(self
._coefficients
) == 0
77 for symbol
in self
.symbols():
78 yield self
.coefficient(symbol
)
82 if not self
.issymbol():
83 raise ValueError('not a symbol: {}'.format(self
))
84 for symbol
in self
.symbols():
88 return len(self
._coefficients
) == 1 and self
._constant
== 0
91 return (not self
.isconstant()) or bool(self
.constant
)
99 def _polymorphic(func
):
100 @functools.wraps(func
)
101 def wrapper(self
, other
):
102 if isinstance(other
, Expression
):
103 return func(self
, other
)
104 if isinstance(other
, numbers
.Rational
):
105 other
= Expression(constant
=other
)
106 return func(self
, other
)
107 return NotImplemented
111 def __add__(self
, other
):
112 coefficients
= dict(self
.coefficients())
113 for symbol
, coefficient
in other
.coefficients():
114 if symbol
in coefficients
:
115 coefficients
[symbol
] += coefficient
117 coefficients
[symbol
] = coefficient
118 constant
= self
.constant
+ other
.constant
119 return Expression(coefficients
, constant
)
124 def __sub__(self
, other
):
125 coefficients
= dict(self
.coefficients())
126 for symbol
, coefficient
in other
.coefficients():
127 if symbol
in coefficients
:
128 coefficients
[symbol
] -= coefficient
130 coefficients
[symbol
] = -coefficient
131 constant
= self
.constant
- other
.constant
132 return Expression(coefficients
, constant
)
137 def __mul__(self
, other
):
138 if other
.isconstant():
139 coefficients
= dict(self
.coefficients())
140 for symbol
in coefficients
:
141 coefficients
[symbol
] *= other
.constant
142 constant
= self
.constant
* other
.constant
143 return Expression(coefficients
, constant
)
144 if isinstance(other
, Expression
) and not self
.isconstant():
145 raise ValueError('non-linear expression: '
146 '{} * {}'.format(self
._parenstr
(), other
._parenstr
()))
147 return NotImplemented
152 def __truediv__(self
, other
):
153 if other
.isconstant():
154 coefficients
= dict(self
.coefficients())
155 for symbol
in coefficients
:
156 coefficients
[symbol
] = \
157 Fraction(coefficients
[symbol
], other
.constant
)
158 constant
= Fraction(self
.constant
, other
.constant
)
159 return Expression(coefficients
, constant
)
160 if isinstance(other
, Expression
):
161 raise ValueError('non-linear expression: '
162 '{} / {}'.format(self
._parenstr
(), other
._parenstr
()))
163 return NotImplemented
165 def __rtruediv__(self
, other
):
166 if isinstance(other
, Rational
):
167 if self
.isconstant():
168 constant
= Fraction(other
, self
.constant
)
169 return Expression(constant
=constant
)
171 raise ValueError('non-linear expression: '
172 '{} / {}'.format(other
._parenstr
(), self
._parenstr
()))
173 return NotImplemented
177 symbols
= sorted(self
.symbols())
179 for symbol
in symbols
:
180 coefficient
= self
[symbol
]
185 string
+= ' + {}'.format(symbol
)
186 elif coefficient
== -1:
188 string
+= '-{}'.format(symbol
)
190 string
+= ' - {}'.format(symbol
)
193 string
+= '{}*{}'.format(coefficient
, symbol
)
194 elif coefficient
> 0:
195 string
+= ' + {}*{}'.format(coefficient
, symbol
)
197 assert coefficient
< 0
199 string
+= ' - {}*{}'.format(coefficient
, symbol
)
201 constant
= self
.constant
202 if constant
!= 0 and i
== 0:
203 string
+= '{}'.format(constant
)
205 string
+= ' + {}'.format(constant
)
208 string
+= ' - {}'.format(constant
)
213 def _parenstr(self
, always
=False):
215 if not always
and (self
.isconstant() or self
.issymbol()):
218 return '({})'.format(string
)
221 string
= '{}({{'.format(self
.__class
__.__name
__)
222 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
225 string
+= '{!r}: {!r}'.format(symbol
, coefficient
)
226 string
+= '}}, {!r})'.format(self
.constant
)
230 def fromstring(cls
, string
):
231 raise NotImplementedError
234 def __eq__(self
, other
):
236 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
237 return isinstance(other
, Expression
) and \
238 self
._coefficients
== other
._coefficients
and \
239 self
.constant
== other
.constant
242 return hash((self
._coefficients
, self
._constant
))
245 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
246 [value
.denominator
for value
in self
.values()])
250 def _eq(self
, other
):
251 return Polyhedron(equalities
=[(self
- other
)._canonify
()])
254 def __le__(self
, other
):
255 return Polyhedron(inequalities
=[(self
- other
)._canonify
()])
258 def __lt__(self
, other
):
259 return Polyhedron(inequalities
=[(self
- other
)._canonify
() + 1])
262 def __ge__(self
, other
):
263 return Polyhedron(inequalities
=[(other
- self
)._canonify
()])
266 def __gt__(self
, other
):
267 return Polyhedron(inequalities
=[(other
- self
)._canonify
() + 1])
270 def constant(numerator
=0, denominator
=None):
271 return Expression(constant
=Fraction(numerator
, denominator
))
274 if not isinstance(name
, str):
275 raise TypeError('name must be a string')
276 return Expression(coefficients
={name
: 1})
279 if isinstance(names
, str):
280 names
= names
.replace(',', ' ').split()
281 return (symbol(name
) for name
in names
)
285 @functools.wraps(func
)
287 if isinstance(a
, numbers
.Rational
):
289 if isinstance(b
, numbers
.Rational
):
291 if isinstance(a
, Expression
) and isinstance(b
, Expression
):
293 raise TypeError('arguments must be linear expressions')
319 This class implements polyhedrons.
322 def __new__(cls
, equalities
=None, inequalities
=None):
323 if isinstance(equalities
, str):
324 if inequalities
is not None:
325 raise TypeError('too many arguments')
326 return cls
.fromstring(equalities
)
327 self
= super().__new
__(cls
)
328 self
._equalities
= []
329 if equalities
is not None:
330 for constraint
in equalities
:
331 for value
in constraint
.values():
332 if value
.denominator
!= 1:
333 raise TypeError('non-integer constraint: '
334 '{} == 0'.format(constraint
))
335 self
._equalities
.append(constraint
)
336 self
._inequalities
= []
337 if inequalities
is not None:
338 for constraint
in inequalities
:
339 for value
in constraint
.values():
340 if value
.denominator
!= 1:
341 raise TypeError('non-integer constraint: '
342 '{} <= 0'.format(constraint
))
343 self
._inequalities
.append(constraint
)
347 def equalities(self
):
348 yield from self
._equalities
351 def inequalities(self
):
352 yield from self
._inequalities
354 def constraints(self
):
355 yield from self
.equalities
356 yield from self
.inequalities
360 for constraint
in self
.constraints():
361 s
.update(constraint
.symbols
)
366 return len(self
.symbols())
369 # return false if the polyhedron is empty, true otherwise
370 raise NotImplementedError
372 def __contains__(self
, value
):
373 # is the value in the polyhedron?
374 raise NotImplementedError
376 def __eq__(self
, other
):
377 raise NotImplementedError
382 def isuniverse(self
):
383 return self
== universe
385 def isdisjoint(self
, other
):
386 # return true if the polyhedron has no elements in common with other
387 raise NotImplementedError
389 def issubset(self
, other
):
390 raise NotImplementedError
392 def __le__(self
, other
):
393 return self
.issubset(other
)
395 def __lt__(self
, other
):
396 raise NotImplementedError
398 def issuperset(self
, other
):
399 # test whether every element in other is in the polyhedron
400 raise NotImplementedError
402 def __ge__(self
, other
):
403 return self
.issuperset(other
)
405 def __gt__(self
, other
):
406 raise NotImplementedError
408 def union(self
, *others
):
409 # return a new polyhedron with elements from the polyhedron and all
410 # others (convex union)
411 raise NotImplementedError
413 def __or__(self
, other
):
414 return self
.union(other
)
416 def intersection(self
, *others
):
417 # return a new polyhedron with elements common to the polyhedron and all
419 # a poor man's implementation could be:
420 # equalities = list(self.equalities)
421 # inequalities = list(self.inequalities)
422 # for other in others:
423 # equalities.extend(other.equalities)
424 # inequalities.extend(other.inequalities)
425 # return self.__class__(equalities, inequalities)
426 raise NotImplementedError
428 def __and__(self
, other
):
429 return self
.intersection(other
)
431 def difference(self
, *others
):
432 # return a new polyhedron with elements in the polyhedron that are not
434 raise NotImplementedError
436 def __sub__(self
, other
):
437 return self
.difference(other
)
441 for constraint
in self
.equalities
:
442 constraints
.append('{} == 0'.format(constraint
))
443 for constraint
in self
.inequalities
:
444 constraints
.append('{} <= 0'.format(constraint
))
445 return '{{{}}}'.format(', '.join(constraints
))
448 equalities
= list(self
.equalities
)
449 inequalities
= list(self
.inequalities
)
450 return '{}(equalities={!r}, inequalities={!r})' \
451 ''.format(self
.__class
__.__name
__, equalities
, inequalities
)
454 def fromstring(cls
, string
):
455 raise NotImplementedError
460 universe
= Polyhedron()