CLIP-Fields: Weakly Supervised Semantic Fields for Robotic Memory

What are they doing?

They found a way to help a robot make a “map” of the world around it in terms of multimodal scene encodings. Then they store these multimodal scene encodings and their respective labels (“chair”) on a database which is differentiable and is also easily searchable.

How are they doing it?

  1. In order to collect data, they used an RGB-D from which the following data was collected:

They used an iphone 13 pro with a LiDAR sensor for depth image sequences.

  1. Robot execution pipeline as seen here

The robot then looks through it’s database and finds which data-point is such that there is max similarity between the semantic representation and the visually aligned encoding.

Let us walk through some part of the code that the author wrote. Read the comments given below along with the code:

This is where they generate encodings from a given text query from the model.

def calculate_clip_and_st_embeddings_for_queries(queries):
    ## queries is a string which first gets tokenized
    all_clip_queries = clip.tokenize(queries)
    with torch.no_grad():

        ## encode text with CLIP to get "visually aligned" encoding
        all_clip_tokens = model.encode_text(all_clip_queries.to(DEVICE)).float()
        ## normalize encodings, dont exactly know why. Maybe they just want the directional information, kinda like a unit vector.
        all_clip_tokens = F.normalize(all_clip_tokens, p=2, dim=-1)

        ## encode text with SBERT to get a nice text encoding (idk semantic?)
        all_st_tokens = torch.from_numpy(sentence_model.encode(queries))
        ## normalize encodings, dont exactly know why. Maybe they just want the directional information, kinda like a unit vector.
        all_st_tokens = F.normalize(all_st_tokens, p=2, dim=-1).to(DEVICE)

    return all_clip_tokens, all_st_tokens

query = "Warm up my lunch"
clip_text_tokens, st_text_tokens = calculate_clip_and_st_embeddings_for_queries([query])
print("query =", query)
print("tokens shape =", clip_text_tokens.shape)

This is where they use the multimodal encodings from the query to perform a search in the robot’s memory:

def find_alignment_over_model(label_model, queries, dataloader, visual=False):
    ## This is the fn that we just discussed about
    clip_text_tokens, st_text_tokens = calculate_clip_and_st_embeddings_for_queries(queries)

    # We give different weights to visual and semantic alignment 
    # for different types of queries
    if visual:
        vision_weight = 10.0
        text_weight = 1.0
    else:
        vision_weight = 1.0
        text_weight = 10.0
    point_opacity = []

    ## iterate over the entire dataset (wow thats gonna be computationally expensive)
    with torch.no_grad():
        for data in tqdm.tqdm(dataloader, total=len(dataloader)):
            # Find alignmnents with the vectors

            ## for a single dataset instance, we generate the semantic and the visual encodings and normalise them
            predicted_label_latents, predicted_image_latents = label_model(data.to(DEVICE))
            data_text_tokens = F.normalize(predicted_label_latents, p=2, dim=-1).to(DEVICE)
            data_visual_tokens = F.normalize(predicted_image_latents, p=2, dim=-1).to(DEVICE)


            ## note that similarity = dot product
            ## calculate similarity between query text encoding and dataset instance label encoding
            text_alignment = data_text_tokens @ st_text_tokens.T

            ## calculate similarity between query visual encoding and dataset instance CLIP encoding 
            visual_alignment = data_visual_tokens @ clip_text_tokens.T

            ## some sort of a weighted sum to prioritize one over the other
            total_alignment = (text_weight * text_alignment) + (vision_weight * visual_alignment)
            total_alignment /= (text_weight + vision_weight)

            ## append all of these weighted similarity scores into a list
            point_opacity.append(total_alignment)

    point_opacity = torch.cat(point_opacity).T
    print(point_opacity.shape)
    return point_opacity

Dot product is the same as taking the cosine similarity between the query and the point embeddings.

My opinions/further discussion