Skip to content

Commit 565d52c

Browse files
committed
fix indexing of non-array types
1 parent 374024c commit 565d52c

File tree

3 files changed

+41
-25
lines changed

3 files changed

+41
-25
lines changed

src/backend.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ def _backend(self,level=0):
8080

8181
@extend(node.arrayref)
8282
def _backend(self,level=0):
83-
fmt = "%s[%s]"
84-
return fmt % (self.func_expr._backend(),
85-
self.args._backend())
83+
from . resolve import A
84+
x = self.func_expr._backend()
85+
fmt = "%s[%s]" if x in A else "take(%s,%s)"
86+
return fmt % (x, self.args._backend())
8687

8788
@extend(node.break_stmt)
8889
def _backend(self,level=0):
@@ -189,14 +190,10 @@ def _backend(self,level=0):
189190
if len(self.args) == 1:
190191
return "%s %s" % (optable.get(self.op,self.op),
191192
self.args[0]._backend())
192-
if len(self.args) == 2:
193-
return "%s %s %s" % (self.args[0]._backend(),
194-
optable.get(self.op,self.op),
195-
self.args[1]._backend())
196-
#import pdb;pdb.set_trace()
197-
ret = "%s=" % self.ret._backend() if self.ret else ""
198-
return ret+"%s(%s)" % (self.op,
199-
",".join([t._backend() for t in self.args]))
193+
if hasattr(self, "ret"):
194+
ret = f"{self.ret._backend()}="
195+
return ret+"%s(%s)" % (self.op, ",".join([t._backend() for t in self.args]))
196+
return optable.get(self.op,self.op).join([t._backend() for t in self.args])
200197

201198

202199
@extend(node.expr_list)
@@ -320,8 +317,7 @@ def _backend(self, level=0):
320317
if not self.args:
321318
return "[]"
322319
elif any(b.__class__ is node.string for a in self.args for b in a):
323-
s = " + ".join(b._backend() for a in self.args for b in a)
324-
return s
320+
return " + ".join(b._backend() for a in self.args for b in a)
325321
else:
326322
#import pdb; pdb.set_trace()
327323
return "concat([%s])" % self.args[0]._backend()
@@ -366,12 +362,7 @@ def _backend(self,level=0):
366362

367363
@extend(node.string)
368364
def _backend(self,level=0):
369-
# if '\\' in self.value:
370-
# print(self.value)
371-
try:
372-
return "'%s'" % str(self.value).encode("string_escape")
373-
except:
374-
return "'%s'" % str(self.value)
365+
return f"'{self.value}'"
375366

376367
@extend(node.sub)
377368
def _backend(self,level=0):

src/libsmop.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
import sys
1919
import time
20+
import json
2021
import glob
2122
from sys import stdin, stdout, stderr
2223

@@ -28,6 +29,8 @@
2829
dir = glob.glob
2930
eps = np.finfo(float).eps
3031
NaN = np.nan
32+
jsonencode = json.dumps
33+
jsondecode = json.loads
3134

3235

3336
# def _load_matlab_builtins(*names):
@@ -47,6 +50,18 @@ def clear(*args):
4750
del globals()[a]
4851

4952

53+
def take(a, i):
54+
""" Get an item with matlab indexing (1-based). """
55+
if isinstance(a, matlabarray):
56+
return a[i]
57+
else:
58+
if isinstance(i, slice):
59+
i = slice(i.start - 1, i.stop, i.step)
60+
else:
61+
i -= 1
62+
return a[i]
63+
64+
5065
def abs(a):
5166
return np.abs(a)
5267

@@ -76,6 +91,8 @@ def concat(args):
7691
>>> concat([[1,2,3,4,5] , [1,2,3,4,5]])
7792
[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
7893
"""
94+
if any(isinstance(a, str) for a in args):
95+
return "".join(args)
7996
t = [matlabarray(a) for a in args]
8097
return np.concatenate(t)
8198

@@ -279,7 +296,7 @@ def length(a):
279296
try:
280297
return max(np.asarray(a).shape)
281298
except ValueError:
282-
return 1
299+
return len(a)
283300

284301

285302
def load(a):

src/resolve.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
from . import node
2525
from . node import extend
2626

27+
28+
A = set() # variables that have array or matrix values
29+
2730
def as_networkx(t):
2831
G = nx.DiGraph()
2932
for u in node.postorder(t):
@@ -86,7 +89,7 @@ def _lhs_resolve(self,symtab):
8689
self.func_expr._resolve(symtab) # A
8790
self.args._resolve(symtab) # B
8891
self.func_expr._lhs_resolve(symtab)
89-
92+
A.add(self.func_expr.name)
9093

9194

9295
@extend(node.expr)
@@ -161,6 +164,9 @@ def _lhs_resolve(self,symtab):
161164
def _resolve(self,symtab):
162165
self.args._resolve(symtab)
163166
self.ret._lhs_resolve(symtab)
167+
if isinstance(self.args, (node.matrix, node.cellarray)):
168+
if isinstance(self.ret, node.ident):
169+
A.add(self.ret.name)
164170

165171
@extend(node.null_stmt)
166172
@extend(node.continue_stmt)
@@ -210,7 +216,7 @@ def _resolve(self,symtab):
210216
@extend(node.string)
211217
@extend(node.comment_stmt)
212218
def _resolve(self,symtab):
213-
pass
219+
pass
214220

215221
# @extend(node.call_stmt)
216222
# def _resolve(self,symtab):
@@ -298,6 +304,8 @@ def fix_let_statement(u):
298304
if (isinstance(u.ret, node.ident) and
299305
isinstance(u.args, node.matrix)):
300306
if any(b.__class__ is node.string for a in u.args.args for b in a):
301-
return
302-
u.args = node.funcall(func_expr=node.ident("matlabarray"),
303-
args=node.expr_list([u.args]))
307+
u.args = node.expr("+", u.args.args[0])
308+
A.discard(u.ret.name)
309+
else:
310+
u.args = node.funcall(func_expr=node.ident("matlabarray"),
311+
args=node.expr_list([u.args]))

0 commit comments

Comments
 (0)