Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions dotnet/src/CopilotRequestHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,40 @@ public enum CopilotRequestTransport
[Experimental(Diagnostics.Experimental)]
public sealed class CopilotRequestContext
{
/// <summary>
/// Creates an instance of <see cref="CopilotRequestContext"/> by copying the values from another instance.
/// </summary>
/// <param name="original">A <see cref="CopilotRequestContext"/> instance to copy values from.</param>
public CopilotRequestContext(CopilotRequestContext original)
: this(original.RequestId, original.Url, original.Headers)
{
SessionId = original.SessionId;
Transport = original.Transport;
CancellationToken = original.CancellationToken;
WebSocketResponse = original.WebSocketResponse;
}

internal CopilotRequestContext(string requestId, string url, IReadOnlyDictionary<string, IReadOnlyList<string>> headers)
{
RequestId = requestId;
Url = url;
Headers = headers;
}
Comment thread
SteveSandersonMS marked this conversation as resolved.

/// <summary>Opaque runtime-minted id, stable across the request lifecycle.</summary>
public required string RequestId { get; init; }
public string RequestId { get; init; }

/// <summary>Runtime session id that triggered the request, if any.</summary>
public string? SessionId { get; init; }

/// <summary>Transport the runtime would otherwise use.</summary>
public CopilotRequestTransport Transport { get; init; }

/// <summary>Original request URL.</summary>
public required string Url { get; init; }
/// <summary>Request URL.</summary>
public string Url { get; init; }

/// <summary>Original request headers.</summary>
public required IReadOnlyDictionary<string, IReadOnlyList<string>> Headers { get; init; }
/// <summary>Request headers.</summary>
public IReadOnlyDictionary<string, IReadOnlyList<string>> Headers { get; init; }

/// <summary>
/// Cancelled when the runtime aborts this in-flight request. Subclasses that
Expand Down Expand Up @@ -199,25 +219,17 @@ public virtual async ValueTask DisposeAsync()
[Experimental(Diagnostics.Experimental)]
public class CopilotWebSocketForwarder : CopilotWebSocketHandler
{
private readonly string _url;
private readonly IReadOnlyDictionary<string, IReadOnlyList<string>> _headers;
private WebSocket? _upstream;
private CancellationTokenSource? _pumpCts;
private Task? _responsePump;

/// <summary>
/// Initializes a forwarding handler that will open the upstream socket on
/// demand using the supplied URL/headers (or the values from
/// <paramref name="context"/> when omitted).
/// demand using the supplied URL/headers from <paramref name="context"/>.
/// </summary>
public CopilotWebSocketForwarder(
CopilotRequestContext context,
string? url = null,
IReadOnlyDictionary<string, IReadOnlyList<string>>? headers = null)
public CopilotWebSocketForwarder(CopilotRequestContext context)
: base(context)
{
_url = url ?? context.Url;
_headers = headers ?? context.Headers;
}

/// <summary>
Expand All @@ -231,7 +243,7 @@ internal override async Task OpenAsync()
}

var socket = new ClientWebSocket();
foreach (var (name, values) in _headers)
foreach (var (name, values) in Context.Headers)
{
if (LlmInferenceHeaders.Forbidden.Contains(name))
{
Expand All @@ -248,7 +260,7 @@ internal override async Task OpenAsync()
}
}

await socket.ConnectAsync(ToWebSocketUri(_url), Context.CancellationToken).ConfigureAwait(false);
await socket.ConnectAsync(ToWebSocketUri(Context.Url), Context.CancellationToken).ConfigureAwait(false);
_upstream = socket;
_pumpCts = CancellationTokenSource.CreateLinkedTokenSource(Context.CancellationToken);

Expand Down Expand Up @@ -855,13 +867,10 @@ public Task<LlmInferenceHttpRequestStartResult> HttpRequestStartAsync(LlmInferen
// dropping those frames and hanging the body drain.
var exchange = _pending.GetOrAdd(request.RequestId, id => new LlmInferenceExchange(id, _getServerRpc));
exchange.Method = request.Method;
exchange.Context = new CopilotRequestContext
exchange.Context = new CopilotRequestContext(request.RequestId, request.Url, ToReadOnlyHeaders(request.Headers))
{
RequestId = request.RequestId,
SessionId = request.SessionId,
Transport = transport,
Url = request.Url,
Headers = ToReadOnlyHeaders(request.Headers),
CancellationToken = exchange.Abort.Token,
};

Expand Down
7 changes: 3 additions & 4 deletions dotnet/test/E2E/CopilotRequestWebSocketE2ETests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ protected override Task<HttpResponseMessage> SendRequestAsync(HttpRequestMessage

protected override Task<CopilotWebSocketHandler> OpenWebSocketAsync(CopilotRequestContext ctx)
{
var wsUrl = Rewrite(new Uri(ctx.Url)).ToString();
return Task.FromResult<CopilotWebSocketHandler>(new CountingForwardingWebSocketHandler(ctx, wsUrl, counters));
ctx = new CopilotRequestContext(ctx) { Url = Rewrite(new Uri(ctx.Url)).ToString() };
return Task.FromResult<CopilotWebSocketHandler>(new CountingForwardingWebSocketHandler(ctx, counters));
}

private Uri Rewrite(Uri original) => new UriBuilder(original)
Expand All @@ -133,9 +133,8 @@ protected override Task<CopilotWebSocketHandler> OpenWebSocketAsync(CopilotReque
/// </summary>
internal sealed class CountingForwardingWebSocketHandler(
CopilotRequestContext context,
string url,
HandlerCounters counters)
: CopilotWebSocketForwarder(context, url)
: CopilotWebSocketForwarder(context)
{
public override Task SendRequestMessageAsync(CopilotWebSocketMessage message)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ public final class CopilotRequestContext {
this.cancellation = cancellation;
}

private CopilotRequestContext(String requestId, @Nullable String sessionId, CopilotRequestTransport transport,
String url, Map<String, List<String>> headers, CompletableFuture<Void> cancellation,
LlmWebSocketResponseBridge webSocketResponse) {
this(requestId, sessionId, transport, url, headers, cancellation);
this.webSocketResponse = webSocketResponse;
}

/**
* Gets the opaque runtime-minted request id, stable across the request
* lifecycle.
Expand Down Expand Up @@ -88,6 +95,30 @@ public Map<String, List<String>> headers() {
return headers;
}

/**
* Returns a copy of this context with a different request URL.
*
* @param url
* the replacement request URL
* @return the copied context
*/
public CopilotRequestContext withUrl(String url) {
return new CopilotRequestContext(requestId, sessionId, transport, url, headers, cancellation,
webSocketResponse);
}

/**
* Returns a copy of this context with different request headers.
*
* @param headers
* the replacement request headers
* @return the copied context
*/
public CopilotRequestContext withHeaders(Map<String, List<String>> headers) {
return new CopilotRequestContext(requestId, sessionId, transport, url, headers, cancellation,
webSocketResponse);
}

/**
* A future that completes when the runtime cancels this in-flight request (for
* example because the agent turn was aborted upstream). Subclasses that issue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
*/
public class CopilotWebSocketForwarder extends CopilotWebSocketHandler {

private final String url;
private final Map<String, List<String>> headers;

private volatile WebSocket webSocket;

/**
Expand All @@ -40,37 +37,7 @@ public class CopilotWebSocketForwarder extends CopilotWebSocketHandler {
* the per-request context
*/
public CopilotWebSocketForwarder(CopilotRequestContext context) {
this(context, context.url(), context.headers());
}

/**
* Creates a forwarding handler targeting {@code url} with the handshake headers
* from {@code context}.
*
* @param context
* the per-request context
* @param url
* the upstream WebSocket URL
*/
public CopilotWebSocketForwarder(CopilotRequestContext context, String url) {
this(context, url, context.headers());
}

/**
* Creates a forwarding handler targeting {@code url} with the given handshake
* headers.
*
* @param context
* the per-request context
* @param url
* the upstream WebSocket URL
* @param headers
* the handshake headers, multi-valued
*/
public CopilotWebSocketForwarder(CopilotRequestContext context, String url, Map<String, List<String>> headers) {
super(context);
this.url = url;
this.headers = headers;
}

@Override
Expand All @@ -79,6 +46,7 @@ void open() throws Exception {
return;
}
WebSocket.Builder builder = HttpClient.newHttpClient().newWebSocketBuilder();
Map<String, List<String>> headers = context.headers();
if (headers != null) {
for (Map.Entry<String, List<String>> entry : headers.entrySet()) {
if (CopilotRequestHandler.isForbiddenRequestHeader(entry.getKey()) || entry.getValue() == null) {
Expand All @@ -90,8 +58,8 @@ void open() throws Exception {
}
}
try {
this.webSocket = builder.buildAsync(URI.create(normalizeWebSocketScheme(url)), new ForwardingListener())
.join();
this.webSocket = builder
.buildAsync(URI.create(normalizeWebSocketScheme(context.url())), new ForwardingListener()).join();
} catch (Exception e) {
throw unwrap(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,7 @@ protected HttpResponse<InputStream> sendRequest(HttpRequest request, CopilotRequ

@Override
protected CopilotWebSocketHandler openWebSocket(CopilotRequestContext rctx) {
String rewritten = rewriteHost(wsBase, URI.create(rctx.url()));
return new CopilotWebSocketForwarder(rctx, rewritten) {
return new CopilotWebSocketForwarder(rctx.withUrl(rewriteHost(wsBase, URI.create(rctx.url())))) {
@Override
public void sendRequestMessage(CopilotWebSocketMessage message) throws Exception {
wsRequestMessages.incrementAndGet();
Expand Down
10 changes: 4 additions & 6 deletions nodejs/src/copilotRequestHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ export interface CopilotRequestContext {
readonly requestId: string;
readonly sessionId?: string;
readonly transport: "http" | "websocket";
readonly url: string;
readonly headers: LlmInferenceHeaders;
url: string;
headers: LlmInferenceHeaders;
readonly signal: AbortSignal;
}

Expand Down Expand Up @@ -139,12 +139,10 @@ export abstract class CopilotWebSocketHandler implements AsyncDisposable {
* @experimental
*/
export class CopilotWebSocketForwarder extends CopilotWebSocketHandler {
readonly #url: string;
#upstream: WebSocket | null = null;

constructor(context: CopilotRequestContext, url = context.url) {
constructor(context: CopilotRequestContext) {
super(context);
this.#url = url;
}

override sendRequestMessage(data: string | Uint8Array): void {
Expand All @@ -159,7 +157,7 @@ export class CopilotWebSocketForwarder extends CopilotWebSocketHandler {
if (this.#upstream) {
return;
}
const upstream = new WebSocket(this.#url);
const upstream = new WebSocket(this.context.url);
upstream.binaryType = "arraybuffer";
this.#upstream = upstream;
upstream.addEventListener("message", (event) => {
Expand Down
Loading
Loading