Skip to content

Commit 438df25

Browse files
committed
GPT-2: enable passing custom session config for training
1 parent 8c0f45f commit 438df25

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

GPT-2/Gpt2Trainer.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@ public Gpt2Trainer(DataSet dataset, Gpt2Encoder encoder, HParams hParams,
4747
public int SampleEvery { get; set; } = 100;
4848
public int SampleNum { get; set; } = 1;
4949

50-
public void Train(string checkpoint, string run, int? counter, CancellationToken cancellation) {
51-
new Session().UseSelf(session => {
50+
public void Train(string checkpoint, string run, int? counter, dynamic sessionConfig = null, CancellationToken cancellation = default) {
51+
Session sess = sessionConfig == null
52+
? Session.NewDyn(config: sessionConfig)
53+
: new Session();
54+
sess.UseSelf(session => {
5255
var context = tf.placeholder(tf.int32, new TensorShape(this.batchSize, null));
5356
var output = Gpt2Model.Model(this.hParams, input: context);
5457
Tensor labels = context[Range.All, Range.StartAt(1)];

GPT-2/TrainCommand.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using ManyConsole.CommandLineUtils;
99
using numpy;
1010
using Python.Runtime;
11+
using tensorflow.core.protobuf.config_pb2;
1112
using DataSet = System.Collections.Generic.List<numpy.ndarray>;
1213
class TrainCommand: ConsoleCommand {
1314
public override int Run(string[] remainingArguments) {
@@ -34,12 +35,15 @@ public override int Run(string[] remainingArguments) {
3435
var random = this.Seed == null ? new Random() : new Random(this.Seed.Value);
3536
var stop = new CancellationTokenSource();
3637
Console.CancelKeyPress += delegate { stop.Cancel(); };
38+
dynamic config = config_pb2.ConfigProto();
39+
config.gpu_options.allow_growth = true;
3740
new Gpt2Trainer(dataset, encoder, hParams, this.BatchSize, this.SampleLength, random) {
3841
SaveEvery = this.SaveEvery,
3942
SampleNum = this.SampleNum,
4043
SampleEvery = this.SampleEvery,
4144
}
4245
.Train(checkpoint, this.RunName,
46+
sessionConfig: config,
4347
counter: checkpoint == "fresh" ? 1 : (int?)null,
4448
cancellation: stop.Token);
4549

0 commit comments

Comments
 (0)