七叶笔记 » golang编程 » 「GoLang」对mysql操作的简单封装

「GoLang」对mysql操作的简单封装

前言

网上找了一些关于Go语言my SQL 操作封装的博客,发现没有太适合我习惯的,于是我参考了一篇博客,对其进行了些调整,已将源码上传至码云,有兴趣的可以拿去用,有修改意见的也可以评论。

源码

 package mysql

import (

"database/sql"

_ "github.com/go-sql-driver/mysql"

"strings"

)

type Db struct {

link *sql.DB //存储连接对象

host string //数据库地址

port string //端口

dbUser string //数据库用户

dbPwd string //数据库密码

dbName string //存储库

tableName string //存储表名

field string //存储字段

allFields []string //存储当前表所有字段

where string //存储where条件

order string //存储order条件

limit string //存储limit条件

}

//构造方法

func DbNew(host string, port string, dbUser string, dbPwd string, dbName string, table string) Db {

var this Db

this.setProp(host, port, dbUser, dbPwd, dbName, table)

//2.初始化连接数据库

this.getConnect()

//3.获得当前表的所有字段

this.getFields()

return this

}

/**

* 初始化属性

*/
func (this *Db) setProp(host string, port string, dbUser string, dbPwd string, dbName string, table string) {

this.field = "*"

this.host = host

if port == "" {

port = "3306"

}

this.port = port

this.dbUser = dbUser

this.dbPwd = dbPwd

this.dbName = dbName

this.tableName = table

}

/**

* 初始化连接数据库操作

*/
func (this *Db) getConnect() {

//1.连接数据库

db, err := sql.Open("mysql", this.dbUser+":"+this.dbPwd+"@tcp("+this.host+":"+this.port+")/"+this.dbName+"?charset=utf8")

//2.判断连接

if err != nil {

panic(err)

}

this.link = db

}

/**

* 设置要查询的字段信息

* @param string $field 要查询的字段

* @return object 返回自己,保证连贯操作

*/
func (this *Db) Field(field string) *Db {

this.field = field

return this

}

/**

* order排序条件

* @param string $order 以此为基准进行排序

* @return $this 返回自己,保证连贯操作

*/
func (this *Db) Order(order string) *Db {

this.order = `order by ` + order

return this

}

/**

* limit条件

*/
func (this *Db) Limit(limit string) *Db {

this.limit = "limit " + limit

return this

}

/**

* where条件

*/
func (this *Db) Where(where string) *Db {

this.where = `where ` + where

return this

}

/**

* 执行并发送SQL语句(增删改)

*/
func (this *Db) exec(sql string) (errCode int, lastId interface{}, error interface{}) {

//执行sql

res, errExec := this.link.Exec(sql)

if errExec != nil {

return 0, nil, errExec

}

result, errLast := res.LastInsertId()

if errLast != nil {

return 0, result, errLast

}

return 1, result, nil

}

/**

* 执行并发送SQL(查询)

*/
func (this *Db) query(sql string) (errCode int, rows map[int]map[string]interface{}, error interface{}) {

rows2, err := this.link.Query(sql)

if err != nil {

return 0, nil, err

}

//返回所有列

cols, err := rows2.Columns()

if err != nil {

return 0, nil, err

}

//这里表示一行所有列的值,用[]byte表示

vals := make([][]byte, len(cols))

//这里表示一行填充数据

scans := make([]interface{}, len(cols))

//这里scans引用vals,把数据填充到[]byte里

for k, _ := range vals {

scans[k] = &vals[k]

}

i := 0

result := make(map[int]map[string]interface{})

for rows2.Next() {

//填充数据

rows2.Scan(scans...) //将slic地址传入

//每行数据

row := make(map[string]interface{})

//把vals中的数据复制到row中

for k, v := range vals {

key := cols[k]

//这里把[]byte数据转成string

row[key] = string(v)

}

//放入结果集

result[i] = row

i++

}

return 1, result, nil

}

/**

* 添加操作

*/
func (this *Db) Add(data map[string]interface{}) (errCode int, lastId interface{}, error interface{}) {

//过滤非法字段

for k, v := range data {

if res := in_array(k, this.allFields); res != true {

delete(data, k)

} else {

str += k + ` = '` + v.(string) + `',`

//将map中取出的键转为字符串拼接

key += k + `,`

//将map中取出的值转为字符串拼接

value += `'` + v.(string) + `',`

}

}

//去掉逗号

key = strings. Trim (key, ",")

value = strings.Trim(value, ",")

//准备SQL语句

sql := `insert into ` + this.tableName + ` (` + key + `) values (` + value + `)`

// //执行并发送SQL

return this.exec(sql)

}

/**

* 修改操作

*/
func (this *Db) Update(data map[string]interface{}) (errCode int, lastId interface{}, error interface{}) {

//过滤非法字段

for k, v := range data {

if res := in_array(k, this.allFields); res != true {

delete(data, k)

} else {

str += k + ` = '` + v.(string) + `',`

//将map中取出的键转为字符串拼接

key += k + `,`

//将map中取出的值转为字符串拼接

value += `'` + v.(string) + `',`

}

}

//去掉逗号

str = strings.Trim(str, ",")

key = strings.Trim(key, ",")

value = strings.Trim(value, ",")

//判断是否有条件

if this.where == "" {

panic("没有条件")

}

sql := `update ` + this.tableName + ` set ` + str + ` ` + this.where

return this.exec(sql)

}

/**

* 删除操作

*/
func (this *Db) Delete() (errCode int, result interface{}, error interface{}) {

sql := `delete from ` + this.tableName + ` ` + this.where

//执行并发送

return this.exec(sql)

}

/**

* 获取当前表的所有字段

*/
func (this *Db) getFields() {

//查看表结构

sql := "DESC " + this.tableName

//执行并发送SQL

result, err := this.link.Query(sql)

if err != nil {

panic(err)

}

this.allFields = make([]string, 0)

for result.Next() {

var field string

var Type interface{}

var Null string

var Key string

var Default interface{}

var Extra string

err := result.Scan(&field, &Type, &Null, &Key, &Default, &Extra)

if err != nil {

panic(err)

}

this.allFields = append(this.allFields, field)

}

}

/**

* 查询多条数据

*/
func (this *Db) Get() (errCode int, result interface{}, error interface{}) {

sql := `select ` + this.field + ` from ` + this.tableName + ` ` + this.where + ` ` + this.order + ` ` + this.limit

//执行并发送SQL

return this.query(sql)

}

/**

* 查询一条数据

*/
func (this *Db) Find() (errCode int, result map[string]interface{}, error interface{}) {

sql := `select ` + this.field + ` from ` + this.tableName + ` ` + this.where + ` ` + this.order + ` ` + ` limit 1`

//执行并发送sql

errCode, resQuery, error := this.query(sql)

if (errCode != 1) || (len(resQuery) == 0) {

return errCode, nil, error

}

return errCode, resQuery[0], error

}

/**

* 统计总条数

*/
func (this *Db) Count() (errCode int, count interface{}, error interface{}) {

var resQuery map[int]map[string]interface{}

//准备SQL语句

sql := `select count(*) as total from ` + this.tableName + ` ` + this.where + ` ` + ` ` + this.limit

errCode, resQuery, error = this.query(sql)

if errCode != 1 {

return errCode, count, error

}

return errCode, resQuery[0]["total"], error

}

var (

key string

value string

conditions string

str string

)

//是否存在数组内

func in_array(need interface{}, needArr []string) bool {

for _, v := range needArr {

if need == v {

return true

}

}

return false

}  

例子

统计结果

 Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")

errCode,res,err :=Db.Where("name='1345284190@qq.com'").Count()

fmt.Println(errCode,res,err)  

获取单个结果

 Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")

errCode,res,err :=Db.Where("name='bbb'").Find()

fmt.Println(errCode,res,err)  

获取多个结果

 Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")

errCode,res,err :=Db.Where("name='bbb'").Get()

fmt.Println(errCode,res,err)  

添加操作

 Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")

data := make(map[string]interface{})

data["email"] = "1345284190@qq.com"

data["create_time"] = time. Unix (time.Now().Unix(),0).Format("2006-01-02 15:04:05")

errCode,res,err :=Db.Add(data)

fmt.Println(errCode,res,err)  

更新操作

 Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")

data := make(map[string]interface{})

data["name"] = "222@qq.com"

data["create_time"] = time.Unix(time.Now().Unix(),0).Format("2006-01-02 15:04:05")

errCode,res,err :=Db.Where("id = 12").Update(data)

fmt.Println(errCode,res,err)  

删除操作

 Db := mysql.DbNew("127.0.0.1", "3306", "root", "root", "demo", "tb_user")

errCode,res,err :=Db.Where("name='bbb'").Delete()

fmt.Println(errCode,res,err)  

写在最后

码云链接:

主图来源:

相关文章