-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathxuanwu.py
155 lines (123 loc) · 4.55 KB
/
xuanwu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import base
from ptsd.loader import Loader
from Cheetah.Template import Template
from os import path, mkdir
import os
if len(sys.argv) != 3:
print "usage: \n\tpython xuanwu.py thrift_file_path output_folder_path"
sys.exit()
namespace = ""
thrift_file = sys.argv[1]
out_path = sys.argv[2]
filename = ".".join(path.basename(thrift_file).split(".")[:-1])
if not out_path.endswith(path.sep):
out_path = out_path + path.sep
try:
src_path = out_path.replace("\\", "/")
src_path = src_path[src_path.index("/src/")+5:].strip("/")
except ValueError:
print "output_folder_path should contains '/src/', for xuanwu to use absolute go path import"
sys.exit()
def fieldElem(field, key):
for att in field.annotations:
if att.name.value == key:
return att.value.value
return ""
def struct_import(obj):
idField = obj.fields[0]
obj.imports = set(["bytes", "fmt"])
if obj.search != None:
obj.imports.add("github.com/mattbaird/elastigo/core")
searchIndex = fieldElem(idField, "searchIndex")
if searchIndex == "":
obj.searchIndex = obj.name.value.lower()
obj.searchType = "simple"
else:
obj.searchIndex = searchIndex
obj.searchType = obj.name.value.lower()
if obj.searchIndex == "flow":
obj.imports.add("encoding/json")
obj.imports.add("github.com/mattbaird/elastigo/api")
for f in obj.fields:
if f.foreign_package != "":
obj.imports.add(f.foreign_package)
if hasattr(f, "bindPackage"):
if f.bindPackage and f.bindPackage != "":
obj.imports.add(f.bindPackage)
for field in obj.fields:
if hasattr(field, "rule"):
obj.imports.add("regexp")
if hasattr(field, "stringList") or hasattr(field, "enums"):
import_module = filename
if src_path:
import_module = src_path + "/" + import_module
if field.type in ["i32", "i64", "bool", "double"] and field.widget_type not in ["date", "time", "datetime"]:
obj.imports.add("strconv")
if field.type == "string" and field.widget_type in ("richtext", "textarea", "opinion"):
obj.imports.add("regexp")
if field.type == "list<string>":
obj.imports.add("strings")
if field.widget_type in ["date", "time", "datetime"]:
obj.imports.add("time")
if hasattr(field, "meta"):
obj.imports.add("encoding/json")
obj.stringFilterFields = [f for f in obj.filterFields if f.type in ["string", "list<string>"]]
obj.need_mapping = len(obj.stringFilterFields) > 0
obj.need_index = len([i for i in obj.fields if hasattr(i, "index")]) > 0
obj.need_searchmore = len([f for f in obj.filterFields if f.type in ["list<string>"]]) > 0
if obj.need_mapping or obj.searchIndex == "flow":
obj.imports.add("github.com/mattbaird/elastigo/indices")
def write_file(fname, content):
dir = path.dirname(fname)
if not path.exists(dir):
os.makedirs(dir)
with open(fname, "w") as f:
f.write(content)
def transform_struct(obj):
struct_import(obj)
tpl = open('tmpl/go.tmpl', 'r').read().decode("utf8")
t = Template(tpl, searchList=[{"namespace": namespace, "filename": filename, "obj": obj}])
code = str(t)
write_file(out_path + namespace + '/gen_' + obj.name.value.lower() + ".go", code)
tpl = open('tmpl/model_web.tmpl', 'r').read().decode("utf8")
t = Template(tpl, searchList=[{"namespace": namespace, "filename": filename, "obj": obj}])
code = str(t)
write_file(out_path + namespace + '/web/gen_' + obj.name.value.lower() + ".go", code)
def transform(module):
for struct in module.structs:
transform_struct(struct)
if len(module.consts) > 0:
tpl = open('tmpl/go_const.tmpl', 'r').read()
t = Template(tpl, searchList=[{"namespace": namespace, "objs": module.consts}])
if not path.exists(out_path + namespace):
mkdir(out_path + namespace)
with open(out_path + "%s/gen_%s_const.go" % (namespace, namespace), "w") as fp:
fp.write(str(t))
if len(module.enums) > 0:
tpl = open('tmpl/go_enum.tmpl', 'r').read()
t = Template(tpl, searchList=[{
"namespace": namespace,
"objs": module.enums,
"name": filename,
}])
if not path.exists(out_path + namespace):
mkdir(out_path + namespace)
with open(out_path + "%s/gen_%s_enum.go" % (namespace, namespace), "w") as fp:
fp.write(str(t))
def main(thrift_idl):
loader = base.load_thrift(thrift_idl)
global namespace
namespace = loader.namespace
tpl = open('tmpl/go_package.tmpl', 'r').read()
t = Template(tpl, searchList=[{"namespace": namespace}])
code = unicode(t)
if not path.exists(out_path + namespace):
mkdir(out_path + namespace)
with open(out_path + namespace + '/gen_init.go', "w") as fp:
fp.write(code)
for module in loader.modules.values():
transform(module)
main(thrift_file)