diff --git a/advance/ctx/context.go b/advance/ctx/context.go deleted file mode 100644 index 97fb03830a9f2be72590857da5d38da9ca744cae..0000000000000000000000000000000000000000 --- a/advance/ctx/context.go +++ /dev/null @@ -1,160 +0,0 @@ -package ctx - -import ( - "context" - "fmt" - "golang.org/x/sync/errgroup" - "sync" - "sync/atomic" - "testing" - "time" -) - -type Cache interface { - Get(key string) (string, error) -} - -type OtherCache interface { - GetValue(ctx context.Context, key string) (any, error) -} - -// CacheAdapter 适配器强调的是不同接口之间进行适配 -// 装饰器强调的是添加额外的功能 -type CacheAdapter struct { - Cache -} - -func (c *CacheAdapter) GetValue(ctx context.Context, key string) (any, error) { - return c.Cache.Get(key) -} - -// 已有的,不是线程安全的 -type memoryMap struct { - // 如果你这样添加锁,那么就是一种侵入式的写法, - // 那么你就需要测试这个类 - // 而且有些时候,这个是第三方的依赖,你都改不了 - // lock sync.RWMutex - m map[string]string -} - -func (m *memoryMap) Get(key string) (string, error) { - return m.m[key], nil -} - -var s = &SafeCache{ - Cache: &memoryMap{}, -} - -// SafeCache 我要改造为线程安全的 -// 无侵入式地改造 -type SafeCache struct { - Cache - lock sync.RWMutex -} - -func (s *SafeCache) Get(key string) (string, error) { - s.lock.RLock() - defer s.lock.RUnlock() - return s.Cache.Get(key) -} - -// type valueCtx struct { -// context.Context -// vals map[any]any -// } - -// func TestSourceCode(t *testing.T) { -// ctx := context.WithCancel(context.Background()) -// } - -func TestErrgroup(t *testing.T) { - eg, ctx := errgroup.WithContext(context.Background()) - var result int64 = 0 - for i := 0; i < 10; i++ { - delta := i - eg.Go(func() error { - atomic.AddInt64(&result, int64(delta)) - return nil - }) - } - if err := eg.Wait(); err != nil { - t.Fatal(err) - } - ctx.Err() - fmt.Println(result) -} - -func TestBusinessTimeout(t *testing.T) { - ctx := context.Background() - timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - end := make(chan struct{}, 1) - go func() { - MyBusiness() - end <- struct{}{} - }() - ch := timeoutCtx.Done() - select { - case <-ch: - fmt.Println("timeout") - case <-end: - fmt.Println("business end") - } -} - -func MyBusiness() { - time.Sleep(500 * time.Millisecond) - fmt.Println("hello, world") -} - -func TestParentValueCtx(t *testing.T) { - ctx := context.Background() - childCtx := context.WithValue(ctx, "map", map[string]string{}) - ccChild := context.WithValue(childCtx, "key1", "value1") - m := ccChild.Value("map").(map[string]string) - m["key1"] = "val1" - val := childCtx.Value("key1") - fmt.Println(val) - val = childCtx.Value("map") - fmt.Println(val) -} - -func TestParentCtx(t *testing.T) { - ctx := context.Background() - dlCtx, cancel := context.WithDeadline(ctx, time.Now().Add(time.Minute)) - childCtx := context.WithValue(dlCtx, "key", 123) - cancel() - err := childCtx.Err() - fmt.Println(err) -} - -func TestContext(t *testing.T) { - ctx := context.Background() - valCtx := context.WithValue(ctx, "abc", 123) - val := valCtx.Value("abc") - fmt.Println(val) -} - -// func TestContext(t *testing.T) { -// ctx := context.Background() -// timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) -// defer cancel() -// time.Sleep(2 * time.Second) -// err := timeoutCtx.Err() -// fmt.Println(err) -// } - -// func SomeBusiness() { -// ctx := context.TODO() -// Step1() -// } - -// -// func Step1(ctx context.Context) { -// var db *sql.DB -// db.ExecContext(ctx, "UPDATE XXXX", 1) -// } - -type A struct { - ctx context.Context -} diff --git a/advance/ctx/context_demo.go b/advance/ctx/context_demo.go deleted file mode 100644 index d061274b03b239015203164ce57105b7d6f87f45..0000000000000000000000000000000000000000 --- a/advance/ctx/context_demo.go +++ /dev/null @@ -1 +0,0 @@ -package ctx diff --git a/advance/ctx/context_test.go b/advance/ctx/context_test.go deleted file mode 100644 index 164e0af36015d26891fd278cd51117c94c396810..0000000000000000000000000000000000000000 --- a/advance/ctx/context_test.go +++ /dev/null @@ -1,67 +0,0 @@ -package ctx - -import ( - "context" - "fmt" - "testing" - "time" -) - -func TestContext(t *testing.T) { - ctx := context.Background() - parent := context.WithValue(ctx, "my key", "my value") - sub := context.WithValue(ctx, "my key", "my new value") - - fmt.Printf("%v \n", parent.Value("my key")) - fmt.Printf("%v \n", sub.Value("my key")) -} - -func TestContext_timeout(t *testing.T) { - bg := context.Background() - timeoutCtx, cancel1 := context.WithTimeout(bg, time.Second) - subCtx, cancel2 := context.WithTimeout(timeoutCtx, 3*time.Second) - go func() { - // 一秒钟之后就会过期,然后输出 timeout - <-subCtx.Done() - fmt.Printf("timout") - }() - - time.Sleep(2 * time.Second) - cancel2() - cancel1() -} - -func TestTimeoutExample(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - bsChan := make(chan struct{}) - go func() { - slowBusiness() - bsChan <- struct{}{} - }() - select { - case <-ctx.Done(): - fmt.Println("timeout") - case <-bsChan: - fmt.Println("business end") - } -} - -func slowBusiness() { - time.Sleep(2 * time.Second) -} - -func TestTimeoutTimeAfter(t *testing.T) { - bsChan := make(chan struct{}) - go func() { - slowBusiness() - bsChan <- struct{}{} - }() - - timer := time.AfterFunc(time.Second, func() { - fmt.Println("timeout") - }) - <-bsChan - fmt.Println("business end") - timer.Stop() -} diff --git a/advance/ctx/graceful_shutdown/main.go b/advance/ctx/graceful_shutdown/main.go index 1fdb6b099b04a705f7c6066952e63ffba6bf7d18..1ce71b5f0d153f09a10ec6eae43184aa17bc50ce 100644 --- a/advance/ctx/graceful_shutdown/main.go +++ b/advance/ctx/graceful_shutdown/main.go @@ -2,10 +2,11 @@ package main import ( "context" - "gitee.com/geektime-geekbang/geektime-go/advance/ctx/graceful_shutdown/service" "log" "net/http" "time" + + "gitee.com/geektime-geekbang/geektime-go/advance/ctx/graceful_shutdown/service" ) // 注意要从命令行启动,否则不同的 IDE 可能会吞掉关闭信号 @@ -15,7 +16,13 @@ func main() { _, _ = writer.Write([]byte("hello")) })) s2 := service.NewServer("admin", "localhost:8081") - app := service.NewApp([]*service.Server{s1, s2}, service.WithShutdownCallbacks(StoreCacheToDBCallback)) + app := service.NewApp( + []*service.Server{s1, s2}, + service.WithShutdownCallbacks(StoreCacheToDBCallback), + service.WithShutDownTimeout(service.DefaultShutdownTimeout), + service.WithWaitTime(service.DefaultWaitTime), + service.WithCbTimeout(service.DefaultCbTimeout), + ) app.StartAndServe() } @@ -26,6 +33,7 @@ func StoreCacheToDBCallback(ctx context.Context) { // 这里我们简单的睡一段时间来模拟 log.Printf("刷新缓存中……") time.Sleep(1 * time.Second) + done <- struct{}{} }() select { case <-ctx.Done(): diff --git a/advance/ctx/graceful_shutdown/service/const.go b/advance/ctx/graceful_shutdown/service/const.go new file mode 100644 index 0000000000000000000000000000000000000000..f16974ed51ddaeb513c2fe400032e63f837a2b39 --- /dev/null +++ b/advance/ctx/graceful_shutdown/service/const.go @@ -0,0 +1,27 @@ +package service + +import ( + "os" + "runtime" + "syscall" + "time" +) + +var Signals = map[string][]os.Signal{ + "darwin": {os.Interrupt, os.Kill, + syscall.SIGKILL, syscall.SIGSTOP, syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT, + syscall.SIGILL, syscall.SIGABRT, syscall.SIGSYS, syscall.SIGTERM}, + "linux": {os.Interrupt, os.Kill, + syscall.SIGKILL, syscall.SIGSTOP, syscall.SIGHUP, + syscall.SIGINT, syscall.SIGQUIT, syscall.SIGILL, + syscall.SIGABRT, syscall.SIGSYS, syscall.SIGTERM}, + "windows": {os.Interrupt, os.Kill, + syscall.SIGKILL, syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT, + syscall.SIGILL, syscall.SIGABRT, syscall.SIGTERM}, +}[runtime.GOOS] + +const ( + DefaultShutdownTimeout = time.Second * 30 + DefaultWaitTime = time.Second * 10 + DefaultCbTimeout = time.Second * 3 +) diff --git a/advance/ctx/graceful_shutdown/service/shutdown.go b/advance/ctx/graceful_shutdown/service/shutdown.go index 6cfafa42780a7f1b1d868dc3b33fbc51e2555dfc..212100d5b12411ff3692a65b8bc53c664347fa9c 100644 --- a/advance/ctx/graceful_shutdown/service/shutdown.go +++ b/advance/ctx/graceful_shutdown/service/shutdown.go @@ -4,10 +4,13 @@ import ( "context" "log" "net/http" + "os" + "os/signal" + "sync" "time" ) -// 典型的 Option 设计模式 +// Option 典型的 Option 设计模式 type Option func(*App) // ShutdownCallback 采用 context.Context 来控制超时,而不是用 time.After 是因为 @@ -15,12 +18,35 @@ type Option func(*App) // - 我们还希望用户知道,他的回调必须要在一定时间内处理完毕,而且他必须显式处理超时错误 type ShutdownCallback func(ctx context.Context) -// 你需要实现这个方法 +// WithShutdownCallbacks 你需要实现这个方法 func WithShutdownCallbacks(cbs ...ShutdownCallback) Option { - panic("implement me") + return func(app *App) { + app.cbs = cbs + } +} + +// WithShutDownTimeout 配置优雅退出超时时间 +func WithShutDownTimeout(d time.Duration) Option { + return func(app *App) { + app.shutdownTimeout = d + } +} + +// WithWaitTime 配置等待时间 +func WithWaitTime(d time.Duration) Option { + return func(app *App) { + app.waitTime = d + } } -// 这里我已经预先定义好了各种可配置字段 +// WithCbTimeout 配置回掉超时时间 +func WithCbTimeout(d time.Duration) Option { + return func(app *App) { + app.cbTimeout = d + } +} + +// App 这里我已经预先定义好了各种可配置字段 type App struct { servers []*Server @@ -37,7 +63,18 @@ type App struct { // NewApp 创建 App 实例,注意设置默认值,同时使用这些选项 func NewApp(servers []*Server, opts ...Option) *App { - panic("implement me") + app := &App{ + servers: servers, + shutdownTimeout: DefaultShutdownTimeout, + waitTime: DefaultWaitTime, + cbTimeout: DefaultCbTimeout, + } + + for _, opt := range opts { + opt(app) + } + + return app } // StartAndServe 你主要要实现这个方法 @@ -47,36 +84,83 @@ func (app *App) StartAndServe() { go func() { if err := srv.Start(); err != nil { if err == http.ErrServerClosed { - log.Printf("服务器%s已关闭", srv.name) + log.Printf("服务器 %s 已关闭\n", srv.name) } else { - log.Printf("服务器%s异常退出", srv.name) + log.Printf("服务器 %s 异常退出\n", srv.name) } - } }() } // 从这里开始优雅退出监听系统信号,强制退出以及超时强制退出。 // 优雅退出的具体步骤在 shutdown 里面实现 // 所以你需要在这里恰当的位置,调用 shutdown + quit := make(chan os.Signal, 1) + signal.Notify(quit, Signals...) + select { + case s := <-quit: + log.Printf("接收到信号: %s,服务器将执行优雅退出\n", s) + // 错误的,应该放在开了协程之后 + // app.shutdown() + go func() { + select { + case s := <-quit: + log.Printf("接受到第二次信号: %s,将强制退出", s) + os.Exit(1) + case <-time.After(app.shutdownTimeout): + log.Printf("服务器优雅退出时间超时,已经强制停止") + os.Exit(1) + } + }() + // 正确的 + app.shutdown() + } + + log.Println("已经停止服务器") } // shutdown 你要设计这里面的执行步骤。 func (app *App) shutdown() { log.Println("开始关闭应用,停止接收新请求") // 你需要在这里让所有的 server 拒绝新请求 + for _, s := range app.servers { + s.rejectReq() + } log.Println("等待正在执行请求完结") // 在这里等待一段时间 + time.Sleep(app.waitTime) + var wg sync.WaitGroup log.Println("开始关闭服务器") // 并发关闭服务器,同时要注意协调所有的 server 都关闭之后才能步入下一个阶段 + for _, s := range app.servers { + wg.Add(1) + go func(s *Server) { + err := s.stop() + if err != nil { + log.Println("服务器关闭错误: ", err) + } + wg.Done() + }(s) + } + wg.Wait() log.Println("开始执行自定义回调") + ctx := context.Background() // 并发执行回调,要注意协调所有的回调都执行完才会步入下一个阶段 + for _, cb := range app.cbs { + wg.Add(1) + go func(cb ShutdownCallback) { + cbContext, cancel := context.WithTimeout(ctx, app.cbTimeout) + defer cancel() + cb(cbContext) + wg.Done() + }(cb) + } + wg.Wait() // 释放资源 log.Println("开始释放资源") - app.close() } func (app *App) close() { @@ -131,6 +215,6 @@ func (s *Server) rejectReq() { } func (s *Server) stop() error { - log.Printf("服务器%s关闭中", s.name) + log.Printf("服务器 %s 关闭中", s.name) return s.srv.Shutdown(context.Background()) } diff --git a/advance/reflect/accessor.go b/advance/reflect/accessor.go deleted file mode 100644 index 30af5d2855b37db50540e180da2cb2718043de36..0000000000000000000000000000000000000000 --- a/advance/reflect/accessor.go +++ /dev/null @@ -1,41 +0,0 @@ -package reflect - -import ( - "errors" - "reflect" -) - -type ReflectAccessor struct { - val reflect.Value - typ reflect.Type -} - -func NewReflectAccessor(val any) (*ReflectAccessor, error) { - typ := reflect.TypeOf(val) - if typ.Kind() != reflect.Pointer || typ.Elem().Kind() != reflect.Struct { - return nil, errors.New("invalid entity") - } - return &ReflectAccessor{ - val: reflect.ValueOf(val).Elem(), - typ: typ.Elem(), - }, nil -} - -func (r *ReflectAccessor) Field(field string) (int, error) { - if _, ok := r.typ.FieldByName(field); !ok { - return 0, errors.New("非法字段") - } - return r.val.FieldByName(field).Interface().(int), nil -} - -func (r *ReflectAccessor) SetField(field string, val int) error { - if _, ok := r.typ.FieldByName(field); !ok { - return errors.New("非法字段") - } - fdVal := r.val.FieldByName(field) - if !fdVal.CanSet() { - return errors.New("无法设置新值的字段") - } - fdVal.Set(reflect.ValueOf(val)) - return nil -} diff --git a/advance/reflect/accessor_test.go b/advance/reflect/accessor_test.go deleted file mode 100644 index 121d6d0a800a77bbbb6701af0808ad0cf65c0b42..0000000000000000000000000000000000000000 --- a/advance/reflect/accessor_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package reflect - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestReflectAccessor_Field(t *testing.T) { - testCases := []struct { - name string - entity interface{} - field string - wantVal int - wantErr error - }{ - { - name: "normal case", - entity: &User{Age: 18}, - field: "Age", - wantVal: 18, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - accessor, err := NewReflectAccessor(tc.entity) - if err != nil { - assert.Equal(t, tc.wantErr, err) - return - } - val, err := accessor.Field(tc.field) - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.wantVal, val) - }) - } -} - -func TestReflectAccessor_SetField(t *testing.T) { - testCases := []struct { - name string - entity *User - field string - newVal int - wantErr error - }{ - { - name: "normal case", - entity: &User{}, - field: "Age", - newVal: 18, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - accessor, err := NewReflectAccessor(tc.entity) - if err != nil { - assert.Equal(t, tc.wantErr, err) - return - } - err = accessor.SetField(tc.field, tc.newVal) - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.newVal, tc.entity.Age) - }) - } -} - -type User struct { - Age int -} diff --git a/advance/reflect/fields.go b/advance/reflect/fields.go deleted file mode 100644 index 616074ff8b4143655b436bdc3369768698194b4f..0000000000000000000000000000000000000000 --- a/advance/reflect/fields.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2021 gotomicro -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package reflect - -import ( - "errors" - "reflect" -) - -// IterateFields 返回所有的字段名字 -// val 只能是结构体,或者结构体指针,可以是多重指针 -func IterateFields(input any) (map[string]any, error) { - typ := reflect.TypeOf(input) - val := reflect.ValueOf(input) - - // 处理指针,要拿到指针指向的东西 - // 这里我们综合考虑了多重指针的效果 - for typ.Kind() == reflect.Ptr { - typ = typ.Elem() - val = val.Elem() - } - - // 如果不是结构体,就返回 error - if typ.Kind() != reflect.Struct { - return nil, errors.New("非法类型") - } - - num := typ.NumField() - res := make(map[string]any, num) - for i := 0; i < num; i++ { - fd := typ.Field(i) - fdVal := val.Field(i) - if fd.IsExported() { - res[fd.Name] = fdVal.Interface() - } else { - // 为了演示效果,不公开字段我们用零值来填充 - res[fd.Name] = reflect.Zero(fd.Type).Interface() - } - } - return res, nil -} - -func SetField(entity any, field string, newVal any) error { - val := reflect.ValueOf(entity) - typ := val.Type() - if typ.Kind() != reflect.Ptr || typ.Elem().Kind() != reflect.Struct { - return errors.New("非法类型") - } - typ = typ.Elem() - val = val.Elem() - fd := val.FieldByName(field) - if _, found := typ.FieldByName(field); !found { - return errors.New("字段不存在") - } - if !fd.CanSet() { - return errors.New("不可修改字段") - } - fd.Set(reflect.ValueOf(newVal)) - return nil -} diff --git a/advance/reflect/fields_test.go b/advance/reflect/fields_test.go deleted file mode 100644 index 1c84f77b5ddfb9a9fba31666b200a15dd84954de..0000000000000000000000000000000000000000 --- a/advance/reflect/fields_test.go +++ /dev/null @@ -1,116 +0,0 @@ -package reflect - -import ( - "errors" - "gitbub.com/flycash/geekbang-middle-camp/advance/reflect/types" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestIterateFields(t *testing.T) { - up := &types.User{} - up2 := &up - testCases := []struct { - name string - input any - wantFields map[string]any - wantErr error - }{ - { - // 普通结构体 - name: "normal struct", - input: types.User{ - Name: "Tom", - // age: 18, - }, - wantFields: map[string]any{ - "Name": "Tom", - "age": 0, - }, - }, - { - // 指针 - name: "pointer", - input: &types.User{ - Name: "Tom", - }, - wantFields: map[string]any{ - "Name": "Tom", - "age": 0, - }, - }, - { - // 多重指针 - name: "multiple pointer", - input: up2, - wantFields: map[string]any{ - "Name": "", - "age": 0, - }, - }, - { - // 非法输入 - name: "slice", - input: []string{}, - wantErr: errors.New("非法类型"), - }, - { - // 非法指针输入 - name: "pointer to map", - input: &(map[string]string{}), - wantErr: errors.New("非法类型"), - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - res, err := IterateFields(tc.input) - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.wantFields, res) - }) - } -} - -func TestSetField(t *testing.T) { - testCases := []struct { - name string - field string - newVal any - entity any - wantErr error - }{ - { - name: "struct", - entity: types.User{}, - field: "Name", - wantErr: errors.New("非法类型"), - }, - { - name: "private field", - entity: &types.User{}, - field: "age", - wantErr: errors.New("不可修改字段"), - }, - { - name: "invalid field", - entity: &types.User{}, - field: "invalid_field", - wantErr: errors.New("不可修改字段"), - }, - { - name: "pass", - entity: &types.User{}, - field: "Name", - newVal: "Tom", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := SetField(tc.entity, tc.field, tc.newVal) - assert.Equal(t, tc.wantErr, err) - }) - } -} diff --git a/advance/reflect/func.go b/advance/reflect/func.go deleted file mode 100644 index 588079d39d7a5f0e0079d6c75e7c2b9b2730aaec..0000000000000000000000000000000000000000 --- a/advance/reflect/func.go +++ /dev/null @@ -1,56 +0,0 @@ -package reflect - -import ( - "errors" - "reflect" -) - -// IterateFuncs 输出方法信息,并执行调用 -func IterateFuncs(val any) (map[string]*FuncInfo, error) { - typ := reflect.TypeOf(val) - if typ.Kind() != reflect.Struct && typ.Kind() != reflect.Ptr { - return nil, errors.New("非法类型") - } - num := typ.NumMethod() - result := make(map[string]*FuncInfo, num) - for i := 0; i < num; i++ { - f := typ.Method(i) - numIn := f.Type.NumIn() - ps := make([]reflect.Value, 0, f.Type.NumIn()) - // 第一个参数永远都是接收器,类似于 java 的 this 概念 - ps = append(ps, reflect.ValueOf(val)) - in := make([]reflect.Type, 0, f.Type.NumIn()) - for j := 0; j < numIn; j++ { - p := f.Type.In(j) - in = append(in, p) - if j > 0 { - ps = append(ps, reflect.Zero(p)) - } - } - // 调用结果 - ret := f.Func.Call(ps) - outNum := f.Type.NumOut() - out := make([]reflect.Type, 0, outNum) - res := make([]any, 0, outNum) - for k := 0; k < outNum; k++ { - out = append(out, f.Type.Out(k)) - res = append(res, ret[k].Interface()) - } - result[f.Name] = &FuncInfo{ - Name: f.Name, - In: in, - Out: out, - Result: res, - } - } - return result, nil -} - -type FuncInfo struct { - Name string - In []reflect.Type - Out []reflect.Type - - // 反射调用得到的结果 - Result []any -} diff --git a/advance/reflect/func_test.go b/advance/reflect/func_test.go deleted file mode 100644 index 48acb0484e09557d6a6f7209f62613f2e2c96e1c..0000000000000000000000000000000000000000 --- a/advance/reflect/func_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package reflect - -import ( - "gitbub.com/flycash/geekbang-middle-camp/advance/reflect/types" - "github.com/stretchr/testify/assert" - "reflect" - "testing" -) - -func TestIterateFuncs(t *testing.T) { - testCases := []struct { - name string - input any - wantRes map[string]*FuncInfo - wantErr error - }{ - { - // 普通结构体 - name: "normal struct", - input: types.User{}, - wantRes: map[string]*FuncInfo{ - "GetAge": { - Name: "GetAge", - In: []reflect.Type{reflect.TypeOf(types.User{})}, - Out: []reflect.Type{reflect.TypeOf(0)}, - Result: []any{0}, - }, - }, - }, - { - // 指针 - name: "pointer", - input: &types.User{}, - wantRes: map[string]*FuncInfo{ - "GetAge": { - Name: "GetAge", - In: []reflect.Type{reflect.TypeOf(&types.User{})}, - Out: []reflect.Type{reflect.TypeOf(0)}, - Result: []any{0}, - }, - "ChangeName": { - Name: "ChangeName", - In: []reflect.Type{reflect.TypeOf(&types.User{}), reflect.TypeOf("")}, - Out: []reflect.Type{}, - Result: []any{}, - }, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - res, err := IterateFuncs(tc.input) - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.wantRes, res) - }) - } -} diff --git a/advance/reflect/iterate.go b/advance/reflect/iterate.go deleted file mode 100644 index 3db6d86a7687725537a2d535494e28fa3ce1af2e..0000000000000000000000000000000000000000 --- a/advance/reflect/iterate.go +++ /dev/null @@ -1,56 +0,0 @@ -package reflect - -import ( - "errors" - "reflect" -) - -// Iterate 迭代数组,切片,或者字符串 -func Iterate(input any) ([]any, error) { - val := reflect.ValueOf(input) - typ := val.Type() - kind := typ.Kind() - if kind != reflect.Array && kind != reflect.Slice && kind != reflect.String { - return nil, errors.New("非法类型") - } - res := make([]any, 0, val.Len()) - for i := 0; i < val.Len(); i++ { - ele := val.Index(i) - res = append(res, ele.Interface()) - } - return res, nil -} - -// IterateMapV1 返回键,值 -func IterateMapV1(input any) ([]any, []any, error) { - val := reflect.ValueOf(input) - if val.Kind() != reflect.Map { - return nil, nil, errors.New("非法类型") - } - l := val.Len() - keys := make([]any, 0, l) - values := make([]any, 0, l) - for _, k := range val.MapKeys() { - keys = append(keys, k.Interface()) - v := val.MapIndex(k) - values = append(values, v.Interface()) - } - return keys, values, nil -} - -// IterateMapV2 返回键,值 -func IterateMapV2(input any) ([]any, []any, error) { - val := reflect.ValueOf(input) - if val.Kind() != reflect.Map { - return nil, nil, errors.New("非法类型") - } - l := val.Len() - keys := make([]any, 0, l) - values := make([]any, 0, l) - itr := val.MapRange() - for itr.Next() { - keys = append(keys, itr.Key().Interface()) - values = append(values, itr.Value().Interface()) - } - return keys, values, nil -} diff --git a/advance/reflect/iterate_test.go b/advance/reflect/iterate_test.go deleted file mode 100644 index e8a95214bcc36d711bb0d241f0f479514e8dc27d..0000000000000000000000000000000000000000 --- a/advance/reflect/iterate_test.go +++ /dev/null @@ -1,91 +0,0 @@ -package reflect - -import ( - "errors" - "fmt" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestIterate(t *testing.T) { - testCases := []struct { - name string - input any - wantRes []any - wantErr error - }{ - { - name: "slice", - input: []int{1, 2, 3}, - wantRes: []any{1, 2, 3}, - }, - { - name: "array", - input: [5]int{1, 2, 3, 4, 5}, - wantRes: []any{1, 2, 3, 4, 5}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - res, err := Iterate(tc.input) - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.wantRes, res) - }) - } -} - -func TestIterateMap(t *testing.T) { - testCases := []struct { - name string - input any - wantKeys []any - wantValues []any - wantErr error - }{ - { - name: "nil", - input: nil, - wantErr: errors.New("非法类型"), - }, - { - name: "happy case", - input: map[string]string{ - "a_k": "a_v", - }, - wantKeys: []any{"a_k"}, - wantValues: []any{"a_v"}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - keys, vals, err := IterateMapV1(tc.input) - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.wantKeys, keys) - assert.Equal(t, tc.wantValues, vals) - - keys, vals, err = IterateMapV2(tc.input) - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.wantKeys, keys) - assert.Equal(t, tc.wantValues, vals) - }) - } -} - -type UserService struct { - GetByIdV1 func() -} - -func (u *UserService) GetByIdV2() { - fmt.Println("aa") -} diff --git a/advance/reflect/types/user.go b/advance/reflect/types/user.go deleted file mode 100644 index 852445f58af012d9ce65136b0863e9dbc46f13cb..0000000000000000000000000000000000000000 --- a/advance/reflect/types/user.go +++ /dev/null @@ -1,22 +0,0 @@ -package types - -import "fmt" - -type User struct { - Name string - // 因为同属一个包,所以 age 还可以被测试访问到 - // 如果是不同包,就访问不到了 - age int -} - -func (u User) GetAge() int { - return u.age -} - -func (u *User) ChangeName(newName string) { - u.Name = newName -} - -func (u User) private() { - fmt.Println("private") -} diff --git a/advance/sync/channel_test.go b/advance/sync/channel_test.go deleted file mode 100644 index 49075891a24463e704ce2f548b0ad8dcc116d08a..0000000000000000000000000000000000000000 --- a/advance/sync/channel_test.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2021 gotomicro -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sync - -import ( - "fmt" - "testing" - "time" -) - -func TestChannelReceive(t *testing.T) { - ch := make(chan string, 1) - go func() { - data := <-ch - fmt.Printf("g1 receiver %s", data) - }() - - go func() { - data := <-ch - fmt.Printf("g2 receiver %s", data) - }() - ch <- "daming" - time.Sleep(3 * time.Second) -} diff --git a/advance/sync/demo/array_list.go b/advance/sync/demo/array_list.go deleted file mode 100644 index a8e47b6bee990fe2f99a59b7b568db8fc2802463..0000000000000000000000000000000000000000 --- a/advance/sync/demo/array_list.go +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2021 gotomicro -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package demo - -// ArrayList 基于切片的简单封装 -type ArrayList[T any] struct { - vals []T -} - -func NewArrayList[T any](cap int) *ArrayList[T] { - panic("implement me") -} - -// NewArrayListOf 直接使用 ts,而不会执行复制 -func NewArrayListOf[T any](ts []T) *ArrayList[T] { - return &ArrayList[T]{ - vals: ts, - } -} - -func (a *ArrayList[T]) Get(index int) (T, error) { - // TODO implement me - panic("implement me") -} - -func (a *ArrayList[T]) Append(t T) error { - // TODO implement me - panic("implement me") -} - -// Add 在ArrayList下标为index的位置插入一个元素 -// 当index等于ArrayList长度等同于append -func (a *ArrayList[T]) Add(index int, t T) error { - if index < 0 || index > len(a.vals) { - return newErrIndexOutOfRange(len(a.vals), index) - } - a.vals = append(a.vals, t) - copy(a.vals[index+1:], a.vals[index:]) - a.vals[index] = t - return nil -} - -func (a *ArrayList[T]) Set(index int, t T) error { - // TODO implement me - panic("implement me") -} - -func (a *ArrayList[T]) Delete(index int) (T, error) { - // TODO implement me - panic("implement me") -} - -func (a *ArrayList[T]) Len() int { - // TODO implement me - panic("implement me") -} - -func (a *ArrayList[T]) Cap() int { - return cap(a.vals) -} - -func (a *ArrayList[T]) Range(fn func(index int, t T) error) error { - for key, value := range a.vals { - e := fn(key, value) - if e != nil { - return e - } - } - return nil -} - -func (a *ArrayList[T]) AsSlice() []T { - slice := make([]T, len(a.vals)) - copy(slice, a.vals) - return slice -} diff --git a/advance/sync/demo/array_list_test.go b/advance/sync/demo/array_list_test.go deleted file mode 100644 index c9a31c88fd70abda71ceebfee0c24bf86ebf1b7b..0000000000000000000000000000000000000000 --- a/advance/sync/demo/array_list_test.go +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright 2021 gotomicro -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package demo - -import ( - "errors" - "fmt" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestArrayList_Add(t *testing.T) { - testCases := []struct { - name string - list *ArrayList[int] - index int - newVal int - wantSlice []int - wantErr error - }{ - // 仿照这个例子,继续添加测试 - // 你需要综合考虑下标的各种可能取值 - // 往两边增加,往中间加 - // 下标可能是负数,也可能超出你的长度 - { - name: "add num to index left", - list: NewArrayListOf[int]([]int{1, 2, 3}), - newVal: 100, - index: 0, - wantSlice: []int{100, 1, 2, 3}, - }, - { - name: "add num to index right", - list: NewArrayListOf[int]([]int{1, 2, 3}), - newVal: 100, - index: 3, - wantSlice: []int{1, 2, 3, 100}, - }, - { - name: "add num to index mid", - list: NewArrayListOf[int]([]int{1, 2, 3}), - newVal: 100, - index: 1, - wantSlice: []int{1, 100, 2, 3}, - }, - { - name: "add num to index -1", - list: NewArrayListOf[int]([]int{1, 2, 3}), - newVal: 100, - index: -1, - wantErr: fmt.Errorf("ekit: 下标超出范围,长度 %d, 下标 %d", 3, -1), - }, - { - name: "add num to index OutOfRange", - list: NewArrayListOf[int]([]int{1, 2, 3}), - newVal: 100, - index: 4, - wantErr: fmt.Errorf("ekit: 下标超出范围,长度 %d, 下标 %d", 3, 4), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - err := tc.list.Add(tc.index, tc.newVal) - assert.Equal(t, tc.wantErr, err) - // 因为返回了 error,所以我们不用继续往下比较了 - if err != nil { - return - } - assert.Equal(t, tc.wantSlice, tc.list.vals) - }) - } -} - -// -// func TestArrayList_Append(t *testing.T) { -// // 这个比较简单,只需要增加元素,然后判断一下 Append 之后是否符合预期 -// } - -func TestArrayList_Cap(t *testing.T) { - testCases := []struct { - name string - expectCap int - list *ArrayList[int] - }{ - { - name: "与实际容量相等", - expectCap: 5, - list: &ArrayList[int]{ - vals: make([]int, 5), - }, - }, - { - name: "用户传入nil", - expectCap: 0, - list: &ArrayList[int]{ - vals: nil, - }, - }, - } - for _, testCase := range testCases { - t.Run(testCase.name, func(t *testing.T) { - actual := testCase.list.Cap() - assert.Equal(t, testCase.expectCap, actual) - }) - } -} - -func BenchmarkArrayList_Cap(b *testing.B) { - list := &ArrayList[int]{ - vals: make([]int, 0), - } - - b.Run("Cap", func(b *testing.B) { - for i := 0; i < b.N; i++ { - list.Cap() - } - }) - - b.Run("Runtime cap", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = cap(list.vals) - } - }) -} - -// func TestArrayList_Append(t *testing.T) { -// // 这个比较简单,只需要增加元素,然后判断一下 Append 之后是否符合预期 -// } - -// func TestArrayList_Delete(t *testing.T) { -// testCases := []struct { -// name string -// list *ArrayList[int] -// index int -// wantSlice []int -// wantVal int -// wantErr error -// }{ -// // 仿照这个例子,继续添加测试 -// // 你需要综合考虑下标的各种可能取值 -// // 往两边增加,往中间加 -// // 下标可能是负数,也可能超出你的长度 -// { -// name: "index 0", -// list: NewArrayListOf[int]([]int{123, 100}), -// index: 0, -// wantSlice: []int{100}, -// wantVal: 123, -// }, -// } -// -// for _, tc := range testCases { -// t.Run(tc.name, func(t *testing.T) { -// val, err := tc.list.Delete(tc.index) -// assert.Equal(t, tc.wantErr, err) -// // 因为返回了 error,所以我们不用继续往下比较了 -// if err != nil { -// return -// } -// assert.Equal(t, tc.wantSlice, tc.list.vals) -// assert.Equal(t, tc.wantVal, val) -// }) -// } -// } -// -// func TestArrayList_Get(t *testing.T) { -// testCases := []struct { -// name string -// list *ArrayList[int] -// index int -// wantVal int -// wantErr error -// }{ -// // 仿照这个例子,继续添加测试 -// // 你需要综合考虑下标的各种可能取值 -// // 往两边增加,往中间加 -// // 下标可能是负数,也可能超出你的长度 -// { -// name: "index 0", -// list: NewArrayListOf[int]([]int{123, 100}), -// index: 0, -// wantVal: 123, -// }, -// } -// -// for _, tc := range testCases { -// t.Run(tc.name, func(t *testing.T) { -// val, err := tc.list.Get(tc.index) -// assert.Equal(t, tc.wantErr, err) -// // 因为返回了 error,所以我们不用继续往下比较了 -// if err != nil { -// return -// } -// assert.Equal(t, tc.wantVal, val) -// }) -// } -// } -func TestArrayList_Range(t *testing.T) { - // 设计两个测试用例,用求和来作为场景 - // 一个测试用例是计算全部元素的和 - // 一个测试用例是计算元素的和,如果遇到了第一个负数,那么就中断返回 - // 测试最终断言求的和是否符合预期 - testCases := []struct { - name string - list *ArrayList[int] - index int - wantVal int - wantErr error - }{ - { - name: "计算全部元素的和", - list: &ArrayList[int]{ - vals: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - }, - wantVal: 55, - wantErr: nil, - }, - { - name: "测试中断", - list: &ArrayList[int]{ - vals: []int{1, 2, 3, 4, -5, 6, 7, 8, -9, 10}, - }, - wantVal: 41, - wantErr: errors.New("index 4 is error"), - }, - { - name: "测试数组为nil", - list: &ArrayList[int]{ - vals: nil, - }, - wantVal: 0, - wantErr: nil, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - result := 0 - err := tc.list.Range(func(index int, num int) error { - if num < 0 { - return fmt.Errorf("index %d is error", index) - } - result += num - return nil - }) - - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.wantVal, result) - }) - } -} - -// -// func TestArrayList_Len(t *testing.T) { -// -// } -// -// func TestArrayList_Set(t *testing.T) { -// -// } -// - -func TestArrayList_AsSlice(t *testing.T) { - vals := []int{1, 2, 3} - a := NewArrayListOf[int](vals) - slice := a.AsSlice() - // 内容相同 - assert.Equal(t, slice, vals) - aAddr := fmt.Sprintf("%p", vals) - sliceAddr := fmt.Sprintf("%p", slice) - // 但是地址不同,也就是意味着 slice 必须是一个新创建的 - assert.NotEqual(t, aAddr, sliceAddr) -} - -// -// // 为其它所有的公开方法都加上例子 -// func ExampleArrayList_Add() { -// list := NewArrayListOf[int]([]int{1, 2, 3}) -// _ = list.Add(0, 9) -// fmt.Println(list.AsSlice()) -// // Output: -// // [9 1 2 3] -// } diff --git a/advance/sync/demo/cas.go b/advance/sync/demo/cas.go deleted file mode 100644 index 8ee85a90898453247b889c6efdf54c8981862f15..0000000000000000000000000000000000000000 --- a/advance/sync/demo/cas.go +++ /dev/null @@ -1,12 +0,0 @@ -package demo - -type Lock struct { - state int -} - -// compare and swap -func (l *Lock) CAS(oldValue int, newValue int) { - if l.state == oldValue { - l.state = newValue - } -} diff --git a/advance/sync/demo/channel_demo_test.go b/advance/sync/demo/channel_demo_test.go deleted file mode 100644 index 109e1d450d5adbebf257bc27ee62b8bd24d397cf..0000000000000000000000000000000000000000 --- a/advance/sync/demo/channel_demo_test.go +++ /dev/null @@ -1,96 +0,0 @@ -package demo - -import ( - "fmt" - "testing" - "time" -) - -func TestChannel(t *testing.T) { - ch := make(chan string, 4) - go func() { - str := <-ch - fmt.Println(str) - }() - go func() { - str := <-ch - fmt.Println(str) - }() - go func() { - str := <-ch - fmt.Println(str) - }() - - ch <- "hello" - ch <- "hello" - time.Sleep(time.Second) -} - -func TestBroker(t *testing.T) { - b := &Broker{ - consumers: make([]*Consumer, 0, 10), - } - c1 := &Consumer{ - ch: make(chan string, 1), - } - c2 := &Consumer{ - ch: make(chan string, 1), - } - b.Subscribe(c1) - b.Subscribe(c2) - - b.Produce("hello") - fmt.Println(<-c1.ch) - fmt.Println(<-c2.ch) -} - -type Broker struct { - consumers []*Consumer -} - -func (b *Broker) Produce(msg string) { - for _, c := range b.consumers { - c.ch <- msg - } -} - -func (b *Broker) Subscribe(c *Consumer) { - b.consumers = append(b.consumers, c) -} - -type Consumer struct { - ch chan string -} - -type Broker1 struct { - ch chan string - consumers []func(s string) -} - -func (b *Broker1) Produce(msg string) { - b.ch <- msg -} - -func (b *Broker1) Subscribe(consume func(s string)) { - b.consumers = append(b.consumers, consume) -} - -func (b *Broker1) Start() { - go func() { - s := <-b.ch - for _, c := range b.consumers { - c(s) - } - }() -} - -func NewBroker1() *Broker1 { - b := &Broker1{ch: make(chan string, 10), consumers: make([]func(s string), 0, 10)} - go func() { - s := <-b.ch - for _, c := range b.consumers { - c(s) - } - }() - return b -} diff --git a/advance/sync/demo/class_demo_sync.go b/advance/sync/demo/class_demo_sync.go deleted file mode 100644 index 1a352265fcd6e2b24a24492b25ac642ce80445de..0000000000000000000000000000000000000000 --- a/advance/sync/demo/class_demo_sync.go +++ /dev/null @@ -1,91 +0,0 @@ -package demo - -import "sync" - -var PublicResource map[string]string -var PublicLock sync.RWMutex - -var privateResource map[string]string -var privateLock sync.RWMutex - -func NewFeature() { - privateLock.Lock() - defer privateLock.Unlock() - privateResource["a"] = "b" -} - -var safeResourceInstance safeResource - -type safeResource struct { - resource map[string]string - lock sync.RWMutex -} - -func (s *safeResource) Add(key string, value string) { - s.lock.Lock() - defer s.lock.RUnlock() - s.resource[key] = value -} - -type SafeMap[K comparable, V any] struct { - values map[K]V - lock sync.RWMutex -} - -// 已经有 key,返回对应的值,然后 loaded = true -// 没有,则放进去,返回 loaded false -// goroutine 1 => ("key1", 1) -// goroutine 2 => ("key1", 2) - -func (s *SafeMap[K, V]) LoadOrStoreV3(key K, newValue V) (V, bool) { - s.lock.RLock() - oldVal, ok := s.values[key] - s.lock.RUnlock() - if ok { - return oldVal, true - } - s.lock.Lock() - defer s.lock.Unlock() - oldVal, ok = s.values[key] - if ok { - return oldVal, true - } - // goroutine1 先进来,那么这里就会变成 key1 => 1 - // goroutine2 进来,那么这里就会变成 key1 => 2 - s.values[key] = newValue - return newValue, false -} - -func (s *SafeMap[K, V]) LoadOrStoreV2(key K, newValue V) (V, bool) { - s.lock.RLock() - oldVal, ok := s.values[key] - s.lock.RUnlock() - if ok { - return oldVal, true - } - s.lock.Lock() - defer s.lock.Unlock() - // goroutine1 先进来,那么这里就会变成 key1 => 1 - // goroutine2 进来,那么这里就会变成 key1 => 2 - s.values[key] = newValue - return newValue, false -} - -func (s *SafeMap[K, V]) LoadOrStoreV1(key K, newValue V) (V, bool) { - s.lock.RLock() - oldVal, ok := s.values[key] - defer s.lock.RUnlock() - if ok { - return oldVal, true - } - s.lock.Lock() - defer s.lock.Unlock() - oldVal, ok = s.values[key] - if ok { - return oldVal, true - } - // goroutine1 先进来,那么这里就会变成 key1 => 1 - // goroutine2 进来,那么这里就会变成 key1 => 2 - s.values[key] = newValue - return newValue, false -} diff --git a/advance/sync/demo/class_demo_sync_test.go b/advance/sync/demo/class_demo_sync_test.go deleted file mode 100644 index 06ee3439c27e293e5ba482056d3cb2427073b39f..0000000000000000000000000000000000000000 --- a/advance/sync/demo/class_demo_sync_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package demo - -import ( - "fmt" - "testing" - "time" -) - -func TestDeferRLock(t *testing.T) { - sm := SafeMap[string, string]{ - values: make(map[string]string, 4), - } - sm.LoadOrStoreV1("a", "b") - fmt.Println("hello") -} - -func TestOverride(t *testing.T) { - sm := SafeMap[string, string]{ - values: make(map[string]string, 4), - } - go func() { - time.Sleep(time.Second) - sm.LoadOrStoreV2("a", "b") - }() - - go func() { - time.Sleep(time.Second) - sm.LoadOrStoreV2("a", "c") - }() - - go func() { - time.Sleep(time.Second) - sm.LoadOrStoreV2("a", "d") - }() - - go func() { - time.Sleep(time.Second) - sm.LoadOrStoreV2("a", "e") - }() - time.Sleep(time.Second) - fmt.Println("hello") -} diff --git a/advance/sync/demo/errors.go b/advance/sync/demo/errors.go deleted file mode 100644 index 434bd830a9d0770667e9968d067cea918fe846c7..0000000000000000000000000000000000000000 --- a/advance/sync/demo/errors.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2021 gotomicro -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package demo - -import "fmt" - -// newErrIndexOutOfRange 创建一个代表 -func newErrIndexOutOfRange(length int, index int) error { - return fmt.Errorf("ekit: 下标超出范围,长度 %d, 下标 %d", length, index) -} diff --git a/advance/sync/demo/my_pool.go b/advance/sync/demo/my_pool.go deleted file mode 100644 index c2d2b5df627a213b02b8162cee598acf03edd286..0000000000000000000000000000000000000000 --- a/advance/sync/demo/my_pool.go +++ /dev/null @@ -1,25 +0,0 @@ -package demo - -import ( - "sync" - "unsafe" -) - -type MyPool struct { - p sync.Pool - maxCnt int32 - cnt int32 -} - -func (p *MyPool) Get() any { - return p.p.Get() -} - -func (p *MyPool) Put(val any) { - // 大对象不放回去 - if unsafe.Sizeof(val) > 1024 { - return - } - - p.p.Put(val) -} diff --git a/advance/sync/demo/once_demo.go b/advance/sync/demo/once_demo.go deleted file mode 100644 index b1438464ef07342dc5d7c2f9759dd1556370ab47..0000000000000000000000000000000000000000 --- a/advance/sync/demo/once_demo.go +++ /dev/null @@ -1,21 +0,0 @@ -package demo - -import ( - "fmt" - "sync" -) - -type OnceClose struct { - close sync.Once -} - -func (o *OnceClose) Close() error { - o.close.Do(func() { - fmt.Println("close") - }) - return nil -} - -func init() { - // 在这里的动作,肯定执行一次 -} diff --git a/advance/sync/demo/once_demo_test.go b/advance/sync/demo/once_demo_test.go deleted file mode 100644 index 839f6584e6a0991893b2fbc1fb346b0e7de9f492..0000000000000000000000000000000000000000 --- a/advance/sync/demo/once_demo_test.go +++ /dev/null @@ -1,12 +0,0 @@ -package demo - -import ( - "testing" -) - -func TestOnceClose_Close(t *testing.T) { - o := &OnceClose{} - for i := 0; i < 100; i++ { - o.Close() - } -} diff --git a/advance/sync/demo/pool_demo.go b/advance/sync/demo/pool_demo.go deleted file mode 100644 index 25a1a2e6f801d8e0af7a377835c81f78415314e2..0000000000000000000000000000000000000000 --- a/advance/sync/demo/pool_demo.go +++ /dev/null @@ -1,21 +0,0 @@ -package demo - -import ( - "fmt" - "sync" -) - -type MyCache struct { - pool sync.Pool -} - -func NewMyCache() *MyCache { - return &MyCache{ - pool: sync.Pool{ - New: func() any { - fmt.Println("hhh, new") - return []byte{} - }, - }, - } -} diff --git a/advance/sync/demo/pool_demo_test.go b/advance/sync/demo/pool_demo_test.go deleted file mode 100644 index 93fc12ba939efa70f24032b0b22a943b990f8775..0000000000000000000000000000000000000000 --- a/advance/sync/demo/pool_demo_test.go +++ /dev/null @@ -1,35 +0,0 @@ -package demo - -import ( - "fmt" - "sync" - "testing" -) - -func TestPool(t *testing.T) { - pool := sync.Pool{ - New: func() any { - return &User{} - }, - } - u1 := pool.Get().(*User) - u1.ID = 12 - u1.Name = "Tom" - // 一通操作 - // 放回去之前要先重置掉 - u1.Reset() - pool.Put(u1) - - u2 := pool.Get().(*User) - fmt.Println(u2) -} - -type User struct { - ID uint64 - Name string -} - -func (u *User) Reset() { - u.ID = 0 - u.Name = "" -} diff --git a/advance/sync/demo/safe_list.go b/advance/sync/demo/safe_list.go deleted file mode 100644 index 336c0e7c5b88e9169821376289ca0d7868dbbe55..0000000000000000000000000000000000000000 --- a/advance/sync/demo/safe_list.go +++ /dev/null @@ -1,20 +0,0 @@ -package demo - -import "sync" - -type SafeList[T any] struct { - List[T] - lock sync.RWMutex -} - -func (s *SafeList[T]) Get(index int) (T, error) { - s.lock.RLock() - defer s.lock.RUnlock() - return s.List.Get(index) -} - -func (s *SafeList[T]) Append(t T) error { - s.lock.Lock() - defer s.lock.Unlock() - return s.List.Append(t) -} diff --git a/advance/sync/demo/task_pool_demo.go b/advance/sync/demo/task_pool_demo.go deleted file mode 100644 index f6eb4c8eea79a197a685c6774f1b574e24841c93..0000000000000000000000000000000000000000 --- a/advance/sync/demo/task_pool_demo.go +++ /dev/null @@ -1,77 +0,0 @@ -package demo - -type TaskPool struct { - ch chan struct{} -} - -func NewTaskPool(limit int) *TaskPool { - t := &TaskPool{ - ch: make(chan struct{}, limit), - } - // 提前准备好了令牌 - for i := 0; i < limit; i++ { - t.ch <- struct{}{} - } - return t -} - -func (t *TaskPool) Do(f func()) { - token := <-t.ch - // 异步执行 - go func() { - f() - t.ch <- token - }() - - // 同步执行 - // f() - // t.ch <- token -} - -type TaskPoolWithCache struct { - cache chan func() -} - -func NewTaskPoolWithCache(limit int, cacheSize int) *TaskPoolWithCache { - t := &TaskPoolWithCache{ - cache: make(chan func(), cacheSize), - } - // 直接把 goroutine 开好 - for i := 0; i < limit; i++ { - go func() { - for { - // 在 goroutine 里面不断尝试从 cache 里面拿到任务 - select { - case task, ok := <-t.cache: - if !ok { - return - } - task() - } - } - }() - } - return t -} - -func (t *TaskPoolWithCache) Do(f func()) { - t.cache <- f -} - -// 显式控制生命周期 -// func (t *TaskPoolWithCache) Start() { -// for i := 0; i < t.limit; i++ { -// go func() { -// for { -// // 在 goroutine 里面不断尝试从 cache 里面拿到任务 -// select { -// case task, ok := <-t.cache: -// if !ok { -// return -// } -// task() -// } -// } -// }() -// } -// } diff --git a/advance/sync/demo/task_pool_demo_test.go b/advance/sync/demo/task_pool_demo_test.go deleted file mode 100644 index 5b07ab024e4890bc4270c2063fd34292e68b64d3..0000000000000000000000000000000000000000 --- a/advance/sync/demo/task_pool_demo_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package demo - -import ( - "fmt" - "testing" - "time" -) - -func TestTaskPool_Do(t1 *testing.T) { - tp := NewTaskPool(2) - tp.Do(func() { - time.Sleep(time.Second) - fmt.Println("task1") - }) - - tp.Do(func() { - time.Sleep(time.Second) - fmt.Println("task2") - }) - - tp.Do(func() { - MyTask(1, "13") - }) -} - -func TestTaskPoolWithCache_Do(t1 *testing.T) { - tp := NewTaskPoolWithCache(2, 10) - tp.Do(func() { - time.Sleep(time.Second) - fmt.Println("task1") - }) - - tp.Do(func() { - time.Sleep(time.Second) - fmt.Println("task2") - }) - - id := 1 - name := "Tom" - tp.Do(func() { - MyTask(id, name) - }) - - time.Sleep(2 * time.Second) -} - -func MyTask(a int, b string) { - // -} diff --git a/advance/sync/demo/types.go b/advance/sync/demo/types.go deleted file mode 100644 index 543caaa2b7d3eef2af67efc5a0eafdee3e8eb54e..0000000000000000000000000000000000000000 --- a/advance/sync/demo/types.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2021 gotomicro -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package demo - -// List 接口 -// 该接口只定义清楚各个方法的行为和表现 -type List[T any] interface { - // Get 返回对应下标的元素, - // 在下标超出范围的情况下,返回错误 - Get(index int) (T, error) - // Append 在末尾追加元素 - Append(t T) error - // Add 在特定下标处增加一个新元素 - // 如果下标超出范围,应该返回错误 - Add(index int, t T) error - // Set 重置 index 位置的值 - // 如果下标超出范围,应该返回错误 - Set(index int, t T) error - // Delete 删除目标元素的位置,并且返回该位置的值 - // 如果 index 超出下标,应该返回错误 - Delete(index int) (T, error) - // Len 返回长度 - Len() int - // Cap 返回容量 - Cap() int - // Range 遍历 List 的所有元素 - Range(fn func(index int, t T) error) error - // AsSlice 将 List 转化为一个切片 - // 不允许返回nil,在没有元素的情况下, - // 必须返回一个长度和容量都为 0 的切片 - // AsSlice 每次调用都必须返回一个全新的切片 - AsSlice() []T -} diff --git a/advance/sync/double_check.go b/advance/sync/double_check.go deleted file mode 100644 index ae7b97699ba7fd282962fcd2b98308470b67977c..0000000000000000000000000000000000000000 --- a/advance/sync/double_check.go +++ /dev/null @@ -1,77 +0,0 @@ -//go:build answer - -package sync - -import "sync" - -type SafeMap[K comparable, V any] struct { - m map[K]V - mutex sync.RWMutex -} - -// LoadOrStore loaded 代表是返回老的对象,还是返回了新的对象 -func (s *SafeMap[K, V]) LoadOrStore(key K, - newVale V) (val V, loaded bool) { - s.mutex.RLock() - val, ok := s.m[key] - s.mutex.RUnlock() - if ok { - return val, true - } - s.mutex.Lock() - defer s.mutex.Unlock() - val, ok = s.m[key] - if ok { - return val, true - } - s.m[key] = newVale - return newVale, false -} - -type valProvider[V any] func() V - -func (s *SafeMap[K, V]) LoadOrStoreHeavy(key K, p valProvider[V]) (val interface{}, loaded bool) { - s.mutex.RLock() - val, ok := s.m[key] - s.mutex.RUnlock() - if ok { - return val, true - } - s.mutex.Lock() - defer s.mutex.Unlock() - val, ok = s.m[key] - if ok { - return val, true - } - newVale := p() - s.m[key] = newVale - return newVale, false -} - -func (s *SafeMap[K, V]) CheckAndDoSomething() { - s.mutex.Lock() - // check and do something - s.mutex.Unlock() -} - -func (s *SafeMap[K, V]) CheckAndDoSomething1() { - s.mutex.RLock() - // check 第一次检查 - s.mutex.RUnlock() - - s.mutex.Lock() - // check and doSomething - defer s.mutex.Unlock() -} - -type Counter struct { - i int -} - -func (c *Counter) Incr() { - c.i++ -} - -func (c *Counter) Get() int { - return c.i -} \ No newline at end of file diff --git a/advance/sync/mutex.go b/advance/sync/mutex.go deleted file mode 100644 index 58187f4d58ddec5922e64a361707e4a3790054a6..0000000000000000000000000000000000000000 --- a/advance/sync/mutex.go +++ /dev/null @@ -1,134 +0,0 @@ - -package sync - -import ( - "sync" -) - -// PublicResource 你永远不知道你的用户拿了它会干啥 -// 他即便不用 PublicResourceLock 你也毫无办法 -// 如果你用这个resource,一定要用锁 -var PublicResource interface{} -var PublicResourceLock sync.Mutex - -// privateResource 要好一点,祈祷你的同事会来看你的注释,知道要用锁 -// 很多库都是这么写的,我也写了很多类似的代码=。= -var privateResource interface{} -var privateResourceLock sync.Mutex - -// safeResource 很棒,所有的期望对资源的操作都只能通过定义在上 safeResource 上的方法来进行 -type safeResource struct { - resource interface{} - lock sync.Mutex -} - -func (s *safeResource) DoSomethingToResource() { - s.lock.Lock() - defer s.lock.Unlock() -} - -var l = sync.RWMutex{} - -func RecursiveA() { - l.Lock() - defer l.Unlock() - RecursiveB() -} - -func RecursiveB() { - RecursiveC() -} - -func RecursiveC() { - l.Lock() - defer l.Unlock() - RecursiveA() -} - -// 锁的伪代码 -// type Lock struct { -// state int -// } -// -// func (l *Lock) Lock() { -// -// i = 0 -// for locked = CAS(UN_LOCK, LOCKED); !locked && i < 10 { -// i ++ -// } -// -// if locked { -// return -// } -// -// // 将自己的线程加入阻塞队列 -// enqueue() -// } - -type List[T any] interface { - Get(index int) T - Set(index int, t T) - DeleteAt(index int) T - Append(t T) -} - -type ArrayList[T any] struct { - vals []T -} - -func (a *ArrayList[T]) Get(index int) T { - return a.vals[index] -} - -func (a *ArrayList[T]) Set(index int, t T) { - if index >= len(a.vals) || index < 0 { - panic("index 超出范围") - } - a.vals[index] = t -} - -func (a *ArrayList[T]) DeleteAt(index int) T { - if index >= len(a.vals) || index < 0 { - panic("index 超出范围") - } - res := a.vals[index] - a.vals = append(a.vals[:index], a.vals[index+1:]...) - return res -} - -func (a *ArrayList[T]) Append(t T) { - a.vals = append(a.vals, t) -} - -func NewArrayList[T any](initCap int) *ArrayList[T] { - return &ArrayList[T]{vals: make([]T, 0, initCap)} -} - -type SafeListDecorator[T any] struct { - l List[T] - mutex sync.RWMutex -} - -func (s *SafeListDecorator[T]) Get(index int) T { - s.mutex.RLock() - defer s.mutex.RUnlock() - return s.l.Get(index) -} -func (s *SafeListDecorator[T]) Set(index int, t T) { - s.mutex.Lock() - defer s.mutex.Unlock() - s.l.Set(index, t) -} -func (s *SafeListDecorator[T]) DeleteAt(index int) T { - s.mutex.Lock() - defer s.mutex.Unlock() - return s.l.DeleteAt(index) -} -func (s *SafeListDecorator[T]) Append(t T) { - s.mutex.Lock() - defer s.mutex.Unlock() - s.l.Append(t) -} - - - diff --git a/advance/sync/mutex_demo.go b/advance/sync/mutex_demo.go deleted file mode 100644 index fd61932b1add9705c792dd985a4530881e370467..0000000000000000000000000000000000000000 --- a/advance/sync/mutex_demo.go +++ /dev/null @@ -1,63 +0,0 @@ - -package sync - -import "sync" - -type SafeMap[K comparable, V any] struct { - m map[K]V - mutex sync.RWMutex -} - -// LoadOrStore loaded 代表是返回老的对象,还是返回了新的对象 -// g1 (key1, 123) g2 (key1, 456) -func (s *SafeMap[K, V]) LoadOrStore(key K, newVal V) (val V, loaded bool) { - oldVal, ok := s.get(key) - if ok { - return oldVal, true - } - s.mutex.Lock() - defer s.mutex.Unlock() - oldVal, ok = s.m[key] - if ok { - return oldVal, true - } - s.m[key]= newVal - return newVal, false -} - -func (s *SafeMap[K, V]) get(key K) (V, bool){ - s.mutex.RLock() - defer s.mutex.RUnlock() - oldVal, ok := s.m[key] - return oldVal, ok -} -type ConcurrentArrayList[T any] struct { - mutex sync.RWMutex - vals []T -} - -func NewConcurrentArrayList[T any](initCap int) *ConcurrentArrayList[T]{ - return &ConcurrentArrayList[T]{ - vals: make([]T, 0, initCap), - } -} - -func (c *ConcurrentArrayList[T]) Get(index int) T { - c.mutex.RLock() - defer c.mutex.RUnlock() - res := c.vals[index] - - return res -} - -func (c *ConcurrentArrayList[T]) DeleteAt(index int) T { - c.mutex.Lock() - defer c.mutex.Unlock() - res := c.vals[index] - c.vals = append(c.vals[:index], c.vals[index+1:]...) - return res -} - -func (c *ConcurrentArrayList[T]) Append(val T) { - -} \ No newline at end of file diff --git a/advance/sync/mutex_test.go b/advance/sync/mutex_test.go deleted file mode 100644 index 469b6a1edb69c1ac6790047c18ea7232b491b67c..0000000000000000000000000000000000000000 --- a/advance/sync/mutex_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package sync - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestArrayList_DeleteAt(t *testing.T) { - testCases := []struct{ - name string - index int - input []int - wantVals []int - } { - { - // 删除第一个 - name: "first", - index: 0, - input: []int{1, 2, 3}, - wantVals: []int{2, 3}, - }, - { - // 删除最后一个 - name: "last", - index: 2, - input: []int{1, 2, 3}, - wantVals: []int{1, 2}, - }, - { - // 删除中间一个 - name:"middle", - index: 2, - input: []int{1, 2, 3, 4, 5}, - wantVals: []int{1, 2, 4, 5}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - a := NewArrayList[int](12) - a.vals = tc.input - _ = a.DeleteAt(tc.index) - assert.Equal(t, tc.wantVals, a.vals) - }) - } -} diff --git a/advance/sync/pool_test.go b/advance/sync/pool_test.go deleted file mode 100644 index 00abf53b630f9c2bf045abc97f1deda8821dc325..0000000000000000000000000000000000000000 --- a/advance/sync/pool_test.go +++ /dev/null @@ -1,20 +0,0 @@ -package sync - -import ( - "sync" - "testing" -) - -func TestPool(t *testing.T) { - p := sync.Pool{ - New: func() interface{} { - // 创建函数,sync.Pool 会回调 - return nil - }, - } - - obj := p.Get() - // 在这里使用取出来的对象 - // 用完再还回去 - p.Put(obj) -} diff --git a/advance/sync/wait_group_test.go b/advance/sync/wait_group_test.go deleted file mode 100644 index 2a9587650cdd7d6424e856caa7c4d1688d0c9b40..0000000000000000000000000000000000000000 --- a/advance/sync/wait_group_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package sync - -import ( - "fmt" - "golang.org/x/sync/errgroup" - "sync" - "sync/atomic" - "testing" -) - -func TestWaitGroup(t *testing.T) { - wg := sync.WaitGroup{} - var result int64 = 0 - for i := 0; i < 10; i++ { - wg.Add(1) - go func(delta int) { - atomic.AddInt64(&result, int64(delta)) - defer wg.Done() - }(i) - } - - wg.Wait() - fmt.Println(result) -} - -func TestErrgroup(t *testing.T) { - eg := errgroup.Group{} - var result int64 = 0 - for i := 0; i < 10; i++ { - delta := i - eg.Go(func() error { - atomic.AddInt64(&result, int64(delta)) - return nil - }) - } - if err := eg.Wait(); err != nil { - t.Fatal(err) - } - fmt.Println(result) -} diff --git a/advance/unsafe/layout.go b/advance/unsafe/layout.go deleted file mode 100644 index 40466351fcd4e83dd67a974b3e786ac40b6c943d..0000000000000000000000000000000000000000 --- a/advance/unsafe/layout.go +++ /dev/null @@ -1,17 +0,0 @@ -package unsafe - -import ( - "fmt" - "reflect" -) - -// PrintFieldOffset 用来打印字段偏移量 -// 用于研究内存布局 -// 只接受结构体作为输入 -func PrintFieldOffset(entity any) { - typ := reflect.TypeOf(entity) - for i := 0; i < typ.NumField(); i++ { - fd := typ.Field(i) - fmt.Printf("%s: %d \n", fd.Name, fd.Offset) - } -} diff --git a/advance/unsafe/layout_test.go b/advance/unsafe/layout_test.go deleted file mode 100644 index 63ea7706884ce35e7c926bcaf039f200ffb8d64e..0000000000000000000000000000000000000000 --- a/advance/unsafe/layout_test.go +++ /dev/null @@ -1,19 +0,0 @@ -package unsafe - -import ( - "fmt" - "gitbub.com/flycash/geekbang-middle-camp/advance/unsafe/types" - "testing" - "unsafe" -) - -func TestPrintFieldOffset(t *testing.T) { - fmt.Println(unsafe.Sizeof(types.User{})) - PrintFieldOffset(types.User{}) - - fmt.Println(unsafe.Sizeof(types.UserV1{})) - PrintFieldOffset(types.UserV1{}) - - fmt.Println(unsafe.Sizeof(types.UserV2{})) - PrintFieldOffset(types.UserV2{}) -} diff --git a/advance/unsafe/types/user.go b/advance/unsafe/types/user.go deleted file mode 100644 index 6c11cdef3303e6582a5f7291f66fcdd15fcf5f8a..0000000000000000000000000000000000000000 --- a/advance/unsafe/types/user.go +++ /dev/null @@ -1,23 +0,0 @@ -package types - -type User struct { - Name string - age int32 - Alias []byte - Address string -} - -type UserV1 struct { - Name string - age int32 - agev1 int32 - Alias []byte - Address string -} - -type UserV2 struct { - Name string - Alias []byte - Address string - age int32 -} diff --git a/advance/unsafe/unsafe.go b/advance/unsafe/unsafe.go deleted file mode 100644 index 1a40f916b4afa7e7e068f7f66325ea1e469c4b5a..0000000000000000000000000000000000000000 --- a/advance/unsafe/unsafe.go +++ /dev/null @@ -1,68 +0,0 @@ -package unsafe - -import ( - "errors" - "fmt" - "reflect" - "unsafe" -) - -type FieldAccessor interface { - Field(field string) (int, error) - SetField(field string, val int) error -} - -type UnsafeAccessor struct { - fields map[string]FieldMeta - entityAddr unsafe.Pointer -} - -func NewUnsafeAccessor(entity interface{}) (*UnsafeAccessor, error) { - if entity == nil { - return nil, errors.New("invalid entity") - } - val := reflect.ValueOf(entity) - typ := reflect.TypeOf(entity) - val.UnsafeAddr() - if typ.Kind() != reflect.Pointer || typ.Elem().Kind() != reflect.Struct { - return nil, errors.New("invalid entity") - } - fields := make(map[string]FieldMeta, typ.Elem().NumField()) - elemType := typ.Elem() - for i := 0; i < elemType.NumField(); i++ { - fd := elemType.Field(i) - fields[fd.Name] = FieldMeta{offset: fd.Offset} - } - return &UnsafeAccessor{entityAddr: val.UnsafePointer(), fields: fields}, nil -} - -func (u *UnsafeAccessor) Field(field string) (int, error) { - fdMeta, ok := u.fields[field] - if !ok { - return 0, fmt.Errorf("invalid field %s", field) - } - ptr := unsafe.Pointer(uintptr(u.entityAddr) + fdMeta.offset) - if ptr == nil { - return 0, fmt.Errorf("invalid address of the field: %s", field) - } - res := *(*int)(ptr) - return res, nil -} - -func (u *UnsafeAccessor) SetField(field string, val int) error { - fdMeta, ok := u.fields[field] - if !ok { - return fmt.Errorf("invalid field %s", field) - } - ptr := unsafe.Pointer(uintptr(u.entityAddr) + fdMeta.offset) - if ptr == nil { - return fmt.Errorf("invalid address of the field: %s", field) - } - *(*int)(ptr) = val - return nil -} - -type FieldMeta struct { - // offset 后期在我们考虑组合,或者复杂类型字段的时候,它的含义衍生为表达相当于最外层的结构体的偏移量 - offset uintptr -} diff --git a/advance/unsafe/unsafe_test.go b/advance/unsafe/unsafe_test.go deleted file mode 100644 index 4f414a9794d6dc8902e1a82a65ca6be3ae492322..0000000000000000000000000000000000000000 --- a/advance/unsafe/unsafe_test.go +++ /dev/null @@ -1,76 +0,0 @@ -package unsafe - -import ( - "github.com/stretchr/testify/assert" - "testing" -) - -func TestUnsafeAccessor_Field(t *testing.T) { - testCases := []struct { - name string - entity interface{} - field string - wantVal int - wantErr error - }{ - { - name: "normal case", - entity: &User{Age: 18}, - field: "Age", - wantVal: 18, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - accessor, err := NewUnsafeAccessor(tc.entity) - if err != nil { - assert.Equal(t, tc.wantErr, err) - return - } - val, err := accessor.Field(tc.field) - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.wantVal, val) - }) - } -} - -func TestUnsafeAccessor_SetField(t *testing.T) { - testCases := []struct { - name string - entity *User - field string - newVal int - wantErr error - }{ - { - name: "normal case", - entity: &User{}, - field: "Age", - newVal: 18, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - accessor, err := NewUnsafeAccessor(tc.entity) - if err != nil { - assert.Equal(t, tc.wantErr, err) - return - } - err = accessor.SetField(tc.field, tc.newVal) - assert.Equal(t, tc.wantErr, err) - if err != nil { - return - } - assert.Equal(t, tc.newVal, tc.entity.Age) - }) - } -} - -type User struct { - Age int -}