gen-golang-code.py 7.2 KB


  1. #coding:utf8
  2. '''
  3. 使用 desc table; 解释MySQL的表结构。生成 golang- beego框架ORM的结构体定义
  4. see doc: http://beego.me/docs/mvc/model/models.md
  5. '''
  6. import sys
  7. import logging
  8. import time
  9. import getopt
  10. import _mysql
  11. def now():
  12. return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
  13. def escape(var):
  14. '''这里连接数据库都是使用utf8的。'''
  15. if var is None:
  16. return ''
  17. if isinstance(var, unicode):
  18. var = var.encode('utf8')
  19. if not isinstance(var, str):
  20. var = str(var)
  21. return _mysql.escape_string(var)
  22. def dbOpen(host, port, user, password, dbname):
  23. conn = _mysql.connect(
  24. db=dbname,
  25. host=host,
  26. user=user,
  27. passwd=password,
  28. port=port)
  29. conn.query("set names utf8;")
  30. return conn
  31. ROW_OF_IDX=0
  32. ROW_OF_KEY=1
  33. DEBUG=0
  34. def query(db, sql, dataType=ROW_OF_KEY):
  35. global DEBUG
  36. if DEBUG and not sql.startswith('SELECT'):
  37. print 'in debug,just print:', sql
  38. return
  39. rows = ()
  40. try:
  41. db.query(sql)
  42. res = db.store_result()
  43. if res:
  44. rows = res.fetch_row(res.num_rows(), dataType)
  45. except Exception, e:
  46. print "[%s]\t[%s]" % (e, sql)
  47. raise e
  48. return rows
  49. def execute(db, sql):
  50. if DEBUG:
  51. logging.warning("Debuging, just print SQL:%s", sql)
  52. return -110
  53. try:
  54. db.query(sql)
  55. res = db.affected_rows()
  56. if res < 0 or res == 0xFFFFFFFFFFFFFFFF:
  57. # ps : 0xFFFFFFFFFFFFFFFF (64位的-1)
  58. # 这个值与驱动、系统、硬件CPU位数都可能有关
  59. logging.error('MySQL execute error n=[%d], sql=%s', res, sql)
  60. return res
  61. except Exception, e:
  62. logging.error("err=[%s]\tsql=[%s]", e, sql)
  63. if e[0] == 1062:
  64. return 0
  65. raise e
  66. return -120
  67. def convertType(typ):
  68. typ = typ.lower()
  69. if typ.find('int') >= 0:
  70. return "int64", '%d'
  71. elif typ.find('char') >= 0 or typ.find("text") >= 0 or typ.find("enum") >= 0:
  72. return "string", "'%s'"
  73. elif typ.find("decimal") >= 0:
  74. return "float64",'%f'
  75. elif typ.find('datetime') >= 0:
  76. return "time.Time", "'%s'"
  77. elif typ.find("bool") >= 0:
  78. return "bool", "'%s'"
  79. else:
  80. return typ, "'%v'"
  81. def gen_model(host,port,user,password,dbname,table, orm):
  82. db = dbOpen(host, port, user, password, dbname)
  83. desc = query(db, "desc %s;" % escape(table), ROW_OF_KEY)
  84. indent = ' ' * 4
  85. imports =[]
  86. const = [indent +'_tablename = "%s"' % table]
  87. vars =[]
  88. struct = ["type %s struct{" % table.title()]
  89. fields = []
  90. formats = []
  91. pk = None
  92. field_define =[]
  93. field_tags = []
  94. field_comments = []
  95. for row in desc:
  96. field = row['Field']
  97. typ, fmt = convertType(row['Type'])
  98. if typ == 'time' and ('time' not in imports):
  99. imports.append("time")
  100. tags = []
  101. if not orm:
  102. tags.append('db:"%s"' % field)
  103. else:
  104. tag = ["column(%s)" % field]
  105. if row['Null'] == "YES":
  106. tag.append("null")
  107. if row['Type'].startswith('decimal'):
  108. tag.append("digits(10);decimals(2)")
  109. if row['Type'].startswith('datetime'):
  110. if field.find("created")>=0:
  111. tag.append("auto_now_add")
  112. if field.find("update")>=0:
  113. tag.append("auto_now")
  114. tag.append("type(datetime)")
  115. if row['Key'].upper().find("PRI") >= 0:
  116. tag.append('pk')
  117. tags.append('orm:"%s"' % ";".join(tag))
  118. tags.append('json:"%s,omitempty"' % field)
  119. field_define.append('%s %s' % (field.title(), typ))
  120. field_tags.append(tags)
  121. field_comments.append(row['Type'])
  122. if not row['Extra'].find("auto_increment") >= 0:
  123. fields.append("`%s`" % field)
  124. formats.append(fmt)
  125. cols_indent = {}
  126. for tags in field_tags:
  127. for i, t in enumerate(tags):
  128. cols_indent[i] = max(len(t)+1, cols_indent.get(i,0))
  129. cols_indent[len(tags)] = 0
  130. for i, tags in enumerate(field_tags):
  131. struct.append(" %s `%s` // %s" %(field_define[i],
  132. ''.join([t + (" " * (cols_indent[j] - len(t))) for j, t in enumerate(tags)]).strip(),
  133. field_comments[i]))
  134. struct.append("}")
  135. vars.append((indent + '_fiels_map = []string{%s}') % ', '.join( [f for f in fields]))
  136. if not orm:
  137. const.append((indent + '_values_fmt = "%s"') % ','.join(formats))
  138. insert = ('_INSERT = fmt.Sprintf("INSERT INTO `%s`(%s) VALUES %s", '
  139. '_tablename, strings.Join(_fiels_map,","), _values_fmt)')
  140. imports.insert(0, indent+'"strings"')
  141. vars.append(indent+insert)
  142. if pk:
  143. delete = '_DELETE = fmt.Sprintf("DELETE FROM `%s` WHERE %s" ,_tablename,"'+ pk +'")'
  144. vars.append(indent + delete)
  145. if orm:
  146. imports.append(indent + '"github.com/astaxie/beego/orm"') # for ORM http://beego.me/docs/mvc/model/orm.md
  147. else:
  148. imports.append(indent + '"my.company/lib/core/mysql"') # for ORM
  149. print "package %s\n\n" % table
  150. if imports:
  151. print "import (\n%s\n)\n\n" % '\n'.join(imports)
  152. print "const(\n%s\n)\n\n" % ('\n'.join(const))
  153. if vars:
  154. print "var(\n%s\n)\n\n" % ('\n'.join(vars))
  155. print '\n'.join(struct)
  156. if orm:
  157. print '\nfunc (self *%s) TableName() string {\n%sreturn _tablename\n}' % (table.title(), indent)
  158. print '\nfunc (self *%s) TableEngine() string {\n%sreturn "INNODB"\n}' % (table.title(), indent)
  159. print '\nfunc init() {\n // 需要在init中注册定义的model\n orm.RegisterModel(new(%s))\n}' % table.title()
  160. def main():
  161. def usage():
  162. print "--help: print this message"
  163. print "-h --host, MySQL host"
  164. print "-P --port, MySQL port"
  165. print "-u --user, MySQL user"
  166. print "-p --password, MySQL password"
  167. print "-D --database, MySQL Database"
  168. print "-t --table name"
  169. print "-o --orm, used OMR define struct."
  170. try:
  171. opts, args = getopt.getopt(sys.argv[1:], "Hh:P:u:p:D:t:ol:",
  172. ["--help","redis=","host=","port=","user=","password=",
  173. "database=",'table=',"--orm"])
  174. except getopt.GetoptError:
  175. print usage()
  176. return
  177. host = 'd5ctestingdb.mysql.rds.aliyuncs.com'
  178. user = "d5c"
  179. port = 3306
  180. password = 'D5ctesting'
  181. dbname = "test_db"
  182. table = "platform_categories"
  183. orm = True
  184. for o, a in opts:
  185. if o in ("-H","--help"):
  186. usage()
  187. sys.exit()
  188. elif o in ("-o","--orm"):
  189. orm = True
  190. elif o in ("-h","--host"):
  191. host=a
  192. elif o in ("-P","--port"):
  193. port=int(a)
  194. elif o in ("-u","--user"):
  195. user=a
  196. elif o in("-p","--password"):
  197. password=a
  198. elif o in ("-D","--database"):
  199. dbname=a
  200. elif o in ('-t', '--table'):
  201. table=a
  202. logging.info("mysql=[%s:%s@%s:%s/%s?table=%s]",
  203. user, password, host, port, dbname, table)
  204. gen_model(host,port,user,password,dbname,table,orm)
  205. if __name__ == '__main__':
  206. main()