Skip to content

Commit

Permalink
merged from main
Browse files Browse the repository at this point in the history
  • Loading branch information
qcampbel committed Aug 8, 2024
2 parents 8124855 + c2ae2f9 commit 5a276a2
Show file tree
Hide file tree
Showing 34 changed files with 218 additions and 8,402 deletions.
2 changes: 1 addition & 1 deletion mdagent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
use_human_tool=False,
uploaded_files=[], # user input files to add to path registry
run_id="",
use_memory=True,
use_memory=False,
):
self.llm = _make_llm(model, temp, streaming)
if tools_model is None:
Expand Down
6 changes: 3 additions & 3 deletions mdagent/tools/base_tools/analysis_tools/rdf_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ def _run(self, **input):
Log_id=trajectory_id,
)
fig_id = self.path_registry.get_fileid(plot_name, type=FileType.FIGURE)

plt.savefig(f"{self.path_registry.ckpt_figures}/rdf_{trajectory_id}.png")
file_path = f"{self.path_registry.ckpt_figures}/rdf_{trajectory_id}.png"
plt.savefig(file_path)
self.path_registry.map_path(
fig_id,
plot_name,
file_path,
description=f"RDF plot for the trajectory file with id: {trajectory_id}",
)
plt.close()
Expand Down
149 changes: 75 additions & 74 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,92 +5,80 @@
import numpy as np
from langchain.tools import BaseTool

from mdagent.utils import FileType, PathRegistry
from mdagent.utils import FileType, PathRegistry, load_single_traj


class RadiusofGyration:
def __init__(self, path_registry):
self.path_registry = path_registry
self.includes_top = [".h5", ".lh5", ".pdb"]

def _grab_files(self, pdb_id: str) -> None:
if "_" in pdb_id:
pdb_id = pdb_id.split("_")[0]
self.pdb_id = pdb_id
all_names = self.path_registry._list_all_paths()
try:
self.pdb_path = [
name
for name in all_names
if pdb_id in name and ".pdb" in name and "records" in name
][0]
except IndexError:
raise ValueError(f"No pdb file found for {pdb_id}")
try:
self.dcd_path = [
name
for name in all_names
if pdb_id in name and ".dcd" in name and "records" in name
][0]
except IndexError:
self.dcd_path = None
pass
return None

def _load_traj(self, pdb_id: str) -> None:
self._grab_files(pdb_id)
if self.dcd_path:
self.traj = md.load(self.dcd_path, top=self.pdb_path)
else:
self.traj = md.load(self.pdb_path)
return None
self.top_file = ""
self.traj_file = ""
self.traj = None

def _load_traj(self, top_file: str, traj_file: str):
self.traj_file = traj_file
self.top_file = top_file
self.traj = load_single_traj(
path_registry=self.path_registry,
top_fileid=top_file,
traj_fileid=traj_file,
traj_required=True,
)

def rad_gyration_per_frame(self, pdb_id: str) -> str:
self._load_traj(pdb_id)
def rgy_per_frame(self, force_recompute: bool = False) -> str:
rg_per_frame = md.compute_rg(self.traj)

self.rgy_file = (
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.pdb_id}.csv"
)

np.savetxt(
self.rgy_file, rg_per_frame, delimiter=",", header="Radius of Gyration (nm)"
)
self.path_registry.map_path(
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.pdb_id}.csv",
self.rgy_file,
description=f"Radii of gyration per frame for {self.pdb_id}",
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv"
)
return f"Radii of gyration saved to {self.rgy_file}"

def rad_gyration_average(self, pdb_id: str) -> str:
_ = self.rad_gyration_per_frame(pdb_id)
rgy_id = f"rgy_{self.traj_file}"
if rgy_id in self.path_registry.list_path_names() and force_recompute is False:
print("RGY already computed, skipping re-compute")
# todo -> maybe allow re-compute & save under different id/path
else:
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
return f"Radii of gyration saved to {self.rgy_file} with id {rgy_id}."

def rgy_average(self) -> str:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
avg_rg = rg_per_frame.mean()

return f"Average radius of gyration: {avg_rg:.2f} nm"

def plot_rad_gyration(self, pdb_id: str) -> str:
_ = self.rad_gyration_per_frame(pdb_id)
def plot_rgy(self) -> str:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
fig_analysis = f"rgy_{self.pdb_id}"
fig_analysis = f"rgy_{self.traj_file}"
plot_name = self.path_registry.write_file_name(
type=FileType.FIGURE, fig_analysis=fig_analysis, file_format="png"
)
print("plot_name: ", plot_name)
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)

if plot_name.endswith(".png"):
plot_name = plot_name.split(".png")[0]
plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}"
plt.plot(rg_per_frame)
plt.xlabel("Frame")
plt.ylabel("Radius of Gyration (nm)")
plt.title(f"{pdb_id} - Radius of Gyration Over Time")
plt.title(f"{self.traj_file} - Radius of Gyration Over Time")

plt.savefig(f"{self.path_registry.ckpt_figures}/{plot_name}")
plt.savefig(f"{plot_path}")
self.path_registry.map_path(
plot_id,
f"{self.path_registry.ckpt_figures}/{plot_name}",
description=f"Plot of radii of gyration over time for {self.pdb_id}",
plot_path,
description=f"Plot of radii of gyration over time for {self.traj_file}",
)
plt.close()
plt.clf()
Expand All @@ -100,20 +88,24 @@ def plot_rad_gyration(self, pdb_id: str) -> str:
class RadiusofGyrationAverage(BaseTool):
name = "RadiusofGyrationAverage"
description = """This tool calculates the average radius of gyration
for the given trajectory file. Give this tool the
protein ID (PDB ID) only. The tool will automatically find the necessary files."""
for a trajectory. Give this tool BOTH the trajectory file ID and the
topology file ID."""

path_registry: Optional[PathRegistry]

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_id: str) -> str:
def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
RGY = RadiusofGyration(self.path_registry)
return "Succeeded. " + RGY.rad_gyration_average(pdb_id)
return "Succeeded. " + RGY.rgy_average()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
Expand All @@ -127,8 +119,9 @@ async def _arun(self, query: str) -> str:
class RadiusofGyrationPerFrame(BaseTool):
name = "RadiusofGyrationPerFrame"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file. Give this tool the
protein ID (PDB ID) only. The tool will automatically find the necessary files.
at each frame of a given trajectory.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the radii of gyration to a csv file and
map it to the registry."""

Expand All @@ -138,11 +131,15 @@ def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_id: str) -> str:
def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY = RadiusofGyration(self.path_registry)
return "Succeeded. " + RGY.rad_gyration_per_frame(pdb_id)
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_per_frame()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
Expand All @@ -157,8 +154,8 @@ class RadiusofGyrationPlot(BaseTool):
name = "RadiusofGyrationPlot"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file and plots it.
Give this tool the protein ID (PDB ID) only.
The tool will automatically find the necessary files.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the plot to a png file and map it to the registry."""

path_registry: Optional[PathRegistry]
Expand All @@ -167,11 +164,15 @@ def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_id: str) -> str:
def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
RGY = RadiusofGyration(self.path_registry)
return "Succeeded. " + RGY.plot_rad_gyration(pdb_id)
return "Succeeded. " + RGY.plot_rgy()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
Expand Down
67 changes: 48 additions & 19 deletions mdagent/tools/base_tools/analysis_tools/secondary_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,16 @@ def write_raw_x(
The file id of the saved file.
"""
file_name = path_registry.write_file_name(
FileType.RECORD,
record_type=x,
FileType.RECORD, record_type=x, file_format="npy"
)
file_id = path_registry.get_fileid(file_name, FileType.RECORD)

file_path = f"{path_registry.ckpt_records}/{x}_{traj_id}.npy"
file_path = f"{path_registry.ckpt_records}/{file_name}"
np.save(file_path, values)

path_registry.map_path(
file_id,
file_name,
file_path,
description=f"{x} values for trajectory with id: {traj_id}",
)
return file_id
Expand All @@ -43,8 +42,12 @@ def write_raw_x(
class ComputeDSSP(BaseTool):
name = "ComputeDSSP"
description = """Compute the DSSP (secondary structure) assignment
for a protein trajectory. Input is a trajectory file ID
for a protein trajectory. Input is a trajectory file ID and
a target_frames, which can be "first", "last", or "all",
and an optional topology file ID.
Input "first" to get DSSP of only the first frame.
Input "last" to get DSSP of only the last frame.
Input "all" to get DSSP of all frames in trajectory, combined.
The output is an array with the DSSP code for each
residue at each time point."""
path_registry: PathRegistry = PathRegistry.get_instance()
Expand All @@ -71,21 +74,21 @@ def _dssp_natural_language(self) -> dict[str, str]:
used. Otherwise, the full set of codes is used."""
if self.simplified:
return {
"H": "helix",
"E": "strand",
"C": "coil",
"NA": "not assigned, not a protein residue",
"H": "residues in helix",
"E": "residues in strand",
"C": "residues in coil",
"NA": "residues not assigned, not a protein residue",
}
return {
"H": "alpha helix",
"B": "beta bridge",
"E": "extended strand",
"G": "three helix",
"I": "five helix",
"T": "hydrogen bonded turn",
"S": "bend",
" ": "loop or irregular",
"NA": "not assigned, not a protein residue",
"H": "residues in alpha helix",
"B": "residues in beta bridge",
"E": "residues in extended strand",
"G": "residues in three helix",
"I": "residues in five helix",
"T": "residues in hydrogen bonded turn",
"S": "residues in bend",
" ": "residues in loop or irregular",
"NA": "residues not assigned, not a protein residue",
}

def _convert_dssp_counts(self, dssp_counts: dict) -> dict:
Expand Down Expand Up @@ -140,7 +143,32 @@ def _compute_dssp(self, traj: md.Trajectory) -> np.ndarray:
"""
return md.compute_dssp(traj, simplified=self.simplified)

def _run(self, traj_file: str, top_file: Optional[str] = None) -> str:
def _get_frame(self, traj, target_frames):
"""
Retrieves the target frame(s) of the trajectory for DSSP.
Args:
traj: the trajectory
target_frames: the target frames to select. can be first, last, or all
Returns:
the trajectory with only target frames"""

if target_frames.lower().strip() == "all":
return traj
if target_frames.lower().strip() == "first":
return traj[0]
if target_frames.lower().strip() == "last":
return traj[-1]
else:
raise ValueError("Target Frames must be 'all', 'first', or 'last'.")

def _run(
self,
traj_file: str,
top_file: Optional[str] = None,
target_frames: str = "last",
) -> str:
"""
Computes the DSSP assignments for a trajectory and saves the results
to a file.
Expand All @@ -160,6 +188,7 @@ def _run(self, traj_file: str, top_file: Optional[str] = None) -> str:
)
if not traj:
raise Exception("Trajectory could not be loaded.")
traj = self._get_frame(traj, target_frames)
except Exception as e:
print("Error loading trajectory: ", e)
return str(e)
Expand Down
Loading

0 comments on commit 5a276a2

Please sign in to comment.