import hashlib
import logging
import pathlib
from collections import defaultdict
from copy import copy
from typing import Callable
import numpy as np
import pandas as pd
import polars as pl
from lbfextract.fextract.schemas import Config
from lbfextract.transcription_factor_analysis.accessibility_extraction import get_chromatin_accessibility_coverage, \
get_chromatin_accessibility_entropy
from lbfextract.transcription_factor_analysis.schemas import AccessibilityConfig
logger = logging.getLogger(__name__)
[docs]
class ResultsLoader:
signal_types = defaultdict(dict, {
"coverage": {"fun": get_chromatin_accessibility_coverage,
"validator": AccessibilityConfig},
"entropy": {"fun": get_chromatin_accessibility_entropy,
"validator": AccessibilityConfig}
})
[docs]
@classmethod
def register_signal_type(cls, signal_name, fun: Callable, validator: Config):
cls.signal_types[signal_name]["fun"] = fun
cls.signal_types[signal_name]["validator"] = validator
def __init__(self, path_to_res_summary: pathlib.Path,
accessibility_extraction_config: dict,
signal_length: int = 4000,
flanking_signal_indices: tuple = (1000, 3000),
normalize: bool = False,
path_to_sample_sheet: pathlib.Path = None,
grouping_column: str = None,
signal_type: str = "coverage"):
self.signal_type = self.check_signal_type_compatibility(signal_type)
self.path_to_res_summary = path_to_res_summary
self.signal_length = signal_length
self.accessibility_extraction_config = accessibility_extraction_config or dict(start=1800, end=2200)
self.flanking_signal_indices = flanking_signal_indices
self.normalize = normalize
self.grouping_column = grouping_column
self.sample_sheet = self.load_sample_sheet(
path_to_sample_sheet) if path_to_sample_sheet else self.generate_sample_sheet()
[docs]
def check_signal_type_compatibility(self, signal_type: str) -> str:
signal_types_keys = list(self.signal_types.keys())
check = True if signal_type in signal_types_keys else False
if not check:
raise ValueError(
f"signal: {self.signal_type} is not compatible, possible signals: {' '.join(list(self.signal_types.keys()))}")
return signal_type
[docs]
def load_sample_sheet(self, path_to_sample_sheet: pathlib.Path) -> pd.DataFrame:
sample_sheet = pd.read_csv(path_to_sample_sheet, sep=",", index_col=0)
not_float_cols = (sample_sheet.dtypes.apply(lambda x: x.name) == "object").to_list()
sample_sheet.loc[:, not_float_cols] = sample_sheet.loc[:, not_float_cols].copy().applymap(
lambda x: "NA" if pd.isna(x) else x)
return sample_sheet
@staticmethod
def _hash_path_sample_results(sample_path: str | pathlib.Path):
if isinstance(sample_path, pathlib.Path):
sample_path = str(sample_path)
input_bytes = sample_path.encode('utf-8')
md5_hash = hashlib.md5()
md5_hash.update(input_bytes)
hash_result = md5_hash.hexdigest()
return hash_result
[docs]
def generate_sample_sheet(self):
paths_to_sample_result = list(self.path_to_res_summary.glob("**/*csv"))
sample_names = [i.parent.parent.stem for i in paths_to_sample_result]
bed_file_metadata = [i.parent.stem for i in paths_to_sample_result]
index_df = range(len(paths_to_sample_result))
sample_sheet_df = pd.DataFrame(
columns=["group", "tumor_fraction", "cov", "signal_type"],
index=index_df)
sample_sheet_df["sample_name"] = pd.Series(sample_names, index=index_df)
sample_sheet_df["bed_file_metadata"] = pd.Series(bed_file_metadata, index=index_df)
sample_sheet_df["path_to_res_summary"] = pd.Series(self.path_to_res_summary, index=index_df)
sample_sheet_df["paths_to_sample_result"] = paths_to_sample_result
sample_sheet_df.index = sample_sheet_df.paths_to_sample_result.apply(
lambda x: self._hash_path_sample_results(x)
).to_list()
return sample_sheet_df
[docs]
def get_res_df_polars(self) -> pl.DataFrame:
dfs = []
for count, i in enumerate(self.path_to_res_summary.glob("**/*csv")):
path_exists = i.exists()
path_hash = self._hash_path_sample_results(str(i))
path_hash_exists = path_hash in self.sample_sheet.index
if not path_exists or not path_hash_exists:
msg = f"path {i} {'exists' if path_exists else 'does not exist'}"
msg1 = f"path hash {path_hash} {'is in index' if path_hash_exists else 'is not in index'}"
logger.warning(f"{msg + ' but ' + msg1 if path_exists else msg1 + ' but ' + msg}")
continue
parsed_name = self.sample_sheet.loc[path_hash]
index_transcription_factor_df = ["genomic_interval"] + [str(i) for i in list(range(self.signal_length))]
gi_df = pl.scan_csv(i, new_columns=index_transcription_factor_df).collect()
parsed_sample_name_df = pl.DataFrame([
copy(parsed_name.to_list()) for _ in range(gi_df.shape[0])],
schema=parsed_name.index.to_list()
)
dfs.append(pl.concat([parsed_sample_name_df, gi_df], how="horizontal"))
return pl.concat(dfs, how="vertical")
[docs]
def get_accessibility(self, df_: pl.DataFrame) -> np.ndarray:
metadata = self.signal_types[self.signal_type]["validator"](self.accessibility_extraction_config)
return self.signal_types[self.signal_type]["fun"](df_, metadata)
[docs]
def normalize_df(self, df: pl.DataFrame) -> pl.DataFrame:
signal_col_index = [str(i) for i in list(range(self.signal_length))]
signal_col_index_l_flank = [str(i) for i in list(range(self.flanking_signal_indices[0]))]
signal_col_index_r_flank = [
str(i) for i
in list(range(self.flanking_signal_indices[1], self.signal_length))
]
df[signal_col_index] = (
df[signal_col_index] /
(
0.5 * (df[signal_col_index_l_flank].mean(axis=1) + df[signal_col_index_r_flank].mean(axis=1))
)
)
return df
[docs]
def load(self) -> pl.DataFrame:
res_polar_df = self.normalize_df(self.get_res_df_polars()) if self.normalize else self.get_res_df_polars()
signal_col_index = [str(i) for i in list(range(self.signal_length))]
res_polar_df = res_polar_df.with_columns(
pl.Series(self.get_accessibility(res_polar_df[signal_col_index])).alias("amplitude"))
return res_polar_df