深入底层,仿MyBatis自己写框架

前言:

最近研究了一下Mybatis的底层代码,写了一个操作数据库的小工具,实现了Mybatis的部分功能:

1.SQL语句在mapper.xml中配置。

2.支持int,String,自定义数据类型的入参。

3.根据mapper.xml动态创建接口的代理实现对象。

功能有限,目的是搞清楚MyBatis框架的底层思想,多学习研究优秀框架的实现思路,对提升自己的编码能力大有裨益。

小工具使用到的核心技术点: xml解析+反射+jdk动态代理

接下来,一步一步来实现。

首先来说为什么要使用jdk动态代理。

传统的开发方式:

1.接口定义业务方法。

2.实现类实现业务方法。

3.实例化实现类对象来完成业务操作。

接口:

public interface UserDAO {
    public User get(int id);
}

实现类:

public class UserDAOImpl implements UserDAO{

    @Override
    public User get(int id) {
        Connection conn = JDBCTools.getConnection();
        String sql = "select * from user where id = ?";
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = conn.prepareStatement(sql);
            pstmt.setInt(1, id);
            rs = pstmt.executeQuery();
            if(rs.next()){
                int sid = rs.getInt(1);
                String name = rs.getString(2);
                User user = new User(sid,name);
                return user;
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }finally{
            JDBCTools.release(conn, pstmt, rs);
        }
        return null;
    }

}

测试:

public static void main(String[] args) {

        UserDAO userDAO = new UserDAOImpl();
        User user = userDAO.get(1);
        System.out.println(user);

    }

Mybatis的方式:

1.开发者只需要创建接口,定义业务方法。

2. 不需要创建实现类。

3.具体的业务操作通过配置xml来完成。

接口:

public interface StudentDAO {
    public Student getById(int id);
    public Student getByStudent(Student student);
    public Student getByName(String name);
    public Student getByStudent2(Student student);
}

StudentDAO.xml:

 
 

    
        select * from student where id=#{id}
    

    
        select * from student where id=#{id} and name=#{name}
    

    
        select * from student where name=#{name} and tel=#{tel} 
    

    
        select * from student where name=#{name}
    


测试:

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student stu = studentDAO.getById(1);
        System.out.println(stu);

    }

通过以上代码可以看到, MyBatis的方式省去了实现类的创建,改为用xml来定义业务方法的具体实现。

那么问题来了。

我们知道Java是面向对象的编程语言, 程序在运行时执行业务方法,必须要有实例化的对象。 但是,接口是不能被实例化的,而且也没有接口的实现类,那么此时这个对象从哪来呢?

程序在运行时,动态创建代理对象。

即jdk动态代理,运行时结合接口和mapper.xml来动态创建一个代理对象,程序调用该代理对象的方法来完成业务。

如何使用jdk动态代理?

创建一个类,实现InvocationHandler接口,该类就具备了创建动态代理对象的功能。

两个核心方法:

1.自定义getInstance方法:入参为目标对象,通过Proxy.newProxyInstance方法创建代理对象,并返回。

    public Object getInstance(Class cls){
        Object newProxyInstance = Proxy.newProxyInstance(  
                cls.getClassLoader(),  
                new Class[] { cls }, 
                this); 
        return (Object)newProxyInstance;
    }

2.实现接口的invoke方法,通过反射机制完成业务逻辑代码。

   @Override
    public Object invoke(Object proxy, Method method, Object[] args)
            throws Throwable {
        // TODO Auto-generated method stub
        return null;
    }

invoke方法是核心代码,在该方法中实现具体的业务需求。接下来我们来看如何实现。

既然是对数据库进行操作,则一定需要数据库连接对象,数据库相关信息配置在config.xml中。

所以invoke方法第一步,就是要解析config.xml,创建数据库连接对象,使用C3P0数据库连接池。

    //读取C3P0数据源配置信息
    public static Map getC3P0Properties(){
        Map map = new HashMap();
        SAXReader reader = new SAXReader();
        try {
            Document document = reader.read("src/config.xml");
            //获取根节点
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element e = (Element) iter.next();
                //解析environments节点
                if("environments".equals(e.getName())){
                    Iterator iter2 = e.elementIterator();
                    while(iter2.hasNext()){
                        //解析environment节点
                        Element e2 = (Element) iter2.next();
                        Iterator iter3 = e2.elementIterator();
                        while(iter3.hasNext()){
                            Element e3 = (Element) iter3.next();
                            //解析dataSource节点
                            if("dataSource".equals(e3.getName())){
                                if("POOLED".equals(e3.attributeValue("type"))){
                                    Iterator iter4 = e3.elementIterator();
                                    //获取数据库连接信息
                                    while(iter4.hasNext()){
                                        Element e4 = (Element) iter4.next();
                                        map.put(e4.attributeValue("name"),e4.attributeValue("value"));
                                    }
                                }
                            }
                        }
                    }
                }
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return map; 
    }
//获取C3P0信息,创建数据源对象
Map map = ParseXML.getC3P0Properties();
ComboPooledDataSource datasource = new ComboPooledDataSource();
datasource.setDriverClass(map.get("driver"));
datasource.setJdbcUrl(map.get("url"));
datasource.setUser(map.get("username"));
datasource.setPassword(map.get("password"));
datasource.setInitialPoolSize(20);
datasource.setMaxPoolSize(40);
datasource.setMinPoolSize(2);
datasource.setAcquireIncrement(5);
Connection conn = datasource.getConnection();

有了数据库连接,接下来就需要获取待执行的SQL语句,SQL的定义全部写在StudentDAO.xml中,继续解析xml,执行SQL语句。

SQL执行完毕,查询结果会保存在ResultSet中,还需要将ResultSet对象中的数据进行解析,封装到JavaBean中返回。

两步完成:

1.反射机制创建Student对象。

2.通过反射动态执行类中所有属性的setter方法,完成赋值。

这样就将ResultSet中的数据封装到JavaBean中了。

//获取sql语句
String sql = element.getText();
//获取参数类型
String parameterType = element.attributeValue("parameterType");
//创建pstmt
PreparedStatement pstmt = createPstmt(sql,parameterType,conn,args);
ResultSet rs = pstmt.executeQuery();
if(rs.next()){
    //读取返回数据类型
    String resultType = element.attributeValue("resultType");   
    //反射创建对象
    Class clazz = Class.forName(resultType);
    obj = clazz.newInstance();
    //获取ResultSet数据
    ResultSetMetaData rsmd = rs.getMetaData();
    //遍历实体类属性集合,依次将结果集中的值赋给属性
    Field[] fields = clazz.getDeclaredFields();
    for(int i = 0; i < fields.length; i++){
        Object value = setFieldValueByResultSet(fields[i],rsmd,rs);
        //通过属性名找到对应的setter方法
        String name = fields[i].getName();
        name = name.substring(0, 1).toUpperCase() + name.substring(1);
        String MethodName = "set"+name;
        Method methodObj = clazz.getMethod(MethodName,fields[i].getType());
        //调用setter方法完成赋值
        methodObj.invoke(obj, value);
        }
}

代码的实现大致思路如上所述,具体实现起来有很多细节需要处理。 使用到两个自定义工具类:ParseXML,MyInvocationHandler。

完整代码:

ParseXML

public class ParseXML {

    //读取C3P0数据源配置信息
    public static Map getC3P0Properties(){
        Map map = new HashMap();
        SAXReader reader = new SAXReader();
        try {
            Document document = reader.read("src/config.xml");
            //获取根节点
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element e = (Element) iter.next();
                //解析environments节点
                if("environments".equals(e.getName())){
                    Iterator iter2 = e.elementIterator();
                    while(iter2.hasNext()){
                        //解析environment节点
                        Element e2 = (Element) iter2.next();
                        Iterator iter3 = e2.elementIterator();
                        while(iter3.hasNext()){
                            Element e3 = (Element) iter3.next();
                            //解析dataSource节点
                            if("dataSource".equals(e3.getName())){
                                if("POOLED".equals(e3.attributeValue("type"))){
                                    Iterator iter4 = e3.elementIterator();
                                    //获取数据库连接信息
                                    while(iter4.hasNext()){
                                        Element e4 = (Element) iter4.next();
                                        map.put(e4.attributeValue("name"),e4.attributeValue("value"));
                                    }
                                }
                            }
                        }
                    }
                }
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return map; 
    }

    //根据接口查找对应的mapper.xml
    public static String getMapperXML(String className){
        //保存xml路径
        String xml = "";
        SAXReader reader = new SAXReader();
        Document document;
        try {
            document = reader.read("src/config.xml");
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element mappersElement = (Element) iter.next();
                if("mappers".equals(mappersElement.getName())){
                    Iterator iter2 = mappersElement.elementIterator();
                    while(iter2.hasNext()){
                        Element mapperElement = (Element) iter2.next();
                        //com.southwin.dao.UserDAO . 替换 #
                        className = className.replace(".", "#");
                        //获取接口结尾名
                        String classNameEnd = className.split("#")[className.split("#").length-1];
                        String resourceName = mapperElement.attributeValue("resource");
                        //获取resource结尾名
                        String resourceName2 = resourceName.split("/")[resourceName.split("/").length-1];
                        //UserDAO.xml . 替换 #
                        resourceName2 = resourceName2.replace(".", "#");
                        String resourceNameEnd = resourceName2.split("#")[0];
                        if(classNameEnd.equals(resourceNameEnd)){
                            xml="src/"+resourceName;
                        }
                    }
                }
            }
        } catch (DocumentException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return xml;
    }
}

MyInvocationHandler:

public class MyInvocationHandler implements InvocationHandler{

    private String className;

    public Object getInstance(Class cls){
        //保存接口类型
        className = cls.getName();
        Object newProxyInstance = Proxy.newProxyInstance(  
                cls.getClassLoader(),  
                new Class[] { cls }, 
                this); 
        return (Object)newProxyInstance;
    }

    public Object invoke(Object proxy, Method method, Object[] args)  throws Throwable {        
        SAXReader reader = new SAXReader();
        //返回结果
        Object obj = null;
        try {
            //获取对应的mapper.xml
            String xml = ParseXML.getMapperXML(className);
            Document document = reader.read(xml);
            Element root = document.getRootElement();
            Iterator iter = root.elementIterator();
            while(iter.hasNext()){
                Element element = (Element) iter.next();
                String id = element.attributeValue("id");
                if(method.getName().equals(id)){
                    //获取C3P0信息,创建数据源对象
                    Map map = ParseXML.getC3P0Properties();
                    ComboPooledDataSource datasource = new ComboPooledDataSource();
                    datasource.setDriverClass(map.get("driver"));
                    datasource.setJdbcUrl(map.get("url"));
                    datasource.setUser(map.get("username"));
                    datasource.setPassword(map.get("password"));
                    datasource.setInitialPoolSize(20);
                    datasource.setMaxPoolSize(40);
                    datasource.setMinPoolSize(2);
                    datasource.setAcquireIncrement(5);
                    Connection conn = datasource.getConnection();
                    //获取sql语句
                    String sql = element.getText();
                    //获取参数类型
                    String parameterType = element.attributeValue("parameterType");
                    //创建pstmt
                    PreparedStatement pstmt = createPstmt(sql,parameterType,conn,args);
                    ResultSet rs = pstmt.executeQuery();
                    if(rs.next()){
                        //读取返回数据类型
                        String resultType = element.attributeValue("resultType");   
                        //反射创建对象
                        Class clazz = Class.forName(resultType);
                        obj = clazz.newInstance();
                        //获取ResultSet数据
                        ResultSetMetaData rsmd = rs.getMetaData();
                        //遍历实体类属性集合,依次将结果集中的值赋给属性
                        Field[] fields = clazz.getDeclaredFields();
                        for(int i = 0; i < fields.length; i++){
                            Object value = setFieldValueByResultSet(fields[i],rsmd,rs);
                            //通过属性名找到对应的setter方法
                            String name = fields[i].getName();
                            name = name.substring(0, 1).toUpperCase() + name.substring(1);
                            String MethodName = "set"+name;
                            Method methodObj = clazz.getMethod(MethodName,fields[i].getType());
                            //调用setter方法完成赋值
                            methodObj.invoke(obj, value);
                        }
                    }
                    conn.close();
                }
            }
        } catch (Exception e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

       return obj;
    }

    /**
     * 根据条件创建pstmt
     * @param sql
     * @param parameterType
     * @param conn
     * @param args
     * @return
     * @throws Exception
     */
    public PreparedStatement createPstmt(String sql,String parameterType,Connection conn,Object[] args) throws Exception{
        PreparedStatement pstmt = null;
        try {
            switch(parameterType){
                case "int":
                    int start = sql.indexOf("#{");
                    int end = sql.indexOf("}");
                    //获取参数占位符 #{name}
                    String target = sql.substring(start, end+1);
                    //将参数占位符替换为?
                    sql = sql.replace(target, "?");
                    pstmt = conn.prepareStatement(sql);
                    int num = Integer.parseInt(args[0].toString());
                    pstmt.setInt(1, num);
                    break;
                case "java.lang.String":
                    int start2 = sql.indexOf("#{");
                    int end2 = sql.indexOf("}");
                    String target2 = sql.substring(start2, end2+1);
                    sql = sql.replace(target2, "?");
                    pstmt = conn.prepareStatement(sql);
                    String str = args[0].toString();
                    pstmt.setString(1, str);
                    break;
                default:
                    Class clazz = Class.forName(parameterType);
                    Object obj = args[0];
                    boolean flag = true;
                    //存储参数
                    List values = new ArrayList();
                    //保存带#的sql
                    String sql2 = "";
                    while(flag){
                        int start3 = sql.indexOf("#{");
                        //判断#{}是否替换完成
                        if(start3<0){
                            flag = false;
                            break;
                        }
                        int end3 = sql.indexOf("}");
                        String target3 = sql.substring(start3, end3+1);
                        //获取#{}的值 如#{name}拿到name
                        String name = sql.substring(start3+2, end3);
                        //通过反射获取对应的getter方法
                        name = name.substring(0, 1).toUpperCase() + name.substring(1);
                        String MethodName = "get"+name;
                        Method methodObj = clazz.getMethod(MethodName);
                        //调用getter方法完成赋值
                        Object value = methodObj.invoke(obj);
                        values.add(value);
                        sql = sql.replace(target3, "?");
                        sql2 = sql.replace("?", "#");
                    }
                    //截取sql2,替换参数
                    String[] sqls = sql2.split("#");
                    pstmt = conn.prepareStatement(sql);
                    for(int i = 0; i < sqls.length-1; i++){
                        Object value = values.get(i);
                        if("java.lang.String".equals(value.getClass().getName())){
                            pstmt.setString(i+1, (String)value);
                        }
                        if("java.lang.Integer".equals(value.getClass().getName())){
                            pstmt.setInt(i+1, (Integer)value);
                        }
                    }
                    break;
                }
        } catch (SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return pstmt;
    }

    /**
     * 根据将结果集中的值赋给对应的属性
     * @param field
     * @param rsmd
     * @param rs
     * @return
     */
    public Object setFieldValueByResultSet(Field field,ResultSetMetaData rsmd,ResultSet rs){
        Object result = null;
        try {
            int count = rsmd.getColumnCount();
            for(int i=1;i<=count;i++){
                if(field.getName().equals(rsmd.getColumnName(i))){
                    String type = field.getType().getName();
                    switch (type) {
                        case "int":
                            result = rs.getInt(field.getName());
                            break;
                        case "java.lang.String":
                            result = rs.getString(field.getName());
                            break;
                    default:
                        break;
                    }
                }
            }
        } catch (SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        return result;
    }


}

代码测试:

StudnetDAO.getById

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student stu = studentDAO.getById(1);
        System.out.println(stu);

    }

代码中的studentDAO为动态代理对象,此对象通过 MyInvocationHandler().getInstance(StudentDAO.class)方法动态创建, 并且结合StudentDAO.xml实现了StudentDAO接口的全部方法,直接调用studentDAO对象的方法即可完成业务需求。

StudnetDAO.getByName

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student stu = studentDAO.getByName("李四");
        System.out.println(stu);

    }

StudnetDAO.getByStudent(根据id和name查询)

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student student = new Student();
        student.setId(1);
        student.setName("张三");
        Student stu = studentDAO.getByStudent(student);
        System.out.println(stu);

    }

StudnetDAO.getByStudent2(根据name和tel查询)

public static void main(String[] args) {

        StudentDAO studentDAO = (StudentDAO) new MyInvocationHandler().getInstance(StudentDAO.class);
        Student student = new Student();
        student.setName("李四");
        student.setTel("18367895678");
        Student stu = studentDAO.getByStudent2(student);
        System.out.println(stu);

    }

以上就是仿MyBatis实现自定义小工具的大致思路,细节之处还需具体查看源码,最后附上小工具源码链接。

源码:

链接:  https://pan.baidu.com/s/ 1pMz0FDh  

密码:  fnjb