Source code for lbfextract.fextract_fragment_length_distribution_in_batch.plugin

import logging
import operator
import pathlib
from functools import reduce
from typing import Any, Optional

import click
import numpy as np
import pandas as pd
import pyranges
import pysam
from matplotlib import pyplot as plt

import lbfextract.fextract.signal_transformer
import lbfextract.fextract_fragment_length_distribution
import lbfextract.fextract_fragment_length_distribution.signal_summarizers
from lbfextract.core import App
from lbfextract.fextract.schemas import Config, ReadFetcherConfig, AppExtraConfig
from lbfextract.fextract_batch_coverage.schemas import PlotConfig
from lbfextract.fextract_entropy_in_batch.schemas import SignalSummarizer
from lbfextract.fextract_fragment_length_distribution.plugin import calculate_reference_distribution, get_peaks
from lbfextract.fextract_fragment_length_distribution.schemas import SingleSignalTransformerConfig
from lbfextract.plotting_lib.plotting_functions import plot_fragment_length_distribution
from lbfextract.utils import load_temporary_bed_file, get_tmp_fextract_file_name, filter_bam, generate_time_stamp, \
    check_input_bed, check_input_bam, filter_out_empty_bed_files, sanitize_file_name
from lbfextract.utils_classes import Signal

logger = logging.getLogger(__name__)


[docs] def subsample_fragment_lengths(x, n): mask = x > 0 probabilities = np.zeros_like(x) probabilities[mask] = x[mask] / x[mask].sum() try: subsampled_reads = np.random.choice(len(x), size=n, p=probabilities, replace=True) except ValueError as e: if str(e) == "probabilities do not sum to 1": probabilities = np.ones_like(probabilities) / len(probabilities) subsampled_reads = np.random.choice(len(x), p=probabilities) else: raise new_x = np.bincount(subsampled_reads, minlength=len(x)) return new_x
[docs] def optimize_tensor_subsampling(tensor, n): num_rows, num_columns = tensor.shape subsampled_tensor = np.zeros_like(tensor) for col in range(num_columns): subsampled_tensor[:, col] = subsample_fragment_lengths(tensor[:, col], n) return subsampled_tensor
[docs] def identity(x): return x
[docs] class IntervalIteratorFld: def __init__(self, df: pd.DataFrame, path_to_bam: pathlib.Path, multiple_iterators: bool, extra_bases: int, flip_based_on_strand: bool | None = None): self.df = df self.df_by_name = self.df.groupby("Name") self.sequence = list(self.df_by_name.groups.keys()) self.path_to_bam = path_to_bam self._index = 0 self.multiple_iterators = multiple_iterators self.extra_bases = extra_bases self.signal_transformer = identity self.max_fragment_length = None self.min_fragment_length = None self.n_bins_pos = None self.n_bins_len = None self.subsample = None self.n = None self.flip_based_on_strand = flip_based_on_strand self.bamfile = pysam.AlignmentFile(self.path_to_bam) def __iter__(self): return self
[docs] def set_signal_transformer(self, signal_transformer): self.signal_transformer = signal_transformer
def __getstate__(self): state = self.__dict__.copy() state["bamfile"] = None return state def __setstate__(self, state): self.__dict__.update(state) self.path_to_bam = pysam.AlignmentFile(state['path_to_bam']) def __next__(self): if self._index < len(self.sequence): key = self.sequence[self._index] df = self.df_by_name.get_group(key).copy() df["reads_per_interval"] = [ self.bamfile.fetch(row.Chromosome, row.Start, row.End, multiple_iterators=False) for row in df.itertuples() ] df["Start"] += self.extra_bases df["End"] -= self.extra_bases relative_fragment_len = self.max_fragment_length - self.min_fragment_length region_length = df["End"].iloc[0] - df["Start"].iloc[0] tensor = np.zeros((relative_fragment_len, region_length)) for interval in df.itertuples(): if self.flip_based_on_strand: if interval.Strand == "+": tensor += self.signal_transformer(interval) if interval.Strand == "-": tensor += np.fliplr(self.signal_transformer(interval)) else: logger.warning(f"Strand {interval.Strand} not recognized. Trating as +") tensor += self.signal_transformer(interval) else: tensor += self.signal_transformer(interval) del df def bin_tensor(t, bins, axis): bin_edges = np.linspace(0, t.shape[axis], bins + 1, dtype=int) return np.add.reduceat(t, bin_edges[:-1], axis=axis) if self.n_bins_pos: tensor = bin_tensor(tensor, self.n_bins_pos, axis=1) if self.n_bins_len: tensor = bin_tensor(tensor, self.n_bins_len, axis=0) if self.subsample: n = self.n or int(tensor.sum(axis=0).min()) tensor = optimize_tensor_subsampling(tensor, n) col_sums = tensor.sum(axis=0) non_zero_mask = col_sums > 0 tensor[:, non_zero_mask] /= col_sums[non_zero_mask] self._index += 1 return {key: tensor} else: raise StopIteration
[docs] class FextractHooks:
[docs] @lbfextract.hookimpl def fetch_reads(self, path_to_bam: pathlib.Path, path_to_bed: pathlib.Path, config: Any, extra_config: Any) -> IntervalIteratorFld: """ :param path_to_bam: path to the bam file :param path_to_bed: path to the bed file with the regions to be filtered :param config: configuration file containing the configuration object required by the fetch_reads function :param extra_config: extra configuration that may be used in the hook implementation :return: ReadsPerIntervalContainer object containing all the ReadsPerInterval objects in all the intervals contained in the bed file """ config_f = config.f or 2 config_F = config.F or 3868 check_input_bed(path_to_bed) check_input_bam(path_to_bam) bed_files_paths = filter_out_empty_bed_files(path_to_bed) temporary_bed_files_name = [ load_temporary_bed_file(bed_file=bed, extra_bases=config.extra_bases, window=config.window, flanking_region_window=config.flanking_region_window, n_binding_sites=config.n_binding_sites, run_id=extra_config.ctx["run_id"])[1].as_df() for bed in bed_files_paths ] concat_bed = pd.concat(temporary_bed_files_name) path_to_tmp_file = get_tmp_fextract_file_name(extra_config.ctx["run_id"]) pyranges.PyRanges(concat_bed).to_csv(path_to_tmp_file, sep="\t", header=None) tmp_bam_file = filter_bam(path_to_bam, path_to_tmp_file, cores=extra_config.cores, run_id=extra_config.ctx["run_id"], f=config_f, F=config_F) return IntervalIteratorFld(concat_bed, tmp_bam_file, multiple_iterators=True, extra_bases=config.extra_bases)
[docs] @lbfextract.hookimpl def transform_single_intervals(self, transformed_reads: IntervalIteratorFld, config: SingleSignalTransformerConfig, extra_config: Any) -> IntervalIteratorFld: """ :param transformed_reads: ReadsPerIntervalContainer containing a list of ReadsPerInterval which are basically lists with information about start and end of the interval :param config: config specific to the function :param extra_config: config containing context information plus extra parameters """ transformed_reads.flip_based_on_strand = config.flip_based_on_strand transformed_reads.max_fragment_length = config.max_fragment_length transformed_reads.min_fragment_length = config.min_fragment_length transformed_reads.n_bins_pos = config.n_bins_pos transformed_reads.n_bins_len = config.n_bins_len transformed_reads.subsample = config.subsample transformed_reads.n = config.n signal_transformers_dict = { "fld": {"class": "TfbsFragmentLengthDistribution", "config": {"min_fragment_length": config.min_fragment_length, "max_fragment_length": config.max_fragment_length, "gc_correction": config.gc_correction, "tag": config.tag }, "tags": ("fld",)}, "fld_middle": { "class": "TfbsFragmentLengthDistributionMiddlePoint", "config": {"min_fragment_length": config.min_fragment_length, "max_fragment_length": config.max_fragment_length, "gc_correction": config.gc_correction, "tag": config.tag }, "tags": ("fld_middle",) }, "fld_middle_n": { "class": "TfbsFragmentLengthDistributionMiddleNPoints", "config": {"min_fragment_length": config.min_fragment_length, "max_fragment_length": config.max_fragment_length, "gc_correction": config.gc_correction, "tag": config.tag, "n": config.w }, "tags": ("fld_middle_n",) }, "fld_dyad": { "class": "TfbsFragmentLengthDistributionDyad", "config": {"min_fragment_length": config.min_fragment_length, "max_fragment_length": config.max_fragment_length, "gc_correction": config.gc_correction, "tag": config.tag, "n": config.w, "peaks": config.peaks }, "tags": ("fld_dyad",) }, "fld_peter_ulz": { "class": "PeterUlzFragmentLengthDistribution", "config": {"min_fragment_length": config.min_fragment_length, "max_fragment_length": config.max_fragment_length, "gc_correction": config.gc_correction, "tag": config.tag, "read_start": config.read_start, "read_end": config.read_end, }, "tags": ("fld_peter_ulz",) } } fld_extractor_params = signal_transformers_dict[config.signal_transformer]["config"] fld_extractor = getattr(lbfextract.fextract_fragment_length_distribution.signal_summarizers, signal_transformers_dict[config.signal_transformer]["class"])( **fld_extractor_params) transformed_reads.set_signal_transformer(fld_extractor) return transformed_reads
[docs] @lbfextract.hookimpl def transform_all_intervals(self, single_intervals_transformed_reads: IntervalIteratorFld, config: Any, extra_config: Any) -> Signal: """ :param single_intervals_transformed_reads: Signal object containing the signals per interval :param config: config specific to the transform_all_intervals hook :param extra_config: extra configuration that may be used in the hook implementation """ summarized_signal_per_bed = [i for i in single_intervals_transformed_reads] summarized_signal_per_bed = reduce(operator.ior, summarized_signal_per_bed, {}) return Signal( array=summarized_signal_per_bed, metadata=None, tags=tuple(["fragment_length_distribution_in_batch", ]) )
[docs] @lbfextract.hookimpl def save_signal(self, signal: Signal, config: Any, extra_config: Any) -> pathlib.Path: """ :param signal: Signal object containing the signals per interval :param config: config specific to the save signal hook :param extra_config: extra configuration that may be used in the hook implementation """ output_path = extra_config.ctx["output_path"] time_stamp = generate_time_stamp() run_id = extra_config.ctx["id"] signal_type = "_".join(signal.tags) file_path = output_path / f"{time_stamp}__{run_id}__{signal_type}__signal" logger.info(f"Saving signal to {file_path}") np.savez_compressed(file_path, **signal.array) return file_path
[docs] @lbfextract.hookimpl def plot_signal(self, signal: Signal, extra_config: Any, config: PlotConfig, ) -> dict[str, pathlib.Path]: """ :param signal: Signal object containing the signals per interval :param config: object containing the configuration to create the plots :param extra_config: extra configuration that may be used in the hook implementation """ time_stamp = generate_time_stamp() run_id = extra_config.ctx["id"] signal_type = "_".join(signal.tags) if signal.tags else "" fig_pths = {} for i in signal.array: file_name = f"{time_stamp}__{run_id}__{signal_type}__{i}__heatmap.png" file_name_sanitized = sanitize_file_name(file_name) array = signal.array[i] start_pos = extra_config.ctx["single_signal_transformer_config"].min_fragment_length end_pos = array.shape[0] + start_pos fig = plot_fragment_length_distribution(array, start_pos, end_pos) sample_name = extra_config.ctx["path_to_bam"].stem fig.suptitle(f"{sample_name} {signal_type} {i}".capitalize(), fontsize=25) output_path = extra_config.ctx["output_path"] / file_name_sanitized fig.savefig(output_path, dpi=600) fig_pths[i] = output_path plt.close(fig) return fig_pths
[docs] class CliHook:
[docs] @lbfextract.hookimpl_cli def get_command(self) -> click.Command: @click.command( short_help="It extracts the fragment length distribution signal from a BAM file for each BED file provided.") @click.option('--path_to_bam', type=click.Path(exists=False, file_okay=True, dir_okay=True, writable=False, readable=True, resolve_path=False, allow_dash=True, path_type=pathlib.Path, executable=False), help='path to the bam file to be used') @click.option('--path_to_bed', type=click.Path(exists=False, file_okay=False, dir_okay=True, writable=False, readable=True, resolve_path=False, allow_dash=True, path_type=pathlib.Path, executable=False), help='path to the bed file to be used') @click.option('--output_path', type=click.Path(exists=False, file_okay=False, dir_okay=True, writable=True, readable=True, resolve_path=False, allow_dash=True, path_type=pathlib.Path, executable=False), help='path to the output directory') @click.option("--skip_read_fetching", is_flag=True, show_default=True, help='Boolean flag. When it is set, the fetching of the reads is skipped and the latest' 'timestamp of this run (identified by the id) is retrieved') @click.option("--exp_id", default=None, type=str, show_default=True, help="run id") @click.option("--window", default=1000, type=int, show_default=True, help="Integer describing the number of bases to be extracted around the middle point of an " "interval present in the bedfile") @click.option("--flanking_window", default=1000, type=int, show_default=True, help="Integer describing the number of bases to be extracted after the window") @click.option("--extra_bases", default=2000, type=int, show_default=True, help="Integer describing the number of bases to be extracted from the bamfile when removing the " "unused bases to be sure to get all the proper pairs, which may be mapping up to 2000 bs") @click.option("--n_binding_sites", default=1000, type=int, show_default=True, help="number of intervals to be used to extract the signal, if it is higher then the provided" "intervals, all the intervals will be used") @click.option("--cores", default=1, type=int, show_default=True, help="number of cores to be used") @click.option("--min_fragment_length", default=100, type=int, show_default=True, help="minimum fragment length to be considered") @click.option("--max_fragment_length", default=300, type=int, show_default=True, help="maximum fragment length to be considered") @click.option("--n_reads", default=1000, type=int, show_default=True, help="number of reads to be subsampled at each position") @click.option("--subsample", is_flag=True, show_default=False, help="Boolean flag. When it is set, the reads are subsampled at each position") @click.option("--n_bins_len", type=int, show_default=True, help="number of bins to be used to discretize the fragment length") @click.option("--n_bins_pos", type=int, show_default=True, help="number of bins to be used to discretize the position") @click.option('--gc_correction_tag', type=str, default=None, help='tag to be used to extract gc coefficient per read from a bam file') @click.option("--fld_type", type=click.Choice(["fld", "fld_middle", "fld_middle_n", "fld_dyad"], case_sensitive=False), show_default=True, default="fld", help="type of fragment length distribution to be extracted") @click.option("--w", default=5, type=int, show_default=True, help="window used for the number of bases around either the middle point in the " "fld_middle_around or the number of bases around the center of the dyad in fld_dyad") def extract_fragment_length_distribution_in_batch(path_to_bam: pathlib.Path, path_to_bed: pathlib.Path, output_path: pathlib.Path, skip_read_fetching: bool, window: int, flanking_window: int, extra_bases: int, n_binding_sites: int, min_fragment_length: int, max_fragment_length: int, n_reads: int, subsample: bool, n_bins_pos: int, n_bins_len: int, cores: int, fld_type: str, w: int, exp_id: Optional[str], gc_correction_tag: Optional[str] = None): """ Given a set of genomic intervals having the same length w, the extract_fragment_length_distribution feature extraction method extracts the fragment length distribution at each position of the genomic intervals used for multiple BED files at the same time. """ read_fetcher_config = { "window": window, "flanking_region_window": flanking_window, "extra_bases": extra_bases, "n_binding_sites": n_binding_sites, } reads_transformer_config = Config({}) single_signal_transformer_config = { "min_fragment_length": min_fragment_length, "max_fragment_length": max_fragment_length, "signal_transformer": fld_type, "n": n_reads, "subsample": subsample, "n_bins_pos": n_bins_pos, "n_bins_len": n_bins_len, "gc_correction": True if gc_correction_tag else False, "tag": gc_correction_tag } if w: single_signal_transformer_config["w"] = w if fld_type == "fld_dyad": distribution = calculate_reference_distribution(path_to_sample=path_to_bam, min_length=min_fragment_length, max_length=max_fragment_length, chr_name="chr12", start=34_300_000, end=34_500_000 ) peaks = get_peaks(distribution) + min_fragment_length single_signal_transformer_config["peaks"] = [peaks[0]] transform_all_intervals_config = {} plot_signal_config = {} save_signal_config = {} extra_config = { "cores": cores } output_path.mkdir(parents=True, exist_ok=True) output_path_interval_spec = output_path / f"{path_to_bam.stem}" / f"{path_to_bed.stem}" output_path_interval_spec.mkdir(parents=True, exist_ok=True) App(plugins_name=["fragment_length_distribution_in_batch", "coverage_in_batch"], path_to_bam=path_to_bam, path_to_bed=path_to_bed, output_path=output_path_interval_spec or pathlib.Path.cwd(), skip_read_fetching=skip_read_fetching, read_fetcher_config=ReadFetcherConfig(read_fetcher_config), reads_transformer_config=reads_transformer_config, single_signal_transformer_config=SingleSignalTransformerConfig(single_signal_transformer_config), transform_all_intervals_config=SignalSummarizer(transform_all_intervals_config), plot_signal_config=PlotConfig(plot_signal_config), save_signal_config=Config(save_signal_config), extra_config=AppExtraConfig(extra_config), id=exp_id).run() return extract_fragment_length_distribution_in_batch
hook = FextractHooks() hook_cli = CliHook()