Explorar o código

Implement spellchecking

Bertrand Chenal %!s(int64=7) %!d(string=hai) anos
pai
achega
6d6f24ee1c
Modificáronse 2 ficheiros con 93 adicións e 23 borrados
  1. 1 1
      README.md
  2. 92 22
      baker.py

+ 1 - 1
README.md

@@ -104,7 +104,7 @@ WHAT?
 
 # Roadmap
 
-- Document: assert, send and auth
+- Document: assert, send dry-run and auth
 - Add tests
 - Implement: import of other config file, add a contrib directory with
   ready-made tasks for common operations

+ 92 - 22
baker.py

@@ -1,7 +1,7 @@
 from getpass import getpass
 from hashlib import md5
 from itertools import chain
-from collections import ChainMap, OrderedDict
+from collections import ChainMap, OrderedDict, defaultdict
 import argparse
 import logging
 import os
@@ -65,6 +65,43 @@ def yaml_load(stream):
     return yaml.load(stream, OrderedLoader)
 
 
+def edits(word):
+    yield word
+    splits = ((word[:i], word[i:]) for i in range(len(word) + 1))
+    for left, right in splits:
+        if right:
+            yield left + right[1:]
+
+def gen_candidates(wordlist):
+    candidates = defaultdict(set)
+    for word in wordlist:
+        for ed1 in edits(word):
+            for ed2 in edits(ed1):
+                candidates[ed2].add(word)
+    return candidates
+
+def spell(candidates,  word):
+    matches = set(chain.from_iterable(
+        candidates[ed] for ed in edits(word) if ed in candidates
+    ))
+    return matches
+
+def spellcheck(objdict, word):
+    if word in objdict:
+        return
+
+    candidates = objdict.get('_candidates')
+    if not candidates:
+        candidates = gen_candidates(list(objdict))
+        objdict['_candidates'] = candidates
+
+    msg = '"%s" not found in %s' % (word, objdict._path)
+    matches = spell(candidates, word)
+    if matches:
+        msg += ', try: %s' % ' or '.join(matches)
+    logger.error(msg)
+    exit(1)
+
 class ObjectDict(dict):
     """
     Simple objet sub-class that allows to transform a dict into an
@@ -73,18 +110,24 @@ class ObjectDict(dict):
     def __getattr__(self, key):
         if key in self:
             return self[key]
-        return None
+        else:
+            return None
 
     def __setattr__(self, key, value):
         self[key] = value
 
+    def __iter__(self):
+        for key in self.keys():
+            if key.startswith('_'):
+                continue
+            yield key
 
 class Node:
 
     @staticmethod
     def fail(path, kind):
         msg = 'Error while parsing config: expecting "%s" while parsing "%s"'
-        logger.error(msg % (kind, ' -> '.join(path)))
+        logger.error(msg % (kind, '->'.join(path)))
         sys.exit()
 
     @classmethod
@@ -97,17 +140,29 @@ class Node:
             if not isinstance(cfg, dict):
                 cls.fail(path, type_name)
             res = ObjectDict()
-            for name, child_class in children.items():
-                if name == '*':
-                    continue
-                if name not in cfg:
-                    continue
-                res[name] = child_class.parse(cfg.pop(name), path + (name,))
 
             if '*' in children:
+                assert len(children) == 1, "Don't mix '*' and other keys"
                 child_class = children['*']
                 for name, value in cfg.items():
                     res[name] = child_class.parse(value, path + (name,))
+            else:
+                # Enforce known pre-defined
+                for key in cfg:
+                    if key not in children:
+                        path = ' -> '.join(path)
+                        msg = 'Attribute "%s" not understoodin %s' % (key, path)
+                        candidates = gen_candidates(children.keys())
+                        matches = spell(candidates, key)
+                        if matches:
+                            msg += ', try: %s' % ' or '.join(matches)
+                        logger.error(msg)
+                        exit(1)
+
+                for name, child_class in children.items():
+                    if name not in cfg:
+                        continue
+                    res[name] = child_class.parse(cfg.pop(name), path + (name,))
 
         elif type_name == 'list':
             if not isinstance(cfg, list):
@@ -124,6 +179,8 @@ class Node:
 
     @classmethod
     def setup(cls, values, path):
+        if isinstance(values, dict):
+            values['_path'] = '->'.join(path)
         return values
 
 
@@ -154,7 +211,8 @@ class Network(Node):
 
 class Multi(Node):
     _children = {
-        '*': Atom,
+        'task': Atom,
+        'export': Atom,
         'env': EnvNode,
     }
 
@@ -163,7 +221,10 @@ class MultiList(Node):
 
 class Command(Node):
     _children = {
-        '*': Atom,
+        'desc': Atom,
+        'local': Atom,
+        'once': Atom,
+        'run': Atom,
         'env': EnvNode,
         'multi': MultiList,
     }
@@ -209,9 +270,14 @@ class Env(ChainMap):
         try:
             return string.format(**self)
         except KeyError as exc:
-            msg = 'Unable to format "%s" (missing: "%s")'
-            logger.error(msg % (string, exc.args[0]))
-            sys.exit()
+            msg = 'Unable to format "%s" (missing: "%s")'% (string, exc.args[0])
+            candidates = gen_candidates(self.keys())
+            key = exc.args[0]
+            matches = spell(candidates, key)
+            if msg:
+                msg += ', try: %s' % ' or '.join(matches)
+            logger.error(msg )
+            exit(1)
         except IndexError as exc:
             msg = 'Unable to format "%s", positional argument not supported'
             logger.error(msg)
@@ -268,6 +334,7 @@ def connect(host, auth, with_sudo=False):
 def run_local(task, env, cli):
     # Run local task
     cmd = env.fmt(task.local)
+    # TODO log only task_desc and let desc contains env info like {host}
     logger.info(env.fmt('RUN {task_name} locally'))
     if cli.dry_run:
         logger.info('[DRY-RUN] ' + cmd)
@@ -279,9 +346,6 @@ def run_local(task, env, cli):
 def run_remote(task, host, env, cli):
     res = None
     host = env.fmt(host)
-    env = env.new_child({
-        'host': host,
-    })
     con = connect(host, cli.cfg.auth, bool(task.sudo))
     if task.run:
         cmd = env.fmt(task.run)
@@ -345,7 +409,7 @@ def run_task(task, host, cli, parent_env=None):
     env.update({
         'task_desc': env.fmt(task.desc),
         'task_name': task.name,
-        'host': host,
+        'host': host or '',
     })
 
     if task.local:
@@ -379,9 +443,11 @@ def run_batch(task, hosts, cli, env=None):
     if task.get('multi'):
         for multi in task.multi:
             task = multi.task
+            spellcheck(cli.cfg.tasks, task)
             sub_task = cli.cfg.tasks[task]
             network = multi.get('network')
             if network:
+                spellcheck(cli.cfg.networks, network)
                 hosts = cli.cfg.networks[network].hosts
             child_env = multi.get('env', {}).copy()
             for k, v in child_env.items():
@@ -432,8 +498,8 @@ def load(path, prefix=None):
     cfg = ConfigRoot.parse(cfg)
 
     # Define useful defaults
-    cfg.networks = cfg.networks or {}
-    cfg.tasks = cfg.tasks or {}
+    cfg.networks = cfg.networks or ObjectDict()
+    cfg.tasks = cfg.tasks or ObjectDict()
 
     if prefix:
         fn = lambda x: '/'.join(prefix + [x])
@@ -442,7 +508,7 @@ def load(path, prefix=None):
             if not cfg.get(section):
                 continue
             items = cfg[section].items()
-            cfg[section] = {fn(k): v for k, v in items}
+            cfg[section] = {fn(k): v for k, v in items if not k.startswith('_')}
 
     # Recursive load
     if cfg.load:
@@ -489,7 +555,11 @@ def main():
             task = cfg.tasks[name]
             tasks.append(task)
         else:
-            logger.error('Name "%s" not understood' % name)
+            msg = 'Name "%s" not understood' % name
+            matches = spell(cfg.networks, name) | spell(cfg.tasks, name)
+            if matches:
+                msg += ', try: %s' % ' or '.join(matches)
+            logger.error(msg)
             sys.exit()
 
     for custom_task in cli.run: