Skip to content
This repository was archived by the owner on Feb 22, 2020. It is now read-only.

Commit e3ab1aa

Browse files
authored
Merge pull request #324 from gnes-ai/feat-grpc-proxy
feat(grpc): add proxy argument to cli
2 parents e64bc7a + 4055ad8 commit e3ab1aa

File tree

6 files changed

+122
-26
lines changed

6 files changed

+122
-26
lines changed

gnes/cli/parser.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,11 @@ def set_service_parser(parser=None):
185185
'dump_interval will be ignored')
186186
parser.add_argument('--parallel_backend', type=str, choices=['thread', 'process'], default='thread',
187187
help='parallel backend of the service')
188-
parser.add_argument('--num_parallel', type=int, default=1,
189-
help='number of parallel services running at the same time, '
188+
parser.add_argument('--num_parallel', '--replicas', type=int, default=1,
189+
help='number of parallel services running at the same time (i.e. replicas), '
190190
'`port_in` and `port_out` will be set to random, '
191191
'and routers will be added automatically when necessary')
192-
parser.add_argument('--parallel_type', type=ParallelType.from_string, choices=list(ParallelType),
192+
parser.add_argument('--parallel_type', '--replica_type', type=ParallelType.from_string, choices=list(ParallelType),
193193
default=ParallelType.PUSH_NONBLOCK,
194194
help='parallel type of the concurrent services')
195195
parser.add_argument('--check_version', action=ActionNoYes, default=True,
@@ -308,6 +308,10 @@ def _set_grpc_parser(parser=None):
308308
help='host port of the grpc service')
309309
parser.add_argument('--max_message_size', type=int, default=-1,
310310
help='maximum send and receive size for grpc server in bytes, -1 means unlimited')
311+
parser.add_argument('--proxy', action=ActionNoYes, default=False,
312+
help='respect the http_proxy and https_proxy environment variables. '
313+
'otherwise, it will unset these proxy variables before start. '
314+
'gRPC seems perfer --no_proxy')
311315
return parser
312316

313317

gnes/client/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
1617
from typing import Tuple, List, Union
1718

1819
import grpc
@@ -120,6 +121,9 @@ class GrpcClient:
120121

121122
def __init__(self, args):
122123
self.args = args
124+
if not args.proxy:
125+
os.unsetenv('http_proxy')
126+
os.unsetenv('https_proxy')
123127
self.logger = set_logger(self.__class__.__name__, self.args.verbose)
124128
self.logger.info('setting up grpc insecure channel...')
125129
# A gRPC channel provides a connection to a remote gRPC server.

gnes/flow/__init__.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
import os
32
from collections import OrderedDict, defaultdict
43
from contextlib import ExitStack
54
from functools import wraps
@@ -143,30 +142,81 @@ def to_mermaid(self, left_right: bool = True):
143142
Output the mermaid graph for visualization
144143
145144
:param left_right: render the flow in left-to-right manner, otherwise top-down manner.
146-
:return:
145+
:return: a mermaid-formatted string
147146
"""
147+
148+
# fill, stroke
149+
service_color = {
150+
Service.Frontend: ('#FFE0E0', '#000'),
151+
Service.Router: ('#C9E8D2', '#000'),
152+
Service.Encoder: ('#FFDAAF', '#000'),
153+
Service.Preprocessor: ('#CED7EF', '#000'),
154+
Service.Indexer: ('#FFFBC1', '#000'),
155+
}
156+
148157
mermaid_graph = OrderedDict()
149-
for k in self._service_nodes.keys():
150-
mermaid_graph[k] = []
151158
cls_dict = defaultdict(set)
159+
replicas_dict = {}
160+
161+
for k, v in self._service_nodes.items():
162+
mermaid_graph[k] = []
163+
num_replicas = getattr(v['parsed_args'], 'num_parallel', 1)
164+
if num_replicas > 1:
165+
head_router = k + '_HEAD'
166+
tail_router = k + '_TAIL'
167+
replicas_dict[k] = (head_router, tail_router)
168+
cls_dict[Service.Router].add(head_router)
169+
cls_dict[Service.Router].add(tail_router)
170+
p_r = '((%s))'
171+
k_service = v['service']
172+
p_e = '((%s))' if k_service == Service.Router else '(%s)'
173+
174+
mermaid_graph[k].append('subgraph %s["%s (replias=%d)"]' % (k, k, num_replicas))
175+
for j in range(num_replicas):
176+
r = k + '_%d' % j
177+
cls_dict[k_service].add(r)
178+
mermaid_graph[k].append('\t%s%s-->%s%s' % (head_router, p_r % 'router', r, p_e % r))
179+
mermaid_graph[k].append('\t%s%s-->%s%s' % (r, p_e % r, tail_router, p_r % 'router'))
180+
mermaid_graph[k].append('end')
181+
mermaid_graph[k].append(
182+
'style %s fill:%s,stroke:%s,stroke-width:2px,stroke-dasharray:5,stroke-opacity:0.3,fill-opacity:0.5' % (
183+
k, service_color[k_service][0], service_color[k_service][1]))
152184

153185
for k, ed_type in self._service_edges.items():
154186
start_node, end_node = k.split('-')
187+
cur_node = mermaid_graph[start_node]
188+
155189
s_service = self._service_nodes[start_node]['service']
156190
e_service = self._service_nodes[end_node]['service']
191+
192+
start_node_text = start_node
193+
end_node_text = end_node
194+
195+
# check if is in replicas
196+
if start_node in replicas_dict:
197+
start_node = replicas_dict[start_node][1] # outgoing
198+
s_service = Service.Router
199+
start_node_text = 'router'
200+
if end_node in replicas_dict:
201+
end_node = replicas_dict[end_node][0] # incoming
202+
e_service = Service.Router
203+
end_node_text = 'router'
204+
205+
# always plot frontend at the start and the end
206+
if e_service == Service.Frontend:
207+
end_node_text = end_node
208+
end_node += '_END'
209+
157210
cls_dict[s_service].add(start_node)
158211
cls_dict[e_service].add(end_node)
159212
p_s = '((%s))' if s_service == Service.Router else '(%s)'
160213
p_e = '((%s))' if e_service == Service.Router else '(%s)'
161-
mermaid_graph[start_node].append('\t%s%s-- %s -->%s%s' % (
162-
start_node, p_s % start_node, ed_type,
163-
end_node, p_e % end_node))
164-
165-
style = ['classDef FrontendCLS fill:#FFE0E0,stroke:#FFE0E0,stroke-width:1px;',
166-
'classDef EncoderCLS fill:#FFDAAF,stroke:#FFDAAF,stroke-width:1px;',
167-
'classDef IndexerCLS fill:#FFFBC1,stroke:#FFFBC1,stroke-width:1px;',
168-
'classDef RouterCLS fill:#C9E8D2,stroke:#C9E8D2,stroke-width:1px;',
169-
'classDef PreprocessorCLS fill:#CEEEEF,stroke:#CEEEEF,stroke-width:1px;']
214+
cur_node.append('\t%s%s-- %s -->%s%s' % (
215+
start_node, p_s % start_node_text, ed_type,
216+
end_node, p_e % end_node_text))
217+
218+
style = ['classDef %sCLS fill:%s,stroke:%s,stroke-width:1px;' % (k, v[0], v[1]) for k, v in
219+
service_color.items()]
170220
class_def = ['class %s %sCLS;' % (','.join(v), k) for k, v in cls_dict.items()]
171221
mermaid_str = '\n'.join(
172222
['graph %s' % ('LR' if left_right else 'TD')] + [ss for s in mermaid_graph.values() for ss in
@@ -175,19 +225,30 @@ def to_mermaid(self, left_right: bool = True):
175225
return mermaid_str
176226

177227
@_build_level(BuildLevel.GRAPH)
178-
def to_jpg(self, path: str = 'flow.jpg', left_right: bool = True):
228+
def to_url(self, **kwargs) -> str:
229+
"""
230+
Rendering the current flow as a url points to a SVG, it needs internet connection
231+
232+
:param kwargs: keyword arguments of :py:meth:`to_mermaid`
233+
:return: the url points to a SVG
234+
"""
235+
import base64
236+
mermaid_str = self.to_mermaid(**kwargs)
237+
encoded_str = base64.b64encode(bytes(mermaid_str, 'utf-8')).decode('utf-8')
238+
return 'https://mermaidjs.github.io/mermaid-live-editor/#/view/%s' % encoded_str
239+
240+
@_build_level(BuildLevel.GRAPH)
241+
def to_jpg(self, path: str = 'flow.jpg', **kwargs):
179242
"""
180243
Rendering the current flow as a jpg image, this will call :py:meth:`to_mermaid` and it needs internet connection
181244
182245
:param path: the file path of the image
183-
:param left_right: render the flow in left-to-right manner, otherwise top-down manner.
246+
:param kwargs: keyword arguments of :py:meth:`to_mermaid`
184247
:return:
185248
"""
186-
import base64
249+
187250
from urllib.request import Request, urlopen
188-
mermaid_str = self.to_mermaid(left_right)
189-
encoded_str = base64.b64encode(bytes(mermaid_str, 'utf-8')).decode('utf-8')
190-
print('https://mermaidjs.github.io/mermaid-live-editor/#/view/%s' % encoded_str)
251+
encoded_str = self.to_url().replace('https://mermaidjs.github.io/mermaid-live-editor/#/view/', '')
191252
self.logger.info('saving jpg...')
192253
req = Request('https://mermaid.ink/img/%s' % encoded_str, headers={'User-Agent': 'Mozilla/5.0'})
193254
with open(path, 'wb') as fp:
@@ -226,8 +287,6 @@ def query(self, bytes_gen: Iterator[bytes] = None, **kwargs):
226287

227288
@_build_level(BuildLevel.RUNTIME)
228289
def _call_client(self, bytes_gen: Iterator[bytes] = None, **kwargs):
229-
os.unsetenv('http_proxy')
230-
os.unsetenv('https_proxy')
231290
args, p_args = self._get_parsed_args(self, set_client_cli_parser, kwargs)
232291
p_args.grpc_port = self._service_nodes[self._frontend]['parsed_args'].grpc_port
233292
p_args.grpc_host = self._service_nodes[self._frontend]['parsed_args'].grpc_host

gnes/service/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,11 +553,13 @@ def __init__(self, service_cls, args):
553553
if args.num_parallel > 1:
554554
from .router import RouterService
555555
_head_router = copy.deepcopy(args)
556+
_head_router.yaml_path = resolve_yaml_path('BaseRouter')
556557
_head_router.port_ctrl = self._get_random_port()
557558
port_out = self._get_random_port()
558559
_head_router.port_out = port_out
559560

560561
_tail_router = copy.deepcopy(args)
562+
_tail_router.yaml_path = resolve_yaml_path('BaseRouter')
561563
port_in = self._get_random_port()
562564
_tail_router.port_in = port_in
563565
_tail_router.port_ctrl = self._get_random_port()

gnes/service/frontend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616

17+
import os
1718
import threading
1819
from concurrent.futures import ThreadPoolExecutor
1920

@@ -28,6 +29,9 @@
2829
class FrontendService:
2930

3031
def __init__(self, args):
32+
if not args.proxy:
33+
os.unsetenv('http_proxy')
34+
os.unsetenv('https_proxy')
3135
self.logger = set_logger(self.__class__.__name__, args.verbose)
3236
self.server = grpc.server(
3337
ThreadPoolExecutor(max_workers=args.max_concurrency),

tests/test_gnes_flow.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,26 @@ def test_flow5(self):
115115
print(f.to_mermaid())
116116
f.to_jpg()
117117

118+
def test_flow_replica_pot(self):
119+
f = (Flow(check_version=False, route_table=True)
120+
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=4)
121+
.add(gfs.Encoder, yaml_path='PyTorchTransformers', replicas=3)
122+
.add(gfs.Indexer, name='vec_idx', yaml_path='NumpyIndexer', replicas=2)
123+
.add(gfs.Indexer, name='doc_idx', yaml_path='DictIndexer', service_in='prep', replicas=2)
124+
.add(gfs.Router, name='sync_barrier', yaml_path='BaseReduceRouter',
125+
num_part=2, service_in=['vec_idx', 'doc_idx'])
126+
.build(backend=None))
127+
print(f.to_mermaid())
128+
print(f.to_url(left_right=False))
129+
print(f.to_url(left_right=True))
130+
118131
def _test_index_flow(self, backend):
119132
for k in [self.indexer1_bin, self.indexer2_bin, self.encoder_bin]:
120133
self.assertFalse(os.path.exists(k))
121134

122135
flow = (Flow(check_version=False, route_table=False)
123136
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor')
124-
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'))
137+
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3)
125138
.add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'))
126139
.add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml'),
127140
service_in='prep')
@@ -137,7 +150,7 @@ def _test_index_flow(self, backend):
137150
def _test_query_flow(self, backend):
138151
flow = (Flow(check_version=False, route_table=False)
139152
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor')
140-
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'))
153+
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3)
141154
.add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'))
142155
.add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml'))
143156
.add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml')))
@@ -153,3 +166,13 @@ def test_index_query_flow(self):
153166
def test_indexe_query_flow_proc(self):
154167
self._test_index_flow('process')
155168
self._test_query_flow('process')
169+
170+
def test_query_flow_plot(self):
171+
flow = (Flow(check_version=False, route_table=False)
172+
.add(gfs.Preprocessor, name='prep', yaml_path='SentSplitPreprocessor', replicas=2)
173+
.add(gfs.Encoder, yaml_path=os.path.join(self.dirname, 'yaml/flow-transformer.yml'), replicas=3)
174+
.add(gfs.Indexer, name='vec_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-vecindex.yml'),
175+
replicas=4)
176+
.add(gfs.Router, name='scorer', yaml_path=os.path.join(self.dirname, 'yaml/flow-score.yml'))
177+
.add(gfs.Indexer, name='doc_idx', yaml_path=os.path.join(self.dirname, 'yaml/flow-dictindex.yml')))
178+
print(flow.build(backend=None).to_url())

0 commit comments

Comments
 (0)