Skip to content

Commit cafae99

Browse files
committed
Other: Multi-game based AI training. (WIP)
1 parent 9e5cde8 commit cafae99

File tree

5 files changed

+126
-91
lines changed

5 files changed

+126
-91
lines changed

G33kShell.Desktop/Console/Screensavers/AI/AiGameCanvasBase.cs

Lines changed: 65 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@ namespace G33kShell.Desktop.Console.Screensavers.AI;
2323
/// </summary>
2424
public abstract class AiGameCanvasBase : ScreensaverBase
2525
{
26-
private int m_generation;
27-
private double m_savedRating;
2826
private const int InitialPopSize = 300;
2927
private const int MinPopSize = 150;
28+
private const int MaxGoatBrains = 5;
29+
30+
private readonly List<(double Rating, AiBrainBase Brain)> m_goatBrains = new List<(double Rating, AiBrainBase Brain)>(MaxGoatBrains);
31+
private int m_generation;
32+
private double m_savedRating;
3033
private int m_generationsSinceImprovement;
3134
private int m_currentPopSize = InitialPopSize;
3235
private Task m_trainingTask;
3336
private bool m_stopTraining;
34-
35-
protected AiGameBase[] m_games;
37+
private List<AiBrainBase> m_nextGenBrains;
3638

3739
protected int ArenaWidth { get; }
3840
protected int ArenaHeight { get; }
@@ -83,35 +85,70 @@ public override void StopScreensaver()
8385

8486
private void TrainAiImpl(Action<byte[]> saveBrainBytes)
8587
{
86-
m_games ??= Enumerable.Range(0, m_currentPopSize).Select(_ => CreateGameWithSeed(m_generation)).ToArray();
88+
m_nextGenBrains ??= Enumerable.Range(0, InitialPopSize).Select(_ => CreateGame().Brain).ToList();
8789

88-
Parallel.For(0, m_games.Length, i =>
90+
var games = new (double AverageRating, AiGameBase Game, AiBrainBase Brain)[m_nextGenBrains.Count];
91+
Parallel.For(0, games.Length, i =>
8992
{
90-
var game = m_games[i];
91-
while (!game.IsGameOver)
92-
game.Tick();
93+
// Play the base game.
94+
var baseGame = CreateGameWithSeed(m_generation);
95+
while (!baseGame.IsGameOver && !m_stopTraining)
96+
baseGame.Tick();
97+
98+
var totalRating = baseGame.Rating;
99+
if (baseGame.Rating > 0.0)
100+
{
101+
// Play several more games.
102+
var gameCount = 1;
103+
for (var trial = 0; trial < 4 && !m_stopTraining; trial++, gameCount++)
104+
{
105+
var game = CreateGameWithSeed(Random.Shared.Next());
106+
game.Brain = baseGame.Brain;
107+
while (!game.IsGameOver)
108+
game.Tick();
109+
totalRating += game.Rating;
110+
if (game.Rating <= 0.0001)
111+
break; // No score, no point in continuing.
112+
}
113+
114+
totalRating /= gameCount;
115+
}
116+
117+
games[i] = (totalRating, baseGame, baseGame.Brain);
93118
});
94119

95120
// Select the breeders.
96-
var orderedGames = m_games.OrderByDescending(o => o.Rating).ToArray();
97-
121+
var orderedGames = games.OrderByDescending(o => o.AverageRating).ToArray();
122+
var theBest = orderedGames[0];
123+
98124
// Report summary of results.
99125
m_generation++;
100-
var veryBest = orderedGames[0];
101-
var stats = $"Gen {m_generation}|Pop {m_currentPopSize}|Rating {veryBest.Rating:F1}|GOAT {m_savedRating:F1}";
102-
var extraStats = veryBest.ExtraGameStats().Select(o => $" {o.Name}: {o.Value}").ToArray().ToCsv().Trim();
126+
var stats = $"Gen {m_generation}|Pop {m_currentPopSize}|Rating {theBest.AverageRating:F1}|GOAT {m_savedRating:F1}";
127+
var extraStats = theBest.Game.ExtraGameStats().Select(o => $" {o.Name}: {o.Value}").ToArray().ToCsv().Trim();
103128
if (!string.IsNullOrEmpty(extraStats))
104129
stats += $"|{extraStats}";
105130
System.Console.WriteLine(stats);
131+
132+
// Remember the GOAT brains.
133+
var worstGoatRating = m_goatBrains.Count > 0 ? m_goatBrains.FastFindMin(o => o.Rating).Rating : 0.0;
134+
for (var i = 0; i < orderedGames.Length; i++)
135+
{
136+
if (orderedGames[i].AverageRating > worstGoatRating)
137+
m_goatBrains.Add((orderedGames[i].AverageRating, orderedGames[i].Brain));
138+
}
139+
140+
while (m_goatBrains.Count > MaxGoatBrains)
141+
{
142+
var toRemove = m_goatBrains.FastFindMin(o => o.Rating);
143+
m_goatBrains.Remove(toRemove);
144+
}
106145

107146
// Persist brain improvements.
108-
AiBrainBase goatBrain = null;
109-
if (veryBest.Rating > m_savedRating)
147+
if (theBest.AverageRating > m_savedRating)
110148
{
111-
m_savedRating = veryBest.Rating;
149+
m_savedRating = theBest.AverageRating;
112150
System.Console.WriteLine("Saved.");
113-
saveBrainBytes(veryBest.Brain.Save());
114-
goatBrain = veryBest.Brain.Clone();
151+
saveBrainBytes(theBest.Brain.Save());
115152

116153
m_generationsSinceImprovement = 0;
117154
}
@@ -124,39 +161,30 @@ private void TrainAiImpl(Action<byte[]> saveBrainBytes)
124161
{
125162
m_generationsSinceImprovement = 0;
126163
m_currentPopSize = InitialPopSize;
127-
System.Console.WriteLine("Stagnation detected Increasing population size.");
164+
System.Console.WriteLine("Stagnation detected - Increasing population size.");
128165
}
129166
}
130167

131168
// Build the brains for the next generation.
132-
var nextBrains = new List<AiBrainBase>(m_games.Length);
133-
134-
// The GOAT lives on.
135-
if (goatBrain != null)
136-
{
137-
nextBrains.Add(goatBrain.Clone());
138-
nextBrains.AddRange(Enumerable.Range(0, (int)(m_currentPopSize * 0.05)).Select(_ => goatBrain.Clone().Mutate(0.03)));
139-
}
169+
var nextBrains = new List<AiBrainBase>(games.Length);
170+
nextBrains.AddRange(m_goatBrains.Select(o => o.Brain.Clone()));
140171

141172
// Spawn 5% pure randoms.
142-
nextBrains.AddRange(Enumerable.Range(0, (int)(m_currentPopSize * 0.05)).Select(_ => veryBest.Brain.Clone().Randomize()));
173+
nextBrains.AddRange(Enumerable.Range(0, (int)(m_currentPopSize * 0.05)).Select(_ => CreateGameWithSeed(0).Brain.Clone().Randomize()));
143174

144175
// Elite get to be parents.
176+
var breeders = orderedGames.Select(o => (o.AverageRating, o.Brain)).ToList();
177+
breeders.AddRange(m_goatBrains);
145178
while (nextBrains.Count < m_currentPopSize)
146179
{
147-
var mumBrain = Random.Shared.RouletteSelection(m_games, o => o.Rating).Brain;
148-
var dadBrain = Random.Shared.RouletteSelection(m_games, o => o.Rating).Brain;
180+
var mumBrain = Random.Shared.RouletteSelection(breeders, o => o.AverageRating).Brain;
181+
var dadBrain = Random.Shared.RouletteSelection(breeders, o => o.AverageRating).Brain;
149182
var childBrain = mumBrain.Clone().CrossWith(dadBrain, 0.5).Mutate(0.05);
150183
nextBrains.Add(childBrain);
151184
}
152185

153186
// Make the next generation of games.
154-
m_games = nextBrains.Select(o =>
155-
{
156-
var newGame = CreateGameWithSeed(m_generation);
157-
newGame.Brain = o;
158-
return newGame;
159-
}).ToArray();
187+
m_nextGenBrains = nextBrains;
160188
}
161189

162190
private AiGameBase CreateGameWithSeed(int seed)

G33kShell.Desktop/Console/Screensavers/Asteroids/AsteroidsCanvas.cs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ namespace G33kShell.Desktop.Console.Screensavers.Asteroids;
2525
[UsedImplicitly]
2626
public class AsteroidsCanvas : AiGameCanvasBase
2727
{
28+
private Game m_game;
29+
2830
public AsteroidsCanvas(int screenWidth, int screenHeight) : base(screenWidth, screenHeight, 60)
2931
{
3032
Name = "asciiroids";
@@ -43,17 +45,17 @@ public override void UpdateFrame(ScreenData screen)
4345
[UsedImplicitly]
4446
private void PlayGame(ScreenData screen)
4547
{
46-
if (m_games == null)
48+
if (m_game == null)
4749
{
48-
m_games = [CreateGame().ResetGame()];
49-
m_games[0].LoadBrainData(Settings.Instance.AsteroidsBrain);
50+
m_game = (Game)CreateGame().ResetGame();
51+
m_game.LoadBrainData(Settings.Instance.AsteroidsBrain);
5052
}
5153

52-
m_games[0].Tick();
53-
DrawGame(screen, m_games[0]);
54+
m_game.Tick();
55+
DrawGame(screen, m_game);
5456

55-
if (((Game)m_games[0]).IsGameOver)
56-
m_games[0].ResetGame();
57+
if (m_game.IsGameOver)
58+
m_game.ResetGame();
5759
}
5860

5961
private void DrawGame(ScreenData screen, AiGameBase aiGame)

G33kShell.Desktop/Console/Screensavers/Asteroids/Game.cs

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@
1717

1818
namespace G33kShell.Desktop.Console.Screensavers.Asteroids;
1919

20-
[DebuggerDisplay("Rating = {Rating}, Score = {Score}, Lives = {m_lives}")]
20+
[DebuggerDisplay("Rating = {Rating}, Score = {Score}")]
2121
public class Game : AiGameBase
2222
{
2323
private int m_bulletsFired;
2424
private int m_gameTicks;
2525
private int m_leftTurns;
2626
private int m_rightTurns;
2727
private GameState m_gameState;
28-
private int m_lives;
2928
private int m_thrustTicks;
29+
private int m_ticksSinceScore;
3030

3131
/// <summary>
3232
/// Score used for public display.
@@ -42,7 +42,8 @@ public override double Rating
4242
if (m_thrustTicks == 0)
4343
return 0.0; // Penalize non-thrusters.
4444

45-
return Score * HitRatio * HitRatio * m_gameTicks * 0.001;
45+
var rating = 1.0 + Score * HitRatio * 0.001;
46+
return rating * rating;
4647
}
4748
}
4849

@@ -66,7 +67,7 @@ private double TurnEquality
6667
}
6768
}
6869

69-
public override bool IsGameOver => m_lives == 0 || m_gameTicks > 200_000;
70+
public override bool IsGameOver => m_ticksSinceScore >= 200_000 || Ship.Shield <= 0.02;
7071
public Ship Ship { get; private set; }
7172
public List<Asteroid> Asteroids { get; } = [];
7273
public List<Bullet> Bullets { get; } = [];
@@ -90,7 +91,7 @@ public override Game ResetGame()
9091
m_leftTurns = 0;
9192
m_rightTurns = 0;
9293
m_thrustTicks = 0;
93-
m_lives = 3;
94+
m_ticksSinceScore = 0;
9495

9596
EnsureMinimumAsteroidCount();
9697

@@ -115,6 +116,7 @@ public override void Tick()
115116
return;
116117

117118
m_gameTicks++;
119+
m_ticksSinceScore++;
118120

119121
// Spawn asteroids.
120122
EnsureMinimumAsteroidCount();
@@ -148,32 +150,38 @@ public override void Tick()
148150
Bullets.RemoveAll(o => o.IsExpired);
149151

150152
// Check for bullet/asteroid collisions.
151-
var bulletsToRemove = new List<Bullet>();
152-
for (var index = 0; index < Bullets.Count; index++)
153+
var hitDetected = false;
154+
if (Bullets.Count > 0)
153155
{
154-
var bullet = Bullets[index];
155-
Asteroid hitAsteroid = null;
156-
// ReSharper disable once ForCanBeConvertedToForeach
157-
// ReSharper disable once LoopCanBeConvertedToQuery
158-
for (var i = 0; i < Asteroids.Count; i++)
156+
var bulletsToRemove = new List<Bullet>();
157+
for (var index = 0; index < Bullets.Count; index++)
159158
{
160-
if (Asteroids[i].IsInvulnerable || !Asteroids[i].Contains(bullet.Position))
161-
continue;
162-
hitAsteroid = Asteroids[i];
163-
break;
164-
}
159+
var bullet = Bullets[index];
160+
Asteroid hitAsteroid = null;
161+
// ReSharper disable once ForCanBeConvertedToForeach
162+
// ReSharper disable once LoopCanBeConvertedToQuery
163+
for (var i = 0; i < Asteroids.Count; i++)
164+
{
165+
if (Asteroids[i].IsInvulnerable || !Asteroids[i].Contains(bullet.Position))
166+
continue;
167+
hitAsteroid = Asteroids[i];
168+
break;
169+
}
165170

166-
if (hitAsteroid == null)
167-
continue; // Bullet not hitting anything.
171+
if (hitAsteroid == null)
172+
continue; // Bullet not hitting anything.
168173

169-
// Bullet hit an asteroid.
170-
bulletsToRemove.Add(bullet);
171-
hitAsteroid.Explode(Asteroids);
172-
Score++;
173-
}
174+
// Bullet hit an asteroid.
175+
bulletsToRemove.Add(bullet);
176+
hitAsteroid.Explode(Asteroids);
177+
Score++;
178+
m_ticksSinceScore = 0;
179+
hitDetected = true;
180+
}
174181

175-
for (var i = 0; i < bulletsToRemove.Count; i++)
176-
Bullets.Remove(bulletsToRemove[i]);
182+
for (var i = 0; i < bulletsToRemove.Count; i++)
183+
Bullets.Remove(bulletsToRemove[i]);
184+
}
177185

178186
// Check for ship/asteroid collisions.
179187
const float shipRadius = 4.5f;
@@ -182,18 +190,12 @@ public override void Tick()
182190
var distance = Asteroids[i].DistanceTo(Ship.Position);
183191
if (distance < Asteroids[i].Radius + shipRadius)
184192
{
185-
var newShield = Ship.Shield - 0.08;
186-
if (newShield < 0.0)
193+
Ship.Shield = Math.Max(0.0, Ship.Shield - 0.08);
194+
if (Ship.Shield <= 0.001)
187195
{
188-
// Ship is dead - Reset asteroids and bullets.
189-
m_lives--;
190-
Asteroids.Clear();
191-
Bullets.Clear();
192-
Ship.Reset();
196+
// Ship is dead.
193197
return;
194198
}
195-
196-
Ship.Shield = newShield;
197199
break;
198200
}
199201
}

G33kShell.Desktop/Console/Screensavers/Pong/PongCanvas.cs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace G33kShell.Desktop.Console.Screensavers.Pong;
2727
public class PongCanvas : AiGameCanvasBase
2828
{
2929
private readonly FIGletFont m_font;
30+
private Game m_game;
3031

3132
public PongCanvas(int screenWidth, int screenHeight) : base(screenWidth, screenHeight, 60)
3233
{
@@ -49,17 +50,17 @@ public override void UpdateFrame(ScreenData screen)
4950
[UsedImplicitly]
5051
private void PlayGame(ScreenData screen)
5152
{
52-
if (m_games == null)
53+
if (m_game == null)
5354
{
54-
m_games = [CreateGame().ResetGame()];
55-
m_games[0].LoadBrainData(Settings.Instance.PongBrain);
55+
m_game = (Game)CreateGame().ResetGame();
56+
m_game.LoadBrainData(Settings.Instance.PongBrain);
5657
}
5758

58-
DrawGame(screen, m_games[0]);
59+
DrawGame(screen, m_game);
5960

60-
m_games[0].Tick();
61-
if (((Game)m_games[0]).IsGameOver)
62-
m_games[0].ResetGame();
61+
m_game.Tick();
62+
if (m_game.IsGameOver)
63+
m_game.ResetGame();
6364
}
6465

6566
private void DrawGame(ScreenData screen, AiGameBase aiGame)

G33kShell.Desktop/Console/Screensavers/Snake/SnakeCanvas.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ namespace G33kShell.Desktop.Console.Screensavers.Snake;
2222
[UsedImplicitly]
2323
public class SnakeCanvas : AiGameCanvasBase
2424
{
25+
private Game m_game;
26+
2527
public SnakeCanvas(int screenWidth, int screenHeight) : base(screenWidth, screenHeight, 60)
2628
{
2729
Name = "snake";
@@ -40,15 +42,15 @@ public override void UpdateFrame(ScreenData screen)
4042
[UsedImplicitly]
4143
private void PlayGame(ScreenData screen)
4244
{
43-
if (m_games == null)
45+
if (m_game == null)
4446
{
45-
m_games = [new Game(ArenaWidth, ArenaHeight).ResetGame()];
46-
m_games[0].LoadBrainData(Settings.Instance.SnakeBrain);
47+
m_game = (Game)new Game(ArenaWidth, ArenaHeight).ResetGame();
48+
m_game.LoadBrainData(Settings.Instance.SnakeBrain);
4749
}
4850

49-
DrawGame(screen, m_games[0]);
51+
DrawGame(screen, m_game);
5052

51-
m_games[0].Tick();
53+
m_game.Tick();
5254
}
5355

5456
private static void DrawGame(ScreenData screen, AiGameBase aiGame)

0 commit comments

Comments
 (0)