X-Git-Url: https://svn.cri.ensmp.fr/git/linpy.git/blobdiff_plain/29ed88d1a15d283ea6f3340a4dd97e8cc7c2d2d4..161d0ced692386a866e55aea673d991e2e95f753:/pypol/linexprs.py?ds=inline diff --git a/pypol/linexprs.py b/pypol/linexprs.py index ccd1564..f3cff23 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -9,8 +9,8 @@ from fractions import Fraction, gcd __all__ = [ 'Expression', - 'Symbol', 'symbols', - 'Constant', + 'Symbol', 'symbols', 'Dummy', + 'Rational', ] @@ -20,7 +20,7 @@ def _polymorphic(func): if isinstance(right, Expression): return func(left, right) elif isinstance(right, numbers.Rational): - right = Constant(right) + right = Rational(right) return func(left, right) return NotImplemented return wrapper @@ -44,7 +44,7 @@ class Expression: raise TypeError('too many arguments') return Expression.fromstring(coefficients) if coefficients is None: - return Constant(constant) + return Rational(constant) if isinstance(coefficients, dict): coefficients = coefficients.items() for symbol, coefficient in coefficients: @@ -53,7 +53,7 @@ class Expression: coefficients = [(symbol, coefficient) for symbol, coefficient in coefficients if coefficient != 0] if len(coefficients) == 0: - return Constant(constant) + return Rational(constant) if len(coefficients) == 1 and constant == 0: symbol, coefficient = coefficients[0] if coefficient == 1: @@ -61,18 +61,18 @@ class Expression: self = object().__new__(cls) self._coefficients = OrderedDict() for symbol, coefficient in sorted(coefficients, - key=lambda item: item[0].name): - if isinstance(coefficient, Constant): + key=lambda item: item[0].sortkey()): + if isinstance(coefficient, Rational): coefficient = coefficient.constant if not isinstance(coefficient, numbers.Rational): raise TypeError('coefficients must be rational numbers ' - 'or Constant instances') + 'or Rational instances') self._coefficients[symbol] = coefficient - if isinstance(constant, Constant): + if isinstance(constant, Rational): constant = constant.constant if not isinstance(constant, numbers.Rational): raise TypeError('constant must be a rational number ' - 'or a Constant instance') + 'or a Rational instance') self._constant = constant self._symbols = tuple(self._coefficients) self._dimension = len(self._symbols) @@ -127,7 +127,7 @@ class Expression: @_polymorphic def __add__(self, other): - coefficients = defaultdict(Constant, self.coefficients()) + coefficients = defaultdict(Rational, self.coefficients()) for symbol, coefficient in other.coefficients(): coefficients[symbol] += coefficient constant = self.constant + other.constant @@ -137,7 +137,7 @@ class Expression: @_polymorphic def __sub__(self, other): - coefficients = defaultdict(Constant, self.coefficients()) + coefficients = defaultdict(Rational, self.coefficients()) for symbol, coefficient in other.coefficients(): coefficients[symbol] -= coefficient constant = self.constant - other.constant @@ -166,8 +166,8 @@ class Expression: if other.isconstant(): coefficients = dict(self.coefficients()) for symbol in coefficients: - coefficients[symbol] = Constant(coefficients[symbol], other.constant) - constant = Constant(self.constant, other.constant) + coefficients[symbol] = Rational(coefficients[symbol], other.constant) + constant = Rational(self.constant, other.constant) return Expression(coefficients, constant) if isinstance(other, Expression): raise ValueError('non-linear expression: ' @@ -177,7 +177,7 @@ class Expression: def __rtruediv__(self, other): if isinstance(other, self): if self.isconstant(): - return Constant(other, self.constant) + return Rational(other, self.constant) else: raise ValueError('non-linear expression: ' '{} / {}'.format(other._parenstr(), self._parenstr())) @@ -242,7 +242,7 @@ class Expression: elif isinstance(node, ast.Name): return Symbol(node.id) elif isinstance(node, ast.Num): - return Constant(node.n) + return Rational(node.n) elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): return -cls._fromast(node.operand) elif isinstance(node, ast.BinOp): @@ -355,7 +355,7 @@ class Symbol(Expression): return self._name def __hash__(self): - return hash(self._name) + return hash(self.sortkey()) def coefficient(self, symbol): if not isinstance(symbol, Symbol): @@ -380,6 +380,9 @@ class Symbol(Expression): def dimension(self): return 1 + def sortkey(self): + return self.name, + def issymbol(self): return True @@ -387,7 +390,11 @@ class Symbol(Expression): yield 1 def __eq__(self, other): - return isinstance(other, Symbol) and self.name == other.name + return not isinstance(other, Dummy) and isinstance(other, Symbol) \ + and self.name == other.name + + def asdummy(self): + return Dummy(self.name) @classmethod def _fromast(cls, node): @@ -414,7 +421,35 @@ def symbols(names): return tuple(Symbol(name) for name in names) -class Constant(Expression): +class Dummy(Symbol): + + __slots__ = ( + '_name', + '_index', + ) + + _count = 0 + + def __new__(cls, name=None): + if name is None: + name = 'Dummy_{}'.format(Dummy._count) + self = object().__new__(cls) + self._name = name.strip() + self._index = Dummy._count + Dummy._count += 1 + return self + + def __hash__(self): + return hash(self.sortkey()) + + def sortkey(self): + return self._name, self._index + + def __eq__(self, other): + return isinstance(other, Dummy) and self._index == other._index + + +class Rational(Expression): __slots__ = ( '_constant', @@ -422,7 +457,7 @@ class Constant(Expression): def __new__(cls, numerator=0, denominator=None): self = object().__new__(cls) - if denominator is None and isinstance(numerator, Constant): + if denominator is None and isinstance(numerator, Rational): self._constant = numerator.constant else: self._constant = Fraction(numerator, denominator) @@ -455,7 +490,7 @@ class Constant(Expression): @_polymorphic def __eq__(self, other): - return isinstance(other, Constant) and self.constant == other.constant + return isinstance(other, Rational) and self.constant == other.constant def __bool__(self): return self.constant != 0 @@ -464,14 +499,14 @@ class Constant(Expression): def fromstring(cls, string): if not isinstance(string, str): raise TypeError('string must be a string instance') - return Constant(Fraction(string)) + return Rational(Fraction(string)) @classmethod def fromsympy(cls, expr): import sympy if isinstance(expr, sympy.Rational): - return Constant(expr.p, expr.q) + return Rational(expr.p, expr.q) elif isinstance(expr, numbers.Rational): - return Constant(expr) + return Rational(expr) else: raise TypeError('expr must be a sympy.Rational instance')