-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Module.Load throws Mismatched state_dict sizes exception on BatchNorm1d #510
Comments
Thanks for sharing that. Is there a small repro that exhibits the problem? I thought the repro I came up with for the previous BatchNorm1d bug covered it, but apparently not. |
You fixed the issue from before where BatchNorm1d could not evaluate tensors with a single batch, so that's good! Just need to get my model loaded and I think I'll be good to go. From my Python training code base: class BasicConv1d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.stack = nn.Sequential(
nn.Conv1d(in_channels, out_channels, bias=False, **kwargs),
nn.BatchNorm1d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.stack(x)
if __name__ == '__main__':
# Create model
model = BasicConv1d(1, 32)
#Export model to .dat file for ingestion into TorchSharp
f = open("C:\model.dat", "wb")
exportsd.save_state_dict(model.to("cpu").state_dict(), f)
f.close() Then import to TorchSharp. #Create the model
BasicConv1d module = new BasicConv1d();
#Load the model from .dat file.
module.load("C:\model.dat"); #Exception occurs here |
Thanks for that. I'll take a look at it later today. |
Adding 'num_batches_tracked' to BatcNorm{123}d
@FusionCarcass -- I have a fix in PR, and will release it with another important fix (parameter groups). |
@FusionCarcass Can you confirm that the bugs are addressed. If possible, do consider provide a simple example so others could expand the workflow of saving model from pytorch and loading that in TorchSharp |
@GeorgeS2019 I can take a look at it. Is this fix pushed out in a new Nuget version or do I need to build from source? So far I've just been waiting on the nuget releases. |
@FusionCarcass Thank you for testing and if possible we need more use cases to load pytorch save dict states with Torchsharp . |
class BasicConv1d(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()
self.stack = torch.nn.Sequential(
torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, bias=False, **kwargs),
torch.nn.BatchNorm1d(out_channels),
torch.nn.ReLU(inplace=True)
)
def forward(self, x):
return self.stack(x) internal class BasicConv1d: Module
{
private readonly Module stack;
public BasicConv1d(int in_channels, int out_channels, int kernel_size=3, int stride = 1, int padding = 0) : base(String.Empty)
{
var temp = BatchNorm1d(out_channels);
this.stack = Sequential(
Conv1d(in_channels, out_channels, 3, stride: stride, padding: padding, bias: false),
temp,
ReLU(inPlace: true)
);
temp.weight = Parameter(torch.randn(temp.weight.shape));
temp.bias = Parameter(torch.randn(temp.bias.shape));
if (temp.running_mean is not null) temp.running_mean = torch.randn(temp.running_mean.shape);
if (temp.running_var is not null) temp.running_var = torch.randn(temp.running_var.shape);
this.RegisterComponents();
}
public override torch.Tensor forward(torch.Tensor t)
{
return this.stack.forward(t);
}
} Loading save dict in pytorch # Create model
model = BasicConv1d(1, 32) Loading save dict in torchsharppublic void ValidateIssue510()
{
// Create model
var model = new BasicConv1d(1, 32);
model.forward(torch.randn(16, 1, 32));
var w0 = model.get_parameter("stack.0.weight").clone();
var w1 = model.get_parameter("stack.1.weight").clone();
var b1 = model.get_parameter("stack.1.bias").clone();
var rm = model.get_buffer("stack.1.running_mean").clone();
var rv = model.get_buffer("stack.1.running_var").clone();
var nm = model.get_buffer("stack.1.num_batches_tracked").clone();
model.load("bug510.dat");
var w0_ = model.get_parameter("stack.0.weight");
var w1_ = model.get_parameter("stack.1.weight");
var b1_ = model.get_parameter("stack.1.bias");
var rm_ = model.get_buffer("stack.1.running_mean");
var rv_ = model.get_buffer("stack.1.running_var");
var nm_ = model.get_buffer("stack.1.num_batches_tracked");
Assert.NotEqual(w0, w0_);
Assert.NotEqual(w1, w1_);
Assert.NotEqual(b1, b1_);
Assert.NotEqual(rm, rm_);
Assert.NotEqual(rv, rv_);
Assert.Equal(1, nm.item<long>());
Assert.Equal(0, nm_.item<long>());
} |
@FusionCarcass it seems num_batches_tracked was previously missing and now added. num_batches_tracked is used to update running_mean and running_var //I wonder it is possible to register a new parameter not defined internally?
temp.register_parameter("num_batches_tracked", new Parameter(temp.state_dict()["num_batches_tracked"], requires_grad: false)); |
When loading a model from a .dat file exported from python, the Module.Load method throws the exception below. I printed out all of the registered parameters, and the only ones that didn't show up are BatchNorm1d parameters: running_mean, running_var, and num_batches_tracked.
I tried to work around this problem by registering those parameters with the register_parameter function. That eliminates the exception below, but I run into a different issue where the bias is not loaded correctly after registering the other parameters. The bias parameter is still set to torch.zeros(N).
The load method should probably take into consideration the registered buffers if we are not going to consider running_mean and running_var parameters.
Here are the fixes I tried.
The text was updated successfully, but these errors were encountered: