前言
网上找了一些关于Go语言my SQL 操作封装的博客,发现没有太适合我习惯的,于是我参考了一篇博客,对其进行了些调整,已将源码上传至码云,有兴趣的可以拿去用,有修改意见的也可以评论。
源码
package mysql
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
"strings"
)
type Db struct {
link *sql.DB //存储连接对象
host string //数据库地址
port string //端口
dbUser string //数据库用户
dbPwd string //数据库密码
dbName string //存储库
tableName string //存储表名
field string //存储字段
allFields []string //存储当前表所有字段
where string //存储where条件
order string //存储order条件
limit string //存储limit条件
}
//构造方法
func DbNew(host string, port string, dbUser string, dbPwd string, dbName string, table string) Db {
var this Db
this.setProp(host, port, dbUser, dbPwd, dbName, table)
//2.初始化连接数据库
this.getConnect()
//3.获得当前表的所有字段
this.getFields()
return this
}
/**
* 初始化属性
*/
func (this *Db) setProp(host string, port string, dbUser string, dbPwd string, dbName string, table string) {
this.field = "*"
this.host = host
if port == "" {
port = "3306"
}
this.port = port
this.dbUser = dbUser
this.dbPwd = dbPwd
this.dbName = dbName
this.tableName = table
}
/**
* 初始化连接数据库操作
*/
func (this *Db) getConnect() {
//1.连接数据库
db, err := sql.Open("mysql", this.dbUser+":"+this.dbPwd+"@tcp("+this.host+":"+this.port+")/"+this.dbName+"?charset=utf8")
//2.判断连接
if err != nil {
panic(err)
}
this.link = db
}
/**
* 设置要查询的字段信息
* @param string $field 要查询的字段
* @return object 返回自己,保证连贯操作
*/
func (this *Db) Field(field string) *Db {
this.field = field
return this
}
/**
* order排序条件
* @param string $order 以此为基准进行排序
* @return $this 返回自己,保证连贯操作
*/
func (this *Db) Order(order string) *Db {
this.order = `order by ` + order
return this
}
/**
* limit条件
*/
func (this *Db) Limit(limit string) *Db {
this.limit = "limit " + limit
return this
}
/**
* where条件
*/
func (this *Db) Where(where string) *Db {
this.where = `where ` + where
return this
}
/**
* 执行并发送SQL语句(增删改)
*/
func (this *Db) exec(sql string) (errCode int, lastId interface{}, error interface{}) {
//执行sql
res, errExec := this.link.Exec(sql)
if errExec != nil {
return 0, nil, errExec
}
result, errLast := res.LastInsertId()
if errLast != nil {
return 0, result, errLast
}
return 1, result, nil
}
/**
* 执行并发送SQL(查询)
*/
func (this *Db) query(sql string) (errCode int, rows map[int]map[string]interface{}, error interface{}) {
rows2, err := this.link.Query(sql)
if err != nil {
return 0, nil, err
}
//返回所有列
cols, err := rows2.Columns()
if err != nil {
return 0, nil, err
}
//这里表示一行所有列的值,用[]byte表示
vals := make([][]byte, len(cols))
//这里表示一行填充数据
scans := make([]interface{}, len(cols))
//这里scans引用vals,把数据填充到[]byte里
for k, _ := range vals {
scans[k] = &vals[k]
}
i := 0
result := make(map[int]map[string]interface{})
for rows2.Next() {
//填充数据
rows2.Scan(scans...) //将slic地址传入
//每行数据
row := make(map[string]interface{})
//把vals中的数据复制到row中
for k, v := range vals {
key := cols[k]
//这里把[]byte数据转成string
row[key] = string(v)
}
//放入结果集
result[i] = row
i++
}
return 1, result, nil
}
/**
* 添加操作
*/
func (this *Db) Add(data map[string]interface{}) (errCode int, lastId interface{}, error interface{}) {
//过滤非法字段
for k, v := range data {
if res := in_array(k, this.allFields); res != true {
delete(data, k)
} else {
str += k + ` = '` + v.(string) + `',`
//将map中取出的键转为字符串拼接
key += k + `,`
//将map中取出的值转为字符串拼接
value += `'` + v.(string) + `',`
}
}
//去掉逗号
key = strings. Trim (key, ",")
value = strings.Trim(value, ",")
//准备SQL语句
sql := `insert into ` + this.tableName + ` (` + key + `) values (` + value + `)`
// //执行并发送SQL
return this.exec(sql)
}
/**
* 修改操作
*/
func (this *Db) Update(data map[string]interface{}) (errCode int, lastId interface{}, error interface{}) {
//过滤非法字段
for k, v := range data {
if res := in_array(k, this.allFields); res != true {
delete(data, k)
} else {
str += k + ` = '` + v.(string) + `',`
//将map中取出的键转为字符串拼接
key += k + `,`
//将map中取出的值转为字符串拼接
value += `'` + v.(string) + `',`
}
}
//去掉逗号
str = strings.Trim(str, ",")
key = strings.Trim(key, ",")
value = strings.Trim(value, ",")
//判断是否有条件
if this.where == "" {
panic("没有条件")
}
sql := `update ` + this.tableName + ` set ` + str + ` ` + this.where
return this.exec(sql)
}
/**
* 删除操作
*/
func (this *Db) Delete() (errCode int, result interface{}, error interface{}) {
sql := `delete from ` + this.tableName + ` ` + this.where
//执行并发送
return this.exec(sql)
}
/**
* 获取当前表的所有字段
*/
func (this *Db) getFields() {
//查看表结构
sql := "DESC " + this.tableName
//执行并发送SQL
result, err := this.link.Query(sql)
if err != nil {
panic(err)
}
this.allFields = make([]string, 0)
for result.Next() {
var field string
var Type interface{}
var Null string
var Key string
var Default interface{}
var Extra string
err := result.Scan(&field, &Type, &Null, &Key, &Default, &Extra)
if err != nil {
panic(err)
}
this.allFields = append(this.allFields, field)
}
}
/**
* 查询多条数据
*/
func (this *Db) Get() (errCode int, result interface{}, error interface{}) {
sql := `select ` + this.field + ` from ` + this.tableName + ` ` + this.where + ` ` + this.order + ` ` + this.limit
//执行并发送SQL
return this.query(sql)
}
/**
* 查询一条数据
*/
func (this *Db) Find() (errCode int, result map[string]interface{}, error interface{}) {
sql := `select ` + this.field + ` from ` + this.tableName + ` ` + this.where + ` ` + this.order + ` ` + ` limit 1`
//执行并发送sql
errCode, resQuery, error := this.query(sql)
if (errCode != 1) || (len(resQuery) == 0) {
return errCode, nil, error
}
return errCode, resQuery[0], error
}
/**
* 统计总条数
*/
func (this *Db) Count() (errCode int, count interface{}, error interface{}) {
var resQuery map[int]map[string]interface{}
//准备SQL语句
sql := `select count(*) as total from ` + this.tableName + ` ` + this.where + ` ` + ` ` + this.limit
errCode, resQuery, error = this.query(sql)
if errCode != 1 {
return errCode, count, error
}
return errCode, resQuery[0]["total"], error
}
var (
key string
value string
conditions string
str string
)
//是否存在数组内
func in_array(need interface{}, needArr []string) bool {
for _, v := range needArr {
if need == v {
return true
}
}
return false
}
例子
统计结果
Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")
errCode,res,err :=Db.Where("name='1345284190@qq.com'").Count()
fmt.Println(errCode,res,err)
获取单个结果
Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")
errCode,res,err :=Db.Where("name='bbb'").Find()
fmt.Println(errCode,res,err)
获取多个结果
Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")
errCode,res,err :=Db.Where("name='bbb'").Get()
fmt.Println(errCode,res,err)
添加操作
Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")
data := make(map[string]interface{})
data["email"] = "1345284190@qq.com"
data["create_time"] = time. Unix (time.Now().Unix(),0).Format("2006-01-02 15:04:05")
errCode,res,err :=Db.Add(data)
fmt.Println(errCode,res,err)
更新操作
Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")
data := make(map[string]interface{})
data["name"] = "222@qq.com"
data["create_time"] = time.Unix(time.Now().Unix(),0).Format("2006-01-02 15:04:05")
errCode,res,err :=Db.Where("id = 12").Update(data)
fmt.Println(errCode,res,err)
删除操作
Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")
errCode,res,err :=Db.Where("name='bbb'").Delete()
fmt.Println(errCode,res,err)
写在最后
码云链接:
主图来源: