package tech.yixiyun.framework.kuafu.db.sql.mysql;

import cn.hutool.core.util.ArrayUtil;
import com.fasterxml.jackson.databind.JavaType;
import org.apache.commons.dbutils.handlers.BeanHandler;
import org.apache.commons.dbutils.handlers.BeanListHandler;
import tech.yixiyun.framework.kuafu.bean.annotation.Bean;
import tech.yixiyun.framework.kuafu.bean.annotation.ClassInitialize;
import tech.yixiyun.framework.kuafu.controller.request.param.ParamType;
import tech.yixiyun.framework.kuafu.db.DbException;
import tech.yixiyun.framework.kuafu.db.DbKit;
import tech.yixiyun.framework.kuafu.db.datasource.DbType;
import tech.yixiyun.framework.kuafu.db.session.DbSessionContext;
import tech.yixiyun.framework.kuafu.db.sql.Sql;
import tech.yixiyun.framework.kuafu.db.sql.SqlException;
import tech.yixiyun.framework.kuafu.db.sql.SqlExecutor;
import tech.yixiyun.framework.kuafu.db.sql.handler.*;
import tech.yixiyun.framework.kuafu.domain.BaseDomain;
import tech.yixiyun.framework.kuafu.domain.ColumnDefinition;
import tech.yixiyun.framework.kuafu.domain.DomainContext;
import tech.yixiyun.framework.kuafu.domain.DomainDefinition;
import tech.yixiyun.framework.kuafu.kits.StringKit;

import java.io.Serializable;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

/**
 * Mysql数据库的语句执行器
 *
 * @author Yixiyun
 * @version 1.0
 * @date 2021-05-04 12:51
 */
@Bean(lazyInit = true)
@ClassInitialize
public class MysqlExecutor extends SqlExecutor {

    static {
        SqlExecutor.registerExecutor(DbType.MYSQL, MysqlExecutor.class);
    }


    @Override
    public int execute(String dataSourceName, String statement, Serializable[] args) {
        if (StringKit.isBlank(statement)) return 0;
        try {
            Connection connection = DbSessionContext.getConnection(dataSourceName);
            return DbKit.RUNNER.executeCommon(connection, statement, args);
        } catch (SQLException e) {
            throw new DbException(e);
        }
    }


    @Override
    public int execute(Sql sql) {
        String statement = sql.getStatement();
        if (StringKit.isBlank(statement)) return 0;
        Serializable[] args = sql.getArgs();
        return execute(sql.getDataSourceName(), statement, args);
    }




    @Override
    public void createTable(String dataSourceName, Class<? extends BaseDomain> domainClass, Serializable... args) {
        execute(Sql.build(dataSourceName, domainClass, args).create());
    }


    @Override
    public void alterTable(String dataSourceName, Class<? extends BaseDomain> domainClass, Serializable... args) {
        execute(Sql.build(dataSourceName, domainClass, args).alter());
    }


    @Override
    public boolean tableExist(String dataSourceName, String tableName) {
        Connection connection = DbSessionContext.getConnection(dataSourceName);
        try {
            DatabaseMetaData metaData = connection.getMetaData();
            ResultSet tables = metaData.getTables(connection.getCatalog(),null,  tableName, new String[]{"TABLE"});
            return tables.next();
        } catch (Exception e) {
            throw new SqlException("检查表是否存在时发生异常", e);
        }
    }

    @Override
    public  boolean columnExist(String dataSourceName, String tableName, String columnName) {
        Connection connection = DbSessionContext.getConnection(dataSourceName);
        try {
            DatabaseMetaData metaData = connection.getMetaData();
            ResultSet tables = metaData.getColumns(connection.getCatalog(),null,  tableName, columnName);
            return tables.next();
        } catch (Exception e) {
            throw new SqlException("检查表字段是否存在时发生异常", e);
        }
    }


    @Override
    public Object insertOne(String dataSourceName, String tableName, BaseDomain instance) {
        if (instance == null ) return null;
        Sql sql = Sql.build(dataSourceName, tableName).insert(instance);
        DomainDefinition definition = DomainContext.getDomainDefinition(instance.getClass());
        Field idField = null;
        ColumnDefinition column = definition.getColumn(definition.getAutoIncrementColumn());
        if (column != null) {
            idField = column.getField();
        }
        Object id = null;
        try {
            String statement = sql.getStatement();
            Object[] args = sql.getArgs();

            id = DbKit.RUNNER.insert(DbSessionContext.getConnection(dataSourceName), statement, new IdColumnHandler(idField, instance), args);
        } catch (Exception e) {
            throw new SqlException("插入数据时发生异常", e);
        }
        return id;
    }


    @Override
    public int insert(String dataSourceName, String tableName, String[] cols, Sql querySql) {
        Sql sql = Sql.build(dataSourceName, tableName).insert(cols, querySql);
        try {
            return DbKit.RUNNER.executeCommon(DbSessionContext.getConnection(dataSourceName), sql.getStatement(), sql.getArgs());
        } catch (SQLException e) {
            throw new SqlException("插入数据时发生异常", e);
        }
    }


    @Override
    public Long insert(String dataSourceName, String tableName, String[] cols, Serializable[] values) {
        Sql sql = Sql.build(dataSourceName, tableName).insert(cols, values);
        try {
            return (Long) DbKit.RUNNER.insert(DbSessionContext.getConnection(dataSourceName), sql.getStatement(), new IdHandler(), sql.getArgs());
        } catch (SQLException e) {
            throw new SqlException("插入数据时发生异常", e);
        }
    }


    @Override
    public List<? extends BaseDomain> insertBatch(String dataSourceName, String tableName, List<? extends BaseDomain> instances, int batchCount) {
        if (instances == null || instances.isEmpty()) return instances;
        /*
         * 先看看有没有自增列
         */
        DomainDefinition definition = DomainContext.getDomainDefinition(instances.get(0).getClass());
        Field idField = null;
        ColumnDefinition column = definition.getColumn(definition.getAutoIncrementColumn());
        if (column != null) {
            idField = column.getField();
        }
        /*
         * 开始循环处理每一个实例的语句
         */
        String prevSql =null; //上一个语句
        String tempSql = null; //当前语句
        List<Object[]> args = new ArrayList<>(32);
        List<BaseDomain> tempInstances = new ArrayList<>(32);
        for (BaseDomain instance : instances) {
            if (instance == null) continue;
            Sql temp = Sql.build(dataSourceName, tableName).insert(instance);
            tempSql = temp.getStatement();
            if (prevSql == null) {
                prevSql = tempSql;
            }
            /*
             * sql语句一致，就等待一起执行，每次最多100条
             * 不一致就先把之前的执行一下
             */
            if (tempSql.equals(prevSql)) {
                args.add(temp.getArgs());
                tempInstances.add(instance);
                if (args.size() == batchCount) { //满多少就执行一次
                    insertBatch(DbSessionContext.getConnection(dataSourceName), prevSql, args.toArray(new Serializable[0][0]), new IdsColumnHandler(idField, tempInstances));
                    args.clear();
                    tempInstances.clear();
                    prevSql = null;
                }
            } else {
                insertBatch(DbSessionContext.getConnection(dataSourceName), prevSql, args.toArray(new Serializable[0][0]), new IdsColumnHandler(idField, tempInstances));
                args.clear();
                tempInstances.clear();
                prevSql = tempSql;
                args.add(temp.getArgs());
                tempInstances.add(instance);
            }

        }
        /*
         * 还有剩余的就再处理一次
         */
        if (args.isEmpty() == false) {
            insertBatch(DbSessionContext.getConnection(dataSourceName), prevSql, args.toArray(new Serializable[0][0]), new IdsColumnHandler(idField, tempInstances));
        }

        return instances;
    }


    private static void insertBatch(Connection connection, String sql, Serializable[][] values, IdsColumnHandler handler) {
        try {
            DbKit.RUNNER.insertBatch(connection, sql, handler, values);
        } catch (SQLException e) {
            throw new SqlException("批量插入数据时发生异常", e);
        }
    }


    @Override
    public <T> T getOne(String dataSourceName, String statement, Serializable[] args, Class<T> resultClass) {
        Sql sql = Sql.build(dataSourceName).from("(" + statement + ")", "this")
                .page(1);

        try {
            return DbKit.RUNNER.query(DbSessionContext.getConnection(dataSourceName), sql.getStatement(), new BeanHandler<T>(resultClass, BaseRowProcessor.INSTANCE), args);
        } catch (Exception e) {
            if (e instanceof RuntimeException) {
                throw (RuntimeException)e;
            }
            throw new SqlException("数据库查询发生异常", e);
        }
    }


    @Override
    public <T> List<T> getList(String dataSourceName, String statement, Serializable[] args, Class<T> resultClass) {

        try {
            return DbKit.RUNNER.query(DbSessionContext.getConnection(dataSourceName), statement, new BeanListHandler<T>(resultClass, BaseRowProcessor.INSTANCE), args);
        } catch (Exception e) {
            if (e instanceof RuntimeException) {
                throw (RuntimeException)e;
            }
            throw new SqlException("数据库查询发生异常", e);
        }
    }

    /**
     * 执行一条查询语句，获取第一行某一列的数据，并自动转换为resultClass类的对象
     * @param dataSourceName 使用的数据源名称
     * @param statement 要执行的查询语句
     * @param args 执行语句需要传入的参数
     * @param columnClass 将结果转换为的类,支持基本数据类型及包装类和他们的数组、String和String数组、Date、Time、byte[]
     * @param <T>
     * @return
     */
    public <T> T getColumn(String dataSourceName, String statement, Serializable[] args, Class<T> columnClass) {
        try {
            return (T) DbKit.RUNNER.query(DbSessionContext.getConnection(dataSourceName), statement, new ColumnHandler(columnClass), args);
        } catch (Throwable e) {
            throw new SqlException("数据库查询发生异常", e);
        }
    }
    /**
     * 执行一条查询语句，获取第一行某一列的数据，并自动转换为columnType类的对象。适用于转化json格式数据
     * @param dataSourceName 使用的数据源名称
     * @param statement 要执行的查询语句
     * @param args 执行语句需要传入的参数
     * @param columnType 将结果转换为的类,jackson提供的反序列化类
     * @param <T>
     * @return
     */
    public <T> T getColumn(String dataSourceName, String statement, Serializable[] args, JavaType columnType) {
        try {
            return (T) DbKit.RUNNER.query(DbSessionContext.getConnection(dataSourceName), statement, new ColumnHandler(columnType), args);
        } catch (Throwable e) {
            throw new SqlException("数据库查询发生异常", e);
        }
    }


    /**
     * sum统计，sql可以只拼接条件语句
     * @param sql
     * @return
     */
    public  int getCount(Sql sql) {
        Sql countSql = Sql.build(sql);

        countSql.select("count(*)");
        try {
            return (int) DbKit.RUNNER.query(DbSessionContext.getConnection(countSql.getDataSourceName()), countSql.getStatement(), new ColumnHandler(int.class), countSql.getArgs());
        } catch (Throwable e) {
            throw new SqlException("数据库查询发生异常", e);
        }
    }


    @Override
    public <T> T getSum(String columnName, Sql sql, Class<T> resultClass) {
        sql.select("sum(" + columnName + ")");
        return getColumn(sql.getDataSourceName(), sql.getStatement(), sql.getArgs(), resultClass);
    }


    @Override
    public int update(String dataSourceName, String statement, Serializable[] args) {
        try {
            return (int) DbKit.RUNNER.update(DbSessionContext.getConnection(dataSourceName), statement, args);
        } catch (Throwable e) {
            throw new SqlException("数据库执行更新发生异常", e);
        }
    }

    @Override
    public int updateOne(String dataSourceName, String tableName, BaseDomain instance) {
        if (instance == null) return 0;
        /*
         * 更新语句需要根据主键来生成，所以先检测该类是否声明了主键，以及主键是否有值
         */
        DomainDefinition definition = DomainContext.getDomainDefinition(instance.getClass());
        Field[] keys = definition.getPrimaryKeyFields();
        if (keys == null || keys.length == 0) {
            throw new SqlException("Domain类【"+instance.getClass()+"】未指定主键，无法生成update语句");
        }
        try {

            for (Field key : keys) {
                if (key.get(instance) == null) {
                    throw new SqlException("该实例的主键字段值为null， 无法生成update语句");
                }
            }
        } catch (Exception e) {
            if (e instanceof RuntimeException) {
                throw (RuntimeException)e;
            }
            throw new SqlException("该实例的主键字段值获取失败， 无法生成update语句", e);
        }

        Sql sql = Sql.build(dataSourceName, tableName).update();
        List<ColumnDefinition> columns = definition.getNormalColumnList();
        int i = 0; Object v = null;
        String[] nullColumns = instance.getNullColumns();
        try {
            for (ColumnDefinition column : columns) {
                if (column.getIsPrimaryKey()) { //
                    sql.where(column.getName()+"=?", (Serializable)column.getField().get(instance));
                } else {
                    //如果值是null并且未显式指定为null，就忽略，相当于这个字段不更新
                    if (ArrayUtil.contains(nullColumns, column.getName()) ) {
                        sql.set(column.getName(), null);
                    } else if ((v = column.getField().get(instance)) != null) {
                        sql.set(column.getName(), (Serializable)v);
                    }

                }
            }


            return DbKit.RUNNER.update(DbSessionContext.getConnection(dataSourceName), sql.getStatement(), sql.getArgs());
        } catch (Exception e) {
            throw new SqlException("更新数据时发生异常", e);
        }
    }


    @Override
    public int updateBatch(String dataSourceName, String tableName, List<? extends BaseDomain> instances, int batchCount) {
        int count = 0;
        /*
         * 更新语句需要根据主键来生成，所以先检测该类是否声明了主键，以及主键是否有值
         */
        Class<? extends BaseDomain> clazz = instances.get(0).getClass();
        DomainDefinition definition = DomainContext.getDomainDefinition(clazz);
        Field[] keys = definition.getPrimaryKeyFields();
        if (keys == null || keys.length == 0) {
            throw new SqlException("Domain类【"+clazz+"】未指定主键，无法生成update语句");
        }
        try {
            for (BaseDomain instance : instances) {
                for (Field key : keys) {
                    if (key.get(instance) == null) {
                        throw new SqlException("有实例的主键值为null，无法生成update语句");
                    }
                }
            }

            /*
             * 开始循环处理每一个实例的语句
             */
            String prevSql =null; //上一个语句
            String tempSql = null; //当前语句
            List<Object[]> args = new ArrayList<>(32);
            List<ColumnDefinition> columns = definition.getNormalColumnList();
            Object v = null;
            for (BaseDomain instance : instances) {
                if (instance == null) continue;
                count += 1;
                String[] nullColumns = instance.getNullColumns();
                Sql temp = Sql.build(dataSourceName, tableName).update();
                for (ColumnDefinition column : columns) {
                    if (column.getIsPrimaryKey()) { //
                        temp.where(column.getName()+"=?", (Serializable) column.getField().get(instance));
                    } else {
                        //如果值是null并且未显式指定为null，就忽略，相当于这个字段不更新
                        if (ArrayUtil.contains(nullColumns, column.getName()) ) {
                            temp.set(column.getName(), null);
                        } else if ((v = column.getField().get(instance)) != null) {
                            temp.set(column.getName(), (Serializable)v);
                        }

                    }
                }


                tempSql = temp.getStatement();
                if (prevSql == null) {
                    prevSql = tempSql;
                }
                /*
                 * sql语句一致，就等待一起执行，每次最多batchCount条
                 * 不一致就先把之前的执行一下
                 */
                if (tempSql.equals(prevSql)) {
                    args.add(temp.getArgs());
                    if (args.size() == batchCount) { //满多少就执行一次
                        DbKit.RUNNER.batch(DbSessionContext.getConnection(dataSourceName), prevSql, args.toArray(new Object[0][0]));
                        args.clear();
                        prevSql = null;
                    }
                } else {
                    DbKit.RUNNER.batch(DbSessionContext.getConnection(dataSourceName), prevSql, args.toArray(new Object[0][0]));
                    args.clear();
                    prevSql = tempSql;
                    args.add(temp.getArgs());
                }

            }
            /*
             * 还有剩余的就再处理一次
             */
            if (args.isEmpty() == false) {
                DbKit.RUNNER.batch(DbSessionContext.getConnection(dataSourceName), prevSql, args.toArray(new Object[0][0]));
            }

        } catch (Exception e) {
            if (e instanceof RuntimeException) {
                throw (RuntimeException)e;
            }
            throw new RuntimeException(e);
        }



        return count;
    }


    @Override
    public BaseDomain saveOne(String dataSourceName, String tableName, BaseDomain instance) {
        if (instance == null) return null;
        DomainDefinition definition = DomainContext.getDomainDefinition(instance.getClass());
        Field[] keys = definition.getPrimaryKeyFields();
        //如果没设计主键，直接插入
        if (keys == null || keys.length == 0) {
            insertOne(dataSourceName, tableName, instance);
            return instance;
        }
        //判断主键是否有值，有值就更新，没值就插入
        boolean toUpdate = true;
        try {
            for (Field field : keys) {
                if (field.get(instance) == ParamType.getDefaultValue(field.getType())) {
                    toUpdate = false;
                    break;
                }
            }
        } catch (Throwable e) {
            throw new SqlException("save数据时发生异常", e);
        }

        if (toUpdate) {
            updateOne(dataSourceName, tableName, instance);
        } else {
            insertOne(dataSourceName, tableName, instance);
        }
        return instance;
    }

    @Override
    public int saveBatch(String dataSourceName, String tableName, List<? extends BaseDomain> instances) {
        if (instances == null || instances.isEmpty()) return 0;
        int count = 0;
        for (BaseDomain instance : instances) {
            if (instance != null) {
                saveOne(dataSourceName, tableName, instance);
                count += 1;
            }
        }
        return count;
    }


    @Override
    public int del(Sql sql) {
        sql.delete();
        return execute(sql);
    }


    @Override
    public int delOne(String dataSourceName, String tableName, BaseDomain instance) {
        if (instance == null) return 0;
        /*
         * 更新语句需要根据主键来生成，所以先检测该类是否声明了主键，以及主键是否有值
         */
        DomainDefinition definition = DomainContext.getDomainDefinition(instance.getClass());
        ColumnDefinition[] keys = definition.getPrimaryKeyColDefinitions();
        if (keys == null || keys.length == 0) {
            throw new SqlException("Domain类【"+instance.getClass()+"】未指定主键，无法生成delete语句");
        }
        Sql sql = Sql.build(dataSourceName, tableName);
        try {

            for (ColumnDefinition key : keys) {
                Serializable value = (Serializable) key.getField().get(instance);
                if (value == null) {
                    throw new SqlException("该实例的主键字段值为null， 无法生成delete语句");
                }
                sql.eq(key.getName(), value);
            }
        } catch (Exception e) {
            if (e instanceof RuntimeException) {
                throw (RuntimeException)e;
            }
            throw new SqlException("该实例的主键字段值获取失败， 无法生成delete语句", e);
        }

        return del(sql);
    }
}
