eff4a7ecdab8bdd5ca30a7f3fc933e03c70f7d87
[linpy.git] / linpy / linexprs.py
1 # Copyright 2014 MINES ParisTech
2 #
3 # This file is part of LinPy.
4 #
5 # LinPy is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU General Public License as published by
7 # the Free Software Foundation, either version 3 of the License, or
8 # (at your option) any later version.
9 #
10 # LinPy is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with LinPy. If not, see <http://www.gnu.org/licenses/>.
17
18 import ast
19 import functools
20 import numbers
21 import re
22
23 from collections import OrderedDict, defaultdict, Mapping
24 from fractions import Fraction, gcd
25
26
27 __all__ = [
28 'LinExpr',
29 'Symbol', 'Dummy', 'symbols',
30 'Rational',
31 ]
32
33
34 def _polymorphic(func):
35 @functools.wraps(func)
36 def wrapper(left, right):
37 if isinstance(right, LinExpr):
38 return func(left, right)
39 elif isinstance(right, numbers.Rational):
40 right = Rational(right)
41 return func(left, right)
42 return NotImplemented
43 return wrapper
44
45
46 class LinExpr:
47 """
48 A linear expression consists of a list of coefficient-variable pairs
49 that capture the linear terms, plus a constant term. Linear expressions
50 are used to build constraints. They are temporary objects that typically
51 have short lifespans.
52
53 Linear expressions are generally built using overloaded operators. For
54 example, if x is a Symbol, then x + 1 is an instance of LinExpr.
55
56 LinExpr instances are hashable, and should be treated as immutable.
57 """
58
59 def __new__(cls, coefficients=None, constant=0):
60 """
61 Return a linear expression from a dictionary or a sequence, that maps
62 symbols to their coefficients, and a constant term. The coefficients and
63 the constant term must be rational numbers.
64
65 For example, the linear expression x + 2*y + 1 can be constructed using
66 one of the following instructions:
67
68 >>> x, y = symbols('x y')
69 >>> LinExpr({x: 1, y: 2}, 1)
70 >>> LinExpr([(x, 1), (y, 2)], 1)
71
72 However, it may be easier to use overloaded operators:
73
74 >>> x, y = symbols('x y')
75 >>> x + 2*y + 1
76
77 Alternatively, linear expressions can be constructed from a string:
78
79 >>> LinExpr('x + 2y + 1')
80
81 A linear expression with a single symbol of coefficient 1 and no
82 constant term is automatically subclassed as a Symbol instance. A linear
83 expression with no symbol, only a constant term, is automatically
84 subclassed as a Rational instance.
85 """
86 if isinstance(coefficients, str):
87 if constant != 0:
88 raise TypeError('too many arguments')
89 return LinExpr.fromstring(coefficients)
90 if coefficients is None:
91 return Rational(constant)
92 if isinstance(coefficients, Mapping):
93 coefficients = coefficients.items()
94 coefficients = list(coefficients)
95 for symbol, coefficient in coefficients:
96 if not isinstance(symbol, Symbol):
97 raise TypeError('symbols must be Symbol instances')
98 if not isinstance(coefficient, numbers.Rational):
99 raise TypeError('coefficients must be rational numbers')
100 if not isinstance(constant, numbers.Rational):
101 raise TypeError('constant must be a rational number')
102 if len(coefficients) == 0:
103 return Rational(constant)
104 if len(coefficients) == 1 and constant == 0:
105 symbol, coefficient = coefficients[0]
106 if coefficient == 1:
107 return symbol
108 coefficients = [(symbol, Fraction(coefficient))
109 for symbol, coefficient in coefficients if coefficient != 0]
110 coefficients.sort(key=lambda item: item[0].sortkey())
111 self = object().__new__(cls)
112 self._coefficients = OrderedDict(coefficients)
113 self._constant = Fraction(constant)
114 self._symbols = tuple(self._coefficients)
115 self._dimension = len(self._symbols)
116 return self
117
118 def coefficient(self, symbol):
119 """
120 Return the coefficient value of the given symbol, or 0 if the symbol
121 does not appear in the expression.
122 """
123 if not isinstance(symbol, Symbol):
124 raise TypeError('symbol must be a Symbol instance')
125 return self._coefficients.get(symbol, Fraction(0))
126
127 __getitem__ = coefficient
128
129 def coefficients(self):
130 """
131 Iterate over the pairs (symbol, value) of linear terms in the
132 expression. The constant term is ignored.
133 """
134 yield from self._coefficients.items()
135
136 @property
137 def constant(self):
138 """
139 The constant term of the expression.
140 """
141 return self._constant
142
143 @property
144 def symbols(self):
145 """
146 The tuple of symbols present in the expression, sorted according to
147 Symbol.sortkey().
148 """
149 return self._symbols
150
151 @property
152 def dimension(self):
153 """
154 The dimension of the expression, i.e. the number of symbols present in
155 it.
156 """
157 return self._dimension
158
159 def __hash__(self):
160 return hash((tuple(self._coefficients.items()), self._constant))
161
162 def isconstant(self):
163 """
164 Return True if the expression only consists of a constant term. In this
165 case, it is a Rational instance.
166 """
167 return False
168
169 def issymbol(self):
170 """
171 Return True if an expression only consists of a symbol with coefficient
172 1. In this case, it is a Symbol instance.
173 """
174 return False
175
176 def values(self):
177 """
178 Iterate over the coefficient values in the expression, and the constant
179 term.
180 """
181 yield from self._coefficients.values()
182 yield self._constant
183
184 def __bool__(self):
185 return True
186
187 def __pos__(self):
188 return self
189
190 def __neg__(self):
191 return self * -1
192
193 @_polymorphic
194 def __add__(self, other):
195 """
196 Return the sum of two linear expressions.
197 """
198 coefficients = defaultdict(Fraction, self._coefficients)
199 for symbol, coefficient in other._coefficients.items():
200 coefficients[symbol] += coefficient
201 constant = self._constant + other._constant
202 return LinExpr(coefficients, constant)
203
204 __radd__ = __add__
205
206 @_polymorphic
207 def __sub__(self, other):
208 """
209 Return the difference between two linear expressions.
210 """
211 coefficients = defaultdict(Fraction, self._coefficients)
212 for symbol, coefficient in other._coefficients.items():
213 coefficients[symbol] -= coefficient
214 constant = self._constant - other._constant
215 return LinExpr(coefficients, constant)
216
217 @_polymorphic
218 def __rsub__(self, other):
219 return other - self
220
221 def __mul__(self, other):
222 """
223 Return the product of the linear expression by a rational.
224 """
225 if isinstance(other, numbers.Rational):
226 coefficients = ((symbol, coefficient * other)
227 for symbol, coefficient in self._coefficients.items())
228 constant = self._constant * other
229 return LinExpr(coefficients, constant)
230 return NotImplemented
231
232 __rmul__ = __mul__
233
234 def __truediv__(self, other):
235 """
236 Return the quotient of the linear expression by a rational.
237 """
238 if isinstance(other, numbers.Rational):
239 coefficients = ((symbol, coefficient / other)
240 for symbol, coefficient in self._coefficients.items())
241 constant = self._constant / other
242 return LinExpr(coefficients, constant)
243 return NotImplemented
244
245 @_polymorphic
246 def __eq__(self, other):
247 """
248 Test whether two linear expressions are equal. Unlike methods
249 LinExpr.__lt__(), LinExpr.__le__(), LinExpr.__ge__(), LinExpr.__gt__(),
250 the result is a boolean value, not a polyhedron. To express that two
251 linear expressions are equal or not equal, use functions Eq() and Ne()
252 instead.
253 """
254 return self._coefficients == other._coefficients and \
255 self._constant == other._constant
256
257 @_polymorphic
258 def __lt__(self, other):
259 from .polyhedra import Polyhedron
260 return Polyhedron([], [other - self - 1])
261
262 @_polymorphic
263 def __le__(self, other):
264 from .polyhedra import Polyhedron
265 return Polyhedron([], [other - self])
266
267 @_polymorphic
268 def __ge__(self, other):
269 from .polyhedra import Polyhedron
270 return Polyhedron([], [self - other])
271
272 @_polymorphic
273 def __gt__(self, other):
274 from .polyhedra import Polyhedron
275 return Polyhedron([], [self - other - 1])
276
277 def scaleint(self):
278 """
279 Return the expression multiplied by its lowest common denominator to
280 make all values integer.
281 """
282 lcd = functools.reduce(lambda a, b: a*b // gcd(a, b),
283 [value.denominator for value in self.values()])
284 return self * lcd
285
286 def subs(self, symbol, expression=None):
287 """
288 Substitute the given symbol by an expression and return the resulting
289 expression. Raise TypeError if the resulting expression is not linear.
290
291 >>> x, y = symbols('x y')
292 >>> e = x + 2*y + 1
293 >>> e.subs(y, x - 1)
294 3*x - 1
295
296 To perform multiple substitutions at once, pass a sequence or a
297 dictionary of (old, new) pairs to subs.
298
299 >>> e.subs({x: y, y: x})
300 2*x + y + 1
301 """
302 if expression is None:
303 substitutions = dict(symbol)
304 else:
305 substitutions = {symbol: expression}
306 for symbol in substitutions:
307 if not isinstance(symbol, Symbol):
308 raise TypeError('symbols must be Symbol instances')
309 result = self._constant
310 for symbol, coefficient in self._coefficients.items():
311 expression = substitutions.get(symbol, symbol)
312 result += coefficient * expression
313 return result
314
315 @classmethod
316 def _fromast(cls, node):
317 if isinstance(node, ast.Module) and len(node.body) == 1:
318 return cls._fromast(node.body[0])
319 elif isinstance(node, ast.Expr):
320 return cls._fromast(node.value)
321 elif isinstance(node, ast.Name):
322 return Symbol(node.id)
323 elif isinstance(node, ast.Num):
324 return Rational(node.n)
325 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
326 return -cls._fromast(node.operand)
327 elif isinstance(node, ast.BinOp):
328 left = cls._fromast(node.left)
329 right = cls._fromast(node.right)
330 if isinstance(node.op, ast.Add):
331 return left + right
332 elif isinstance(node.op, ast.Sub):
333 return left - right
334 elif isinstance(node.op, ast.Mult):
335 return left * right
336 elif isinstance(node.op, ast.Div):
337 return left / right
338 raise SyntaxError('invalid syntax')
339
340 _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()')
341
342 @classmethod
343 def fromstring(cls, string):
344 """
345 Create an expression from a string. Raise SyntaxError if the string is
346 not properly formatted.
347 """
348 # Add implicit multiplication operators, e.g. '5x' -> '5*x'.
349 string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
350 tree = ast.parse(string, 'eval')
351 expr = cls._fromast(tree)
352 if not isinstance(expr, cls):
353 raise SyntaxError('invalid syntax')
354 return expr
355
356 def __repr__(self):
357 string = ''
358 for i, (symbol, coefficient) in enumerate(self.coefficients()):
359 if coefficient == 1:
360 if i != 0:
361 string += ' + '
362 elif coefficient == -1:
363 string += '-' if i == 0 else ' - '
364 elif i == 0:
365 string += '{}*'.format(coefficient)
366 elif coefficient > 0:
367 string += ' + {}*'.format(coefficient)
368 else:
369 string += ' - {}*'.format(-coefficient)
370 string += '{}'.format(symbol)
371 constant = self.constant
372 if len(string) == 0:
373 string += '{}'.format(constant)
374 elif constant > 0:
375 string += ' + {}'.format(constant)
376 elif constant < 0:
377 string += ' - {}'.format(-constant)
378 return string
379
380 def _repr_latex_(self):
381 string = ''
382 for i, (symbol, coefficient) in enumerate(self.coefficients()):
383 if coefficient == 1:
384 if i != 0:
385 string += ' + '
386 elif coefficient == -1:
387 string += '-' if i == 0 else ' - '
388 elif i == 0:
389 string += '{}'.format(coefficient._repr_latex_().strip('$'))
390 elif coefficient > 0:
391 string += ' + {}'.format(coefficient._repr_latex_().strip('$'))
392 elif coefficient < 0:
393 string += ' - {}'.format((-coefficient)._repr_latex_().strip('$'))
394 string += '{}'.format(symbol._repr_latex_().strip('$'))
395 constant = self.constant
396 if len(string) == 0:
397 string += '{}'.format(constant._repr_latex_().strip('$'))
398 elif constant > 0:
399 string += ' + {}'.format(constant._repr_latex_().strip('$'))
400 elif constant < 0:
401 string += ' - {}'.format((-constant)._repr_latex_().strip('$'))
402 return '$${}$$'.format(string)
403
404 def _parenstr(self, always=False):
405 string = str(self)
406 if not always and (self.isconstant() or self.issymbol()):
407 return string
408 else:
409 return '({})'.format(string)
410
411 @classmethod
412 def fromsympy(cls, expr):
413 """
414 Create a linear expression from a SymPy expression. Raise TypeError is
415 the sympy expression is not linear.
416 """
417 import sympy
418 coefficients = []
419 constant = 0
420 for symbol, coefficient in expr.as_coefficients_dict().items():
421 coefficient = Fraction(coefficient.p, coefficient.q)
422 if symbol == sympy.S.One:
423 constant = coefficient
424 elif isinstance(symbol, sympy.Dummy):
425 # We cannot properly convert dummy symbols with respect to
426 # symbol equalities.
427 raise TypeError('cannot convert dummy symbols')
428 elif isinstance(symbol, sympy.Symbol):
429 symbol = Symbol(symbol.name)
430 coefficients.append((symbol, coefficient))
431 else:
432 raise TypeError('non-linear expression: {!r}'.format(expr))
433 expr = LinExpr(coefficients, constant)
434 if not isinstance(expr, cls):
435 raise TypeError('cannot convert to a {} instance'.format(cls.__name__))
436 return expr
437
438 def tosympy(self):
439 """
440 Convert the linear expression to a SymPy expression.
441 """
442 import sympy
443 expr = 0
444 for symbol, coefficient in self.coefficients():
445 term = coefficient * sympy.Symbol(symbol.name)
446 expr += term
447 expr += self.constant
448 return expr
449
450
451 class Symbol(LinExpr):
452 """
453 Symbols are the basic components to build expressions and constraints.
454 They correspond to mathematical variables. Symbols are instances of
455 class LinExpr and inherit its functionalities.
456
457 Two instances of Symbol are equal if they have the same name.
458 """
459
460 __slots__ = (
461 '_name',
462 '_constant',
463 '_symbols',
464 '_dimension',
465 )
466
467 def __new__(cls, name):
468 """
469 Return a symbol with the name string given in argument.
470 """
471 if not isinstance(name, str):
472 raise TypeError('name must be a string')
473 node = ast.parse(name)
474 try:
475 name = node.body[0].value.id
476 except (AttributeError, SyntaxError):
477 raise SyntaxError('invalid syntax')
478 self = object().__new__(cls)
479 self._name = name
480 self._constant = Fraction(0)
481 self._symbols = (self,)
482 self._dimension = 1
483 return self
484
485 @property
486 def _coefficients(self):
487 # This is not implemented as an attribute, because __hash__ is not
488 # callable in __new__ in class Dummy.
489 return {self: Fraction(1)}
490
491 @property
492 def name(self):
493 """
494 The name of the symbol.
495 """
496 return self._name
497
498 def __hash__(self):
499 return hash(self.sortkey())
500
501 def sortkey(self):
502 """
503 Return a sorting key for the symbol. It is useful to sort a list of
504 symbols in a consistent order, as comparison functions are overridden
505 (see the documentation of class LinExpr).
506
507 >>> sort(symbols, key=Symbol.sortkey)
508 """
509 return self.name,
510
511 def issymbol(self):
512 return True
513
514 def __eq__(self, other):
515 if isinstance(other, Symbol):
516 return self.sortkey() == other.sortkey()
517 return NotImplemented
518
519 def asdummy(self):
520 """
521 Return a new Dummy symbol instance with the same name.
522 """
523 return Dummy(self.name)
524
525 def __repr__(self):
526 return self.name
527
528 def _repr_latex_(self):
529 return '$${}$$'.format(self.name)
530
531
532 def symbols(names):
533 """
534 This function returns a tuple of symbols whose names are taken from a comma
535 or whitespace delimited string, or a sequence of strings. It is useful to
536 define several symbols at once.
537
538 >>> x, y = symbols('x y')
539 >>> x, y = symbols('x, y')
540 >>> x, y = symbols(['x', 'y'])
541 """
542 if isinstance(names, str):
543 names = names.replace(',', ' ').split()
544 return tuple(Symbol(name) for name in names)
545
546
547 class Dummy(Symbol):
548 """
549 A variation of Symbol in which all symbols are unique and identified by
550 an internal count index. If a name is not supplied then a string value
551 of the count index will be used. This is useful when a unique, temporary
552 variable is needed and the name of the variable used in the expression
553 is not important.
554
555 Unlike Symbol, Dummy instances with the same name are not equal:
556
557 >>> x = Symbol('x')
558 >>> x1, x2 = Dummy('x'), Dummy('x')
559 >>> x == x1
560 False
561 >>> x1 == x2
562 False
563 >>> x1 == x1
564 True
565 """
566
567 _count = 0
568
569 def __new__(cls, name=None):
570 """
571 Return a fresh dummy symbol with the name string given in argument.
572 """
573 if name is None:
574 name = 'Dummy_{}'.format(Dummy._count)
575 self = super().__new__(cls, name)
576 self._index = Dummy._count
577 Dummy._count += 1
578 return self
579
580 def __hash__(self):
581 return hash(self.sortkey())
582
583 def sortkey(self):
584 return self._name, self._index
585
586 def __repr__(self):
587 return '_{}'.format(self.name)
588
589 def _repr_latex_(self):
590 return '$${}_{{{}}}$$'.format(self.name, self._index)
591
592
593 class Rational(LinExpr, Fraction):
594 """
595 A particular case of linear expressions are rational values, i.e. linear
596 expressions consisting only of a constant term, with no symbol. They are
597 implemented by the Rational class, that inherits from both LinExpr and
598 fractions.Fraction classes.
599 """
600
601 __slots__ = (
602 '_coefficients',
603 '_constant',
604 '_symbols',
605 '_dimension',
606 ) + Fraction.__slots__
607
608 def __new__(cls, numerator=0, denominator=None):
609 self = object().__new__(cls)
610 self._coefficients = {}
611 self._constant = Fraction(numerator, denominator)
612 self._symbols = ()
613 self._dimension = 0
614 self._numerator = self._constant.numerator
615 self._denominator = self._constant.denominator
616 return self
617
618 def __hash__(self):
619 return Fraction.__hash__(self)
620
621 @property
622 def constant(self):
623 return self
624
625 def isconstant(self):
626 return True
627
628 def __bool__(self):
629 return Fraction.__bool__(self)
630
631 def __repr__(self):
632 if self.denominator == 1:
633 return '{!r}'.format(self.numerator)
634 else:
635 return '{!r}/{!r}'.format(self.numerator, self.denominator)
636
637 def _repr_latex_(self):
638 if self.denominator == 1:
639 return '$${}$$'.format(self.numerator)
640 elif self.numerator < 0:
641 return '$$-\\frac{{{}}}{{{}}}$$'.format(-self.numerator,
642 self.denominator)
643 else:
644 return '$$\\frac{{{}}}{{{}}}$$'.format(self.numerator,
645 self.denominator)