From 031680820d27df9658f862499dccc8684e965719 Mon Sep 17 00:00:00 2001 From: Niklas Gustafsson Date: Thu, 10 Feb 2022 10:30:09 -0800 Subject: [PATCH] Address issue #510 Adding 'num_batches_tracked' to BatcNorm{123}d --- bug510.dat | Bin 0 -> 1036 bytes bug510.py | 24 ++++++++ src/Native/LibTorchSharp/THSNN.h | 4 ++ src/Native/LibTorchSharp/THSNormalization.cpp | 27 ++++++++ src/TorchSharp/NN/Module.cs | 19 ++++++ .../NN/Normalization/BatchNorm1D.cs | 18 ++++-- .../NN/Normalization/BatchNorm2D.cs | 19 ++++-- .../NN/Normalization/BatchNorm3D.cs | 18 ++++-- src/TorchSharp/NN/Sequential.cs | 11 ++++ .../TorchSharpTest.WithCudaBinaries.csproj | 10 ++- test/TorchSharpTest/TestTorchTensorBugs.cs | 58 ++++++++++++++++++ test/TorchSharpTest/TorchSharpTest.csproj | 10 +++ test/TorchSharpTest/bug510.dat | Bin 0 -> 1036 bytes test/TorchSharpTest/bug510.py | 24 ++++++++ 14 files changed, 229 insertions(+), 13 deletions(-) create mode 100644 bug510.dat create mode 100644 bug510.py create mode 100644 test/TorchSharpTest/bug510.dat create mode 100644 test/TorchSharpTest/bug510.py diff --git a/bug510.dat b/bug510.dat new file mode 100644 index 0000000000000000000000000000000000000000..ba41d979cebf6aa40491edb043febb85dc9d4257 GIT binary patch literal 1036 zcmZSMD=tY)&ek)~D^JZ#&nRJIR$yda;q0=n>~`5ckI7s1&A5ET_SodeeI{ye?be+5 zw>MztJKGlr59~b_aojG2@1WhBDb98wJ=}I{;_urXnDgK6+JS5PGTaU97EE|+XK`1- zF5~=8JCC&ub|+L=>=>AG_FR#iZI`h4rd`B$#eK8mAMf)DW7_{9W9c5A{=NG&Y}@Tl zcz&~ExPEQ#g~&+T5H|}u5&QXeTeh#=>rr~y?tCNL{zsJz`^C-(?)Q*9zgOYhAv^8Q zKlbq?bMKE=7T#Y_zjI%MXqe57?jSpkW&i9JO*j_^CuHA{kmc1nx=h_`%X5N1!{;%Bx{|kG62p_in zafj9ZhqKkb37xNPw=Ouc?}78xeG?MAZNJRdu+#DRve#ibgWZ8!W;VwJ)AyxGHQKpE zvD+`7eQuw}H%Ge}!dL8^`FHQjnfh|C!q)$LR&2&Y~mFuAUbE z`vvOu?A>*wabMEXh5Mj!VhD>9Mg<0j2K#};JdjjksF#$PSj+};&%go^n0ty!^YSwD y(&KYe6Y~avyM)o*RhC!;cGkc~pES%_d8N7WNr@%N8L7qbB}KrDpPIq~k4XTu4j3{3 literal 0 HcmV?d00001 diff --git a/bug510.py b/bug510.py new file mode 100644 index 000000000..7fa461630 --- /dev/null +++ b/bug510.py @@ -0,0 +1,24 @@ +import torch +import src.Python.exportsd as exportsd + +class BasicConv1d(torch.nn.Module): + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__() + + self.stack = torch.nn.Sequential( + torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, bias=False, **kwargs), + torch.nn.BatchNorm1d(out_channels), + torch.nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.stack(x) + +if __name__ == '__main__': + # Create model + model = BasicConv1d(1, 32) + + #Export model to .dat file for ingestion into TorchSharp + f = open("bug510.dat", "wb") + exportsd.save_state_dict(model.to("cpu").state_dict(), f) + f.close() \ No newline at end of file diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h index 89ea6c58d..7949a85d5 100644 --- a/src/Native/LibTorchSharp/THSNN.h +++ b/src/Native/LibTorchSharp/THSNN.h @@ -182,6 +182,10 @@ EXPORT_API(void) THSNN_BatchNorm1d_set_var(const NNModule module, const Tens EXPORT_API(void) THSNN_BatchNorm2d_set_var(const NNModule module, const Tensor weight); EXPORT_API(void) THSNN_BatchNorm3d_set_var(const NNModule module, const Tensor weight); +EXPORT_API(Tensor) THSNN_BatchNorm1d_get_batches(const NNModule module); +EXPORT_API(Tensor) THSNN_BatchNorm2d_get_batches(const NNModule module); +EXPORT_API(Tensor) THSNN_BatchNorm3d_get_batches(const NNModule module); + EXPORT_API(NNModule) THSNN_InstanceNorm1d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule); EXPORT_API(Tensor) THSNN_InstanceNorm1d_forward(const NNModule module, const Tensor tensor); EXPORT_API(NNModule) THSNN_InstanceNorm2d_ctor(const int64_t features, const double eps, const double momentum, const bool affine, const bool track_running_stats, NNAnyModule* outAsAnyModule); diff --git a/src/Native/LibTorchSharp/THSNormalization.cpp b/src/Native/LibTorchSharp/THSNormalization.cpp index 90a982089..2d88db262 100644 --- a/src/Native/LibTorchSharp/THSNormalization.cpp +++ b/src/Native/LibTorchSharp/THSNormalization.cpp @@ -202,6 +202,15 @@ Tensor THSNN_BatchNorm1d_get_var(const NNModule module) return nullptr; } +Tensor THSNN_BatchNorm1d_get_batches(const NNModule module) +{ + CATCH( + auto v = (*module)->as()->num_batches_tracked; + return v.defined() ? ResultTensor(v) : nullptr; + ); + return nullptr; +} + void THSNN_BatchNorm1d_set_mean(const NNModule module, const Tensor bias) { CATCH( @@ -259,6 +268,15 @@ Tensor THSNN_BatchNorm2d_get_var(const NNModule module) return nullptr; } +Tensor THSNN_BatchNorm2d_get_batches(const NNModule module) +{ + CATCH( + auto v = (*module)->as()->num_batches_tracked; + return v.defined() ? ResultTensor(v) : nullptr; + ); + return nullptr; +} + void THSNN_BatchNorm2d_set_mean(const NNModule module, const Tensor bias) { CATCH( @@ -316,6 +334,15 @@ Tensor THSNN_BatchNorm3d_get_var(const NNModule module) return nullptr; } +Tensor THSNN_BatchNorm3d_get_batches(const NNModule module) +{ + CATCH( + auto v = (*module)->as()->num_batches_tracked; + return v.defined() ? ResultTensor(v) : nullptr; + ); + return nullptr; +} + void THSNN_BatchNorm3d_set_mean(const NNModule module, const Tensor bias) { CATCH( diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 682ccd173..0ffed83aa 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -520,6 +520,25 @@ public virtual bool has_parameter(string target) return false; } + /// + /// Returns the buffer given by target if it exists, otherwise throws an error. + /// + /// The fully-qualified string name of the buffer to look for. + /// The tensor referenced by target + public virtual Tensor get_buffer(string target) + { + if (_internal_buffers.TryGetValue(target, out var parameter)) { + return parameter; + } + foreach (var child in named_children().Where(nc => target.StartsWith(nc.name))) { + var prefix = child.name + "."; + var p = child.module.get_buffer(target.Remove(0, prefix.Length)); + if (p is not null) + return p; + } + return null; + } + /// /// Returns the parameter given by target if it exists, otherwise throws an error. /// diff --git a/src/TorchSharp/NN/Normalization/BatchNorm1D.cs b/src/TorchSharp/NN/Normalization/BatchNorm1D.cs index 4940350fd..6c8dda3f1 100644 --- a/src/TorchSharp/NN/Normalization/BatchNorm1D.cs +++ b/src/TorchSharp/NN/Normalization/BatchNorm1D.cs @@ -48,6 +48,8 @@ public override Tensor forward(Tensor tensor) [DllImport("LibTorchSharp")] private static extern IntPtr THSNN_BatchNorm1d_get_var(torch.nn.Module.HType module); [DllImport("LibTorchSharp")] + private static extern IntPtr THSNN_BatchNorm1d_get_batches(torch.nn.Module.HType module); + [DllImport("LibTorchSharp")] private static extern void THSNN_BatchNorm1d_set_mean(torch.nn.Module.HType module, IntPtr weight); [DllImport("LibTorchSharp")] private static extern void THSNN_BatchNorm1d_set_var(torch.nn.Module.HType module, IntPtr weight); @@ -78,7 +80,7 @@ public Parameter weight { } } - public Parameter? running_mean { + public Tensor? running_mean { get { var res = THSNN_BatchNorm1d_get_mean(handle); if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } @@ -87,11 +89,11 @@ public Parameter? running_mean { set { THSNN_BatchNorm1d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle)); torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + ConditionallyRegisterBuffer("running_mean", value); } } - public Parameter? running_var { + public Tensor? running_var { get { var res = THSNN_BatchNorm1d_get_var(handle); if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } @@ -100,7 +102,15 @@ public Parameter? running_var { set { THSNN_BatchNorm1d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle)); torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + ConditionallyRegisterBuffer("running_var", value); + } + } + + public Tensor? num_batches_tracked { + get { + var res = THSNN_BatchNorm1d_get_batches(handle); + if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } + return new Parameter(res); } } diff --git a/src/TorchSharp/NN/Normalization/BatchNorm2D.cs b/src/TorchSharp/NN/Normalization/BatchNorm2D.cs index 2c51e1e1c..5098aae54 100644 --- a/src/TorchSharp/NN/Normalization/BatchNorm2D.cs +++ b/src/TorchSharp/NN/Normalization/BatchNorm2D.cs @@ -5,6 +5,7 @@ using static TorchSharp.torch; +#nullable enable namespace TorchSharp { using Modules; @@ -46,6 +47,8 @@ public override Tensor forward(Tensor tensor) [DllImport("LibTorchSharp")] private static extern IntPtr THSNN_BatchNorm2d_get_var(torch.nn.Module.HType module); [DllImport("LibTorchSharp")] + private static extern IntPtr THSNN_BatchNorm2d_get_batches(torch.nn.Module.HType module); + [DllImport("LibTorchSharp")] private static extern void THSNN_BatchNorm2d_set_mean(torch.nn.Module.HType module, IntPtr weight); [DllImport("LibTorchSharp")] private static extern void THSNN_BatchNorm2d_set_var(torch.nn.Module.HType module, IntPtr weight); @@ -76,7 +79,7 @@ public Parameter weight { } } - public Parameter running_mean { + public Tensor? running_mean { get { var res = THSNN_BatchNorm2d_get_mean(handle); if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } @@ -85,11 +88,11 @@ public Parameter running_mean { set { THSNN_BatchNorm2d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle)); torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + ConditionallyRegisterBuffer("running_mean", value); } } - public Parameter running_var { + public Tensor? running_var { get { var res = THSNN_BatchNorm2d_get_var(handle); if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } @@ -98,7 +101,15 @@ public Parameter running_var { set { THSNN_BatchNorm2d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle)); torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + ConditionallyRegisterBuffer("running_var", value); + } + } + + public Tensor? num_batches_tracked { + get { + var res = THSNN_BatchNorm2d_get_batches(handle); + if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } + return new Parameter(res); } } diff --git a/src/TorchSharp/NN/Normalization/BatchNorm3D.cs b/src/TorchSharp/NN/Normalization/BatchNorm3D.cs index 8b0b86426..1e0227e81 100644 --- a/src/TorchSharp/NN/Normalization/BatchNorm3D.cs +++ b/src/TorchSharp/NN/Normalization/BatchNorm3D.cs @@ -47,6 +47,8 @@ public override Tensor forward(Tensor tensor) [DllImport("LibTorchSharp")] private static extern IntPtr THSNN_BatchNorm3d_get_var(torch.nn.Module.HType module); [DllImport("LibTorchSharp")] + private static extern IntPtr THSNN_BatchNorm3d_get_batches(torch.nn.Module.HType module); + [DllImport("LibTorchSharp")] private static extern void THSNN_BatchNorm3d_set_mean(torch.nn.Module.HType module, IntPtr weight); [DllImport("LibTorchSharp")] private static extern void THSNN_BatchNorm3d_set_var(torch.nn.Module.HType module, IntPtr weight); @@ -77,7 +79,7 @@ public Parameter weight { } } - public Parameter? running_mean { + public Tensor? running_mean { get { var res = THSNN_BatchNorm3d_get_mean(handle); if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } @@ -86,11 +88,11 @@ public Parameter? running_mean { set { THSNN_BatchNorm3d_set_mean(handle, (value is null ? IntPtr.Zero : value.Handle)); torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + ConditionallyRegisterBuffer("running_mean", value); } } - public Parameter? running_var { + public Tensor? running_var { get { var res = THSNN_BatchNorm3d_get_var(handle); if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } @@ -99,7 +101,15 @@ public Parameter? running_var { set { THSNN_BatchNorm3d_set_var(handle, (value is null ? IntPtr.Zero : value.Handle)); torch.CheckForErrors(); - ConditionallyRegisterParameter("bias", value); + ConditionallyRegisterBuffer("running_var", value); + } + } + + public Tensor? num_batches_tracked { + get { + var res = THSNN_BatchNorm3d_get_batches(handle); + if (res == IntPtr.Zero) { torch.CheckForErrors(); return null; } + return new Parameter(res); } } diff --git a/src/TorchSharp/NN/Sequential.cs b/src/TorchSharp/NN/Sequential.cs index 43ff20206..8321c55bc 100644 --- a/src/TorchSharp/NN/Sequential.cs +++ b/src/TorchSharp/NN/Sequential.cs @@ -54,6 +54,17 @@ internal void Add(torch.nn.Module module) } } + public override IEnumerable<(string name, Tensor buffer)> named_buffers(bool recurse = true) + { + if (!recurse) yield break; + + for (var i = 0; i < _names.Count; i++) { + foreach (var (n, p) in _modules[i].named_buffers(true)) { + yield return ($"{_names[i]}.{n}", p); + } + } + } + public override IEnumerable<(string name, torch.nn.Module module)> named_children() { for (var i = 0; i < _names.Count; i++) { diff --git a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj index 768c20c7e..f090a7570 100644 --- a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj +++ b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj @@ -12,7 +12,9 @@ - + + Always + @@ -23,6 +25,12 @@ + + + PreserveNewest + + + diff --git a/test/TorchSharpTest/TestTorchTensorBugs.cs b/test/TorchSharpTest/TestTorchTensorBugs.cs index dc77566eb..a8d4dedb7 100644 --- a/test/TorchSharpTest/TestTorchTensorBugs.cs +++ b/test/TorchSharpTest/TestTorchTensorBugs.cs @@ -443,5 +443,63 @@ class Module500 : Module public override torch.Tensor forward(torch.Tensor t) => bn1.forward(t); } + + [Fact] + public void ValidateIssue510() + { + var model = new Module510(1, 32); + model.forward(torch.randn(16, 1, 32)); + + var w0 = model.get_parameter("stack.0.weight").clone(); + var w1 = model.get_parameter("stack.1.weight").clone(); + var b1 = model.get_parameter("stack.1.bias").clone(); + var rm = model.get_buffer("stack.1.running_mean").clone(); + var rv = model.get_buffer("stack.1.running_var").clone(); + var nm = model.get_buffer("stack.1.num_batches_tracked").clone(); + + model.load(@".\bug510.dat"); + + var w0_ = model.get_parameter("stack.0.weight"); + var w1_ = model.get_parameter("stack.1.weight"); + var b1_ = model.get_parameter("stack.1.bias"); + var rm_ = model.get_buffer("stack.1.running_mean"); + var rv_ = model.get_buffer("stack.1.running_var"); + var nm_ = model.get_buffer("stack.1.num_batches_tracked"); + + Assert.NotEqual(w0, w0_); + Assert.NotEqual(w1, w1_); + Assert.NotEqual(b1, b1_); + Assert.NotEqual(rm, rm_); + Assert.NotEqual(rv, rv_); + Assert.Equal(1, nm.item()); + Assert.Equal(0, nm_.item()); + } + + internal class Module510 : Module + { + private readonly Module stack; + + public Module510(int in_channels, int out_channels, int kernel_size=3, int stride = 1, int padding = 0) : base(String.Empty) + { + var temp = BatchNorm1d(out_channels); + this.stack = Sequential( + Conv1d(in_channels, out_channels, 3, stride: stride, padding: padding, bias: false), + temp, + ReLU(inPlace: true) + ); + + temp.weight = Parameter(torch.randn(temp.weight.shape)); + temp.bias = Parameter(torch.randn(temp.bias.shape)); + if (temp.running_mean is not null) temp.running_mean = torch.randn(temp.running_mean.shape); + if (temp.running_var is not null) temp.running_var = torch.randn(temp.running_var.shape); + + this.RegisterComponents(); + } + + public override torch.Tensor forward(torch.Tensor t) + { + return this.stack.forward(t); + } + } } } \ No newline at end of file diff --git a/test/TorchSharpTest/TorchSharpTest.csproj b/test/TorchSharpTest/TorchSharpTest.csproj index de85f1656..4821ca305 100644 --- a/test/TorchSharpTest/TorchSharpTest.csproj +++ b/test/TorchSharpTest/TorchSharpTest.csproj @@ -10,6 +10,16 @@ $(OutputPath) + + + + + + + PreserveNewest + + + PreserveNewest diff --git a/test/TorchSharpTest/bug510.dat b/test/TorchSharpTest/bug510.dat new file mode 100644 index 0000000000000000000000000000000000000000..ba41d979cebf6aa40491edb043febb85dc9d4257 GIT binary patch literal 1036 zcmZSMD=tY)&ek)~D^JZ#&nRJIR$yda;q0=n>~`5ckI7s1&A5ET_SodeeI{ye?be+5 zw>MztJKGlr59~b_aojG2@1WhBDb98wJ=}I{;_urXnDgK6+JS5PGTaU97EE|+XK`1- zF5~=8JCC&ub|+L=>=>AG_FR#iZI`h4rd`B$#eK8mAMf)DW7_{9W9c5A{=NG&Y}@Tl zcz&~ExPEQ#g~&+T5H|}u5&QXeTeh#=>rr~y?tCNL{zsJz`^C-(?)Q*9zgOYhAv^8Q zKlbq?bMKE=7T#Y_zjI%MXqe57?jSpkW&i9JO*j_^CuHA{kmc1nx=h_`%X5N1!{;%Bx{|kG62p_in zafj9ZhqKkb37xNPw=Ouc?}78xeG?MAZNJRdu+#DRve#ibgWZ8!W;VwJ)AyxGHQKpE zvD+`7eQuw}H%Ge}!dL8^`FHQjnfh|C!q)$LR&2&Y~mFuAUbE z`vvOu?A>*wabMEXh5Mj!VhD>9Mg<0j2K#};JdjjksF#$PSj+};&%go^n0ty!^YSwD y(&KYe6Y~avyM)o*RhC!;cGkc~pES%_d8N7WNr@%N8L7qbB}KrDpPIq~k4XTu4j3{3 literal 0 HcmV?d00001 diff --git a/test/TorchSharpTest/bug510.py b/test/TorchSharpTest/bug510.py new file mode 100644 index 000000000..7fa461630 --- /dev/null +++ b/test/TorchSharpTest/bug510.py @@ -0,0 +1,24 @@ +import torch +import src.Python.exportsd as exportsd + +class BasicConv1d(torch.nn.Module): + def __init__(self, in_channels, out_channels, **kwargs): + super().__init__() + + self.stack = torch.nn.Sequential( + torch.nn.Conv1d(in_channels, out_channels, kernel_size=3, bias=False, **kwargs), + torch.nn.BatchNorm1d(out_channels), + torch.nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.stack(x) + +if __name__ == '__main__': + # Create model + model = BasicConv1d(1, 32) + + #Export model to .dat file for ingestion into TorchSharp + f = open("bug510.dat", "wb") + exportsd.save_state_dict(model.to("cpu").state_dict(), f) + f.close() \ No newline at end of file