|
9 | 9 | from datetime import datetime
|
10 | 10 | import json
|
11 | 11 | import requests
|
| 12 | +import logging |
12 | 13 |
|
13 | 14 | from pathlib import Path
|
14 | 15 | from typing import Any, Dict, List, Optional, cast
|
@@ -38,6 +39,9 @@ def patch_sdk():
|
38 | 39 | """
|
39 | 40 |
|
40 | 41 |
|
| 42 | +logger = logging.getLogger(__name__) |
| 43 | + |
| 44 | + |
41 | 45 | class DiracTokenCredential(TokenCredential):
|
42 | 46 | """Tailor get_token() for our context"""
|
43 | 47 |
|
@@ -98,12 +102,21 @@ def on_request(
|
98 | 102 | return
|
99 | 103 |
|
100 | 104 | if not self._token:
|
101 |
| - credentials = json.loads(self._credential.location.read_text()) |
102 |
| - self._token = self._credential.get_token( |
103 |
| - "", refresh_token=credentials["refresh_token"] |
104 |
| - ) |
105 |
| - |
106 |
| - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" |
| 105 | + try: |
| 106 | + credentials = json.loads(self._credential.location.read_text()) |
| 107 | + except Exception: |
| 108 | + logger.warning( |
| 109 | + "Cannot load credentials from %s", self._credential.location |
| 110 | + ) |
| 111 | + else: |
| 112 | + self._token = self._credential.get_token( |
| 113 | + "", refresh_token=credentials["refresh_token"] |
| 114 | + ) |
| 115 | + |
| 116 | + if self._token: |
| 117 | + request.http_request.headers[ |
| 118 | + "Authorization" |
| 119 | + ] = f"Bearer {self._token.token}" |
107 | 120 |
|
108 | 121 |
|
109 | 122 | class DiracClient(DiracGenerated):
|
@@ -160,6 +173,7 @@ def refresh_token(
|
160 | 173 | )
|
161 | 174 |
|
162 | 175 | if response.status_code != 200:
|
| 176 | + location.unlink() |
163 | 177 | raise RuntimeError(
|
164 | 178 | f"An issue occured while refreshing your access token: {response.json()['detail']}"
|
165 | 179 | )
|
@@ -192,24 +206,28 @@ def get_token(location: Path, token: AccessToken | None) -> AccessToken | None:
|
192 | 206 | raise RuntimeError("credentials are not set")
|
193 | 207 |
|
194 | 208 | # Load the existing credentials
|
195 |
| - if not token: |
196 |
| - credentials = json.loads(location.read_text()) |
197 |
| - token = AccessToken( |
198 |
| - cast(str, credentials.get("access_token")), |
199 |
| - cast(int, credentials.get("expires_on")), |
200 |
| - ) |
201 |
| - |
202 |
| - # We check the validity of the token |
203 |
| - # If not valid, then return None to inform the caller that a new token |
204 |
| - # is needed |
205 |
| - if not is_token_valid(token): |
206 |
| - return None |
207 |
| - |
208 |
| - return token |
| 209 | + try: |
| 210 | + if not token: |
| 211 | + credentials = json.loads(location.read_text()) |
| 212 | + token = AccessToken( |
| 213 | + cast(str, credentials.get("access_token")), |
| 214 | + cast(int, credentials.get("expires_on")), |
| 215 | + ) |
| 216 | + except Exception: |
| 217 | + logger.warning("Cannot load credentials from %s", location) |
| 218 | + pass |
| 219 | + else: |
| 220 | + # We check the validity of the token |
| 221 | + # If not valid, then return None to inform the caller that a new token |
| 222 | + # is needed |
| 223 | + if is_token_valid(token): |
| 224 | + return token |
| 225 | + return None |
209 | 226 |
|
210 | 227 |
|
211 | 228 | def is_token_valid(token: AccessToken) -> bool:
|
212 | 229 | """Condition to get a new token"""
|
| 230 | + # TODO: Should we check against the userinfo endpoint? |
213 | 231 | return (
|
214 | 232 | datetime.utcfromtimestamp(token.expires_on) - datetime.utcnow()
|
215 | 233 | ).total_seconds() > 300
|
0 commit comments