From 84149a54de0e24e7806d0ec4919e29ffcfc296bd Mon Sep 17 00:00:00 2001 From: shine <1042864399@qq.com> Date: Fri, 26 Sep 2025 16:16:05 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=AF=B9=20proto=20?= =?UTF-8?q?=E7=9A=84=E8=A7=A3=E6=9E=90=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/pb2port.go | 127 ++++++++++++++++++++ go.mod | 12 +- go.sum | 3 + parser/zm_proto/import.go | 11 ++ parser/zm_proto/message.go | 12 ++ parser/zm_proto/port.go | 63 ++++++++++ parser/zm_proto/proto.go | 213 ++++++++++++++++++++++++++++++++++ parser/zm_proto/proto_test.go | 13 +++ template/template_engine.go | 4 + worker/proto_to_erlang.go | 74 ++++++++++++ 10 files changed, 528 insertions(+), 4 deletions(-) create mode 100644 cmd/pb2port.go create mode 100644 parser/zm_proto/import.go create mode 100644 parser/zm_proto/message.go create mode 100644 parser/zm_proto/port.go create mode 100644 parser/zm_proto/proto.go create mode 100644 parser/zm_proto/proto_test.go create mode 100644 worker/proto_to_erlang.go diff --git a/cmd/pb2port.go b/cmd/pb2port.go new file mode 100644 index 0000000..3db2218 --- /dev/null +++ b/cmd/pb2port.go @@ -0,0 +1,127 @@ +package cmd + +import ( + "complie-erlang/config" + "complie-erlang/worker" + "fmt" + "github.com/spf13/cobra" + "log" + "os" + "path/filepath" + "strings" +) + +type Pb2Port struct { + debug bool + + protoPath string // proto 文件 + out string // 输出文件 + + tplDir string // 模版目录 + author string // 作者 + mainTemplate string // 默认模版 +} + +func (s *Pb2Port) run(_ *cobra.Command, arg []string) { + + if len(arg) == 0 { + log.Println("请输入功能名称") + return + } + + worker1 := worker.NewProto2ErlangWorker() + + // 获取可执行文件所在目录 + exePath, err := os.Executable() + if err != nil { + log.Fatalf("获取可执行文件路径失败: %v", err) + return + } + if err := worker1.LoadTemplates(filepath.Join(filepath.Dir(exePath), s.tplDir)); err != nil { + log.Printf("Err 加载模版报错: %v", err) + return + } + + defaultArgs := []config.DefaultArg{ + { + Key: "Author", + Value: s.author, + }, + } + + args, err := worker1.LoadTemplatesArgs(s.protoPath, defaultArgs) + if err != nil { + log.Printf("Err 读取配置文件失败: %v path: %s", err, s.protoPath) + return + } + + args["Module"] = filepath.Base(s.out)[:len(filepath.Base(s.out))-len(filepath.Ext(s.out))] + arg = append(arg, "port") + args["Desc"] = arg + + template, err := worker1.ExecuteTemplate(s.mainTemplate, args) + if err != nil { + log.Printf("Err 模版生成出错: %v", err) + return + } + + if s.debug { + for key, arg1 := range args { + fmt.Println(key, "-->", arg1) + } + } + + _, err = os.Stat(s.out) + if err != nil && !os.IsNotExist(err) { + log.Printf("Err 文件错误: %v", err) + return + } + if err == nil { + s.out = filepath.Base(s.out)[:len(filepath.Base(s.out))-len(filepath.Ext(s.out))] + s.out = s.out + "_gen.erl" + } + + if s.debug { + fmt.Println("template:", template) + return + } + + if err := os.WriteFile(s.out, []byte(template), 0644); err != nil { + log.Printf("Err 写入文件失败: %v", err) + return + } + fmt.Println("ok") +} +func init() { + var singleSet = new(Pb2Port) + var logsCmd = &cobra.Command{ + Use: "init", + Short: "根据proto 文件构建 功能模版", + Long: `构建功能数据`, + Run: singleSet.run, + } + + var ( + out = "" + protoPath = "" + ) + + // 写入默认数据 + if currentDir, err := os.Getwd(); err == nil { + out = fmt.Sprintf("%s_port.erl", filepath.Base(currentDir)) + pluginSpilt := strings.Split(currentDir, "plugin") + if len(pluginSpilt) > 0 { + pluginSpilt = pluginSpilt[:len(pluginSpilt)-1] + protoPath = filepath.Join(strings.Join(pluginSpilt, "plugin"), fmt.Sprintf("\\gpb\\game\\pro_%s.proto", filepath.Base(currentDir))) + } + } + + logsCmd.PersistentFlags().BoolVar(&singleSet.debug, "debug", false, "是否启动调试模式") + logsCmd.PersistentFlags().StringVar(&singleSet.protoPath, "proto", protoPath, "读取文件") + logsCmd.PersistentFlags().StringVar(&singleSet.out, "out", out, "输出文件") + logsCmd.PersistentFlags().StringVar(&singleSet.author, "author", "st,sutong@youkia.net", "作者") + logsCmd.PersistentFlags().StringVar(&singleSet.mainTemplate, "main_tpl", "ErlangPort", "主模版") + logsCmd.PersistentFlags().StringVar(&singleSet.tplDir, "tpl", "./templates/*.tpl", "模版地址") + + rootCmd.AddCommand(logsCmd) +} diff --git a/go.mod b/go.mod index 2da4073..dedd18b 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,12 @@ module complie-erlang go 1.23.4 require ( - github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/spf13/cobra v1.10.1 // indirect - github.com/spf13/pflag v1.0.9 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect + github.com/emicklei/proto v1.14.2 + github.com/spf13/cobra v1.10.1 + gopkg.in/yaml.v2 v2.4.0 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.9 // indirect ) diff --git a/go.sum b/go.sum index 584657f..6d7ecf6 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/emicklei/proto v1.14.2 h1:wJPxPy2Xifja9cEMrcA/g08art5+7CGJNFNk35iXC1I= +github.com/emicklei/proto v1.14.2/go.mod h1:rn1FgRS/FANiZdD2djyH7TMA9jdRDcYQ9IEN9yvjX0A= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -6,6 +8,7 @@ github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= diff --git a/parser/zm_proto/import.go b/parser/zm_proto/import.go new file mode 100644 index 0000000..10c934f --- /dev/null +++ b/parser/zm_proto/import.go @@ -0,0 +1,11 @@ +package zm_proto + +type Import struct { + Name string + Path string + Messages []Message + Imports []string + Enum map[string]bool + + Proto *Proto +} diff --git a/parser/zm_proto/message.go b/parser/zm_proto/message.go new file mode 100644 index 0000000..1b249ae --- /dev/null +++ b/parser/zm_proto/message.go @@ -0,0 +1,12 @@ +package zm_proto + +type Message struct { + Name string + Fields []Field + + Import *Import +} + +func (m *Message) GetImportName() string { + return m.Import.Name +} diff --git a/parser/zm_proto/port.go b/parser/zm_proto/port.go new file mode 100644 index 0000000..05e3ff2 --- /dev/null +++ b/parser/zm_proto/port.go @@ -0,0 +1,63 @@ +package zm_proto + +import ( + "fmt" + "os" + "regexp" + "strings" +) + +type Port struct { + Cmd string // 地址 + PortDesc string // 描述 + ClientPB string // + ServerPB string + ConnectType string + IsPush bool + Parallel bool + + Proto *Proto +} + +func ParsePort(Proto *Proto, path string) ([]Port, error) { + bytes, err := os.ReadFile(path) + if err != nil { + return nil, err + } + mustPort := regexp.MustCompile(`/\*/*Cmd=.+\*/`) + + result := mustPort.FindAllStringSubmatch(string(bytes), -1) + //fmt.Println("result", result) + var ports []Port + for _, res := range result { + for _, match := range res { + match = match[2 : len(match)-2] + // fmt.Println("match", match) + var p = Port{Proto: Proto} + for _, fieldStr := range strings.Split(match, ";") { + fieldslice := strings.Split(fieldStr, "=") + if len(fieldslice) != 2 { + return nil, fmt.Errorf("invalid field: %s, match: %s", fieldStr, match) + } + switch fieldslice[0] { + case "Cmd": + p.Cmd = fieldslice[1] + case "PortDesc": + p.PortDesc = fieldslice[1] + case "ClientPB": + p.ClientPB = fieldslice[1] + case "ServerPB": + p.ServerPB = fieldslice[1] + case "ConnectType": + p.ConnectType = fieldslice[1] + case "IsPush": + p.IsPush = fieldslice[1] == "true" + case "Parallel": + p.Parallel = fieldslice[1] == "true" + } + } + ports = append(ports, p) + } + } + return ports, nil +} diff --git a/parser/zm_proto/proto.go b/parser/zm_proto/proto.go new file mode 100644 index 0000000..85f5bb7 --- /dev/null +++ b/parser/zm_proto/proto.go @@ -0,0 +1,213 @@ +package zm_proto + +import ( + "fmt" + "github.com/emicklei/proto" + "os" + "path/filepath" +) + +type Field struct { + Name string + Type string + Repeated bool +} + +type Proto struct { + Imports map[string]Import + + Enum map[string]bool + + Dir []string +} + +func NewProto() *Proto { + return &Proto{ + Imports: make(map[string]Import), + Enum: make(map[string]bool), + } +} + +func (p *Proto) AddImport(name string, imp Import) { + p.Imports[name] = imp + p.AddEnum(imp.Enum) +} + +func (p *Proto) AddEnum(name map[string]bool) { + for k := range name { + p.Enum[k] = true + } +} + +func (p *Proto) Include(dir ...string) *Proto { + p.Dir = append(p.Dir, dir...) + return p +} + +func (p *Proto) FindPathByDir(protoImport string) string { + for _, dir := range p.Dir { + path := filepath.Join(dir, protoImport) + _, err := os.Stat(path) + if err == nil { + return path + } + } + return "" +} + +func (p *Proto) ParseAllImport() error { + for _, dir := range p.Dir { + readDir, err := os.ReadDir(dir) + if err != nil { + return fmt.Errorf("failed to read dir %s: %w", dir, err) + } + for _, f := range readDir { + if f.IsDir() { + continue + } + if filepath.Ext(f.Name()) != ".proto" { + continue + } + if p.Imports[f.Name()].Name != "" { + continue + } + if err = p.parseImport(filepath.Join(dir, f.Name())); err != nil { + return fmt.Errorf("failed to parse file %s: %w", f.Name(), err) + } + } + } + return nil +} + +func (p *Proto) ParseImport(protoImport string) error { + path := p.FindPathByDir(protoImport) + if path == "" { + return fmt.Errorf("import %s not found", path) + } + err := p.parseImport(path) + if err != nil { + return fmt.Errorf("failed to parse import %s: %w", path, err) + } + for _, i := range p.Imports[protoImport].Imports { + if err = p.ParseImport(i); err != nil { + return fmt.Errorf("failed to parse import %s: %w", i, err) + } + } + return nil +} + +func (p *Proto) parseImport(path string) error { + //ext := filepath.Ext(path) + //name := filepath.Base(path)[:len(filepath.Base(path))-len(ext)] + + name := filepath.Base(path) + reader, err := os.Open(path) + if err != nil { + return err + } + defer reader.Close() + + parser := proto.NewParser(reader) + definition, err := parser.Parse() + if err != nil { + return err + } + + var msgs []*Message + var enums = make(map[string]bool) + var imports []string + + var imp = Import{ + Name: name, + Path: path, + Proto: p, + } + proto.Walk(definition, + proto.WithImport(func(i *proto.Import) { + //ext1 := filepath.Ext(i.Filename) + //name1 := filepath.Base(i.Filename)[:len(filepath.Base(i.Filename))-len(ext1)] + imports = append(imports, i.Filename) + }), + + proto.WithEnum(func(enum *proto.Enum) { + enums[enum.Name] = true + }), + proto.WithNormalField(func(field *proto.NormalField) { + lastmsg := msgs[len(msgs)-1] + lastmsg.Fields = append(lastmsg.Fields, Field{ + Name: field.Name, + Type: field.Type, + Repeated: field.Repeated, + }) + + }), + proto.WithMessage(func(message *proto.Message) { + msgs = append(msgs, &Message{ + Name: message.Name, + Fields: []Field{}, + Import: &imp, + }) + })) + + imp.Imports = imports + imp.Enum = enums + + for _, msg := range msgs { + imp.Messages = append(imp.Messages, *msg) + } + + p.AddImport(name, imp) + return nil +} + +var defaultTypeMap = map[string]any{ + "int32": 0, + "int64": 0, + "bool": false, + "string": "ok", + "byte": "byte", +} + +func (p *Proto) MessageToMap(name string) (string, map[string]any) { + var msg = make(map[string]any) + var erlMod = "" + for _, i := range p.Imports { + for _, m := range i.Messages { + if m.Name == name { + ext := filepath.Ext(i.Name) + erlMod = filepath.Base(i.Name)[:len(filepath.Base(i.Name))-len(ext)] + + for _, field := range m.Fields { + defaultValue := defaultTypeMap[field.Type] + if defaultValue == nil { + _, defaultValue = p.MessageToMap(field.Type) + } + if field.Repeated { + if defaultValue == nil { + defaultValue = []any{} + } + } + msg[field.Name] = defaultValue + } + } + } + } + if len(msg) == 0 || erlMod == "" { + return "single_str", nil + } + return erlMod, msg +} + +// MessageToErlMod z +func (p *Proto) MessageToErlMod(name string) string { + for _, i := range p.Imports { + for _, m := range i.Messages { + if m.Name == name { + ext := filepath.Ext(i.Name) + erlMod := filepath.Base(i.Name)[:len(filepath.Base(i.Name))-len(ext)] + return erlMod + } + } + } + return "single_str" +} diff --git a/parser/zm_proto/proto_test.go b/parser/zm_proto/proto_test.go new file mode 100644 index 0000000..dae5d7f --- /dev/null +++ b/parser/zm_proto/proto_test.go @@ -0,0 +1,13 @@ +package zm_proto + +import ( + "fmt" + "testing" +) + +func TestP1(t *testing.T) { + pb := NewProto() + err := pb.ParseImport("game\\pro_array.proto") + fmt.Println(err) + fmt.Println(pb) +} diff --git a/template/template_engine.go b/template/template_engine.go index cdfc85d..bb5e624 100644 --- a/template/template_engine.go +++ b/template/template_engine.go @@ -14,6 +14,10 @@ var DefaultFuncMap = template.FuncMap{ "title": strings.Title, "join": strings.Join, "list": func(strs []string) string { return strings.Join(strs, ",") }, + "cmd2func": func(str string) string { + splits := strings.Split(str, "/") + return splits[len(splits)-1] + }, } func hd(str []string) string { diff --git a/worker/proto_to_erlang.go b/worker/proto_to_erlang.go new file mode 100644 index 0000000..385702c --- /dev/null +++ b/worker/proto_to_erlang.go @@ -0,0 +1,74 @@ +package worker + +import ( + "complie-erlang/config" + "complie-erlang/parser/zm_proto" + "complie-erlang/template" + "path/filepath" + "time" +) + +type Proto2ErlangWorker struct { + *zm_proto.Proto + Template *template.Template +} + +func NewProto2ErlangWorker() *Proto2ErlangWorker { + return &Proto2ErlangWorker{ + zm_proto.NewProto(), + template.NewTemplate(), + } +} + +// LoadTemplates 加载模版 +func (p *Proto2ErlangWorker) LoadTemplates(templatePath string) error { + return p.Template.ParseGlob(templatePath) +} + +// LoadTemplatesArgs 解析模版数据 +func (p *Proto2ErlangWorker) LoadTemplatesArgs(protoPath string, DefaultArgs []config.DefaultArg) (map[string]any, error) { + p.Include(filepath.Dir(protoPath)) + + if err := p.ParseImport(filepath.Base(protoPath)); err != nil { + return nil, err + } + parsePort, err := zm_proto.ParsePort(p.Proto, protoPath) + if err != nil { + return nil, err + } + + var templatArgs = make(map[string]any) + var defaultArgs = make(map[string]any) + + for _, defaultArg := range DefaultArgs { + defaultArgs[defaultArg.Key] = defaultArg.Value + } + templatArgs["default"] = defaultArgs + + var ports []any + for _, port := range parsePort { + if port.IsPush { + continue + } + var portMap = map[string]any{ + "Cmd": port.Cmd, + "PortDesc": port.PortDesc, + "ClientPB": port.ClientPB, + "ServerPB": port.ServerPB, + "ClientProto": p.Proto.MessageToErlMod(port.ClientPB), + "ServerProto": p.Proto.MessageToErlMod(port.ServerPB), + } + ports = append(ports, portMap) + } + templatArgs["ports"] = ports + + // 其他默认值 + templatArgs["CreateAt"] = time.Now().Format(time.DateTime) + + return templatArgs, nil +} + +// ExecuteTemplate 组装模版 +func (p *Proto2ErlangWorker) ExecuteTemplate(template string, templatArgs map[string]any) (string, error) { + return p.Template.ExecuteTemplate(template, templatArgs) +}