generic_scatter_uows.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import argparse
  2. import collections
  3. import glob
  4. import logging
  5. import os
  6. import sys
  7. from .. import io
  8. LOG = logging.getLogger()
  9. def yield_uows(n, all_uows):
  10. uows_per_chunk = (len(all_uows) + n - 1) / n
  11. for uow in all_uows:
  12. yield [uow]
  13. def run(all_uow_list_fn, pattern, nchunks_max):
  14. all_uows = io.deserialize(all_uow_list_fn)
  15. n = min(nchunks_max, len(all_uows))
  16. LOG.info('Num chunks = {} (<= {})'.format(n, nchunks_max))
  17. all_dn = os.path.abspath(os.path.dirname(all_uow_list_fn))
  18. for i, uows in enumerate(yield_uows(n, all_uows)):
  19. key = '{:02d}'.format(i)
  20. fn = pattern.replace('%', key)
  21. LOG.info('Writing {} units-of-work to "{}" ({}).'.format(len(uows), fn, key))
  22. one_dn = os.path.abspath(os.path.dirname(fn))
  23. rel_dn = os.path.relpath(all_dn, one_dn)
  24. def fixpath(rel):
  25. try:
  26. if not os.path.isabs(rel):
  27. return os.path.join('.', os.path.normpath(os.path.join(rel_dn, rel)))
  28. except Exception:
  29. # in case of non-string?
  30. pass
  31. return rel
  32. for one_uow in uows:
  33. if isinstance(one_uow, dict):
  34. input_dict = one_uow['input']
  35. for k, v in list(input_dict.items()):
  36. input_dict[k] = fixpath(v)
  37. io.serialize(fn, uows)
  38. class HelpF(argparse.RawTextHelpFormatter, argparse.ArgumentDefaultsHelpFormatter):
  39. pass
  40. def parse_args(argv):
  41. description = 'Scatter a single unit-of-work from many units-of-work.'
  42. epilog = ''
  43. parser = argparse.ArgumentParser(
  44. description=description,
  45. epilog=epilog,
  46. formatter_class=HelpF,
  47. )
  48. parser.add_argument(
  49. '--all-uow-list-fn',
  50. help='Input. JSON list of all units of work.')
  51. parser.add_argument(
  52. '--nchunks-max', type=int,
  53. help='Input. Maximum number of output files.')
  54. parser.add_argument(
  55. '--pattern',
  56. help='Output. The "%" will be replaced by a zero-padded number. (Probably should be ".json")')
  57. args = parser.parse_args(argv[1:])
  58. return args
  59. def main(argv=sys.argv):
  60. args = parse_args(argv)
  61. logging.basicConfig(level=logging.INFO)
  62. run(**vars(args))
  63. if __name__ == '__main__': # pragma: no cover
  64. main()