9.9 KB

  1. """
  2. Performs a single pass over an input FASTA/FOFN, and collects
  3. all ZMWs. For each ZMW it calculates the expected molecular size by picking
  4. the internal median subread length.
  5. The script outputs a JSON file with a whitelist of ZMWs selected by a given
  6. strategy (random, longest, etc.) and desired coverage of a genome.
  7. Author: Ivan Sovic
  8. """
  9. from falcon_kit.mains.fasta_filter import ZMWTuple
  10. from falcon_kit.util.system import set_random_seed
  11. import falcon_kit.FastaReader as FastaReader
  12. import falcon_kit.mains.fasta_filter as fasta_filter
  13. import as io
  14. import os
  15. import sys
  16. import argparse
  17. import logging
  18. import contextlib
  19. import itertools
  20. import random
  21. import json
  22. import copy
  23. LOG = logging.getLogger()
  24. STRATEGY_RANDOM = 'random'
  25. STRATEGY_LONGEST = 'longest'
  26. def strategy_func_random(zmws):
  27. """
  28. >>> random.seed(12345); strategy_func_random([])
  29. []
  30. >>> random.seed(12345); strategy_func_random([('synthetic/1', 9)])
  31. [('synthetic/1', 9)]
  32. >>> random.seed(12345); strategy_func_random([('synthetic/1', 9), ('synthetic/2', 21), ('synthetic/3', 9), ('synthetic/4', 15), ('synthetic/5', 20)])
  33. [('synthetic/5', 20), ('synthetic/3', 9), ('synthetic/2', 21), ('synthetic/1', 9), ('synthetic/4', 15)]
  34. """
  35. ret = copy.deepcopy(zmws)
  36. random.shuffle(ret)
  37. return ret
  38. def strategy_func_longest(zmws):
  39. """
  40. >>> strategy_func_longest([])
  41. []
  42. >>> strategy_func_longest([('synthetic/1', 9)])
  43. [('synthetic/1', 9)]
  44. >>> strategy_func_longest([('synthetic/1', 9), ('synthetic/2', 21), ('synthetic/3', 9), ('synthetic/4', 15), ('synthetic/5', 20)])
  45. [('synthetic/2', 21), ('synthetic/5', 20), ('synthetic/4', 15), ('synthetic/1', 9), ('synthetic/3', 9)]
  46. """
  47. return sorted(zmws, key = lambda x: x[1], reverse = True)
  48. STRATEGY_TYPE_TO_FUNC = { STRATEGY_RANDOM: strategy_func_random,
  49. STRATEGY_LONGEST: strategy_func_longest
  50. }
  51. def get_strategy_func(strategy_type):
  52. """
  53. >>> get_strategy_func(STRATEGY_RANDOM) == strategy_func_random
  54. True
  55. >>> get_strategy_func(STRATEGY_LONGEST) == strategy_func_longest
  56. True
  57. >>> try:
  58. ... get_strategy_func('nonexistent_strategy')
  59. ... print('False')
  60. ... except:
  61. ... print('True')
  62. True
  63. """
  64. assert strategy_type in STRATEGY_TYPE_TO_FUNC, 'Unknown strategy type: "{}"'.format(str(strategy_type))
  65. return STRATEGY_TYPE_TO_FUNC[strategy_type]
  66. def select_zmws(zmws, min_requested_bases):
  67. """
  68. >>> select_zmws([], 0)
  69. ([], 0)
  70. >>> select_zmws([], 10)
  71. ([], 0)
  72. >>> select_zmws([('zmw/1', 1), ('zmw/2', 2), ('zmw/3', 5), ('zmw/4', 7), ('zmw/5', 10), ('zmw/6', 15)], 10)
  73. (['zmw/1', 'zmw/2', 'zmw/3', 'zmw/4'], 15)
  74. >>> select_zmws([('zmw/1', 1), ('zmw/2', 2), ('zmw/3', 5), ('zmw/4', 7), ('zmw/5', 10), ('zmw/6', 15)], 20)
  75. (['zmw/1', 'zmw/2', 'zmw/3', 'zmw/4', 'zmw/5'], 25)
  76. >>> select_zmws([('zmw/1', 1), ('zmw/1', 2), ('zmw/1', 5), ('zmw/1', 7), ('zmw/1', 10), ('zmw/1', 15)], 20)
  77. (['zmw/1', 'zmw/1', 'zmw/1', 'zmw/1', 'zmw/1'], 25)
  78. """
  79. # Select the first N ZMWs which sum up to the desired coverage.
  80. num_bases = 0
  81. subsampled_zmws = []
  82. for zmw_name, seq_len in zmws:
  83. num_bases += seq_len
  84. subsampled_zmws.append(zmw_name)
  85. if num_bases >= min_requested_bases:
  86. break
  87. return subsampled_zmws, num_bases
  88. def calc_stats(total_unique_molecular_bases, total_bases, output_bases, genome_size, coverage):
  89. """
  90. >>> calc_stats(0, 0, 0, 0, 0) == \
  91. {'genome_size': 0, 'coverage': 0, 'total_bases': 0, 'total_unique_molecular_bases': 0, \
  92. 'output_bases': 0, 'unique_molecular_avg_cov': 0.0, 'output_avg_cov': 0.0, 'total_avg_cov': 0.0}
  93. True
  94. >>> calc_stats(10000, 100000, 2000, 1000, 2) == \
  95. {'genome_size': 1000, 'coverage': 2, 'total_bases': 100000, 'total_unique_molecular_bases': 10000, \
  96. 'output_bases': 2000, 'unique_molecular_avg_cov': 10.0, 'output_avg_cov': 2.0, 'total_avg_cov': 100.0}
  97. True
  98. """
  99. unique_molecular_avg_cov = 0.0 if genome_size == 0 else float(total_unique_molecular_bases) / float(genome_size)
  100. total_avg_cov = 0.0 if genome_size == 0 else float(total_bases) / float(genome_size)
  101. output_avg_cov = 0.0 if genome_size == 0 else float(output_bases) / float(genome_size)
  102. ret = {}
  103. ret['genome_size'] = genome_size
  104. ret['coverage'] = coverage
  105. ret['total_bases'] = total_bases
  106. ret['total_unique_molecular_bases'] = total_unique_molecular_bases
  107. ret['output_bases'] = output_bases
  108. ret['total_avg_cov'] = total_avg_cov
  109. ret['unique_molecular_avg_cov'] = unique_molecular_avg_cov
  110. ret['output_avg_cov'] = output_avg_cov
  111. return ret
  112. def collect_zmws(yield_zmwtuple_func):
  113. """
  114. >>> collect_zmws([])
  115. ([], 0, 0)
  116. >>> collect_zmws([\
  117. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=0, subread_end=1000, seq_len=1000, subread_record=None, subread_header='test/1/0_1000', subread_id=0), \
  118. ])
  119. ([('test/1', 1000)], 1000, 1000)
  120. >>> collect_zmws([\
  121. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=1000, subread_record=None, subread_header='test/1/0_1000', subread_id=0), \
  122. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=2000, subread_record=None, subread_header='test/1/1000_3000', subread_id=0), \
  123. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=3000, subread_record=None, subread_header='test/1/3000_6000', subread_id=0), \
  124. ])
  125. ([('test/1', 2000)], 2000, 6000)
  126. >>> collect_zmws([\
  127. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=1000, subread_record=None, subread_header='test/1/0_1000', subread_id=0), \
  128. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=2000, subread_record=None, subread_header='test/1/1000_3000', subread_id=1), \
  129. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=3000, subread_record=None, subread_header='test/1/3000_6000', subread_id=2), \
  130. ZMWTuple(movie_name='test' , zmw_id='2', subread_start=123, subread_end=456, seq_len=10000, subread_record=None, subread_header='header2', subread_id=3), \
  131. ])
  132. ([('test/1', 2000), ('test/2', 10000)], 12000, 16000)
  133. """
  134. zmws = []
  135. unique_molecular_size = 0
  136. total_size = 0
  137. for zmw_id, zmw_subreads in itertools.groupby(yield_zmwtuple_func, lambda x: x.zmw_id):
  138. zmw_subreads_list = list(zmw_subreads)
  139. zrec = fasta_filter.internal_median_zmw_subread(zmw_subreads_list)
  140. movie_zmw = zrec.movie_name + '/' + zrec.zmw_id
  141. unique_molecular_size += zrec.seq_len
  142. total_size += sum([zmw.seq_len for zmw in zmw_subreads_list])
  143. zmws.append((movie_zmw, zrec.seq_len))
  144. return zmws, unique_molecular_size, total_size
  145. def yield_record(input_files):
  146. for input_fn in input_files:
  147. with open(input_fn, 'r') as fp_in:
  148. fasta_records = FastaReader.yield_fasta_record(fp_in,
  149. for record in fasta_records:
  150. yield record
  151. def run(yield_zmw_tuple_func, coverage, genome_size, strategy_func):
  152. zmws, total_unique_molecular_bases, total_bases = collect_zmws(yield_zmw_tuple_func)
  153. zmws = strategy_func(zmws)
  154. subsampled_zmws, output_bases = select_zmws(zmws, coverage * genome_size)
  155. stats_dict = calc_stats(total_unique_molecular_bases, total_bases, output_bases, genome_size, coverage)
  156. return subsampled_zmws, zmws, stats_dict
  157. def parse_args(argv):
  158. parser = argparse.ArgumentParser(description="Produces a list of ZMW where the median unique molecular "\
  159. "coverage sums up to the desired coverage of the given genome size.s",
  160. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  161. parser.add_argument('--strategy', type=str, default='random',
  162. help='Subsampling strategy: random, longest')
  163. parser.add_argument('--coverage', type=float, default=60,
  164. help='Desired coverage for subsampling.')
  165. parser.add_argument('--genome-size', type=float, default=0,
  166. help='Genome size estimate of the input dataset.', required=True)
  167. parser.add_argument('--random-seed', type=int, default=12345,
  168. help='Seed value used for the random generator.', required=False)
  169. parser.add_argument('input_fn', type=str, default='input.fofn',
  170. help='Input FASTA files or a FOFN. (Streaming is not allowed).')
  171. parser.add_argument('out_prefix', type=str, default='input.cov',
  172. help='Prefix of the output files to generate.')
  173. args = parser.parse_args(argv[1:])
  174. return args
  175. def main(argv=sys.argv):
  176. args = parse_args(argv)
  177. logging.basicConfig(level=logging.INFO)
  178. strategy_func = get_strategy_func(args.strategy)
  179.'Using subsampling strategy: "{strategy}"'.format(strategy=args.strategy))
  180. set_random_seed(args.random_seed)
  181. input_files = list(io.yield_abspath_from_fofn(args.input_fn))
  182. zmws_whitelist, zmws_all, stats_dict = run(
  183. fasta_filter.yield_zmwtuple(yield_record(input_files), None, False), args.coverage, args.genome_size, strategy_func)
  184. out_zmw_whitelist = args.out_prefix + '.whitelist.json'
  185. out_all_zmws = args.out_prefix + '.all.json'
  186. out_zmw_stats = args.out_prefix + '.stats.json'
  187. with open(out_zmw_whitelist, 'w') as fp_out_whitelist, \
  188. open(out_all_zmws, 'w') as fp_out_all_zmws, \
  189. open(out_zmw_stats, 'w') as fp_out_stats:
  190. # Write out the whitelist.
  191. fp_out_whitelist.write(json.dumps(zmws_whitelist))
  192. # Write the entire list of ZMWs and lengths, might be very informative.
  193. fp_out_all_zmws.write(json.dumps(zmws_all))
  194. # Write out the stats.
  195. fp_out_stats.write(json.dumps(stats_dict))
  196. if __name__ == "__main__": # pragma: no cover
  197. main(sys.argv) # pragma: no cover