Skip to content

Commit cc06cc1

Browse files
committed
Make auth slightly more robust
1 parent cefdc61 commit cc06cc1

File tree

2 files changed

+57
-26
lines changed

2 files changed

+57
-26
lines changed

src/diracx/client/_patch.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from datetime import datetime
1010
import json
1111
import requests
12+
import logging
1213

1314
from pathlib import Path
1415
from typing import Any, Dict, List, Optional, cast
@@ -38,6 +39,9 @@ def patch_sdk():
3839
"""
3940

4041

42+
logger = logging.getLogger(__name__)
43+
44+
4145
class DiracTokenCredential(TokenCredential):
4246
"""Tailor get_token() for our context"""
4347

@@ -98,12 +102,21 @@ def on_request(
98102
return
99103

100104
if not self._token:
101-
credentials = json.loads(self._credential.location.read_text())
102-
self._token = self._credential.get_token(
103-
"", refresh_token=credentials["refresh_token"]
104-
)
105-
106-
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
105+
try:
106+
credentials = json.loads(self._credential.location.read_text())
107+
except Exception:
108+
logger.warning(
109+
"Cannot load credentials from %s", self._credential.location
110+
)
111+
else:
112+
self._token = self._credential.get_token(
113+
"", refresh_token=credentials["refresh_token"]
114+
)
115+
116+
if self._token:
117+
request.http_request.headers[
118+
"Authorization"
119+
] = f"Bearer {self._token.token}"
107120

108121

109122
class DiracClient(DiracGenerated):
@@ -160,6 +173,7 @@ def refresh_token(
160173
)
161174

162175
if response.status_code != 200:
176+
location.unlink()
163177
raise RuntimeError(
164178
f"An issue occured while refreshing your access token: {response.json()['detail']}"
165179
)
@@ -192,24 +206,28 @@ def get_token(location: Path, token: AccessToken | None) -> AccessToken | None:
192206
raise RuntimeError("credentials are not set")
193207

194208
# Load the existing credentials
195-
if not token:
196-
credentials = json.loads(location.read_text())
197-
token = AccessToken(
198-
cast(str, credentials.get("access_token")),
199-
cast(int, credentials.get("expires_on")),
200-
)
201-
202-
# We check the validity of the token
203-
# If not valid, then return None to inform the caller that a new token
204-
# is needed
205-
if not is_token_valid(token):
206-
return None
207-
208-
return token
209+
try:
210+
if not token:
211+
credentials = json.loads(location.read_text())
212+
token = AccessToken(
213+
cast(str, credentials.get("access_token")),
214+
cast(int, credentials.get("expires_on")),
215+
)
216+
except Exception:
217+
logger.warning("Cannot load credentials from %s", location)
218+
pass
219+
else:
220+
# We check the validity of the token
221+
# If not valid, then return None to inform the caller that a new token
222+
# is needed
223+
if is_token_valid(token):
224+
return token
225+
return None
209226

210227

211228
def is_token_valid(token: AccessToken) -> bool:
212229
"""Condition to get a new token"""
230+
# TODO: Should we check against the userinfo endpoint?
213231
return (
214232
datetime.utcfromtimestamp(token.expires_on) - datetime.utcnow()
215233
).total_seconds() > 300

src/diracx/client/aio/_patch.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
88
"""
99
import json
10+
import logging
1011
from types import TracebackType
1112
from pathlib import Path
1213
from typing import Any, List, Optional
@@ -24,6 +25,8 @@
2425
"DiracClient",
2526
] # Add all objects you want publicly available to users at this package level
2627

28+
logger = logging.getLogger(__name__)
29+
2730

2831
def patch_sdk():
2932
"""Do not remove from this file.
@@ -104,19 +107,29 @@ async def on_request(
104107
credentials: dict[str, Any]
105108

106109
try:
110+
# TODO: Use httpx and await this call
107111
self._token = get_token(self._credential.location, self._token)
108112
except RuntimeError:
109113
# If we are here, it means the credentials path does not exist
110114
# we suppose it is not needed to perform the request
111115
return
112116

113117
if not self._token:
114-
credentials = json.loads(self._credential.location.read_text())
115-
self._token = await self._credential.get_token(
116-
"", refresh_token=credentials["refresh_token"]
117-
)
118-
119-
request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"
118+
try:
119+
credentials = json.loads(self._credential.location.read_text())
120+
except Exception:
121+
logger.warning(
122+
"Cannot load credentials from %s", self._credential.location
123+
)
124+
else:
125+
self._token = await self._credential.get_token(
126+
"", refresh_token=credentials["refresh_token"]
127+
)
128+
129+
if self._token:
130+
request.http_request.headers[
131+
"Authorization"
132+
] = f"Bearer {self._token.token}"
120133

121134

122135
class DiracClient(DiracGenerated):

0 commit comments

Comments
 (0)