Skip to content

Commit 1109077

Browse files
committed
feat: Add bedrock_proxy
1 parent 3f1b56a commit 1109077

File tree

3 files changed

+118
-1
lines changed

3 files changed

+118
-1
lines changed

src/api/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from fastapi.responses import PlainTextResponse
88
from mangum import Mangum
99

10-
from api.routers import chat, embeddings, model
10+
from api.routers import chat, embeddings, model, bedrock_proxy
1111
from api.setting import API_ROUTE_PREFIX, DESCRIPTION, SUMMARY, TITLE, VERSION
1212

1313
config = {
@@ -35,6 +35,7 @@
3535
app.include_router(model.router, prefix=API_ROUTE_PREFIX)
3636
app.include_router(chat.router, prefix=API_ROUTE_PREFIX)
3737
app.include_router(embeddings.router, prefix=API_ROUTE_PREFIX)
38+
app.include_router(bedrock_proxy.router, prefix=API_ROUTE_PREFIX)
3839

3940

4041
@app.get("/health")

src/api/routers/bedrock_proxy.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
import logging
3+
from typing import Dict, Any
4+
from urllib.parse import quote
5+
6+
import httpx
7+
from fastapi import APIRouter, Depends, HTTPException, Request
8+
from fastapi.responses import StreamingResponse, Response
9+
10+
from api.auth import api_key_auth
11+
from api.setting import AWS_REGION, DEBUG
12+
13+
logger = logging.getLogger(__name__)
14+
15+
router = APIRouter(prefix="/bedrock")
16+
17+
# Get AWS bearer token from environment
18+
AWS_BEARER_TOKEN = os.environ.get("AWS_BEARER_TOKEN_BEDROCK")
19+
20+
if not AWS_BEARER_TOKEN:
21+
logger.warning("AWS_BEARER_TOKEN_BEDROCK not set - bedrock proxy endpoints will not work")
22+
23+
24+
def get_aws_url(model_id: str, endpoint_path: str) -> str:
25+
"""Convert proxy path to AWS Bedrock URL"""
26+
encoded_model_id = quote(model_id, safe='')
27+
base_url = f"https://bedrock-runtime.{AWS_REGION}.amazonaws.com"
28+
return f"{base_url}/model/{encoded_model_id}/{endpoint_path}"
29+
30+
31+
def get_proxy_headers(request: Request) -> Dict[str, str]:
32+
"""Get headers to forward to AWS, replacing Authorization"""
33+
headers = dict(request.headers)
34+
35+
# Remove proxy authorization and add AWS bearer token
36+
headers.pop("authorization", None)
37+
headers.pop("host", None) # Let httpx set the correct host
38+
39+
if AWS_BEARER_TOKEN:
40+
headers["Authorization"] = f"Bearer {AWS_BEARER_TOKEN}"
41+
42+
return headers
43+
44+
45+
@router.api_route("/model/{model_id}/{endpoint_path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH"])
46+
async def transparent_proxy(
47+
request: Request,
48+
model_id: str,
49+
endpoint_path: str,
50+
_: None = Depends(api_key_auth)
51+
):
52+
"""
53+
Transparent HTTP proxy to AWS Bedrock.
54+
Forwards all requests as-is, only changing auth and URL.
55+
"""
56+
if not AWS_BEARER_TOKEN:
57+
raise HTTPException(
58+
status_code=503,
59+
detail="AWS_BEARER_TOKEN_BEDROCK not configured"
60+
)
61+
62+
# Build AWS URL
63+
aws_url = get_aws_url(model_id, endpoint_path)
64+
65+
# Get headers to forward
66+
proxy_headers = get_proxy_headers(request)
67+
68+
# Get request body
69+
body = await request.body()
70+
71+
if DEBUG:
72+
logger.info(f"Proxying {request.method} to: {aws_url}")
73+
logger.info(f"Headers: {dict(proxy_headers)}")
74+
if body:
75+
logger.info(f"Body length: {len(body)} bytes")
76+
77+
try:
78+
async with httpx.AsyncClient() as client:
79+
# Forward the request to AWS
80+
response = await client.request(
81+
method=request.method,
82+
url=aws_url,
83+
headers=proxy_headers,
84+
content=body,
85+
params=request.query_params,
86+
timeout=120.0
87+
)
88+
89+
# Check if response is streaming
90+
content_type = response.headers.get("content-type", "")
91+
if "text/event-stream" in content_type or "stream" in content_type:
92+
# Stream the response
93+
return StreamingResponse(
94+
content=response.aiter_bytes(),
95+
status_code=response.status_code,
96+
headers=dict(response.headers),
97+
media_type=content_type
98+
)
99+
100+
# Regular response
101+
return Response(
102+
content=response.content,
103+
status_code=response.status_code,
104+
headers=dict(response.headers)
105+
)
106+
107+
except httpx.RequestError as e:
108+
logger.error(f"Proxy request failed: {e}")
109+
raise HTTPException(status_code=502, detail=f"Upstream request failed: {str(e)}")
110+
except httpx.HTTPStatusError as e:
111+
logger.error(f"AWS returned error: {e.response.status_code}")
112+
raise HTTPException(status_code=e.response.status_code, detail=e.response.text)
113+
except Exception as e:
114+
logger.error(f"Proxy error: {e}")
115+
raise HTTPException(status_code=500, detail="Proxy error")

src/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ uvicorn==0.29.0
44
mangum==0.17.0
55
tiktoken==0.6.0
66
requests==2.32.4
7+
httpx==0.27.0
78
numpy==1.26.4
89
boto3==1.37.0
910
botocore==1.37.0

0 commit comments

Comments
 (0)