X-Git-Url: https://svn.cri.ensmp.fr/git/linpy.git/blobdiff_plain/d06ab92943ec2e10a2bd798ca7c1b5cea395bf34..53cfc921440668ddabd7d8192fada69348486a7f:/pypol/linexprs.py diff --git a/pypol/linexprs.py b/pypol/linexprs.py index 9a1ed64..ef5d90b 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -9,7 +9,7 @@ from fractions import Fraction, gcd __all__ = [ 'Expression', - 'Symbol', 'symbols', + 'Symbol', 'Dummy', 'symbols', 'Rational', ] @@ -61,7 +61,7 @@ class Expression: self = object().__new__(cls) self._coefficients = OrderedDict() for symbol, coefficient in sorted(coefficients, - key=lambda item: item[0].name): + key=lambda item: item[0].sortkey()): if isinstance(coefficient, Rational): coefficient = coefficient.constant if not isinstance(coefficient, numbers.Rational): @@ -355,7 +355,7 @@ class Symbol(Expression): return self._name def __hash__(self): - return hash(self._name) + return hash(self.sortkey()) def coefficient(self, symbol): if not isinstance(symbol, Symbol): @@ -380,6 +380,9 @@ class Symbol(Expression): def dimension(self): return 1 + def sortkey(self): + return self.name, + def issymbol(self): return True @@ -387,7 +390,11 @@ class Symbol(Expression): yield 1 def __eq__(self, other): - return isinstance(other, Symbol) and self.name == other.name + return not isinstance(other, Dummy) and isinstance(other, Symbol) \ + and self.name == other.name + + def asdummy(self): + return Dummy(self.name) @classmethod def _fromast(cls, node): @@ -408,6 +415,34 @@ class Symbol(Expression): raise TypeError('expr must be a sympy.Symbol instance') +class Dummy(Symbol): + + __slots__ = ( + '_name', + '_index', + ) + + _count = 0 + + def __new__(cls, name=None): + if name is None: + name = 'Dummy_{}'.format(Dummy._count) + self = object().__new__(cls) + self._name = name.strip() + self._index = Dummy._count + Dummy._count += 1 + return self + + def __hash__(self): + return hash(self.sortkey()) + + def sortkey(self): + return self._name, self._index + + def __eq__(self, other): + return isinstance(other, Dummy) and self._index == other._index + + def symbols(names): if isinstance(names, str): names = names.replace(',', ' ').split()