Skip to content

Commit 03aa035

Browse files
authored
Added a model subclassing sample: ResNetBlock (#2)
also: use Gradient preview 6.2
1 parent f8d26f2 commit 03aa035

File tree

16 files changed

+140
-11
lines changed

16 files changed

+140
-11
lines changed

BasicMath/BasicMath.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
</PropertyGroup>
99

1010
<ItemGroup>
11-
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.1" />
11+
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.2" />
1212
</ItemGroup>
1313

1414
</Project>

CharRNN/CharRNN.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
<ItemGroup>
1717
<PackageReference Include="CommandLineParser" Version="2.3.0" />
18-
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.1" />
18+
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.2" />
1919
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" />
2020
</ItemGroup>
2121

FSharp/BasicMathF/BasicMathF.fsproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
<ItemGroup>
1313
<PackageReference Include="FSharp.Interop.Dynamic" Version="4.0.3.130" />
14-
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.1" />
14+
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.2" />
1515
</ItemGroup>
1616

1717
</Project>

FSharp/FashionMnistF/FashionMnist.fs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,6 @@ let main argv =
4242
let accuracy = Core.Operators.float (Dyn.getIndex evalResult [1] : numpy.float64)
4343
printfn "Test accuracy: %f" accuracy
4444

45+
model.summary()
46+
4547
0 // return an integer exit code

FSharp/FashionMnistF/FashionMnistF.fsproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111

1212
<ItemGroup>
1313
<PackageReference Include="FSharp.Interop.Dynamic" Version="4.0.3.130" />
14-
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.1" />
14+
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.2" />
1515
</ItemGroup>
1616
</Project>

FashionMnistClassification/FashionMnistClassification.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
</PropertyGroup>
1010

1111
<ItemGroup>
12-
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.1" />
12+
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.2" />
1313
</ItemGroup>
1414

1515
</Project>

GPT-2/GPT-2.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
<ItemGroup>
2323
<PackageReference Include="CsvHelper" Version="12.1.2" />
24-
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.1" />
24+
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.2" />
2525
<PackageReference Include="ManyConsole.CommandLineUtils" Version="1.0.3-alpha" />
2626
<PackageReference Include="morelinq" Version="3.1.0" />
2727
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" />

Gradient-Samples.sln

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ Project("{778DAE3C-4631-46EA-AA77-85C1314464D9}") = "BasicMathVB", "VB\BasicMath
3131
EndProject
3232
Project("{778DAE3C-4631-46EA-AA77-85C1314464D9}") = "FashionMnistVB", "VB\FashionMnistVB\FashionMnistVB.vbproj", "{1077A08A-AB90-4286-BB0C-0F2E9D620595}"
3333
EndProject
34-
Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "FashionMnistF", "FSharp\FashionMnistF\FashionMnistF.fsproj", "{670A666C-1A56-40B0-874D-72673623112E}"
34+
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "FashionMnistF", "FSharp\FashionMnistF\FashionMnistF.fsproj", "{670A666C-1A56-40B0-874D-72673623112E}"
35+
EndProject
36+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ResNetBlock", "ResNetBlock\ResNetBlock.csproj", "{C4832B9C-793B-4D6A-B19F-B8FCAE16AACE}"
3537
EndProject
3638
Global
3739
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -79,6 +81,10 @@ Global
7981
{670A666C-1A56-40B0-874D-72673623112E}.Debug|Any CPU.Build.0 = Debug|Any CPU
8082
{670A666C-1A56-40B0-874D-72673623112E}.Release|Any CPU.ActiveCfg = Release|Any CPU
8183
{670A666C-1A56-40B0-874D-72673623112E}.Release|Any CPU.Build.0 = Release|Any CPU
84+
{C4832B9C-793B-4D6A-B19F-B8FCAE16AACE}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
85+
{C4832B9C-793B-4D6A-B19F-B8FCAE16AACE}.Debug|Any CPU.Build.0 = Debug|Any CPU
86+
{C4832B9C-793B-4D6A-B19F-B8FCAE16AACE}.Release|Any CPU.ActiveCfg = Release|Any CPU
87+
{C4832B9C-793B-4D6A-B19F-B8FCAE16AACE}.Release|Any CPU.Build.0 = Release|Any CPU
8288
EndGlobalSection
8389
GlobalSection(SolutionProperties) = preSolution
8490
HideSolutionNode = FALSE

LinearSVM/LinearSVM.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
</PropertyGroup>
88

99
<ItemGroup>
10-
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.1" />
10+
<PackageReference Include="Gradient" Version="0.1.10-tech-preview6.2" />
1111
<PackageReference Include="ManyConsole.CommandLineUtils" Version="1.0.3-alpha" />
1212
</ItemGroup>
1313

ResNetBlock/ResNetBlock.cs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
namespace Gradient.Samples {
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using Gradient.ManualWrappers;
5+
using SharPy.Runtime;
6+
using tensorflow;
7+
using tensorflow.keras;
8+
using tensorflow.keras.layers;
9+
10+
class ResNetBlock: Model {
11+
const int PartCount = 3;
12+
readonly PythonList<Conv2D> convs = new PythonList<Conv2D>();
13+
readonly PythonList<BatchNormalization> batchNorms = new PythonList<BatchNormalization>();
14+
public ResNetBlock(int kernelSize, int[] filters) {
15+
for (int part = 0; part < PartCount; part++) {
16+
this.convs.Add(this.Track(part == 1
17+
? Conv2D.NewDyn(filters[part], kernel_size: kernelSize, padding: "same")
18+
: Conv2D.NewDyn(filters[part], kernel_size: (1, 1))));
19+
this.batchNorms.Add(this.Track(new BatchNormalization()));
20+
}
21+
}
22+
23+
public override dynamic call(IEnumerable<IGraphNodeBase> inputs, ImplicitContainer<IGraphNodeBase> training, IGraphNodeBase mask) {
24+
return this.callImpl((Tensor)inputs.Single(), training);
25+
}
26+
27+
public override object call(object inputs, bool training, IGraphNodeBase mask = null) {
28+
return this.callImpl((Tensor)inputs, training);
29+
}
30+
31+
public override dynamic call(object inputs, ImplicitContainer<IGraphNodeBase> training = null, IEnumerable<IGraphNodeBase> mask = null) {
32+
return this.callImpl((Tensor)inputs, training?.Value);
33+
}
34+
35+
object callImpl(IGraphNodeBase inputs, dynamic training) {
36+
IGraphNodeBase result = inputs;
37+
38+
var batchNormExtraArgs = new PythonDict<string, object>();
39+
if (training != null)
40+
batchNormExtraArgs["training"] = training;
41+
42+
for (int part = 0; part < PartCount; part++) {
43+
result = this.convs[part].apply(result);
44+
result = this.batchNorms[part].apply(result, kwargs: batchNormExtraArgs);
45+
if (part + 1 != PartCount)
46+
result = tf.nn.relu(result);
47+
}
48+
49+
result += (Tensor)result + inputs;
50+
51+
return tf.nn.relu(result);
52+
}
53+
}
54+
}

0 commit comments

Comments
 (0)