snakemake.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """Exact copy of falcon_unzip/tasks/snakemake.py
  2. TODO: Consolidate.
  3. """
  4. from future.utils import viewitems
  5. from future.utils import itervalues
  6. from builtins import object
  7. import json
  8. import os
  9. import re
  10. def find_wildcards(pattern):
  11. """
  12. >>> find_wildcards('{foo}/{bar}')
  13. ['bar', 'foo']
  14. """
  15. re_wildcard = re.compile(r'\{(\w+)\}')
  16. found = [mo.group(1) for mo in re_wildcard.finditer(pattern)]
  17. return list(sorted(found))
  18. class SnakemakeRuleWriter(object):
  19. def legalize(self, rule_name):
  20. return self.re_bad_char.sub('_', rule_name, count=0)
  21. def unique_rule_name(self, basename):
  22. rule_name = basename
  23. if rule_name in self.rule_names:
  24. i = 1
  25. while rule_name in self.rule_names:
  26. rule_name = basename + str(i)
  27. i += 1
  28. self.rule_names.add(rule_name)
  29. return rule_name
  30. def write_dynamic_rules(self, rule_name, input_json, inputs, shell_template,
  31. parameters, wildcard_outputs, output_json):
  32. """Lots of conventions.
  33. input_json: should have a key 'mapped_inputs', which is a map of key->filename
  34. Those filenames will be symlinked here, according to the patterns in wildcard_inputs.
  35. shell_template: for the parallel task
  36. output_json: This will contain only key->filename, based on wildcard_outputs.
  37. inputs: These include non-wildcards too.
  38. (For now, we assume inputs/outputs is just one per parallel task.)
  39. """
  40. # snakemake does not like paths starting with './'; they can lead to mismatches.
  41. # So we run normpath everywhere.
  42. input_json = os.path.normpath(input_json)
  43. output_json = os.path.normpath(output_json)
  44. # snakemake cannot use already-generated files as dynamic outputs (the wildcard_inputs for the parallel task),
  45. # so we rename them and plan to symlink.
  46. wildcard_inputs = dict(inputs)
  47. nonwildcard_inputs = dict()
  48. for (key, fn) in list(viewitems(wildcard_inputs)):
  49. if '{' not in fn:
  50. del wildcard_inputs[key]
  51. nonwildcard_inputs[key] = fn
  52. continue
  53. dn, bn = os.path.split(wildcard_inputs[key])
  54. wildcard_inputs[key] = os.path.join(dn + '.symlink', bn)
  55. rule_name = self.unique_rule_name(rule_name)
  56. dynamic_output_kvs = ', '.join("%s=dynamic('%s')"%(k, os.path.normpath(v)) for (k, v) in viewitems(wildcard_inputs))
  57. dynamic_input_kvs = ', '.join("%s=ancient(dynamic('%s'))"%(k, os.path.normpath(v)) for (k, v) in viewitems(wildcard_outputs))
  58. rule_parameters = {k: v for (k, v) in viewitems(parameters) if not k.startswith('_')}
  59. params = ','.join('\n %s="%s"'%(k,v) for (k, v) in viewitems(rule_parameters))
  60. pattern_kv_list = list()
  61. for (name, wi) in viewitems(wildcard_inputs):
  62. fn_pattern = wi
  63. fn_pattern = fn_pattern.replace('{', '{{')
  64. fn_pattern = fn_pattern.replace('}', '}}')
  65. pattern_kv_list.append('%s="%s"' %(name, fn_pattern))
  66. wi_pattern_kvs = ' '.join(pattern_kv_list)
  67. rule = """
  68. rule dynamic_%(rule_name)s_split:
  69. input: %(input_json)r
  70. output: %(dynamic_output_kvs)s
  71. shell: 'python3 -m falcon_kit.mains.copy_mapped --special-split={input} %(wi_pattern_kvs)s'
  72. """%(locals())
  73. self.write(rule)
  74. input_wildcards = set() # Not sure yet whether input must match output wildcards.
  75. for wi_fn in itervalues(wildcard_inputs):
  76. found = find_wildcards(wi_fn)
  77. input_wildcards.update(found)
  78. wildcards = list(sorted(input_wildcards))
  79. params_plus_wildcards = {k: '{%s}'%k for k in wildcards}
  80. params_plus_wildcards.update(parameters)
  81. # The parallel script uses all inputs, not just wildcards.
  82. all_inputs = dict(wildcard_inputs)
  83. all_inputs.update(nonwildcard_inputs)
  84. self.write_script_rule(all_inputs, wildcard_outputs, params_plus_wildcards, shell_template, rule_name=None)
  85. wo_str_lists_list = ['%s=[str(i) for i in input.%s]' %(name, name) for name in list(wildcard_outputs.keys())]
  86. wo_pattern_kv_list = ['%s="%s"' %(name, os.path.normpath(patt)) for (name, patt) in viewitems(wildcard_outputs)]
  87. wo_str_lists_kvs = ',\n '.join(wo_str_lists_list)
  88. wo_pattern_kvs = ',\n '.join(wo_pattern_kv_list)
  89. wildcards = list()
  90. for wi_fn in itervalues(wildcard_outputs):
  91. found = find_wildcards(wi_fn)
  92. if wildcards:
  93. assert wildcards == found, 'snakemake requires all outputs (and inputs?) to have the same wildcards'
  94. else:
  95. wildcards = found
  96. wildcards_comma_sep = ', '.join('"%s"' %k for k in wildcards)
  97. rule = '''
  98. rule dynamic_%(rule_name)s_merge:
  99. input: %(dynamic_input_kvs)s
  100. output: %(output_json)r
  101. run:
  102. snake_merge_multi_dynamic(output[0],
  103. dict(
  104. %(wo_str_lists_kvs)s
  105. ),
  106. dict(
  107. %(wo_pattern_kvs)s
  108. ),
  109. [%(wildcards_comma_sep)s] # all wildcards
  110. )
  111. '''%(locals())
  112. self.write(rule)
  113. def write_script_rule(self, inputs, outputs, parameters, shell_template, rule_name):
  114. assert '_bash_' not in parameters
  115. first_output_name, first_output_fn = list(outputs.items())[0] # for rundir, since we cannot sub wildcards in shell
  116. if not rule_name:
  117. rule_name = os.path.dirname(first_output_fn)
  118. rule_name = self.unique_rule_name(self.legalize(rule_name))
  119. wildcard_rundir = os.path.normpath(os.path.dirname(first_output_fn)) # unsubstituted
  120. # We use snake_string_path b/c normpath drops leading ./, but we do NOT want abspath.
  121. input_kvs = ', '.join('%s=%s'%(k, snake_string_path(v)) for k,v in
  122. sorted(viewitems(inputs)))
  123. output_kvs = ', '.join('%s=%s'%(k, snake_string_path(v)) for k,v in
  124. sorted(viewitems(outputs)))
  125. rule_parameters = {k: v for (k, v) in viewitems(parameters) if not k.startswith('_')}
  126. #rule_parameters['reltopdir'] = os.path.relpath('.', wildcard_rundir) # in case we need this later
  127. params = ','.join('\n %s="%s"'%(k,v) for (k, v) in viewitems(rule_parameters))
  128. shell = snake_shell(shell_template, wildcard_rundir)
  129. # cd $(dirname '{output.%(first_output_name)s}')
  130. rule = """
  131. rule static_%(rule_name)s:
  132. input: %(input_kvs)s
  133. output: %(output_kvs)s
  134. params:%(params)s
  135. shell:
  136. '''
  137. outdir=$(dirname {output[0]})
  138. #mkdir -p ${{outdir}}
  139. cd ${{outdir}}
  140. date
  141. %(shell)s
  142. date
  143. '''
  144. """%(locals())
  145. self.write(rule)
  146. def __call__(self, inputs, outputs, parameters, shell_template, rule_name=None):
  147. self.write_script_rule(inputs, outputs, parameters, shell_template, rule_name)
  148. def __init__(self, writer):
  149. self.write = writer.write
  150. self.rule_names = set() # to ensure uniqueness
  151. self.re_bad_char = re.compile(r'\W')
  152. self.write("""
  153. # THIS IS CURRENTLY BROKEN.
  154. import json
  155. import os
  156. #import snakemake.utils
  157. def snake_merge_dynamic_dict(reldir, input_fns, pattern, wildcards):
  158. '''Assume each wildcard appears at most once in the pattern.
  159. '''
  160. for k in wildcards:
  161. pattern = pattern.replace('{%s}' %k, '(?P<%s>\w+)' %k)
  162. re_dynamic = re.compile(pattern)
  163. mapped = list()
  164. for fn in input_fns:
  165. mo = re_dynamic.search(fn)
  166. assert mo, '{!r} did not match {!r}'.format(fn, re_dynamic.pattern)
  167. file_description = dict()
  168. file_description['wildcards'] = dict(mo.groupdict())
  169. file_description['fn'] = os.path.relpath(fn, reldir)
  170. mapped.append(file_description)
  171. return mapped
  172. def snake_merge_multi_dynamic(output_fn, dict_of_input_fns, dict_of_patterns, wildcards):
  173. outdir = os.path.normpath(os.path.dirname(output_fn))
  174. if not os.path.isdir(outdir):
  175. os.makedirs(outdir)
  176. assert list(sorted(dict_of_input_fns.keys())) == list(sorted(dict_of_patterns.keys()))
  177. all_mapped = dict()
  178. for i in dict_of_patterns.keys():
  179. input_fns = dict_of_input_fns[i]
  180. pattern = dict_of_patterns[i]
  181. mapped = snake_merge_dynamic_dict(outdir, input_fns, pattern, wildcards)
  182. all_mapped[i] = mapped
  183. all_grouped = dict()
  184. for i, mapped in all_mapped.items():
  185. #print(i, mapped)
  186. for file_description in mapped:
  187. #print(file_description)
  188. #print(file_description['wildcards'])
  189. #print(list(sorted(file_description['wildcards'].items())))
  190. wildkey = ','.join('{}={}'.format(k,v) for k,v in sorted(file_description['wildcards'].items()))
  191. if wildkey not in all_grouped:
  192. new_group = dict(
  193. wildcards=dict(file_description['wildcards']),
  194. fns=dict(),
  195. )
  196. all_grouped[wildkey] = new_group
  197. group = all_grouped[wildkey]
  198. wildcards = file_description['wildcards']
  199. assert wildcards == group['wildcards'], '{!r} should match {!r} by snakemake convention'.format(
  200. wildcards, group['wildcards'])
  201. fn = file_description['fn']
  202. group['fns'][i] = fn
  203. ser = json.dumps(all_grouped, indent=2, separators=(',', ': ')) + '\\n'
  204. with open(output_fn, 'w') as out:
  205. out.write(ser)
  206. """)
  207. prefix = """
  208. shell.prefix('''
  209. # Add -e vs. in falcon_unzip.
  210. set -vex
  211. hostname
  212. pwd
  213. ''')
  214. """
  215. self.write(prefix)
  216. class SnakemakeDynamic(object):
  217. """Not currently used."""
  218. def __init__(self, path):
  219. self.path = path
  220. def snake_string_path(p):
  221. """normpath drops leading ./
  222. """
  223. if isinstance(p, SnakemakeDynamic):
  224. return "dynamic('{}')".format(
  225. os.path.normpath(p.path))
  226. else:
  227. return "'{}'".format(
  228. os.path.normpath(p))
  229. def snake_shell(template, rundir):
  230. reltopdir = os.path.relpath('.', rundir)
  231. def makerel(mo):
  232. return os.path.join(reltopdir, mo.group(0))
  233. re_inout = re.compile(r'{(?:input|output)')
  234. return re_inout.sub(makerel, template, count =0)