From 139523410b1152cf011abbbd2639293d670e3270 Mon Sep 17 00:00:00 2001 From: Johney Xu Date: Sun, 4 Sep 2022 20:04:33 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E7=AC=AC=E4=B8=89=E6=AC=A1=E4=BD=9C?= =?UTF-8?q?=E4=B8=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- advance/template/gen/annotation/file.go | 44 ++++++- advance/template/gen/http/gen_http.go | 40 +++++- advance/template/gen/main.go | 124 +++++++++++++++++- .../gen/testdata/user_service_gen.txt | 2 +- 4 files changed, 195 insertions(+), 15 deletions(-) diff --git a/advance/template/gen/annotation/file.go b/advance/template/gen/annotation/file.go index 2b0cf3e..955a387 100644 --- a/advance/template/gen/annotation/file.go +++ b/advance/template/gen/annotation/file.go @@ -10,11 +10,20 @@ type SingleFileEntryVisitor struct { } func (s *SingleFileEntryVisitor) Get() File { - panic("implement me") + if s.file == nil { + return File{} + } + return s.file.Get() } func (s *SingleFileEntryVisitor) Visit(node ast.Node) ast.Visitor { - panic("implement me") + if file, ok := node.(*ast.File); ok { + s.file = &fileVisitor{ + ans: newAnnotations(file, file.Doc), + } + return s.file + } + return s } type fileVisitor struct { @@ -24,11 +33,27 @@ type fileVisitor struct { } func (f *fileVisitor) Get() File { - panic("implement me") + types := make([]Type, 0, len(f.types)) + for _, t := range f.types { + types = append(types, t.Get()) + } + return File{ + Annotations: f.ans, + Types: types, + } } func (f *fileVisitor) Visit(node ast.Node) ast.Visitor { - panic("implement me") + typ, ok := node.(*ast.TypeSpec) + if ok { + res := &typeVisitor{ + ans: newAnnotations(typ, typ.Doc), + fields: make([]Field, 0), + } + f.types = append(f.types, res) + return res + } + return f } type File struct { @@ -42,11 +67,18 @@ type typeVisitor struct { } func (t *typeVisitor) Get() Type { - panic("implement me") + return Type{ + Annotations: t.ans, + Fields: t.fields, + } } func (t *typeVisitor) Visit(node ast.Node) (w ast.Visitor) { - panic("implement me") + if fd, ok := node.(*ast.Field); ok { + t.fields = append(t.fields, Field{Annotations: newAnnotations(fd, fd.Doc)}) + return nil + } + return t } type Type struct { diff --git a/advance/template/gen/http/gen_http.go b/advance/template/gen/http/gen_http.go index b0cb6d3..11b5170 100644 --- a/advance/template/gen/http/gen_http.go +++ b/advance/template/gen/http/gen_http.go @@ -6,7 +6,45 @@ import ( ) // 这部分和课堂的很像,但是有一些地方被我改掉了 -const serviceTpl = ` +const serviceTpl = `package {{.Package}} + +import ( + "bytes" + "context" + "encoding/json" + "io/ioutil" + "net/http" +) + +{{ $service :=.GenName -}} +type {{ $service }} struct { + Endpoint string + Path string + Client http.Client +} +{{range $idx, $method := .Methods}} +func (s *{{$service}}) {{$method.Name}}(ctx context.Context, req *{{$method.ReqTypeName}}) (*{{$method.RespTypeName}}, error) { + url := s.Endpoint + s.Path + "/{{$method.Name}}" + bs, err := json.Marshal(req) + if err != nil { + return nil, err + } + body := &bytes.Buffer{} + body.Write(bs) + httpReq, err := http.NewRequestWithContext(ctx, "POST", url, body) + if err != nil { + return nil, err + } + httpResp, err := s.Client.Do(httpReq) + if err != nil { + return nil, err + } + bs, err = ioutil.ReadAll(httpResp.Body) + resp := &{{$method.RespTypeName}}{} + err = json.Unmarshal(bs, resp) + return resp, err +} +{{end}} ` func Gen(writer io.Writer, def ServiceDefinition) error { diff --git a/advance/template/gen/main.go b/advance/template/gen/main.go index 27fc4ec..36c445e 100644 --- a/advance/template/gen/main.go +++ b/advance/template/gen/main.go @@ -1,11 +1,21 @@ package main import ( + "bufio" + "bytes" + "errors" "fmt" - "gitee.com/geektime-geekbang/geektime-go/advance/template/gen/annotation" - "gitee.com/geektime-geekbang/geektime-go/advance/template/gen/http" + "go/ast" + "go/parser" + "go/token" "os" + "path" + "path/filepath" + "strings" "unicode" + + "gitee.com/geektime-geekbang/geektime-go/advance/template/gen/annotation" + "gitee.com/geektime-geekbang/geektime-go/advance/template/gen/http" ) // 实际上 main 函数这里要考虑接收参数 @@ -43,7 +53,25 @@ func gen(src string) error { // 根据 defs 来生成代码 // src 是源代码所在目录,在测试里面它是 ./testdata func genFiles(src string, defs []http.ServiceDefinition) error { - panic("implement me") + for _, def := range defs { + bs := &bytes.Buffer{} + err := http.Gen(bs, def) + file, err := os.OpenFile(path.Join(src, underscoreName(def.Name)+"_gen.go"), os.O_WRONLY|os.O_CREATE, 0666) + + if err != nil { + return err + } + defer file.Close() + + if err != nil { + fmt.Printf("open file error=%v\n", err) + return err + } + writer := bufio.NewWriter(file) + writer.Write(bs.Bytes()) + writer.Flush() + } + return nil } func parseFiles(srcFiles []string) ([]http.ServiceDefinition, error) { @@ -51,8 +79,15 @@ func parseFiles(srcFiles []string) ([]http.ServiceDefinition, error) { for _, src := range srcFiles { fmt.Println(src) // 你需要利用 annotation 里面的东西来扫描 src,然后生成 file - panic("implement me") - var file annotation.File + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, src, nil, parser.ParseComments) + if err != nil { + return nil, err + } + sfev := &annotation.SingleFileEntryVisitor{} + ast.Walk(sfev, f) + + file := sfev.Get() for _, typ := range file.Types { _, ok := typ.Annotations.Get("HttpClient") @@ -72,12 +107,87 @@ func parseFiles(srcFiles []string) ([]http.ServiceDefinition, error) { // 你需要利用 typ 来构造一个 http.ServiceDefinition // 注意你可能需要检测用户的定义是否符合你的预期 func parseServiceDefinition(pkg string, typ annotation.Type) (http.ServiceDefinition, error) { - panic("implement me") + result := &http.ServiceDefinition{ + Package: pkg, + } + for _, a := range typ.Annotations.Ans { + if a.Key == "ServiceName" { + result.Name = a.Value + } + } + if result.Name == "" { + result.Name = typ.Annotations.Node.Name.Name + } + method := &http.ServiceMethod{} + fields := typ.Fields + for _, field := range fields { + method.Name = field.Annotations.Node.Names[0].Name + + mAns := field.Annotations.Ans + for _, ma := range mAns { + if ma.Key == "Path" { + method.Path = ma.Value + } + } + if method.Path == "" { + method.Path = "/" + method.Name + } + params := field.Annotations.Node.Type.(*ast.FuncType).Params + if len(params.List) != 2 { + return *result, errors.New("gen: 方法必须接收两个参数,其中第一个参数是 context.Context,第二个参数请求") + } + method.ReqTypeName = params.List[1].Type.(*ast.StarExpr).X.(*ast.Ident).Name + results := field.Annotations.Node.Type.(*ast.FuncType).Results + if len(results.List) != 2 { + return *result, errors.New("gen: 方法必须返回两个参数,其中第一个返回值是响应,第二个返回值是error") + } + method.RespTypeName = results.List[0].Type.(*ast.StarExpr).X.(*ast.Ident).Name + result.Methods = append(result.Methods, *method) + } + return *result, nil } // 返回符合条件的 Go 源代码文件,也就是你要用 AST 来分析这些文件的代码 func scanFiles(src string) ([]string, error) { - panic("implement me") + srcFiles := make([]string, 0, 10) + files, err := os.ReadDir(src) + if err != nil { + return nil, err + } + for _, file := range files { + if strings.HasSuffix(file.Name(), ".go") && + !strings.HasSuffix(file.Name(), "_test.go") && + !strings.HasSuffix(file.Name(), "gen.go") { + src, err = filepath.Abs(src) + if err != nil { + return nil, err + } + srcFiles = append(srcFiles, filepath.Join(src, file.Name())) + } + } + return srcFiles, nil + //srcAbs, err := filepath.Abs(src) + //if err != nil { + // return nil, err + //} + //files := make([]string, 0) + //if err := filepath.Walk(src, func(filePath string, f os.FileInfo, err error) error { + // if f == nil { + // return err + // } + // if f.IsDir() { + // return nil + // } + // if strings.HasSuffix(f.Name(), ".go") && + // !strings.HasSuffix(f.Name(), "_test.go") && + // !strings.HasSuffix(f.Name(), "gen.go") { + // files = append(files, path.Join(srcAbs, filePath)) + // } + // return nil + //}); err != nil { + // return nil, err + //} + //return files, nil } // underscoreName 驼峰转字符串命名,在决定生成的文件名的时候需要这个方法 diff --git a/advance/template/gen/testdata/user_service_gen.txt b/advance/template/gen/testdata/user_service_gen.txt index 031e0a4..35f7c40 100644 --- a/advance/template/gen/testdata/user_service_gen.txt +++ b/advance/template/gen/testdata/user_service_gen.txt @@ -37,7 +37,7 @@ func (s *UserServiceGen) Get(ctx context.Context, req *GetUserReq) (*GetUserResp } func (s *UserServiceGen) Update(ctx context.Context, req *UpdateUserReq) (*UpdateUserResp, error) { - url := s.Endpoint + s.Path + "/user/update" + url := s.Endpoint + s.Path + "/Update" bs, err := json.Marshal(req) if err != nil { return nil, err -- Gitee From 4f791293f4121af4db70f090ca018176a6d28b39 Mon Sep 17 00:00:00 2001 From: Johney Xu Date: Mon, 26 Sep 2022 20:00:04 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E7=AC=AC=E4=B8=83=E5=91=A8=E4=BD=9C?= =?UTF-8?q?=E4=B8=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- orm/homework1/aggregate.go | 18 +++++-- orm/homework1/select.go | 99 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 110 insertions(+), 7 deletions(-) diff --git a/orm/homework1/aggregate.go b/orm/homework1/aggregate.go index e34a8f7..8bdcb1c 100644 --- a/orm/homework1/aggregate.go +++ b/orm/homework1/aggregate.go @@ -21,15 +21,27 @@ func (a Aggregate) As(alias string) Aggregate { // EQ 例如 C("id").Eq(12) func (a Aggregate) EQ(arg any) Predicate { - panic("implement me") + return Predicate{ + left: a, + op: opEQ, + right: exprOf(arg), + } } func (a Aggregate) LT(arg any) Predicate { - panic("implement me") + return Predicate{ + left: a, + op: opLT, + right: exprOf(arg), + } } func (a Aggregate) GT(arg any) Predicate { - panic("implement me") + return Predicate{ + left: a, + op: opGT, + right: exprOf(arg), + } } func Avg(c string) Aggregate { diff --git a/orm/homework1/select.go b/orm/homework1/select.go index b6295eb..a1f4624 100644 --- a/orm/homework1/select.go +++ b/orm/homework1/select.go @@ -67,7 +67,47 @@ func (s *Selector[T]) Build() (*Query, error) { } } - panic("implement me") + if len(s.orderBy) > 0 { + s.sb.WriteString(" ORDER BY ") + if err = s.buildOrderBy(); err != nil { + return nil, err + } + } + + if len(s.groupBy) > 0 { + s.sb.WriteString(" GROUP BY ") + for i, c := range s.groupBy { + if i > 0 { + s.sb.WriteByte(',') + } + if err = s.buildColumn(c.name, c.alias); err != nil { + return nil, err + } + } + } + + if len(s.having) > 0 { + s.sb.WriteString(" HAVING ") + if err = s.buildPredicates(s.having); err != nil { + return nil, err + } + } + + if s.limit > 0 { + s.sb.WriteString(" LIMIT ?") + s.addArgs(s.limit) + } + + if s.offset > 0 { + s.sb.WriteString(" OFFSET ?") + s.addArgs(s.offset) + } + + //panic("implement me") + + if s.having != nil { + + } s.sb.WriteString(";") return &Query{ @@ -159,7 +199,52 @@ func (s *Selector[T]) buildColumn(c string, alias string) error { } func (s *Selector[T]) buildExpression(e Expression) error { - panic("implement me") + if e == nil { + return nil + } + switch exp := e.(type) { + case Column: + return s.buildColumn(exp.name, exp.alias) + case Aggregate: + return s.buildAggregate(exp, false) + case value: + s.sb.WriteByte('?') + s.addArgs(exp.val) + case Predicate: + _, lp := exp.left.(Predicate) + if lp { + s.sb.WriteByte('(') + } + if err := s.buildExpression(exp.left); err != nil { + return err + } + if lp { + s.sb.WriteByte(')') + } + + // 可能只有左边 + if exp.op == "" { + return nil + } + + s.sb.WriteByte(' ') + s.sb.WriteString(exp.op.String()) + s.sb.WriteByte(' ') + + _, rp := exp.right.(Predicate) + if rp { + s.sb.WriteByte('(') + } + if err := s.buildExpression(exp.right); err != nil { + return err + } + if rp { + s.sb.WriteByte(')') + } + default: + return errs.NewErrUnsupportedExpressionType(exp) + } + return nil } // Where 用于构造 WHERE 查询条件。如果 ps 长度为 0,那么不会构造 WHERE 部分 @@ -271,9 +356,15 @@ type OrderBy struct { } func Asc(col string) OrderBy { - panic("implement me") + return OrderBy{ + col: col, + order: "ASC", + } } func Desc(col string) OrderBy { - panic("implement me") + return OrderBy{ + col: col, + order: "DESC", + } } -- Gitee