diff --git a/src/main/java/com/commafeed/backend/service/StartupService.java b/src/main/java/com/commafeed/backend/service/StartupService.java index 426d37d5..91f87cd1 100644 --- a/src/main/java/com/commafeed/backend/service/StartupService.java +++ b/src/main/java/com/commafeed/backend/service/StartupService.java @@ -1,16 +1,12 @@ package com.commafeed.backend.service; -import java.sql.Connection; import java.util.Arrays; import javax.inject.Inject; import javax.inject.Singleton; -import javax.sql.DataSource; +import org.hibernate.Session; import org.hibernate.SessionFactory; -import org.hibernate.engine.jdbc.connections.internal.DatasourceConnectionProviderImpl; -import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider; -import org.hibernate.internal.SessionFactoryImpl; import com.commafeed.CommaFeedApplication; import com.commafeed.CommaFeedConfiguration; @@ -31,7 +27,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @Slf4j -@RequiredArgsConstructor(onConstructor = @__({ @Inject }) ) +@RequiredArgsConstructor(onConstructor = @__({ @Inject })) @Singleton public class StartupService implements Managed { @@ -50,17 +46,10 @@ public class StartupService implements Managed { } private void updateSchema() { - try { - Connection connection = null; + Session session = sessionFactory.openSession(); + session.doWork(connection -> { try { - Thread currentThread = Thread.currentThread(); - ClassLoader classLoader = currentThread.getContextClassLoader(); - ResourceAccessor accessor = new ClassLoaderResourceAccessor(classLoader); - - DataSource dataSource = getDataSource(sessionFactory); - connection = dataSource.getConnection(); JdbcConnection jdbcConnection = new JdbcConnection(connection); - Database database = DatabaseFactory.getInstance().findCorrectDatabaseImplementation(jdbcConnection); if (database instanceof PostgresDatabase) { @@ -73,17 +62,14 @@ public class StartupService implements Managed { database.setConnection(jdbcConnection); } + ResourceAccessor accessor = new ClassLoaderResourceAccessor(Thread.currentThread().getContextClassLoader()); Liquibase liq = new Liquibase("migrations.xml", accessor, database); liq.update("prod"); - } finally { - if (connection != null) { - connection.close(); - } + } catch (Exception e) { + throw new RuntimeException(e); } - - } catch (Exception e) { - throw new RuntimeException(e); - } + }); + session.close(); } private void initialData() { @@ -103,15 +89,4 @@ public class StartupService implements Managed { public void stop() throws Exception { } - - private static DataSource getDataSource(SessionFactory sessionFactory) { - if (sessionFactory instanceof SessionFactoryImpl) { - ConnectionProvider cp = ((SessionFactoryImpl) sessionFactory).getConnectionProvider(); - if (cp instanceof DatasourceConnectionProviderImpl) { - return ((DatasourceConnectionProviderImpl) cp).getDataSource(); - } - } - return null; - } - }