214 lines
4.3 KiB
Go
214 lines
4.3 KiB
Go
package zm_proto
|
|
|
|
import (
|
|
"fmt"
|
|
"github.com/emicklei/proto"
|
|
"os"
|
|
"path/filepath"
|
|
)
|
|
|
|
type Field struct {
|
|
Name string
|
|
Type string
|
|
Repeated bool
|
|
}
|
|
|
|
type Proto struct {
|
|
Imports map[string]Import
|
|
|
|
Enum map[string]bool
|
|
|
|
Dir []string
|
|
}
|
|
|
|
func NewProto() *Proto {
|
|
return &Proto{
|
|
Imports: make(map[string]Import),
|
|
Enum: make(map[string]bool),
|
|
}
|
|
}
|
|
|
|
func (p *Proto) AddImport(name string, imp Import) {
|
|
p.Imports[name] = imp
|
|
p.AddEnum(imp.Enum)
|
|
}
|
|
|
|
func (p *Proto) AddEnum(name map[string]bool) {
|
|
for k := range name {
|
|
p.Enum[k] = true
|
|
}
|
|
}
|
|
|
|
func (p *Proto) Include(dir ...string) *Proto {
|
|
p.Dir = append(p.Dir, dir...)
|
|
return p
|
|
}
|
|
|
|
func (p *Proto) FindPathByDir(protoImport string) string {
|
|
for _, dir := range p.Dir {
|
|
path := filepath.Join(dir, protoImport)
|
|
_, err := os.Stat(path)
|
|
if err == nil {
|
|
return path
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (p *Proto) ParseAllImport() error {
|
|
for _, dir := range p.Dir {
|
|
readDir, err := os.ReadDir(dir)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read dir %s: %w", dir, err)
|
|
}
|
|
for _, f := range readDir {
|
|
if f.IsDir() {
|
|
continue
|
|
}
|
|
if filepath.Ext(f.Name()) != ".proto" {
|
|
continue
|
|
}
|
|
if p.Imports[f.Name()].Name != "" {
|
|
continue
|
|
}
|
|
if err = p.parseImport(filepath.Join(dir, f.Name())); err != nil {
|
|
return fmt.Errorf("failed to parse file %s: %w", f.Name(), err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Proto) ParseImport(protoImport string) error {
|
|
path := p.FindPathByDir(protoImport)
|
|
if path == "" {
|
|
return fmt.Errorf("import %s not found", path)
|
|
}
|
|
err := p.parseImport(path)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse import %s: %w", path, err)
|
|
}
|
|
for _, i := range p.Imports[protoImport].Imports {
|
|
if err = p.ParseImport(i); err != nil {
|
|
return fmt.Errorf("failed to parse import %s: %w", i, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Proto) parseImport(path string) error {
|
|
//ext := filepath.Ext(path)
|
|
//name := filepath.Base(path)[:len(filepath.Base(path))-len(ext)]
|
|
|
|
name := filepath.Base(path)
|
|
reader, err := os.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer reader.Close()
|
|
|
|
parser := proto.NewParser(reader)
|
|
definition, err := parser.Parse()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var msgs []*Message
|
|
var enums = make(map[string]bool)
|
|
var imports []string
|
|
|
|
var imp = Import{
|
|
Name: name,
|
|
Path: path,
|
|
Proto: p,
|
|
}
|
|
proto.Walk(definition,
|
|
proto.WithImport(func(i *proto.Import) {
|
|
//ext1 := filepath.Ext(i.Filename)
|
|
//name1 := filepath.Base(i.Filename)[:len(filepath.Base(i.Filename))-len(ext1)]
|
|
imports = append(imports, i.Filename)
|
|
}),
|
|
|
|
proto.WithEnum(func(enum *proto.Enum) {
|
|
enums[enum.Name] = true
|
|
}),
|
|
proto.WithNormalField(func(field *proto.NormalField) {
|
|
lastmsg := msgs[len(msgs)-1]
|
|
lastmsg.Fields = append(lastmsg.Fields, Field{
|
|
Name: field.Name,
|
|
Type: field.Type,
|
|
Repeated: field.Repeated,
|
|
})
|
|
|
|
}),
|
|
proto.WithMessage(func(message *proto.Message) {
|
|
msgs = append(msgs, &Message{
|
|
Name: message.Name,
|
|
Fields: []Field{},
|
|
Import: &imp,
|
|
})
|
|
}))
|
|
|
|
imp.Imports = imports
|
|
imp.Enum = enums
|
|
|
|
for _, msg := range msgs {
|
|
imp.Messages = append(imp.Messages, *msg)
|
|
}
|
|
|
|
p.AddImport(name, imp)
|
|
return nil
|
|
}
|
|
|
|
var defaultTypeMap = map[string]any{
|
|
"int32": 0,
|
|
"int64": 0,
|
|
"bool": false,
|
|
"string": "ok",
|
|
"byte": "byte",
|
|
}
|
|
|
|
func (p *Proto) MessageToMap(name string) (string, map[string]any) {
|
|
var msg = make(map[string]any)
|
|
var erlMod = ""
|
|
for _, i := range p.Imports {
|
|
for _, m := range i.Messages {
|
|
if m.Name == name {
|
|
ext := filepath.Ext(i.Name)
|
|
erlMod = filepath.Base(i.Name)[:len(filepath.Base(i.Name))-len(ext)]
|
|
|
|
for _, field := range m.Fields {
|
|
defaultValue := defaultTypeMap[field.Type]
|
|
if defaultValue == nil {
|
|
_, defaultValue = p.MessageToMap(field.Type)
|
|
}
|
|
if field.Repeated {
|
|
if defaultValue == nil {
|
|
defaultValue = []any{}
|
|
}
|
|
}
|
|
msg[field.Name] = defaultValue
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if len(msg) == 0 || erlMod == "" {
|
|
return "single_str", nil
|
|
}
|
|
return erlMod, msg
|
|
}
|
|
|
|
// MessageToErlMod z
|
|
func (p *Proto) MessageToErlMod(name string) string {
|
|
for _, i := range p.Imports {
|
|
for _, m := range i.Messages {
|
|
if m.Name == name {
|
|
ext := filepath.Ext(i.Name)
|
|
erlMod := filepath.Base(i.Name)[:len(filepath.Base(i.Name))-len(ext)]
|
|
return erlMod
|
|
}
|
|
}
|
|
}
|
|
return "single_str"
|
|
}
|