Source code for lbfextract.fextract_fragment_length_distribution.plugin
from __future__ import annotations
import logging
import pathlib
from typing import Any, Optional
import click
import matplotlib
import numpy as np
import pandas as pd
import pysam
from scipy.signal import find_peaks
from scipy.signal import savgol_filter
import lbfextract.fextract.signal_transformer
import lbfextract.fextract_fragment_length_distribution.signal_summarizers
from lbfextract.core import App
from lbfextract.fextract.schemas import AppExtraConfig, ReadFetcherConfig, Config
from lbfextract.fextract_fragment_length_distribution.schemas import SingleSignalTransformerConfig
from lbfextract.plotting_lib.plotting_functions import plot_fragment_length_distribution
from lbfextract.utils import generate_time_stamp, sanitize_file_name
from lbfextract.utils_classes import Signal
logger = logging.getLogger(__name__)
[docs]
def calculate_reference_distribution(path_to_sample, min_length, max_length, chr_name, start, end):
alignment_file = pysam.AlignmentFile(path_to_sample, "rb")
reads = alignment_file.fetch(chr_name, start, end)
array_fragment_lengths = np.zeros(max_length - min_length)
for i in reads:
if i.tlen < min_length or i.tlen >= max_length:
continue
array_fragment_lengths[i.tlen - min_length] += 1
return array_fragment_lengths / array_fragment_lengths.sum()
[docs]
def get_peaks(distribution, height=0.0001, distance=100):
distribution = savgol_filter(distribution, 10, 3)
peaks = find_peaks(distribution, height=height, distance=distance)[0]
return peaks
[docs]
def subsample_fragment_lengths(x, n):
new_x = np.zeros_like(x)
probabilities = np.where(x > 0, x / x.sum(), 0)
try:
subsampled_reads = np.random.choice(np.arange(0, x.shape[0]), size=n,
p=probabilities, replace=True)
except ValueError as e:
if str(e) == "probabilities do not sum to 1":
probabilities = np.ones_like(probabilities) / len(probabilities)
subsampled_reads = np.random.choice(len(x), p=probabilities)
else:
raise
for i in range(x.shape[0]):
new_x[i] = np.sum(subsampled_reads == i)
return new_x
[docs]
def get_position_coefficient(config, array):
window = (config.ctx["read_fetcher_config"].window + config.ctx["read_fetcher_config"].flanking_region_window)
return (window * 2) / array.shape[0]
[docs]
class FextractHooks:
[docs]
@lbfextract.hookimpl
def transform_single_intervals(self, transformed_reads: pd.DataFrame,
config: SingleSignalTransformerConfig,
extra_config: AppExtraConfig) -> Signal:
"""
:param transformed_reads: ReadsPerIntervalContainer containing a list of ReadsPerInterval which are
basically lists with information about start and end of the interval
:param config: config specific to the function
:param extra_config: config containing context information plus extra parameters
"""
signal_transformers_dict = {
"fld": {"class": "TfbsFragmentLengthDistribution",
"config": {"min_fragment_length": config.min_fragment_length,
"max_fragment_length": config.max_fragment_length,
"gc_correction": config.gc_correction,
"tag": config.tag
},
"tags": ("fld",)},
"fld_middle": {
"class": "TfbsFragmentLengthDistributionMiddlePoint",
"config": {"min_fragment_length": config.min_fragment_length,
"max_fragment_length": config.max_fragment_length,
"gc_correction": config.gc_correction,
"tag": config.tag
},
"tags": ("fld_middle",)
},
"fld_middle_n": {
"class": "TfbsFragmentLengthDistributionMiddleNPoints",
"config": {"min_fragment_length": config.min_fragment_length,
"max_fragment_length": config.max_fragment_length,
"gc_correction": config.gc_correction,
"tag": config.tag,
"n": config.w
},
"tags": ("fld_middle_n",)
},
"fld_dyad": {
"class": "TfbsFragmentLengthDistributionDyad",
"config": {"min_fragment_length": config.min_fragment_length,
"max_fragment_length": config.max_fragment_length,
"gc_correction": config.gc_correction,
"tag": config.tag,
"n": config.w,
"peaks": config.peaks
},
"tags": ("fld_dyad",)
},
"fld_peter_ulz": {
"class": "PeterUlzFragmentLengthDistribution",
"config": {"min_fragment_length": config.min_fragment_length,
"max_fragment_length": config.max_fragment_length,
"gc_correction": config.gc_correction,
"tag": config.tag,
"read_start": config.read_start,
"read_end": config.read_end,
},
"tags": ("fld_peter_ulz",)
}
}
fld_extractor_params = signal_transformers_dict[config.signal_transformer]["config"]
fld_extractor = getattr(lbfextract.fextract_fragment_length_distribution.signal_summarizers,
signal_transformers_dict[config.signal_transformer]["class"])(
**fld_extractor_params)
tag = signal_transformers_dict[config.signal_transformer]["tags"]
relative_fragment_len = config.max_fragment_length - config.min_fragment_length
region_length = transformed_reads["End"][0] - transformed_reads["Start"][0]
tensor = np.zeros((relative_fragment_len, region_length))
for interval in transformed_reads.itertuples():
if config.flip_based_on_strand:
if interval.Strand == "+":
tensor += fld_extractor(interval)
elif interval.Strand == "-":
tensor += np.fliplr(fld_extractor(interval))
else:
logger.warning(f"Strand {interval.Strand} not recognized. Treating it as +.")
else:
tensor += fld_extractor(interval)
# TODO: optimize this for speed and readability
if config.n_bins_pos:
tensor = np.hstack(
list(map(lambda x: x.sum(axis=1)[:, None], np.array_split(tensor, config.n_bins_pos, axis=1))))
if config.n_bins_len:
tensor = np.vstack(list(map(lambda x: x.sum(axis=0), np.array_split(tensor, config.n_bins_len, axis=0))))
if config.subsample:
n = config.n or int(tensor.sum(axis=0).min())
tensor = np.apply_along_axis(
lambda x: subsample_fragment_lengths(x, n),
0,
tensor
)
return Signal(array=tensor, tags=tag, metadata=None)
[docs]
@lbfextract.hookimpl
def transform_all_intervals(self, single_intervals_transformed_reads: Signal, config: Any,
extra_config: Any) -> Signal:
"""
:param single_intervals_transformed_reads: Signal object containing the signals per interval
:param config: config specific to the function
:param extra_config: extra configuration that may be used in the hook implementation
"""
array = single_intervals_transformed_reads.array
col_sums = array.sum(axis=0)
mask = col_sums > 0
normalized_array = np.ones_like(array)
normalized_array[:, mask] = array[:, mask] / col_sums[mask]
return Signal(array=array, tags=single_intervals_transformed_reads.tags, metadata=None)
[docs]
@lbfextract.hookimpl
def plot_signal(self, signal: Signal, extra_config: Any) -> matplotlib.figure.Figure:
"""
:param signal: Signal object containing the signals per interval
:param extra_config: extra configuration that may be used in the hook implementation
"""
time_stamp = generate_time_stamp()
run_id = extra_config.ctx["id"]
signal_type = "_".join(signal.tags) if signal.tags else ""
start_pos = extra_config.ctx["single_signal_transformer_config"].min_fragment_length
end_pos = signal.array.shape[0] + start_pos
sample_name = extra_config.ctx["path_to_bam"].stem
interval_name = extra_config.ctx["path_to_bed"].stem.split(".", 1)[0]
fig = plot_fragment_length_distribution(signal.array, start_pos, end_pos,
title=f"{sample_name} {signal_type} {interval_name}")
file_name = f"{time_stamp}__{run_id}__{signal_type}__heatmap.png"
file_name_sanitized = sanitize_file_name(file_name)
output_path = extra_config.ctx["output_path"] / file_name_sanitized
fig.savefig(output_path, dpi=600)
return fig
[docs]
@lbfextract.hookimpl
def save_signal(self,
signal: Signal,
extra_config: Any) -> None:
"""
:param signal: Signal object containing the signals per interval
:param extra_config: extra configuration that may be used in the hook implementation
"""
output_path = extra_config.ctx["output_path"]
time_stamp = generate_time_stamp()
run_id = extra_config.ctx["id"]
signal_type = "_".join(signal.tags)
file_name = f"{time_stamp}__{run_id}__{signal_type}__signal.csv"
file_name_sanitized = sanitize_file_name(file_name)
path_to_plot = output_path / file_name_sanitized
df = pd.DataFrame(signal.array)
df.to_csv(path_to_plot)
return path_to_plot
[docs]
class CliHook:
r"""
This CliHook implements the CLI interface for the extract_fragment_length_distribution feature extraction
method.
**extract_fragment_length_distribution**
Given a set of genomic intervals having the same length w, extract_fragment_length_distribution calculates the
fragment length distribution at each position, which can be represented as:
.. math::
\mathbf{d}_l = \left( \frac{1}{|F|} \sum_{\substack{f \in F \\ |f| = p \\ i \in f}} \mathbb{1} \right)^{p_e}_{p_s}
Where :math:`l` represents the genomic position, :math:`f` represents a fragment, :math:`p_e` represent
the maximum fragment length
and :math:`p_s` represents the minimum fragment length
"""
[docs]
@lbfextract.hookimpl_cli
def get_command(self) -> click.Command:
@click.command(
short_help="It extracts the fragment length distribution signal from a BAM file for a BED file provided."
)
@click.option('--path_to_bam', type=click.Path(exists=False,
file_okay=True,
dir_okay=True,
writable=False,
readable=True,
resolve_path=False,
allow_dash=True,
path_type=pathlib.Path,
executable=False),
help='path to the bam file to be used')
@click.option('--path_to_bed', type=click.Path(exists=False,
file_okay=True,
dir_okay=True,
writable=False,
readable=True,
resolve_path=False,
allow_dash=True,
path_type=pathlib.Path,
executable=False),
help='path to the bed file to be used')
@click.option('--output_path', type=click.Path(exists=False,
file_okay=False,
dir_okay=True,
writable=True,
readable=True,
resolve_path=False,
allow_dash=True,
path_type=pathlib.Path,
executable=False),
help='path to the output directory')
@click.option("--skip_read_fetching", is_flag=True, show_default=True,
help='Boolean flag. When it is set, the fetching of the reads is skipped and the latest'
'timestamp of this run (identified by the id) is retrieved')
@click.option("--exp_id", default=None, type=str, show_default=True,
help="run id")
@click.option("--window", default=1000, type=int, show_default=True,
help="Integer describing the number of bases to be extracted around the middle point of an "
"interval present in the bedfile")
@click.option("--flanking_window", default=1000, type=int, show_default=True,
help="Integer describing the number of bases to be extracted after the window")
@click.option("--extra_bases", default=2000, type=int, show_default=True,
help="Integer describing the number of bases to be extracted from the bamfile when removing the "
"unused bases to be sure to get all the proper pairs, which may be mapping up to 2000 bs")
@click.option("--n_binding_sites", default=1000, type=int, show_default=True,
help="number of intervals to be used to extract the signal, if it is higher then the provided"
"intervals, all the intervals will be used")
@click.option("--cores", default=1, type=int, show_default=True,
help="number of cores to be used")
@click.option("--min_fragment_length", default=100, type=int, show_default=True,
help="minimum fragment length to be considered")
@click.option("--max_fragment_length", default=300, type=int, show_default=True,
help="maximum fragment length to be considered")
@click.option("--n_reads", default=1000, type=int, show_default=True,
help="number of reads to be subsampled at each position")
@click.option("--subsample", is_flag=True, show_default=False,
help="Boolean flag. When it is set, the reads are subsampled at each position")
@click.option("--n_bins_len", type=int, show_default=True,
help="number of bins to be used to discretize the fragment length")
@click.option("--n_bins_pos", type=int, show_default=True,
help="number of bins to be used to discretize the position")
@click.option("--flip_based_on_strand", is_flag=True, show_default=False,
default=False,
help="Boolean flag. When it is set, the signal is flipped based on the strand")
@click.option('--gc_correction_tag', type=str,
default=None, help='tag to be used to extract gc coefficient per read from a bam file')
@click.option("--w", default=5, type=int, show_default=True,
help="window used for the number of bases around either the middle point in the"
" fld_middle_around or the number of bases around the center of the dyad in fld_dyad")
@click.option("--fld_type",
type=click.Choice(["fld", "fld_middle", "fld_middle_n", "fld_dyad", "fld_peter_ulz"],
case_sensitive=False),
show_default=True, default="fld",
help="type of fragment length distribution to be extracted")
@click.option("--read_start", default=53, type=int, show_default=True,
help="start of the read to be used to extract coverage")
@click.option("--read_end", default=113, type=int, show_default=True,
help="end of the read to be used to extract coverage")
def extract_fragment_length_distribution(
path_to_bam: pathlib.Path, path_to_bed: pathlib.Path,
output_path: pathlib.Path,
skip_read_fetching: bool,
window: int,
flanking_window: int,
extra_bases: int,
n_binding_sites: int,
min_fragment_length: int,
max_fragment_length: int,
n_reads: int,
subsample: bool,
n_bins_pos: int,
n_bins_len: int,
cores: int,
w: int,
fld_type: str,
exp_id: Optional[str],
read_start: int,
read_end: int,
flip_based_on_strand: bool = False,
gc_correction_tag: Optional[str] = None,
):
"""
Given a set of genomic intervals having the same length w, the extract_fragment_length_distribution feature
extraction method extracts the fragment length distribution at each position of the genomic intervals used.
"""
read_fetcher_config = {
"window": window,
"flanking_region_window": flanking_window,
"extra_bases": extra_bases,
"n_binding_sites": n_binding_sites,
}
reads_transformer_config = Config({})
single_signal_transformer_config = {
"min_fragment_length": min_fragment_length,
"max_fragment_length": max_fragment_length,
"signal_transformer": fld_type,
"n": n_reads,
"subsample": subsample,
"n_bins_pos": n_bins_pos,
"n_bins_len": n_bins_len,
"flip_based_on_strand": flip_based_on_strand,
"gc_correction": True if gc_correction_tag else False,
"tag": gc_correction_tag,
"read_start": read_start,
"read_end": read_end
}
if w:
single_signal_transformer_config["w"] = w
if fld_type == "fld_dyad":
distribution = calculate_reference_distribution(path_to_sample=path_to_bam,
min_length=min_fragment_length,
max_length=max_fragment_length,
chr_name="chr12",
start=34_300_000,
end=34_500_000
)
peaks = get_peaks(distribution) + min_fragment_length
single_signal_transformer_config["peaks"] = [peaks[0]]
plot_signal_config = {}
save_signal_config = {}
extra_config = {
"cores": cores
}
output_path.mkdir(parents=True, exist_ok=True)
output_path_interval_spec = output_path / f"{path_to_bam.stem}" / f"{path_to_bed.stem}"
output_path_interval_spec.mkdir(parents=True, exist_ok=True)
res = App(plugins_name=["fragment_length_distribution", ],
path_to_bam=path_to_bam,
path_to_bed=path_to_bed,
output_path=output_path_interval_spec or pathlib.Path.cwd(),
skip_read_fetching=skip_read_fetching,
read_fetcher_config=ReadFetcherConfig(read_fetcher_config),
reads_transformer_config=reads_transformer_config,
single_signal_transformer_config=SingleSignalTransformerConfig(single_signal_transformer_config),
transform_all_intervals_config=Config({}),
plot_signal_config=Config(plot_signal_config),
save_signal_config=Config(save_signal_config),
extra_config=AppExtraConfig(extra_config),
id=exp_id).run()
return res
return extract_fragment_length_distribution
hook = FextractHooks()
hook_cli = CliHook()