generic_scatter_one_uow.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """
  2. This must not run in a tmpdir. The 'inputs' paths will
  3. end up relative to the rundir.
  4. """
  5. import argparse
  6. import collections
  7. import glob
  8. import logging
  9. import os
  10. import sys
  11. from .. import io
  12. LOG = logging.getLogger()
  13. # Here is some stuff basically copied from pypeflow.sample_tasks.py.
  14. def validate(bash_template, inputs, outputs, parameterss):
  15. LOG.info('bash_script_from_template({}\n\tinputs={!r},\n\toutputs={!r})'.format(
  16. bash_template, inputs, outputs))
  17. def validate_dict(mydict):
  18. "Python identifiers are illegal as keys."
  19. try:
  20. collections.namedtuple('validate', list(mydict.keys()))
  21. except ValueError as exc:
  22. LOG.exception('Bad key name in task definition dict {!r}'.format(mydict))
  23. raise
  24. validate_dict(inputs)
  25. validate_dict(outputs)
  26. validate_dict(parameterss)
  27. def run(all_uow_list_fn, split_idx, one_uow_list_fn):
  28. all_uows = io.deserialize(all_uow_list_fn)
  29. all_dn = os.path.abspath(os.path.dirname(all_uow_list_fn))
  30. one_dn = os.path.abspath(os.path.dirname(one_uow_list_fn))
  31. rel_dn = os.path.relpath(all_dn, one_dn)
  32. one_uow = all_uows[split_idx]
  33. def fixpath(rel):
  34. try:
  35. if not os.path.isabs(rel):
  36. return os.path.join('.', os.path.normpath(os.path.join(rel_dn, rel)))
  37. except Exception:
  38. # in case of non-string?
  39. pass
  40. return rel
  41. if isinstance(one_uow, dict):
  42. input_dict = one_uow['input']
  43. for k, v in list(input_dict.items()):
  44. input_dict[k] = fixpath(v)
  45. io.serialize(one_uow_list_fn, [one_uow])
  46. class HelpF(argparse.RawTextHelpFormatter, argparse.ArgumentDefaultsHelpFormatter):
  47. pass
  48. def parse_args(argv):
  49. description = 'Scatter a single unit-of-work from many units-of-work.'
  50. epilog = ''
  51. parser = argparse.ArgumentParser(
  52. description=description,
  53. epilog=epilog,
  54. formatter_class=HelpF,
  55. )
  56. parser.add_argument(
  57. '--all-uow-list-fn',
  58. help='Input. JSON list of all units of work.')
  59. parser.add_argument(
  60. '--split-idx', type=int,
  61. help='Input. Index into the all-uow-list for our single unit-of-work.')
  62. parser.add_argument(
  63. '--one-uow-list-fn',
  64. help='Output. JSON list of a single unit-of-work.')
  65. args = parser.parse_args(argv[1:])
  66. return args
  67. def main(argv=sys.argv):
  68. args = parse_args(argv)
  69. logging.basicConfig(level=logging.INFO)
  70. run(**vars(args))
  71. if __name__ == '__main__': # pragma: no cover
  72. main()