e73449e9af3429f9abc445c68103275127b225f8
6 from collections
import OrderedDict
, defaultdict
, Mapping
7 from fractions
import Fraction
, gcd
12 'Symbol', 'Dummy', 'symbols',
17 def _polymorphic(func
):
18 @functools.wraps(func
)
19 def wrapper(left
, right
):
20 if isinstance(right
, Expression
):
21 return func(left
, right
)
22 elif isinstance(right
, numbers
.Rational
):
23 right
= Rational(right
)
24 return func(left
, right
)
31 This class implements linear expressions.
34 def __new__(cls
, coefficients
=None, constant
=0):
35 if isinstance(coefficients
, str):
37 raise TypeError('too many arguments')
38 return Expression
.fromstring(coefficients
)
39 if coefficients
is None:
40 return Rational(constant
)
41 if isinstance(coefficients
, Mapping
):
42 coefficients
= coefficients
.items()
43 for symbol
, coefficient
in coefficients
:
44 if not isinstance(symbol
, Symbol
):
45 raise TypeError('symbols must be Symbol instances')
46 if not isinstance(coefficient
, numbers
.Rational
):
47 raise TypeError('coefficients must be Rational instances')
48 coefficients
= [(symbol
, Fraction(coefficient
))
49 for symbol
, coefficient
in coefficients
if coefficient
!= 0]
50 if not isinstance(constant
, numbers
.Rational
):
51 raise TypeError('constant must be a Rational instance')
52 constant
= Fraction(constant
)
53 if len(coefficients
) == 0:
54 return Rational(constant
)
55 if len(coefficients
) == 1 and constant
== 0:
56 symbol
, coefficient
= coefficients
[0]
59 self
= object().__new
__(cls
)
60 self
._coefficients
= OrderedDict(sorted(coefficients
,
61 key
=lambda item
: item
[0].sortkey()))
62 self
._constant
= constant
63 self
._symbols
= tuple(self
._coefficients
)
64 self
._dimension
= len(self
._symbols
)
67 def coefficient(self
, symbol
):
68 if not isinstance(symbol
, Symbol
):
69 raise TypeError('symbol must be a Symbol instance')
71 return Rational(self
._coefficients
[symbol
])
75 __getitem__
= coefficient
77 def coefficients(self
):
78 for symbol
, coefficient
in self
._coefficients
.items():
79 yield symbol
, Rational(coefficient
)
83 return Rational(self
._constant
)
91 return self
._dimension
94 return hash((tuple(self
._coefficients
.items()), self
._constant
))
103 for coefficient
in self
._coefficients
.values():
104 yield Rational(coefficient
)
105 yield Rational(self
._constant
)
117 def __add__(self
, other
):
118 coefficients
= defaultdict(Fraction
, self
._coefficients
)
119 for symbol
, coefficient
in other
._coefficients
.items():
120 coefficients
[symbol
] += coefficient
121 constant
= self
._constant
+ other
._constant
122 return Expression(coefficients
, constant
)
127 def __sub__(self
, other
):
128 coefficients
= defaultdict(Fraction
, self
._coefficients
)
129 for symbol
, coefficient
in other
._coefficients
.items():
130 coefficients
[symbol
] -= coefficient
131 constant
= self
._constant
- other
._constant
132 return Expression(coefficients
, constant
)
134 def __rsub__(self
, other
):
135 return -(self
- other
)
138 def __mul__(self
, other
):
139 if isinstance(other
, Rational
):
140 return other
.__rmul
__(self
)
141 return NotImplemented
146 def __truediv__(self
, other
):
147 if isinstance(other
, Rational
):
148 return other
.__rtruediv
__(self
)
149 return NotImplemented
151 __rtruediv__
= __truediv__
154 def __eq__(self
, other
):
156 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
157 return isinstance(other
, Expression
) and \
158 self
._coefficients
== other
._coefficients
and \
159 self
._constant
== other
._constant
162 def __le__(self
, other
):
163 from .polyhedra
import Le
164 return Le(self
, other
)
167 def __lt__(self
, other
):
168 from .polyhedra
import Lt
169 return Lt(self
, other
)
172 def __ge__(self
, other
):
173 from .polyhedra
import Ge
174 return Ge(self
, other
)
177 def __gt__(self
, other
):
178 from .polyhedra
import Gt
179 return Gt(self
, other
)
182 lcm
= functools
.reduce(lambda a
, b
: a
*b
// gcd(a
, b
),
183 [value
.denominator
for value
in self
.values()])
186 def subs(self
, symbol
, expression
=None):
187 if expression
is None:
188 if isinstance(symbol
, Mapping
):
189 symbol
= symbol
.items()
190 substitutions
= symbol
192 substitutions
= [(symbol
, expression
)]
194 for symbol
, expression
in substitutions
:
195 if not isinstance(symbol
, Symbol
):
196 raise TypeError('symbols must be Symbol instances')
197 coefficients
= [(othersymbol
, coefficient
)
198 for othersymbol
, coefficient
in result
._coefficients
.items()
199 if othersymbol
!= symbol
]
200 coefficient
= result
._coefficients
.get(symbol
, 0)
201 constant
= result
._constant
202 result
= Expression(coefficients
, constant
) + coefficient
*expression
206 def _fromast(cls
, node
):
207 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
208 return cls
._fromast
(node
.body
[0])
209 elif isinstance(node
, ast
.Expr
):
210 return cls
._fromast
(node
.value
)
211 elif isinstance(node
, ast
.Name
):
212 return Symbol(node
.id)
213 elif isinstance(node
, ast
.Num
):
214 return Rational(node
.n
)
215 elif isinstance(node
, ast
.UnaryOp
) and isinstance(node
.op
, ast
.USub
):
216 return -cls
._fromast
(node
.operand
)
217 elif isinstance(node
, ast
.BinOp
):
218 left
= cls
._fromast
(node
.left
)
219 right
= cls
._fromast
(node
.right
)
220 if isinstance(node
.op
, ast
.Add
):
222 elif isinstance(node
.op
, ast
.Sub
):
224 elif isinstance(node
.op
, ast
.Mult
):
226 elif isinstance(node
.op
, ast
.Div
):
228 raise SyntaxError('invalid syntax')
230 _RE_NUM_VAR
= re
.compile(r
'(\d+|\))\s*([^\W\d_]\w*|\()')
233 def fromstring(cls
, string
):
234 # add implicit multiplication operators, e.g. '5x' -> '5*x'
235 string
= Expression
._RE
_NUM
_VAR
.sub(r
'\1*\2', string
)
236 tree
= ast
.parse(string
, 'eval')
237 return cls
._fromast
(tree
)
241 for i
, (symbol
, coefficient
) in enumerate(self
.coefficients()):
243 string
+= '' if i
== 0 else ' + '
244 string
+= '{!r}'.format(symbol
)
245 elif coefficient
== -1:
246 string
+= '-' if i
== 0 else ' - '
247 string
+= '{!r}'.format(symbol
)
250 string
+= '{}*{!r}'.format(coefficient
, symbol
)
251 elif coefficient
> 0:
252 string
+= ' + {}*{!r}'.format(coefficient
, symbol
)
254 string
+= ' - {}*{!r}'.format(-coefficient
, symbol
)
255 constant
= self
.constant
257 string
+= '{}'.format(constant
)
259 string
+= ' + {}'.format(constant
)
261 string
+= ' - {}'.format(-constant
)
264 def _parenstr(self
, always
=False):
266 if not always
and (self
.isconstant() or self
.issymbol()):
269 return '({})'.format(string
)
272 def fromsympy(cls
, expr
):
276 for symbol
, coefficient
in expr
.as_coefficients_dict().items():
277 coefficient
= Fraction(coefficient
.p
, coefficient
.q
)
278 if symbol
== sympy
.S
.One
:
279 constant
= coefficient
280 elif isinstance(symbol
, sympy
.Symbol
):
281 symbol
= Symbol(symbol
.name
)
282 coefficients
.append((symbol
, coefficient
))
284 raise ValueError('non-linear expression: {!r}'.format(expr
))
285 return Expression(coefficients
, constant
)
290 for symbol
, coefficient
in self
.coefficients():
291 term
= coefficient
* sympy
.Symbol(symbol
.name
)
293 expr
+= self
.constant
297 class Symbol(Expression
):
299 def __new__(cls
, name
):
300 if not isinstance(name
, str):
301 raise TypeError('name must be a string')
302 self
= object().__new
__(cls
)
303 self
._name
= name
.strip()
304 self
._coefficients
= {self
: 1}
306 self
._symbols
= (self
,)
315 return hash(self
.sortkey())
323 def __eq__(self
, other
):
324 return not isinstance(other
, Dummy
) and isinstance(other
, Symbol
) \
325 and self
.name
== other
.name
328 return Dummy(self
.name
)
331 def _fromast(cls
, node
):
332 if isinstance(node
, ast
.Module
) and len(node
.body
) == 1:
333 return cls
._fromast
(node
.body
[0])
334 elif isinstance(node
, ast
.Expr
):
335 return cls
._fromast
(node
.value
)
336 elif isinstance(node
, ast
.Name
):
337 return Symbol(node
.id)
338 raise SyntaxError('invalid syntax')
344 def fromsympy(cls
, expr
):
346 if isinstance(expr
, sympy
.Symbol
):
347 return cls(expr
.name
)
349 raise TypeError('expr must be a sympy.Symbol instance')
356 def __new__(cls
, name
=None):
358 name
= 'Dummy_{}'.format(Dummy
._count
)
359 self
= object().__new
__(cls
)
360 self
._index
= Dummy
._count
361 self
._name
= name
.strip()
362 self
._coefficients
= {self
: 1}
364 self
._symbols
= (self
,)
370 return hash(self
.sortkey())
373 return self
._name
, self
._index
375 def __eq__(self
, other
):
376 return isinstance(other
, Dummy
) and self
._index
== other
._index
379 return '_{}'.format(self
.name
)
383 if isinstance(names
, str):
384 names
= names
.replace(',', ' ').split()
385 return tuple(Symbol(name
) for name
in names
)
388 class Rational(Expression
, Fraction
):
390 def __new__(cls
, numerator
=0, denominator
=None):
391 self
= Fraction
.__new
__(cls
, numerator
, denominator
)
392 self
._coefficients
= {}
393 self
._constant
= Fraction(self
)
399 return Fraction
.__hash
__(self
)
405 def isconstant(self
):
409 return Fraction
.__bool
__(self
)
412 def __mul__(self
, other
):
413 coefficients
= dict(other
._coefficients
)
414 for symbol
in coefficients
:
415 coefficients
[symbol
] *= self
._constant
416 constant
= other
._constant
* self
._constant
417 return Expression(coefficients
, constant
)
422 def __rtruediv__(self
, other
):
423 coefficients
= dict(other
._coefficients
)
424 for symbol
in coefficients
:
425 coefficients
[symbol
] /= self
._constant
426 constant
= other
._constant
/ self
._constant
427 return Expression(coefficients
, constant
)
430 def fromstring(cls
, string
):
431 if not isinstance(string
, str):
432 raise TypeError('string must be a string instance')
433 return Rational(Fraction(string
))
436 def fromsympy(cls
, expr
):
438 if isinstance(expr
, sympy
.Rational
):
439 return Rational(expr
.p
, expr
.q
)
440 elif isinstance(expr
, numbers
.Rational
):
441 return Rational(expr
)
443 raise TypeError('expr must be a sympy.Rational instance')