Improve tests involving iterators
[linpy.git] / pypol / linexprs.py
index 9a1ed64..ef5d90b 100644 (file)
@@ -9,7 +9,7 @@ from fractions import Fraction, gcd
 
 __all__ = [
     'Expression',
 
 __all__ = [
     'Expression',
-    'Symbol', 'symbols',
+    'Symbol', 'Dummy', 'symbols',
     'Rational',
 ]
 
     'Rational',
 ]
 
@@ -61,7 +61,7 @@ class Expression:
         self = object().__new__(cls)
         self._coefficients = OrderedDict()
         for symbol, coefficient in sorted(coefficients,
         self = object().__new__(cls)
         self._coefficients = OrderedDict()
         for symbol, coefficient in sorted(coefficients,
-                key=lambda item: item[0].name):
+                key=lambda item: item[0].sortkey()):
             if isinstance(coefficient, Rational):
                 coefficient = coefficient.constant
             if not isinstance(coefficient, numbers.Rational):
             if isinstance(coefficient, Rational):
                 coefficient = coefficient.constant
             if not isinstance(coefficient, numbers.Rational):
@@ -355,7 +355,7 @@ class Symbol(Expression):
         return self._name
 
     def __hash__(self):
         return self._name
 
     def __hash__(self):
-        return hash(self._name)
+        return hash(self.sortkey())
 
     def coefficient(self, symbol):
         if not isinstance(symbol, Symbol):
 
     def coefficient(self, symbol):
         if not isinstance(symbol, Symbol):
@@ -380,6 +380,9 @@ class Symbol(Expression):
     def dimension(self):
         return 1
 
     def dimension(self):
         return 1
 
+    def sortkey(self):
+        return self.name,
+
     def issymbol(self):
         return True
 
     def issymbol(self):
         return True
 
@@ -387,7 +390,11 @@ class Symbol(Expression):
         yield 1
 
     def __eq__(self, other):
         yield 1
 
     def __eq__(self, other):
-        return isinstance(other, Symbol) and self.name == other.name
+        return not isinstance(other, Dummy) and isinstance(other, Symbol) \
+            and self.name == other.name
+
+    def asdummy(self):
+        return Dummy(self.name)
 
     @classmethod
     def _fromast(cls, node):
 
     @classmethod
     def _fromast(cls, node):
@@ -408,6 +415,34 @@ class Symbol(Expression):
             raise TypeError('expr must be a sympy.Symbol instance')
 
 
             raise TypeError('expr must be a sympy.Symbol instance')
 
 
+class Dummy(Symbol):
+
+    __slots__ = (
+        '_name',
+        '_index',
+    )
+
+    _count = 0
+
+    def __new__(cls, name=None):
+        if name is None:
+            name = 'Dummy_{}'.format(Dummy._count)
+        self = object().__new__(cls)
+        self._name = name.strip()
+        self._index = Dummy._count
+        Dummy._count += 1
+        return self
+
+    def __hash__(self):
+        return hash(self.sortkey())
+
+    def sortkey(self):
+        return self._name, self._index
+
+    def __eq__(self, other):
+        return isinstance(other, Dummy) and self._index == other._index
+
+
 def symbols(names):
     if isinstance(names, str):
         names = names.replace(',', ' ').split()
 def symbols(names):
     if isinstance(names, str):
         names = names.replace(',', ' ').split()