Skip to content

Commit 5e94e00

Browse files
authored
Merge pull request #610 from chrisburr/feat-zstd-sandbox-compression
feat: implement zstandard compression for sandbox files
2 parents b548c09 + 0fa02ae commit 5e94e00

File tree

8 files changed

+49
-14
lines changed

8 files changed

+49
-14
lines changed

diracx-api/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"diracx-client",
1717
"diracx-core",
1818
"httpx",
19+
"zstandard",
1920
]
2021
dynamic = ["version"]
2122

diracx-api/src/diracx/api/jobs.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import os
88
import tarfile
99
import tempfile
10+
from contextlib import contextmanager
1011
from pathlib import Path
11-
from typing import Literal
12+
from typing import BinaryIO, Literal
1213

1314
import httpx
15+
import zstandard
1416

1517
from diracx.client.aio import AsyncDiracClient
1618
from diracx.client.models import SandboxInfo
@@ -20,8 +22,29 @@
2022
logger = logging.getLogger(__name__)
2123

2224
SANDBOX_CHECKSUM_ALGORITHM = "sha256"
23-
SANDBOX_COMPRESSION: Literal["bz2"] = "bz2"
24-
SANDBOX_OPEN_MODE: Literal["w|bz2"] = "w|bz2"
25+
SANDBOX_COMPRESSION: Literal["zst"] = "zst"
26+
27+
28+
@contextmanager
29+
def tarfile_open(fileobj: BinaryIO):
30+
"""Context manager to extend tarfile.open to support reading zstd compressed files.
31+
32+
This is only needed for Python <=3.13.
33+
"""
34+
# Save current position and read magic bytes
35+
current_pos = fileobj.tell()
36+
magic = fileobj.read(4)
37+
fileobj.seek(current_pos)
38+
39+
# Read magic bytes to determine compression format
40+
if magic.startswith(b"\x28\xb5\x2f\xfd"): # zstd magic number
41+
dctx = zstandard.ZstdDecompressor()
42+
with dctx.stream_reader(fileobj) as decompressor:
43+
with tarfile.open(fileobj=decompressor, mode="r|") as tf:
44+
yield tf
45+
else:
46+
with tarfile.open(fileobj=fileobj, mode="r") as tf:
47+
yield tf
2548

2649

2750
@with_client
@@ -33,10 +56,18 @@ async def create_sandbox(paths: list[Path], *, client: AsyncDiracClient) -> str:
3356
be used to submit jobs.
3457
"""
3558
with tempfile.TemporaryFile(mode="w+b") as tar_fh:
36-
with tarfile.open(fileobj=tar_fh, mode=SANDBOX_OPEN_MODE) as tf:
37-
for path in paths:
38-
logger.debug("Adding %s to sandbox as %s", path.resolve(), path.name)
39-
tf.add(path.resolve(), path.name, recursive=True)
59+
# Create zstd compressed tar with level 18 and long matching enabled
60+
compression_params = zstandard.ZstdCompressionParameters.from_level(
61+
18, enable_ldm=1
62+
)
63+
cctx = zstandard.ZstdCompressor(compression_params=compression_params)
64+
with cctx.stream_writer(tar_fh, closefd=False) as compressor:
65+
with tarfile.open(fileobj=compressor, mode="w|") as tf:
66+
for path in paths:
67+
logger.debug(
68+
"Adding %s to sandbox as %s", path.resolve(), path.name
69+
)
70+
tf.add(path.resolve(), path.name, recursive=True)
4071
tar_fh.seek(0)
4172

4273
hasher = getattr(hashlib, SANDBOX_CHECKSUM_ALGORITHM)()
@@ -89,6 +120,6 @@ async def download_sandbox(pfn: str, destination: Path, *, client: AsyncDiracCli
89120
fh.seek(0)
90121
logger.debug("Sandbox downloaded for %s", pfn)
91122

92-
with tarfile.open(fileobj=fh) as tf:
123+
with tarfile_open(fh) as tf:
93124
tf.extractall(path=destination, filter="data")
94125
logger.debug("Extracted %s to %s", pfn, destination)

diracx-client/src/diracx/client/_generated/models/_enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta):
3838
"""SandboxFormat."""
3939

4040
TAR_BZ2 = "tar.bz2"
41+
TAR_ZST = "tar.zst"
4142

4243

4344
class SandboxType(str, Enum, metaclass=CaseInsensitiveEnumMeta):

diracx-client/src/diracx/client/_generated/models/_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -930,7 +930,7 @@ class SandboxInfo(_serialization.Model):
930930
:vartype checksum: str
931931
:ivar size: Size. Required.
932932
:vartype size: int
933-
:ivar format: SandboxFormat. Required. "tar.bz2"
933+
:ivar format: SandboxFormat. Required. Known values are: "tar.bz2" and "tar.zst".
934934
:vartype format: str or ~_generated.models.SandboxFormat
935935
"""
936936

@@ -964,7 +964,7 @@ def __init__(
964964
:paramtype checksum: str
965965
:keyword size: Size. Required.
966966
:paramtype size: int
967-
:keyword format: SandboxFormat. Required. "tar.bz2"
967+
:keyword format: SandboxFormat. Required. Known values are: "tar.bz2" and "tar.zst".
968968
:paramtype format: str or ~_generated.models.SandboxFormat
969969
"""
970970
super().__init__(**kwargs)

diracx-core/src/diracx/core/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class ChecksumAlgorithm(StrEnum):
204204

205205
class SandboxFormat(StrEnum):
206206
TAR_BZ2 = "tar.bz2"
207+
TAR_ZST = "tar.zst"
207208

208209

209210
class SandboxInfo(BaseModel):

diracx-logic/tests/logic/test_sandboxes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,14 @@ async def test_upload_and_clean(
7272
"""
7373
data = secrets.token_bytes(256)
7474
data_digest = hashlib.sha256(data).hexdigest()
75-
key = f"fakevo/fake_group/fakeuser/sha256:{data_digest}.tar.bz2"
75+
key = f"fakevo/fake_group/fakeuser/sha256:{data_digest}.tar.zst"
7676
expected_pfn = f"SB:SandboxSE|/S3/sandboxes/{key}"
7777

7878
sandbox_info = SandboxInfo(
7979
checksum_algorithm=ChecksumAlgorithm.SHA256,
8080
checksum=data_digest,
8181
size=len(data),
82-
format=SandboxFormat.TAR_BZ2,
82+
format=SandboxFormat.TAR_ZST,
8383
)
8484

8585
# Test with a new sandbox

extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta):
3838
"""SandboxFormat."""
3939

4040
TAR_BZ2 = "tar.bz2"
41+
TAR_ZST = "tar.zst"
4142

4243

4344
class SandboxType(str, Enum, metaclass=CaseInsensitiveEnumMeta):

extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -951,7 +951,7 @@ class SandboxInfo(_serialization.Model):
951951
:vartype checksum: str
952952
:ivar size: Size. Required.
953953
:vartype size: int
954-
:ivar format: SandboxFormat. Required. "tar.bz2"
954+
:ivar format: SandboxFormat. Required. Known values are: "tar.bz2" and "tar.zst".
955955
:vartype format: str or ~_generated.models.SandboxFormat
956956
"""
957957

@@ -985,7 +985,7 @@ def __init__(
985985
:paramtype checksum: str
986986
:keyword size: Size. Required.
987987
:paramtype size: int
988-
:keyword format: SandboxFormat. Required. "tar.bz2"
988+
:keyword format: SandboxFormat. Required. Known values are: "tar.bz2" and "tar.zst".
989989
:paramtype format: str or ~_generated.models.SandboxFormat
990990
"""
991991
super().__init__(**kwargs)

0 commit comments

Comments
 (0)