Skip to content

Commit

Permalink
Add NameEntityRecognition and Q&A deep learning tasks. (#6760)
Browse files Browse the repository at this point in the history
* NER

* QA almost done, runtime error

* QA finished

* fixes from PR comments

* fixed build

* build fixes

* perf changes

* made disposable

* fixed not disposing model

* added some disposables to TensorFlow for memory

* build testing

* fixing build

* added missing dispose

* build fixes

* build fixes

* testing macos fix
  • Loading branch information
michaelgsharp authored Jul 24, 2023
1 parent 321158d commit 65c7ca9
Show file tree
Hide file tree
Showing 49 changed files with 90,032 additions and 383 deletions.
2 changes: 1 addition & 1 deletion build/ci/job-template.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
steps:
# Extra MacOS step required to install OS-specific dependencies
- ${{ if and(contains(parameters.pool.vmImage, 'macOS'), not(contains(parameters.name, 'cross'))) }}:
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=TRUE && brew update && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
- script: export HOMEBREW_NO_INSTALLED_DEPENDENTS_CHECK=TRUE && brew unlink libomp && brew install $(Build.SourcesDirectory)/build/libomp.rb --build-from-source --formula
displayName: Install MacOS build dependencies
# Extra Apple MacOS step required to install OS-specific dependencies
- ${{ if and(contains(parameters.pool.vmImage, 'macOS'), contains(parameters.name, 'cross')) }}:
Expand Down
13 changes: 13 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/BPE.cs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ public override IReadOnlyList<Token> Tokenize(string sequence)
return null;
}

/// <summary>
/// Map the tokenized Id to the token.
/// </summary>
/// <param name="id">The Id to map to the token.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToString(int id, bool skipSpecialTokens = false) => throw new NotImplementedException();

/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
Expand Down Expand Up @@ -443,6 +451,11 @@ internal List<Token> TokenizeWithCache(string sequence)
return tokens;
}

public override bool IsValidChar(char ch)
{
throw new NotImplementedException();
}

internal static readonly List<Token> EmptyTokensList = new();
}
}
27 changes: 27 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/EnglishRoberta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,28 @@ public EnglishRoberta(Stream vocabularyStream, Stream mergeStream, Stream highes
public override string? IdToToken(int id, bool skipSpecialTokens = false) =>
skipSpecialTokens && id < 0 ? null : _vocabReverse.TryGetValue(id, out var value) ? value : null;

/// <summary>
/// Map the tokenized Id to the original string.
/// </summary>
/// <param name="id">The Id to map to the string.</param>
/// <param name="skipSpecialTokens">Indicate if want to skip the special tokens during the decoding.</param>
/// <returns>The mapped token of the Id.</returns>
public override string? IdToString(int id, bool skipSpecialTokens = false)
{
if (skipSpecialTokens && id < 0)
return null;
if (_vocabReverse.TryGetValue(id, out var value))
{
var textChars = string.Join("", value)
.Where(c => _unicodeToByte.ContainsKey(c))
.Select(c => _unicodeToByte[c]);
var text = new string(textChars.ToArray());
return text;
}

return null;
}

/// <summary>
/// Save the model data into the vocabulary, merges, and occurrence mapping files.
/// </summary>
Expand Down Expand Up @@ -565,6 +587,11 @@ private List<Token> BpeToken(Span<char> token, Span<int> indexMapping)

return pairs;
}

public override bool IsValidChar(char ch)
{
return _byteToUnicode.ContainsKey(ch);
}
}

/// <summary>
Expand Down
10 changes: 10 additions & 0 deletions src/Microsoft.ML.Tokenizers/Model/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public abstract class Model
/// <returns>The mapped token of the Id.</returns>
public abstract string? IdToToken(int id, bool skipSpecialTokens = false);

public abstract string? IdToString(int id, bool skipSpecialTokens = false);

/// <summary>
/// Gets the dictionary mapping tokens to Ids.
/// </summary>
Expand All @@ -57,6 +59,14 @@ public abstract class Model
/// Gets a trainer object to use in training the model.
/// </summary>
public abstract Trainer? GetTrainer();

/// <summary>
/// Return true if the char is valid in the tokenizer; otherwise return false.
/// </summary>
/// <param name="ch"></param>
/// <returns></returns>
public abstract bool IsValidChar(char ch);

}

}
10 changes: 9 additions & 1 deletion src/Microsoft.ML.Tokenizers/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ public TokenizerResult Encode(string sequence)

foreach (int id in ids)
{
tokens.Add(Model.IdToToken(id) ?? "");
if (Model.GetType() == typeof(EnglishRoberta))
tokens.Add(Model.IdToString(id) ?? "");
else
tokens.Add(Model.IdToToken(id) ?? "");
}

return Decoder?.Decode(tokens) ?? string.Join("", tokens);
Expand Down Expand Up @@ -187,5 +190,10 @@ public void TrainFromFiles(
// To Do: support added vocabulary in the tokenizer which will include this returned special_tokens.
// self.add_special_tokens(&special_tokens);
}

public bool IsValidChar(char ch)
{
return Model.IsValidChar(ch);
}
}
}
36 changes: 28 additions & 8 deletions src/Microsoft.ML.TorchSharp/AutoFormerV2/ObjectDetectionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,12 @@
using static TorchSharp.torch.optim.lr_scheduler;
using Microsoft.ML.TorchSharp.Utils;
using Microsoft.ML;
using Microsoft.ML.TorchSharp.NasBert;
using System.IO;
using Microsoft.ML.Data.IO;
using Microsoft.ML.TorchSharp.Loss;
using Microsoft.ML.Transforms.Image;
using static Microsoft.ML.TorchSharp.AutoFormerV2.ObjectDetectionTrainer;
using Microsoft.ML.TorchSharp.AutoFormerV2;
using Microsoft.ML.Tokenizers;
using Microsoft.ML.TorchSharp.Extensions;
using Microsoft.ML.TorchSharp.NasBert.Models;
using static Microsoft.ML.TorchSharp.NasBert.NasBertTrainer;
using TorchSharp.Modules;
using System.Text;
using static Microsoft.ML.Data.AnnotationUtils;

[assembly: LoadableClass(typeof(ObjectDetectionTransformer), null, typeof(SignatureLoadModel),
Expand Down Expand Up @@ -503,7 +496,7 @@ private void CheckInputSchema(SchemaShape inputSchema)
}
}

public class ObjectDetectionTransformer : RowToRowTransformerBase
public class ObjectDetectionTransformer : RowToRowTransformerBase, IDisposable
{
private protected readonly Device Device;
private protected readonly AutoFormerV2 Model;
Expand All @@ -522,6 +515,7 @@ public class ObjectDetectionTransformer : RowToRowTransformerBase

private static readonly FuncStaticMethodInfo1<object, Delegate> _decodeInitMethodInfo
= new FuncStaticMethodInfo1<object, Delegate>(DecodeInit<int>);
private bool _disposedValue;

internal ObjectDetectionTransformer(IHostEnvironment env, ObjectDetectionTrainer.Options options, AutoFormerV2 model, DataViewSchema.DetachedColumn labelColumn)
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ObjectDetectionTransformer)))
Expand Down Expand Up @@ -992,5 +986,31 @@ private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> a
return col => (activeOutput(0) || activeOutput(1) || activeOutput(2)) && _inputColIndices.Any(i => i == col);
}
}

protected virtual void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
}

Model.Dispose();
_disposedValue = true;
}
}

~ObjectDetectionTransformer()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: false);
}

public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
}
}
16 changes: 16 additions & 0 deletions src/Microsoft.ML.TorchSharp/NasBert/BertModelType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;

namespace Microsoft.ML.TorchSharp.NasBert
{
internal enum BertModelType
{
NasBert,
Roberta
}
}
4 changes: 3 additions & 1 deletion src/Microsoft.ML.TorchSharp/NasBert/BertTaskType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public enum BertTaskType
None = 0,
MaskedLM = 1,
TextClassification = 2,
SentenceRegression = 3
SentenceRegression = 3,
NameEntityRecognition = 4,
QuestionAnswering = 5
}
}
3 changes: 0 additions & 3 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/BaseHead.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Text;
using TorchSharp;

namespace Microsoft.ML.TorchSharp.NasBert.Models
Expand Down
11 changes: 5 additions & 6 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/BaseModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.ML.TorchSharp.Utils;
using TorchSharp;

namespace Microsoft.ML.TorchSharp.NasBert.Models
{
internal abstract class BaseModel : torch.nn.Module<torch.Tensor, torch.Tensor, torch.Tensor>
{
protected readonly NasBertTrainer.NasBertOptions Options;
public BertTaskType HeadType => Options.TaskType;
public BertModelType EncoderType => Options.ModelType;

//public ModelType EncoderType => Options.ModelType;
public BertTaskType HeadType => Options.TaskType;

#pragma warning disable CA1024 // Use properties where appropriate: Modules should be fields in TorchSharp
public abstract TransformerEncoder GetEncoder();

public abstract BaseHead GetHead();

#pragma warning restore CA1024 // Use properties where appropriate

protected BaseModel(NasBertTrainer.NasBertOptions options)
Expand Down
36 changes: 36 additions & 0 deletions src/Microsoft.ML.TorchSharp/NasBert/Models/ModelPrediction.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using TorchSharp;

namespace Microsoft.ML.TorchSharp.NasBert.Models
{
internal sealed class ModelForPrediction : NasBertModel
{
[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_PrivateFieldName:Private field name not in: _camelCase format", Justification = "Has to match TorchSharp model.")]
private readonly PredictionHead PredictionHead;

public override BaseHead GetHead() => PredictionHead;

public ModelForPrediction(NasBertTrainer.NasBertOptions options, int padIndex, int symbolsCount, int numClasses)
: base(options, padIndex, symbolsCount)
{
PredictionHead = new PredictionHead(
inputDim: Options.EncoderOutputDim,
numClasses: numClasses,
dropoutRate: Options.PoolerDropout);
Initialize();
RegisterComponents();
}

[System.Diagnostics.CodeAnalysis.SuppressMessage("Naming", "MSML_GeneralName:This name should be PascalCased", Justification = "Need to match TorchSharp.")]
public override torch.Tensor forward(torch.Tensor srcTokens, torch.Tensor tokenMask = null)
{
using var disposeScope = torch.NewDisposeScope();
var x = ExtractFeatures(srcTokens);
x = PredictionHead.call(x);
return x.MoveToOuterDisposeScope();
}
}
}
Loading

0 comments on commit 65c7ca9

Please sign in to comment.