From a7c3719cfb120f96e3896e055680a473839160bc Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Fri, 16 Aug 2024 18:48:49 +0900 Subject: [PATCH 01/11] Implement initial failing test for residue representations --- src/sceptr/model.py | 4 ++++ tests/test_functional_api.py | 9 +++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/sceptr/model.py b/src/sceptr/model.py index 23bae7d..ce71f38 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -13,6 +13,10 @@ BATCH_SIZE = 512 +class ResidueRepresentations: + pass + + class Sceptr: """ Loads a trained state of a SCEPTR (variant) and provides an easy interface for generating TCR representations and making inferences from them. diff --git a/tests/test_functional_api.py b/tests/test_functional_api.py index c9397f3..687fce1 100644 --- a/tests/test_functional_api.py +++ b/tests/test_functional_api.py @@ -1,4 +1,5 @@ import sceptr +from sceptr.model import ResidueRepresentations import numpy as np import pandas as pd import pytest @@ -14,8 +15,12 @@ def test_embed(dummy_data): result = sceptr.calc_vector_representations(dummy_data) assert type(result) == np.ndarray - assert len(result.shape) == 2 - assert result.shape[0] == 3 + assert result.shape == (3, 64) + +def test_residue_embed(dummy_data): + result = sceptr.calc_residue_representations(dummy_data) + + assert type(result) == ResidueRepresentations def test_cdist(dummy_data): From 9e736a3de61187eaae1e24b225b6659435ca09f6 Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Fri, 16 Aug 2024 20:47:35 +0900 Subject: [PATCH 02/11] Implement minimum passing code --- src/sceptr/__init__.py | 22 ++++++++++++++++-- src/sceptr/model.py | 53 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/sceptr/__init__.py b/src/sceptr/__init__.py index 00b07ae..dd81f2b 100644 --- a/src/sceptr/__init__.py +++ b/src/sceptr/__init__.py @@ -4,7 +4,7 @@ """ from sceptr import variant -from sceptr.model import Sceptr +from sceptr.model import Sceptr, ResidueRepresentations import sys from numpy import ndarray from pandas import DataFrame @@ -53,7 +53,7 @@ def calc_pdist_vector(instances: DataFrame) -> ndarray: def calc_vector_representations(instances: DataFrame) -> ndarray: """ - Map a table of TCRs provided as a pandas DataFrame in the above format to their corresponding vector representations. + Map TCRs to their corresponding vector representations. Parameters ---------- @@ -69,6 +69,24 @@ def calc_vector_representations(instances: DataFrame) -> ndarray: return get_default_model().calc_vector_representations(instances) +def calc_residue_representations(instances: DataFrame) -> ResidueRepresentations: + """ + Given multiple TCRs, map each TCR to a set of amino acid residue-level representations. + The residue-level representations are taken from the output of the penultimate self-attention layer, and are the same ones used by the :py:func:`~sceptr.variant.average_pooling` variant when generating TCR receptor-level representations. + + Parameters + ---------- + instances : DataFrame + DataFrame in the :ref:`prescribed format `. + + Returns + ------- + :py:class:`~sceptr.model.ResidueRepresentations` + For details on how to interpret/use this output, please refer to the documentation for :py:class:`~sceptr.model.ResidueRepresentations`. + """ + return get_default_model().calc_residue_representations(instances) + + def get_default_model() -> Sceptr: if "_DEFAULT_MODEL" not in dir(sys.modules[__name__]): setattr(sys.modules[__name__], "_DEFAULT_MODEL", variant.default()) diff --git a/src/sceptr/model.py b/src/sceptr/model.py index ce71f38..bdaff17 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -14,7 +14,8 @@ class ResidueRepresentations: - pass + compartment_mask: ndarray + representation_array: ndarray class Sceptr: @@ -41,7 +42,7 @@ def __init__( def calc_vector_representations(self, instances: DataFrame) -> ndarray: """ - Map a table of TCRs provided as a pandas DataFrame in the above format to their corresponding vector representations. + Map TCRs to their corresponding vector representations. Parameters ---------- @@ -57,6 +58,54 @@ def calc_vector_representations(self, instances: DataFrame) -> ndarray: torch_representations = self._calc_torch_representations(instances) return torch_representations.cpu().numpy() + def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresentations: + """ + Given multiple TCRs, map each TCR to a set of amino acid residue-level representations. + The residue-level representations are taken from the output of the penultimate self-attention layer, and are the same ones used by the :py:func:`~sceptr.variant.average_pooling` variant when generating TCR receptor-level representations. + + Parameters + ---------- + instances : DataFrame + DataFrame in the :ref:`prescribed format `. + + Returns + ------- + :py:class:`~sceptr.model.ResidueRepresentations` + For details on how to interpret/use this output, please refer to the documentation for :py:class:`~sceptr.model.ResidueRepresentations`. + """ + instances = instances.copy() + + for col in ("TRAV", "CDR3A", "TRAJ", "TRBV", "CDR3B", "TRBJ"): + if col not in instances: + instances[col] = None + + tcrs = schema.generate_tcr_series(instances) + + residue_reps_collection = [] + compartment_masks_collection = [] + + for idx in range(0, len(tcrs), BATCH_SIZE): + batch = tcrs.iloc[idx : idx + BATCH_SIZE] + tokenised_batch = [self._tokeniser.tokenise(tcr) for tcr in batch] + padded_batch = utils.rnn.pad_sequence( + sequences=tokenised_batch, + batch_first=True, + padding_value=DefaultTokenIndex.NULL, + ).to(self._device) + + raw_token_embeddings = self._bert._embed(padded_batch) + padding_mask = self._bert._get_padding_mask(padded_batch) + residue_reps = self._bert._self_attention_stack.get_token_embeddings_at_penultimate_layer(raw_token_embeddings, padding_mask) + compartment_masks = padded_batch[:, :, 3] + + residue_reps_collection.append(residue_reps) + compartment_masks_collection.append(compartment_masks) + + residue_reps_combined = torch.concatenate(residue_reps_collection, dim=0) + compartment_masks_combined = torch.concatenate(compartment_masks_collection, dim=0) + + return ResidueRepresentations() + @torch.no_grad() def _calc_torch_representations(self, instances: DataFrame) -> FloatTensor: instances = instances.copy() From 1b1f9db8e6967055f4b71bdc10a63913a3fb8361 Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Sat, 17 Aug 2024 13:15:09 +0900 Subject: [PATCH 03/11] Modify github release workflow to only trigger on stable releases --- .github/workflows/publish_to_pypi.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish_to_pypi.yaml b/.github/workflows/publish_to_pypi.yaml index ecab5ab..cf4a0a9 100644 --- a/.github/workflows/publish_to_pypi.yaml +++ b/.github/workflows/publish_to_pypi.yaml @@ -1,7 +1,7 @@ name: publish to PyPI on: release: - types: [published] + types: [released] jobs: build: runs-on: ubuntu-latest From 01403e778794219c4fe7241b40eca6e3697e7afb Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Sat, 17 Aug 2024 13:15:31 +0900 Subject: [PATCH 04/11] Implement basic inline documentation --- docs/sceptr_model.rst | 3 ++ src/sceptr/model.py | 89 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/docs/sceptr_model.rst b/docs/sceptr_model.rst index be1c3fc..16aca8d 100644 --- a/docs/sceptr_model.rst +++ b/docs/sceptr_model.rst @@ -3,3 +3,6 @@ .. autoclass:: sceptr.model.Sceptr() :members: + +.. autoclass:: sceptr.model.ResidueRepresentations() + :members: diff --git a/src/sceptr/model.py b/src/sceptr/model.py index bdaff17..763bf4e 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -8,14 +8,98 @@ from libtcrlm.tokeniser import Tokeniser from libtcrlm.tokeniser.token_indices import DefaultTokenIndex from libtcrlm import schema +from turtle import Shape BATCH_SIZE = 512 class ResidueRepresentations: - compartment_mask: ndarray + """ + An object containing information necessary to interpret and operate on residue-level representations from the SCEPTR family of models. + Instances of this class can be obtained via the :py:func:`sceptr.calc_residue_representations` function and a method of the same name on the :py:class:`~sceptr.model.Sceptr` class. + + This feature is implemented for the curious users who would like to tinker around and examine what kind of information SCEPTR focuses on at the individual amino acid residue level, and do so without completely hacking into the source code of :py:mod:`sceptr`. + For some examples of how to use instances of this class to make useful examinations of SCEPTR's residue-level embeddings, please refer to the "Examples" section below. + + Attributes + ---------- + representation_array : ndarray + A numpy float ndarray containing the residue-level representation data. + The array is of shape :math:`(N, M, D)` where :math:`N` is the number of TCRs in the original input, :math:`M` is the maximum number of residues amongst the input TCRs when put into its tokenised form, and :math:`D` is the dimensionality of the model variant that produced the result. + + compartment_mask : ndarray + A numpy integer array containing information on which indices in the `representation_array` correspond to tokens that come from each CDR loop of the input TCRs. + The array is of shape :math:`(N, M)` where :math:`N` is the number of TCRs in the original input, and :math:`M` is the maximum number of residues amongst the input TCRs when put into its tokenised form. + Entries in `compartment_mask` have the following values: + + +------------------------------+------------------+ + | If residue at index is from: | Entry has value: | + +==============================+==================+ + | None (padding token) | 0 | + +------------------------------+------------------+ + | CDR1A | 1 | + +------------------------------+------------------+ + | CDR2A | 2 | + +------------------------------+------------------+ + | CDR3A | 3 | + +------------------------------+------------------+ + | CDR1B | 4 | + +------------------------------+------------------+ + | CDR2B | 5 | + +------------------------------+------------------+ + | CDR3B | 6 | + +------------------------------+------------------+ + + Within each CDR loop compartment, residues are ordered from C- to N-terminal from left to right. + + Examples + -------- + As an example, let's see how one could get the residue-level representations for the beta-chain CDR3 amino acid sequences of all input TCR sequences. + Say we have some DataFrame ``tcrs`` that contains the sequence data for four TCRs. + + >>> from pandas import DataFrame + >>> tcrs = DataFrame( + ... data = { + ... "TRAV": ["TRAV38-1*01", "TRAV3*01", "TRAV13-2*01", "TRAV38-2/DV8*01"], + ... "CDR3A": ["CAHRSAGGGTSYGKLTF", "CAVDNARLMF", "CAERIRKGQVLTGGGNKLTF", "CAYRSAGGGTSYGKLTF"], + ... "TRBV": ["TRBV2*01", "TRBV25-1*01", "TRBV9*01", "TRBV2*01"], + ... "CDR3B": ["CASSEFQGDNEQFF", "CASSDGSFNEQFF", "CASSVGDLLTGELFF", "CASSPGTGGNEQYF"], + ... }, + ... index = [0,1,2,3] + ... ) + >>> print(tcrs) + TRAV CDR3A TRBV CDR3B + 0 TRAV38-1*01 CAHRSAGGGTSYGKLTF TRBV2*01 CASSEFQGDNEQFF + 1 TRAV3*01 CAVDNARLMF TRBV25-1*01 CASSDGSFNEQFF + 2 TRAV13-2*01 CAERIRKGQVLTGGGNKLTF TRBV9*01 CASSVGDLLTGELFF + 3 TRAV38-2/DV8*01 CAYRSAGGGTSYGKLTF TRBV2*01 CASSPGTGGNEQYF + + We can get the residue-level representations for those TCRs like so: + + >>> import sceptr + >>> res_reps = sceptr.calc_residue_representations(tcrs) + + Now, we can iterate through the residue-level representation subarray corresponding to each TCR, and filter out/obtain the representations for the beta chain CDR3 sequence. + + >>> cdr3b_reps = [] + >>> for reps, mask in zip(res_reps.representation_array, res_reps.compartment_mask): + ... # reps.shape == (M, D) + ... # mask.shape == (M,) + ... cdr3b_rep = reps[mask == 6] # collect only the residue representations for the beta CDR3 sequence + ... cdr3b_reps.append(cdr3b_rep) + + Now we have a list containing four numpy ndarrays, each of which is a matrix whose row vectors are representations of individual CDR3B amino acid residues. + + >>> type(cdr3b_reps[0]) + numpy.ndarray + >>> cdr3b_reps[0].shape + (14, 64) + + Note that the zeroth element of the shape tuple above is 14 because the CDR3B sequence of the first TCR in ``tcrs`` is 14 residues long, and the first element of the shape tuple is 64 because the model dimensionality of the default SCEPTR variant is 64. + """ representation_array: ndarray + compartment_mask: ndarray class Sceptr: @@ -63,6 +147,9 @@ def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresent Given multiple TCRs, map each TCR to a set of amino acid residue-level representations. The residue-level representations are taken from the output of the penultimate self-attention layer, and are the same ones used by the :py:func:`~sceptr.variant.average_pooling` variant when generating TCR receptor-level representations. + .. note :: + This method is currently only supported on SCEPTR model variants such as the default one that 1) use both the alpha and beta chains, and 2) take into account all three CDR loops from each chain. + Parameters ---------- instances : DataFrame From b2292d9291551864224bb9874b2f2b42deb5106f Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:30:53 +0900 Subject: [PATCH 05/11] Fix bugs and improve residue representation unit tests --- src/sceptr/model.py | 17 ++++++++++++----- tests/test_functional_api.py | 11 +++++++---- 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/sceptr/model.py b/src/sceptr/model.py index 763bf4e..e1b1dd3 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -8,7 +8,6 @@ from libtcrlm.tokeniser import Tokeniser from libtcrlm.tokeniser.token_indices import DefaultTokenIndex from libtcrlm import schema -from turtle import Shape BATCH_SIZE = 512 @@ -101,6 +100,10 @@ class ResidueRepresentations: representation_array: ndarray compartment_mask: ndarray + def __init__(self, representation_array: ndarray, compartment_mask: ndarray) -> None: + self.representation_array = representation_array + self.compartment_mask = compartment_mask + class Sceptr: """ @@ -142,6 +145,7 @@ def calc_vector_representations(self, instances: DataFrame) -> ndarray: torch_representations = self._calc_torch_representations(instances) return torch_representations.cpu().numpy() + @torch.no_grad() def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresentations: """ Given multiple TCRs, map each TCR to a set of amino acid residue-level representations. @@ -182,16 +186,19 @@ def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresent raw_token_embeddings = self._bert._embed(padded_batch) padding_mask = self._bert._get_padding_mask(padded_batch) + residue_reps = self._bert._self_attention_stack.get_token_embeddings_at_penultimate_layer(raw_token_embeddings, padding_mask) - compartment_masks = padded_batch[:, :, 3] + residue_reps = residue_reps[:, 1:, :] + + compartment_masks = padded_batch[:, 1:, 3] residue_reps_collection.append(residue_reps) compartment_masks_collection.append(compartment_masks) - residue_reps_combined = torch.concatenate(residue_reps_collection, dim=0) - compartment_masks_combined = torch.concatenate(compartment_masks_collection, dim=0) + residue_reps_combined = torch.concatenate(residue_reps_collection, dim=0).cpu().numpy() + compartment_masks_combined = torch.concatenate(compartment_masks_collection, dim=0).cpu().numpy() - return ResidueRepresentations() + return ResidueRepresentations(residue_reps_combined, compartment_masks_combined) @torch.no_grad() def _calc_torch_representations(self, instances: DataFrame) -> FloatTensor: diff --git a/tests/test_functional_api.py b/tests/test_functional_api.py index 687fce1..5a693ed 100644 --- a/tests/test_functional_api.py +++ b/tests/test_functional_api.py @@ -14,24 +14,27 @@ def dummy_data(): def test_embed(dummy_data): result = sceptr.calc_vector_representations(dummy_data) - assert type(result) == np.ndarray + assert isinstance(result, np.ndarray) assert result.shape == (3, 64) + def test_residue_embed(dummy_data): result = sceptr.calc_residue_representations(dummy_data) - assert type(result) == ResidueRepresentations + assert isinstance(result, ResidueRepresentations) + assert result.representation_array.shape == (3, 47, 64) + assert result.compartment_mask.shape == (3, 47) def test_cdist(dummy_data): result = sceptr.calc_cdist_matrix(dummy_data, dummy_data) - assert type(result) == np.ndarray + assert isinstance(result, np.ndarray) assert result.shape == (3, 3) def test_pdist(dummy_data): result = sceptr.calc_pdist_vector(dummy_data) - assert type(result) == np.ndarray + assert isinstance(result, np.ndarray) assert result.shape == (3,) From 1e5b51143923f82f169724f2621c521a20549fbc Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:50:58 +0900 Subject: [PATCH 06/11] Revert "Modify github release workflow to only trigger on stable releases" This reverts commit 7b9275679cb784129fc7e10d772a13e7f3658941. --- .github/workflows/publish_to_pypi.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish_to_pypi.yaml b/.github/workflows/publish_to_pypi.yaml index cf4a0a9..ecab5ab 100644 --- a/.github/workflows/publish_to_pypi.yaml +++ b/.github/workflows/publish_to_pypi.yaml @@ -1,7 +1,7 @@ name: publish to PyPI on: release: - types: [released] + types: [published] jobs: build: runs-on: ubuntu-latest From 18b01a6dc12ed4d1049e7af3efcec50f8c1bff56 Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:56:52 +0900 Subject: [PATCH 07/11] Ensure residue representation method is only supported on SCEPTR variants that use the CdrTokensier --- src/sceptr/model.py | 5 ++++- tests/test_variants.py | 32 ++++++++++++++++++++++++++++---- 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/src/sceptr/model.py b/src/sceptr/model.py index e1b1dd3..e74ffa1 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -5,7 +5,7 @@ from numpy import ndarray from pandas import DataFrame from libtcrlm.bert import Bert -from libtcrlm.tokeniser import Tokeniser +from libtcrlm.tokeniser import Tokeniser, CdrTokeniser from libtcrlm.tokeniser.token_indices import DefaultTokenIndex from libtcrlm import schema @@ -164,6 +164,9 @@ def calc_residue_representations(self, instances: DataFrame) -> ResidueRepresent :py:class:`~sceptr.model.ResidueRepresentations` For details on how to interpret/use this output, please refer to the documentation for :py:class:`~sceptr.model.ResidueRepresentations`. """ + if not isinstance(self._tokeniser, CdrTokeniser): + raise NotImplementedError("The calc_residue_representations method is currently only supported on SCEPTR model variants that 1) use both the alpha and beta chains, and 2) take into account all three CDR loops from each chain.") + instances = instances.copy() for col in ("TRAV", "CDR3A", "TRAJ", "TRBV", "CDR3B", "TRBJ"): diff --git a/tests/test_variants.py b/tests/test_variants.py index 071df5d..96ec164 100644 --- a/tests/test_variants.py +++ b/tests/test_variants.py @@ -1,5 +1,5 @@ from sceptr import variant -from sceptr.model import Sceptr +from sceptr.model import Sceptr, ResidueRepresentations import numpy as np import pandas as pd import pytest @@ -39,18 +39,42 @@ def test_load_variant(self, model): def test_embed(self, model, dummy_data): result = model.calc_vector_representations(dummy_data) - assert type(result) == np.ndarray + assert isinstance(result, np.ndarray) assert len(result.shape) == 2 assert result.shape[0] == 3 def test_cdist(self, model, dummy_data): result = model.calc_cdist_matrix(dummy_data, dummy_data) - assert type(result) == np.ndarray + assert isinstance(result, np.ndarray) assert result.shape == (3, 3) def test_pdist(self, model, dummy_data): result = model.calc_pdist_vector(dummy_data) - assert type(result) == np.ndarray + assert isinstance(result, np.ndarray) assert result.shape == (3,) + + def test_residue_representations(self, model, dummy_data): + if model.name in ( + "SCEPTR", + "SCEPTR (MLM only)", + "SCEPTR (left-aligned)", + "SCEPTR (small)", + "SCEPTR (BLOSUM)", + "SCEPTR (average-pooling)", + "SCEPTR (finetuned)" + ): + result = model.calc_residue_representations(dummy_data) + + assert isinstance(result, ResidueRepresentations) + assert len(result.representation_array.shape) == 3 + assert result.representation_array.shape[:2] == (3, 47) + assert result.compartment_mask.shape == (3, 47) + + if model.name in ( + "SCEPTR (CDR3 only)", + "A SCEPTR", + ): + with pytest.raises(NotImplementedError): + model.calc_residue_representations(dummy_data) From 2360dd66c1fcf0ee4e3472e1b336a3c120698b04 Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:35:26 +0900 Subject: [PATCH 08/11] Implement doctest builder and fix doc bugs --- docs/conf.py | 2 +- docs/usage.rst | 8 ++++---- src/sceptr/_model_saves/__init__.py | 2 +- src/sceptr/model.py | 2 +- tox.ini | 1 + 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index c75ea74..390dbed 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -16,7 +16,7 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.napoleon"] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.napoleon", "sphinx.ext.doctest"] templates_path = ["_templates"] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] diff --git a/docs/usage.rst b/docs/usage.rst index 1bf0bd5..3ee4a56 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -35,14 +35,14 @@ As the name suggests, :py:func:`~sceptr.calc_cdist_matrix` gives you an easy way >>> import sceptr >>> cdist_matrix = sceptr.calc_cdist_matrix(tcrs.iloc[:2], tcrs.iloc[2:]) >>> print(cdist_matrix) -[[1.2849896 0.75219345] - [1.4653426 1.4646543 ]] +[[1.2849896 0.7521934] + [1.4653426 1.4646543]] If you're only interested in calculating distances within a set, :py:func:`~sceptr.calc_pdist_vector` gives you a one-dimensional array of within-set distances. >>> pdist_vector = sceptr.calc_pdist_vector(tcrs) >>> print(pdist_vector) -[1.4135991 1.2849895 0.7521934 1.4653426 1.4646543 1.287208 ] +[1.4135991 1.2849895 0.75219345 1.4653426 1.4646543 1.287208 ] .. tip:: The end result of using the :py:func:`~sceptr.calc_cdist_matrix` and :py:func:`~sceptr.calc_pdist_vector` functions are equivalent to generating sceptr's TCR representations first with :py:func:`~sceptr.calc_vector_representations`, then using `scipy `_'s `cdist `_ or `pdist `_ functions to get the corresponding matrix or vector, respectively. @@ -52,7 +52,7 @@ If you want to directly operate on sceptr's TCR representations, you can use :py >>> reps = sceptr.calc_vector_representations(tcrs) >>> print(reps.shape) -(4,64) +(4, 64) .. _model_variants: diff --git a/src/sceptr/_model_saves/__init__.py b/src/sceptr/_model_saves/__init__.py index 7bfc01b..d17c94b 100644 --- a/src/sceptr/_model_saves/__init__.py +++ b/src/sceptr/_model_saves/__init__.py @@ -12,7 +12,7 @@ def load_variant(model_name: str) -> Sceptr: config = json.load(f) with (model_save_dir / "state_dict.pt").open("rb") as f: - state_dict = torch.load(f) + state_dict = torch.load(f, weights_only=True) config_reader = ConfigReader(config) diff --git a/src/sceptr/model.py b/src/sceptr/model.py index e74ffa1..c7e810c 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -91,7 +91,7 @@ class ResidueRepresentations: Now we have a list containing four numpy ndarrays, each of which is a matrix whose row vectors are representations of individual CDR3B amino acid residues. >>> type(cdr3b_reps[0]) - numpy.ndarray + >>> cdr3b_reps[0].shape (14, 64) diff --git a/tox.ini b/tox.ini index 7b029eb..98551f6 100644 --- a/tox.ini +++ b/tox.ini @@ -26,4 +26,5 @@ deps = sphinx>=6 sphinx-book-theme commands = + sphinx-build -b doctest docs docs/_build sphinx-build docs docs/_build From dc12a66b6b8aeff2a25db618e924513d83fd278b Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Tue, 20 Aug 2024 13:57:46 +0900 Subject: [PATCH 09/11] Improve readability of ResidueRepresentations repr --- src/sceptr/model.py | 3 +++ tests/test_residue_representations.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) create mode 100644 tests/test_residue_representations.py diff --git a/src/sceptr/model.py b/src/sceptr/model.py index c7e810c..91e7554 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -104,6 +104,9 @@ def __init__(self, representation_array: ndarray, compartment_mask: ndarray) -> self.representation_array = representation_array self.compartment_mask = compartment_mask + def __repr__(self) -> str: + return f"ResidueRepresentations[num_tcrs: {self.representation_array.shape[0]}, rep_dim: {self.representation_array.shape[2]}]" + class Sceptr: """ diff --git a/tests/test_residue_representations.py b/tests/test_residue_representations.py new file mode 100644 index 0000000..64b2355 --- /dev/null +++ b/tests/test_residue_representations.py @@ -0,0 +1,14 @@ +import numpy as np +import pytest +from sceptr.model import ResidueRepresentations + + +def test_repr(res_reps): + assert res_reps.__repr__() == "ResidueRepresentations[num_tcrs: 3, rep_dim: 64]" + + +@pytest.fixture +def res_reps() -> ResidueRepresentations: + rep_array = np.zeros((3, 10, 64)) + comp_mask = np.zeros_like(rep_array, dtype=int) + return ResidueRepresentations(rep_array, comp_mask) From 8f1a0a0fc53a09fea68ed661c20e9eb5196940e8 Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:27:47 +0900 Subject: [PATCH 10/11] Improve ResidueRepresentations docstring --- src/sceptr/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sceptr/model.py b/src/sceptr/model.py index 91e7554..8c25986 100644 --- a/src/sceptr/model.py +++ b/src/sceptr/model.py @@ -78,13 +78,13 @@ class ResidueRepresentations: >>> import sceptr >>> res_reps = sceptr.calc_residue_representations(tcrs) + >>> print(res_reps) + ResidueRepresentations[num_tcrs: 4, rep_dim: 64] Now, we can iterate through the residue-level representation subarray corresponding to each TCR, and filter out/obtain the representations for the beta chain CDR3 sequence. >>> cdr3b_reps = [] >>> for reps, mask in zip(res_reps.representation_array, res_reps.compartment_mask): - ... # reps.shape == (M, D) - ... # mask.shape == (M,) ... cdr3b_rep = reps[mask == 6] # collect only the residue representations for the beta CDR3 sequence ... cdr3b_reps.append(cdr3b_rep) From b9b8167cf0ccbec0901d41c65d98e86965a9fe75 Mon Sep 17 00:00:00 2001 From: Yuta Nagano <52748151+yutanagano@users.noreply.github.com> Date: Tue, 20 Aug 2024 14:49:31 +0900 Subject: [PATCH 11/11] Improve usage docs and reflect new residue_reps function --- docs/usage.rst | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/docs/usage.rst b/docs/usage.rst index 3ee4a56..0c86ac5 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -29,7 +29,10 @@ To begin analysing TCR data with sceptr, you must first load the TCR data into m 2 TRAV13-2*01 CAERIRKGQVLTGGGNKLTF TRBV9*01 CASSVGDLLTGELFF 3 TRAV38-2/DV8*01 CAYRSAGGGTSYGKLTF TRBV2*01 CASSPGTGGNEQYF -:py:mod:`sceptr` exposes three intuitive functions: :py:func:`~sceptr.calc_cdist_matrix`, :py:func:`~sceptr.calc_pdist_vector`, and :py:func:`~sceptr.calc_vector_representations`. + +``calc_cdist_matrix`` +********************* + As the name suggests, :py:func:`~sceptr.calc_cdist_matrix` gives you an easy way to calculate a cross-distance matrix between two sets of TCRs. >>> import sceptr @@ -38,6 +41,9 @@ As the name suggests, :py:func:`~sceptr.calc_cdist_matrix` gives you an easy way [[1.2849896 0.7521934] [1.4653426 1.4646543]] +``calc_pdist_vector`` +********************* + If you're only interested in calculating distances within a set, :py:func:`~sceptr.calc_pdist_vector` gives you a one-dimensional array of within-set distances. >>> pdist_vector = sceptr.calc_pdist_vector(tcrs) @@ -48,16 +54,30 @@ If you're only interested in calculating distances within a set, :py:func:`~scep The end result of using the :py:func:`~sceptr.calc_cdist_matrix` and :py:func:`~sceptr.calc_pdist_vector` functions are equivalent to generating sceptr's TCR representations first with :py:func:`~sceptr.calc_vector_representations`, then using `scipy `_'s `cdist `_ or `pdist `_ functions to get the corresponding matrix or vector, respectively. But on machines with `CUDA-enabled GPUs `_, directly using sceptr's :py:func:`~sceptr.calc_cdist_matrix` and :py:func:`~sceptr.calc_pdist_vector` functions will run faster, as it internally runs all computations on the GPU. +``calc_vector_representations`` +******************************* + If you want to directly operate on sceptr's TCR representations, you can use :py:func:`~sceptr.calc_vector_representations`. >>> reps = sceptr.calc_vector_representations(tcrs) >>> print(reps.shape) (4, 64) +``calc_residue_representations`` +******************************** + +The package also provides the user with an easy way to get access to SCEPTR's internal representations of each individual amino acid residue in the tokenised form of its input TCRs, as outputted by the penultimate layer of its self-attention stack. +Interested users can use :py:func:`~sceptr.calc_residue_representations`. +Please refer to the documentation for the :py:class:`~sceptr.model.ResidueRepresentations` class for details on how to interpret the output. + +>>> res_reps = sceptr.calc_residue_representations(tcrs) +>>> print(res_reps) +ResidueRepresentations[num_tcrs: 4, rep_dim: 64] + .. _model_variants: -Model Variants (:py:mod:`sceptr.variant`) ------------------------------------------ +Model variants +-------------- The :py:mod:`sceptr.variant` submodule allows users access a variety of non-default SCEPTR model variants, and use them for TCR analysis. The submodule exposes functions which return :py:class:`~sceptr.model.Sceptr` objects with the model state of the chosen variant loaded.