@@ -620,42 +620,36 @@ def make_array(dtype: Sym, size: int, v: int = 0) -> np.ndarray:
620
620
raise MyError (f"number too large to be converted to { dtype } " )
621
621
622
622
623
- def minclip (oarr : BoolList , _min : int ) -> BoolList :
623
+ def minclip (oarr : BoolList , _min : int , / ) -> BoolList :
624
624
arr = np .copy (oarr )
625
625
mut_remove_small (arr , _min , replace = 1 , with_ = 0 )
626
626
return arr
627
627
628
628
629
- def mincut (oarr : BoolList , _min : int ) -> BoolList :
629
+ def mincut (oarr : BoolList , _min : int , / ) -> BoolList :
630
630
arr = np .copy (oarr )
631
631
mut_remove_small (arr , _min , replace = 0 , with_ = 1 )
632
632
return arr
633
633
634
634
635
- def maxclip (oarr : BoolList , _min : int ) -> BoolList :
635
+ def maxclip (oarr : BoolList , _min : int , / ) -> BoolList :
636
636
arr = np .copy (oarr )
637
637
mut_remove_large (arr , _min , replace = 1 , with_ = 0 )
638
638
return arr
639
639
640
640
641
- def maxcut (oarr : BoolList , _min : int ) -> BoolList :
641
+ def maxcut (oarr : BoolList , _min : int , / ) -> BoolList :
642
642
arr = np .copy (oarr )
643
643
mut_remove_large (arr , _min , replace = 0 , with_ = 1 )
644
644
return arr
645
645
646
646
647
- def margin (a : int , b : Any , c : Any = None ) -> BoolList :
648
- if c is None :
649
- check_args ("margin" , [a , b ], (2 , 2 ), (is_int , is_boolarr ))
650
- oarr = b
651
- start , end = a , a
652
- else :
653
- check_args ("margin" , [a , b , c ], (3 , 3 ), (is_int , is_int , is_boolarr ))
654
- oarr = c
655
- start , end = a , b
656
-
647
+ def margin (oarr : BoolList , start : int , end : int | None = None , / ) -> BoolList :
657
648
arr = np .copy (oarr )
658
- mut_margin (arr , start , end )
649
+ if end is None :
650
+ mut_margin (arr , start , start )
651
+ else :
652
+ mut_margin (arr , start , end )
659
653
return arr
660
654
661
655
@@ -1741,6 +1735,8 @@ def my_eval(env: Env, node: object) -> Any:
1741
1735
"round" : Proc ("round" , round , (1 , 1 ), is_real ),
1742
1736
"max" : Proc ("max" , lambda * v : max (v ), (1 , None ), is_real ),
1743
1737
"min" : Proc ("min" , lambda * v : min (v ), (1 , None ), is_real ),
1738
+ "max-seq" : Proc ("max-seq" , max , (1 , 1 ), is_sequence ),
1739
+ "min-seq" : Proc ("min-seq" , min , (1 , 1 ), is_sequence ),
1744
1740
"mod" : Proc ("mod" , mod , (2 , 2 ), is_int ),
1745
1741
"modulo" : Proc ("modulo" , mod , (2 , 2 ), is_int ),
1746
1742
# symbols
@@ -1796,7 +1792,7 @@ def my_eval(env: Env, node: object) -> Any:
1796
1792
"bool-array" : Proc (
1797
1793
"bool-array" , lambda * a : np .array (a , dtype = np .bool_ ), (1 , None ), is_nat
1798
1794
),
1799
- "margin" : Proc ("margin" , margin , (2 , 3 )),
1795
+ "margin" : Proc ("margin" , margin , (2 , 3 ), is_boolarr , is_int ),
1800
1796
"mincut" : Proc ("mincut" , mincut , (2 , 2 ), is_boolarr , is_nat ),
1801
1797
"minclip" : Proc ("minclip" , minclip , (2 , 2 ), is_boolarr , is_nat ),
1802
1798
"maxcut" : Proc ("maxcut" , maxcut , (2 , 2 ), is_boolarr , is_nat ),
0 commit comments