Given a Huggingface model, e.g.
from transformers import AutoModelForSequenceClassification model = AutoModelForSequenceClassification.from_pretrained("bert-large-uncased", num_labels=2)
I can access a layer’s tensor as such:
# Shape [1024, 1024] model.state_dict()["bert.encoder.layer.0.attention.self.query.weight"]
[out]:
tensor([[ 0.0167, -0.0422, -0.0425, ..., 0.0302, -0.0341, 0.0251], [ 0.0323, 0.0347, -0.0041, ..., -0.0722, 0.0031, -0.0351], [ 0.0387, -0.0293, -0.0694, ..., 0.0492, 0.0201, -0.0727], ..., [ 0.0035, 0.0081, -0.0337, ..., 0.0460, 0.0268, 0.0747], [ 0.0513, 0.0131, 0.0735, ..., -0.0127, 0.0144, -0.0400], [ 0.0385, 0.0013, -0.0272, ..., 0.0148, 0.0399, 0.0339]])
Given the another tensor of the same shape that I’ve pre-defined from somewhere else, in this case, for illustration, I’m creating a random tensor but this can be any tensor that is pre-defined.
import torch replacement_layer = torch.rand([1024, 1024])
Note: I’m not trying to replace a layer with a random tensor but replace it with a pre-defined one.
When I try to do this to replace the layer tensor through the state_dict()
, it didn’t seem to work:
import torch from transformers import AutoModelForSequenceClassification # The model with a layer that we want to replace. model = AutoModelForSequenceClassification.from_pretrained("bert-large-uncased", num_labels=2) # A replacement layer. replacement_layer = torch.rand([1024, 1024]) # Replacing the layer in the statedict. model.state_dict()["bert.encoder.layer.0.attention.self.query.weight"] = replacement_layer # Check that the layer is replaced. No, it is not =( assert torch.equal( model.state_dict()["bert.encoder.layer.0.attention.self.query.weight"], replacement_layer)
How to replace PyTorch model layer’s tensor with another layer of same shape in Huggingface model?
Advertisement
Answer
A state_dict is something special. It is an on-the-fly copy more than it is the actual contents of a model, if that makes sense.
You can directly access a model’s layers by dot notation. Note that 0
often indicates an index rather than a string. You’ll also need to transform your tensor into a torch Parameter for it to work within a model.
So this should work:
model.bert.encoder.layer[0].attention.self.query.weight = torch.nn.Parameter(replacement_layer)
or in full:
# Note I used the base model for testing import torch from transformers import AutoModelForSequenceClassification # The model with a layer that we want to replace. model: torch.nn.Module = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) # A replacement layer. replacement_layer = torch.rand([768, 768]) model.bert.encoder.layer[0].attention.self.query.weight = torch.nn.Parameter(replacement_layer) # Check that the layer is replaced assert torch.equal( model.state_dict()["bert.encoder.layer.0.attention.self.query.weight"], replacement_layer) assert torch.equal( model.bert.encoder.layer[0].attention.self.query.weight, replacement_layer) print("Succes!")