Skip to content

Commit 0052b37

Browse files
committed
Refactor WebSocket close for suspend/resume
Ensure that WebSocket connection closure completes if the connection is closed when the server side has used the proprietary suspend/resume feature to suspend the connection.
1 parent 60594b4 commit 0052b37

File tree

7 files changed

+187
-8
lines changed

7 files changed

+187
-8
lines changed

java/org/apache/tomcat/websocket/Constants.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.ArrayList;
2020
import java.util.Collections;
2121
import java.util.List;
22+
import java.util.concurrent.TimeUnit;
2223

2324
import jakarta.websocket.ClientEndpointConfig;
2425
import jakarta.websocket.Extension;
@@ -117,6 +118,11 @@ public class Constants {
117118
// Milliseconds so this is 20 seconds
118119
public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;
119120

121+
// Configuration for session close timeout
122+
public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT";
123+
// Default is 30 seconds - setting is in milliseconds
124+
public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30);
125+
120126
// Configuration for read idle timeout on WebSocket session
121127
public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS";
122128

java/org/apache/tomcat/websocket/WsSession.java

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.Map;
2828
import java.util.Set;
2929
import java.util.concurrent.ConcurrentHashMap;
30+
import java.util.concurrent.TimeUnit;
3031
import java.util.concurrent.atomic.AtomicLong;
3132
import java.util.concurrent.atomic.AtomicReference;
3233

@@ -115,6 +116,7 @@ public class WsSession implements Session {
115116
private volatile long lastActiveRead = System.currentTimeMillis();
116117
private volatile long lastActiveWrite = System.currentTimeMillis();
117118
private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>();
119+
private volatile Long sessionCloseTimeoutExpiry;
118120

119121

120122
/**
@@ -593,7 +595,14 @@ public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal
593595
*/
594596
state.set(State.CLOSED);
595597
// ... and close the network connection.
596-
wsRemoteEndpoint.close();
598+
closeConnection();
599+
} else {
600+
/*
601+
* Set close timeout. If the client fails to send a close message response within the timeout, the session
602+
* and the connection will be closed when the timeout expires.
603+
*/
604+
sessionCloseTimeoutExpiry =
605+
Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout()));
597606
}
598607

599608
// Fail any uncompleted messages.
@@ -632,7 +641,7 @@ public void onClose(CloseReason closeReason) {
632641
state.set(State.CLOSED);
633642

634643
// Close the network connection.
635-
wsRemoteEndpoint.close();
644+
closeConnection();
636645
} else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) {
637646
/*
638647
* The local endpoint sent a close message the the same time as the remote endpoint. The local close is
@@ -644,12 +653,55 @@ public void onClose(CloseReason closeReason) {
644653
* The local endpoint sent the first close message. The remote endpoint has now responded with its own close
645654
* message so mark the session as fully closed and close the network connection.
646655
*/
647-
wsRemoteEndpoint.close();
656+
closeConnection();
648657
}
649658
// CLOSING and CLOSED are NO-OPs
650659
}
651660

652661

662+
private void closeConnection() {
663+
/*
664+
* Close the network connection.
665+
*/
666+
wsRemoteEndpoint.close();
667+
/*
668+
* Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for
669+
* tracking the session close timeout.
670+
*/
671+
webSocketContainer.unregisterSession(getSessionMapKey(), this);
672+
}
673+
674+
675+
/*
676+
* Returns the session close timeout in milliseconds
677+
*/
678+
protected long getSessionCloseTimeout() {
679+
long result = 0;
680+
Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY);
681+
if (obj instanceof Long) {
682+
result = ((Long) obj).intValue();
683+
}
684+
if (result <= 0) {
685+
result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT;
686+
}
687+
return result;
688+
}
689+
690+
691+
protected void checkCloseTimeout() {
692+
// Skip the check if no session close timeout has been set.
693+
if (sessionCloseTimeoutExpiry != null) {
694+
// Check if the timeout has expired.
695+
if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) {
696+
// Check if the session has been closed in another thread while the timeout was being processed.
697+
if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
698+
closeConnection();
699+
}
700+
}
701+
}
702+
}
703+
704+
653705
private void fireEndpointOnClose(CloseReason closeReason) {
654706

655707
// Fire the onClose event
@@ -722,16 +774,14 @@ private void sendCloseMessage(CloseReason closeReason) {
722774
if (log.isDebugEnabled()) {
723775
log.debug(sm.getString("wsSession.sendCloseFail", id), e);
724776
}
725-
wsRemoteEndpoint.close();
777+
closeConnection();
726778
// Failure to send a close message is not unexpected in the case of
727779
// an abnormal closure (usually triggered by a failure to read/write
728780
// from/to the client. In this case do not trigger the endpoint's
729781
// error handling
730782
if (closeCode != CloseCodes.CLOSED_ABNORMALLY) {
731783
localEndpoint.onError(this, e);
732784
}
733-
} finally {
734-
webSocketContainer.unregisterSession(getSessionMapKey(), this);
735785
}
736786
}
737787

@@ -864,6 +914,11 @@ public String getQueryString() {
864914
@Override
865915
public Principal getUserPrincipal() {
866916
checkState();
917+
return getUserPrincipalInternal();
918+
}
919+
920+
921+
public Principal getUserPrincipalInternal() {
867922
return userPrincipal;
868923
}
869924

java/org/apache/tomcat/websocket/WsWebSocketContainer.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,12 @@ Set<Session> getOpenSessions(Object key) {
610610
synchronized (endPointSessionMapLock) {
611611
Set<WsSession> sessions = endpointSessionMap.get(key);
612612
if (sessions != null) {
613-
result.addAll(sessions);
613+
// Some sessions may be in the process of closing
614+
for (WsSession session : sessions) {
615+
if (session.isOpen()) {
616+
result.add(session);
617+
}
618+
}
614619
}
615620
}
616621
return result;
@@ -1060,8 +1065,10 @@ public void backgroundProcess() {
10601065
if (backgroundProcessCount >= processPeriod) {
10611066
backgroundProcessCount = 0;
10621067

1068+
// Check all registered sessions.
10631069
for (WsSession wsSession : sessions.keySet()) {
10641070
wsSession.checkExpiration();
1071+
wsSession.checkCloseTimeout();
10651072
}
10661073
}
10671074

java/org/apache/tomcat/websocket/server/WsServerContainer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ protected void registerSession(Object key, WsSession wsSession) {
349349
*/
350350
@Override
351351
protected void unregisterSession(Object key, WsSession wsSession) {
352-
if (wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) {
352+
if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) {
353353
unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId());
354354
}
355355
super.unregisterSession(key, wsSession);

test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.concurrent.CountDownLatch;
2424
import java.util.concurrent.TimeUnit;
2525

26+
import jakarta.servlet.ServletContextEvent;
27+
import jakarta.servlet.ServletContextListener;
2628
import jakarta.websocket.ClientEndpointConfig;
2729
import jakarta.websocket.CloseReason;
2830
import jakarta.websocket.ContainerProvider;
@@ -39,7 +41,9 @@
3941
import org.apache.catalina.servlets.DefaultServlet;
4042
import org.apache.catalina.startup.Tomcat;
4143
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
44+
import org.apache.tomcat.websocket.server.Constants;
4245
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
46+
import org.apache.tomcat.websocket.server.WsServerContainer;
4347

4448
public class TestWsSessionSuspendResume extends WebSocketBaseTest {
4549

@@ -141,4 +145,99 @@ void addMessage(String message) {
141145
}
142146
}
143147
}
148+
149+
150+
@Test
151+
public void testSuspendThenClose() throws Exception {
152+
Tomcat tomcat = getTomcatInstance();
153+
154+
Context ctx = getProgrammaticRootContext();
155+
ctx.addApplicationListener(SuspendCloseConfig.class.getName());
156+
ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName());
157+
158+
Tomcat.addServlet(ctx, "default", new DefaultServlet());
159+
ctx.addServletMappingDecoded("/", "default");
160+
161+
tomcat.start();
162+
163+
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
164+
165+
ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
166+
Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig,
167+
new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH));
168+
169+
wsSession.getBasicRemote().sendText("start test");
170+
171+
// Wait for the client response to be received by the server
172+
int count = 0;
173+
while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) {
174+
Thread.sleep(100);
175+
count ++;
176+
}
177+
Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed());
178+
}
179+
180+
181+
public static final class SuspendCloseConfig extends TesterEndpointConfig {
182+
private static final String PATH = "/echo";
183+
184+
@Override
185+
protected Class<?> getEndpointClass() {
186+
return SuspendCloseEndpoint.class;
187+
}
188+
189+
@Override
190+
protected ServerEndpointConfig getServerEndpointConfig() {
191+
return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build();
192+
}
193+
}
194+
195+
196+
public static final class SuspendCloseEndpoint extends Endpoint {
197+
198+
// Yes, a static variable is a hack.
199+
private static WsSession serverSession;
200+
201+
@Override
202+
public void onOpen(Session session, EndpointConfig epc) {
203+
serverSession = (WsSession) session;
204+
// Set a short session close timeout (milliseconds)
205+
serverSession.getUserProperties().put(
206+
org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000));
207+
// Any message will trigger the suspend then close
208+
serverSession.addMessageHandler(String.class, message -> {
209+
try {
210+
serverSession.getBasicRemote().sendText("server session open");
211+
serverSession.getBasicRemote().sendText("suspending server session");
212+
serverSession.suspend();
213+
serverSession.getBasicRemote().sendText("closing server session");
214+
serverSession.close();
215+
} catch (IOException ioe) {
216+
ioe.printStackTrace();
217+
// Attempt to make the failure more obvious
218+
throw new RuntimeException(ioe);
219+
}
220+
});
221+
}
222+
223+
@Override
224+
public void onError(Session session, Throwable t) {
225+
t.printStackTrace();
226+
}
227+
228+
public static boolean isServerSessionFullyClosed() {
229+
return serverSession.isClosed();
230+
}
231+
}
232+
233+
234+
public static class WebSocketFastServerTimeout implements ServletContextListener {
235+
236+
@Override
237+
public void contextInitialized(ServletContextEvent sce) {
238+
WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute(
239+
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
240+
container.setProcessPeriod(0);
241+
}
242+
}
144243
}

webapps/docs/changelog.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@
205205
Review usage of debug logging and downgrade trace or data dumping
206206
operations from debug level to trace. (remm)
207207
</fix>
208+
<fix>
209+
Ensure that WebSocket connection closure completes if the connection is
210+
closed when the server side has used the proprietary suspend/resume
211+
feature to suspend the connection. (markt)
212+
</fix>
208213
</changelog>
209214
</subsection>
210215
<subsection name="Web applications">

webapps/docs/web-socket-howto.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@
6464
the timeout to use in milliseconds. For an infinite timeout, use
6565
<code>-1</code>.</p>
6666

67+
<p>The session close timeout defaults to 30000 milliseconds (30 seconds). This
68+
may be changed by setting the property
69+
<code>org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT</code> in the user
70+
properties collection attached to the WebSocket session. The value assigned
71+
to this property should be a <code>Long</code> and represents the timeout to
72+
use in milliseconds. Values less than or equal to zero will be ignored.</p>
73+
6774
<p>In addition to the <code>Session.setMaxIdleTimeout(long)</code> method which
6875
is part of the Jakarta WebSocket API, Tomcat provides greater control of the
6976
timing out the session due to lack of activity. Setting the property

0 commit comments

Comments
 (0)