Quellcode durchsuchen

Fix attribute list on Command node; better logging

Bertrand Chenal vor 7 Jahren
Ursprung
Commit
63211c4aaa
1 geänderte Dateien mit 29 neuen und 21 gelöschten Zeilen
  1. 29 21
      baker.py

+ 29 - 21
baker.py

@@ -72,6 +72,7 @@ def edits(word):
         if right:
             yield left + right[1:]
 
+
 def gen_candidates(wordlist):
     candidates = defaultdict(set)
     for word in wordlist:
@@ -80,12 +81,14 @@ def gen_candidates(wordlist):
                 candidates[ed2].add(word)
     return candidates
 
+
 def spell(candidates,  word):
     matches = set(chain.from_iterable(
         candidates[ed] for ed in edits(word) if ed in candidates
     ))
     return matches
 
+
 def spellcheck(objdict, word):
     if word in objdict:
         return
@@ -99,8 +102,8 @@ def spellcheck(objdict, word):
     matches = spell(candidates, word)
     if matches:
         msg += ', try: %s' % ' or '.join(matches)
-    logger.error(msg)
-    exit(1)
+    abort(msg)
+
 
 class ObjectDict(dict):
     """
@@ -127,8 +130,7 @@ class Node:
     @staticmethod
     def fail(path, kind):
         msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
-        logger.error(msg % (kind, '->'.join(path)))
-        sys.exit()
+        abort(msg % (kind, '->'.join(path)))
 
     @classmethod
     def parse(cls, cfg, path=tuple()):
@@ -156,8 +158,7 @@ class Node:
                         matches = spell(candidates, key)
                         if matches:
                             msg += ', try: %s' % ' or '.join(matches)
-                        logger.error(msg)
-                        exit(1)
+                        abort(msg)
 
                 for name, child_class in children.items():
                     if name not in cfg:
@@ -213,6 +214,7 @@ class Multi(Node):
     _children = {
         'task': Atom,
         'export': Atom,
+        'network': Atom,
         'env': EnvNode,
     }
 
@@ -225,6 +227,9 @@ class Command(Node):
         'local': Atom,
         'once': Atom,
         'run': Atom,
+        'send': Atom,
+        'to': Atom,
+        'assert': Atom,
         'env': EnvNode,
         'multi': MultiList,
     }
@@ -276,12 +281,10 @@ class Env(ChainMap):
             matches = spell(candidates, key)
             if msg:
                 msg += ', try: %s' % ' or '.join(matches)
-            logger.error(msg )
-            exit(1)
+            abort(msg )
         except IndexError as exc:
             msg = 'Unable to format "%s", positional argument not supported'
-            logger.error(msg)
-            sys.exit()
+            abort(msg)
 
 def get_passphrase(key_path):
     service = 'SSH private key'
@@ -312,8 +315,7 @@ def connect(host, auth, with_sudo=False):
         connect_kwargs['key_filename'] = auth.ssh_private_key
         if not os.path.exists(auth.ssh_private_key):
             msg = 'Private key file "%s" not found' % auth.ssh_private_key
-            logger.error(msg)
-            sys.exit()
+            abort(msg)
         ssh_pass = get_passphrase(auth.ssh_private_key)
         connect_kwargs['password'] = ssh_pass
 
@@ -346,6 +348,9 @@ def run_local(task, env, cli):
 def run_remote(task, host, env, cli):
     res = None
     host = env.fmt(host)
+    env = env.new_child({
+        'host': host,
+    })
     con = connect(host, cli.cfg.auth, bool(task.sudo))
     if task.run:
         cmd = env.fmt(task.run)
@@ -382,8 +387,7 @@ def run_remote(task, host, env, cli):
                     rem_file = posixpath.join(rem_dir, f)
                     con.put(os.path.abspath(rel_f), remote=rem_file)
     else:
-        logger.error('Unable to run task "%s"' % task.name)
-        sys.exit()
+        abort('Unable to run task "%s"' % task.name)
 
     return res
 
@@ -427,8 +431,7 @@ def run_task(task, host, cli, parent_env=None):
         if ok:
             logger.info('Assert ok')
         else:
-            logger.error('Assert "%s" failed!' % assert_)
-            sys.exit()
+            abort('Assert "%s" failed!' % assert_)
     return res
 
 
@@ -489,13 +492,19 @@ def base_cli(args=None):
     cli = parser.parse_args(args=args)
     return ObjectDict(vars(cli))
 
+def abort(msg):
+    logger.error(msg)
+    sys.exit(1)
 
 def load(path, prefix=None):
     load_sections = ('networks', 'tasks', 'auth', 'env')
 
-    logger.info('Load config %s' % path)
-    cfg = yaml_load(open(path))
-    cfg = ConfigRoot.parse(cfg)
+    if os.path.isfile(path):
+        logger.info('Load config %s' % path)
+        cfg = yaml_load(open(path))
+        cfg = ConfigRoot.parse(cfg)
+    else:
+        abort('Config file "%s" not found' % path)
 
     # Define useful defaults
     cfg.networks = cfg.networks or ObjectDict()
@@ -559,8 +568,7 @@ def main():
             matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
             if matches:
                 msg += ', try: %s' % ' or '.join(matches)
-            logger.error(msg)
-            sys.exit()
+            abort(msg)
 
     for custom_task in cli.run:
         task = Command.parse(yaml_load(custom_task))