Source code for lbfextract.fextract.lib
from __future__ import annotations
import logging
import os
import pathlib
import shutil
import tempfile
from typing import Any
import dill as pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyranges
import pysam
import lbfextract.fextract
import lbfextract.fextract.signal_transformer
from lbfextract.fextract.schemas import ReadFetcherConfig, Config, SingleSignalTransformerConfig, \
SignalSummarizer, AppExtraConfig
from lbfextract.plotting_lib.plotting_functions import plot_signal
from lbfextract.utils import filter_bam, load_temporary_bed_file, generate_time_stamp, get_tmp_bam_name, \
write_yml, load_reads_from_dir, sanitize_file_name
from lbfextract.utils_classes import Signal
logger = logging.getLogger(__name__)
[docs]
class FextractHooks:
[docs]
@lbfextract.hookimpl
def fetch_reads(self, path_to_bam: pathlib.Path,
path_to_bed: pathlib.Path,
config: ReadFetcherConfig,
extra_config: AppExtraConfig) -> pd.DataFrame:
if not path_to_bam.exists():
raise ValueError(f"The bam file ({path_to_bam}) does not exist")
if not path_to_bed.exists():
raise ValueError(f"The bed file ({path_to_bed}) does not exist")
if path_to_bed.stat().st_size == 0:
raise ValueError(f"The bed ({path_to_bed}) file is empty")
if path_to_bam.stat().st_size == 0:
raise ValueError(f"The bam file ({path_to_bam}) is empty")
config_f = config.f or 2
config_F = config.F or 3868
temporary_bed_file_name, bed_file = load_temporary_bed_file(
bed_file=path_to_bed,
extra_bases=config.extra_bases,
window=config.window,
flanking_region_window=config.flanking_region_window,
n_binding_sites=config.n_binding_sites,
run_id=extra_config.ctx["run_id"]
)
if config.window == 0:
starts_equal_ends: pd.Series = bed_file.as_df()["Start"] == bed_file.as_df()["End"]
if any(starts_equal_ends):
raise ValueError(
"The bed file contains intervals with the same start and end but window is set to 0."
"Please either provide interval of size grater than 0 or set the window size to a value"
"greater than 0"
)
# filtering the bam file to avoid having to go through it each time while fatching
if bed_file.empty:
raise ValueError("The bed file is empty")
tmp_bam_file = filter_bam(path_to_bam, temporary_bed_file_name, cores=extra_config.cores,
run_id=extra_config.ctx["run_id"], f=config_f, F=config_F)
bamfile = pysam.AlignmentFile(tmp_bam_file)
list_of_reads = bed_file.as_df()
list_of_reads["reads_per_interval"] = [
bamfile.fetch(row.Chromosome, row.Start, row.End, multiple_iterators=False)
for row in list_of_reads.itertuples(index=False)
]
return pyranges.PyRanges(list_of_reads).slack(-config.extra_bases).as_df()
[docs]
@lbfextract.hookimpl
def save_fetched_reads(self,
reads_per_interval_container: pd.DataFrame,
config: Config,
extra_config: AppExtraConfig
) -> pathlib.Path:
"""
Hook implementing the strategy to save the reads fetched for the intervals
:param reads_per_interval_container: ReadsPerIntervalContainer containing information about the genomic region
and the reads mapping to it
:param extra_config: AppExtraConfig containing the output path
:return: None
"""
sample = extra_config.ctx["path_to_bam"].stem
temp_dir = pathlib.Path(os.environ.get("FRAGMENTOMICS_TMP") or tempfile.gettempdir())
bam_file_name = get_tmp_bam_name(extra_config.ctx["path_to_bam"],
run_id=extra_config.ctx["run_id"], )
path_to_tmp_bam = (temp_dir / bam_file_name).with_suffix('.sorted.bam')
path_to_tmp_bam_index = (temp_dir / bam_file_name).with_suffix('.sorted.bam.bai')
output_dir = pathlib.Path(extra_config.ctx["output_path"]) / "fatched_reads"
output_dir.mkdir(parents=True, exist_ok=True)
(
reads_per_interval_container[["Chromosome", "Start", "End"]]
.to_csv(
output_dir / f"{sample}_fetched_reads.bed",
sep="\t",
header=False,
index=False
)
)
write_yml(
{
"bed": f"{sample}_fetched_reads.bed",
"bam": path_to_tmp_bam.name
},
output_dir / "metadata.yml"
)
if not (output_dir / path_to_tmp_bam.name).exists():
logger.debug(f"Moving {path_to_tmp_bam} to {output_dir}")
shutil.move(str(path_to_tmp_bam), str(output_dir))
if not (output_dir / path_to_tmp_bam_index.name).exists():
logger.debug(f"Moving {path_to_tmp_bam_index} to {output_dir}")
shutil.move(str(path_to_tmp_bam_index), str(output_dir))
return output_dir
[docs]
@lbfextract.hookimpl
def load_fetched_reads(self, config: Config, extra_config: AppExtraConfig) -> pd.DataFrame:
return load_reads_from_dir(
pathlib.Path(extra_config.ctx["output_path"]) / "fatched_reads",
extra_bases=extra_config.ctx["read_fetcher_config"].extra_bases
)
[docs]
@lbfextract.hookimpl
def transform_reads(self,
reads_per_interval_container: pd.DataFrame,
config: Config,
extra_config: AppExtraConfig) -> pd.DataFrame:
return reads_per_interval_container
[docs]
@lbfextract.hookimpl
def transform_single_intervals(self,
transformed_reads: pd.DataFrame,
config: SingleSignalTransformerConfig,
extra_config: AppExtraConfig
) -> Signal:
signal_transformers_dict = {
"coverage": {"class": "TFBSCoverage",
"config": {"gc_correction": config.gc_correction,
"tag": config.tag}
},
"coverage_dyads": {"class": "TFBSCoverageAroundDyads",
"config": {"n": config.n,
"gc_correction": config.gc_correction,
"tag": config.tag,
"peaks": config.peaks}
},
"middle_point_coverage": {"class": "TFBSMiddlePointCoverage",
"config": {"gc_correction": config.gc_correction,
"tag": config.tag}
},
"middle_n_points_coverage": {"class": "TFBSNmiddlePointCoverage",
"config": {"n": config.n,
"gc_correction": config.gc_correction,
"tag": config.tag}
},
"sliding_window_coverage": {"class": "TFBSSlidingWindowCoverage",
"config": {"window_size": config.window_size,
"gc_correction": config.gc_correction,
"tag": config.tag}
},
"peter_ulz_coverage": {"class": "PeterUlzCoverage",
"config": {"read_start": config.read_start,
"read_end": config.read_end,
"gc_correction": config.gc_correction,
"tag": config.tag}
},
"wps_coverage": {"class": "WPSCoverage",
"config": {"window_size": config.window_size,
"gc_correction": config.gc_correction,
"min_fragment_length": config.min_fragment_length,
"max_fragment_length": config.max_fragment_length,
"tag": config.tag,
}
},
}
coverage_extractor_params = signal_transformers_dict[config.signal_transformer]["config"]
coverage_extractor = getattr(lbfextract.fextract.signal_transformer,
signal_transformers_dict[config.signal_transformer]["class"])(
**coverage_extractor_params)
array = np.vstack([coverage_extractor(row) for row in transformed_reads.itertuples(index=False)])
if config.flip_based_on_strand:
strands = transformed_reads["Strand"].values
flip_mask = strands == "-"
array[flip_mask] = np.fliplr(array[flip_mask])
return Signal(
array=array,
metadata={"bed_file_df": transformed_reads[['Chromosome', 'Start', 'End', 'Name', 'Score', 'Strand']]},
tags=(config.signal_transformer,)
)
[docs]
@lbfextract.hookimpl
def transform_all_intervals(self,
single_intervals_transformed_reads: Signal,
config: SignalSummarizer,
extra_config: AppExtraConfig) -> Signal:
config.bed_file = extra_config.ctx["path_to_bed"]
default_flanking_region = single_intervals_transformed_reads.array.shape[1] // 3
flanking_window = extra_config.ctx["read_fetcher_config"].flanking_region_window or default_flanking_region
summary_method = {
"mean": np.nanmean,
"median": np.nanmedian,
"max": np.nanmax,
"min": np.nanmin,
"skip": lambda x, axis: x
}
indices = np.arange(single_intervals_transformed_reads.array.shape[1])
index_flanking = np.logical_or(
indices < flanking_window,
indices > (indices - flanking_window)
)
normalized_array = np.zeros_like(single_intervals_transformed_reads.array)
means_flanking = single_intervals_transformed_reads.array[:, index_flanking].mean(axis=1)
mask = means_flanking != 0
normalized_array[mask, :] = single_intervals_transformed_reads.array[mask, :] / means_flanking[mask, None]
array = summary_method[config.summarization_method](normalized_array, axis=0)
return Signal(
array=array,
metadata=None,
tags=tuple(list(single_intervals_transformed_reads.tags) + [config.summarization_method]))
[docs]
@lbfextract.hookimpl
def save_signal(self,
signal: Signal,
config: Any,
extra_config: AppExtraConfig
) -> pathlib.Path:
time_stamp = generate_time_stamp()
run_id = extra_config.ctx["id"]
signal_type = "_".join(signal.tags)
file_name = f"{time_stamp}__{run_id}__{signal_type}__signal.pkl"
file_name_sanitized = sanitize_file_name(file_name)
output_path = extra_config.ctx["output_path"] / file_name_sanitized
with open(output_path, "wb") as f:
pickle.dump(signal, f)
return output_path
[docs]
@lbfextract.hookimpl
def plot_signal(self, signal: Signal,
config: Any,
extra_config: AppExtraConfig) -> plt.Figure:
big_fs = 20
medium_fs = 15
signal_type = "_".join(signal.tags) if signal.tags else ""
with plt.style.context('seaborn-v0_8-whitegrid'):
fig, ax = plt.subplots(1, figsize=(10, 10))
title = (
f"SIGNAL TYPE: {signal_type}\n"
f"ID: {extra_config.ctx['path_to_bam'].stem} \n"
f"BED file: {extra_config.ctx['path_to_bed'].stem.split('.', 1)[0]}"
)
ax.set_title(title, fontsize=big_fs)
fig, _ = plot_signal(signal.array, apply_savgol=False, ax=ax, fig=fig, label=signal_type)
ax.set_ylabel(signal_type, fontsize=medium_fs)
ax.set_xlabel("Position", fontsize=medium_fs)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=medium_fs)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=medium_fs)
file_name = f"{generate_time_stamp()}__{extra_config.ctx['id']}__{signal_type}_signal_plot.png"
file_name_sanitized = sanitize_file_name(file_name)
output_path = extra_config.ctx["output_path"] / file_name_sanitized
fig.savefig(output_path, dpi=600)
return fig