Skip to content

Commit 6f3a6d4

Browse files
committed
Feature: Simplify brain creation and input layer population.
1 parent cafae99 commit 6f3a6d4

File tree

15 files changed

+74
-86
lines changed

15 files changed

+74
-86
lines changed

G33kShell.Desktop/Console/Screensavers/AI/AiBrainBase.cs

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@ namespace G33kShell.Desktop.Console.Screensavers.AI;
1717
public abstract class AiBrainBase
1818
{
1919
[JsonProperty] private NeuralNetwork m_qNet;
20+
private readonly double[] m_inputVector;
2021

21-
public int InputSize { get; private set; }
22+
public int InputSize { get; }
2223
public int[] HiddenLayers { get; private set; }
2324
public int OutputSize { get; private set; }
2425

@@ -28,14 +29,26 @@ protected AiBrainBase(int inputSize, int[] hiddenLayers, int outputSize)
2829
HiddenLayers = hiddenLayers;
2930
OutputSize = outputSize;
3031
m_qNet = new NeuralNetwork(inputSize, hiddenLayers, outputSize, learningRate: 0.05);
32+
m_inputVector = new double[inputSize];
3133
}
3234

3335
protected AiBrainBase(AiBrainBase toCopy) =>
3436
m_qNet = toCopy.m_qNet.Clone();
3537

3638
protected int ChooseHighestOutput(IAiGameState state) => ArgMax(GetOutputs(state));
3739

38-
protected double[] GetOutputs(IAiGameState state) => m_qNet.Predict(state.ToInputVector());
40+
protected double[] GetOutputs(IAiGameState state)
41+
{
42+
#if DEBUG
43+
Array.Fill(m_inputVector, 0xDE);
44+
#endif
45+
state.FillInputVector(m_inputVector);
46+
#if DEBUG
47+
if (m_inputVector.Contains(0xDE))
48+
throw new Exception("Input vector contains uninitialized data.");
49+
#endif
50+
return m_qNet.Predict(m_inputVector);
51+
}
3952

4053
/// <summary>
4154
/// Finds the index of the maximum value in the array.
@@ -57,11 +70,10 @@ private static int ArgMax(double[] values)
5770

5871
public byte[] Save() => JsonConvert.SerializeObject(this).Compress();
5972

60-
public void Load(byte[] brainBytes) => JsonConvert.PopulateObject(brainBytes.DecompressToString(), this);
61-
62-
public AiBrainBase Randomize()
73+
public AiBrainBase Load(byte[] brainBytes)
6374
{
64-
m_qNet.Randomize();
75+
if (brainBytes != null)
76+
JsonConvert.PopulateObject(brainBytes.DecompressToString(), this);
6577
return this;
6678
}
6779

G33kShell.Desktop/Console/Screensavers/AI/AiGameBase.cs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// THE SOFTWARE IS PROVIDED AS IS, WITHOUT WARRANTY OF ANY KIND.
1111
using System;
1212
using System.Collections.Generic;
13+
using JetBrains.Annotations;
1314

1415
namespace G33kShell.Desktop.Console.Screensavers.AI;
1516

@@ -38,11 +39,11 @@ public abstract class AiGameBase
3839

3940
public abstract IEnumerable<(string Name, string Value)> ExtraGameStats();
4041

41-
protected AiGameBase(int arenaWidth, int arenaHeight, AiBrainBase brain)
42+
protected AiGameBase(int arenaWidth, int arenaHeight, [NotNull] AiBrainBase brain)
4243
{
4344
ArenaWidth = arenaWidth;
4445
ArenaHeight = arenaHeight;
45-
Brain = brain;
46+
Brain = brain ?? throw new ArgumentNullException(nameof(brain));
4647
}
4748

4849
/// <summary>
@@ -54,10 +55,4 @@ protected AiGameBase(int arenaWidth, int arenaHeight, AiBrainBase brain)
5455
/// Reset all game state back to the same initial conditions.
5556
/// </summary>
5657
public abstract AiGameBase ResetGame();
57-
58-
public void LoadBrainData(byte[] brainBytes)
59-
{
60-
if (brainBytes != null && brainBytes.Length > 0)
61-
Brain.Load(brainBytes);
62-
}
6358
}

G33kShell.Desktop/Console/Screensavers/AI/AiGameCanvasBase.cs

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ protected void TrainAi(ScreenData screen, Action<byte[]> saveBrainBytes)
5959
}
6060

6161
System.Console.WriteLine("Starting training...");
62-
var brain = CreateGame().Brain;
62+
var brain = CreateBrain();
6363
System.Console.WriteLine($"Brain layers: {brain.InputSize} : {brain.HiddenLayers.ToCsv(' ')} : {brain.OutputSize}");
6464

6565
m_stopTraining = false;
@@ -85,7 +85,7 @@ public override void StopScreensaver()
8585

8686
private void TrainAiImpl(Action<byte[]> saveBrainBytes)
8787
{
88-
m_nextGenBrains ??= Enumerable.Range(0, InitialPopSize).Select(_ => CreateGame().Brain).ToList();
88+
m_nextGenBrains ??= Enumerable.Range(0, InitialPopSize).Select(_ => CreateBrain()).ToList();
8989

9090
var games = new (double AverageRating, AiGameBase Game, AiBrainBase Brain)[m_nextGenBrains.Count];
9191
Parallel.For(0, games.Length, i =>
@@ -102,13 +102,10 @@ private void TrainAiImpl(Action<byte[]> saveBrainBytes)
102102
var gameCount = 1;
103103
for (var trial = 0; trial < 4 && !m_stopTraining; trial++, gameCount++)
104104
{
105-
var game = CreateGameWithSeed(Random.Shared.Next());
106-
game.Brain = baseGame.Brain;
105+
var game = CreateGameWithSeed(Random.Shared.Next(), baseGame.Brain);
107106
while (!game.IsGameOver)
108107
game.Tick();
109108
totalRating += game.Rating;
110-
if (game.Rating <= 0.0001)
111-
break; // No score, no point in continuing.
112109
}
113110

114111
totalRating /= gameCount;
@@ -170,7 +167,7 @@ private void TrainAiImpl(Action<byte[]> saveBrainBytes)
170167
nextBrains.AddRange(m_goatBrains.Select(o => o.Brain.Clone()));
171168

172169
// Spawn 5% pure randoms.
173-
nextBrains.AddRange(Enumerable.Range(0, (int)(m_currentPopSize * 0.05)).Select(_ => CreateGameWithSeed(0).Brain.Clone().Randomize()));
170+
nextBrains.AddRange(Enumerable.Range(0, (int)(m_currentPopSize * 0.05)).Select(_ => CreateBrain()));
174171

175172
// Elite get to be parents.
176173
var breeders = orderedGames.Select(o => (o.AverageRating, o.Brain)).ToList();
@@ -187,13 +184,14 @@ private void TrainAiImpl(Action<byte[]> saveBrainBytes)
187184
m_nextGenBrains = nextBrains;
188185
}
189186

190-
private AiGameBase CreateGameWithSeed(int seed)
187+
private AiGameBase CreateGameWithSeed(int seed, AiBrainBase brain = null)
191188
{
192-
var game = CreateGame();
189+
var game = CreateGame(brain ?? CreateBrain());
193190
game.GameRand = new Random(seed);
194191
game.ResetGame();
195192
return game;
196193
}
197194

198-
protected abstract AiGameBase CreateGame();
195+
protected abstract AiGameBase CreateGame(AiBrainBase brain);
196+
protected abstract AiBrainBase CreateBrain();
199197
}

G33kShell.Desktop/Console/Screensavers/AI/IAiGameState.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ namespace G33kShell.Desktop.Console.Screensavers.AI;
1515
/// </summary>
1616
public interface IAiGameState
1717
{
18-
double[] ToInputVector();
18+
void FillInputVector(double[] inputVector);
1919
}

G33kShell.Desktop/Console/Screensavers/Asteroids/AsteroidsCanvas.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ private void PlayGame(ScreenData screen)
4747
{
4848
if (m_game == null)
4949
{
50-
m_game = (Game)CreateGame().ResetGame();
51-
m_game.LoadBrainData(Settings.Instance.AsteroidsBrain);
50+
var brain = new Brain().Load(Settings.Instance.AsteroidsBrain);
51+
m_game = (Game)CreateGame(brain).ResetGame();
5252
}
5353

5454
m_game.Tick();
@@ -105,6 +105,6 @@ private void DrawGame(ScreenData screen, AiGameBase aiGame)
105105
screen.PrintAt(2, 0, $"Score: {game.Score.ToString().PadLeft(5, '0')} Shield: {game.Ship.Shield.ToProgressBar(73)}");
106106
}
107107

108-
protected override AiGameBase CreateGame() =>
109-
new Game(ArenaWidth, ArenaHeight * 2);
108+
protected override AiGameBase CreateGame(AiBrainBase brain) => new Game(ArenaWidth, ArenaHeight * 2, (Brain)brain);
109+
protected override AiBrainBase CreateBrain() => new Brain();
110110
}

G33kShell.Desktop/Console/Screensavers/Asteroids/Brain.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,7 @@ namespace G33kShell.Desktop.Console.Screensavers.Asteroids;
1515

1616
public class Brain : AiBrainBase
1717
{
18-
public const int BrainInputCount = 6;
19-
20-
public Brain() : base(BrainInputCount, [16, 8], 4)
18+
public Brain() : base(6, [16, 8], 4)
2119
{
2220
}
2321

G33kShell.Desktop/Console/Screensavers/Asteroids/Game.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,10 @@ public override double Rating
3737
{
3838
get
3939
{
40-
if (TurnEquality < 0.3)
41-
return 0.0; // Penalize wonky-turners.
4240
if (m_thrustTicks == 0)
4341
return 0.0; // Penalize non-thrusters.
4442

45-
var rating = 1.0 + Score * HitRatio * 0.001;
43+
var rating = 1.0 + Score * Score * HitRatio * Math.Min(m_gameTicks, 20 * 60) * TurnEquality * 0.00001;
4644
return rating * rating;
4745
}
4846
}
@@ -74,7 +72,7 @@ private double TurnEquality
7472

7573
private double HitRatio => m_bulletsFired == 0 ? 0.0 : (double)Score / m_bulletsFired;
7674

77-
public Game(int arenaWidth, int arenaHeight) : base(arenaWidth, arenaHeight, new Brain())
75+
public Game(int arenaWidth, int arenaHeight, Brain brain) : base(arenaWidth, arenaHeight, brain)
7876
{
7977
}
8078

@@ -150,7 +148,6 @@ public override void Tick()
150148
Bullets.RemoveAll(o => o.IsExpired);
151149

152150
// Check for bullet/asteroid collisions.
153-
var hitDetected = false;
154151
if (Bullets.Count > 0)
155152
{
156153
var bulletsToRemove = new List<Bullet>();
@@ -176,7 +173,6 @@ public override void Tick()
176173
hitAsteroid.Explode(Asteroids);
177174
Score++;
178175
m_ticksSinceScore = 0;
179-
hitDetected = true;
180176
}
181177

182178
for (var i = 0; i < bulletsToRemove.Count; i++)

G33kShell.Desktop/Console/Screensavers/Asteroids/GameState.cs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ namespace G33kShell.Desktop.Console.Screensavers.Asteroids;
2121
/// </summary>
2222
public class GameState : IAiGameState
2323
{
24-
private readonly double[] m_inputVector = new double[Brain.BrainInputCount];
2524
private readonly Ship m_ship;
2625
private readonly List<Asteroid> m_asteroids;
2726
private readonly float m_arenaDiagonal;
@@ -34,32 +33,30 @@ public GameState(Ship ship, List<Asteroid> asteroids, int arenaWidth, int arenaH
3433
m_arenaDiagonal = new Vector2(arenaWidth, arenaHeight).Length();
3534
}
3635

37-
public double[] ToInputVector()
36+
public void FillInputVector(double[] inputVector)
3837
{
3938
// Bias.
40-
m_inputVector[0] = 1.0;
39+
inputVector[0] = 1.0;
4140

4241
// Find nearest asteroid.
4342
var asteroid = m_asteroids.Count == 0 ? null : m_asteroids.FastFindMin(o => Vector2.DistanceSquared(o.Position, m_ship.Position));
4443
if (asteroid != null)
4544
{
4645
var relativePos = asteroid.Position - m_ship.Position;
4746
var angleToAsteroid = Vector2.Dot(Vector2.Normalize(relativePos), m_ship.Theta.ToDirection()).Clamp(-1.0f, 1.0f);
48-
m_inputVector[1] = angleToAsteroid;
49-
m_inputVector[2] = 1.0 - relativePos.Length() / m_arenaDiagonal;
47+
inputVector[1] = angleToAsteroid;
48+
inputVector[2] = 1.0 - relativePos.Length() / m_arenaDiagonal;
5049
}
5150
else
5251
{
53-
m_inputVector[1] = 0.0;
54-
m_inputVector[2] = 0.0;
52+
inputVector[1] = 0.0;
53+
inputVector[2] = 0.0;
5554
}
5655

5756
var relativeVelocity = asteroid?.Velocity ?? Vector2.Zero - m_ship.Velocity;
5857
relativeVelocity = relativeVelocity.LengthSquared() > 0.0f ? Vector2.Normalize(relativeVelocity) : Vector2.Zero;
59-
m_inputVector[3] = relativeVelocity.X.Clamp(-1.0f, 1.0f);
60-
m_inputVector[4] = relativeVelocity.Y.Clamp(-1.0f, 1.0f);
61-
m_inputVector[5] = (m_ship.Theta / Math.Tau).Clamp(-1.0f, 1.0f);
62-
63-
return m_inputVector;
58+
inputVector[3] = relativeVelocity.X.Clamp(-1.0f, 1.0f);
59+
inputVector[4] = relativeVelocity.Y.Clamp(-1.0f, 1.0f);
60+
inputVector[5] = (m_ship.Theta / Math.Tau).Clamp(-1.0f, 1.0f);
6461
}
6562
}

G33kShell.Desktop/Console/Screensavers/Pong/Game.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public override double Rating
5858
yield return ("Rallies", m_rallies.ToString());
5959
}
6060

61-
public Game(int arenaWidth, int arenaHeight) : base(arenaWidth, arenaHeight, new Brain())
61+
public Game(int arenaWidth, int arenaHeight, Brain brain) : base(arenaWidth, arenaHeight, brain)
6262
{
6363
}
6464

G33kShell.Desktop/Console/Screensavers/Pong/GameState.cs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ public GameState(Vector2[] bats, Vector2 ballPosition, Vector2 ballVelocity, int
3434
m_arenaHeight = arenaHeight;
3535
}
3636

37-
public double[] ToInputVector()
37+
public void FillInputVector(double[] inputVector)
3838
{
39-
var inputVector = new double[Brain.BrainInputCount];
40-
4139
// Encode bat positions.
4240
inputVector[0] = m_bats[0].Y / m_arenaHeight * 2.0f - 1.0f;
4341
inputVector[1] = m_bats[1].Y / m_arenaHeight * 2.0f - 1.0f;
@@ -50,7 +48,5 @@ public double[] ToInputVector()
5048
inputVector[5] = m_ballPosition.Y / m_arenaHeight * 2.0f - 1.0f;
5149
inputVector[6] = m_ballVelocity.X.Clamp(-1.0f, 1.0f);
5250
inputVector[7] = m_ballVelocity.Y.Clamp(-1.0f, 1.0f);
53-
54-
return inputVector;
5551
}
5652
}

0 commit comments

Comments
 (0)