Skip to content

Commit 3631adb

Browse files
committed
Refactor WebSocket close for suspend/resume
Ensure that WebSocket connection closure completes if the connection is closed when the server side has used the proprietary suspend/resume feature to suspend the connection.
1 parent 7b8d5cd commit 3631adb

File tree

7 files changed

+195
-8
lines changed

7 files changed

+195
-8
lines changed

java/org/apache/tomcat/websocket/Constants.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import java.util.ArrayList;
2020
import java.util.Collections;
2121
import java.util.List;
22+
import java.util.concurrent.TimeUnit;
2223

2324
import javax.websocket.Extension;
2425

@@ -107,6 +108,11 @@ public class Constants {
107108
// Milliseconds so this is 20 seconds
108109
public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;
109110

111+
// Configuration for session close timeout
112+
public static final String SESSION_CLOSE_TIMEOUT_PROPERTY = "org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT";
113+
// Default is 30 seconds - setting is in milliseconds
114+
public static final long DEFAULT_SESSION_CLOSE_TIMEOUT = TimeUnit.SECONDS.toMillis(30);
115+
110116
// Configuration for read idle timeout on WebSocket session
111117
public static final String READ_IDLE_TIMEOUT_MS = "org.apache.tomcat.websocket.READ_IDLE_TIMEOUT_MS";
112118

java/org/apache/tomcat/websocket/WsSession.java

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import java.util.Map;
2828
import java.util.Set;
2929
import java.util.concurrent.ConcurrentHashMap;
30+
import java.util.concurrent.TimeUnit;
3031
import java.util.concurrent.atomic.AtomicLong;
3132
import java.util.concurrent.atomic.AtomicReference;
3233

@@ -114,6 +115,7 @@ public class WsSession implements Session {
114115
private volatile long lastActiveRead = System.currentTimeMillis();
115116
private volatile long lastActiveWrite = System.currentTimeMillis();
116117
private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>();
118+
private volatile Long sessionCloseTimeoutExpiry;
117119

118120

119121
/**
@@ -676,7 +678,14 @@ public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal
676678
*/
677679
state.set(State.CLOSED);
678680
// ... and close the network connection.
679-
wsRemoteEndpoint.close();
681+
closeConnection();
682+
} else {
683+
/*
684+
* Set close timeout. If the client fails to send a close message response within the timeout, the session
685+
* and the connection will be closed when the timeout expires.
686+
*/
687+
sessionCloseTimeoutExpiry =
688+
Long.valueOf(System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(getSessionCloseTimeout()));
680689
}
681690

682691
// Fail any uncompleted messages.
@@ -715,7 +724,7 @@ public void onClose(CloseReason closeReason) {
715724
state.set(State.CLOSED);
716725

717726
// Close the network connection.
718-
wsRemoteEndpoint.close();
727+
closeConnection();
719728
} else if (state.compareAndSet(State.OUTPUT_CLOSING, State.CLOSING)) {
720729
/*
721730
* The local endpoint sent a close message the the same time as the remote endpoint. The local close is
@@ -727,12 +736,55 @@ public void onClose(CloseReason closeReason) {
727736
* The local endpoint sent the first close message. The remote endpoint has now responded with its own close
728737
* message so mark the session as fully closed and close the network connection.
729738
*/
730-
wsRemoteEndpoint.close();
739+
closeConnection();
731740
}
732741
// CLOSING and CLOSED are NO-OPs
733742
}
734743

735744

745+
private void closeConnection() {
746+
/*
747+
* Close the network connection.
748+
*/
749+
wsRemoteEndpoint.close();
750+
/*
751+
* Don't unregister the session until the connection is fully closed since webSocketContainer is responsible for
752+
* tracking the session close timeout.
753+
*/
754+
webSocketContainer.unregisterSession(getSessionMapKey(), this);
755+
}
756+
757+
758+
/*
759+
* Returns the session close timeout in milliseconds
760+
*/
761+
protected long getSessionCloseTimeout() {
762+
long result = 0;
763+
Object obj = userProperties.get(Constants.SESSION_CLOSE_TIMEOUT_PROPERTY);
764+
if (obj instanceof Long) {
765+
result = ((Long) obj).intValue();
766+
}
767+
if (result <= 0) {
768+
result = Constants.DEFAULT_SESSION_CLOSE_TIMEOUT;
769+
}
770+
return result;
771+
}
772+
773+
774+
protected void checkCloseTimeout() {
775+
// Skip the check if no session close timeout has been set.
776+
if (sessionCloseTimeoutExpiry != null) {
777+
// Check if the timeout has expired.
778+
if (System.nanoTime() - sessionCloseTimeoutExpiry.longValue() > 0) {
779+
// Check if the session has been closed in another thread while the timeout was being processed.
780+
if (state.compareAndSet(State.OUTPUT_CLOSED, State.CLOSED)) {
781+
closeConnection();
782+
}
783+
}
784+
}
785+
}
786+
787+
736788
private void fireEndpointOnClose(CloseReason closeReason) {
737789

738790
// Fire the onClose event
@@ -805,16 +857,14 @@ private void sendCloseMessage(CloseReason closeReason) {
805857
if (log.isDebugEnabled()) {
806858
log.debug(sm.getString("wsSession.sendCloseFail", id), e);
807859
}
808-
wsRemoteEndpoint.close();
860+
closeConnection();
809861
// Failure to send a close message is not unexpected in the case of
810862
// an abnormal closure (usually triggered by a failure to read/write
811863
// from/to the client. In this case do not trigger the endpoint's
812864
// error handling
813865
if (closeCode != CloseCodes.CLOSED_ABNORMALLY) {
814866
localEndpoint.onError(this, e);
815867
}
816-
} finally {
817-
webSocketContainer.unregisterSession(getSessionMapKey(), this);
818868
}
819869
}
820870

@@ -947,6 +997,11 @@ public String getQueryString() {
947997
@Override
948998
public Principal getUserPrincipal() {
949999
checkState();
1000+
return getUserPrincipalInternal();
1001+
}
1002+
1003+
1004+
public Principal getUserPrincipalInternal() {
9501005
return userPrincipal;
9511006
}
9521007

java/org/apache/tomcat/websocket/WsWebSocketContainer.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,12 @@ Set<Session> getOpenSessions(Object key) {
614614
synchronized (endPointSessionMapLock) {
615615
Set<WsSession> sessions = endpointSessionMap.get(key);
616616
if (sessions != null) {
617-
result.addAll(sessions);
617+
// Some sessions may be in the process of closing
618+
for (WsSession session : sessions) {
619+
if (session.isOpen()) {
620+
result.add(session);
621+
}
622+
}
618623
}
619624
}
620625
return result;
@@ -1061,8 +1066,10 @@ public void backgroundProcess() {
10611066
if (backgroundProcessCount >= processPeriod) {
10621067
backgroundProcessCount = 0;
10631068

1069+
// Check all registered sessions.
10641070
for (WsSession wsSession : sessions.keySet()) {
10651071
wsSession.checkExpiration();
1072+
wsSession.checkCloseTimeout();
10661073
}
10671074
}
10681075

java/org/apache/tomcat/websocket/server/WsServerContainer.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ protected void registerSession(Object key, WsSession wsSession) {
429429
*/
430430
@Override
431431
protected void unregisterSession(Object key, WsSession wsSession) {
432-
if (wsSession.getUserPrincipal() != null && wsSession.getHttpSessionId() != null) {
432+
if (wsSession.getUserPrincipalInternal() != null && wsSession.getHttpSessionId() != null) {
433433
unregisterAuthenticatedSession(wsSession, wsSession.getHttpSessionId());
434434
}
435435
super.unregisterSession(key, wsSession);

test/org/apache/tomcat/websocket/TestWsSessionSuspendResume.java

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.concurrent.CountDownLatch;
2424
import java.util.concurrent.TimeUnit;
2525

26+
import javax.servlet.ServletContextEvent;
27+
import javax.servlet.ServletContextListener;
2628
import javax.websocket.ClientEndpointConfig;
2729
import javax.websocket.CloseReason;
2830
import javax.websocket.ContainerProvider;
@@ -40,7 +42,9 @@
4042
import org.apache.catalina.servlets.DefaultServlet;
4143
import org.apache.catalina.startup.Tomcat;
4244
import org.apache.tomcat.websocket.TesterMessageCountClient.TesterProgrammaticEndpoint;
45+
import org.apache.tomcat.websocket.server.Constants;
4346
import org.apache.tomcat.websocket.server.TesterEndpointConfig;
47+
import org.apache.tomcat.websocket.server.WsServerContainer;
4448

4549
public class TestWsSessionSuspendResume extends WebSocketBaseTest {
4650

@@ -152,4 +156,107 @@ void addMessage(String message) {
152156
}
153157
}
154158
}
159+
160+
161+
@Test
162+
public void testSuspendThenClose() throws Exception {
163+
Tomcat tomcat = getTomcatInstance();
164+
165+
Context ctx = getProgrammaticRootContext();
166+
ctx.addApplicationListener(SuspendCloseConfig.class.getName());
167+
ctx.addApplicationListener(WebSocketFastServerTimeout.class.getName());
168+
169+
Tomcat.addServlet(ctx, "default", new DefaultServlet());
170+
ctx.addServletMappingDecoded("/", "default");
171+
172+
tomcat.start();
173+
174+
WebSocketContainer wsContainer = ContainerProvider.getWebSocketContainer();
175+
176+
ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().build();
177+
Session wsSession = wsContainer.connectToServer(TesterProgrammaticEndpoint.class, clientEndpointConfig,
178+
new URI("ws://localhost:" + getPort() + SuspendResumeConfig.PATH));
179+
180+
wsSession.getBasicRemote().sendText("start test");
181+
182+
// Wait for the client response to be received by the server
183+
int count = 0;
184+
while (count < 50 && !SuspendCloseEndpoint.isServerSessionFullyClosed()) {
185+
Thread.sleep(100);
186+
count ++;
187+
}
188+
Assert.assertTrue(SuspendCloseEndpoint.isServerSessionFullyClosed());
189+
}
190+
191+
192+
public static final class SuspendCloseConfig extends TesterEndpointConfig {
193+
private static final String PATH = "/echo";
194+
195+
@Override
196+
protected Class<?> getEndpointClass() {
197+
return SuspendCloseEndpoint.class;
198+
}
199+
200+
@Override
201+
protected ServerEndpointConfig getServerEndpointConfig() {
202+
return ServerEndpointConfig.Builder.create(getEndpointClass(), PATH).build();
203+
}
204+
}
205+
206+
207+
public static final class SuspendCloseEndpoint extends Endpoint {
208+
209+
// Yes, a static variable is a hack.
210+
private static WsSession serverSession;
211+
212+
@Override
213+
public void onOpen(Session session, EndpointConfig epc) {
214+
serverSession = (WsSession) session;
215+
// Set a short session close timeout (milliseconds)
216+
serverSession.getUserProperties().put(
217+
org.apache.tomcat.websocket.Constants.SESSION_CLOSE_TIMEOUT_PROPERTY, Long.valueOf(2000));
218+
// Any message will trigger the suspend then close
219+
serverSession.addMessageHandler(String.class, new MessageHandler.Whole<String>() {
220+
@Override
221+
public void onMessage(String message) {
222+
try {
223+
serverSession.getBasicRemote().sendText("server session open");
224+
serverSession.getBasicRemote().sendText("suspending server session");
225+
serverSession.suspend();
226+
serverSession.getBasicRemote().sendText("closing server session");
227+
serverSession.close();
228+
} catch (IOException ioe) {
229+
ioe.printStackTrace();
230+
// Attempt to make the failure more obvious
231+
throw new RuntimeException(ioe);
232+
}
233+
}
234+
});
235+
}
236+
237+
@Override
238+
public void onError(Session session, Throwable t) {
239+
t.printStackTrace();
240+
}
241+
242+
public static boolean isServerSessionFullyClosed() {
243+
return serverSession.isClosed();
244+
}
245+
}
246+
247+
248+
public static class WebSocketFastServerTimeout implements ServletContextListener {
249+
250+
@Override
251+
public void contextInitialized(ServletContextEvent sce) {
252+
WsServerContainer container = (WsServerContainer) sce.getServletContext().getAttribute(
253+
Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
254+
container.setProcessPeriod(0);
255+
}
256+
257+
@Override
258+
public void contextDestroyed(ServletContextEvent sce) {
259+
// NO-OP
260+
}
261+
}
155262
}

webapps/docs/changelog.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,11 @@
199199
Review usage of debug logging and downgrade trace or data dumping
200200
operations from debug level to trace. (remm)
201201
</fix>
202+
<fix>
203+
Ensure that WebSocket connection closure completes if the connection is
204+
closed when the server side has used the proprietary suspend/resume
205+
feature to suspend the connection. (markt)
206+
</fix>
202207
</changelog>
203208
</subsection>
204209
<subsection name="Web applications">

webapps/docs/web-socket-howto.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@
6363
the timeout to use in milliseconds. For an infinite timeout, use
6464
<code>-1</code>.</p>
6565

66+
<p>The session close timeout defaults to 30000 milliseconds (30 seconds). This
67+
may be changed by setting the property
68+
<code>org.apache.tomcat.websocket.SESSION_CLOSE_TIMEOUT</code> in the user
69+
properties collection attached to the WebSocket session. The value assigned
70+
to this property should be a <code>Long</code> and represents the timeout to
71+
use in milliseconds. Values less than or equal to zero will be ignored.</p>
72+
6673
<p>In addition to the <code>Session.setMaxIdleTimeout(long)</code> method which
6774
is part of the Java WebSocket API, Tomcat provides greater control of the
6875
timing out the session due to lack of activity. Setting the property

0 commit comments

Comments
 (0)