consensus.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. from builtins import range
  2. from ctypes import (POINTER, c_char_p, c_uint, c_uint,
  3. c_uint, c_uint, c_uint, c_double, string_at, pointer)
  4. from falcon_kit.multiproc import Pool
  5. from falcon_kit import falcon
  6. import argparse
  7. import logging
  8. import multiprocessing
  9. import os
  10. import re
  11. import sys
  12. import falcon_kit
  13. import falcon_kit.util.io as io
  14. import collections
  15. LOG = logging.getLogger()
  16. falcon.generate_consensus.argtypes = [
  17. POINTER(c_char_p), c_uint, c_uint, c_uint, c_double]
  18. falcon.generate_consensus.restype = POINTER(falcon_kit.ConsensusData)
  19. falcon.free_consensus_data.argtypes = [POINTER(falcon_kit.ConsensusData)]
  20. falcon.generate_consensus_from_mapping.argtypes = [
  21. POINTER(c_char_p), POINTER(POINTER(falcon_kit.AlnRange)), c_uint, c_uint, c_uint, c_double]
  22. falcon.generate_consensus_from_mapping.restype = POINTER(falcon_kit.ConsensusData)
  23. """
  24. SeqTuple encodes a single line in a block for consensus. Legacy code used only the 'name' and 'seq' (read from input),
  25. but if the coordinates are already known, we can use this info.
  26. The `qlen` and `tlen` are necessary because this consensus code can clip the end of a sequence if it's
  27. beyond a certain threshold. If it's clipped, the start/end coordinates can fall within the clipped region,
  28. which means that the internal alignment will have to be triggered.
  29. The 'tstart' and 'tend' relate to the seed read, and the 'qstart' and 'qend' to the currenr read on the same line.
  30. The current query should be in the same strand as the target. For consistency, we added a 'qstranq' as well, but
  31. in the current LA4Falcon output it will always be 0.
  32. The 'aln' field can be used to provide an alignment directly from the tool which determined that these sequences
  33. need to go into the same block. This could be used downstream to prevent quadratic memory consumption during error
  34. correction, and speed up the process.
  35. Parameter 'is_trimmed' is a bool, indicating that the sequence was trimmed from the back because it exceeded the maximum length.
  36. """
  37. SeqTuple = collections.namedtuple('SeqTuple', ['name', 'seq', 'qstrand', 'qstart', 'qend', 'qlen', 'tstart', 'tend', 'tlen', 'aln', 'is_mapped', 'is_trimmed'])
  38. def get_longest_reads(seqs, max_n_read, max_cov_aln, sort=True):
  39. # including the sort kwarg allows us to avoid a redundant sort
  40. # in get_consensus_trimmed()
  41. if sort:
  42. seqs = seqs[:1] + sorted(seqs[1:], key=lambda x: -len(x.seq))
  43. longest_n_reads = max_n_read
  44. if max_cov_aln > 0:
  45. longest_n_reads = 1
  46. seed_len = len(seqs[0].seq)
  47. read_cov = 0
  48. for seq in seqs[1:]:
  49. if read_cov // seed_len > max_cov_aln:
  50. break
  51. longest_n_reads += 1
  52. read_cov += len(seq.seq)
  53. longest_n_reads = min(longest_n_reads, max_n_read)
  54. return(seqs[:longest_n_reads])
  55. def get_alignment(seq1, seq0, edge_tolerance=1000):
  56. kup = falcon_kit.kup
  57. K = 8
  58. lk_ptr = kup.allocate_kmer_lookup(1 << (K * 2))
  59. sa_ptr = kup.allocate_seq(len(seq0))
  60. sda_ptr = kup.allocate_seq_addr(len(seq0))
  61. kup.add_sequence(0, K, seq0, len(seq0), sda_ptr, sa_ptr, lk_ptr)
  62. kup.mask_k_mer(1 << (K * 2), lk_ptr, 16)
  63. kmer_match_ptr = kup.find_kmer_pos_for_seq(
  64. seq1, len(seq1), K, sda_ptr, lk_ptr)
  65. kmer_match = kmer_match_ptr[0]
  66. aln_range_ptr = kup.find_best_aln_range2(kmer_match_ptr, K, K * 50, 25)
  67. #x,y = zip( * [ (kmer_match.query_pos[i], kmer_match.target_pos[i]) for i in range(kmer_match.count )] )
  68. aln_range = aln_range_ptr[0]
  69. kup.free_kmer_match(kmer_match_ptr)
  70. s1, e1, s0, e0, km_score = aln_range.s1, aln_range.e1, aln_range.s2, aln_range.e2, aln_range.score
  71. e1 += K + K // 2
  72. e0 += K + K // 2
  73. kup.free_aln_range(aln_range)
  74. len_1 = len(seq1)
  75. len_0 = len(seq0)
  76. if e1 > len_1:
  77. e1 = len_1
  78. if e0 > len_0:
  79. e0 = len_0
  80. aln_size = 1
  81. if e1 - s1 > 500:
  82. aln_size = max(e1 - s1, e0 - s0)
  83. aln_score = int(km_score * 48)
  84. aln_q_s = s1
  85. aln_q_e = e1
  86. aln_t_s = s0
  87. aln_t_e = e0
  88. kup.free_seq_addr_array(sda_ptr)
  89. kup.free_seq_array(sa_ptr)
  90. kup.free_kmer_lookup(lk_ptr)
  91. if s1 > edge_tolerance and s0 > edge_tolerance:
  92. return 0, 0, 0, 0, 0, 0, "none"
  93. if len_1 - e1 > edge_tolerance and len_0 - e0 > edge_tolerance:
  94. return 0, 0, 0, 0, 0, 0, "none"
  95. if e1 - s1 > 500 and aln_size > 500:
  96. return s1, s1 + aln_q_e - aln_q_s, s0, s0 + aln_t_e - aln_t_s, aln_size, aln_score, "aln"
  97. else:
  98. return 0, 0, 0, 0, 0, 0, "none"
  99. def get_trimmed_seq(seq, s, e):
  100. # Mapping info is useless after clipping, so just reset it.
  101. ret = SeqTuple(name = seq.name, seq = seq.seq[s:e],
  102. qstrand = seq.qstrand, qstart = -1, qend = -1, qlen = -1,
  103. tstart = -1, tend = -1, tlen = -1,
  104. aln = '*', is_mapped = False, is_trimmed = True)
  105. return ret
  106. def get_consensus_core(seqs, min_cov, K, min_idt, allow_external_mapping):
  107. seqs_ptr = (c_char_p * len(seqs))()
  108. seqs_ptr[:] = [bytes(val.seq, encoding='ascii') for val in seqs]
  109. all_seqs_mapped = False
  110. if allow_external_mapping:
  111. all_seqs_mapped = True
  112. for seq in seqs:
  113. if not seq.is_mapped:
  114. all_seqs_mapped = False
  115. break
  116. if not all_seqs_mapped:
  117. LOG.info('Internally mapping the sequences.')
  118. consensus_data_ptr = falcon.generate_consensus(
  119. seqs_ptr, len(seqs), min_cov, K, min_idt)
  120. else:
  121. LOG.info('Using external mapping coordinates from input.')
  122. aln_ranges_ptr = (POINTER(falcon_kit.AlnRange) * len(seqs))()
  123. for i, seq in enumerate(seqs):
  124. a = falcon_kit.AlnRange(seq.qstart, seq.qend, seq.tstart, seq.tend, (seq.qend - seq.qstart))
  125. aln_ranges_ptr[i] = pointer(a)
  126. consensus_data_ptr = falcon.generate_consensus_from_mapping(
  127. seqs_ptr, aln_ranges_ptr, len(seqs), min_cov, K, min_idt)
  128. del aln_ranges_ptr
  129. del seqs_ptr
  130. if not consensus_data_ptr:
  131. LOG.warning("====>get_consensus_core return consensus_data_ptr={}".format(consensus_data_ptr))
  132. return ''
  133. # assert consensus_data_ptr
  134. consensus = string_at(consensus_data_ptr[0].sequence)[:]
  135. #eff_cov = consensus_data_ptr[0].eff_cov[:len(consensus)]
  136. LOG.debug(' Freeing')
  137. falcon.free_consensus_data(consensus_data_ptr)
  138. return consensus.decode('ascii')
  139. def get_consensus_without_trim(c_input):
  140. seqs, seed_id, config = c_input
  141. LOG.debug('Starting get_consensus_without_trim(len(seqs)=={}, seed_id={})'.format(
  142. len(seqs), seed_id))
  143. min_cov, K, max_n_read, min_idt, edge_tolerance, trim_size, min_cov_aln, max_cov_aln, allow_external_mapping = config
  144. if len(seqs) > max_n_read:
  145. seqs = get_longest_reads(seqs, max_n_read, max_cov_aln, sort=True)
  146. consensus = get_consensus_core(seqs, min_cov, K, min_idt, allow_external_mapping)
  147. LOG.debug(' Finishing get_consensus_without_trim(seed_id={})'.format(seed_id))
  148. return consensus, seed_id
  149. def get_consensus_with_trim(c_input):
  150. seqs, seed_id, config = c_input
  151. LOG.debug('Starting get_consensus_with_trim(len(seqs)=={}, seed_id={})'.format(
  152. len(seqs), seed_id))
  153. min_cov, K, max_n_read, min_idt, edge_tolerance, trim_size, min_cov_aln, max_cov_aln, allow_external_mapping = config
  154. trim_seqs = []
  155. seed = seqs[0]
  156. for seq in seqs[1:]:
  157. aln_data = get_alignment(seq.seq, seed.seq, edge_tolerance)
  158. s1, e1, s2, e2, aln_size, aln_score, c_status = aln_data
  159. if c_status == "none":
  160. continue
  161. if aln_score > 1000 and e1 - s1 > 500:
  162. e1 -= trim_size
  163. s1 += trim_size
  164. trim_seqs.append((e1 - s1, get_trimmed_seq(seq, s1, e1)))
  165. # trim_seqs.append((e1 - s1, seq.seq[s1:e1]))
  166. trim_seqs.sort(key=lambda x: -x[0]) # use longest alignment first
  167. trim_seqs = [x[1] for x in trim_seqs]
  168. trim_seqs = [seed] + trim_seqs
  169. if len(trim_seqs[1:]) > max_n_read:
  170. # seqs already sorted, dont' sort again
  171. trim_seqs = get_longest_reads(
  172. trim_seqs, max_n_read, max_cov_aln, sort=False)
  173. consensus = get_consensus_core(trim_seqs, min_cov, K, min_idt, allow_external_mapping)
  174. LOG.debug(' Finishing get_consensus_with_trim(seed_id={})'.format(seed_id))
  175. return consensus, seed_id
  176. def get_seq_data(config, min_n_read, min_len_aln):
  177. max_len = 128000
  178. min_cov, K, max_n_read, min_idt, edge_tolerance, trim_size, min_cov_aln, max_cov_aln, allow_external_mapping = config
  179. seqs = []
  180. seed_id = None
  181. seed_len = 0
  182. seqs_data = []
  183. read_cov = 0
  184. read_ids = set()
  185. with sys.stdin as f:
  186. for line in f:
  187. split_line = line.strip().split()
  188. if len(split_line) < 2:
  189. continue
  190. qname = split_line[0]
  191. qseq = split_line[1]
  192. qstrand, qstart, qend, qlen = 0, -1, -1, -1
  193. tstart, tend, tlen = -1, -1, -1
  194. aln, is_mapped, is_trimmed = '*', False, False
  195. if len(split_line) >= 10:
  196. qstrand = int(split_line[2])
  197. qstart = int(split_line[3])
  198. qend = int(split_line[4])
  199. qlen = int(split_line[5])
  200. tstart = int(split_line[6])
  201. tend = int(split_line[7])
  202. tlen = int(split_line[8])
  203. aln = split_line[9]
  204. is_mapped = True
  205. new_seq = SeqTuple(name = qname, seq = qseq,
  206. qstrand = qstrand, qstart = qstart, qend = qend, qlen = qlen,
  207. tstart = tstart, tend = tend, tlen = tlen,
  208. aln = aln, is_mapped = is_mapped, is_trimmed = is_trimmed)
  209. if len(new_seq.seq) > max_len:
  210. new_seq = get_trimmed_seq(new_seq, 0, max_len - 1)
  211. if new_seq.name not in ("+", "-", "*"):
  212. if len(new_seq.seq) >= min_len_aln:
  213. if len(seqs) == 0:
  214. seqs.append(new_seq) # the "seed"
  215. seed_len = len(new_seq.seq)
  216. seed_id = new_seq.name
  217. if new_seq.name not in read_ids: # avoidng using the same read twice. seed is used again here by design
  218. seqs.append(new_seq)
  219. read_ids.add(new_seq.name)
  220. read_cov += len(new_seq.seq)
  221. elif split_line[0] == "+":
  222. if len(seqs) >= min_n_read and read_cov // seed_len >= min_cov_aln:
  223. seqs = get_longest_reads(
  224. seqs, max_n_read, max_cov_aln, sort=True)
  225. yield (seqs, seed_id, config)
  226. #seqs_data.append( (seqs, seed_id) )
  227. seqs = []
  228. read_ids = set()
  229. seed_id = None
  230. read_cov = 0
  231. elif split_line[0] == "*":
  232. seqs = []
  233. read_ids = set()
  234. seed_id = None
  235. read_cov = 0
  236. elif split_line[0] == "-":
  237. # yield (seqs, seed_id)
  238. #seqs_data.append( (seqs, seed_id) )
  239. break
  240. def format_seq(seq, col):
  241. return "\n".join([seq[i:(i + col)] for i in range(0, len(seq), col)])
  242. def parse_args(argv):
  243. parser = argparse.ArgumentParser(description='a simple multi-processor consensus sequence generator',
  244. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  245. parser.add_argument('--n-core', type=int, default=24,
  246. help='number of processes used for generating consensus; '
  247. '0 for main process only')
  248. parser.add_argument('--min-cov', type=int, default=6,
  249. help='minimum coverage to break the consensus')
  250. parser.add_argument('--min-cov-aln', type=int, default=10,
  251. help='minimum coverage of alignment data; a seed read with less than MIN_COV_ALN average depth' +
  252. ' of coverage will be completely ignored')
  253. parser.add_argument('--max-cov-aln', type=int, default=0, # 0 to emulate previous behavior
  254. help='maximum coverage of alignment data; a seed read with more than MAX_COV_ALN average depth' + \
  255. ' of coverage of the longest alignments will be capped, excess shorter alignments will be ignored')
  256. parser.add_argument('--min-len-aln', type=int, default=0, # 0 to emulate previous behavior
  257. help='minimum length of a sequence in an alignment to be used in consensus; any shorter sequence will be completely ignored')
  258. parser.add_argument('--min-n-read', type=int, default=10,
  259. help='1 + minimum number of reads used in generating the consensus; a seed read with fewer alignments will ' +
  260. 'be completely ignored')
  261. parser.add_argument('--max-n-read', type=int, default=500,
  262. help='1 + maximum number of reads used in generating the consensus')
  263. parser.add_argument('--trim', action="store_true", default=False,
  264. help='trim the input sequence with k-mer spare dynamic programming to find the mapped range')
  265. parser.add_argument('--output-full', action="store_true", default=False,
  266. help='output uncorrected regions too')
  267. parser.add_argument('--output-multi', action="store_true", default=False,
  268. help='output multi correct regions')
  269. parser.add_argument('--min-idt', type=float, default=0.70,
  270. help='minimum identity of the alignments used for correction')
  271. parser.add_argument('--edge-tolerance', type=int, default=1000,
  272. help='for trimming, the there is unaligned edge leng > edge_tolerance, ignore the read')
  273. parser.add_argument('--trim-size', type=int, default=50,
  274. help='the size for triming both ends from initial sparse aligned region')
  275. parser.add_argument('--allow-external-mapping', action="store_true", default=False,
  276. help='if provided, externally determined mapping coordinates will be used for error correction')
  277. parser.add_argument('-v', '--verbose-level', type=float, default=2.0,
  278. help='logging level (WARNING=3, INFO=2, DEBUG=1)')
  279. return parser.parse_args(argv[1:])
  280. def run(args):
  281. logging.basicConfig(level=int(round(10*args.verbose_level)))
  282. # logging.basicConfig(level=logging.NOTSET,
  283. # format='%(asctime)s: [%(module)s:%(funcName)s()line:%(lineno)d] - %(levelname)s : %(message)s')
  284. assert args.n_core <= multiprocessing.cpu_count(), 'Requested n_core={} > cpu_count={}'.format(
  285. args.n_core, multiprocessing.cpu_count())
  286. def Start():
  287. LOG.info('====>Started a worker in {} from parent {}'.format(
  288. os.getpid(), os.getppid()))
  289. exe_pool = Pool(args.n_core, initializer=Start)
  290. if args.trim:
  291. get_consensus = get_consensus_with_trim
  292. else:
  293. get_consensus = get_consensus_without_trim
  294. K = 8
  295. config = args.min_cov, K, \
  296. args.max_n_read, args.min_idt, args.edge_tolerance, \
  297. args.trim_size, args.min_cov_aln, args.max_cov_aln, \
  298. args.allow_external_mapping
  299. # TODO: pass config object, not tuple, so we can add fields
  300. LOG.debug("====>args={}".format(args))
  301. LOG.debug("====>get_consensus={}".format(get_consensus))
  302. LOG.debug("====>config={}".format(config))
  303. inputs = []
  304. for datum in get_seq_data(config, args.min_n_read, args.min_len_aln):
  305. inputs.append((get_consensus, datum))
  306. try:
  307. LOG.info('====>running {!r}'.format(get_consensus))
  308. for res in exe_pool.imap(io.run_func, inputs):
  309. process_get_consensus_result(res, args)
  310. LOG.info('====>finished {!r}'.format(get_consensus))
  311. except:
  312. LOG.exception('====>failed gen_consensus')
  313. exe_pool.terminate()
  314. raise
  315. good_region = re.compile("[ACGT]+")
  316. def process_get_consensus_result(res, args, limit=500):
  317. cns, seed_id = res
  318. seed_id = int(seed_id)
  319. if not cns:
  320. LOG.warning("====>process_get_consensus_result() data error! res={}".format(res))
  321. return
  322. if len(cns) < limit:
  323. LOG.debug("====>process_get_consensus_result() len(cns)={} < limit[{}]".format(len(cns), limit))
  324. return
  325. if args.output_full:
  326. print('>{:d}_f'.format(seed_id))
  327. print(cns)
  328. else:
  329. cns = good_region.findall(cns)
  330. if args.output_multi:
  331. seq_i = 0
  332. for cns_seq in cns:
  333. if len(cns_seq) < limit:
  334. continue
  335. if seq_i >= 10:
  336. break
  337. print(">prolog/%s%01d/%d_%d" % (seed_id, seq_i, 0, len(cns_seq)))
  338. print(format_seq(cns_seq, 80))
  339. seq_i += 1
  340. else:
  341. if len(cns) == 0:
  342. return
  343. cns.sort(key=lambda x: len(x))
  344. print('>{:d}'.format(seed_id))
  345. print(cns[-1])
  346. def main(argv=sys.argv):
  347. args = parse_args(argv)
  348. run(args)
  349. if __name__ == "__main__":
  350. main(sys.argv)