Skip to content

Commit ca23ba7

Browse files
authored
Add linearStateRewardPredictor and fix names (#302)
Add linearStateRewardPredictor and fix names
1 parent a17b4e6 commit ca23ba7

File tree

18 files changed

+162
-40
lines changed

18 files changed

+162
-40
lines changed

docs/jlearch/jlearch-architecture.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ classDiagram
3737
UtBotSymbolicEngine *-- InterproceduralUnitGraph
3838
3939
class Predictors
40-
class NNStateRewardPredictor
40+
class StateRewardPredictor
4141
class NNRewardGuidedSelector
4242
4343
@@ -50,12 +50,13 @@ classDiagram
5050
5151
UtBotSymbolicEngine *-- BasePathSelector
5252
53-
Predictors o-- NNStateRewardPredictor
53+
Predictors o-- StateRewardPredictor
5454
NNRewardGuidedSelector ..> Predictors
5555
NNRewardGuidedSelector *-- FeatureExtractor
5656
57-
NNStateRewardPredictorSmile --|> NNStateRewardPredictor
58-
NNStateRewardPredictorTorch --|> NNStateRewardPredictor
57+
NNStateRewardPredictorSmile --|> StateRewardPredictor
58+
StateRewardPredictorTorch --|> StateRewardPredictor
59+
LinearStateRewardPredictor --|> StateRewardPredictor
5960
6061
NNStateRewardGuidedSelectorWithRecalculationWeight --|> NNRewardGuidedSelector
6162
NNStateRewardGuidedSelectorWithoutRecalculationWeight --|> NNRewardGuidedSelector
@@ -129,12 +130,13 @@ For creating `FeatureExtractor`, it uses `FeatureExtractorFactory` from `EngineA
129130
It is interface in framework-module, that allows to use implementation from analytics module.
130131
* `extractFeatures(state: ExecutionState)` - create features list for state and store it in `state.features`. Now we extract all features, which were described in [paper](https://files.sri.inf.ethz.ch/website/papers/ccs21-learch.pdf). In feature, we can extend the feature list by other features, for example, NeuroSMT.
131132

132-
# NNStateRewardPredictor
133+
# StateRewardPredictor
133134

134-
Interface for reward predictors. Now it has two implementations in `analytics` module:
135+
Interface for reward predictors. Now it has three implementations in `analytics` module:
135136

136137
* `NNStateRewardPredictorSmile`: it uses our own format to store feedforward neural network, and it uses `Smile` library to do multiplication of matrix.
137138
* `NNStateRewardPredictorTorch`: it assumed that a model is any type of model in `pt` format. It uses the `Deep Java library` to use such models.
139+
* `LinearStateRewardPredictor`: it uses our own format to store weights vector: line of doubles, separated by comma with bias as last weight.
138140

139141
It should be created at the beginning of work and stored at `Predictors` class to be used in `NNRewardGuidedSelector` from the `framework` module.
140142

utbot-analytics/src/main/kotlin/org/utbot/predictors/FeedForwardNetwork.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.utbot.predictors
22

3+
import org.utbot.predictors.util.ModelBuildingException
34
import smile.math.matrix.Matrix
45
import kotlin.math.max
56

@@ -26,7 +27,7 @@ internal fun buildModel(nnJson: NNJson): FeedForwardNetwork {
2627
operations.add {
2728
when (nnJson.activationLayers[i]) {
2829
ActivationFunctions.ReLU -> reLU(it)
29-
else -> error("Unsupported activation")
30+
else -> throw ModelBuildingException("Unsupported activation")
3031
}
3132
}
3233
}
Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package org.utbot.predictors
22

3+
import org.utbot.analytics.StateRewardPredictor
34
import mu.KotlinLogging
4-
import org.utbot.analytics.UtBotAbstractPredictor
55
import org.utbot.framework.PathSelectorType
66
import org.utbot.framework.UtSettings
7+
import org.utbot.predictors.util.PredictorLoadingException
8+
import org.utbot.predictors.util.WeightsLoadingException
9+
import org.utbot.predictors.util.splitByCommaIntoDoubleArray
10+
import smile.math.MathEx.dot
711
import smile.math.matrix.Matrix
812
import java.io.File
913

@@ -16,32 +20,39 @@ private val logger = KotlinLogging.logger {}
1620
*/
1721
private fun loadWeights(path: String): Matrix {
1822
val weightsFile = File("${UtSettings.rewardModelPath}/${path}")
23+
lateinit var weightsArray: DoubleArray
1924

20-
if (!weightsFile.exists()) {
21-
error("There is no file with weights with path: ${weightsFile.absolutePath}")
22-
}
25+
try {
26+
if (!weightsFile.exists()) {
27+
error("There is no file with weights with path: ${weightsFile.absolutePath}")
28+
}
2329

24-
val weightsArray = weightsFile.readText().splitByCommaIntoDoubleArray()
30+
weightsArray = weightsFile.readText().splitByCommaIntoDoubleArray()
31+
} catch (e: Exception) {
32+
throw WeightsLoadingException(e)
33+
}
2534

2635
return Matrix(weightsArray)
2736
}
2837

29-
class LinearStateRewardPredictor(weightsPath: String = DEFAULT_WEIGHT_PATH) :
30-
UtBotAbstractPredictor<List<List<Double>>, List<Double>> {
38+
class LinearStateRewardPredictor(weightsPath: String = DEFAULT_WEIGHT_PATH, scalerPath: String = DEFAULT_SCALER_PATH) :
39+
StateRewardPredictor {
3140
private lateinit var weights: Matrix
41+
private lateinit var scaler: StandardScaler
3242

3343
init {
3444
try {
3545
weights = loadWeights(weightsPath)
36-
} catch (e: Exception) {
46+
scaler = loadScaler(scalerPath)
47+
} catch (e: PredictorLoadingException) {
3748
logger.info(e) {
3849
"Error while initialization of LinearStateRewardPredictor. Changing pathSelectorType on INHERITORS_SELECTOR"
3950
}
4051
UtSettings.pathSelectorType = PathSelectorType.INHERITORS_SELECTOR
4152
}
4253
}
4354

44-
override fun predict(input: List<List<Double>>): List<Double> {
55+
fun predict(input: List<List<Double>>): List<Double> {
4556
// add 1 to each feature vector
4657
val matrixValues = input
4758
.map { (it + 1.0).toDoubleArray() }
@@ -51,4 +62,11 @@ class LinearStateRewardPredictor(weightsPath: String = DEFAULT_WEIGHT_PATH) :
5162

5263
return X.mm(weights).col(0).toList()
5364
}
65+
66+
override fun predict(input: List<Double>): Double {
67+
var inputArray = Matrix(input.toDoubleArray()).sub(scaler.mean).div(scaler.variance).col(0)
68+
inputArray += 1.0
69+
70+
return dot(inputArray, weights.col(0))
71+
}
5472
}

utbot-analytics/src/main/kotlin/org/utbot/predictors/NNJson.kt

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package org.utbot.predictors
22

33
import com.google.gson.Gson
44
import org.utbot.framework.UtSettings
5+
import org.utbot.predictors.util.ModelLoadingException
56
import java.io.FileReader
67
import java.nio.file.Paths
78

@@ -33,10 +34,16 @@ data class NNJson(
3334

3435
internal fun loadModel(path: String): NNJson {
3536
val modelFile = Paths.get(UtSettings.rewardModelPath, path).toFile()
36-
val nnJson: NNJson =
37-
Gson().fromJson(FileReader(modelFile), NNJson::class.java) ?: run {
38-
error("Empty model")
39-
}
37+
lateinit var nnJson: NNJson
38+
39+
try {
40+
nnJson =
41+
Gson().fromJson(FileReader(modelFile), NNJson::class.java) ?: run {
42+
error("Empty model")
43+
}
44+
} catch (e: Exception) {
45+
throw ModelLoadingException(e)
46+
}
4047

4148
return nnJson
4249
}

utbot-analytics/src/main/kotlin/org/utbot/predictors/NNStateRewardPredictor.kt

Lines changed: 0 additions & 8 deletions
This file was deleted.

utbot-analytics/src/main/kotlin/org/utbot/predictors/NNStateRewardPredictorBase.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
package org.utbot.predictors
22

33
import mu.KotlinLogging
4+
import org.utbot.analytics.StateRewardPredictor
45
import org.utbot.framework.PathSelectorType
56
import org.utbot.framework.UtSettings
7+
import org.utbot.predictors.util.PredictorLoadingException
68
import smile.math.matrix.Matrix
79

810
private const val DEFAULT_MODEL_PATH = "nn.json"
9-
private const val DEFAULT_SCALER_PATH = "scaler.txt"
1011

1112
private val logger = KotlinLogging.logger {}
1213

1314
private fun getModel(path: String) = buildModel(loadModel(path))
1415

1516
class NNStateRewardPredictorBase(modelPath: String = DEFAULT_MODEL_PATH, scalerPath: String = DEFAULT_SCALER_PATH) :
16-
NNStateRewardPredictor {
17+
StateRewardPredictor {
1718
private lateinit var nn: FeedForwardNetwork
1819
private lateinit var scaler: StandardScaler
1920

2021
init {
2122
try {
2223
nn = getModel(modelPath)
2324
scaler = loadScaler(scalerPath)
24-
} catch (e: Exception) {
25+
} catch (e: PredictorLoadingException) {
2526
logger.info(e) {
2627
"Error while initialization of NNStateRewardPredictorBase. Changing pathSelectorType on INHERITORS_SELECTOR"
2728
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package org.utbot.predictors
2+
3+
import org.utbot.analytics.StateRewardPredictorFactory
4+
import org.utbot.framework.StateRewardPredictorType
5+
import org.utbot.framework.UtSettings
6+
7+
/**
8+
* Creates [StateRewardPredictor], by checking the [UtSettings] configuration.
9+
*/
10+
class StateRewardPredictorFactoryImpl : StateRewardPredictorFactory {
11+
override operator fun invoke() = when (UtSettings.stateRewardPredictorType) {
12+
StateRewardPredictorType.BASE -> NNStateRewardPredictorBase()
13+
StateRewardPredictorType.TORCH -> StateRewardPredictorTorch()
14+
StateRewardPredictorType.LINEAR -> LinearStateRewardPredictor()
15+
}
16+
}

utbot-analytics/src/main/kotlin/org/utbot/predictors/NNStateRewardPredictorTorch.kt renamed to utbot-analytics/src/main/kotlin/org/utbot/predictors/StateRewardPredictorTorch.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ import ai.djl.ndarray.NDArray
66
import ai.djl.ndarray.NDList
77
import ai.djl.translate.Translator
88
import ai.djl.translate.TranslatorContext
9+
import org.utbot.analytics.StateRewardPredictor
910
import org.utbot.framework.UtSettings
1011
import java.io.Closeable
1112
import java.nio.file.Paths
1213

13-
class NNStateRewardPredictorTorch : NNStateRewardPredictor, Closeable {
14+
class StateRewardPredictorTorch : StateRewardPredictor, Closeable {
1415
val model: Model = Model.newInstance("model")
1516

1617
init {
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
package org.utbot.predictors
22

33
import org.utbot.framework.UtSettings
4+
import org.utbot.predictors.util.ScalerLoadingException
5+
import org.utbot.predictors.util.splitByCommaIntoDoubleArray
46
import smile.math.matrix.Matrix
57
import java.nio.file.Paths
68

9+
10+
internal const val DEFAULT_SCALER_PATH = "scaler.txt"
11+
712
data class StandardScaler(val mean: Matrix?, val variance: Matrix?)
813

914
internal fun loadScaler(path: String): StandardScaler =
10-
Paths.get(UtSettings.rewardModelPath, path).toFile().bufferedReader().use {
11-
val mean = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not mean in $path")
12-
val variance = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not variance in $path")
13-
StandardScaler(Matrix(mean), Matrix(variance))
15+
try {
16+
Paths.get(UtSettings.rewardModelPath, path).toFile().bufferedReader().use {
17+
val mean = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not mean in $path")
18+
val variance = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not variance in $path")
19+
StandardScaler(Matrix(mean), Matrix(variance))
20+
}
21+
} catch (e: Exception) {
22+
throw ScalerLoadingException(e)
1423
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.utbot.predictors.util
2+
3+
sealed class PredictorLoadingException(msg: String?, cause: Throwable? = null) : Exception(msg, cause)
4+
5+
class WeightsLoadingException(e: Throwable) : PredictorLoadingException("Error while loading weights", e)
6+
7+
class ModelLoadingException(e: Throwable) : PredictorLoadingException("Error while loading model", e)
8+
9+
class ScalerLoadingException(e: Throwable) : PredictorLoadingException("Error while loading scaler", e)
10+
11+
class ModelBuildingException(msg: String) : PredictorLoadingException(msg)

0 commit comments

Comments
 (0)