完成基本的创建使用功能

This commit is contained in:
2024-09-11 20:19:47 +08:00
parent 28a84ad4d7
commit 951572a1f5
22 changed files with 783 additions and 31 deletions
+60
View File
@@ -0,0 +1,60 @@
package gormx
import (
"gorm.io/gorm"
)
type LimitHandler func(db *gorm.DB) *gorm.DB
func Finder[T any](model T, db *gorm.DB, objs ...LimitHandler) (int64, []T, error) {
for _, obj := range objs {
db = obj(db)
}
var models []T
err := db.Model(&model).Find(&models).Error
if err != nil {
return 0, nil, err
}
var sum int64
err = db.Model(&model).Count(&sum).Error
return sum, models, err
}
// FinderPage 分页查找
func FinderPage[T any](db *gorm.DB, page, limit int, objs ...LimitHandler) (int64, []T, error) {
return pageFinder[T](db, pageHandle(page, limit), objs...)
}
func pageFinder[T any](db *gorm.DB, limit LimitHandler, objs ...LimitHandler) (int64, []T, error) {
// 定义表
var model T
db = db.Model(&model)
for _, obj := range objs {
db = obj(db)
}
// 查找全部数量
var sum int64
if err := db.Model(&model).Count(&sum).Error; err != nil {
return 0, nil, err
}
// 分页查找
db = limit(db)
var models []T
if err := db.Find(&models).Error; err != nil {
return 0, nil, err
}
return sum, models, nil
}
func pageHandle(page, limit int) LimitHandler {
return func(db *gorm.DB) *gorm.DB {
// 用户输入起始位
var beginIndex int
if page < 0 {
beginIndex = -1
} else {
beginIndex = (page - 1) * limit
}
return db.Offset(beginIndex).Limit(limit)
}
}
+27
View File
@@ -0,0 +1,27 @@
package gormx
import (
"errors"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
type Config struct {
Type string `ini:"type"`
DSN string `ini:"dsn"`
}
func New(cfg Config) (db *gorm.DB, err error) {
switch cfg.Type {
case "sqlite3":
// dsn := "exe.db"
db, err = gorm.Open(sqlite.Open(cfg.DSN), &gorm.Config{})
case "mysql":
// dsn := "userRepo:pass@tcp(127.0.0.1:3306)/dbname?charset=utf8mb4&parseTime=True&loc=Local"
db, err = gorm.Open(mysql.Open(cfg.DSN), &gorm.Config{})
default:
err = errors.New("mode not supported")
}
return
}
+108
View File
@@ -0,0 +1,108 @@
package gormx
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
)
type ListInt64 []int64
// Value 接口,Value 返回 json value any -> string
func (j ListInt64) Value() (driver.Value, error) {
return json.Marshal(j)
}
// Scan 接口,Scan 将 value 扫描至 Jsonb
func (j *ListInt64) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
err := json.Unmarshal(bytes, j)
if err != nil {
return err
}
return nil
}
type ListUint []uint
// Value 接口,Value 返回 json value any -> string
func (j ListUint) Value() (driver.Value, error) {
return json.Marshal(j)
}
// Scan 接口,Scan 将 value 扫描至 Jsonb
func (j *ListUint) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
err := json.Unmarshal(bytes, j)
if err != nil {
return err
}
return nil
}
type ListInt []int
// Value 接口,Value 返回 json value any -> string
func (j ListInt) Value() (driver.Value, error) {
return json.Marshal(j)
}
// Scan 接口,Scan 将 value 扫描至 Jsonb
func (j *ListInt) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
err := json.Unmarshal(bytes, j)
if err != nil {
return err
}
return nil
}
type ListString []string
// Value 接口,Value 返回 json value any -> string
func (j ListString) Value() (driver.Value, error) {
return json.Marshal(j)
}
// Scan 接口,Scan 将 value 扫描至 Jsonb
func (j *ListString) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
err := json.Unmarshal(bytes, j)
if err != nil {
return err
}
return nil
}
type MapString map[string]string
// Value 接口,Value 返回 json value any -> string
func (j MapString) Value() (driver.Value, error) {
return json.Marshal(j)
}
// Scan 接口,Scan 将 value 扫描至 Jsonb
func (j *MapString) Scan(value interface{}) error {
bytes, ok := value.([]byte)
if !ok {
return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value))
}
err := json.Unmarshal(bytes, j)
if err != nil {
return err
}
return nil
}