Skip to content

Commit

Permalink
Added feature to support training in TensorFlowTransform. (#1063)
Browse files Browse the repository at this point in the history
  • Loading branch information
zeahmed authored Oct 8, 2018
1 parent a9d8ae4 commit d2ed0ad
Show file tree
Hide file tree
Showing 9 changed files with 1,026 additions and 86 deletions.
65 changes: 65 additions & 0 deletions src/Microsoft.ML.Legacy/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16650,6 +16650,71 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.
/// </summary>
public string[] OutputColumns { get; set; }

/// <summary>
/// Training labels.
/// </summary>
public string LabelColumn { get; set; }

/// <summary>
/// TensorFlow label node.
/// </summary>
public string TensorFlowLabel { get; set; }

/// <summary>
/// The name of the optimization operation in the TensorFlow graph.
/// </summary>
public string OptimizationOperation { get; set; }

/// <summary>
/// The name of the operation in the TensorFlow graph to compute training loss (Optional)
/// </summary>
public string LossOperation { get; set; }

/// <summary>
/// The name of the operation in the TensorFlow graph to compute performance metric during training (Optional)
/// </summary>
public string MetricOperation { get; set; }

/// <summary>
/// Number of samples to use for mini-batch training.
/// </summary>
public int BatchSize { get; set; } = 64;

/// <summary>
/// Number of training iterations.
/// </summary>
public int Epoch { get; set; } = 5;

/// <summary>
/// The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).
/// </summary>
public string LearningRateOperation { get; set; }

/// <summary>
/// Learning rate to use during optimization.
/// </summary>
public float LearningRate { get; set; } = 0.01f;

/// <summary>
/// Shuffle data before each iteration.
/// </summary>
public bool Shuffle { get; set; } = true;

/// <summary>
/// Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.
/// </summary>
public string SaveLocationOperation { get; set; } = "save/Const";

/// <summary>
/// Name of the input in TensorFlow graph that specifiy the location for saving/restoring models from disk.
/// </summary>
public string SaveOperation { get; set; } = "save/control_dependency";

/// <summary>
/// Retrain TensorFlow model.
/// </summary>
public bool ReTrain { get; set; } = false;

/// <summary>
/// Input dataset
/// </summary>
Expand Down
42 changes: 42 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Text;
using size_t = System.UIntPtr;
using TF_Tensor = System.IntPtr;
using TF_Status = System.IntPtr;

#pragma warning disable MSML_ParameterLocalVarName

Expand Down Expand Up @@ -74,6 +75,18 @@ internal partial class TFTensor : TFDisposableThreadSafe
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe TF_Tensor TF_NewTensor(TFDataType dataType, IntPtr zeroDims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg);

// extern size_t TF_StringEncode (const char *src, size_t src_len, char *dst, size_t dst_len, TF_Status *status);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe size_t TF_StringEncode(byte* src, size_t src_len, sbyte* dst, size_t dst_len, TF_Status status);

// extern size_t TF_StringDecode (const char *src, size_t src_len, const char **dst, size_t *dst_len, TF_Status *status);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe size_t TF_StringDecode(sbyte* src, size_t src_len, sbyte** dst, size_t* dst_len, TF_Status status);

// extern size_t TF_StringEncodedSize (size_t len);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern size_t TF_StringEncodedSize(size_t len);

internal TFTensor(IntPtr handle) : base(handle) { }

internal static Deallocator FreeTensorDataDelegate = FreeTensorData;
Expand Down Expand Up @@ -409,6 +422,33 @@ public TFTensor(long[] data) : base(SetupTensor(TFDataType.Int64, data, size: 8)
/// <param name="data">Data.</param>
public TFTensor(Complex[] data) : base(SetupTensor(TFDataType.Complex128, data, size: 16)) { }

internal static unsafe TFTensor CreateString(byte[] buffer)
{
if (buffer == null)
throw new ArgumentNullException(nameof(buffer));
//
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
// TF_StringEncode-encoded bytes.
//
var size = TF_StringEncodedSize((UIntPtr)buffer.Length);
IntPtr handle = TF_AllocateTensor(TFDataType.String, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));

// Clear offset table
IntPtr dst = TF_TensorData(handle);
Marshal.WriteInt64(dst, 0);
using (var status = new TFStatus())
{
fixed (byte* src = &buffer[0])
{
TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(dst + 8), size, status.handle);
var ok = status.StatusCode == TFCode.Ok;
if (!ok)
return null;
}
}
return new TFTensor(handle);
}

// Convenience function to factor out the setup of a new tensor from an array
internal static IntPtr SetupTensor(TFDataType dt, long[] dims, Array data, int count, int size)
{
Expand Down Expand Up @@ -591,6 +631,8 @@ public static Type TypeFromTensorType(TFDataType type)
{
case TFDataType.Float:
return typeof(float);
case TFDataType.Float_ref:
return typeof(float);
case TFDataType.Double:
return typeof(double);
case TFDataType.Int32:
Expand Down
8 changes: 7 additions & 1 deletion src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1666,7 +1666,13 @@ internal enum TFDataType : uint
/// <summary>
/// 64-bit unsigned integers
/// </summary>
UInt64 = 23
UInt64 = 23,

/// <summary>
/// Float reference type. It used for defining types of Variables.
/// Please https://www.tensorflow.org/api_docs/python/tf/DType for more details.
/// </summary>
Float_ref = 101
}

/// <summary>
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type)
{
case TFDataType.Float:
return NumberType.R4;
case TFDataType.Float_ref:
return NumberType.R4;
case TFDataType.Double:
return NumberType.R8;
case TFDataType.UInt16:
Expand All @@ -144,6 +146,12 @@ private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type)
return NumberType.U4;
case TFDataType.UInt64:
return NumberType.U8;
case TFDataType.Int16:
return NumberType.I2;
case TFDataType.Int32:
return NumberType.I4;
case TFDataType.Int64:
return NumberType.I8;
default:
return null;
}
Expand Down Expand Up @@ -336,6 +344,9 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
case TFDataType.UInt16:
case TFDataType.UInt32:
case TFDataType.UInt64:
case TFDataType.Int16:
case TFDataType.Int32:
case TFDataType.Int64:
return true;
default:
return false;
Expand Down
Loading

0 comments on commit d2ed0ad

Please sign in to comment.