代码拉取完成,页面将自动刷新
package tcode
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
)
type mapping uint8
type ignored uint8
const (
noMapping mapping = iota
resultSingleVar
resultStruct
resultSlice
resultRawTable
// notIgnoredEveryColumn 不忽略任何一列
notIgnoredEveryColumn ignored = iota
// ignoredEveryEmptyColumn 忽略每一个空列
ignoredEveryEmptyColumn
)
type TableSlice[T Table] []T
type rawTableSqlInfo interface {
ToVar(ptrVar ...any) error
ToRawTable(page *page) (*rawTable, error)
Exec() (rowsAffected int64, lastInsertId int64, err error)
String() (sqlStr string, err error)
}
type sqlInfo[T Table] struct {
ctx context.Context
statement strings.Builder
params []any
result T
resultSlice *[]T
rawTable *rawTable
ptrVar []any
mapping mapping
LastInsertId int64
RowsAffected int64
//updateMechanism 更新策略,默认忽略每一个空列, notIgnoredEveryColumn 不忽略任何一列
updateMechanism ignored
// possibleQueryEmptyColumn 当数据列可能存在(nil)空值时,请设置possibleQueryEmptyColumn=true,设置为true时效率会较低
//推荐数据库默认值不要为null,都应该有对应的默认值
possibleQueryEmptyColumn bool
}
func (sqlInfo *sqlInfo[T]) ToVar(ptrVar ...any) (err error) {
for i := range ptrVar {
if ptrVar[i] == nil {
err = errors.New(fmt.Sprintf("err: ptrVar index[%d] is nil", i))
FuncLog(err)
return err
}
}
sqlInfo.ptrVar = append(sqlInfo.ptrVar, ptrVar...)
sqlInfo.mapping = resultSingleVar
err = sqlInfo.query()
if err != nil {
FuncLog(err)
return err
}
return nil
}
// ToStruct 获取结构体数据
func (sqlInfo *sqlInfo[T]) ToStruct() (t T, err error) {
sqlInfo.result = sqlInfo.result.NewInstance().(T)
sqlInfo.mapping = resultStruct
err = sqlInfo.query()
if err != nil {
return t, err
}
return sqlInfo.result, nil
}
// ToSlice 获取切片数据
func (sqlInfo *sqlInfo[T]) ToSlice(page *page) (*[]T, error) {
sqlInfo.resultSlice = new([]T)
sqlInfo.mapping = resultSlice
if page != nil {
var total int
var err error
if page.FuncCustomTotal == nil {
var sqlStr string
sqlStr, err = sqlInfo.String()
if err != nil {
FuncLog(err)
return nil, err
}
info := NewSqlInfo(sqlInfo.ctx, fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS count_query", sqlStr), sqlInfo.params...)
err = info.ToVar(&total)
} else {
err = page.FuncCustomTotal(&total)
}
if err != nil {
FuncLog(err)
return nil, err
}
page.setTotalCount(total)
if total <= 0 {
err := errors.New("no data")
return &[]T{}, err
}
sqlInfo.AppendSQL(fmt.Sprintf(" LIMIT %d OFFSET %d ", page.PageSize, (page.CurrentPage-1)*page.PageSize))
}
err := sqlInfo.query()
if err != nil {
return nil, err
}
return sqlInfo.resultSlice, nil
}
// ToRawTable 获取原始表数据
func (sqlInfo *sqlInfo[T]) ToRawTable(page *page) (*rawTable, error) {
sqlInfo.mapping = resultRawTable
if page != nil {
var total int
var err error
if page.FuncCustomTotal == nil {
var sqlStr string
sqlStr, err = sqlInfo.String()
if err != nil {
FuncLog(err)
return sqlInfo.rawTable, err
}
info := NewSqlInfo(sqlInfo.ctx, fmt.Sprintf("SELECT COUNT(*) FROM (%s) AS count_query", sqlStr), sqlInfo.params...)
err = info.ToVar(&total)
} else {
err = page.FuncCustomTotal(&total)
}
if err != nil {
FuncLog(err)
return sqlInfo.rawTable, err
}
page.setTotalCount(total)
if total <= 0 {
err := errors.New("no data")
return sqlInfo.rawTable, err
}
sqlInfo.AppendSQL(fmt.Sprintf(" LIMIT %d OFFSET %d ", page.PageSize, (page.CurrentPage-1)*page.PageSize))
}
err := sqlInfo.query()
if err != nil {
return sqlInfo.rawTable, err
}
return sqlInfo.rawTable, nil
}
// String 处理最终执行的sql
func (sqlInfo *sqlInfo[T]) String() (sqlStr string, err error) {
// sql查询条件 in 切片处理
err = handlerSqlIn(sqlInfo)
if err != nil {
return "", err
}
sqlStr = sqlInfo.statement.String()
if strings.Contains(sqlStr, "'") {
err = errors.New("warn: sqlStr statement contains [']")
FuncLog(err)
}
if getContextConf(sqlInfo.ctx).DebugSQL {
split := strings.Split(sqlStr, "?")
params := sqlInfo.params
paramsLen := len(params)
if len(split)-1 != paramsLen {
err = errors.New(fmt.Sprintf("err: parameter conditions do not match [?]:%d; value:%d", len(split)-1, paramsLen))
FuncLog(err)
return "", err
}
sqlBuilderTemp := strings.Builder{}
for i := 0; i < paramsLen; i++ {
sqlBuilderTemp.WriteString(split[i])
sqlBuilderTemp.WriteString(singleQ + ConvertToString(params[i]) + singleQ)
}
sqlBuilderTemp.WriteString(split[paramsLen])
FuncLog("debug sqlStr:", sqlBuilderTemp.String())
}
return sqlStr, err
}
func (sqlInfo *sqlInfo[T]) AppendSQL(statement string, param ...any) *sqlInfo[T] {
sqlInfo.statement.WriteString(statement)
sqlInfo.params = append(sqlInfo.params, param...)
return sqlInfo
}
func handlerSqlIn[T Table](sqlInfo *sqlInfo[T]) error {
pSql := []byte(sqlInfo.statement.String())
//查找in所在的位置
index := findInReg.FindAllIndex(pSql, -1)
ii := len(index)
if ii <= 0 {
return nil
}
var resultSql []byte
var resultParam []any
a := 0
b := 0
for i := range pSql {
if a < ii && pSql[i] == '?' && i > index[a][0] && i < index[a][1] {
slice := ConvertToStringSlice(sqlInfo.params[b])
if len(slice) <= 0 {
err := errors.New("slice len is 0")
FuncLog(err)
return err
}
for j := range slice {
if j > 0 {
resultSql = append(resultSql, ',')
}
resultSql = append(resultSql, '?')
resultParam = append(resultParam, slice[j])
}
a++
} else if pSql[i] == '?' {
resultSql = append(resultSql, pSql[i])
resultParam = append(resultParam, sqlInfo.params[b])
} else {
resultSql = append(resultSql, pSql[i])
}
if pSql[i] == '?' {
b++
}
}
//将只有一个数据的 in 条件修改成 = ; 当多次调用这个方法是不会再次处理已经匹配过的,减少处理次数
index = findInReg.FindAllIndex(resultSql, -1)
for i := range index {
for j := index[i][0]; j < index[i][1]; j++ {
resultSql[j] = ' '
}
resultSql[index[i][0]] = '='
resultSql[index[i][0]+1] = '?'
}
//不丢参数
for i := b; i < len(sqlInfo.params); i++ {
resultParam = append(resultParam, sqlInfo.params[i])
}
sqlInfo.statement.Reset()
sqlInfo.statement.Write(resultSql)
sqlInfo.params = resultParam
return nil
}
// PossibleQueryEmptyColumn 当数据列可能存在(nil)空值时,请设置possibleQueryEmptyColumn=true,设置为true时效率会较低
// 推荐数据库默认值不要为null,都应该有对应的默认值
func (sqlInfo *sqlInfo[T]) PossibleQueryEmptyColumn() *sqlInfo[T] {
sqlInfo.possibleQueryEmptyColumn = true
return sqlInfo
}
func (sqlInfo *sqlInfo[T]) query() error {
sqlStr, err := sqlInfo.String()
if err != nil {
FuncLog(err)
return err
}
var rows *sql.Rows
tx := GetContextTxConn(sqlInfo.ctx)
startTime := time.Now().UnixNano()
if tx != nil {
rows, err = tx.QueryContext(sqlInfo.ctx, sqlStr, sqlInfo.params...)
} else {
db := GetContextDBConn(sqlInfo.ctx)
if db == nil {
err = errors.New("please set ctx dbConn")
FuncLog(err)
return err
}
rows, err = db.QueryContext(sqlInfo.ctx, sqlStr, sqlInfo.params...)
}
endTime := time.Now().UnixNano()
FuncSQLLog(float64(endTime-startTime)/1e6, sqlStr, sqlInfo.params)
if err != nil {
FuncLog(err)
return err
}
defer func() {
err = rows.Close()
if err != nil {
FuncLog(err)
}
}()
columns, err := rows.Columns()
if err != nil {
FuncLog(err)
return err
}
switch sqlInfo.mapping {
case noMapping:
FuncLog("warn: no mapping")
return nil
case resultSingleVar:
if rows.Next() {
if sqlInfo.possibleQueryEmptyColumn {
err = reTryScan(sqlInfo.ptrVar, rows)
} else {
err = rows.Scan(sqlInfo.ptrVar...)
}
if err != nil {
FuncLog(err)
return err
}
}
if rows.Next() {
err = errors.New("err: require one row data,but returns multi row")
FuncLog(err)
return err
}
case resultStruct:
sqlInfo.ptrVar = sqlInfo.result.RawColumnContainer(columns...)
if rows.Next() {
if sqlInfo.possibleQueryEmptyColumn {
err = reTryScan(sqlInfo.ptrVar, rows)
} else {
err = rows.Scan(sqlInfo.ptrVar...)
}
if err != nil {
FuncLog(err)
return err
}
}
if rows.Next() {
err = errors.New("err: require one row data,but returns multi row")
FuncLog(err)
return err
}
case resultSlice:
//用于恢复结构体属性类型的默认值
var empty = sqlInfo.result.NewInstance()
var sliceTable = sqlInfo.result.NewInstance()
sqlInfo.ptrVar = sliceTable.RawColumnContainer(columns...)
for i := 0; rows.Next(); i++ {
if sqlInfo.possibleQueryEmptyColumn {
err = reTryScan(sqlInfo.ptrVar, rows)
} else {
err = rows.Scan(sqlInfo.ptrVar...)
}
if err != nil {
FuncLog(err)
return err
}
*sqlInfo.resultSlice = append(*sqlInfo.resultSlice, sliceTable.NewTable().(T))
sliceTable.CopyFrom(empty) //清空上一行数据(恢复属性默认值)
}
case resultRawTable:
for rows.Next() {
sqlInfo.ptrVar = sqlInfo.rawTable.RawColumnContainer(columns...)
err = rows.Scan(sqlInfo.ptrVar...)
if err != nil {
FuncLog(err)
return err
}
}
sqlInfo.rawTable.columnNames = columns
default:
err = errors.New("err: unknown mapping type")
FuncLog(err)
return err
}
return nil
}
// insert 属性默认值遵循go语言基本类型的默认值 int=0;string="";time="0000-01-01 00:00:00" ...等
// 新增时的非自增的主键由调用方处理,可使用主键生成工具类(tcode.FuncGenId)生成
func (sqlInfo *sqlInfo[T]) insert(data T) (rowsAffected int64, lastInsertId int64, err error) {
sqlInfo.result = data
columns := data.Columns()
container := data.RawColumnContainer(columns...)
sqlInfo.AppendSQL("INSERT INTO " + data.TableName() + " VALUES (")
for i := range container {
if i > 0 {
sqlInfo.AppendSQL(",")
}
sqlInfo.AppendSQL("?", ConvertToString(container[i]))
}
sqlInfo.AppendSQL(")")
rowsAffected, lastInsertId, err = sqlInfo.Exec()
if lastInsertId != 0 && err == nil { // 处理自增主键回显至结构体
index := StringInIndex(GetPkColumnName(sqlInfo.ctx), columns)
if index <= -1 {
return rowsAffected, lastInsertId, err
}
pkVal := container[index]
switch p := pkVal.(type) {
case *int8:
*p = int8(lastInsertId)
case *int16:
*p = int16(lastInsertId)
case *int32:
*p = int32(lastInsertId)
case *int64:
*p = lastInsertId
case *uint8:
*p = uint8(lastInsertId)
case *uint16:
*p = uint16(lastInsertId)
case *uint32:
*p = uint32(lastInsertId)
case *uint64:
*p = uint64(lastInsertId)
}
}
return rowsAffected, lastInsertId, err
}
func (sqlInfo *sqlInfo[T]) insertBatch(datas *[]T) (rowsAffected int64, lastInsertId int64, err error) {
ds := *datas
if len(ds) <= 0 {
return rowsAffected, lastInsertId, err
}
sqlInfo.result = ds[0]
sqlInfo.AppendSQL("INSERT INTO " + ds[0].TableName() + " VALUES ")
for i := range ds {
if i > 0 {
sqlInfo.AppendSQL(",")
}
sqlInfo.AppendSQL("(")
container := ds[i].RawColumnContainer(ds[i].Columns()...)
for j := range container {
if j > 0 {
sqlInfo.AppendSQL(",")
}
sqlInfo.AppendSQL("?", ConvertToString(container[j]))
}
sqlInfo.AppendSQL(")")
}
return sqlInfo.Exec()
}
// UpdateByPk 更新策略,默认忽略每一个空列, notIgnoredEveryColumn 不忽略任何一列
func (sqlInfo *sqlInfo[T]) updateByPk(data T) (rowsAffected int64, lastInsertId int64, err error) {
sqlInfo.result = data
sqlInfo.AppendSQL("UPDATE " + data.TableName() + " SET ")
columns := data.Columns()
container := data.RawColumnContainer(columns...)
switch sqlInfo.updateMechanism {
case notIgnoredEveryColumn: //不忽略任何一列
for i := range columns {
if i > 0 {
sqlInfo.AppendSQL(",")
}
sqlInfo.AppendSQL(columns[i]+"=?", ConvertToString(container[i]))
}
pkColumnName := GetPkColumnName(sqlInfo.ctx)
condition := data.RawColumnContainer(pkColumnName)
sqlInfo.AppendSQL(" WHERE "+pkColumnName+"=?", ConvertToString(condition[0]))
return sqlInfo.Exec()
case ignoredEveryEmptyColumn: //忽略每一个空列
empty := data.NewInstance()
emptyEqualiser := empty.RawColumnContainer(columns...)
pkColumnName := GetPkColumnName(sqlInfo.ctx)
k := 0
for i := range columns {
if columns[i] == pkColumnName { //不set条件列
continue
}
value := ConvertToString(container[i])
emptyValue := ConvertToString(emptyEqualiser[i])
if value == emptyValue {
continue
}
if k > 0 {
sqlInfo.AppendSQL(",")
}
sqlInfo.AppendSQL(columns[i]+"=?", value)
k++
}
condition := data.RawColumnContainer(pkColumnName)
emptyEqualiser = empty.RawColumnContainer(pkColumnName)
value := ConvertToString(condition[0])
if value == ConvertToString(emptyEqualiser[0]) {
err = errors.New(fmt.Sprintf("condition '%s' column value is empty", pkColumnName))
FuncLog(err)
return rowsAffected, lastInsertId, err
}
sqlInfo.AppendSQL(" WHERE "+pkColumnName+"=?", value)
return sqlInfo.Exec()
default:
err = errors.New("no action was taken")
FuncLog(err)
return rowsAffected, lastInsertId, err
}
}
func (sqlInfo *sqlInfo[T]) deleteByPk(data T) (rowsAffected int64, lastInsertId int64, err error) {
sqlInfo.result = data
pkColumnName := GetPkColumnName(sqlInfo.ctx)
container := data.RawColumnContainer(pkColumnName)
sqlInfo.AppendSQL("DELETE FROM "+data.TableName()+" WHERE "+pkColumnName+"=?", container[0])
return sqlInfo.Exec()
}
func (sqlInfo *sqlInfo[T]) Exec() (rowsAffected int64, lastInsertId int64, err error) {
sqlStr, err := sqlInfo.String()
if err != nil {
FuncLog(err)
return rowsAffected, lastInsertId, err
}
var execContext sql.Result
tx := GetContextTxConn(sqlInfo.ctx)
startTime := time.Now().UnixNano()
if tx != nil {
execContext, err = tx.ExecContext(sqlInfo.ctx, sqlStr, sqlInfo.params...)
} else if conf := getContextConf(sqlInfo.ctx); conf.SkipDefaultTransaction {
if conf.DB == nil {
err = errors.New("please set ctx dbConn")
FuncLog(err)
return rowsAffected, lastInsertId, err
}
execContext, err = conf.DB.ExecContext(sqlInfo.ctx, sqlStr, sqlInfo.params...)
} else {
err = Transaction(sqlInfo.ctx, func(ctx context.Context) error {
tx = GetContextTxConn(ctx)
execContext, err = tx.ExecContext(ctx, sqlStr, sqlInfo.params...)
return err
})
}
endTime := time.Now().UnixNano()
FuncSQLLog(float64(endTime-startTime)/1e6, sqlStr, sqlInfo.params)
if err != nil {
FuncLog(err)
return rowsAffected, lastInsertId, err
}
sqlInfo.RowsAffected, err = execContext.RowsAffected()
if err != nil {
FuncLog(err)
return rowsAffected, lastInsertId, err
}
sqlInfo.LastInsertId, err = execContext.LastInsertId()
if err != nil {
FuncLog(err)
return sqlInfo.RowsAffected, lastInsertId, err
}
return sqlInfo.RowsAffected, sqlInfo.LastInsertId, err
}
func SqlScript[T Table](ctx context.Context, statement string, param ...any) *sqlInfo[T] {
info := sqlInfo[T]{}
info.ctx = ctx
info.mapping = noMapping
info.updateMechanism = ignoredEveryEmptyColumn
info.possibleQueryEmptyColumn = false
info.AppendSQL(statement, param...)
return &info
}
func Select[T Table](ctx context.Context, columns ...string) *sqlInfo[T] {
if len(columns) <= 0 {
columns = append(columns, "*")
}
sqlInfo := SqlScript[T](ctx, "")
sqlInfo.AppendSQL("SELECT " + strings.Join(columns, ",") + " FROM " + sqlInfo.result.TableName() + " ")
return sqlInfo
}
func NewSqlInfo(ctx context.Context, statement string, param ...any) rawTableSqlInfo {
script := SqlScript[*rawTable](ctx, statement, param...)
script.rawTable = &rawTable{}
script.mapping = resultRawTable
return script
}
// Insert 属性默认值遵循go语言基本类型的默认值 int=0;string="";time="0000-01-01 00:00:00" ...等
// 新增时的非自增的主键由调用方处理,可使用主键生成工具类(tcode.FuncGenId)生成
func Insert(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
return SqlScript[Table](ctx, "").insert(t)
}
// InsertBatch 批量新增
func InsertBatch(ctx context.Context, ts *[]Table) (rowsAffected int64, lastInsertId int64, err error) {
return SqlScript[Table](ctx, "").insertBatch(ts)
}
// Save 保存: 新增或更新
func Save(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
exist, err := Exist(ctx, t)
if err != nil {
return rowsAffected, lastInsertId, err
}
if exist {
return UpdateByPk(ctx, t)
}
return Insert(ctx, t)
}
func Exist(ctx context.Context, t Table) (bool, error) {
pkColName := GetPkColumnName(ctx)
pkVal := ConvertToString(t.RawColumnContainer(pkColName)[0])
if len(pkVal) <= 0 || pkVal == "0" {
return false, nil
}
var count int
err := NewSqlInfo(ctx, fmt.Sprintf("SELECT count(%s) FROM %s WHERE %s=?", pkColName, t.TableName(), pkColName), pkVal).ToVar(&count)
if err != nil {
return false, err
}
return count > 0, nil
}
// UpdateByPk 忽略每一个空列 仅根据主键更新
func UpdateByPk(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
return SqlScript[Table](ctx, "").updateByPk(t)
}
// UpdateNotIgnoredEveryColumnByPk 不忽略任何一列 仅根据主键更新
func UpdateNotIgnoredEveryColumnByPk(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
sqlInfo := SqlScript[Table](ctx, "")
sqlInfo.updateMechanism = notIgnoredEveryColumn
return sqlInfo.updateByPk(t)
}
// DeleteByPk 仅根据主键删除
func DeleteByPk(ctx context.Context, t Table) (rowsAffected int64, lastInsertId int64, err error) {
return SqlScript[Table](ctx, "").deleteByPk(t)
}
func Transaction(ctx context.Context, call func(ctx context.Context) error) error {
db := GetContextDBConn(ctx)
if db == nil {
err := errors.New("please set ctx dbConn")
FuncLog(err)
return err
}
tx, err := db.BeginTx(ctx, GetContextTxOptions(ctx))
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
FuncLog(p)
err = tx.Rollback()
} else if err != nil {
FuncLog(err)
err = tx.Rollback()
} else {
err = tx.Commit()
}
if err != nil {
FuncLog(err)
}
}()
ctx = WithTX(ctx, tx)
err = call(ctx)
if err != nil {
return err
}
return nil
}
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。