Skip to content

Commit e612311

Browse files
committed
Put array as first arg in margin proc
1 parent 0b835e1 commit e612311

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

auto_editor/lang/palet.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -620,42 +620,36 @@ def make_array(dtype: Sym, size: int, v: int = 0) -> np.ndarray:
620620
raise MyError(f"number too large to be converted to {dtype}")
621621

622622

623-
def minclip(oarr: BoolList, _min: int) -> BoolList:
623+
def minclip(oarr: BoolList, _min: int, /) -> BoolList:
624624
arr = np.copy(oarr)
625625
mut_remove_small(arr, _min, replace=1, with_=0)
626626
return arr
627627

628628

629-
def mincut(oarr: BoolList, _min: int) -> BoolList:
629+
def mincut(oarr: BoolList, _min: int, /) -> BoolList:
630630
arr = np.copy(oarr)
631631
mut_remove_small(arr, _min, replace=0, with_=1)
632632
return arr
633633

634634

635-
def maxclip(oarr: BoolList, _min: int) -> BoolList:
635+
def maxclip(oarr: BoolList, _min: int, /) -> BoolList:
636636
arr = np.copy(oarr)
637637
mut_remove_large(arr, _min, replace=1, with_=0)
638638
return arr
639639

640640

641-
def maxcut(oarr: BoolList, _min: int) -> BoolList:
641+
def maxcut(oarr: BoolList, _min: int, /) -> BoolList:
642642
arr = np.copy(oarr)
643643
mut_remove_large(arr, _min, replace=0, with_=1)
644644
return arr
645645

646646

647-
def margin(a: int, b: Any, c: Any = None) -> BoolList:
648-
if c is None:
649-
check_args("margin", [a, b], (2, 2), (is_int, is_boolarr))
650-
oarr = b
651-
start, end = a, a
652-
else:
653-
check_args("margin", [a, b, c], (3, 3), (is_int, is_int, is_boolarr))
654-
oarr = c
655-
start, end = a, b
656-
647+
def margin(oarr: BoolList, start: int, end: int | None = None, /) -> BoolList:
657648
arr = np.copy(oarr)
658-
mut_margin(arr, start, end)
649+
if end is None:
650+
mut_margin(arr, start, start)
651+
else:
652+
mut_margin(arr, start, end)
659653
return arr
660654

661655

@@ -1741,6 +1735,8 @@ def my_eval(env: Env, node: object) -> Any:
17411735
"round": Proc("round", round, (1, 1), is_real),
17421736
"max": Proc("max", lambda *v: max(v), (1, None), is_real),
17431737
"min": Proc("min", lambda *v: min(v), (1, None), is_real),
1738+
"max-seq": Proc("max-seq", max, (1, 1), is_sequence),
1739+
"min-seq": Proc("min-seq", min, (1, 1), is_sequence),
17441740
"mod": Proc("mod", mod, (2, 2), is_int),
17451741
"modulo": Proc("modulo", mod, (2, 2), is_int),
17461742
# symbols
@@ -1796,7 +1792,7 @@ def my_eval(env: Env, node: object) -> Any:
17961792
"bool-array": Proc(
17971793
"bool-array", lambda *a: np.array(a, dtype=np.bool_), (1, None), is_nat
17981794
),
1799-
"margin": Proc("margin", margin, (2, 3)),
1795+
"margin": Proc("margin", margin, (2, 3), is_boolarr, is_int),
18001796
"mincut": Proc("mincut", mincut, (2, 2), is_boolarr, is_nat),
18011797
"minclip": Proc("minclip", minclip, (2, 2), is_boolarr, is_nat),
18021798
"maxcut": Proc("maxcut", maxcut, (2, 2), is_boolarr, is_nat),

auto_editor/lib/data_structs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ def display_str(val: object) -> str:
185185
return f"{val.real}{join}{val.imag}i"
186186
if type(val) is np.bool_:
187187
return "1" if val else "0"
188+
if type(val) is np.float64 or type(val) is np.float32:
189+
return f"{float(val)}"
188190
if type(val) is Fraction:
189191
return f"{val.numerator}/{val.denominator}"
190192

auto_editor/subcommands/test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,11 +628,11 @@ def cases(*cases: tuple[str, Any]) -> None:
628628
("(string #\\a #\\b)", "ab"),
629629
("(string #\\a #\\b #\\c)", "abc"),
630630
(
631-
"(margin 0 (bool-array 0 0 0 1 0 0 0))",
631+
"(margin (bool-array 0 0 0 1 0 0 0) 0)",
632632
np.array([0, 0, 0, 1, 0, 0, 0], dtype=np.bool_),
633633
),
634634
(
635-
"(margin -2 2 (bool-array 0 0 1 1 0 0 0))",
635+
"(margin (bool-array 0 0 1 1 0 0 0) -2 2)",
636636
np.array([0, 0, 0, 0, 1, 1, 0], dtype=np.bool_),
637637
),
638638
("(equal? 3 3)", True),

0 commit comments

Comments
 (0)