store only user id in session in order to avoid invalidating all sessions when user model changes

This commit is contained in:
Athou
2024-01-09 21:06:47 +01:00
parent 2bf9186135
commit 7675a24eb6
11 changed files with 57 additions and 46 deletions

View File

@@ -10,6 +10,7 @@ import java.util.concurrent.TimeUnit;
import org.hibernate.cfg.AvailableSettings; import org.hibernate.cfg.AvailableSettings;
import com.codahale.metrics.json.MetricsModule; import com.codahale.metrics.json.MetricsModule;
import com.commafeed.backend.dao.UserDAO;
import com.commafeed.backend.feed.FeedRefreshEngine; import com.commafeed.backend.feed.FeedRefreshEngine;
import com.commafeed.backend.model.AbstractModel; import com.commafeed.backend.model.AbstractModel;
import com.commafeed.backend.model.Feed; import com.commafeed.backend.model.Feed;
@@ -166,7 +167,9 @@ public class CommaFeedApplication extends Application<CommaFeedConfiguration> {
environment.servlets().setSessionHandler(config.getSessionHandlerFactory().build(config.getDataSourceFactory())); environment.servlets().setSessionHandler(config.getSessionHandlerFactory().build(config.getDataSourceFactory()));
// support for "@SecurityCheck User user" injection // support for "@SecurityCheck User user" injection
environment.jersey().register(new SecurityCheckFactoryProvider.Binder(injector.getInstance(UserService.class))); environment.jersey()
.register(new SecurityCheckFactoryProvider.Binder(injector.getInstance(UserDAO.class),
injector.getInstance(UserService.class)));
// support for "@Context SessionHelper sessionHelper" injection // support for "@Context SessionHelper sessionHelper" injection
environment.jersey().register(new SessionHelperFactoryProvider.Binder()); environment.jersey().register(new SessionHelperFactoryProvider.Binder());

View File

@@ -8,6 +8,7 @@ import java.util.function.Function;
import org.glassfish.jersey.server.ContainerRequest; import org.glassfish.jersey.server.ContainerRequest;
import com.commafeed.backend.dao.UserDAO;
import com.commafeed.backend.model.User; import com.commafeed.backend.model.User;
import com.commafeed.backend.model.UserRole.Role; import com.commafeed.backend.model.UserRole.Role;
import com.commafeed.backend.service.UserService; import com.commafeed.backend.service.UserService;
@@ -25,6 +26,7 @@ public class SecurityCheckFactory implements Function<ContainerRequest, User> {
private static final String PREFIX = "Basic"; private static final String PREFIX = "Basic";
private final UserDAO userDAO;
private final UserService userService; private final UserService userService;
private final HttpServletRequest request; private final HttpServletRequest request;
private final Role role; private final Role role;
@@ -59,7 +61,7 @@ public class SecurityCheckFactory implements Function<ContainerRequest, User> {
} }
Optional<User> cookieSessionLogin(SessionHelper sessionHelper) { Optional<User> cookieSessionLogin(SessionHelper sessionHelper) {
Optional<User> loggedInUser = sessionHelper.getLoggedInUser(); Optional<User> loggedInUser = sessionHelper.getLoggedInUserId().map(userDAO::findById);
loggedInUser.ifPresent(userService::performPostLoginActivities); loggedInUser.ifPresent(userService::performPostLoginActivities);
return loggedInUser; return loggedInUser;
} }

View File

@@ -9,6 +9,7 @@ import org.glassfish.jersey.server.internal.inject.MultivaluedParameterExtractor
import org.glassfish.jersey.server.model.Parameter; import org.glassfish.jersey.server.model.Parameter;
import org.glassfish.jersey.server.spi.internal.ValueParamProvider; import org.glassfish.jersey.server.spi.internal.ValueParamProvider;
import com.commafeed.backend.dao.UserDAO;
import com.commafeed.backend.model.User; import com.commafeed.backend.model.User;
import com.commafeed.backend.service.UserService; import com.commafeed.backend.service.UserService;
@@ -21,12 +22,14 @@ import lombok.RequiredArgsConstructor;
public class SecurityCheckFactoryProvider extends AbstractValueParamProvider { public class SecurityCheckFactoryProvider extends AbstractValueParamProvider {
private final UserService userService; private final UserService userService;
private final UserDAO userDAO;
private final HttpServletRequest request; private final HttpServletRequest request;
@Inject @Inject
public SecurityCheckFactoryProvider(final MultivaluedParameterExtractorProvider extractorProvider, UserService userService, public SecurityCheckFactoryProvider(final MultivaluedParameterExtractorProvider extractorProvider, UserDAO userDAO,
HttpServletRequest request) { UserService userService, HttpServletRequest request) {
super(() -> extractorProvider, Parameter.Source.UNKNOWN); super(() -> extractorProvider, Parameter.Source.UNKNOWN);
this.userDAO = userDAO;
this.userService = userService; this.userService = userService;
this.request = request; this.request = request;
} }
@@ -44,17 +47,19 @@ public class SecurityCheckFactoryProvider extends AbstractValueParamProvider {
return null; return null;
} }
return new SecurityCheckFactory(userService, request, securityCheck.value(), securityCheck.apiKeyAllowed()); return new SecurityCheckFactory(userDAO, userService, request, securityCheck.value(), securityCheck.apiKeyAllowed());
} }
@RequiredArgsConstructor @RequiredArgsConstructor
public static class Binder extends AbstractBinder { public static class Binder extends AbstractBinder {
private final UserDAO userDAO;
private final UserService userService; private final UserService userService;
@Override @Override
protected void configure() { protected void configure() {
bind(SecurityCheckFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class); bind(SecurityCheckFactoryProvider.class).to(ValueParamProvider.class).in(Singleton.class);
bind(userDAO).to(UserDAO.class);
bind(userService).to(UserService.class); bind(userService).to(UserService.class);
} }
} }

View File

@@ -4,6 +4,7 @@ import java.io.IOException;
import java.util.Optional; import java.util.Optional;
import com.commafeed.backend.dao.UnitOfWork; import com.commafeed.backend.dao.UnitOfWork;
import com.commafeed.backend.dao.UserDAO;
import com.commafeed.backend.dao.UserSettingsDAO; import com.commafeed.backend.dao.UserSettingsDAO;
import com.commafeed.backend.model.User; import com.commafeed.backend.model.User;
import com.commafeed.backend.model.UserSettings; import com.commafeed.backend.model.UserSettings;
@@ -20,13 +21,16 @@ abstract class AbstractCustomCodeServlet extends HttpServlet {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
private final UnitOfWork unitOfWork; private final UnitOfWork unitOfWork;
private final UserDAO userDAO;
private final UserSettingsDAO userSettingsDAO; private final UserSettingsDAO userSettingsDAO;
@Override @Override
protected final void doGet(final HttpServletRequest req, HttpServletResponse resp) throws IOException { protected final void doGet(final HttpServletRequest req, HttpServletResponse resp) throws IOException {
resp.setContentType(getMimeType()); resp.setContentType(getMimeType());
final Optional<User> user = new SessionHelper(req).getLoggedInUser(); SessionHelper sessionHelper = new SessionHelper(req);
Optional<Long> userId = sessionHelper.getLoggedInUserId();
final Optional<User> user = unitOfWork.call(() -> userId.map(userDAO::findById));
if (user.isEmpty()) { if (user.isEmpty()) {
return; return;
} }

View File

@@ -1,6 +1,7 @@
package com.commafeed.frontend.servlet; package com.commafeed.frontend.servlet;
import com.commafeed.backend.dao.UnitOfWork; import com.commafeed.backend.dao.UnitOfWork;
import com.commafeed.backend.dao.UserDAO;
import com.commafeed.backend.dao.UserSettingsDAO; import com.commafeed.backend.dao.UserSettingsDAO;
import com.commafeed.backend.model.UserSettings; import com.commafeed.backend.model.UserSettings;
@@ -11,8 +12,8 @@ public class CustomCssServlet extends AbstractCustomCodeServlet {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
@Inject @Inject
public CustomCssServlet(UnitOfWork unitOfWork, UserSettingsDAO userSettingsDAO) { public CustomCssServlet(UnitOfWork unitOfWork, UserDAO userDAO, UserSettingsDAO userSettingsDAO) {
super(unitOfWork, userSettingsDAO); super(unitOfWork, userDAO, userSettingsDAO);
} }
@Override @Override

View File

@@ -1,6 +1,7 @@
package com.commafeed.frontend.servlet; package com.commafeed.frontend.servlet;
import com.commafeed.backend.dao.UnitOfWork; import com.commafeed.backend.dao.UnitOfWork;
import com.commafeed.backend.dao.UserDAO;
import com.commafeed.backend.dao.UserSettingsDAO; import com.commafeed.backend.dao.UserSettingsDAO;
import com.commafeed.backend.model.UserSettings; import com.commafeed.backend.model.UserSettings;
@@ -13,8 +14,8 @@ public class CustomJsServlet extends AbstractCustomCodeServlet {
private static final long serialVersionUID = 1L; private static final long serialVersionUID = 1L;
@Inject @Inject
public CustomJsServlet(UnitOfWork unitOfWork, UserSettingsDAO userSettingsDAO) { public CustomJsServlet(UnitOfWork unitOfWork, UserDAO userDAO, UserSettingsDAO userSettingsDAO) {
super(unitOfWork, userSettingsDAO); super(unitOfWork, userDAO, userSettingsDAO);
} }
@Override @Override

View File

@@ -11,6 +11,7 @@ import com.commafeed.backend.dao.FeedCategoryDAO;
import com.commafeed.backend.dao.FeedEntryStatusDAO; import com.commafeed.backend.dao.FeedEntryStatusDAO;
import com.commafeed.backend.dao.FeedSubscriptionDAO; import com.commafeed.backend.dao.FeedSubscriptionDAO;
import com.commafeed.backend.dao.UnitOfWork; import com.commafeed.backend.dao.UnitOfWork;
import com.commafeed.backend.dao.UserDAO;
import com.commafeed.backend.model.FeedCategory; import com.commafeed.backend.model.FeedCategory;
import com.commafeed.backend.model.FeedEntryStatus; import com.commafeed.backend.model.FeedEntryStatus;
import com.commafeed.backend.model.FeedSubscription; import com.commafeed.backend.model.FeedSubscription;
@@ -41,6 +42,7 @@ public class NextUnreadServlet extends HttpServlet {
private final FeedSubscriptionDAO feedSubscriptionDAO; private final FeedSubscriptionDAO feedSubscriptionDAO;
private final FeedEntryStatusDAO feedEntryStatusDAO; private final FeedEntryStatusDAO feedEntryStatusDAO;
private final FeedCategoryDAO feedCategoryDAO; private final FeedCategoryDAO feedCategoryDAO;
private final UserDAO userDAO;
private final UserService userService; private final UserService userService;
private final FeedEntryService feedEntryService; private final FeedEntryService feedEntryService;
private final CommaFeedConfiguration config; private final CommaFeedConfiguration config;
@@ -51,7 +53,8 @@ public class NextUnreadServlet extends HttpServlet {
String orderParam = req.getParameter(PARAM_READINGORDER); String orderParam = req.getParameter(PARAM_READINGORDER);
SessionHelper sessionHelper = new SessionHelper(req); SessionHelper sessionHelper = new SessionHelper(req);
Optional<User> user = sessionHelper.getLoggedInUser(); Optional<Long> userId = sessionHelper.getLoggedInUserId();
Optional<User> user = unitOfWork.call(() -> userId.map(userDAO::findById));
user.ifPresent(value -> unitOfWork.run(() -> userService.performPostLoginActivities(value))); user.ifPresent(value -> unitOfWork.run(() -> userService.performPostLoginActivities(value)));
if (user.isEmpty()) { if (user.isEmpty()) {
resp.sendRedirect(resp.encodeRedirectURL(config.getApplicationSettings().getPublicUrl())); resp.sendRedirect(resp.encodeRedirectURL(config.getApplicationSettings().getPublicUrl()));

View File

@@ -11,32 +11,25 @@ import lombok.RequiredArgsConstructor;
@RequiredArgsConstructor @RequiredArgsConstructor
public class SessionHelper { public class SessionHelper {
private static final String SESSION_KEY_USER = "user"; public static final String SESSION_KEY_USER_ID = "user-id";
private final HttpServletRequest request; private final HttpServletRequest request;
public Optional<User> getLoggedInUser() { public Optional<Long> getLoggedInUserId() {
Optional<HttpSession> session = getSession(false); HttpSession session = request.getSession(false);
if (session.isPresent()) { return getLoggedInUserId(session);
return getLoggedInUser(session.get());
}
return Optional.empty();
} }
public static Optional<User> getLoggedInUser(HttpSession session) { public static Optional<Long> getLoggedInUserId(HttpSession session) {
User user = (User) session.getAttribute(SESSION_KEY_USER); if (session == null) {
return Optional.ofNullable(user); return Optional.empty();
}
Long userId = (Long) session.getAttribute(SESSION_KEY_USER_ID);
return Optional.ofNullable(userId);
} }
public void setLoggedInUser(User user) { public void setLoggedInUser(User user) {
Optional<HttpSession> session = getSession(true); request.getSession(true).setAttribute(SESSION_KEY_USER_ID, user.getId());
session.get().setAttribute(SESSION_KEY_USER, user);
}
private Optional<HttpSession> getSession(boolean force) {
HttpSession session = request.getSession(force);
return Optional.ofNullable(session);
} }
} }

View File

@@ -2,7 +2,6 @@ package com.commafeed.frontend.ws;
import java.util.Optional; import java.util.Optional;
import com.commafeed.backend.model.User;
import com.commafeed.frontend.session.SessionHelper; import com.commafeed.frontend.session.SessionHelper;
import jakarta.inject.Inject; import jakarta.inject.Inject;
@@ -26,8 +25,8 @@ public class WebSocketConfigurator extends Configurator {
public void modifyHandshake(ServerEndpointConfig config, HandshakeRequest request, HandshakeResponse response) { public void modifyHandshake(ServerEndpointConfig config, HandshakeRequest request, HandshakeResponse response) {
HttpSession httpSession = (HttpSession) request.getHttpSession(); HttpSession httpSession = (HttpSession) request.getHttpSession();
if (httpSession != null) { if (httpSession != null) {
Optional<User> user = SessionHelper.getLoggedInUser(httpSession); Optional<Long> userId = SessionHelper.getLoggedInUserId(httpSession);
user.ifPresent(value -> config.getUserProperties().put(SESSIONKEY_USERID, value.getId())); userId.ifPresent(value -> config.getUserProperties().put(SESSIONKEY_USERID, value));
} }
} }

View File

@@ -5,6 +5,7 @@ import java.util.Optional;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.Mockito; import org.mockito.Mockito;
import com.commafeed.backend.dao.UserDAO;
import com.commafeed.backend.model.User; import com.commafeed.backend.model.User;
import com.commafeed.backend.service.UserService; import com.commafeed.backend.service.UserService;
import com.commafeed.backend.service.internal.PostLoginActivities; import com.commafeed.backend.service.internal.PostLoginActivities;
@@ -15,15 +16,17 @@ class SecurityCheckFactoryTest {
@Test @Test
void cookieLoginShouldPerformPostLoginActivities() { void cookieLoginShouldPerformPostLoginActivities() {
User userInSession = new User(); User userInSession = new User();
UserDAO userDAO = Mockito.mock(UserDAO.class);
Mockito.when(userDAO.findById(1L)).thenReturn(userInSession);
SessionHelper sessionHelper = Mockito.mock(SessionHelper.class); SessionHelper sessionHelper = Mockito.mock(SessionHelper.class);
Mockito.when(sessionHelper.getLoggedInUser()).thenReturn(Optional.of(userInSession)); Mockito.when(sessionHelper.getLoggedInUserId()).thenReturn(Optional.of(1L));
PostLoginActivities postLoginActivities = Mockito.mock(PostLoginActivities.class); PostLoginActivities postLoginActivities = Mockito.mock(PostLoginActivities.class);
UserService service = new UserService(null, null, null, null, null, null, null, postLoginActivities); UserService service = new UserService(null, null, null, null, null, null, null, postLoginActivities);
SecurityCheckFactory factory = new SecurityCheckFactory(service, null, null, false); SecurityCheckFactory factory = new SecurityCheckFactory(userDAO, service, null, null, false);
factory.cookieSessionLogin(sessionHelper); factory.cookieSessionLogin(sessionHelper);
Mockito.verify(postLoginActivities).executeFor(userInSession); Mockito.verify(postLoginActivities).executeFor(userInSession);

View File

@@ -13,14 +13,12 @@ import jakarta.servlet.http.HttpSession;
class SessionHelperTest { class SessionHelperTest {
private static final String SESSION_KEY_USER = "user";
@Test @Test
void gettingUserDoesNotCreateSession() { void gettingUserDoesNotCreateSession() {
HttpServletRequest request = Mockito.mock(HttpServletRequest.class); HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
SessionHelper sessionHelper = new SessionHelper(request); SessionHelper sessionHelper = new SessionHelper(request);
sessionHelper.getLoggedInUser(); sessionHelper.getLoggedInUserId();
Mockito.verify(request).getSession(false); Mockito.verify(request).getSession(false);
} }
@@ -31,23 +29,23 @@ class SessionHelperTest {
Mockito.when(request.getSession(false)).thenReturn(null); Mockito.when(request.getSession(false)).thenReturn(null);
SessionHelper sessionHelper = new SessionHelper(request); SessionHelper sessionHelper = new SessionHelper(request);
Optional<User> user = sessionHelper.getLoggedInUser(); Optional<Long> userId = sessionHelper.getLoggedInUserId();
Assertions.assertFalse(user.isPresent()); Assertions.assertFalse(userId.isPresent());
} }
@Test @Test
void gettingUserShouldNotReturnUserIfUserNotPresentInHttpSession() { void gettingUserShouldNotReturnUserIfUserNotPresentInHttpSession() {
HttpSession session = Mockito.mock(HttpSession.class); HttpSession session = Mockito.mock(HttpSession.class);
Mockito.when(session.getAttribute(SESSION_KEY_USER)).thenReturn(null); Mockito.when(session.getAttribute(SessionHelper.SESSION_KEY_USER_ID)).thenReturn(null);
HttpServletRequest request = Mockito.mock(HttpServletRequest.class); HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
Mockito.when(request.getSession(false)).thenReturn(session); Mockito.when(request.getSession(false)).thenReturn(session);
SessionHelper sessionHelper = new SessionHelper(request); SessionHelper sessionHelper = new SessionHelper(request);
Optional<User> user = sessionHelper.getLoggedInUser(); Optional<Long> userId = sessionHelper.getLoggedInUserId();
Assertions.assertFalse(user.isPresent()); Assertions.assertFalse(userId.isPresent());
} }
@Test @Test
@@ -55,16 +53,15 @@ class SessionHelperTest {
User userInSession = new User(); User userInSession = new User();
HttpSession session = Mockito.mock(HttpSession.class); HttpSession session = Mockito.mock(HttpSession.class);
Mockito.when(session.getAttribute(SESSION_KEY_USER)).thenReturn(userInSession); Mockito.when(session.getAttribute(SessionHelper.SESSION_KEY_USER_ID)).thenReturn(1L);
HttpServletRequest request = Mockito.mock(HttpServletRequest.class); HttpServletRequest request = Mockito.mock(HttpServletRequest.class);
Mockito.when(request.getSession(false)).thenReturn(session); Mockito.when(request.getSession(false)).thenReturn(session);
SessionHelper sessionHelper = new SessionHelper(request); SessionHelper sessionHelper = new SessionHelper(request);
Optional<User> user = sessionHelper.getLoggedInUser(); Optional<Long> userId = sessionHelper.getLoggedInUserId();
Assertions.assertTrue(user.isPresent()); Assertions.assertTrue(userId.isPresent());
Assertions.assertEquals(userInSession, user.get());
} }
} }