Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

text-motion alignment pre-trained model #12

Open
Wenretium opened this issue Apr 7, 2024 · 16 comments
Open

text-motion alignment pre-trained model #12

Wenretium opened this issue Apr 7, 2024 · 16 comments
Labels
question Further information is requested

Comments

@Wenretium
Copy link

Hi! I am very interested in your work, especially the text-motion alignment pre-trained model. Hope to see your model and codes soon.

@Wenretium
Copy link
Author

Or please allow me to ask some questions about this part. Reading your paper, I found that your model design is very similar with the work TMR: Text-to-Motion Retrieval Using Contrastive 3D Human Motion Synthesis. Then, can I say that your model is based on it, adding the ability to support hand motions and replacing MPNet with sBERT? Did you train on TMR's code framework?

@LinghaoChan
Copy link
Collaborator

LinghaoChan commented Apr 7, 2024

@Wenretium Thanks for your interest! Your comment is really in-depth and insightful.

Most of the answers are right. However, we do not train our model in TMR's framework. We implemented the model by ourselves before they released codes. We plan to release these codes in about 2 weeks. You can try the demo at first.

Best,

Ling-Hao CHEN

@Wenretium
Copy link
Author

I get it. Thanks for your quick reply!

@LinghaoChan LinghaoChan added the question Further information is requested label Apr 7, 2024
@LinghaoChan
Copy link
Collaborator

Hi @Wenretium !

We release the TMR training in the ./OpenTMA. Please check it!

[Note]: As the research target in this project is to clarify how to use text-motion alignment, TMR in our project is charged as TMA in the ICML-24 version.

@Wenretium
Copy link
Author

Thank you very much! You provided very detailed code documentation.

@LinghaoChan
Copy link
Collaborator

Thank you very much! You provided very detailed code documentation.

@Wenretium welcome. any question, feel free to discuss!

@Wenretium Wenretium reopened this May 20, 2024
@Wenretium
Copy link
Author

Hello! I have another question. Since you didn't provide a full demo for text-motion alignment loss, I added it based on my own understanding.

# Load text and motion data
import torch
import torch.nn.functional as f
import numpy as np
from os.path import join as pjoin
from transformers import AutoTokenizer, AutoModel
from tma.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder
from tma.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder
from sentence_transformers import SentenceTransformer
from collections import OrderedDict

modelpath = 'distilbert-base-uncased'

textencoder = DistilbertActorAgnosticEncoder(modelpath, num_layers=4)
motionencoder = ActorAgnosticEncoder(nfeats=263, vae = True, num_layers=4)

"""
load model here
You need to normalize the motion data with mean and std.
For motionx, they are stored in './deps/t2m/motionx/vector_623/Comp_v6_KLD01/meta/*.npy'
"""
# # loading state dict

state_dict = torch.load('humanml3d.ckpt', map_location="cpu")["state_dict"]

from collections import OrderedDict
textencoder_dict = OrderedDict()
for k, v in state_dict.items():
    if k.split(".")[0] == "textencoder":
        name = k.replace("textencoder.", "")
        textencoder_dict[name] = v
textencoder.load_state_dict(textencoder_dict, strict=True)

motionencoder_dict = OrderedDict()
for k, v in state_dict.items():
    if k.split(".")[0] == "motionencoder":
        name = k.replace("motionencoder.", "")
        motionencoder_dict[name] = v
motionencoder.load_state_dict(motionencoder_dict, strict=True)

text = ["a person wonders in an oval path and ends where he started"]
motion = np.load('/path/to/HumanML3D/new_joint_vecs/000014.npy')
mean = np.load(pjoin('/path/to/Comp_v6_KLD01/meta', 'mean.npy'))
std = np.load(pjoin('/path/to/Comp_v6_KLD01/meta', 'std.npy'))
motion = (motion - mean) / std
motion = torch.Tensor(motion).unsqueeze(0)
lengths = [motion.shape[1]]
text_emb = textencoder(text).loc
motion_emb = motionencoder(motion, lengths).loc     
# print(textencoder(text_emb)
# print(motionencoder(motion_emb)
print(torch.mean(text_emb - motion_emb))

My question is: Did I load the pretrained model correctly? In HumanTOMATO, did you calculate the text-motion alignment loss by 'torch.mean(text_emb - motion_emb)'?

@LinghaoChan
Copy link
Collaborator

@Wenretium Thanks for the reminder. I detail the implementation here.

infoloss = InfoNCE(0.1)
filter_model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
# if TMA supervision 
if args.supervision:
    # generated motions
    all_supervise_motion = torch.cat(gen_supervise_tensor_list, dim = 0).cuda()
    # motion length = token length * 4 (due to upsampling rate is 4)
    full_m_tokens_len = (m_tokens_len.detach() * 4).tolist()
    # get TMR_motion_embedding
    TMR_motion_embedding = t2m_TMR_motionencoder(all_supervise_motion, full_m_tokens_len).loc
    # get TMR_text_embedding
    TMR_text_embedding = t2m_TMR_textencoder(texts).loc
    with torch.no_grad():
        text_embedding = filter_model.encode(texts)
        text_embedding = torch.tensor(text_embedding).cuda()
        normalized = F.normalize(text_embedding, p=2, dim=1)
        # cos similarity
        emb_dist = normalized.matmul(normalized.T)
    loss_infonce = infoloss((TMR_motion_embedding, TMR_text_embedding), emb_dist)
    
    all_loss = loss_cls + args.lambdainfo * loss_infonce

Welcome any question.

@no-Seaweed
Copy link

Hi @LinghaoChan , I would like to ask a followup question.

I am using the trained the h3d checkpoint you provided to reproduce the Recall result. The chart below is the result I have.

protocal A
T2M

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
HumanTOMATO 0.207 0.323 0.553 1.106 2.096

M2T

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
HumanTOMATO 0.230 0.323 0.484 0.783 1.428

protocal D
T2M

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
HumanTOMATO 13.776 21.401 27.390 36.351 51.670

M2T

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
HumanTOMATO 10.620 17.415 23.013 31.560 47.524

Since the *_embedding.npy files are not provided, I use the demo code to get all text_emb and motion_emb of the testset. Both of them are in shape of (~4000, 256). And in the retrieval code, I changed it to the following

import numpy as np
import torch.nn.functional as F
from scipy.signal import normalize
import argparse
import os


def neg_recall(mat, k_value):
    """
    This function prints a table with the given title and metrics.

    Parameters:
    title (str): The title of the table.
    metrics (dict): A dictionary where keys are metric names and values are metric values.

    Returns:
    None
    """
    neg_lists = []
    N = len(mat)

    # For each row in the matrix...
    for i in range(N):
        array = np.arange(N)
        np.random.shuffle(array)
        neg_list = list(array[:32])
        # If the current row index is in the negative list, remove it.
        if i in neg_list:
            neg_list.remove(i)
        else:
            neg_list.pop()
        # Append the negative list to the list of negative lists.
        neg_lists.append(neg_list)

    # Initialize a counter for the number of hits.
    hits = 0

    # For each row in the matrix...
    for rowid in range(len(mat)):
        row = mat[rowid]
        negsocres = list(row[neg_lists[rowid]])
        count_large = 0

        # For each score in the negative scores...
        for one_score in negsocres:
            # If the score at the current row index is less than this score, increment the counter.
            if row[rowid] < one_score:
                count_large += 1

        # If the number of scores that are larger is less than or equal to k_value - 1, increment the hits counter.
        if count_large <= k_value - 1:
            hits += 1

    # Return the number of hits.
    return hits


def main(args):
    """
    This function is the main entry point of the script. It loads embeddings, calculates similarities,
    and prints recall metrics.

    Parameters:
    args (argparse.Namespace): The command-line arguments.

    Returns:
    None
    """

    # Retrieve the list of experiment directories, retrieval type, and protocol from the command-line arguments.
    expdirs = args.expdirs
    retrieval_type = args.retrieval_type
    protocal = args.protocal

    # Define a list of values for K (the number of top elements to consider in the recall calculation).
    K_list = [1, 2, 3, 5, 10]

    # Initialize a list of lists to store the recall values for each experiment directory.
    RecK_list = [[] for i in expdirs]

    # For each experiment directory...
    for index in range(len(expdirs)):
        # Retrieve the current experiment directory.
        exp_dir = expdirs[index]

        # Set the directory containing the embeddings to the experiment directory.
        emb_dir = exp_dir

        # Define the paths to the motion, text, and SBERT embeddings.
        motion_emb_dir = os.path.join(emb_dir, "humanml3d_motion_embs.npy")
        text_emb_dir = os.path.join(emb_dir, "humanml3d_text_embs.npy")
        # sbert_emb_dir = os.path.join(emb_dir, "sbert_embedding.npy")

        # Load the embeddings from the files.
        text_embedding = np.load(text_emb_dir)
        motion_embedding = np.load(motion_emb_dir)
        # sbert_embedding = np.load(sbert_emb_dir)

        # Normalize the SBERT embeddings.
        # sbert_embedding = sbert_embedding / np.linalg.norm(
        #     sbert_embedding, axis=1, keepdims=True
        # )

        # Calculate the text-to-motion and motion-to-text similarity matrices.
        T2M_logits = text_embedding @ (motion_embedding.T)
        M2T_logits = motion_embedding @ (text_embedding.T)

        # Depending on the retrieval type, select the appropriate similarity matrix.
        if retrieval_type == "T2M":
            logits_matrix = T2M_logits
        elif retrieval_type == "M2T":
            logits_matrix = M2T_logits

        # Calculate the SBERT similarity matrix.
        # sbert_sim = sbert_embedding @ (sbert_embedding.T)
        # N = sbert_embedding.shape[0]
        N = text_embedding.shape[0]

        # Initialize a list to store the target lists.
        target_list = []

        # If the protocol is A or B...
        if protocal == "A" or protocal == "B":
            for i in range(N):
                target_list_i = []
                for j in range(N):
                    # If the protocol is A and the other embedding is the same as the current one, add it to the target list.
                    if protocal == "A":
                        if j == i:
                            target_list_i.append(j)
                    # If the protocol is B and the SBERT similarity between the other embedding and the current one is at least 0.9, add it to the target list.
                    elif protocal == "B":
                        if sbert_sim[i][j] >= 0.9:
                            target_list_i.append(j)

                # Add the target list for this embedding to the list of target lists.
                target_list.append(target_list_i)

            # Sort the indices of the embeddings in the similarity matrix in descending order of similarity.
            sorted_embedding_idx = np.argsort(-logits_matrix, axis=1)
            i = 0
            for k in K_list:
                hits = 0
                for i in range(N):
                    # Get the top K embeddings in the sorted list.
                    pred = list(sorted_embedding_idx[i][:k])
                    # If any of the top K embeddings are in the target list for this embedding, increment the hits counter.
                    for item in pred:
                        if item in target_list[i]:
                            hits += 1
                            break
                # Calculate the recall for this value of K and add it to the list of recall values for this experiment directory.
                RecK_list[index].append("%.3f" % (100.0 * (hits / N)))
                i += 1

        # If the protocol is D...
        elif protocal == "D":
            for k in K_list:
                # Calculate the negative recall for this value of K and add it to the list of recall values for this experiment directory.
                hits = neg_recall(logits_matrix, k)
                RecK_list[index].append("%.3f" % (100.0 * (hits / N)))

    # To markdown table format
    print("|   Metrics   |", end="  ")
    for k in K_list:
        print(f"Recall @{k} |", end="  ")
    print()
    print("|-------------|", end="  ")
    for k in K_list:
        print("--------- |", end="  ")
    print()
    for l in range(len(RecK_list)):
        exp_name = expdirs[l].split("/")[-2]
        print(f"|{exp_name} |", end="  ")
        for item in RecK_list[l]:
            print(item, end="   |")
        print("")
    print()


if __name__ == "__main__":
    # Create a parser for command-line arguments
    parser = argparse.ArgumentParser()

    # Add arguments for retrieval type, protocol, and experiment directories
    parser.add_argument("--retrieval_type", default="T2M", type=str, help="T2M or M2T")
    parser.add_argument("--protocal", default="A", type=str, help="A, B, or D")
    parser.add_argument("--expdirs", nargs="+")

    # Parse the command-line arguments
    args = parser.parse_args()

    # Call the main function with the parsed arguments
    main(args)

What I did is simply remove the requirement of sbert. Am I missing something here? Retrain the TMA model is a bit costly to me.

Looking forward to your reply!

@LinghaoChan
Copy link
Collaborator

@no-Seaweed The difference between protocol A and B is the usage of sbert filtering or not.

@no-Seaweed
Copy link

@no-Seaweed The difference between protocol A and B is the usage of sbert filtering or not.

Thank you for your reply. I was able to get those embedding npy by executing

python -m test --cfg configs/configs_temos/H3D-TMR.yaml --cfg_assets configs/assets.yaml --nodebug

with minor modification in code.

I got the following result:
protocal A
T2M

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
test 5.050 9.883 13.741 19.744 31.274

M2T

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
test 6.545 11.617 15.475 22.150 33.052

protocal B
T2M

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
test 9.059 15.258 20.416 27.893 40.399

M2T

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
test 11.422 16.840 21.088 27.633 38.470

protocal D
T2M

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
test 70.243 83.138 88.578 93.000 96.294

M2T

Metrics Recall @1 Recall @2 Recall @3 Recall @5 Recall @10
test 70.178 83.875 88.882 92.826 96.272

Looks good, though the numbers are abit off comparing to the chart in the sup.

@LinghaoChan
Copy link
Collaborator

@no-Seaweed Seems good. The jitters are normal.

@xjli360
Copy link

xjli360 commented Sep 29, 2024

@Wenretium Thanks for the reminder. I detail the implementation here.

infoloss = InfoNCE(0.1)
filter_model = SentenceTransformer('sentence-transformers/paraphrase-MiniLM-L6-v2')
# if TMA supervision 
if args.supervision:
    # generated motions
    all_supervise_motion = torch.cat(gen_supervise_tensor_list, dim = 0).cuda()
    # motion length = token length * 4 (due to upsampling rate is 4)
    full_m_tokens_len = (m_tokens_len.detach() * 4).tolist()
    # get TMR_motion_embedding
    TMR_motion_embedding = t2m_TMR_motionencoder(all_supervise_motion, full_m_tokens_len).loc
    # get TMR_text_embedding
    TMR_text_embedding = t2m_TMR_textencoder(texts).loc
    with torch.no_grad():
        text_embedding = filter_model.encode(texts)
        text_embedding = torch.tensor(text_embedding).cuda()
        normalized = F.normalize(text_embedding, p=2, dim=1)
        # cos similarity
        emb_dist = normalized.matmul(normalized.T)
    loss_infonce = infoloss((TMR_motion_embedding, TMR_text_embedding), emb_dist)
    
    all_loss = loss_cls + args.lambdainfo * loss_infonce

Welcome any question.

hi , i have a question, How loss_infonce performs gradient return?
as we all know, the input of vq.decoder must be of int type ( they represent indexes ), but the int type is not differentiable. How did you end up? Can you provide more code references?
IF i input float type, i got this:

self.vq_model.forward_decoder(_pred_ids.unsqueeze_(-1))
*** RuntimeError: gather(): Expected dtype int64 for index

@LinghaoChan
Copy link
Collaborator

Why is loss_infonce related to the vq? @xjli360

@xjli360
Copy link

xjli360 commented Sep 30, 2024

Why is loss_infonce related to the vq? @xjli360

Because gpt predicts that code token requires vq-decoder to get motion, and then motion passes through t2m_TMR_motionencoder to get latent and music lantent to get the infonce loss.

@LinghaoChan
Copy link
Collaborator

Why is loss_infonce related to the vq? @xjli360

Because gpt predicts that code token requires vq-decoder to get motion, and then motion passes through t2m_TMR_motionencoder to get latent and music lantent to get the infonce loss.

I see what you mean. When introducing the TMA supervision, it cannot accept the codes directly. It should process the motion decoded by the codes. As you know, directly using the max probability code is not differentiable. Therefore, we activate the GPT logits via gumbel softmax, not max. @xjli360

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants