7
7
8
8
9
9
# This file contains all the functions that replace one op with another in the
10
- # graph. The functions replacing ops for models deployed with Jarvis are grouped
11
- # together in class 'ReplaceOpsInGraph'. Some examples of functions in the class are
12
- # 1. functions that replace an ATen op with a custom op that accepts extra arguments
13
- # 2. functions that replace in-place variants of ATen ops with out-of-place version.
14
- # 3. functions that replace an ATen op with another semantically equivalent ATen op.
15
- # 4. functions that concretize optional args.
10
+ # graph.
16
11
17
12
# pyre-unsafe
18
13
54
49
from torch .fx .node import Argument
55
50
56
51
# A map to represent ops that:
57
- # (a) are functionally equivalent wrt. Jarvis ; and
52
+ # (a) are functionally equivalent; and
58
53
# (b) have identical arguments
59
54
# An op whose target is 'key' in this dict can be replaced by the functionally euivalent
60
55
# op whose target is 'value'. The replacement would just involve changing the op target.
@@ -650,7 +645,7 @@ def call_operator(self, op, args, kwargs, meta):
650
645
651
646
# Make that pass runnable standalone at opt level 0.
652
647
@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
653
- class ReplaceAtenConvolutionWithJarvisConvolutionPass (ExportPass ):
648
+ class ReplaceAtenConvolutionWithCadenceConvolutionPass (ExportPass ):
654
649
"""
655
650
Replace aten convolution op with jarvis-specific convolution op, since the
656
651
aten version is not supported by jarvis.
@@ -784,7 +779,7 @@ class ReplaceConvWithChannelLastConv:
784
779
tensors. However, if the input and output to the convolution op are originally
785
780
in NWHC layout, and are then permuted to conform to NCHW layout, we can fuse
786
781
the two permute ops with the convolution op, and call the NHWC layout
787
- convolution op in Jarvis .
782
+ convolution op.
788
783
"""
789
784
790
785
def __init__ (self ):
@@ -821,7 +816,7 @@ def conv_layout_is_nhwc(self, node: torch.fx.Node) -> bool:
821
816
out_shape = get_shape (self .graph_module , node )
822
817
assert out_shape is not None
823
818
out_dims = len (out_shape )
824
- assert out_dims in {3 , 4 }, "Jarvis only supports conv1d and conv2d"
819
+ assert out_dims in {3 , 4 }, "Only supports conv1d and conv2d"
825
820
conv1d = out_dims == 3
826
821
827
822
# Get the possible targets for the nodes in pt_nodes. Since conv1d has
@@ -951,7 +946,7 @@ class ReplaceConvWithChannelLastConvPass(ExportPass):
951
946
"""
952
947
953
948
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
954
- result = ReplaceAtenConvolutionWithJarvisConvolutionPass ()(graph_module )
949
+ result = ReplaceAtenConvolutionWithCadenceConvolutionPass ()(graph_module )
955
950
assert result is not None
956
951
ReplaceConvWithChannelLastConv ()(result .graph_module )
957
952
return result
@@ -1871,9 +1866,9 @@ def call_operator(self, op, args, kwargs, meta):
1871
1866
1872
1867
1873
1868
@register_cadence_pass (CadencePassAttribute (opt_level = 0 ))
1874
- class ReplaceAtenAvgPoolWithJarvisAvgPoolPass (ExportPass ):
1869
+ class ReplaceAtenAvgPoolWithCadenceAvgPoolPass (ExportPass ):
1875
1870
"""
1876
- Replace the aten avg_pool op with the jarvis custom avg_pool2d op.
1871
+ Replace the aten avg_pool op with the cadence custom avg_pool2d op.
1877
1872
"""
1878
1873
1879
1874
def call_operator (self , op , args , kwargs , meta ):
@@ -2429,7 +2424,7 @@ class CadenceReplaceOpsInGraph:
2429
2424
ReplacePadWithCatPass ,
2430
2425
ReplaceConstantPadNdWithSlicePass ,
2431
2426
ReplaceConvWithChannelLastConvPass ,
2432
- ReplaceAtenConvolutionWithJarvisConvolutionPass ,
2427
+ ReplaceAtenConvolutionWithCadenceConvolutionPass ,
2433
2428
ForceChannelLastForConvPass ,
2434
2429
ReplaceTrivialConvWithLinear ,
2435
2430
ReplaceConvWithIm2RowAndLinear ,
@@ -2448,7 +2443,7 @@ class CadenceReplaceOpsInGraph:
2448
2443
ReplacePT2DequantWithCadenceDequantPass ,
2449
2444
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass ,
2450
2445
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass ,
2451
- ReplaceAtenAvgPoolWithJarvisAvgPoolPass ,
2446
+ ReplaceAtenAvgPoolWithCadenceAvgPoolPass ,
2452
2447
ReplaceWhereWithFullArgsWithWhereScalar ,
2453
2448
ReplaceAtenApproxGeluWithApproxGeluPass ,
2454
2449
ReplaceSplitWithSlicePass ,
0 commit comments