fasta_subsample.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. """
  2. Performs a single pass over an input FASTA/FOFN, and collects
  3. all ZMWs. For each ZMW it calculates the expected molecular size by picking
  4. the internal median subread length.
  5. The script outputs a JSON file with a whitelist of ZMWs selected by a given
  6. strategy (random, longest, etc.) and desired coverage of a genome.
  7. Author: Ivan Sovic
  8. """
  9. from falcon_kit.mains.fasta_filter import ZMWTuple
  10. from falcon_kit.util.system import set_random_seed
  11. import falcon_kit.FastaReader as FastaReader
  12. import falcon_kit.mains.fasta_filter as fasta_filter
  13. import falcon_kit.io as io
  14. import os
  15. import sys
  16. import argparse
  17. import logging
  18. import contextlib
  19. import itertools
  20. import random
  21. import json
  22. import copy
  23. LOG = logging.getLogger()
  24. STRATEGY_RANDOM = 'random'
  25. STRATEGY_LONGEST = 'longest'
  26. def strategy_func_random(zmws):
  27. """
  28. >>> random.seed(12345); strategy_func_random([])
  29. []
  30. >>> random.seed(12345); strategy_func_random([('synthetic/1', 9)])
  31. [('synthetic/1', 9)]
  32. >>> random.seed(12345); strategy_func_random([('synthetic/1', 9), ('synthetic/2', 21), ('synthetic/3', 9), ('synthetic/4', 15), ('synthetic/5', 20)])
  33. [('synthetic/5', 20), ('synthetic/3', 9), ('synthetic/2', 21), ('synthetic/1', 9), ('synthetic/4', 15)]
  34. """
  35. ret = copy.deepcopy(zmws)
  36. random.shuffle(ret)
  37. return ret
  38. def strategy_func_longest(zmws):
  39. """
  40. >>> strategy_func_longest([])
  41. []
  42. >>> strategy_func_longest([('synthetic/1', 9)])
  43. [('synthetic/1', 9)]
  44. >>> strategy_func_longest([('synthetic/1', 9), ('synthetic/2', 21), ('synthetic/3', 9), ('synthetic/4', 15), ('synthetic/5', 20)])
  45. [('synthetic/2', 21), ('synthetic/5', 20), ('synthetic/4', 15), ('synthetic/1', 9), ('synthetic/3', 9)]
  46. """
  47. return sorted(zmws, key = lambda x: x[1], reverse = True)
  48. STRATEGY_TYPE_TO_FUNC = { STRATEGY_RANDOM: strategy_func_random,
  49. STRATEGY_LONGEST: strategy_func_longest
  50. }
  51. def get_strategy_func(strategy_type):
  52. """
  53. >>> get_strategy_func(STRATEGY_RANDOM) == strategy_func_random
  54. True
  55. >>> get_strategy_func(STRATEGY_LONGEST) == strategy_func_longest
  56. True
  57. >>> try:
  58. ... get_strategy_func('nonexistent_strategy')
  59. ... print('False')
  60. ... except:
  61. ... print('True')
  62. True
  63. """
  64. assert strategy_type in STRATEGY_TYPE_TO_FUNC, 'Unknown strategy type: "{}"'.format(str(strategy_type))
  65. return STRATEGY_TYPE_TO_FUNC[strategy_type]
  66. def select_zmws(zmws, min_requested_bases):
  67. """
  68. >>> select_zmws([], 0)
  69. ([], 0)
  70. >>> select_zmws([], 10)
  71. ([], 0)
  72. >>> select_zmws([('zmw/1', 1), ('zmw/2', 2), ('zmw/3', 5), ('zmw/4', 7), ('zmw/5', 10), ('zmw/6', 15)], 10)
  73. (['zmw/1', 'zmw/2', 'zmw/3', 'zmw/4'], 15)
  74. >>> select_zmws([('zmw/1', 1), ('zmw/2', 2), ('zmw/3', 5), ('zmw/4', 7), ('zmw/5', 10), ('zmw/6', 15)], 20)
  75. (['zmw/1', 'zmw/2', 'zmw/3', 'zmw/4', 'zmw/5'], 25)
  76. >>> select_zmws([('zmw/1', 1), ('zmw/1', 2), ('zmw/1', 5), ('zmw/1', 7), ('zmw/1', 10), ('zmw/1', 15)], 20)
  77. (['zmw/1', 'zmw/1', 'zmw/1', 'zmw/1', 'zmw/1'], 25)
  78. """
  79. # Select the first N ZMWs which sum up to the desired coverage.
  80. num_bases = 0
  81. subsampled_zmws = []
  82. for zmw_name, seq_len in zmws:
  83. num_bases += seq_len
  84. subsampled_zmws.append(zmw_name)
  85. if num_bases >= min_requested_bases:
  86. break
  87. return subsampled_zmws, num_bases
  88. def calc_stats(total_unique_molecular_bases, total_bases, output_bases, genome_size, coverage):
  89. """
  90. >>> calc_stats(0, 0, 0, 0, 0) == \
  91. {'genome_size': 0, 'coverage': 0, 'total_bases': 0, 'total_unique_molecular_bases': 0, \
  92. 'output_bases': 0, 'unique_molecular_avg_cov': 0.0, 'output_avg_cov': 0.0, 'total_avg_cov': 0.0}
  93. True
  94. >>> calc_stats(10000, 100000, 2000, 1000, 2) == \
  95. {'genome_size': 1000, 'coverage': 2, 'total_bases': 100000, 'total_unique_molecular_bases': 10000, \
  96. 'output_bases': 2000, 'unique_molecular_avg_cov': 10.0, 'output_avg_cov': 2.0, 'total_avg_cov': 100.0}
  97. True
  98. """
  99. unique_molecular_avg_cov = 0.0 if genome_size == 0 else float(total_unique_molecular_bases) / float(genome_size)
  100. total_avg_cov = 0.0 if genome_size == 0 else float(total_bases) / float(genome_size)
  101. output_avg_cov = 0.0 if genome_size == 0 else float(output_bases) / float(genome_size)
  102. ret = {}
  103. ret['genome_size'] = genome_size
  104. ret['coverage'] = coverage
  105. ret['total_bases'] = total_bases
  106. ret['total_unique_molecular_bases'] = total_unique_molecular_bases
  107. ret['output_bases'] = output_bases
  108. ret['total_avg_cov'] = total_avg_cov
  109. ret['unique_molecular_avg_cov'] = unique_molecular_avg_cov
  110. ret['output_avg_cov'] = output_avg_cov
  111. return ret
  112. def collect_zmws(yield_zmwtuple_func):
  113. """
  114. >>> collect_zmws([])
  115. ([], 0, 0)
  116. >>> collect_zmws([\
  117. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=0, subread_end=1000, seq_len=1000, subread_record=None, subread_header='test/1/0_1000', subread_id=0), \
  118. ])
  119. ([('test/1', 1000)], 1000, 1000)
  120. >>> collect_zmws([\
  121. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=1000, subread_record=None, subread_header='test/1/0_1000', subread_id=0), \
  122. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=2000, subread_record=None, subread_header='test/1/1000_3000', subread_id=0), \
  123. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=3000, subread_record=None, subread_header='test/1/3000_6000', subread_id=0), \
  124. ])
  125. ([('test/1', 2000)], 2000, 6000)
  126. >>> collect_zmws([\
  127. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=1000, subread_record=None, subread_header='test/1/0_1000', subread_id=0), \
  128. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=2000, subread_record=None, subread_header='test/1/1000_3000', subread_id=1), \
  129. ZMWTuple(movie_name='test' , zmw_id='1', subread_start=123, subread_end=456, seq_len=3000, subread_record=None, subread_header='test/1/3000_6000', subread_id=2), \
  130. ZMWTuple(movie_name='test' , zmw_id='2', subread_start=123, subread_end=456, seq_len=10000, subread_record=None, subread_header='header2', subread_id=3), \
  131. ])
  132. ([('test/1', 2000), ('test/2', 10000)], 12000, 16000)
  133. """
  134. zmws = []
  135. unique_molecular_size = 0
  136. total_size = 0
  137. for zmw_id, zmw_subreads in itertools.groupby(yield_zmwtuple_func, lambda x: x.zmw_id):
  138. zmw_subreads_list = list(zmw_subreads)
  139. zrec = fasta_filter.internal_median_zmw_subread(zmw_subreads_list)
  140. movie_zmw = zrec.movie_name + '/' + zrec.zmw_id
  141. unique_molecular_size += zrec.seq_len
  142. total_size += sum([zmw.seq_len for zmw in zmw_subreads_list])
  143. zmws.append((movie_zmw, zrec.seq_len))
  144. return zmws, unique_molecular_size, total_size
  145. def yield_record(input_files):
  146. for input_fn in input_files:
  147. with open(input_fn, 'r') as fp_in:
  148. fasta_records = FastaReader.yield_fasta_record(fp_in, log=LOG.info)
  149. for record in fasta_records:
  150. yield record
  151. def run(yield_zmw_tuple_func, coverage, genome_size, strategy_func):
  152. zmws, total_unique_molecular_bases, total_bases = collect_zmws(yield_zmw_tuple_func)
  153. zmws = strategy_func(zmws)
  154. subsampled_zmws, output_bases = select_zmws(zmws, coverage * genome_size)
  155. stats_dict = calc_stats(total_unique_molecular_bases, total_bases, output_bases, genome_size, coverage)
  156. return subsampled_zmws, zmws, stats_dict
  157. def parse_args(argv):
  158. parser = argparse.ArgumentParser(description="Produces a list of ZMW where the median unique molecular "\
  159. "coverage sums up to the desired coverage of the given genome size.s",
  160. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  161. parser.add_argument('--strategy', type=str, default='random',
  162. help='Subsampling strategy: random, longest')
  163. parser.add_argument('--coverage', type=float, default=60,
  164. help='Desired coverage for subsampling.')
  165. parser.add_argument('--genome-size', type=float, default=0,
  166. help='Genome size estimate of the input dataset.', required=True)
  167. parser.add_argument('--random-seed', type=int, default=12345,
  168. help='Seed value used for the random generator.', required=False)
  169. parser.add_argument('input_fn', type=str, default='input.fofn',
  170. help='Input FASTA files or a FOFN. (Streaming is not allowed).')
  171. parser.add_argument('out_prefix', type=str, default='input.cov',
  172. help='Prefix of the output files to generate.')
  173. args = parser.parse_args(argv[1:])
  174. return args
  175. def main(argv=sys.argv):
  176. args = parse_args(argv)
  177. logging.basicConfig(level=logging.INFO)
  178. strategy_func = get_strategy_func(args.strategy)
  179. LOG.info('Using subsampling strategy: "{strategy}"'.format(strategy=args.strategy))
  180. set_random_seed(args.random_seed)
  181. input_files = list(io.yield_abspath_from_fofn(args.input_fn))
  182. zmws_whitelist, zmws_all, stats_dict = run(
  183. fasta_filter.yield_zmwtuple(yield_record(input_files), None, False), args.coverage, args.genome_size, strategy_func)
  184. out_zmw_whitelist = args.out_prefix + '.whitelist.json'
  185. out_all_zmws = args.out_prefix + '.all.json'
  186. out_zmw_stats = args.out_prefix + '.stats.json'
  187. with open(out_zmw_whitelist, 'w') as fp_out_whitelist, \
  188. open(out_all_zmws, 'w') as fp_out_all_zmws, \
  189. open(out_zmw_stats, 'w') as fp_out_stats:
  190. # Write out the whitelist.
  191. fp_out_whitelist.write(json.dumps(zmws_whitelist))
  192. # Write the entire list of ZMWs and lengths, might be very informative.
  193. fp_out_all_zmws.write(json.dumps(zmws_all))
  194. # Write out the stats.
  195. fp_out_stats.write(json.dumps(stats_dict))
  196. if __name__ == "__main__": # pragma: no cover
  197. main(sys.argv) # pragma: no cover