自己实现了一个简易的MySQL数据操作中间层,经过近一年的线上使用和维护,功能已比较完善,性能方面也没有发现大的问题。诚然类似的开源工具有很多,但对于想快速了解其实现原理的同学来说,本文可以成为你的一个切入口。

ORM实体关系映射

import java.util.Date;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

@DataBase(name = DBName.ZHU_ZHAN)
@Table(name = "position")
public class Position {

    private static Log logger=LogFactory.getLog(Position.class);
    @Id
    private int id;
    private int companyId;
    @Column("createTime")
    private Date refreshTime;//最后一次的刷新时间
    @NotColumn
    private String refreshTimeStr;
}

类注解@DataBase和@Table分别注明该类跟哪个库哪张表对应。

  • 可增。refreshTimeStr是数据库中不存在的字段,加上@NotColumn注解。static成员变量不在数据库中,不需要加@NotColumn。从DB中select出数据后不会给实体的static变量和@NotColumn赋值。
  • 可减。数据库position表中还有其他很多字段,这里Position中都可以没有。当“select *”时实际上提交的请求是“select id,companyId,refreshTime”。
  • 可不同。数据库中的字段名是createTime,但我们在代码中用refreshTime更好理解。
  • 必须注明@Id。在进行update时需要知道主键。

数据访问

   1 import java.io.Serializable;
   2 import java.lang.reflect.Field;
   3 import java.lang.reflect.Modifier;
   4 import java.lang.reflect.ParameterizedType;
   5 import java.sql.ResultSet;
   6 import java.sql.SQLException;
   7 import java.sql.Statement;
   8 import java.sql.Timestamp;
   9 import java.text.SimpleDateFormat;
  10 import java.util.ArrayList;
  11 import java.util.Date;
  12 import java.util.HashMap;
  13 import java.util.HashSet;
  14 import java.util.List;
  15 import java.util.Map;
  16 import java.util.TimeZone;
  17 import java.util.Map.Entry;
  18 import java.util.Set;
  19 import java.util.concurrent.Callable;
  20 import java.util.concurrent.ExecutionException;
  21 import java.util.concurrent.ExecutorService;
  22 import java.util.concurrent.Executors;
  23 import java.util.concurrent.Future;
  24 import java.util.concurrent.TimeUnit;
  25 import java.util.concurrent.TimeoutException;
  26 
  27 import org.apache.commons.lang.StringUtils;
  28 import org.apache.commons.logging.Log;
  29 import org.apache.commons.logging.LogFactory;
  30 
  35 
  36 /**
  37  * 
  38  * @Author:orisun
  39  * @Since:2015-9-29
  40  * @Version:1.0
  41  */
  42 public class BaseDao<T, PK extends Serializable> {
  43 
  44     private static Log logger = LogFactory.getLog(BaseDao.class);
  45     protected final Class<T> aclass;
  46     protected final String TABLE;
  47     // SQL中的关键字,防SQL攻击
  48     private static Set<String> sqlKeywords = new HashSet<String>();
  49     private Map<String, Field> column2Field = new HashMap<String, Field>();// DB字段名=》类属性名
  50     private String allColumns = "";
  51     private final SimpleDateFormat sdf = new SimpleDateFormat("yyyyMMddHHmmss");
  52     private static Set<Class<?>> validType = new HashSet<Class<?>>(); // 若要和DB类型对应,合法的java类型
  53     private static ExecutorService exec = Executors.newCachedThreadPool();
  54     private KVreport kvReporter = KVreport.getReporter();
  55     private int dbErrorKey = SystemConfig.getIntValue("db_error_key", -1);// 每次DB操作发生异常时上报
  56     private int dbTimeKey = SystemConfig.getIntValue("db_time_key", -1);// 每次的DB操作耗时都上报
  57 
  58     static {
  59         sqlKeywords.add("and");
  60         sqlKeywords.add("or");
  61         sqlKeywords.add("insert");
  62         sqlKeywords.add("select");
  63         sqlKeywords.add("delete");
  64         sqlKeywords.add("update");
  65         sqlKeywords.add("count");
  66         sqlKeywords.add("chr");
  67         sqlKeywords.add("mid");
  68         sqlKeywords.add("truncate");
  69         sqlKeywords.add("trunc");
  70         sqlKeywords.add("char");
  71         sqlKeywords.add("declare");
  72         sqlKeywords.add("like");
  73         sqlKeywords.add("%");
  74         sqlKeywords.add("<");
  75         sqlKeywords.add(">");
  76         sqlKeywords.add("=");
  77         sqlKeywords.add("\"");
  78         sqlKeywords.add("'");
  79         sqlKeywords.add(")");
  80         sqlKeywords.add("(");
  81         // 防止Xss攻击
  82         sqlKeywords.add("script");
  83         sqlKeywords.add("alert");
  84 
  85         validType.add(int.class);
  86         validType.add(Integer.class);
  87         validType.add(byte.class);
  88         validType.add(Byte.class);
  89         validType.add(Float.class);
  90         validType.add(float.class);
  91         validType.add(Short.class);
  92         validType.add(short.class);
  93         validType.add(Long.class);
  94         validType.add(long.class);
  95         validType.add(String.class);
  96         validType.add(Double.class);
  97         validType.add(double.class);
  98         validType.add(Date.class);
  99         validType.add(Timestamp.class);
 100     }
 101 
 102     /**
 103      * 判断str中是否包含SQL关键字
 104      * 
 105      * @param str
 106      * @return
 107      */
 108     protected boolean containSql(String str) {
 109         String[] arr = str.split("\\s+");
 110         for (String ele : arr) {
 111             if (sqlKeywords.contains(ele)) {
 112                 return true;
 113             }
 114         }
 115         return false;
 116     }
 117 
 118     @SuppressWarnings("unchecked")
 119     public BaseDao() throws Exception {
 120         // 获得超类的泛型参数(即T和PK)的首元素的实际类型(即T在运行时对应的实际类型)
 121         this.aclass = (Class<T>) ((ParameterizedType) getClass()
 122                 .getGenericSuperclass()).getActualTypeArguments()[0];
 123         if (aclass.isAnnotationPresent(Table.class)) {
 124             Table table = (Table) aclass.getAnnotation(Table.class);
 125             String name = table.name();
 126             if (name != null) {
 127                 this.TABLE = name;
 128             } else {
 129                 this.TABLE = "";
 130             }
 131         } else {
 132             this.TABLE = "";
 133         }
 134 
 135         if (this.TABLE == null || "".equals(this.TABLE)) {
 136             throw new Exception("have not specify the table name for "
 137                     + aclass.getCanonicalName());
 138         }
 139         Field[] fileds = aclass.getDeclaredFields();
 140         for (int i = 0; i < fileds.length; i++) {
 141             Field field = fileds[i];
 142             field.setAccessible(true);
 143             String columnName = field.getName();
 144             // 丢弃2种成员变量:静态和带NotColumn注解的
 145             if (!field.isAnnotationPresent(NotColumn.class)
 146                     && (field.getModifiers() & Modifier.STATIC) != Modifier.STATIC) {
 147                 if (field.isAnnotationPresent(Column.class)) {
 148                     columnName = field.getAnnotation(Column.class).value();
 149                 }
 150                 if (field.isAnnotationPresent(Id.class)
 151                         && field.getAnnotation(Id.class).auto_increment() == true) {
 152                     assert field.getType() == Integer.class
 153                             || field.getType() == Long.class;
 154                 }
 155                 column2Field.put(columnName.toLowerCase(), field);
 156             }
 157         }
 158         allColumns = StringUtils.join(column2Field.keySet(), ",");
 159     }
 160 
 161     /**
 162      * 获取一个主库连接
 163      * 
 164      * @return
 165      * @throws SQLException
 166      */
 167     public PooledConnection getMasterConn() throws SQLException {
 168         ConnectionPools pools = DaoHelperPool.getConnPool(aclass);
 169         if (pools != null) {
 170             PooledConnection conn = pools.getMasterPool().getConnection();
 171             return conn;
 172         }
 173         return null;
 174     }
 175 
 176     /**
 177      * 获取一个从库连接。从库连接没有时获取主库连接
 178      * 
 179      * @return
 180      * @throws SQLException
 181      */
 182     public PooledConnection getSlaveConn() throws SQLException {
 183         ConnectionPools pools = DaoHelperPool.getConnPool(aclass);
 184         if (pools != null) {
 185             PooledConnection conn = pools.getSlavePool().getConnection();
 186             return conn;
 187         }
 188         return null;
 189     }
 190 
 191     /**
 192      * 关闭一个从库的物理连接
 193      * 
 194      * @param conn
 195      */
 196     private void closeSlaveConnection(PooledConnection conn) {
 197         ConnectionPools pools = DaoHelperPool.getConnPool(aclass);
 198         if (pools != null) {
 199             ConnectionPool pool = pools.getSlavePool();
 200             pool.closeConnection(conn.getConnection());
 201         }
 202     }
 203 
 204     /**
 205      * 把主库连接返回连接池
 206      * 
 207      * @param conn
 208      */
 209     public void retrunMasterConn(PooledConnection conn) {
 210         ConnectionPools pools = DaoHelperPool.getConnPool(aclass);
 211         if (pools != null) {
 212             ConnectionPool pool = pools.getMasterPool();
 213             pool.returnConnection(conn);
 214         }
 215     }
 216 
 217     /**
 218      * 把从库连接返回连接池
 219      * 
 220      * @param conn
 221      */
 222     public void retrunSlaveConn(PooledConnection conn) {
 223         ConnectionPools pools = DaoHelperPool.getConnPool(aclass);
 224         if (pools != null) {
 225             ConnectionPool pool = pools.getSlavePool();
 226             pool.returnConnection(conn);
 227         }
 228     }
 229 
 230     /**
 231      * 分页读取数据<br>
 232      * 注意:使用完ResultSet后一定要调用ResultSet.close()
 233      * 
 234      * @param columns
 235      *            各列用逗号分隔,不区分大小写,允许使用"*"
 236      * @param where
 237      * @param pageNo
 238      *            页数,编号从1始
 239      * @param pageSize
 240      *            每页的大小,即使数据库中有充足的数据,返回的量也可能略少于pageSize
 241      * @return
 242      */
 243     @Deprecated
 244     public ResultSet getListByPage(String columns, String where, int pageNo,
 245             int pageSize) {
 246         if (columns == null || columns.length() == 0) {
 247             return null;
 248         }
 249         if (columns.contains("*")) {
 250             columns = allColumns;
 251         }
 252         columns = columns.toLowerCase();
 253         if (pageNo * pageSize > 5000) {
 254             logger.error("pageNo*pageSize  can't more than 5000");
 255             return null;
 256         }
 257         PooledConnection conn = null;
 258         ResultSet resultSet = null;
 259         // 当数据库设置了主键自增时,select出的结果默认就是按主键递增排序好的
 260         StringBuilder sql = new StringBuilder();
 261         sql.append("select ");
 262         sql.append(columns);
 263         sql.append(" from ");
 264         sql.append(TABLE);
 265         if (where != null && where.length() > 0) {
 266             sql.append(" where ");
 267             sql.append(where);
 268         }
 269         sql.append(" limit ");
 270         sql.append(pageSize * (pageNo - 1));
 271         sql.append(",");
 272         sql.append(pageSize);
 273         int timeout = 0;
 274         long begin = System.currentTimeMillis();
 275         try {
 276             conn = this.getSlaveConn();
 277             timeout = conn.getQueryTimeOut();
 278             final Statement statement = conn.getConnection().createStatement();
 279             final String sqlF = sql.toString();
 280             Future<ResultSet> futureResult = exec
 281                     .submit(new Callable<ResultSet>() {
 282                         @Override
 283                         public ResultSet call() throws Exception {
 284                             return statement.executeQuery(sqlF);
 285                         }
 286                     });
 287             resultSet = futureResult.get(timeout, TimeUnit.MILLISECONDS);
 288             // 如果返回结果数为0,则返回的ResultSet为null
 289             resultSet.last();
 290             if (resultSet.getRow() == 0) {
 291                 resultSet = null;
 292             } else {
 293                 resultSet.beforeFirst();
 294             }
 295         } catch (SQLException | InterruptedException | ExecutionException e) {
 296             logger.error("read data from " + TABLE + " failed", e);
 297             kvReporter.send(dbErrorKey, 1);
 298         } catch (TimeoutException e) {
 299             if (conn != null) {
 300                 // 超时,则直接关闭物理连接
 301                 this.closeSlaveConnection(conn);
 302             }
 303             logger.error("sql query timeout, SQL=" + sql.toString()
 304                     + ", time limit is " + timeout);
 305             kvReporter.send(dbErrorKey, 1);
 306         } finally {
 307             if (conn != null) {
 308                 // 正常使用完,返还连接
 309                 this.retrunSlaveConn(conn);
 310             }
 311             long end = System.currentTimeMillis();
 312             kvReporter.send(dbTimeKey, end - begin);
 313         }
 314         return resultSet;
 315     }
 316 
 317     /**
 318      * 分页读取数据
 319      * 
 320      * @param columns
 321      *            各列用逗号分隔,不区分大小写,允许使用"*"
 322      * @param where
 323      * @param pageNo
 324      *            页数,编号从1始
 325      * @param pageSize
 326      *            每页的大小,即使数据库中有充足的数据,返回的量也可能略少于pageSize
 327      * @return 发生异常时返回null,通常是TimeoutException或SQLException
 328      */
 329     public List<T> getDataByPage(String columns, String where, int pageNo,
 330             int pageSize) {
 331         return getDataByPage(columns, where, pageNo, pageSize, null);
 332     }
 333 
 334     /**
 335      * 分页读取数据
 336      * 
 337      * @param columns
 338      *            各列用逗号分隔,不区分大小写,允许使用"*"
 339      * @param where
 340      * @param pageNo
 341      *            页数,编号从1始
 342      * @param pageSize
 343      *            每页的大小,即使数据库中有充足的数据,返回的量也可能略少于pageSize
 344      * @param forceIndex
 345      *            显式指定要使用的索引名称
 346      * @return 发生异常时返回null,通常是TimeoutException或SQLException
 347      */
 348     public List<T> getDataByPage(String columns, String where, int pageNo,
 349             int pageSize, String forceIndex) {
 350         List<T> rect = new ArrayList<T>();
 351         if (columns == null || columns.length() == 0) {
 352             return rect;
 353         }
 354         if (columns.contains("*")) {
 355             columns = allColumns;
 356         }
 357         columns = columns.toLowerCase();
 358         if (pageNo * pageSize > 5000) {
 359             logger.error("pageNo*pageSize  can't more than 5000");
 360             return rect;
 361         }
 362         PooledConnection conn = null;
 363         ResultSet resultSet = null;
 364         Set<String> columnSet = new HashSet<String>();
 365         String[] arr = columns.split(",");
 366         for (String col : arr) {
 367             if (col.length() > 0) {
 368                 columnSet.add(col);
 369             }
 370         }
 371         // 当数据库设置了主键自增时,select出的结果默认就是按主键递增排序好的
 372         StringBuilder sql = new StringBuilder();
 373         sql.append("select ");
 374         sql.append(columns);
 375         sql.append(" from ");
 376         sql.append(TABLE);
 377         if (forceIndex != null && forceIndex.length() > 0) {
 378             sql.append(" force index(");
 379             sql.append(forceIndex);
 380             sql.append(")");
 381         }
 382         if (where != null && where.length() > 0) {
 383             sql.append(" where ");
 384             sql.append(where);
 385         }
 386         sql.append(" limit ");
 387         sql.append(pageSize * (pageNo - 1));
 388         sql.append(",");
 389         sql.append(pageSize);
 390         int timeout = 0;
 391         long begin = System.currentTimeMillis();
 392         try {
 393             conn = this.getSlaveConn();
 394             final Statement statement = conn.getConnection().createStatement();
 395             timeout = conn.getQueryTimeOut();
 396             final String sqlF = sql.toString();
 397             Future<ResultSet> futureResult = exec
 398                     .submit(new Callable<ResultSet>() {
 399                         @Override
 400                         public ResultSet call() throws Exception {
 401                             return statement.executeQuery(sqlF);
 402                         }
 403                     });
 404             resultSet = futureResult.get(timeout, TimeUnit.MILLISECONDS);
 405             while (resultSet.next()) {
 406                 T inst = po2Vo(resultSet, columnSet);
 407                 rect.add(inst);
 408             }
 409         } catch (SQLException | InterruptedException | ExecutionException e) {
 410             rect = null;
 411             logger.error("read data from " + TABLE + " failed", e);
 412             kvReporter.send(dbErrorKey, 1);
 413         } catch (TimeoutException e) {
 414             rect = null;
 415             if (conn != null) {
 416                 // 超时则关闭DB连接,这样就会造成返回连接池中的有无效连接,从连接池中获取连接时需要判断一下连接是否可用。
 417                 this.closeSlaveConnection(conn);
 418             }
 419             logger.error("sql query timeout, SQL=" + sql.toString()
 420                     + ", time limit is " + timeout);
 421             kvReporter.send(dbErrorKey, 1);
 422         } finally {
 423             try {
 424                 if (resultSet != null) {
 425                     resultSet.close();
 426                 }
 427             } catch (SQLException e) {
 428                 logger.error("close ResultSet failed", e);
 429                 kvReporter.send(dbErrorKey, 1);
 430             }
 431             if (conn != null) {
 432                 // 正常使用完,返还连接
 433                 this.retrunSlaveConn(conn);
 434             }
 435             long end = System.currentTimeMillis();
 436             kvReporter.send(dbTimeKey, end - begin);
 437         }
 438         return rect;
 439     }
 440 
 441     /**
 442      * in查询
 443      * 
 444      * @param columns
 445      *            要获取哪几列
 446      * @param collections
 447      * @param targetColumn
 448      *            在哪一列上进行where in查询
 449      * @return 发生异常时返回null
 450      */
 451     public <K extends Number> List<T> getIn(String columns, Set<K> collections,
 452             String targetColumn) {
 453         List<T> rect = new ArrayList<T>();
 454         if (columns == null || columns.length() == 0 || targetColumn == null
 455                 || targetColumn.length() == 0 || collections == null
 456                 || collections.size() == 0) {
 457             return rect;
 458         }
 459         if (columns.contains("*")) {
 460             columns = allColumns;
 461         }
 462         columns = columns.toLowerCase();
 463         PooledConnection conn = null;
 464         ResultSet resultSet = null;
 465         Set<String> columnSet = new HashSet<String>();
 466         String[] arr = columns.split(",");
 467         for (String col : arr) {
 468             if (col.length() > 0) {
 469                 columnSet.add(col);
 470             }
 471         }
 472         // 当数据库设置了主键自增时,select出的结果默认就是按主键递增排序好的
 473         StringBuilder sql = new StringBuilder();
 474         sql.append("select ");
 475         sql.append(columns);
 476         sql.append(" from ");
 477         sql.append(TABLE);
 478         sql.append(" where ");
 479         sql.append(targetColumn);
 480         sql.append(" in (");
 481         for (Number ele : collections) {
 482             sql.append(ele);
 483             sql.append(",");
 484         }
 485         sql.setCharAt(sql.length() - 1, ')');
 486         int timeout = 0;
 487         long begin = System.currentTimeMillis();
 488         try {
 489             conn = this.getSlaveConn();
 490             final Statement statement = conn.getConnection().createStatement();
 491             timeout = conn.getQueryTimeOut();
 492             final String sqlF = sql.toString();
 493             Future<ResultSet> futureResult = exec
 494                     .submit(new Callable<ResultSet>() {
 495                         @Override
 496                         public ResultSet call() throws Exception {
 497                             return statement.executeQuery(sqlF);
 498                         }
 499 
 500                     });
 501             resultSet = futureResult.get(timeout, TimeUnit.MILLISECONDS);
 502             while (resultSet.next()) {
 503                 T inst = po2Vo(resultSet, columnSet);
 504                 rect.add(inst);
 505             }
 506         } catch (SQLException | InterruptedException | ExecutionException e) {
 507             rect = null;
 508             logger.error("read data from " + TABLE + " failed", e);
 509             kvReporter.send(dbErrorKey, 1);
 510         } catch (TimeoutException e) {
 511             rect = null;
 512             if (conn != null) {
 513                 // 超时则关闭DB连接,这样就会造成返回连接池中的有无效连接,从连接池中获取连接时需要判断一下连接是否可用。
 514                 this.closeSlaveConnection(conn);
 515             }
 516             logger.error("sql query timeout, SQL=" + sql.toString()
 517                     + ", time limit is " + timeout);
 518             kvReporter.send(dbErrorKey, 1);
 519         } finally {
 520             try {
 521                 if (resultSet != null) {
 522                     resultSet.close();
 523                 }
 524             } catch (SQLException e) {
 525                 logger.error("close ResultSet failed", e);
 526                 kvReporter.send(dbErrorKey, 1);
 527             }
 528             if (conn != null) {
 529                 // 正常使用完,返还连接
 530                 this.retrunSlaveConn(conn);
 531             }
 532             long end = System.currentTimeMillis();
 533             kvReporter.send(dbTimeKey, end - begin);
 534         }
 535         return rect;
 536     }
 537 
 538     /**
 539      * 从迭代器ResultSet中读出一个实体
 540      * 
 541      * @param resultSet
 542      * @return
 543      */
 544     private T po2Vo(ResultSet resultSet, Set<String> columns) {
 545         try {
 546             @SuppressWarnings("unchecked")
 547             T inst = (T) Class.forName(aclass.getName()).newInstance();
 548             for (Entry<String, Field> entry : column2Field.entrySet()) {
 549                 Field field = entry.getValue();
 550                 String columnName = entry.getKey();
 551                 if (columns.contains("*") || columns.contains(columnName)) {
 552                     if (field.getType() == Integer.class
 553                             || field.getType() == int.class) {
 554                         field.set(inst, resultSet.getInt(columnName));
 555                     } else if (field.getType() == Long.class
 556                             || field.getType() == long.class) {
 557                         field.set(inst, resultSet.getLong(columnName));
 558                     } else if (field.getType() == Double.class
 559                             || field.getType() == double.class) {
 560                         field.set(inst, resultSet.getDouble(columnName));
 561                     } else if (field.getType() == String.class) {
 562                         field.set(inst, resultSet.getString(columnName));
 563                     }
 564                     // MySQL中的datetime和timestamp都只精确到秒
 565                     else if (field.getType() == Date.class) {
 566                         // resultSet.getDate()返回年月日(精确到天),resultSet.getTime()返回时分秒部分(精确到秒)
 567                         field.set(inst, new Date(resultSet.getDate(columnName)
 568                                 .getTime()
 569                                 + resultSet.getTime(columnName).getTime()
 570                                 + TimeZone.getDefault().getRawOffset()));// 注意加上时间偏置
 571                     } else if (field.getType() == Short.class
 572                             || field.getType() == short.class) {
 573                         field.set(inst, resultSet.getShort(columnName));
 574                     } else if (field.getType() == Timestamp.class) {
 575                         field.set(inst, resultSet.getTimestamp(columnName));
 576                     } else if (field.getType() == Byte.class
 577                             || field.getType() == byte.class) {
 578                         field.set(inst, resultSet.getByte(columnName));
 579                     } else if (field.getType() == Float.class
 580                             || field.getType() == float.class) {
 581                         field.set(inst, resultSet.getFloat(columnName));
 582                     }
 583                 }
 584             }
 585             return inst;
 586         } catch (Exception e) {
 587             logger.error("parse column failed", e);
 588             kvReporter.send(dbErrorKey, 1);
 589             return null;
 590         }
 591     }
 592 
 593     /**
 594      * 根据主键获得一个实体
 595      * 
 596      * @param id
 597      * @return
 598      */
 599     public T getById(PK id) {
 600         T rect = null;
 601         PooledConnection conn = null;
 602         ResultSet resultSet = null;
 603         String sql = "select " + allColumns + " from " + TABLE + " where id="
 604                 + id;
 605         long begin = System.currentTimeMillis();
 606         try {
 607             conn = this.getSlaveConn();
 608             Statement statement = conn.getConnection().createStatement();
 609             resultSet = statement.executeQuery(sql);
 610             if (resultSet.next()) {
 611                 rect = po2Vo(resultSet, column2Field.keySet());
 612             }
 613         } catch (SQLException e) {
 614             logger.error("read data from " + TABLE + " failed", e);
 615             kvReporter.send(dbErrorKey, 1);
 616         } finally {
 617             if (resultSet != null) {
 618                 try {
 619                     resultSet.close();
 620                 } catch (SQLException e) {
 621                     logger.error("close ResultSet failed", e);
 622                     kvReporter.send(dbErrorKey, 1);
 623                 }
 624             }
 625             if (conn != null) {
 626                 // 正常使用完,返还连接
 627                 this.retrunSlaveConn(conn);
 628                 ;
 629             }
 630             long end = System.currentTimeMillis();
 631             kvReporter.send(dbTimeKey, end - begin);
 632         }
 633         return rect;
 634     }
 635 
 636     /**
 637      * 根据条件删除一条记录
 638      * 
 639      * @param condition
 640      * @return 成功返回非负数,失败返回-1
 641      */
 642     public int delete(String condition) {
 643         int rect = -1;
 644 
 645         String sql = "delete from " + TABLE + " where " + condition;
 646         PooledConnection conn = null;
 647         long begin = System.currentTimeMillis();
 648         try {
 649             conn = this.getMasterConn();
 650             rect = conn.executeUpdate(sql);
 651         } catch (SQLException e) {
 652             logger.error("delete from " + TABLE + " failed, condition:"
 653                     + condition, e);
 654             kvReporter.send(dbErrorKey, 1);
 655         } finally {
 656             if (conn != null) {
 657                 this.retrunMasterConn(conn);
 658             }
 659             long end = System.currentTimeMillis();
 660             kvReporter.send(dbTimeKey, end - begin);
 661         }
 662         return rect;
 663     }
 664 
 665     /**
 666      * 删除指定column值属于集合ids的行<br>
 667      * 注意:column必须是索引
 668      * 
 669      * @param ids
 670      * @param column
 671      * @return
 672      */
 673     public int deleteIn(List<? extends Number> ids, String column) {
 674         int rect = -1;
 675         StringBuilder sql = new StringBuilder();
 676         sql.append("delete from ");
 677         sql.append(TABLE);
 678         sql.append(" where ");
 679         sql.append(column);
 680         sql.append(" in (");
 681         for (Number ele : ids) {
 682             sql.append(ele);
 683             sql.append(",");
 684         }
 685         sql.setCharAt(sql.length() - 1, ')');
 686         PooledConnection conn = null;
 687         long begin = System.currentTimeMillis();
 688         try {
 689             conn = this.getMasterConn();
 690             rect = conn.executeUpdate(sql.toString());
 691         } catch (SQLException e) {
 692             logger.error("delete from " + TABLE + " failed", e);
 693             kvReporter.send(dbErrorKey, 1);
 694         } finally {
 695             if (conn != null) {
 696                 this.retrunMasterConn(conn);
 697             }
 698             long end = System.currentTimeMillis();
 699             kvReporter.send(dbTimeKey, end - begin);
 700         }
 701         return rect;
 702     }
 703 
 704     /**
 705      * 插入一条新记录
 706      * 
 707      * @param entity
 708      * @return 成功返回受影响的行数,失败返回-1
 709      */
 710     public int insert(T entity) {
 711         int rect = -1;
 712         StringBuilder columnNames = new StringBuilder();
 713         StringBuilder values = new StringBuilder();
 714         try {
 715             for (Entry<String, Field> entry : column2Field.entrySet()) {
 716                 Field field = entry.getValue();
 717                 if ((!field.isAnnotationPresent(Id.class) || field
 718                         .getAnnotation(Id.class).auto_increment() == false)
 719                         && validType.contains(field.getType())) {
 720                     String columnName = entry.getKey();
 721                     columnNames.append(columnName);
 722                     columnNames.append(",");
 723                     Object value = field.get(entity);
 724                     if (value == null) {
 725                         values.append("null");
 726                     } else if (field.getType() == Date.class
 727                             || field.getType() == Timestamp.class) {
 728                         values.append(sdf.format((Date) value));
 729                     } else if (field.getType() == String.class) {
 730                         String sv = (String) value;
 731                         if (containSql(sv)) {
 732                             logger.warn("danger! sql injection:" + sv);
 733                             sv = "";
 734                         }
 735                         values.append("'" + sv + "'");
 736                     } else {
 737                         values.append(value);
 738                     }
 739                     values.append(",");
 740                 }
 741             }
 742 
 743         } catch (Exception e) {
 744             logger.error("reflect " + entity.getClass().getCanonicalName()
 745                     + " entity failed", e);
 746             kvReporter.send(dbErrorKey, 1);
 747         }
 748 
 749         if (columnNames.length() > 0) {
 750             StringBuilder sql = new StringBuilder("insert into ");
 751             sql.append(TABLE);
 752             sql.append(" (");
 753             sql.append(columnNames.subSequence(0, columnNames.length() - 1));
 754             sql.append(") values (");
 755             sql.append(values.subSequence(0, values.length() - 1));
 756             sql.append(")");
 757             PooledConnection conn = null;
 758             long begin = System.currentTimeMillis();
 759             try {
 760                 conn = this.getMasterConn();
 761                 rect = conn.executeUpdate(sql.toString());
 762             } catch (Exception e) {
 763                 // 如果 是因为唯一键值冲突,则不打印日志
 764                 if (!e.getMessage().contains("Duplicate entry")) {
 765                     logger.error("insert data into " + TABLE + " failed", e);
 766                     kvReporter.send(dbErrorKey, 1);
 767                 }
 768                 rect = -1;
 769             } finally {
 770                 if (conn != null) {
 771                     this.retrunMasterConn(conn);
 772                 }
 773                 long end = System.currentTimeMillis();
 774                 kvReporter.send(dbTimeKey, end - begin);
 775             }
 776         }
 777         return rect;
 778     }
 779 
 780     /**
 781      * 批量插入数据
 782      * 
 783      * @param entities
 784      * @return 成功返回受影响的行数,失败返回-1
 785      */
 786     public int batchInsert(List<T> entities) {
 787         if (entities == null || entities.size() == 0) {
 788             return 0;
 789         }
 790         int rect = -1;
 791         StringBuilder sqlBuffer = new StringBuilder("insert into ");
 792         sqlBuffer.append(TABLE);
 793         StringBuilder columnNames = new StringBuilder();
 794         for (Entry<String, Field> entry : column2Field.entrySet()) {
 795             Field field = entry.getValue();
 796             if ((!field.isAnnotationPresent(Id.class) || field.getAnnotation(
 797                     Id.class).auto_increment() == false)
 798                     && validType.contains(field.getType())) {
 799                 String columnName = entry.getKey();
 800                 columnNames.append(columnName);
 801                 columnNames.append(",");
 802             }
 803         }
 804         sqlBuffer.append(" (");
 805         sqlBuffer.append(columnNames.subSequence(0, columnNames.length() - 1));
 806         sqlBuffer.append(") values ");
 807 
 808         for (T entity : entities) {
 809             StringBuilder values = new StringBuilder();
 810             try {
 811                 for (Entry<String, Field> entry : column2Field.entrySet()) {
 812                     Field field = entry.getValue();
 813                     if ((!field.isAnnotationPresent(Id.class) || field
 814                             .getAnnotation(Id.class).auto_increment() == false)
 815                             && validType.contains(field.getType())) {
 816                         Object value = field.get(entity);
 817                         if (value == null) {
 818                             values.append("null");
 819                         } else if (field.getType() == Date.class
 820                                 || field.getType() == Timestamp.class) {
 821                             values.append(sdf.format((Date) value));
 822                         } else if (field.getType() == String.class) {
 823                             String sv = (String) value;
 824                             if (containSql(sv)) {
 825                                 logger.warn("danger! sql injection:" + sv);
 826                                 sv = "";
 827                             }
 828                             values.append("'" + sv + "'");
 829                         } else {
 830                             values.append(value);
 831                         }
 832                         values.append(",");
 833                     }
 834                 }
 835 
 836             } catch (Exception e) {
 837                 logger.error("reflect " + entity.getClass().getCanonicalName()
 838                         + " entity failed", e);
 839                 kvReporter.send(dbErrorKey, 1);
 840             }
 841 
 842             if (values.length() > 0) {
 843                 sqlBuffer.append("(");
 844                 sqlBuffer.append(values.subSequence(0, values.length() - 1));
 845                 sqlBuffer.append("),");
 846             }
 847         }
 848         PooledConnection conn = null;
 849         String sql = sqlBuffer.substring(0, sqlBuffer.length() - 1).toString();
 850         long begin = System.currentTimeMillis();
 851         try {
 852             conn = this.getMasterConn();
 853             rect = conn.executeUpdate(sql);
 854         } catch (Exception e) {
 855             // 如果是因为唯一键(包括主键在内)值冲突,则不打印日志
 856             if (!e.getMessage().contains("Duplicate entry")) {
 857                 logger.error("insert data into " + TABLE + " failed, sql="
 858                         + sql, e);
 859                 kvReporter.send(dbErrorKey, 1);
 860             }
 861             rect = -1;
 862         } finally {
 863             if (conn != null) {
 864                 this.retrunMasterConn(conn);
 865             }
 866             long end = System.currentTimeMillis();
 867             kvReporter.send(dbTimeKey, end - begin);
 868         }
 869         return rect;
 870     }
 871 
 872     /**
 873      * 根据主键更新一条记录
 874      * 
 875      * @param entity
 876      * @return 如果记录不存在则返回0,更新成功返回正数,更新失败返回-1
 877      */
 878     public int update(T entity) {
 879         int rect = -1;
 880         List<String> columnNames = new ArrayList<String>();
 881         List<String> values = new ArrayList<String>();
 882         String condition = null;
 883         try {
 884             for (Entry<String, Field> entry : column2Field.entrySet()) {
 885                 Field field = entry.getValue();
 886                 String columnName = entry.getKey();
 887                 Object value = field.get(entity);
 888                 if (field.isAnnotationPresent(Id.class)) {
 889                     if (value == null) {
 890                         return 0;
 891                     }
 892                     if (field.getType() == Date.class
 893                             || field.getType() == Timestamp.class) {
 894                         condition = columnName + "=" + sdf.format((Date) value);
 895                     } else if (field.getType() == String.class) {
 896                         String sv = (String) value;
 897                         if (containSql(sv)) {
 898                             logger.warn("danger! sql injection:" + sv);
 899                             sv = "";
 900                         }
 901                         condition = columnName + "='" + sv + "'";
 902                     } else {
 903                         condition = columnName + "=" + value;
 904                     }
 905                 } else if (validType.contains(field.getType())) {
 906                     if (value != null) {
 907                         columnNames.add(columnName);
 908                         if (field.getType() == Date.class
 909                                 || field.getType() == Timestamp.class) {
 910                             values.add(sdf.format((Date) value));
 911                         } else if (field.getType() == String.class) {
 912                             String sv = (String) value;
 913                             if (containSql(sv)) {
 914                                 logger.warn("danger! sql injection:" + sv);
 915                                 sv = "";
 916                             }
 917                             values.add("'" + sv + "'");
 918                         } else {
 919                             values.add(value.toString());
 920                         }
 921                     }
 922                 }
 923             }
 924 
 925         } catch (Exception e) {
 926             logger.error("reflect " + entity.getClass().getCanonicalName()
 927                     + " entity failed", e);
 928             kvReporter.send(dbErrorKey, 1);
 929         }
 930 
 931         if (columnNames.size() > 0 && condition != null) {
 932             StringBuilder sql = new StringBuilder("update ");
 933             sql.append(TABLE);
 934             sql.append(" set ");
 935             int i = 0;
 936             for (; i < columnNames.size() - 1; i++) {
 937                 sql.append(columnNames.get(i) + "=" + values.get(i) + ",");
 938             }
 939             sql.append(columnNames.get(i) + "=" + values.get(i));
 940             sql.append(" where " + condition);
 941             PooledConnection conn = null;
 942             long begin = System.currentTimeMillis();
 943             try {
 944                 conn = this.getMasterConn();
 945                 rect = conn.executeUpdate(sql.toString());
 946             } catch (Exception e) {
 947                 logger.error("update " + TABLE + " failed", e);
 948                 kvReporter.send(dbErrorKey, 1);
 949                 rect = -1;
 950             } finally {
 951                 if (conn != null) {
 952                     this.retrunMasterConn(conn);
 953                 }
 954                 long end = System.currentTimeMillis();
 955                 kvReporter.send(dbTimeKey, end - begin);
 956             }
 957         }
 958         return rect;
 959     }
 960 
 961     /**
 962      * 根据主键删除一条记录
 963      * 
 964      * @param id
 965      * @return
 966      */
 967     public int deleteById(PK id) {
 968         String where = "> id;
 969         return delete(where);
 970     }
 971 
 972     /**
 973      * 根据where条件进行count
 974      * 
 975      * @param condition
 976      * @return
 977      */
 978     public int count(String condition) {
 979         int rect = -1;
 980 
 981         String sql = "select count(*) from " + TABLE + " where " + condition;
 982         PooledConnection conn = null;
 983         long begin = System.currentTimeMillis();
 984         try {
 985             conn = this.getSlaveConn();
 986             ResultSet resultSet = conn.executeQuery(sql);
 987             if (resultSet != null) {
 988                 if (resultSet.next()) {
 989                     rect = resultSet.getInt(1); // 编号从1开始
 990                 }
 991                 resultSet.close();
 992             }
 993         } catch (SQLException e) {
 994             logger.error("count(*) from " + TABLE + " failed, condition:"
 995                     + condition, e);
 996             kvReporter.send(dbErrorKey, 1);
 997         } finally {
 998             if (conn != null) {
 999                 this.retrunSlaveConn(conn);
1000             }
1001             long end = System.currentTimeMillis();
1002             kvReporter.send(dbTimeKey, end - begin);
1003         }
1004         return rect;
1005     }
1006 
1007     /**
1008      * 慎用:慢查询,容易造成DB阻塞!<br>
1009      * 删除表里的所有数据
1010      * 
1011      * @return
1012      */
1013     public int deleteAll() {
1014         return delete("1=1");
1015     }
1016 }
View Code

相关文章: