Return NotImplemented in Point special methods
[linpy.git] / linpy / linexprs.py
index b2cec53..d2554a0 100644 (file)
@@ -122,7 +122,7 @@ class LinExpr:
         """
         if not isinstance(symbol, Symbol):
             raise TypeError('symbol must be a Symbol instance')
         """
         if not isinstance(symbol, Symbol):
             raise TypeError('symbol must be a Symbol instance')
-        return self._coefficients.get(symbol, 0)
+        return self._coefficients.get(symbol, Fraction(0))
 
     __getitem__ = coefficient
 
 
     __getitem__ = coefficient
 
@@ -131,8 +131,7 @@ class LinExpr:
         Iterate over the pairs (symbol, value) of linear terms in the
         expression. The constant term is ignored.
         """
         Iterate over the pairs (symbol, value) of linear terms in the
         expression. The constant term is ignored.
         """
-        for symbol, coefficient in self._coefficients.items():
-            yield symbol, coefficient
+        yield from self._coefficients.items()
 
     @property
     def constant(self):
 
     @property
     def constant(self):
@@ -179,8 +178,7 @@ class LinExpr:
         Iterate over the coefficient values in the expression, and the constant
         term.
         """
         Iterate over the coefficient values in the expression, and the constant
         term.
         """
-        for coefficient in self._coefficients.values():
-            yield coefficient
+        yield from self._coefficients.values()
         yield self._constant
 
     def __bool__(self):
         yield self._constant
 
     def __bool__(self):
@@ -249,9 +247,10 @@ class LinExpr:
         """
         Test whether two linear expressions are equal.
         """
         """
         Test whether two linear expressions are equal.
         """
-        return isinstance(other, LinExpr) and \
-            self._coefficients == other._coefficients and \
-            self._constant == other._constant
+        if isinstance(other, LinExpr):
+            return self._coefficients == other._coefficients and \
+                self._constant == other._constant
+        return NotImplemented
 
     def __le__(self, other):
         from .polyhedra import Le
 
     def __le__(self, other):
         from .polyhedra import Le
@@ -274,9 +273,9 @@ class LinExpr:
         Return the expression multiplied by its lowest common denominator to
         make all values integer.
         """
         Return the expression multiplied by its lowest common denominator to
         make all values integer.
         """
-        lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
+        lcd = functools.reduce(lambda a, b: a*b // gcd(a, b),
             [value.denominator for value in self.values()])
             [value.denominator for value in self.values()])
-        return self * lcm
+        return self * lcd
 
     def subs(self, symbol, expression=None):
         """
 
     def subs(self, symbol, expression=None):
         """
@@ -295,21 +294,16 @@ class LinExpr:
         2*x + y + 1
         """
         if expression is None:
         2*x + y + 1
         """
         if expression is None:
-            if isinstance(symbol, Mapping):
-                symbol = symbol.items()
-            substitutions = symbol
+            substitutions = dict(symbol)
         else:
         else:
-            substitutions = [(symbol, expression)]
-        result = self
-        for symbol, expression in substitutions:
+            substitutions = {symbol: expression}
+        for symbol in substitutions:
             if not isinstance(symbol, Symbol):
                 raise TypeError('symbols must be Symbol instances')
             if not isinstance(symbol, Symbol):
                 raise TypeError('symbols must be Symbol instances')
-            coefficients = [(othersymbol, coefficient)
-                for othersymbol, coefficient in result._coefficients.items()
-                if othersymbol != symbol]
-            coefficient = result._coefficients.get(symbol, 0)
-            constant = result._constant
-            result = LinExpr(coefficients, constant) + coefficient*expression
+        result = self._constant
+        for symbol, coefficient in self._coefficients.items():
+            expression = substitutions.get(symbol, symbol)
+            result += coefficient * expression
         return result
 
     @classmethod
         return result
 
     @classmethod
@@ -337,7 +331,7 @@ class LinExpr:
                 return left / right
         raise SyntaxError('invalid syntax')
 
                 return left / right
         raise SyntaxError('invalid syntax')
 
-    _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()')
+    _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d]\w*|\()')
 
     @classmethod
     def fromstring(cls, string):
 
     @classmethod
     def fromstring(cls, string):
@@ -345,7 +339,7 @@ class LinExpr:
         Create an expression from a string. Raise SyntaxError if the string is
         not properly formatted.
         """
         Create an expression from a string. Raise SyntaxError if the string is
         not properly formatted.
         """
-        # add implicit multiplication operators, e.g. '5x' -> '5*x'
+        # Add implicit multiplication operators, e.g. '5x' -> '5*x'.
         string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
         tree = ast.parse(string, 'eval')
         expr = cls._fromast(tree)
         string = LinExpr._RE_NUM_VAR.sub(r'\1*\2', string)
         tree = ast.parse(string, 'eval')
         expr = cls._fromast(tree)
@@ -422,7 +416,8 @@ class LinExpr:
             if symbol == sympy.S.One:
                 constant = coefficient
             elif isinstance(symbol, sympy.Dummy):
             if symbol == sympy.S.One:
                 constant = coefficient
             elif isinstance(symbol, sympy.Dummy):
-                # we cannot properly convert dummy symbols
+                # We cannot properly convert dummy symbols with respect to
+                # symbol equalities.
                 raise TypeError('cannot convert dummy symbols')
             elif isinstance(symbol, sympy.Symbol):
                 symbol = Symbol(symbol.name)
                 raise TypeError('cannot convert dummy symbols')
             elif isinstance(symbol, sympy.Symbol):
                 symbol = Symbol(symbol.name)
@@ -456,6 +451,13 @@ class Symbol(LinExpr):
     Two instances of Symbol are equal if they have the same name.
     """
 
     Two instances of Symbol are equal if they have the same name.
     """
 
+    __slots__ = (
+        '_name',
+        '_constant',
+        '_symbols',
+        '_dimension',
+    )
+
     def __new__(cls, name):
         """
         Return a symbol with the name string given in argument.
     def __new__(cls, name):
         """
         Return a symbol with the name string given in argument.
@@ -469,12 +471,17 @@ class Symbol(LinExpr):
             raise SyntaxError('invalid syntax')
         self = object().__new__(cls)
         self._name = name
             raise SyntaxError('invalid syntax')
         self = object().__new__(cls)
         self._name = name
-        self._coefficients = {self: Fraction(1)}
         self._constant = Fraction(0)
         self._symbols = (self,)
         self._dimension = 1
         return self
 
         self._constant = Fraction(0)
         self._symbols = (self,)
         self._dimension = 1
         return self
 
+    @property
+    def _coefficients(self):
+        # This is not implemented as an attribute, because __hash__ is not
+        # callable in __new__ in class Dummy.
+        return {self: Fraction(1)}
+
     @property
     def name(self):
         """
     @property
     def name(self):
         """
@@ -499,7 +506,9 @@ class Symbol(LinExpr):
         return True
 
     def __eq__(self, other):
         return True
 
     def __eq__(self, other):
-        return self.sortkey() == other.sortkey()
+        if isinstance(other, Symbol):
+            return self.sortkey() == other.sortkey()
+        return NotImplemented
 
     def asdummy(self):
         """
 
     def asdummy(self):
         """
@@ -557,15 +566,8 @@ class Dummy(Symbol):
         """
         if name is None:
             name = 'Dummy_{}'.format(Dummy._count)
         """
         if name is None:
             name = 'Dummy_{}'.format(Dummy._count)
-        elif not isinstance(name, str):
-            raise TypeError('name must be a string')
-        self = object().__new__(cls)
+        self = super().__new__(cls, name)
         self._index = Dummy._count
         self._index = Dummy._count
-        self._name = name.strip()
-        self._coefficients = {self: Fraction(1)}
-        self._constant = Fraction(0)
-        self._symbols = (self,)
-        self._dimension = 1
         Dummy._count += 1
         return self
 
         Dummy._count += 1
         return self
 
@@ -590,6 +592,13 @@ class Rational(LinExpr, Fraction):
     fractions.Fraction classes.
     """
 
     fractions.Fraction classes.
     """
 
+    __slots__ = (
+        '_coefficients',
+        '_constant',
+        '_symbols',
+        '_dimension',
+    ) + Fraction.__slots__
+
     def __new__(cls, numerator=0, denominator=None):
         self = object().__new__(cls)
         self._coefficients = {}
     def __new__(cls, numerator=0, denominator=None):
         self = object().__new__(cls)
         self._coefficients = {}