Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making TensorFlowTransform trainable. #1063

Merged
merged 24 commits into from
Oct 8, 2018
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
22d08d6
Making TensorflowTransform trainable.
zeahmed Sep 25, 2018
1fc443a
Merge remote-tracking branch 'upstream/master' into tf_training
zeahmed Sep 26, 2018
ef67dd1
Added model files and script to create model in python.
zeahmed Sep 26, 2018
c1eafa6
Added functionality to serialized TensorFlow model after retraining.
zeahmed Sep 27, 2018
07ae61c
Added comments. Added checking for parameters and types.
zeahmed Oct 1, 2018
6ec4a5a
Merge remote-tracking branch 'upstream/master' into tf_training
zeahmed Oct 1, 2018
36bf551
Addressed reviewers' comments.
zeahmed Oct 1, 2018
570bd02
Addressed reviewers' comments.
zeahmed Oct 2, 2018
3cc20e4
Added more tests...
zeahmed Oct 2, 2018
49bdf59
Merge remote-tracking branch 'upstream/master' into tf_training
zeahmed Oct 2, 2018
a9a90ab
Fixed failing tests.
zeahmed Oct 3, 2018
345c504
Using progress channel to report metrics.
zeahmed Oct 3, 2018
883f9b4
Moved models to models repo.
zeahmed Oct 3, 2018
18c70f3
Removed extra array copy.
zeahmed Oct 3, 2018
8ea8e9d
Addressed reviewers' comments.
zeahmed Oct 4, 2018
735f7ca
Addressed reviewers' comments.
zeahmed Oct 4, 2018
9c07a53
Addressed reviewers' comments.
zeahmed Oct 4, 2018
14ec78b
Addressed reviewers' comments.
zeahmed Oct 5, 2018
afb3ca1
Addressed reviewers' comments.
zeahmed Oct 5, 2018
5a6cc08
Added try-finally in tests so that folders are deleted even when ther…
zeahmed Oct 8, 2018
4297376
Merge conflicts.
zeahmed Oct 8, 2018
c0fe791
Merge remote-tracking branch 'upstream/master' into tf_training
zeahmed Oct 8, 2018
0984b0f
Merge remote-tracking branch 'upstream/master' into tf_training
zeahmed Oct 8, 2018
18040ca
Merge branch 'tf_training' of https://github.com/zeahmed/machinelearn…
zeahmed Oct 8, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/Microsoft.ML.Legacy/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15801,6 +15801,66 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.
/// </summary>
public string[] OutputColumns { get; set; }

/// <summary>
/// Training labels.
/// </summary>
public string LabeLColumn { get; set; } = "Label";

/// <summary>
/// The name of the optimization operation in the TensorFlow graph.
/// </summary>
public string OptimizationOperation { get; set; }

/// <summary>
/// The name of the operation in the TensorFlow graph to compute training loss (Optional)
/// </summary>
public string LossOperation { get; set; }

/// <summary>
/// The name of the operation in the TensorFlow graph to compute performance metric during training (Optional)
/// </summary>
public string MetricOperation { get; set; }

/// <summary>
/// Number of samples to use for mini-batch training.
/// </summary>
public int BatchSize { get; set; } = 64;

/// <summary>
/// Number of training iterations.
/// </summary>
public int Epoch { get; set; } = 5;

/// <summary>
/// The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).
/// </summary>
public string LearningRateOperation { get; set; }

/// <summary>
/// Learning rate to use during optimization.
/// </summary>
public float LearningRate { get; set; } = 0.01f;

/// <summary>
/// Shuffle data before each iteration.
/// </summary>
public bool Shuffle { get; set; } = true;

/// <summary>
/// Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.
/// </summary>
public string SaveLocationOperation { get; set; } = "save/Const";

/// <summary>
/// Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.
/// </summary>
public string SaveOperation { get; set; } = "save/control_dependency";

/// <summary>
/// Retrain TensorFlow model.
/// </summary>
public bool ReTrain { get; set; } = false;

/// <summary>
/// Input dataset
/// </summary>
Expand Down
40 changes: 40 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Text;
using size_t = System.UIntPtr;
using TF_Tensor = System.IntPtr;
using TF_Status = System.IntPtr;

#pragma warning disable MSML_ParameterLocalVarName

Expand Down Expand Up @@ -74,6 +75,18 @@ internal partial class TFTensor : TFDisposableThreadSafe
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe TF_Tensor TF_NewTensor(TFDataType dataType, IntPtr zeroDims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg);

// extern size_t TF_StringEncode (const char *src, size_t src_len, char *dst, size_t dst_len, TF_Status *status);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe size_t TF_StringEncode(byte* src, size_t src_len, sbyte* dst, size_t dst_len, TF_Status status);

// extern size_t TF_StringDecode (const char *src, size_t src_len, const char **dst, size_t *dst_len, TF_Status *status);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe size_t TF_StringDecode(sbyte* src, size_t src_len, sbyte** dst, size_t* dst_len, TF_Status status);

// extern size_t TF_StringEncodedSize (size_t len);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern size_t TF_StringEncodedSize(size_t len);

internal TFTensor(IntPtr handle) : base(handle) { }

internal static Deallocator FreeTensorDataDelegate = FreeTensorData;
Expand Down Expand Up @@ -409,6 +422,31 @@ public TFTensor(long[] data) : base(SetupTensor(TFDataType.Int64, data, size: 8)
/// <param name="data">Data.</param>
public TFTensor(Complex[] data) : base(SetupTensor(TFDataType.Complex128, data, size: 16)) { }

internal static unsafe TFTensor CreateString(byte[] buffer)
{
if (buffer == null)
throw new ArgumentNullException(nameof(buffer));
//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
// TF_StringEncode-encoded bytes.
//
var size = TF_StringEncodedSize((UIntPtr)buffer.Length);
IntPtr handle = TF_AllocateTensor(TFDataType.String, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));

// Clear offset table
IntPtr dst = TF_TensorData(handle);
Marshal.WriteInt64(dst, 0);
var status = new TFStatus();
Copy link

@yaeldekel yaeldekel Oct 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

status [](start = 16, length = 6)

using (var status ...) #Resolved

fixed (byte* src = &buffer[0])
{
TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(dst + 8), size, status.handle);
var ok = status.StatusCode == TFCode.Ok;
if (!ok)
return null;
}
return new TFTensor(handle);
}

// Convenience function to factor out the setup of a new tensor from an array
internal static IntPtr SetupTensor(TFDataType dt, long[] dims, Array data, int count, int size)
{
Expand Down Expand Up @@ -591,6 +629,8 @@ public static Type TypeFromTensorType(TFDataType type)
{
case TFDataType.Float:
return typeof(float);
case TFDataType.Float_ref:
return typeof(float);
case TFDataType.Double:
return typeof(double);
case TFDataType.Int32:
Expand Down
12 changes: 9 additions & 3 deletions src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ public TFStatus() : base(TF_NewStatus())

// extern void TF_DeleteStatus (TF_Status *);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe void TF_DeleteStatus(TF_Status status);
internal static extern unsafe void TF_DeleteStatus(TF_Status status);
Copy link

@yaeldekel yaeldekel Oct 4, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal [](start = 8, length = 8)

Can this one and TF_GetCode go back to being private? #Closed


internal override void NativeDispose(IntPtr handle)
{
Expand All @@ -313,7 +313,7 @@ public void SetStatusCode(TFCode code, string msg)

// extern TF_Code TF_GetCode (const TF_Status *s);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe TFCode TF_GetCode(TF_Status s);
internal static extern unsafe TFCode TF_GetCode(TF_Status s);

/// <summary>
/// Gets the status code for the status code.
Expand Down Expand Up @@ -1666,7 +1666,13 @@ internal enum TFDataType : uint
/// <summary>
/// 64-bit unsigned integers
/// </summary>
UInt64 = 23
UInt64 = 23,

/// <summary>
/// Float reference type. It used for defining types of Variables.
/// Please https://www.tensorflow.org/api_docs/python/tf/DType for more details.
/// </summary>
Float_ref = 101
}

/// <summary>
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type)
{
case TFDataType.Float:
return NumberType.R4;
case TFDataType.Float_ref:
return NumberType.R4;
case TFDataType.Double:
return NumberType.R8;
case TFDataType.UInt16:
Expand All @@ -144,6 +146,12 @@ private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type)
return NumberType.U4;
case TFDataType.UInt64:
return NumberType.U8;
case TFDataType.Int16:
return NumberType.I2;
case TFDataType.Int32:
return NumberType.I4;
case TFDataType.Int64:
return NumberType.I8;
default:
return null;
}
Expand Down Expand Up @@ -336,6 +344,9 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
case TFDataType.UInt16:
case TFDataType.UInt32:
case TFDataType.UInt64:
case TFDataType.Int16:
case TFDataType.Int32:
case TFDataType.Int64:
return true;
default:
return false;
Expand Down
Loading