import glob import itertools from pathlib import Path from typing import List, Tuple, Optional, Dict, NamedTuple, Union, Callable import string import numpy as np import torch from scipy.spatial.distance import squareform, pdist, cdist from Bio import SeqIO #import biotite.structure as bs #from biotite.structure.io.pdbx import PDBxFile, get_structure #from biotite.database import rcsb from tqdm import tqdm import pandas as pd # This is an efficient way to delete lowercase characters and insertion characters from a string deletekeys = dict.fromkeys(string.ascii_lowercase) deletekeys["."] = None deletekeys["*"] = None translation = str.maketrans(deletekeys) def read_sequence(filename: str) -> Tuple[str, str]: """ Reads the first (reference) sequences from a fasta or MSA file.""" record = next(SeqIO.parse(filename, "fasta")) return record.description, str(record.seq) def remove_insertions(sequence: str) -> str: """ Removes any insertions into the sequence. Needed to load aligned sequences in an MSA. """ return sequence.translate(translation) def read_msa(filename: str) -> List[Tuple[str, str]]: """ Reads the sequences from an MSA file, automatically removes insertions.""" return [(record.description, remove_insertions(str(record.seq))) for record in SeqIO.parse(filename, "fasta")] def greedy_select(msa: List[Tuple[str, str]], num_seqs: int, mode: str = "max") -> List[Tuple[str, str]]: """ Select sequences from the MSA to maximize the hamming distance Alternatively, can use hhfilter """ assert mode in ("max", "min") if len(msa) <= num_seqs: return msa array = np.array([list(seq) for _, seq in msa], dtype=np.bytes_).view(np.uint8) optfunc = np.argmax if mode == "max" else np.argmin all_indices = np.arange(len(msa)) indices = [0] pairwise_distances = np.zeros((0, len(msa))) for _ in range(num_seqs - 1): dist = cdist(array[indices[-1:]], array, "hamming") pairwise_distances = np.concatenate([pairwise_distances, dist]) shifted_distance = np.delete(pairwise_distances, indices, axis=1).mean(0) shifted_index = optfunc(shifted_distance) index = np.delete(all_indices, indices)[shifted_index] indices.append(index) indices = sorted(indices) return [msa[idx] for idx in indices]