|
|
@@ -314,7 +314,18 @@ class Env(ChainMap):
|
|
|
def __init__(self, *dicts):
|
|
|
return super().__init__(*filter(lambda x: x is not None, dicts))
|
|
|
|
|
|
- def fmt(self, string):
|
|
|
+ def fmt_env(self, child_env):
|
|
|
+ if not isinstance(child_env, str):
|
|
|
+ new_env = {}
|
|
|
+ for key, val in child_env.items():
|
|
|
+ # env wrap-around!
|
|
|
+ new_val = self.fmt(val)
|
|
|
+ if new_val == val:
|
|
|
+ continue
|
|
|
+ new_env[key] = new_val
|
|
|
+ return Env(new_env, child_env)
|
|
|
+
|
|
|
+ def fmt_string(self, string):
|
|
|
try:
|
|
|
return string.format(**self)
|
|
|
except KeyError as exc:
|
|
|
@@ -329,6 +340,11 @@ class Env(ChainMap):
|
|
|
msg = 'Unable to format "%s", positional argument not supported'
|
|
|
raise FmtException(msg)
|
|
|
|
|
|
+ def fmt(self, what):
|
|
|
+ if isinstance(what, str):
|
|
|
+ return self.fmt_string(what)
|
|
|
+ return self.fmt_env(what)
|
|
|
+
|
|
|
|
|
|
def get_passphrase(key_path):
|
|
|
service = 'SSH private key'
|
|
|
@@ -569,14 +585,10 @@ def run_task(task, host, cli, parent_env=None):
|
|
|
# Prepare environment
|
|
|
env = Env(
|
|
|
{},
|
|
|
- # Env from parent task
|
|
|
- parent_env,
|
|
|
# Env on the task itself
|
|
|
task.get('env'),
|
|
|
- # Top-level env
|
|
|
- cli.cfg.get('env'),
|
|
|
- # OS env
|
|
|
- os.environ,
|
|
|
+ # Env from parent task
|
|
|
+ parent_env,
|
|
|
).new_child()
|
|
|
|
|
|
env.update({
|
|
|
@@ -606,13 +618,15 @@ def run_task(task, host, cli, parent_env=None):
|
|
|
return res
|
|
|
|
|
|
|
|
|
-def run_batch(task, hosts, cli, env=None):
|
|
|
+
|
|
|
+def run_batch(task, hosts, cli, global_env=None):
|
|
|
'''
|
|
|
Run one task on a list of hosts
|
|
|
'''
|
|
|
out = None
|
|
|
export_env = {}
|
|
|
- env = Env(export_env, task.get('env'), env)
|
|
|
+ task_env = global_env.fmt(task.get('env', {}))
|
|
|
+ parent_env = Env(export_env, task_env, global_env)
|
|
|
if task.get('multi'):
|
|
|
parent_sudo = task.sudo
|
|
|
for multi in task.multi:
|
|
|
@@ -632,11 +646,9 @@ def run_batch(task, hosts, cli, env=None):
|
|
|
if network:
|
|
|
spellcheck(cli.cfg.networks, network)
|
|
|
hosts = cli.cfg.networks[network].hosts
|
|
|
- child_env = Env({}, multi.get('env', {}), env)
|
|
|
- for k, v in child_env.items():
|
|
|
- # env wrap-around!
|
|
|
- child_env[k] = child_env.fmt(child_env[k])
|
|
|
- out = run_batch(sub_task, hosts, cli, child_env)
|
|
|
+ child_env = multi.get('env', {})
|
|
|
+ child_env = parent_env.fmt(child_env)
|
|
|
+ out = run_batch(sub_task, hosts, cli, Env(child_env, parent_env))
|
|
|
out = out.decode() if isinstance(out, bytes) else out
|
|
|
export_env['_'] = out
|
|
|
if multi.export:
|
|
|
@@ -645,10 +657,10 @@ def run_batch(task, hosts, cli, env=None):
|
|
|
else:
|
|
|
res = None
|
|
|
if task.once and (task.local or task.python):
|
|
|
- res = run_task(task, None, cli, env)
|
|
|
+ res = run_task(task, None, cli, parent_env)
|
|
|
else:
|
|
|
for host in hosts:
|
|
|
- res = run_task(task, host, cli, env)
|
|
|
+ res = run_task(task, host, cli, parent_env)
|
|
|
if task.once:
|
|
|
break
|
|
|
out = res and res.stdout.strip() or ''
|
|
|
@@ -791,8 +803,13 @@ def main():
|
|
|
log_handler.setLevel(level)
|
|
|
logger.setLevel(level)
|
|
|
|
|
|
+ base_env = Env(
|
|
|
+ cli.env, # Highest-priority
|
|
|
+ cli.cfg.get('env'),
|
|
|
+ os.environ, # Lowest
|
|
|
+ )
|
|
|
for task in cli.tasks:
|
|
|
- run_batch(task, cli.hosts, cli, cli.env)
|
|
|
+ run_batch(task, cli.hosts, cli, base_env)
|
|
|
except BakerException as e:
|
|
|
if cli and cli.verbose > 2:
|
|
|
raise
|