Better implementation of _polymorphic_operator
[linpy.git] / pypol / linear.py
index 6550e92..a5f55fa 100644 (file)
@@ -14,6 +14,31 @@ __all__ = [
 ]
 
 
 ]
 
 
+def _polymorphic_method(func):
+    @functools.wraps(func)
+    def wrapper(a, b):
+        if isinstance(b, Expression):
+            return func(a, b)
+        if isinstance(b, numbers.Rational):
+            b = constant(b)
+            return func(a, b)
+        return NotImplemented
+    return wrapper
+
+def _polymorphic_operator(func):
+    # A polymorphic operator should call a polymorphic method, hence we just
+    # have to test the left operand.
+    @functools.wraps(func)
+    def wrapper(a, b):
+        if isinstance(a, numbers.Rational):
+            a = constant(a)
+            return func(a, b)
+        elif isinstance(a, Expression):
+            return func(a, b)
+        raise TypeError('arguments must be linear expressions')
+    return wrapper
+
+
 class Expression:
     """
     This class implements linear expressions.
 class Expression:
     """
     This class implements linear expressions.
@@ -96,18 +121,7 @@ class Expression:
     def __neg__(self):
         return self * -1
 
     def __neg__(self):
         return self * -1
 
-    def _polymorphic(func):
-        @functools.wraps(func)
-        def wrapper(self, other):
-            if isinstance(other, Expression):
-                return func(self, other)
-            if isinstance(other, numbers.Rational):
-                other = Expression(constant=other)
-                return func(self, other)
-            return NotImplemented
-        return wrapper
-
-    @_polymorphic
+    @_polymorphic_method
     def __add__(self, other):
         coefficients = dict(self.coefficients())
         for symbol, coefficient in other.coefficients():
     def __add__(self, other):
         coefficients = dict(self.coefficients())
         for symbol, coefficient in other.coefficients():
@@ -120,7 +134,7 @@ class Expression:
 
     __radd__ = __add__
 
 
     __radd__ = __add__
 
-    @_polymorphic
+    @_polymorphic_method
     def __sub__(self, other):
         coefficients = dict(self.coefficients())
         for symbol, coefficient in other.coefficients():
     def __sub__(self, other):
         coefficients = dict(self.coefficients())
         for symbol, coefficient in other.coefficients():
@@ -133,7 +147,7 @@ class Expression:
 
     __rsub__ = __sub__
 
 
     __rsub__ = __sub__
 
-    @_polymorphic
+    @_polymorphic_method
     def __mul__(self, other):
         if other.isconstant():
             coefficients = dict(self.coefficients())
     def __mul__(self, other):
         if other.isconstant():
             coefficients = dict(self.coefficients())
@@ -148,7 +162,7 @@ class Expression:
 
     __rmul__ = __mul__
 
 
     __rmul__ = __mul__
 
-    @_polymorphic
+    @_polymorphic_method
     def __truediv__(self, other):
         if other.isconstant():
             coefficients = dict(self.coefficients())
     def __truediv__(self, other):
         if other.isconstant():
             coefficients = dict(self.coefficients())
@@ -230,7 +244,7 @@ class Expression:
     def fromstring(cls, string):
         raise NotImplementedError
 
     def fromstring(cls, string):
         raise NotImplementedError
 
-    @_polymorphic
+    @_polymorphic_method
     def __eq__(self, other):
         # "normal" equality
         # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
     def __eq__(self, other):
         # "normal" equality
         # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
@@ -246,29 +260,32 @@ class Expression:
                 [value.denominator for value in self.values()])
         return self * lcm
 
                 [value.denominator for value in self.values()])
         return self * lcm
 
-    @_polymorphic
+    @_polymorphic_method
     def _eq(self, other):
         return Polyhedron(equalities=[(self - other)._canonify()])
 
     def _eq(self, other):
         return Polyhedron(equalities=[(self - other)._canonify()])
 
-    @_polymorphic
+    @_polymorphic_method
     def __le__(self, other):
         return Polyhedron(inequalities=[(self - other)._canonify()])
 
     def __le__(self, other):
         return Polyhedron(inequalities=[(self - other)._canonify()])
 
-    @_polymorphic
+    @_polymorphic_method
     def __lt__(self, other):
         return Polyhedron(inequalities=[(self - other)._canonify() + 1])
 
     def __lt__(self, other):
         return Polyhedron(inequalities=[(self - other)._canonify() + 1])
 
-    @_polymorphic
+    @_polymorphic_method
     def __ge__(self, other):
         return Polyhedron(inequalities=[(other - self)._canonify()])
 
     def __ge__(self, other):
         return Polyhedron(inequalities=[(other - self)._canonify()])
 
-    @_polymorphic
+    @_polymorphic_method
     def __gt__(self, other):
         return Polyhedron(inequalities=[(other - self)._canonify() + 1])
 
 
 def constant(numerator=0, denominator=None):
     def __gt__(self, other):
         return Polyhedron(inequalities=[(other - self)._canonify() + 1])
 
 
 def constant(numerator=0, denominator=None):
-    return Expression(constant=Fraction(numerator, denominator))
+    if denominator is None and isinstance(numerator, numbers.Rational):
+        return Expression(constant=numerator)
+    else:
+        return Expression(constant=Fraction(numerator, denominator))
 
 def symbol(name):
     if not isinstance(name, str):
 
 def symbol(name):
     if not isinstance(name, str):
@@ -281,35 +298,23 @@ def symbols(names):
     return (symbol(name) for name in names)
 
 
     return (symbol(name) for name in names)
 
 
-def _operator(func):
-    @functools.wraps(func)
-    def wrapper(a, b):
-        if isinstance(a, numbers.Rational):
-            a = constant(a)
-        if isinstance(b, numbers.Rational):
-            b = constant(b)
-        if isinstance(a, Expression) and isinstance(b, Expression):
-            return func(a, b)
-        raise TypeError('arguments must be linear expressions')
-    return wrapper
-
-@_operator
+@_polymorphic_operator
 def eq(a, b):
     return a._eq(b)
 
 def eq(a, b):
     return a._eq(b)
 
-@_operator
+@_polymorphic_operator
 def le(a, b):
     return a <= b
 
 def le(a, b):
     return a <= b
 
-@_operator
+@_polymorphic_operator
 def lt(a, b):
     return a < b
 
 def lt(a, b):
     return a < b
 
-@_operator
+@_polymorphic_operator
 def ge(a, b):
     return a >= b
 
 def ge(a, b):
     return a >= b
 
-@_operator
+@_polymorphic_operator
 def gt(a, b):
     return a > b
 
 def gt(a, b):
     return a > b