Skip to content

Commit

Permalink
Address issue dotnet#510
Browse files Browse the repository at this point in the history
Adding 'num_batches_tracked' to BatcNorm{123}d
  • Loading branch information
NiklasGustafsson committed Feb 10, 2022
1 parent 3fc452d commit 0316808
Show file tree
Hide file tree
Showing 14 changed files with 229 additions and 13 deletions.
Binary file added bug510.dat
Binary file not shown.
24 changes: 24 additions & 0 deletions bug510.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import src.Python.exportsd as exportsd

class BasicConv1d(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super().__init__()

self.stack = torch.nn.Sequential(
torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, bias=False, **kwargs),
torch.nn.BatchNorm1d(out_channels),
torch.nn.ReLU(inplace=True)
)

def forward(self, x):
return self.stack(x)

if __name__ == '__main__':
# Create model
model = BasicConv1d(1, 32)

#Export model to .dat file for ingestion into TorchSharp
f = open("bug510.dat", "wb")
exportsd.save_state_dict(model.to("cpu").state_dict(), f)
f.close()
4 changes: 4 additions & 0 deletions src/Native/LibTorchSharp/THSNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ EXPORT_API(void) THSNN_BatchNorm1d_set_var(const NNModule module, const Tens
EXPORT_API(void) THSNN_BatchNorm2d_set_var(const NNModule module, const Tensor weight);
EXPORT_API(void) THSNN_BatchNorm3d_set_var(const NNModule module, const Tensor weight);

EXPORT_API(Tensor) THSNN_BatchNorm1d_get_batches(const NNModule module);
EXPORT_API(Tensor) THSNN_BatchNorm2d_get_batches(const NNModule module);
EXPORT_API(Tensor) THSNN_BatchNorm3d_get_batches(const NNModule module);

EXPORT_API(NNModule) THSNN_InstanceNorm1d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule);
EXPORT_API(Tensor) THSNN_InstanceNorm1d_forward(const NNModule module, const Tensor tensor);
EXPORT_API(NNModule) THSNN_InstanceNorm2d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule);
Expand Down
27 changes: 27 additions & 0 deletions src/Native/LibTorchSharp/THSNormalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,15 @@ Tensor THSNN_BatchNorm1d_get_var(const NNModule module)
return nullptr;
}

Tensor THSNN_BatchNorm1d_get_batches(const NNModule module)
{
CATCH(
auto v = (*module)->as<torch::nn::BatchNorm1d>()->num_batches_tracked;
return v.defined() ? ResultTensor(v) : nullptr;
);
return nullptr;
}

void THSNN_BatchNorm1d_set_mean(const NNModule module, const Tensor bias)
{
CATCH(
Expand Down Expand Up @@ -259,6 +268,15 @@ Tensor THSNN_BatchNorm2d_get_var(const NNModule module)
return nullptr;
}

Tensor THSNN_BatchNorm2d_get_batches(const NNModule module)
{
CATCH(
auto v = (*module)->as<torch::nn::BatchNorm2d>()->num_batches_tracked;
return v.defined() ? ResultTensor(v) : nullptr;
);
return nullptr;
}

void THSNN_BatchNorm2d_set_mean(const NNModule module, const Tensor bias)
{
CATCH(
Expand Down Expand Up @@ -316,6 +334,15 @@ Tensor THSNN_BatchNorm3d_get_var(const NNModule module)
return nullptr;
}

Tensor THSNN_BatchNorm3d_get_batches(const NNModule module)
{
CATCH(
auto v = (*module)->as<torch::nn::BatchNorm3d>()->num_batches_tracked;
return v.defined() ? ResultTensor(v) : nullptr;
);
return nullptr;
}

void THSNN_BatchNorm3d_set_mean(const NNModule module, const Tensor bias)
{
CATCH(
Expand Down
19 changes: 19 additions & 0 deletions src/TorchSharp/NN/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,25 @@ public virtual bool has_parameter(string target)
return false;
}

/// <summary>
/// Returns the buffer given by target if it exists, otherwise throws an error.
/// </summary>
/// <param name="target">The fully-qualified string name of the buffer to look for.</param>
/// <returns>The tensor referenced by target</returns>
public virtual Tensor get_buffer(string target)
{
if (_internal_buffers.TryGetValue(target, out var parameter)) {
return parameter;
}
foreach (var child in named_children().Where(nc => target.StartsWith(nc.name))) {
var prefix = child.name + ".";
var p = child.module.get_buffer(target.Remove(0, prefix.Length));
if (p is not null)
return p;
}
return null;
}

/// <summary>
/// Returns the parameter given by target if it exists, otherwise throws an error.
/// </summary>
Expand Down
18 changes: 14 additions & 4 deletions src/TorchSharp/NN/Normalization/BatchNorm1D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ public override Tensor forward(Tensor tensor)
[DllImport("LibTorchSharp")]
private static extern IntPtr THSNN_BatchNorm1d_get_var(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
private static extern IntPtr THSNN_BatchNorm1d_get_batches(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
private static extern void THSNN_BatchNorm1d_set_mean(torch.nn.Module.HType module, IntPtr weight);
[DllImport("LibTorchSharp")]
private static extern void THSNN_BatchNorm1d_set_var(torch.nn.Module.HType module, IntPtr weight);
Expand Down Expand Up @@ -78,7 +80,7 @@ public Parameter weight {
}
}

public Parameter? running_mean {
public Tensor? running_mean {
get {
var res = THSNN_BatchNorm1d_get_mean(handle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; }
Expand All @@ -87,11 +89,11 @@ public Parameter? running_mean {
set {
THSNN_BatchNorm1d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle));
torch.CheckForErrors();
ConditionallyRegisterParameter("bias", value);
ConditionallyRegisterBuffer("running_mean", value);
}
}

public Parameter? running_var {
public Tensor? running_var {
get {
var res = THSNN_BatchNorm1d_get_var(handle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; }
Expand All @@ -100,7 +102,15 @@ public Parameter? running_var {
set {
THSNN_BatchNorm1d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle));
torch.CheckForErrors();
ConditionallyRegisterParameter("bias", value);
ConditionallyRegisterBuffer("running_var", value);
}
}

public Tensor? num_batches_tracked {
get {
var res = THSNN_BatchNorm1d_get_batches(handle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; }
return new Parameter(res);
}
}

Expand Down
19 changes: 15 additions & 4 deletions src/TorchSharp/NN/Normalization/BatchNorm2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

using static TorchSharp.torch;

#nullable enable
namespace TorchSharp
{
using Modules;
Expand Down Expand Up @@ -46,6 +47,8 @@ public override Tensor forward(Tensor tensor)
[DllImport("LibTorchSharp")]
private static extern IntPtr THSNN_BatchNorm2d_get_var(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
private static extern IntPtr THSNN_BatchNorm2d_get_batches(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
private static extern void THSNN_BatchNorm2d_set_mean(torch.nn.Module.HType module, IntPtr weight);
[DllImport("LibTorchSharp")]
private static extern void THSNN_BatchNorm2d_set_var(torch.nn.Module.HType module, IntPtr weight);
Expand Down Expand Up @@ -76,7 +79,7 @@ public Parameter weight {
}
}

public Parameter running_mean {
public Tensor? running_mean {
get {
var res = THSNN_BatchNorm2d_get_mean(handle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; }
Expand All @@ -85,11 +88,11 @@ public Parameter running_mean {
set {
THSNN_BatchNorm2d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle));
torch.CheckForErrors();
ConditionallyRegisterParameter("bias", value);
ConditionallyRegisterBuffer("running_mean", value);
}
}

public Parameter running_var {
public Tensor? running_var {
get {
var res = THSNN_BatchNorm2d_get_var(handle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; }
Expand All @@ -98,7 +101,15 @@ public Parameter running_var {
set {
THSNN_BatchNorm2d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle));
torch.CheckForErrors();
ConditionallyRegisterParameter("bias", value);
ConditionallyRegisterBuffer("running_var", value);
}
}

public Tensor? num_batches_tracked {
get {
var res = THSNN_BatchNorm2d_get_batches(handle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; }
return new Parameter(res);
}
}

Expand Down
18 changes: 14 additions & 4 deletions src/TorchSharp/NN/Normalization/BatchNorm3D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public override Tensor forward(Tensor tensor)
[DllImport("LibTorchSharp")]
private static extern IntPtr THSNN_BatchNorm3d_get_var(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
private static extern IntPtr THSNN_BatchNorm3d_get_batches(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
private static extern void THSNN_BatchNorm3d_set_mean(torch.nn.Module.HType module, IntPtr weight);
[DllImport("LibTorchSharp")]
private static extern void THSNN_BatchNorm3d_set_var(torch.nn.Module.HType module, IntPtr weight);
Expand Down Expand Up @@ -77,7 +79,7 @@ public Parameter weight {
}
}

public Parameter? running_mean {
public Tensor? running_mean {
get {
var res = THSNN_BatchNorm3d_get_mean(handle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; }
Expand All @@ -86,11 +88,11 @@ public Parameter? running_mean {
set {
THSNN_BatchNorm3d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle));
torch.CheckForErrors();
ConditionallyRegisterParameter("bias", value);
ConditionallyRegisterBuffer("running_mean", value);
}
}

public Parameter? running_var {
public Tensor? running_var {
get {
var res = THSNN_BatchNorm3d_get_var(handle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; }
Expand All @@ -99,7 +101,15 @@ public Parameter? running_var {
set {
THSNN_BatchNorm3d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle));
torch.CheckForErrors();
ConditionallyRegisterParameter("bias", value);
ConditionallyRegisterBuffer("running_var", value);
}
}

public Tensor? num_batches_tracked {
get {
var res = THSNN_BatchNorm3d_get_batches(handle);
if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; }
return new Parameter(res);
}
}

Expand Down
11 changes: 11 additions & 0 deletions src/TorchSharp/NN/Sequential.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ internal void Add(torch.nn.Module module)
}
}

public override IEnumerable<(string name, Tensor buffer)> named_buffers(bool recurse = true)
{
if (!recurse) yield break;

for (var i = 0; i < _names.Count; i++) {
foreach (var (n, p) in _modules[i].named_buffers(true)) {
yield return ($"{_names[i]}.{n}", p);
}
}
}

public override IEnumerable<(string name, torch.nn.Module module)> named_children()
{
for (var i = 0; i < _names.Count; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
</PropertyGroup>

<ItemGroup>
<Compile Include="..\TorchSharpTest\GlobalSuppressions.cs" Link="GlobalSuppressions.cs" />
<Compile Include="..\TorchSharpTest\GlobalSuppressions.cs" Link="GlobalSuppressions.cs">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</Compile>
<Compile Include="..\TorchSharpTest\NN.cs" Link="NN.cs" />
<Compile Include="..\TorchSharpTest\TestDistributions.cs" Link="TestDistributions.cs" />
<Compile Include="..\TorchSharpTest\TestLoadSave.cs" Link="TestLoadSave.cs" />
Expand All @@ -23,6 +25,12 @@
<Compile Include="..\TorchSharpTest\TestDisposeScopes.cs" Link="TestDisposeScopes.cs" />
</ItemGroup>

<ItemGroup>
<Content Include="..\TorchSharpTest\bug510.dat" Link="bug510.dat">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\src\TorchSharp\TorchSharp.csproj" />
</ItemGroup>
Expand Down
58 changes: 58 additions & 0 deletions test/TorchSharpTest/TestTorchTensorBugs.cs
Original file line number Diff line number Diff line change
Expand Up @@ -443,5 +443,63 @@ class Module500 : Module

public override torch.Tensor forward(torch.Tensor t) => bn1.forward(t);
}

[Fact]
public void ValidateIssue510()
{
var model = new Module510(1, 32);
model.forward(torch.randn(16, 1, 32));

var w0 = model.get_parameter("stack.0.weight").clone();
var w1 = model.get_parameter("stack.1.weight").clone();
var b1 = model.get_parameter("stack.1.bias").clone();
var rm = model.get_buffer("stack.1.running_mean").clone();
var rv = model.get_buffer("stack.1.running_var").clone();
var nm = model.get_buffer("stack.1.num_batches_tracked").clone();

model.load(@".\bug510.dat");

var w0_ = model.get_parameter("stack.0.weight");
var w1_ = model.get_parameter("stack.1.weight");
var b1_ = model.get_parameter("stack.1.bias");
var rm_ = model.get_buffer("stack.1.running_mean");
var rv_ = model.get_buffer("stack.1.running_var");
var nm_ = model.get_buffer("stack.1.num_batches_tracked");

Assert.NotEqual(w0, w0_);
Assert.NotEqual(w1, w1_);
Assert.NotEqual(b1, b1_);
Assert.NotEqual(rm, rm_);
Assert.NotEqual(rv, rv_);
Assert.Equal(1, nm.item<long>());
Assert.Equal(0, nm_.item<long>());
}

internal class Module510 : Module
{
private readonly Module stack;

public Module510(int in_channels, int out_channels, int kernel_size=3, int stride = 1, int padding = 0) : base(String.Empty)
{
var temp = BatchNorm1d(out_channels);
this.stack = Sequential(
Conv1d(in_channels, out_channels, 3, stride: stride, padding: padding, bias: false),
temp,
ReLU(inPlace: true)
);

temp.weight = Parameter(torch.randn(temp.weight.shape));
temp.bias = Parameter(torch.randn(temp.bias.shape));
if (temp.running_mean is not null) temp.running_mean = torch.randn(temp.running_mean.shape);
if (temp.running_var is not null) temp.running_var = torch.randn(temp.running_var.shape);

this.RegisterComponents();
}

public override torch.Tensor forward(torch.Tensor t)
{
return this.stack.forward(t);
}
}
}
}
10 changes: 10 additions & 0 deletions test/TorchSharpTest/TorchSharpTest.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
<VSTestResultsDirectory>$(OutputPath)</VSTestResultsDirectory>
</PropertyGroup>

<ItemGroup>
<None Remove="bug510.dat" />
</ItemGroup>

<ItemGroup>
<Content Include="bug510.dat">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</Content>
</ItemGroup>

<ItemGroup>
<None Update="xunit.runner.json">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
Expand Down
Binary file added test/TorchSharpTest/bug510.dat
Binary file not shown.
Loading

0 comments on commit 0316808

Please sign in to comment.