From 14e7d701062e1f66454e94c65dd1736c041b7c39 Mon Sep 17 00:00:00 2001 From: Athou Date: Mon, 5 Feb 2024 20:27:26 +0100 Subject: [PATCH] simplify websocket session retrieval --- .../frontend/ws/WebSocketSessions.java | 8 +- .../frontend/ws/WebSocketSessionsTest.java | 82 +++++++++++++++++++ 2 files changed, 83 insertions(+), 7 deletions(-) create mode 100644 commafeed-server/src/test/java/com/commafeed/frontend/ws/WebSocketSessionsTest.java diff --git a/commafeed-server/src/main/java/com/commafeed/frontend/ws/WebSocketSessions.java b/commafeed-server/src/main/java/com/commafeed/frontend/ws/WebSocketSessions.java index 1f1d5fd1..37b86eeb 100644 --- a/commafeed-server/src/main/java/com/commafeed/frontend/ws/WebSocketSessions.java +++ b/commafeed-server/src/main/java/com/commafeed/frontend/ws/WebSocketSessions.java @@ -3,7 +3,6 @@ package com.commafeed.frontend.ws; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.stream.Collectors; import com.codahale.metrics.Gauge; import com.codahale.metrics.MetricRegistry; @@ -38,12 +37,7 @@ public class WebSocketSessions { } public void sendMessage(User user, String text) { - Set userSessions = sessions.entrySet() - .stream() - .filter(e -> e.getKey().equals(user.getId())) - .flatMap(e -> e.getValue().stream()) - .collect(Collectors.toSet()); - + Set userSessions = sessions.get(user.getId()); if (!userSessions.isEmpty()) { log.debug("sending '{}' to {} users via websocket", text, userSessions.size()); for (Session userSession : userSessions) { diff --git a/commafeed-server/src/test/java/com/commafeed/frontend/ws/WebSocketSessionsTest.java b/commafeed-server/src/test/java/com/commafeed/frontend/ws/WebSocketSessionsTest.java new file mode 100644 index 00000000..201b240e --- /dev/null +++ b/commafeed-server/src/test/java/com/commafeed/frontend/ws/WebSocketSessionsTest.java @@ -0,0 +1,82 @@ +package com.commafeed.frontend.ws; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Answers; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; + +import com.codahale.metrics.MetricRegistry; +import com.commafeed.backend.model.User; + +import jakarta.websocket.Session; + +@ExtendWith(MockitoExtension.class) +class WebSocketSessionsTest { + + @Mock + private MetricRegistry metrics; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Session session1; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Session session2; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Session session3; + + private WebSocketSessions webSocketSessions; + + @BeforeEach + void init() { + webSocketSessions = new WebSocketSessions(metrics); + } + + @Test + void sendsMessageToUser() { + Mockito.when(session1.isOpen()).thenReturn(true); + Mockito.when(session2.isOpen()).thenReturn(true); + + User user1 = newUser(1L); + webSocketSessions.add(user1.getId(), session1); + webSocketSessions.add(user1.getId(), session2); + + User user2 = newUser(2L); + webSocketSessions.add(user2.getId(), session3); + + webSocketSessions.sendMessage(user1, "Hello"); + Mockito.verify(session1).getAsyncRemote(); + Mockito.verify(session2).getAsyncRemote(); + Mockito.verifyNoInteractions(session3); + } + + @Test + void closedSessionsAreNotNotified() { + Mockito.when(session1.isOpen()).thenReturn(false); + + User user1 = newUser(1L); + webSocketSessions.add(user1.getId(), session1); + + webSocketSessions.sendMessage(user1, "Hello"); + Mockito.verify(session1, Mockito.never()).getAsyncRemote(); + } + + @Test + void removedSessionsAreNotNotified() { + User user1 = newUser(1L); + webSocketSessions.add(user1.getId(), session1); + webSocketSessions.remove(session1); + + webSocketSessions.sendMessage(user1, "Hello"); + Mockito.verifyNoInteractions(session1); + } + + private User newUser(Long userId) { + User user = new User(); + user.setId(userId); + return user; + } +} \ No newline at end of file