Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug510 #511

Merged
merged 6 commits into from
Feb 14, 2022
Merged

Bug510 #511

merged 6 commits into from
Feb 14, 2022

Conversation

NiklasGustafsson
Copy link
Contributor

Addressing #510

Adding 'num_batches_tracked' to BatcNorm{123}d
src/Native/LibTorchSharp/THSNormalization.cpp Outdated Show resolved Hide resolved
src/Native/LibTorchSharp/THSNormalization.cpp Outdated Show resolved Hide resolved
src/TorchSharp/NN/Module.cs Show resolved Hide resolved
src/TorchSharp/NN/Module.cs Outdated Show resolved Hide resolved
src/TorchSharp/NN/Normalization/BatchNorm1D.cs Outdated Show resolved Hide resolved
@lostmsu
Copy link
Contributor

lostmsu commented Feb 13, 2022

Mostly looks good to me, except the get_buffer handling of nested buffers.

Copy link
Contributor

@lostmsu lostmsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd squash.

var splits = target.Split('.');
if (splits.Length > 1) {
foreach (var child in named_children().Where(nc => nc.name == splits[0])) {
if (child.module.has_parameter(target.Remove(0, splits[0].Length + 1)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: target[(splits[0].Length + 2)..]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth doing if it's more than trivially faster, otherwise I'll leave as is.

Comment on lines +536 to +550
if (target is null) throw new ArgumentNullException("target");
if (_internal_buffers.TryGetValue(target, out var buffer)) {
return buffer;
}

var splits = target.Split('.');
if (splits.Length > 1) {
foreach (var child in named_children().Where(nc => nc.name == splits[0])) {
var p = child.module.get_buffer(target.Remove(0, splits[0].Length + 1));
if (p is not null)
return p;
}
}
return null;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: this repeats a bit. Perhaps make a T get_named<T>(name, internal_collection, getter)?

@NiklasGustafsson NiklasGustafsson merged commit 897c8ac into dotnet:main Feb 14, 2022
@NiklasGustafsson NiklasGustafsson deleted the bug510 branch August 5, 2022 14:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Module.Load throws Mismatched state_dict sizes exception on BatchNorm1d
2 participants