1 Star 0 Fork 0

Erdian718/sqlx

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
stmt.go 2.09 KB
一键复制 编辑 原始数据 按行查看 历史
Erdian718 提交于 2025-02-20 13:11 +08:00 . Support linq APIs
package sqlx
import (
"context"
"database/sql"
"reflect"
"gitee.com/erdian718/linq"
)
// Stmt is a prepared statement.
type Stmt struct {
err error
ctx context.Context
raw *sql.Stmt
nargs map[string]any
pargs []any
}
// Exec executes a query without returning any rows.
func (s *Stmt) Exec(args ...any) (Result, error) {
if s.err != nil {
return nil, s.err
}
return s.raw.ExecContext(s.ctx, s.buildArgs(args)...)
}
// Query executes a query that returns rows.
func (s *Stmt) Query(args ...any) linq.Seq[Row] {
if s.err != nil {
return linq.Error[Row](s.err)
}
return func(yield func(Row, error) bool) {
rows, err := s.raw.QueryContext(s.ctx, s.buildArgs(args)...)
if err != nil {
yield(Row{}, err)
return
}
defer rows.Close()
row := Row{rows: rows}
if row.columns, err = rows.Columns(); err != nil {
yield(row, err)
return
}
row.fields = make([]any, len(row.columns))
for rows.Next() {
if !yield(row, nil) {
return
}
}
if err := rows.Err(); err != nil {
yield(row, err)
}
}
}
func (s *Stmt) reset() {
if len(s.pargs) > 0 {
s.pargs = s.pargs[:0]
for param := range s.nargs {
s.nargs[param] = nil
}
}
}
func (s *Stmt) buildArgs(args []any) []any {
s.reset()
for _, arg := range args {
s.buildArg(arg)
}
for name, value := range s.nargs {
s.pargs = append(s.pargs, sql.Named(name, value))
}
return s.pargs
}
func (s *Stmt) buildArg(arg any) {
if namedArg, ok := arg.(sql.NamedArg); ok {
s.buildNamedArg(namedArg)
return
}
v := reflect.ValueOf(arg)
t := v.Type()
if t.Kind() == reflect.Struct {
s.buildStructArg(t, v)
return
}
if t.Kind() == reflect.Pointer {
if t := t.Elem(); t.Kind() == reflect.Struct {
s.buildStructArg(t, v.Elem())
return
}
}
s.pargs = append(s.pargs, arg)
}
func (s *Stmt) buildNamedArg(arg sql.NamedArg) {
if _, ok := s.nargs[arg.Name]; ok {
s.nargs[arg.Name] = arg.Value
}
}
func (s *Stmt) buildStructArg(t reflect.Type, v reflect.Value) {
for name := range s.nargs {
if fv, err := getFieldValue(t, v, name, false); err == nil {
s.nargs[name] = fv.Interface()
}
}
}
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/erdian718/sqlx.git
git@gitee.com:erdian718/sqlx.git
erdian718
sqlx
sqlx
main

搜索帮助