package tech.yixiyun.framework.kuafu.db.session;

import tech.yixiyun.framework.kuafu.kits.ObjectKit;
import tech.yixiyun.framework.kuafu.log.LOGGER;

import java.sql.Connection;
import java.util.Deque;
import java.util.LinkedList;

/**
 * 全局数据库会话，如果是手动创建的会话，用完一定记得clear!!!!!
 *
 * @author Yixiyun
 * @version 1.0
 * @date 2021-04-24 15:11
 */
public class DbSessionContext {

    //每个线程维护一个session栈
    private static final ThreadLocal<Deque<DbSession>> SESSION_HOLDER = new ThreadLocal<>();


    /**
     * 获取当前线程的会话
     * @return
     */
    public static DbSession getSession() {
        Deque<DbSession> stack = SESSION_HOLDER.get();
        return ObjectKit.ifNotNull(stack, s -> s.peek());
    }

    /**
     * 获取当前线程的会话，如果创建过，直接返回。如果没创建过，就创建一个默认Session。
     * 如果是手动创建的会话，用完一定记得clear!!!!!
     * @return
     */
    public static DbSession getSessionWithDefault() {
        DbSession dbSession = getSession();
        if (dbSession == null) {
            dbSession = new DbSession();

            insertSession(dbSession);
        }
        return dbSession;
    }

    /**
     * 插入一个新会话
     * @param dbSession
     */
    public static void insertSession(DbSession dbSession) {
        Deque<DbSession> stack = SESSION_HOLDER.get();
        if (stack == null) {
            stack = new LinkedList<>();
            SESSION_HOLDER.set(stack);
        }
        stack.addFirst(dbSession);
    }

    /**
     * 从当前线程的session栈中移除最顶层的session
     * @return
     */
    public static DbSession removeSession() {
        Deque<DbSession> stack = SESSION_HOLDER.get();
        DbSession session = stack == null ? null : stack.poll();
        if (session != null) {
            session.clear();
//            LOGGER.trace("session已被关闭：{}", session.hashCode());
        }
        return session;
    }



    /**
     * 从当前会话中，根据数据源名称，从这个数据源获取数据库连接
     * @param dataSourceName
     * @return
     */
    public static Connection getConnection(String dataSourceName) {

        DbSession dbSession = getSessionWithDefault();

        Connection connection = dbSession.getConnection(dataSourceName);
//        LOGGER.trace("获取到数据库连接：{}, 来源session：{}", connection.hashCode(), dbSession.hashCode());
        return connection;
    }

    /**
     * 将当前会话的所有connection回滚
     */
    public static void rollback() {
        DbSession dbSession = getSession();
        if (dbSession != null) {
            dbSession.rollback();
            LOGGER.warn("事务发生回滚：{}", dbSession);
        }
    }

    /**
     * 提交当前会话的所有connection
     */
    public static void commit() {
        DbSession dbSession = getSession();
        if (dbSession != null) {
            dbSession.commit();
//            LOGGER.trace("事务提交了：{}", dbSession);
        }
    }
    /**
     * 重置当前session中的数据库连接<br/>
     * 注意：这个方法不负责提交事务，回滚事务，只负责重置连接
     */
    public static void resetSession() {
        DbSession session = getSession();
        if (session != null) {
            session.reset();
//            LOGGER.trace("session已被重置");
        }
    }
}
