Skip to content

Commit d3075bd

Browse files
committed
2 parents babfa3a + 2a06ffa commit d3075bd

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

GPT-2/Gpt2Trainer.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ public Gpt2Trainer(DataSet dataset, Gpt2Encoder encoder, HParams hParams,
4747
public int SampleEvery { get; set; } = 100;
4848
public int SampleNum { get; set; } = 1;
4949

50-
public void Train(string checkpoint, string run, int? counter, CancellationToken cancellation) {
51-
new Session().UseSelf(session => {
50+
public void Train(string checkpoint, string run, int? counter, dynamic sessionConfig = null, CancellationToken cancellation = default) {
51+
Session sess = sessionConfig == null
52+
? Session.NewDyn(config: sessionConfig)
53+
: new Session();
54+
sess.UseSelf(session => {
5255
var context = tf.placeholder(tf.int32, new TensorShape(this.batchSize, null));
5356
var output = Gpt2Model.Model(this.hParams, input: context);
5457
Tensor labels = context[Range.All, Range.StartAt(1)];

GPT-2/TrainCommand.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using ManyConsole.CommandLineUtils;
99
using numpy;
1010
using Python.Runtime;
11+
using tensorflow.core.protobuf.config_pb2;
1112
using DataSet = System.Collections.Generic.List<numpy.ndarray>;
1213
class TrainCommand: ConsoleCommand {
1314
public override int Run(string[] remainingArguments) {
@@ -34,12 +35,15 @@ public override int Run(string[] remainingArguments) {
3435
var random = this.Seed == null ? new Random() : new Random(this.Seed.Value);
3536
var stop = new CancellationTokenSource();
3637
Console.CancelKeyPress += delegate { stop.Cancel(); };
38+
dynamic config = config_pb2.ConfigProto();
39+
config.gpu_options.allow_growth = true;
3740
new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) {
3841
SaveEvery = this.SaveEvery,
3942
SampleNum = this.SampleNum,
4043
SampleEvery = this.SampleEvery,
4144
}
4245
.Train(checkpoint, this.RunName,
46+
sessionConfig: config,
4347
counter: checkpoint == "fresh" ? 1 : (int?)null,
4448
cancellation: stop.Token);
4549

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# Gradient-Samples
22
Samples for [Gradient](https://losttech.software/gradient.html), TensorFlow binding for .NET
33

4-
**[Billion Songs](https://github.com/losttech/BillionSongs)** a repository with
4+
**[Billion Songs](https://github.com/losttech/BillionSongs)** a separate repository with
55
deep learning-powered song lyrics generator in an ASP.NET Core web site.
66

7+
See also:
8+
[Writing billion songs with C# and Deep Learning](https://habr.com/post/453232/)
9+
710
**BasicMath**- creates two constant tensors and performs simple algebraic operations on them
811

912
**CharRNN** - generates semi-sensical text in the style of input. For example (Shakespeare):

0 commit comments

Comments
 (0)