add small isl changes
[linpy.git] / pypol / linear.py
1 '''
2 VERY MESSY, made notes on where I will change things
3 '''
4
5 import functools
6 import numbers
7 import ctypes, ctypes.util
8 from pypol import isl
9
10 from fractions import Fraction, gcd
11
12 libisl = ctypes.CDLL(ctypes.util.find_library('isl'))
13
14 libisl.isl_printer_get_str.restype = ctypes.c_char_p
15
16 __all__ = [
17 'Expression',
18 'constant', 'symbol', 'symbols',
19 'eq', 'le', 'lt', 'ge', 'gt',
20 'Polyhedron',
21 'empty', 'universe'
22 ]
23
24
25 _CONTEXT = isl.Context()
26
27 def _polymorphic_method(func):
28 @functools.wraps(func)
29 def wrapper(a, b):
30 if isinstance(b, Expression):
31 return func(a, b)
32 if isinstance(b, numbers.Rational):
33 b = constant(b)
34 return func(a, b)
35 return NotImplemented
36 return wrapper
37
38 def _polymorphic_operator(func):
39 # A polymorphic operator should call a polymorphic method, hence we just
40 # have to test the left operand.
41 @functools.wraps(func)
42 def wrapper(a, b):
43 if isinstance(a, numbers.Rational):
44 a = constant(a)
45 return func(a, b)
46 elif isinstance(a, Expression):
47 return func(a, b)
48 raise TypeError('arguments must be linear expressions')
49 return wrapper
50
51
52 class Expression:
53 """
54 This class implements linear expressions.
55 """
56
57 def __new__(cls, coefficients=None, constant=0):
58 if isinstance(coefficients, str):
59 if constant:
60 raise TypeError('too many arguments')
61 return cls.fromstring(coefficients)
62 self = super().__new__(cls)
63 self._coefficients = {}
64 if isinstance(coefficients, dict):
65 coefficients = coefficients.items()
66 if coefficients is not None:
67 for symbol, coefficient in coefficients:
68 if isinstance(symbol, Expression) and symbol.issymbol():
69 symbol = str(symbol)
70 elif not isinstance(symbol, str):
71 raise TypeError('symbols must be strings')
72 if not isinstance(coefficient, numbers.Rational):
73 raise TypeError('coefficients must be rational numbers')
74 if coefficient != 0:
75 self._coefficients[symbol] = coefficient
76 if not isinstance(constant, numbers.Rational):
77 raise TypeError('constant must be a rational number')
78 self._constant = constant
79 return self
80
81
82 def symbols(self):
83 yield from sorted(self._coefficients)
84
85 @property
86 def dimension(self):
87 return len(list(self.symbols()))
88
89 def coefficient(self, symbol):
90 if isinstance(symbol, Expression) and symbol.issymbol():
91 symbol = str(symbol)
92 elif not isinstance(symbol, str):
93 raise TypeError('symbol must be a string')
94 try:
95 return self._coefficients[symbol]
96 except KeyError:
97 return 0
98
99 __getitem__ = coefficient
100
101 def coefficients(self):
102 for symbol in self.symbols():
103 yield symbol, self.coefficient(symbol)
104
105 @property
106 def constant(self):
107 return self._constant
108
109 def isconstant(self):
110 return len(self._coefficients) == 0
111
112 def values(self):
113 for symbol in self.symbols():
114 yield self.coefficient(symbol)
115 yield self.constant
116
117 def values_int(self):
118 for symbol in self.symbols():
119 return self.coefficient(symbol)
120 return int(self.constant)
121
122
123 def symbol(self):
124 if not self.issymbol():
125 raise ValueError('not a symbol: {}'.format(self))
126 for symbol in self.symbols():
127 return symbol
128
129 def issymbol(self):
130 return len(self._coefficients) == 1 and self._constant == 0
131
132 def __bool__(self):
133 return (not self.isconstant()) or bool(self.constant)
134
135 def __pos__(self):
136 return self
137
138 def __neg__(self):
139 return self * -1
140
141 @_polymorphic_method
142 def __add__(self, other):
143 coefficients = dict(self.coefficients())
144 for symbol, coefficient in other.coefficients():
145 if symbol in coefficients:
146 coefficients[symbol] += coefficient
147 else:
148 coefficients[symbol] = coefficient
149 constant = self.constant + other.constant
150 return Expression(coefficients, constant)
151
152 __radd__ = __add__
153
154 @_polymorphic_method
155 def __sub__(self, other):
156 coefficients = dict(self.coefficients())
157 for symbol, coefficient in other.coefficients():
158 if symbol in coefficients:
159 coefficients[symbol] -= coefficient
160 else:
161 coefficients[symbol] = -coefficient
162 constant = self.constant - other.constant
163 return Expression(coefficients, constant)
164
165 __rsub__ = __sub__
166
167 @_polymorphic_method
168 def __mul__(self, other):
169 if other.isconstant():
170 coefficients = dict(self.coefficients())
171 for symbol in coefficients:
172 coefficients[symbol] *= other.constant
173 constant = self.constant * other.constant
174 return Expression(coefficients, constant)
175 if isinstance(other, Expression) and not self.isconstant():
176 raise ValueError('non-linear expression: '
177 '{} * {}'.format(self._parenstr(), other._parenstr()))
178 return NotImplemented
179
180 __rmul__ = __mul__
181
182 @_polymorphic_method
183 def __truediv__(self, other):
184 if other.isconstant():
185 coefficients = dict(self.coefficients())
186 for symbol in coefficients:
187 coefficients[symbol] = \
188 Fraction(coefficients[symbol], other.constant)
189 constant = Fraction(self.constant, other.constant)
190 return Expression(coefficients, constant)
191 if isinstance(other, Expression):
192 raise ValueError('non-linear expression: '
193 '{} / {}'.format(self._parenstr(), other._parenstr()))
194 return NotImplemented
195
196 def __rtruediv__(self, other):
197 if isinstance(other, self):
198 if self.isconstant():
199 constant = Fraction(other, self.constant)
200 return Expression(constant=constant)
201 else:
202 raise ValueError('non-linear expression: '
203 '{} / {}'.format(other._parenstr(), self._parenstr()))
204 return NotImplemented
205
206 def __str__(self):
207 string = ''
208 symbols = sorted(self.symbols())
209 i = 0
210 for symbol in symbols:
211 coefficient = self[symbol]
212 if coefficient == 1:
213 if i == 0:
214 string += symbol
215 else:
216 string += ' + {}'.format(symbol)
217 elif coefficient == -1:
218 if i == 0:
219 string += '-{}'.format(symbol)
220 else:
221 string += ' - {}'.format(symbol)
222 else:
223 if i == 0:
224 string += '{}*{}'.format(coefficient, symbol)
225 elif coefficient > 0:
226 string += ' + {}*{}'.format(coefficient, symbol)
227 else:
228 assert coefficient < 0
229 coefficient *= -1
230 string += ' - {}*{}'.format(coefficient, symbol)
231 i += 1
232 constant = self.constant
233 if constant != 0 and i == 0:
234 string += '{}'.format(constant)
235 elif constant > 0:
236 string += ' + {}'.format(constant)
237 elif constant < 0:
238 constant *= -1
239 string += ' - {}'.format(constant)
240 if string == '':
241 string = '0'
242 return string
243
244 def _parenstr(self, always=False):
245 string = str(self)
246 if not always and (self.isconstant() or self.issymbol()):
247 return string
248 else:
249 return '({})'.format(string)
250
251 def __repr__(self):
252 string = '{}({{'.format(self.__class__.__name__)
253 for i, (symbol, coefficient) in enumerate(self.coefficients()):
254 if i != 0:
255 string += ', '
256 string += '{!r}: {!r}'.format(symbol, coefficient)
257 string += '}}, {!r})'.format(self.constant)
258 return string
259
260 @classmethod
261 def fromstring(cls, string):
262 raise NotImplementedError
263
264 @_polymorphic_method
265 def __eq__(self, other):
266 # "normal" equality
267 # see http://docs.sympy.org/dev/tutorial/gotchas.html#equals-signs
268 return isinstance(other, Expression) and \
269 self._coefficients == other._coefficients and \
270 self.constant == other.constant
271
272 def __hash__(self):
273 return hash((self._coefficients, self._constant))
274
275 def _canonify(self):
276 lcm = functools.reduce(lambda a, b: a*b // gcd(a, b),
277 [value.denominator for value in self.values()])
278 return self * lcm
279
280 @_polymorphic_method
281 def _eq(self, other):
282 return Polyhedron(equalities=[(self - other)._canonify()])
283
284 @_polymorphic_method
285 def __le__(self, other):
286 return Polyhedron(inequalities=[(self - other)._canonify()])
287
288 @_polymorphic_method
289 def __lt__(self, other):
290 return Polyhedron(inequalities=[(self - other)._canonify() + 1])
291
292 @_polymorphic_method
293 def __ge__(self, other):
294 return Polyhedron(inequalities=[(other - self)._canonify()])
295
296 @_polymorphic_method
297 def __gt__(self, other):
298 return Polyhedron(inequalities=[(other - self)._canonify() + 1])
299
300
301 def constant(numerator=0, denominator=None):
302 if denominator is None and isinstance(numerator, numbers.Rational):
303 return Expression(constant=numerator)
304 else:
305 return Expression(constant=Fraction(numerator, denominator))
306
307 def symbol(name):
308 if not isinstance(name, str):
309 raise TypeError('name must be a string')
310 return Expression(coefficients={name: 1})
311
312 def symbols(names):
313 if isinstance(names, str):
314 names = names.replace(',', ' ').split()
315 return (symbol(name) for name in names)
316
317
318 @_polymorphic_operator
319 def eq(a, b):
320 return a._eq(b)
321
322 @_polymorphic_operator
323 def le(a, b):
324 return a <= b
325
326 @_polymorphic_operator
327 def lt(a, b):
328 return a < b
329
330 @_polymorphic_operator
331 def ge(a, b):
332 return a >= b
333
334 @_polymorphic_operator
335 def gt(a, b):
336 return a > b
337
338
339 class Polyhedron:
340 """
341 This class implements polyhedrons.
342 """
343
344 def __new__(cls, equalities=None, inequalities=None):
345 if isinstance(equalities, str):
346 if inequalities is not None:
347 raise TypeError('too many arguments')
348 return cls.fromstring(equalities)
349 self = super().__new__(cls)
350 self._equalities = []
351 if equalities is not None:
352 for constraint in equalities:
353 for value in constraint.values():
354 if value.denominator != 1:
355 raise TypeError('non-integer constraint: '
356 '{} == 0'.format(constraint))
357 self._equalities.append(constraint)
358 self._inequalities = []
359 if inequalities is not None:
360 for constraint in inequalities:
361 for value in constraint.values():
362 if value.denominator != 1:
363 raise TypeError('non-integer constraint: '
364 '{} <= 0'.format(constraint))
365 self._inequalities.append(constraint)
366 print('in polyhedron')
367 #print(self.constraints())
368 self._bset = self.to_isl()
369 #print(self._bset)
370 return self
371
372
373 @property
374 def equalities(self):
375 yield from self._equalities
376
377 @property
378 def inequalities(self):
379 yield from self._inequalities
380
381 @property
382 def constant(self):
383 return self._constant
384
385 def isconstant(self):
386 return len(self._coefficients) == 0
387
388
389 def isempty(self):
390 return bool(libisl.isl_basic_set_is_empty(self._bset))
391
392 def constraints(self):
393 yield from self.equalities
394 yield from self.inequalities
395
396
397 def symbols(self):
398 s = set()
399 for constraint in self.constraints():
400 s.update(constraint.symbols)
401 yield from sorted(s)
402
403 def symbol_count(self):
404 s = []
405 for constraint in self.constraints():
406 s.append(constraint.symbols)
407 return s
408
409 @property
410 def dimension(self):
411 return len(self.symbols())
412
413 def __bool__(self):
414 # return false if the polyhedron is empty, true otherwise
415 if self._equalities or self._inequalities:
416 return False
417 else:
418 return True
419
420
421 def __contains__(self, value):
422 # is the value in the polyhedron?
423 raise NotImplementedError
424
425 def __eq__(self, other):
426 raise NotImplementedError
427
428 def is_empty(self):
429 return
430
431 def isuniverse(self):
432 return self == universe
433
434 def isdisjoint(self, other):
435 # return true if the polyhedron has no elements in common with other
436 raise NotImplementedError
437
438 def issubset(self, other):
439 raise NotImplementedError
440
441 def __le__(self, other):
442 return self.issubset(other)
443
444 def __lt__(self, other):
445 raise NotImplementedError
446
447 def issuperset(self, other):
448 # test whether every element in other is in the polyhedron
449 for value in other:
450 if value == self.constraints():
451 return True
452 else:
453 return False
454 raise NotImplementedError
455
456 def __ge__(self, other):
457 return self.issuperset(other)
458
459 def __gt__(self, other):
460 raise NotImplementedError
461
462 def union(self, *others):
463 # return a new polyhedron with elements from the polyhedron and all
464 # others (convex union)
465 raise NotImplementedError
466
467 def __or__(self, other):
468 return self.union(other)
469
470 def intersection(self, *others):
471 # return a new polyhedron with elements common to the polyhedron and all
472 # others
473 # a poor man's implementation could be:
474 # equalities = list(self.equalities)
475 # inequalities = list(self.inequalities)
476 # for other in others:
477 # equalities.extend(other.equalities)
478 # inequalities.extend(other.inequalities)
479 # return self.__class__(equalities, inequalities)
480 raise NotImplementedError
481
482 def __and__(self, other):
483 return self.intersection(other)
484
485 def difference(self, *others):
486 # return a new polyhedron with elements in the polyhedron that are not
487 # in the others
488 raise NotImplementedError
489
490 def __sub__(self, other):
491 return self.difference(other)
492
493 def __str__(self):
494 constraints = []
495 for constraint in self.equalities:
496 constraints.append('{} == 0'.format(constraint))
497 for constraint in self.inequalities:
498 constraints.append('{} <= 0'.format(constraint))
499 return '{{{}}}'.format(', '.join(constraints))
500
501 def __repr__(self):
502 equalities = list(self.equalities)
503 inequalities = list(self.inequalities)
504 return '{}(equalities={!r}, inequalities={!r})' \
505 ''.format(self.__class__.__name__, equalities, inequalities)
506
507 @classmethod
508 def fromstring(cls, string):
509 raise NotImplementedError
510
511 def to_isl(self):
512 space = libisl.isl_space_set_alloc(_CONTEXT, 0, len(self.symbol_count()))
513 bset = libisl.isl_basic_set_universe(libisl.isl_space_copy(space))
514 copy = libisl.isl_basic_set_copy(bset)
515 ls = libisl.isl_local_space_from_space(libisl.isl_space_copy(space))
516 ceq = libisl.isl_equality_alloc(libisl.isl_local_space_copy(ls))
517 for value in self.equalities:
518 for value in self.equalities:
519 #need method to get expression value
520 if self._equalities:
521 value = self._equalities.method_get_value_from_expression()
522 ceq = libisl.isl_constraint_set_constant_val(ceq, value )
523 #ceq = libisl.isl_constraint_set_coefficient_si(ceq, libisl.isl_set_dim, self.symbols(), value)
524 '''
525 cin = libisl.isl_inequality_alloc(libisl.isl_local_space_copy(ls))
526 for item in self.inequalities:
527 for item in self.inequalities:
528 if isinstance(item, int):
529 cin = libisl.isl_constraint_set_constant_si(cin, item)
530 else:
531 cin = libisl.isl_constraint_set_coefficient_si(cin, libisl.isl_set_dim, self.symbols(), item)
532 '''
533 bsetfinal = libisl.isl_basic_set_add_contraint(copy, ceq)
534 #bsetfinal = libisl.isl_basic_set_add_contraint(copy, cin)
535 string = libisl.isl_printer_print_basic_set(bsetfinal)
536 print(string)
537 return self
538
539 empty = eq(1, 1)
540
541
542 universe = Polyhedron()