Quellcode durchsuchen

Make sure we eval env at every level

Bertrand Chenal vor 7 Jahren
Ursprung
Commit
23c7572b90
1 geänderte Dateien mit 34 neuen und 17 gelöschten Zeilen
  1. 34 17
      baker.py

+ 34 - 17
baker.py

@@ -314,7 +314,18 @@ class Env(ChainMap):
     def __init__(self, *dicts):
         return super().__init__(*filter(lambda x: x is not None, dicts))
 
-    def fmt(self, string):
+    def fmt_env(self, child_env):
+        if not isinstance(child_env, str):
+            new_env = {}
+            for key, val in child_env.items():
+                # env wrap-around!
+                new_val = self.fmt(val)
+                if new_val == val:
+                    continue
+                new_env[key] = new_val
+            return Env(new_env, child_env)
+
+    def fmt_string(self, string):
         try:
             return string.format(**self)
         except KeyError as exc:
@@ -329,6 +340,11 @@ class Env(ChainMap):
             msg = 'Unable to format "%s", positional argument not supported'
             raise FmtException(msg)
 
+    def fmt(self, what):
+        if isinstance(what, str):
+            return self.fmt_string(what)
+        return self.fmt_env(what)
+
 
 def get_passphrase(key_path):
     service = 'SSH private key'
@@ -569,14 +585,10 @@ def run_task(task, host, cli, parent_env=None):
     # Prepare environment
     env = Env(
         {},
-        # Env from parent task
-        parent_env,
         # Env on the task itself
         task.get('env'),
-        # Top-level env
-        cli.cfg.get('env'),
-        # OS env
-        os.environ,
+        # Env from parent task
+        parent_env,
     ).new_child()
 
     env.update({
@@ -606,13 +618,15 @@ def run_task(task, host, cli, parent_env=None):
     return res
 
 
-def run_batch(task, hosts, cli, env=None):
+
+def run_batch(task, hosts, cli, global_env=None):
     '''
     Run one task on a list of hosts
     '''
     out = None
     export_env = {}
-    env = Env(export_env, task.get('env'), env)
+    task_env = global_env.fmt(task.get('env', {}))
+    parent_env = Env(export_env, task_env, global_env)
     if task.get('multi'):
         parent_sudo = task.sudo
         for multi in task.multi:
@@ -632,11 +646,9 @@ def run_batch(task, hosts, cli, env=None):
             if network:
                 spellcheck(cli.cfg.networks, network)
                 hosts = cli.cfg.networks[network].hosts
-            child_env = Env({}, multi.get('env', {}), env)
-            for k, v in child_env.items():
-                # env wrap-around!
-                child_env[k] = child_env.fmt(child_env[k])
-            out = run_batch(sub_task, hosts, cli, child_env)
+            child_env = multi.get('env', {})
+            child_env = parent_env.fmt(child_env)
+            out = run_batch(sub_task, hosts, cli, Env(child_env, parent_env))
             out = out.decode() if isinstance(out, bytes) else out
             export_env['_'] = out
             if multi.export:
@@ -645,10 +657,10 @@ def run_batch(task, hosts, cli, env=None):
     else:
         res = None
         if task.once and (task.local or task.python):
-            res = run_task(task, None, cli, env)
+            res = run_task(task, None, cli, parent_env)
         else:
             for host in hosts:
-                res = run_task(task, host, cli, env)
+                res = run_task(task, host, cli, parent_env)
                 if task.once:
                     break
         out = res and res.stdout.strip() or ''
@@ -791,8 +803,13 @@ def main():
         log_handler.setLevel(level)
         logger.setLevel(level)
 
+        base_env = Env(
+            cli.env, # Highest-priority
+            cli.cfg.get('env'),
+            os.environ, # Lowest
+        )
         for task in cli.tasks:
-            run_batch(task, cli.hosts, cli, cli.env)
+            run_batch(task, cli.hosts, cli, base_env)
     except BakerException as e:
         if cli and cli.verbose > 2:
             raise