Check symbol names
authorVivien Maisonneuve <v.maisonneuve@gmail.com>
Mon, 18 Aug 2014 14:55:29 +0000 (16:55 +0200)
committerVivien Maisonneuve <v.maisonneuve@gmail.com>
Mon, 18 Aug 2014 16:08:38 +0000 (18:08 +0200)
linpy/linexprs.py
linpy/tests/test_linexprs.py

index cf2a980..834c3b4 100644 (file)
@@ -456,8 +456,13 @@ class Symbol(LinExpr):
         """
         if not isinstance(name, str):
             raise TypeError('name must be a string')
         """
         if not isinstance(name, str):
             raise TypeError('name must be a string')
+        node = ast.parse(name)
+        try:
+            name = node.body[0].value.id
+        except (AttributeError, SyntaxError):
+            raise SyntaxError('invalid syntax')
         self = object().__new__(cls)
         self = object().__new__(cls)
-        self._name = name.strip()
+        self._name = name
         self._coefficients = {self: Fraction(1)}
         self._constant = Fraction(0)
         self._symbols = (self,)
         self._coefficients = {self: Fraction(1)}
         self._constant = Fraction(0)
         self._symbols = (self,)
index fb7e4a2..9599d06 100644 (file)
@@ -214,11 +214,20 @@ class TestSymbol(unittest.TestCase):
         self.y = Symbol('y')
 
     def test_new(self):
         self.y = Symbol('y')
 
     def test_new(self):
-        self.assertEqual(Symbol(' x '), self.x)
+        self.assertEqual(Symbol('x'), self.x)
         with self.assertRaises(TypeError):
             Symbol(self.x)
         with self.assertRaises(TypeError):
             Symbol(1)
         with self.assertRaises(TypeError):
             Symbol(self.x)
         with self.assertRaises(TypeError):
             Symbol(1)
+        with self.assertRaises(SyntaxError):
+            Symbol('1')
+        with self.assertRaises(SyntaxError):
+            Symbol('x.1')
+        with self.assertRaises(SyntaxError):
+            Symbol('x 1')
+        Symbol('_')
+        Symbol('_x')
+        Symbol('x_1')
 
     def test_name(self):
         self.assertEqual(self.x.name, 'x')
 
     def test_name(self):
         self.assertEqual(self.x.name, 'x')