6550e92a00269300ccdbbdf90dd509249df01406
[linpy.git] / pypol / linear.py
1
2 import functools
3 import numbers
4
5 from fractions import Fraction, gcd
6
7
8 __all__ = [
9 'Expression',
10 'constant', 'symbol', 'symbols',
11 'eq', 'le', 'lt', 'ge', 'gt',
12 'Polyhedron',
13 'empty', 'universe'
14 ]
15
16
17 class Expression:
18 """
19 This class implements linear expressions.
20 """
21
22 def __new__(cls, coefficients=None, constant=0):
23 if isinstance(coefficients, str):
24 if constant:
25 raise TypeError('too many arguments')
26 return cls.fromstring(coefficients)
27 self = super().__new__(cls)
28 self._coefficients = {}
29 if isinstance(coefficients, dict):
30 coefficients = coefficients.items()
31 if coefficients is not None:
32 for symbol, coefficient in coefficients:
33 if isinstance(symbol, Expression) and symbol.issymbol():
34 symbol = str(symbol)
35 elif not isinstance(symbol, str):
36 raise TypeError('symbols must be strings')
37 if not isinstance(coefficient, numbers.Rational):
38 raise TypeError('coefficients must be rational numbers')
39 if coefficient != 0:
40 self._coefficients[symbol] = coefficient
41 if not isinstance(constant, numbers.Rational):
42 raise TypeError('constant must be a rational number')
43 self._constant = constant
44 return self
45
46 def symbols(self):
47 yield from sorted(self._coefficients)
48
49 @property
50 def dimension(self):
51 return len(list(self.symbols()))
52
53 def coefficient(self, symbol):
54 if isinstance(symbol, Expression) and symbol.issymbol():
55 symbol = str(symbol)
56 elif not isinstance(symbol, str):
57 raise TypeError('symbol must be a string')
58 try:
59 return self._coefficients[symbol]
60 except KeyError:
61 return 0
62
63 __getitem__ = coefficient
64
65 def coefficients(self):
66 for symbol in self.symbols():
67 yield symbol, self.coefficient(symbol)
68
69 @property
70 def constant(self):
71 return self._constant
72
73 def isconstant(self):
74 return len(self._coefficients) == 0
75
76 def values(self):
77 for symbol in self.symbols():
78 yield self.coefficient(symbol)
79 yield self.constant
80
81 def symbol(self):
82 if not self.issymbol():
83 raise ValueError('not a symbol: {}'.format(self))
84 for symbol in self.symbols():
85 return symbol
86
87 def issymbol(self):
88 return len(self._coefficients) == 1 and self._constant == 0
89
90 def __bool__(self):
91 return (not self.isconstant()) or bool(self.constant)
92
93 def __pos__(self):
94 return self
95
96 def __neg__(self):
97 return self * -1
98
99 def _polymorphic(func):
100 @functools.wraps(func)
101 def wrapper(self, other):
102 if isinstance(other, Expression):
103 return func(self, other)
104 if isinstance(other, numbers.Rational):
105 other = Expression(constant=other)
106 return func(self, other)
107 return NotImplemented
108 return wrapper
109
110 @_polymorphic
111 def __add__(self, other):
112 coefficients = dict(self.coefficients())
113 for symbol, coefficient in other.coefficients():
114 if symbol in coefficients:
115 coefficients[symbol] += coefficient
116 else:
117 coefficients[symbol] = coefficient
118 constant = self.constant + other.constant
119 return Expression(coefficients, constant)
120
121 __radd__ = __add__
122
123 @_polymorphic
124 def __sub__(self, other):
125 coefficients = dict(self.coefficients())
126 for symbol, coefficient in other.coefficients():
127 if symbol in coefficients:
128 coefficients[symbol] -= coefficient
129 else:
130 coefficients[symbol] = -coefficient
131 constant = self.constant - other.constant
132 return Expression(coefficients, constant)
133
134 __rsub__ = __sub__
135
136 @_polymorphic
137 def __mul__(self, other):
138 if other.isconstant():
139 coefficients = dict(self.coefficients())
140 for symbol in coefficients:
141 coefficients[symbol] *= other.constant
142 constant = self.constant * other.constant
143 return Expression(coefficients, constant)
144 if isinstance(other, Expression) and not self.isconstant():
145 raise ValueError('non-linear expression: '
146 '{} * {}'.format(self._parenstr(), other._parenstr()))
147 return NotImplemented
148
149 __rmul__ = __mul__
150
151 @_polymorphic
152 def __truediv__(self, other):
153 if other.isconstant():
154 coefficients = dict(self.coefficients())
155 for symbol in coefficients:
156 coefficients[symbol] = \
157 Fraction(coefficients[symbol], other.constant)
158 constant = Fraction(self.constant, other.constant)
159 return Expression(coefficients, constant)
160 if isinstance(other, Expression):
161 raise ValueError('non-linear expression: '
162 '{} / {}'.format(self._parenstr(), other._parenstr()))
163 return NotImplemented
164
165 def __rtruediv__(self, other):
166 if isinstance(other, Rational):
167 if self.isconstant():
168 constant = Fraction(other, self.constant)
169 return Expression(constant=constant)
170 else:
171 raise ValueError('non-linear expression: '
172 '{} / {}'.format(other._parenstr(), self._parenstr()))
173 return NotImplemented
174
175 def __str__(self):
176 string = ''
177 symbols = sorted(self.symbols())
178 i = 0
179 for symbol in symbols:
180 coefficient = self[symbol]
181 if coefficient == 1:
182 if i == 0:
183 string += symbol
184 else:
185 string += ' + {}'.format(symbol)
186 elif coefficient == -1:
187 if i == 0:
188 string += '-{}'.format(symbol)
189 else:
190 string += ' - {}'.format(symbol)
191 else:
192 if i == 0:
193 string += '{}*{}'.format(coefficient, symbol)
194 elif coefficient > 0:
195 string += ' + {}*{}'.format(coefficient, symbol)
196 else:
197 assert coefficient < 0
198 coefficient *= -1
199 string += ' - {}*{}'.format(coefficient, symbol)
200 i += 1
201 constant = self.constant
202 if constant != 0 and i == 0:
203 string += '{}'.format(constant)
204 elif constant > 0:
205 string += ' + {}'.format(constant)
206 elif constant < 0:
207 constant *= -1
208 string += ' - {}'.format(constant)
209 if string == '':
210 string = '0'
211 return string
212
213 def _parenstr(self, always=False):
214 string = str(self)
215 if not always and (self.isconstant() or self.issymbol()):
216 return string
217 else:
218 return '({})'.format(string)
219
220 def __repr__(self):
221 string = '{}({{'.format(self.__class__.__name__)
222 for i, (symbol, coefficient) in enumerate(self.coefficients()):
223 if i != 0:
224 string += ', '
225 string += '{!r}: {!r}'.format(symbol, coefficient)
226 string += '}}, {!r})'.format(self.constant)
227 return string
228
229 @classmethod
230 def fromstring(cls, string):
231 raise NotImplementedError
232
233 @_polymorphic
234 def __eq__(self, other):
235 # "normal" equality
236 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
237 return isinstance(other, Expression) and \
238 self._coefficients == other._coefficients and \
239 self.constant == other.constant
240
241 def __hash__(self):
242 return hash((self._coefficients, self._constant))
243
244 def _canonify(self):
245 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
246 [value.denominator for value in self.values()])
247 return self * lcm
248
249 @_polymorphic
250 def _eq(self, other):
251 return Polyhedron(equalities=[(self - other)._canonify()])
252
253 @_polymorphic
254 def __le__(self, other):
255 return Polyhedron(inequalities=[(self - other)._canonify()])
256
257 @_polymorphic
258 def __lt__(self, other):
259 return Polyhedron(inequalities=[(self - other)._canonify() + 1])
260
261 @_polymorphic
262 def __ge__(self, other):
263 return Polyhedron(inequalities=[(other - self)._canonify()])
264
265 @_polymorphic
266 def __gt__(self, other):
267 return Polyhedron(inequalities=[(other - self)._canonify() + 1])
268
269
270 def constant(numerator=0, denominator=None):
271 return Expression(constant=Fraction(numerator, denominator))
272
273 def symbol(name):
274 if not isinstance(name, str):
275 raise TypeError('name must be a string')
276 return Expression(coefficients={name: 1})
277
278 def symbols(names):
279 if isinstance(names, str):
280 names = names.replace(',', ' ').split()
281 return (symbol(name) for name in names)
282
283
284 def _operator(func):
285 @functools.wraps(func)
286 def wrapper(a, b):
287 if isinstance(a, numbers.Rational):
288 a = constant(a)
289 if isinstance(b, numbers.Rational):
290 b = constant(b)
291 if isinstance(a, Expression) and isinstance(b, Expression):
292 return func(a, b)
293 raise TypeError('arguments must be linear expressions')
294 return wrapper
295
296 @_operator
297 def eq(a, b):
298 return a._eq(b)
299
300 @_operator
301 def le(a, b):
302 return a <= b
303
304 @_operator
305 def lt(a, b):
306 return a < b
307
308 @_operator
309 def ge(a, b):
310 return a >= b
311
312 @_operator
313 def gt(a, b):
314 return a > b
315
316
317 class Polyhedron:
318 """
319 This class implements polyhedrons.
320 """
321
322 def __new__(cls, equalities=None, inequalities=None):
323 if isinstance(equalities, str):
324 if inequalities is not None:
325 raise TypeError('too many arguments')
326 return cls.fromstring(equalities)
327 self = super().__new__(cls)
328 self._equalities = []
329 if equalities is not None:
330 for constraint in equalities:
331 for value in constraint.values():
332 if value.denominator != 1:
333 raise TypeError('non-integer constraint: '
334 '{} == 0'.format(constraint))
335 self._equalities.append(constraint)
336 self._inequalities = []
337 if inequalities is not None:
338 for constraint in inequalities:
339 for value in constraint.values():
340 if value.denominator != 1:
341 raise TypeError('non-integer constraint: '
342 '{} <= 0'.format(constraint))
343 self._inequalities.append(constraint)
344 return self
345
346 @property
347 def equalities(self):
348 yield from self._equalities
349
350 @property
351 def inequalities(self):
352 yield from self._inequalities
353
354 def constraints(self):
355 yield from self.equalities
356 yield from self.inequalities
357
358 def symbols(self):
359 s = set()
360 for constraint in self.constraints():
361 s.update(constraint.symbols)
362 yield from sorted(s)
363
364 @property
365 def dimension(self):
366 return len(self.symbols())
367
368 def __bool__(self):
369 # return false if the polyhedron is empty, true otherwise
370 raise NotImplementedError
371
372 def __contains__(self, value):
373 # is the value in the polyhedron?
374 raise NotImplementedError
375
376 def __eq__(self, other):
377 raise NotImplementedError
378
379 def isempty(self):
380 return self == empty
381
382 def isuniverse(self):
383 return self == universe
384
385 def isdisjoint(self, other):
386 # return true if the polyhedron has no elements in common with other
387 raise NotImplementedError
388
389 def issubset(self, other):
390 raise NotImplementedError
391
392 def __le__(self, other):
393 return self.issubset(other)
394
395 def __lt__(self, other):
396 raise NotImplementedError
397
398 def issuperset(self, other):
399 # test whether every element in other is in the polyhedron
400 raise NotImplementedError
401
402 def __ge__(self, other):
403 return self.issuperset(other)
404
405 def __gt__(self, other):
406 raise NotImplementedError
407
408 def union(self, *others):
409 # return a new polyhedron with elements from the polyhedron and all
410 # others (convex union)
411 raise NotImplementedError
412
413 def __or__(self, other):
414 return self.union(other)
415
416 def intersection(self, *others):
417 # return a new polyhedron with elements common to the polyhedron and all
418 # others
419 # a poor man's implementation could be:
420 # equalities = list(self.equalities)
421 # inequalities = list(self.inequalities)
422 # for other in others:
423 # equalities.extend(other.equalities)
424 # inequalities.extend(other.inequalities)
425 # return self.__class__(equalities, inequalities)
426 raise NotImplementedError
427
428 def __and__(self, other):
429 return self.intersection(other)
430
431 def difference(self, *others):
432 # return a new polyhedron with elements in the polyhedron that are not
433 # in the others
434 raise NotImplementedError
435
436 def __sub__(self, other):
437 return self.difference(other)
438
439 def __str__(self):
440 constraints = []
441 for constraint in self.equalities:
442 constraints.append('{} == 0'.format(constraint))
443 for constraint in self.inequalities:
444 constraints.append('{} <= 0'.format(constraint))
445 return '{{{}}}'.format(', '.join(constraints))
446
447 def __repr__(self):
448 equalities = list(self.equalities)
449 inequalities = list(self.inequalities)
450 return '{}(equalities={!r}, inequalities={!r})' \
451 ''.format(self.__class__.__name__, equalities, inequalities)
452
453 @classmethod
454 def fromstring(cls, string):
455 raise NotImplementedError
456
457
458 empty = le(1, 0)
459
460 universe = Polyhedron()