X-Git-Url: https://svn.cri.ensmp.fr/git/linpy.git/blobdiff_plain/f2561050230a9c56f842acb698853f6998528aaa..29ed88d1a15d283ea6f3340a4dd97e8cc7c2d2d4:/pypol/linexprs.py?ds=inline diff --git a/pypol/linexprs.py b/pypol/linexprs.py index 10daf9d..ccd1564 100644 --- a/pypol/linexprs.py +++ b/pypol/linexprs.py @@ -3,13 +3,13 @@ import functools import numbers import re -from collections import OrderedDict +from collections import OrderedDict, defaultdict from fractions import Fraction, gcd __all__ = [ 'Expression', - 'Symbol', 'symbols', 'symbolname', 'symbolnames', + 'Symbol', 'symbols', 'Constant', ] @@ -36,37 +36,38 @@ class Expression: '_constant', '_symbols', '_dimension', - '_hash', ) def __new__(cls, coefficients=None, constant=0): if isinstance(coefficients, str): if constant: raise TypeError('too many arguments') - return cls.fromstring(coefficients) - if isinstance(coefficients, dict): - coefficients = coefficients.items() + return Expression.fromstring(coefficients) if coefficients is None: return Constant(constant) + if isinstance(coefficients, dict): + coefficients = coefficients.items() + for symbol, coefficient in coefficients: + if not isinstance(symbol, Symbol): + raise TypeError('symbols must be Symbol instances') coefficients = [(symbol, coefficient) for symbol, coefficient in coefficients if coefficient != 0] if len(coefficients) == 0: return Constant(constant) - elif len(coefficients) == 1 and constant == 0: + if len(coefficients) == 1 and constant == 0: symbol, coefficient = coefficients[0] if coefficient == 1: - return Symbol(symbol) + return symbol self = object().__new__(cls) - self._coefficients = {} - for symbol, coefficient in coefficients: - symbol = symbolname(symbol) + self._coefficients = OrderedDict() + for symbol, coefficient in sorted(coefficients, + key=lambda item: item[0].name): if isinstance(coefficient, Constant): coefficient = coefficient.constant if not isinstance(coefficient, numbers.Rational): raise TypeError('coefficients must be rational numbers ' 'or Constant instances') self._coefficients[symbol] = coefficient - self._coefficients = OrderedDict(sorted(self._coefficients.items())) if isinstance(constant, Constant): constant = constant.constant if not isinstance(constant, numbers.Rational): @@ -75,11 +76,11 @@ class Expression: self._constant = constant self._symbols = tuple(self._coefficients) self._dimension = len(self._symbols) - self._hash = hash((tuple(self._coefficients.items()), self._constant)) return self def coefficient(self, symbol): - symbol = symbolname(symbol) + if not isinstance(symbol, Symbol): + raise TypeError('symbol must be a Symbol instance') try: return self._coefficients[symbol] except KeyError: @@ -103,7 +104,7 @@ class Expression: return self._dimension def __hash__(self): - return self._hash + return hash((tuple(self._coefficients.items()), self._constant)) def isconstant(self): return False @@ -112,8 +113,7 @@ class Expression: return False def values(self): - for symbol in self.symbols: - yield self.coefficient(symbol) + yield from self._coefficients.values() yield self.constant def __bool__(self): @@ -127,12 +127,9 @@ class Expression: @_polymorphic def __add__(self, other): - coefficients = dict(self.coefficients()) + coefficients = defaultdict(Constant, self.coefficients()) for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] += coefficient - else: - coefficients[symbol] = coefficient + coefficients[symbol] += coefficient constant = self.constant + other.constant return Expression(coefficients, constant) @@ -140,12 +137,9 @@ class Expression: @_polymorphic def __sub__(self, other): - coefficients = dict(self.coefficients()) + coefficients = defaultdict(Constant, self.coefficients()) for symbol, coefficient in other.coefficients(): - if symbol in coefficients: - coefficients[symbol] -= coefficient - else: - coefficients[symbol] = -coefficient + coefficients[symbol] -= coefficient constant = self.constant - other.constant return Expression(coefficients, constant) @@ -172,9 +166,8 @@ class Expression: if other.isconstant(): coefficients = dict(self.coefficients()) for symbol in coefficients: - coefficients[symbol] = \ - Fraction(coefficients[symbol], other.constant) - constant = Fraction(self.constant, other.constant) + coefficients[symbol] = Constant(coefficients[symbol], other.constant) + constant = Constant(self.constant, other.constant) return Expression(coefficients, constant) if isinstance(other, Expression): raise ValueError('non-linear expression: ' @@ -184,8 +177,7 @@ class Expression: def __rtruediv__(self, other): if isinstance(other, self): if self.isconstant(): - constant = Fraction(other, self.constant) - return Expression(constant=constant) + return Constant(other, self.constant) else: raise ValueError('non-linear expression: ' '{} / {}'.format(other._parenstr(), self._parenstr())) @@ -196,8 +188,8 @@ class Expression: # "normal" equality # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs return isinstance(other, Expression) and \ - self._coefficients == other._coefficients and \ - self.constant == other.constant + self._coefficients == other._coefficients and \ + self.constant == other.constant @_polymorphic def __le__(self, other): @@ -219,11 +211,28 @@ class Expression: from .polyhedra import Gt return Gt(self, other) - def _toint(self): + def scaleint(self): lcm = functools.reduce(lambda a, b: a*b // gcd(a, b), [value.denominator for value in self.values()]) return self * lcm + def subs(self, symbol, expression=None): + if expression is None: + if isinstance(symbol, dict): + symbol = symbol.items() + substitutions = symbol + else: + substitutions = [(symbol, expression)] + result = self + for symbol, expression in substitutions: + coefficients = [(othersymbol, coefficient) + for othersymbol, coefficient in result.coefficients() + if othersymbol != symbol] + coefficient = result.coefficient(symbol) + constant = result.constant + result = Expression(coefficients, constant) + coefficient*expression + return result + @classmethod def _fromast(cls, node): if isinstance(node, ast.Module) and len(node.body) == 1: @@ -249,46 +258,23 @@ class Expression: return left / right raise SyntaxError('invalid syntax') - def subs(self, symbol, expression=None): - if expression is None: - if isinstance(symbol, dict): - symbol = symbol.items() - substitutions = symbol - else: - substitutions = [(symbol, expression)] - result = self - for symbol, expression in substitutions: - symbol = symbolname(symbol) - result = result._subs(symbol, expression) - return result - - def _subs(self, symbol, expression): - coefficients = {name: coefficient - for name, coefficient in self.coefficients() - if name != symbol} - constant = self.constant - coefficient = self.coefficient(symbol) - result = Expression(coefficients, self.constant) - result += coefficient * expression - return result - _RE_NUM_VAR = re.compile(r'(\d+|\))\s*([^\W\d_]\w*|\()') @classmethod def fromstring(cls, string): # add implicit multiplication operators, e.g. '5x' -> '5*x' - string = cls._RE_NUM_VAR.sub(r'\1*\2', string) + string = Expression._RE_NUM_VAR.sub(r'\1*\2', string) tree = ast.parse(string, 'eval') return cls._fromast(tree) - def __str__(self): + def __repr__(self): string = '' i = 0 for symbol in self.symbols: coefficient = self.coefficient(symbol) if coefficient == 1: if i == 0: - string += symbol + string += symbol.name else: string += ' + {}'.format(symbol) elif coefficient == -1: @@ -325,30 +311,27 @@ class Expression: else: return '({})'.format(string) - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, str(self)) - @classmethod def fromsympy(cls, expr): import sympy - coefficients = {} + coefficients = [] constant = 0 for symbol, coefficient in expr.as_coefficients_dict().items(): coefficient = Fraction(coefficient.p, coefficient.q) if symbol == sympy.S.One: constant = coefficient elif isinstance(symbol, sympy.Symbol): - symbol = symbol.name - coefficients[symbol] = coefficient + symbol = Symbol(symbol.name) + coefficients.append((symbol, coefficient)) else: raise ValueError('non-linear expression: {!r}'.format(expr)) - return cls(coefficients, constant) + return Expression(coefficients, constant) def tosympy(self): import sympy expr = 0 for symbol, coefficient in self.coefficients(): - term = coefficient * sympy.Symbol(symbol) + term = coefficient * sympy.Symbol(symbol.name) expr += term expr += self.constant return expr @@ -358,14 +341,13 @@ class Symbol(Expression): __slots__ = ( '_name', - '_hash', ) def __new__(cls, name): - name = symbolname(name) + if not isinstance(name, str): + raise TypeError('name must be a string') self = object().__new__(cls) - self._name = name - self._hash = hash(self._name) + self._name = name.strip() return self @property @@ -373,17 +355,18 @@ class Symbol(Expression): return self._name def __hash__(self): - return self._hash + return hash(self._name) def coefficient(self, symbol): - symbol = symbolname(symbol) - if symbol == self.name: + if not isinstance(symbol, Symbol): + raise TypeError('symbol must be a Symbol instance') + if symbol == self: return 1 else: return 0 def coefficients(self): - yield self.name, 1 + yield self, 1 @property def constant(self): @@ -391,7 +374,7 @@ class Symbol(Expression): @property def symbols(self): - return self.name, + return self, @property def dimension(self): @@ -400,6 +383,9 @@ class Symbol(Expression): def issymbol(self): return True + def values(self): + yield 1 + def __eq__(self, other): return isinstance(other, Symbol) and self.name == other.name @@ -413,14 +399,11 @@ class Symbol(Expression): return Symbol(node.id) raise SyntaxError('invalid syntax') - def __repr__(self): - return '{}({!r})'.format(self.__class__.__name__, self._name) - @classmethod def fromsympy(cls, expr): import sympy if isinstance(expr, sympy.Symbol): - return cls(expr.name) + return Symbol(expr.name) else: raise TypeError('expr must be a sympy.Symbol instance') @@ -428,27 +411,13 @@ class Symbol(Expression): def symbols(names): if isinstance(names, str): names = names.replace(',', ' ').split() - return (Symbol(name) for name in names) - -def symbolname(symbol): - if isinstance(symbol, str): - return symbol.strip() - elif isinstance(symbol, Symbol): - return symbol.name - else: - raise TypeError('symbol must be a string or a Symbol instance') - -def symbolnames(symbols): - if isinstance(symbols, str): - return symbols.replace(',', ' ').split() - return tuple(symbolname(symbol) for symbol in symbols) + return tuple(Symbol(name) for name in names) class Constant(Expression): __slots__ = ( '_constant', - '_hash', ) def __new__(cls, numerator=0, denominator=None): @@ -457,18 +426,18 @@ class Constant(Expression): self._constant = numerator.constant else: self._constant = Fraction(numerator, denominator) - self._hash = hash(self._constant) return self def __hash__(self): - return self._hash + return hash(self.constant) def coefficient(self, symbol): - symbol = symbolname(symbol) + if not isinstance(symbol, Symbol): + raise TypeError('symbol must be a Symbol instance') return 0 def coefficients(self): - yield from [] + yield from () @property def symbols(self): @@ -481,6 +450,9 @@ class Constant(Expression): def isconstant(self): return True + def values(self): + yield self._constant + @_polymorphic def __eq__(self, other): return isinstance(other, Constant) and self.constant == other.constant @@ -490,25 +462,16 @@ class Constant(Expression): @classmethod def fromstring(cls, string): - if isinstance(string, str): - return Constant(Fraction(string)) - else: + if not isinstance(string, str): raise TypeError('string must be a string instance') - - def __repr__(self): - if self.constant.denominator == 1: - return '{}({!r})'.format(self.__class__.__name__, - self.constant.numerator) - else: - return '{}({!r}, {!r})'.format(self.__class__.__name__, - self.constant.numerator, self.constant.denominator) + return Constant(Fraction(string)) @classmethod def fromsympy(cls, expr): import sympy if isinstance(expr, sympy.Rational): - return cls(expr.p, expr.q) + return Constant(expr.p, expr.q) elif isinstance(expr, numbers.Rational): - return cls(expr) + return Constant(expr) else: raise TypeError('expr must be a sympy.Rational instance')