Skip to content

Commit 82f5ab9

Browse files
authored
Merge pull request #179 from Dstack-TEE/fix-vmm-cli
fix(vmm-cli): compatible with custom kms-url and gateway-url
2 parents aa5df6e + 7838e6c commit 82f5ab9

File tree

1 file changed

+37
-12
lines changed

1 file changed

+37
-12
lines changed

vmm/src/vmm-cli.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import socket
1313
import http.client
1414
import urllib.parse
15+
import ssl
1516

1617
from eth_keys import keys
1718
from eth_utils import keccak
@@ -162,7 +163,11 @@ def request(self, method: str, path: str, headers: Dict[str, str] = None,
162163
conn = UnixSocketHTTPConnection(self.uds_path)
163164
else:
164165
if self.is_https:
165-
conn = http.client.HTTPSConnection(self.host)
166+
# TODO: we should verify TLS cert.
167+
context = ssl.create_default_context()
168+
context.check_hostname = False
169+
context.verify_mode = ssl.CERT_NONE
170+
conn = http.client.HTTPSConnection(self.host, context=context)
166171
else:
167172
conn = http.client.HTTPConnection(self.host)
168173

@@ -313,9 +318,18 @@ def list_images(self) -> List[Dict]:
313318
response = self.rpc_call('ListImages')
314319
return response['images']
315320

316-
def get_app_env_encrypt_pub_key(self, app_id: str) -> Dict:
321+
def get_app_env_encrypt_pub_key(self, app_id: str, kms_url: Optional[str] = None) -> Dict:
317322
"""Get the encryption public key for the specified application ID"""
318-
response = self.rpc_call('GetAppEnvEncryptPubKey', {'app_id': app_id})
323+
if kms_url:
324+
client = VmmClient(kms_url)
325+
path = f"/prpc/GetAppEnvEncryptPubKey?json"
326+
status, response = client.request(
327+
'POST', path, headers={
328+
'Content-Type': 'application/json'
329+
}, body={'app_id': app_id})
330+
print(f"Getting encryption public key for {app_id} from {kms_url}")
331+
else:
332+
response = self.rpc_call('GetAppEnvEncryptPubKey', {'app_id': app_id})
319333

320334
# Verify the signature if available
321335
if 'signature' not in response:
@@ -443,6 +457,8 @@ def create_vm(self, name: str, image: str, compose_file: str,
443457
gpus: Optional[List[str]] = None,
444458
pin_numa: bool = False,
445459
hugepages: bool = False,
460+
kms_urls: Optional[List[str]] = None,
461+
gateway_urls: Optional[List[str]] = None,
446462
) -> None:
447463
"""Create a new VM"""
448464
# Read and validate compose file
@@ -471,11 +487,15 @@ def create_vm(self, name: str, image: str, compose_file: str,
471487
"attach_mode": "listed",
472488
"gpus": [{"slot": gpu} for gpu in gpus or []]
473489
}
490+
if kms_urls:
491+
params["kms_urls"] = kms_urls
492+
if gateway_urls:
493+
params["gateway_urls"] = gateway_urls
474494

475495
app_id = app_id or self.calc_app_id(compose_content)
476496
print(f"App ID: {app_id}")
477497
if envs:
478-
encrypt_pubkey = self.get_app_env_encrypt_pub_key(app_id)
498+
encrypt_pubkey = self.get_app_env_encrypt_pub_key(app_id, kms_urls[0] if kms_urls else None)
479499
print(
480500
f"Encrypting environment variables with key: {encrypt_pubkey}")
481501
envs_list = [{"key": k, "value": v} for k, v in envs.items()]
@@ -484,7 +504,7 @@ def create_vm(self, name: str, image: str, compose_file: str,
484504
print(f"Created VM with ID: {response.get('id')}")
485505
return response.get('id')
486506

487-
def update_vm_env(self, vm_id: str, envs: Dict[str, str]) -> None:
507+
def update_vm_env(self, vm_id: str, envs: Dict[str, str], kms_urls: Optional[List[str]] = None) -> None:
488508
"""Update environment variables for a VM"""
489509
# First get the VM info to retrieve the app_id
490510
vm_info_response = self.rpc_call('GetInfo', {'id': vm_id})
@@ -496,11 +516,7 @@ def update_vm_env(self, vm_id: str, envs: Dict[str, str]) -> None:
496516
print(f"Retrieved app ID: {app_id}")
497517

498518
# Now get the encryption key for the app
499-
response = self.rpc_call('GetAppEnvEncryptPubKey', {'app_id': app_id})
500-
if 'public_key' not in response:
501-
raise Exception("Failed to get encryption public key for the VM")
502-
503-
encrypt_pubkey = response['public_key']
519+
encrypt_pubkey = self.get_app_env_encrypt_pub_key(app_id, kms_urls[0] if kms_urls else None)
504520
print(f"Encrypting environment variables with key: {encrypt_pubkey}")
505521
envs_list = [{"key": k, "value": v} for k, v in envs.items()]
506522
encrypted_env = encrypt_env(envs_list, encrypt_pubkey)
@@ -815,6 +831,10 @@ def main():
815831
help='Pin VM to specific NUMA node')
816832
deploy_parser.add_argument('--hugepages', action='store_true',
817833
help='Enable hugepages for the VM')
834+
deploy_parser.add_argument('--kms-url', action='append', type=str,
835+
help='KMS URL')
836+
deploy_parser.add_argument('--gateway-url', action='append', type=str,
837+
help='Gateway URL')
818838

819839
# Images command
820840
_images_parser = subparsers.add_parser(
@@ -829,6 +849,9 @@ def main():
829849
update_env_parser.add_argument('vm_id', help='VM ID to update')
830850
update_env_parser.add_argument(
831851
'--env-file', required=True, help='File with environment variables to encrypt')
852+
update_env_parser.add_argument(
853+
'--kms-url', action='append', type=str,
854+
help='KMS URL')
832855

833856
# Whitelist command
834857
kms_parser = subparsers.add_parser(
@@ -891,7 +914,9 @@ def main():
891914
app_id=args.app_id,
892915
gpus=args.gpu,
893916
hugepages=args.hugepages,
894-
pin_numa=args.pin_numa
917+
pin_numa=args.pin_numa,
918+
kms_urls=args.kms_url,
919+
gateway_urls=args.gateway_url
895920
)
896921
elif args.command == 'lsimage':
897922
images = cli.list_images()
@@ -901,7 +926,7 @@ def main():
901926
elif args.command == 'lsgpu':
902927
cli.list_gpus()
903928
elif args.command == 'update-env':
904-
cli.update_vm_env(args.vm_id, parse_env_file(args.env_file))
929+
cli.update_vm_env(args.vm_id, parse_env_file(args.env_file), kms_urls=args.kms_url)
905930
elif args.command == 'kms':
906931
if not args.kms_action:
907932
kms_parser.print_help()

0 commit comments

Comments
 (0)