baker.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468
  1. from getpass import getpass
  2. from hashlib import md5
  3. from itertools import chain
  4. from collections import ChainMap, OrderedDict
  5. import argparse
  6. import logging
  7. import os
  8. import posixpath
  9. import sys
  10. from fabric import Connection, Config
  11. from invoke import run
  12. from tanker import yaml_load
  13. import invoke
  14. import yaml
  15. try:
  16. import keyring
  17. except ImportError:
  18. keyring = None
  19. __version__ = '0.0'
  20. fmt = '%(levelname)s:%(asctime).19s: %(message)s'
  21. logging.basicConfig(format=fmt)
  22. logger = logging.getLogger('baker')
  23. logger.setLevel(logging.INFO)
  24. try:
  25. import colorama
  26. colorama.init()
  27. MAGENTA = colorama.Fore.MAGENTA
  28. RED = colorama.Fore.RED
  29. RESET = colorama.Style.RESET_ALL
  30. # We define custom handler ..
  31. class Handler(logging.StreamHandler):
  32. def format(self, record):
  33. if record.levelname == 'INFO':
  34. record.msg = MAGENTA + record.msg + RESET
  35. elif record.levelname in ('WARNING', 'ERROR', 'CRITICAL'):
  36. record.msg = RED + record.msg + RESET
  37. return super(Handler, self).format(record)
  38. # .. and plug it
  39. handler = Handler()
  40. handler.setFormatter(logging.Formatter(fmt))
  41. logger.addHandler(handler)
  42. logger.propagate = 0
  43. except ImportError:
  44. pass
  45. def yaml_load(stream):
  46. class OrderedLoader(yaml.Loader):
  47. pass
  48. def construct_mapping(loader, node):
  49. loader.flatten_mapping(node)
  50. return OrderedDict(loader.construct_pairs(node))
  51. OrderedLoader.add_constructor(
  52. yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG,
  53. construct_mapping)
  54. return yaml.load(stream, OrderedLoader)
  55. class ObjectDict(dict):
  56. """
  57. Simple objet sub-class that allows to transform a dict into an
  58. object, like: `ObjectDict({'ham': 'spam'}).ham == 'spam'`
  59. """
  60. def __getattr__(self, key):
  61. if key in self:
  62. return self[key]
  63. return None
  64. def __setattr__(self, key, value):
  65. self[key] = value
  66. class Node:
  67. @staticmethod
  68. def fail(path, kind):
  69. msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
  70. logger.error(msg % (kind, ' -> '.join(path)))
  71. sys.exit()
  72. @classmethod
  73. def parse(cls, cfg, path=tuple()):
  74. children = getattr(cls, '_children', None)
  75. type_name = children and type(children).__name__ \
  76. or ' or '.join((c.__name__ for c in cls._type))
  77. res = None
  78. if type_name == 'dict':
  79. if not isinstance(cfg, dict):
  80. cls.fail(path, type_name)
  81. res = ObjectDict()
  82. for name, child_class in children.items():
  83. if name == '*':
  84. continue
  85. if name not in cfg:
  86. continue
  87. res[name] = child_class.parse(cfg.pop(name), path + (name,))
  88. if '*' in children:
  89. child_class = children['*']
  90. for name, value in cfg.items():
  91. res[name] = child_class.parse(value, path + (name,))
  92. elif type_name == 'list':
  93. if not isinstance(cfg, list):
  94. cls.fail(path, type_name)
  95. child_class = children[0]
  96. res = [child_class.parse(c, path+ ('[]',)) for c in cfg]
  97. else:
  98. if not isinstance(cfg, cls._type):
  99. cls.fail(path, type_name)
  100. res = cfg
  101. return cls.setup(res, path)
  102. @classmethod
  103. def setup(cls, values, path):
  104. return values
  105. class Atom(Node):
  106. _type = (str, bool)
  107. class AtomList(Node):
  108. _children = [Atom]
  109. class Hosts(Node):
  110. _children = [Atom]
  111. class Auth(Node):
  112. _children = {'*': Atom}
  113. class EnvNode(Node):
  114. _children = {'*': Atom}
  115. class HostGroup(Node):
  116. _children = {
  117. 'hosts': Hosts,
  118. }
  119. class Network(Node):
  120. _children = {
  121. '*': HostGroup,
  122. }
  123. class Multi(Node):
  124. _children = {
  125. '*': Atom,
  126. 'env': EnvNode,
  127. }
  128. class MultiList(Node):
  129. _children = [Multi]
  130. class Command(Node):
  131. _children = {
  132. '*': Atom,
  133. 'env': EnvNode,
  134. 'multi': MultiList,
  135. }
  136. @classmethod
  137. def setup(cls, values, path):
  138. if path:
  139. values['name'] = path[-1]
  140. return values
  141. class Task(Node):
  142. _children = {
  143. '*': Command,
  144. }
  145. class ConfigRoot(Node):
  146. _children = {
  147. 'networks': Network,
  148. 'tasks': Task,
  149. 'auth': Auth,
  150. 'env': EnvNode,
  151. # 'load': ? -> todo allows to load other files and merge them
  152. }
  153. class Env(ChainMap):
  154. def __init__(self, *dicts):
  155. return super().__init__(*filter(bool, dicts))
  156. def fmt(self, string):
  157. try:
  158. return string.format(**self)
  159. except KeyError as exc:
  160. msg = 'Unable to format "%s" (missing: "%s")'
  161. logger.error(msg % (string, exc.args[0]))
  162. sys.exit()
  163. except IndexError as exc:
  164. msg = 'Unable to format "%s", positional argument not supported'
  165. logger.error(msg)
  166. sys.exit()
  167. def get_passphrase(key_path):
  168. service = 'SSH private key'
  169. csum = md5(open(key_path, 'rb').read()).digest().hex()
  170. ssh_pass = keyring.get_password(service, csum)
  171. if not ssh_pass:
  172. ssh_pass = getpass('Password for %s: ' % key_path)
  173. keyring.set_password(service, csum, ssh_pass)
  174. return ssh_pass
  175. def get_sudo_passwd():
  176. service = "Sudo password"
  177. passwd = keyring.get_password(service, '-')
  178. if not passwd:
  179. passwd = getpass('Sudo password:')
  180. keyring.set_password(service, '-', passwd)
  181. return passwd
  182. CONNECTION_CACHE = {}
  183. def connect(host, auth, with_sudo=False):
  184. if (host, with_sudo) in CONNECTION_CACHE:
  185. return CONNECTION_CACHE[host, with_sudo]
  186. connect_kwargs = {}
  187. if auth and auth.get('ssh_private_key'):
  188. connect_kwargs['key_filename'] = auth.ssh_private_key
  189. if not os.path.exists(auth.ssh_private_key):
  190. msg = 'Private key file "%s" not found' % auth.ssh_private_key
  191. logger.error(msg)
  192. sys.exit()
  193. ssh_pass = get_passphrase(auth.ssh_private_key)
  194. connect_kwargs['password'] = ssh_pass
  195. if with_sudo:
  196. config = Config(overrides={
  197. 'sudo': {
  198. 'password': get_sudo_passwd()
  199. }
  200. })
  201. else:
  202. config = None
  203. con = Connection(host, config=config, connect_kwargs=connect_kwargs)
  204. CONNECTION_CACHE[host, with_sudo] = con
  205. return con
  206. def run_local(task, env, cli):
  207. # Run local task
  208. cmd = env.fmt(task.local)
  209. logger.info(env.fmt('RUN {task_name} locally'))
  210. if cli.dry_run:
  211. logger.info('[DRY-RUN] ' + cmd)
  212. return None
  213. res = run(cmd, env=env)
  214. return res
  215. def run_remote(task, host, env, cli):
  216. res = None
  217. host = env.fmt(host)
  218. env = env.new_child({
  219. 'host': host,
  220. })
  221. con = connect(host, cli.cfg.auth, bool(task.sudo))
  222. if task.run:
  223. cmd = env.fmt(task.run)
  224. logger.info(env.fmt('RUN {task_name} ON {host}'))
  225. if cli.dry_run:
  226. logger.info('[DRY-RUN] ' + cmd)
  227. else:
  228. res = con.run(cmd, pty=True, env=env)
  229. elif task.sudo:
  230. cmd = env.fmt(task.sudo)
  231. logger.info(env.fmt('SUDO {task_name} ON {host}'))
  232. if cli.dry_run:
  233. logger.info('[DRY-RUN] %s' + cmd)
  234. else:
  235. res = con.sudo(cmd)
  236. elif task.send:
  237. local_path = env.fmt(task.send)
  238. remote_path = env.fmt(task.to)
  239. logger.info(f'SEND {local_path} TO {host}:{remote_path}')
  240. if cli.dry_run:
  241. logger.info('[DRY-RUN]')
  242. elif os.path.isfile(local_path):
  243. con.put(local_path, remote=remote_path)
  244. else:
  245. for root, subdirs, files in os.walk(local_path):
  246. rel_dir = os.path.relpath(root, local_path)
  247. rem_dir = posixpath.join(remote_path, rel_dir)
  248. con.run('mkdir -p {}'.format(rem_dir))
  249. for f in files:
  250. rel_f = os.path.join(root, f)
  251. rem_file = posixpath.join(rem_dir, f)
  252. con.put(os.path.abspath(rel_f), remote=rem_file)
  253. else:
  254. logger.error('Unable to run task "%s"' % task.name)
  255. sys.exit()
  256. return res
  257. def run_task(task, host, cli, parent_env=None):
  258. '''
  259. Execute one task on one host (or locally)
  260. '''
  261. # Prepare environment
  262. env = Env(
  263. # Cli is top priority
  264. dict(e.split('=') for e in cli.env),
  265. # Then comes env from parent task
  266. parent_env,
  267. # Env on the task itself
  268. task.get('env'),
  269. # Top-level env
  270. cli.cfg.get('env'),
  271. # OS env
  272. os.environ,
  273. ).new_child()
  274. env.update({
  275. 'task_desc': env.fmt(task.desc),
  276. 'task_name': task.name,
  277. })
  278. if task.local:
  279. res = run_local(task, env, cli)
  280. else:
  281. res = run_remote(task, host, env, cli)
  282. if task.get('assert'):
  283. env.update({
  284. 'stdout': res.stdout,
  285. 'stderr': res.stderr,
  286. })
  287. assert_ = env.fmt(task['assert'])
  288. ok = eval(assert_, dict(env))
  289. if ok:
  290. logger.info('Assert ok')
  291. else:
  292. logger.error('Assert "%s" failed!' % assert_)
  293. sys.exit()
  294. return res
  295. def run_batch(task, hosts, cli, env=None):
  296. '''
  297. Run one task on a list of hosts
  298. '''
  299. env = Env(task.get('env'), env)
  300. res = None
  301. export_env = {}
  302. if task.get('multi'):
  303. for multi in task.multi:
  304. task = multi.task
  305. sub_task = cli.cfg.tasks[task]
  306. network = multi.get('network')
  307. if network:
  308. hosts = cli.cfg.networks[network].hosts
  309. child_env = multi.get('env', {}).copy()
  310. for k, v in child_env.items():
  311. # env wrap-around!
  312. child_env[k] = env.fmt(child_env[k])
  313. run_env = Env(export_env, child_env, env)
  314. res = run_batch(sub_task, hosts, cli, run_env)
  315. if multi.export:
  316. export_env[multi.export] = res and res.stdout.strip() or ''
  317. else:
  318. if task.once and task.local:
  319. res = run_task(task, None, cli, env)
  320. return res
  321. for host in hosts:
  322. res = run_task(task, host, cli, env)
  323. if task.once:
  324. break
  325. return res
  326. def base_cli(args=None):
  327. parser = argparse.ArgumentParser()
  328. parser.add_argument('names', nargs='*',
  329. help='Hosts and commands to run them on')
  330. parser.add_argument('-c', '--config', default='bk.yaml',
  331. help='Config file')
  332. parser.add_argument('-r', '--run', nargs='*', default=[],
  333. help='Run custom task')
  334. parser.add_argument('-d', '--dry-run', action='store_true',
  335. help='Do not run actual tasks, just print them')
  336. parser.add_argument('-e', '--env', nargs='*', default=[],
  337. help='Add value to execution environment '
  338. '(ex: -e foo=bar "name=John Doe")')
  339. parser.add_argument('-s', '--sudo', default='auto',
  340. help='Enable sudo (auto|yes|no')
  341. parser.add_argument('-v', '--verbose', action='count',
  342. default=0, help='Increase verbosity')
  343. cli = parser.parse_args(args=args)
  344. return ObjectDict(vars(cli))
  345. def main():
  346. cli = base_cli()
  347. if cli.verbose:
  348. level = 'INFO' if cli.verbose == 1 else 'DEBUG'
  349. logger.setLevel(level)
  350. logger.info('Log level set to: %s' % level)
  351. # Load config
  352. logger.info('Load config %s' % cli.config)
  353. cfg = yaml_load(open(cli.config))
  354. cfg = ConfigRoot.parse(cfg)
  355. cli.cfg = cfg
  356. # Define useful defaults
  357. cfg.networks = cfg.networks or {}
  358. cfg.tasks = cfg.tasks or {}
  359. # Make sure we don't have overlap between hosts and tasks
  360. items = list(cfg.networks) + list(cfg.tasks)
  361. msg = 'Name collision between tasks and networks'
  362. assert len(set(items)) == len(items), msg
  363. # Build task list
  364. tasks = []
  365. networks = []
  366. for name in cli.names:
  367. if name in cfg.networks:
  368. host = cfg.networks[name]
  369. networks.append(host)
  370. elif name in cfg.tasks:
  371. task = cfg.tasks[name]
  372. tasks.append(task)
  373. else:
  374. logger.error('Name "%s" not understood' % name)
  375. sys.exit()
  376. for custom_task in cli.run:
  377. task = Command.parse(yaml_load(custom_task))
  378. task.desc = 'Custom command'
  379. tasks.append(task)
  380. hosts = list(chain.from_iterable(n.hosts for n in networks))
  381. try:
  382. for task in tasks:
  383. run_batch(task, hosts, cli)
  384. except invoke.exceptions.Failure as e:
  385. logger.error(str(e))
  386. if __name__ == '__main__':
  387. main()