Skip to content

Commit 19e76c0

Browse files
Add Llama service and update pom
1 parent d142e38 commit 19e76c0

File tree

6 files changed

+302
-1
lines changed

6 files changed

+302
-1
lines changed

pom.xml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
<!-- force overriding property at command line, use ${maven.build.timestamp}-->
8181
<timestamp>${maven.build.timestamp}</timestamp>
8282
<maven.build.timestamp.format>yyyyMMddHHmm</maven.build.timestamp.format>
83-
<version>${version}</version>
83+
<version>${version}</version>
8484
<GitBranch>${git.branch}</GitBranch>
8585
<username>${NODE_NAME}</username>
8686
<platform>${NODE_LABELS}</platform>
@@ -614,6 +614,15 @@
614614
</dependency>
615615
<!-- LeapMotion end -->
616616

617+
<!-- Llama begin -->
618+
<dependency>
619+
<groupId>de.kherud</groupId>
620+
<artifactId>llama</artifactId>
621+
<version>1.1.0</version>
622+
<scope>provided</scope>
623+
</dependency>
624+
<!-- Llama end -->
625+
617626
<!-- LocalSpeech begin -->
618627
<dependency>
619628
<groupId>org.myrobotlab.audio</groupId>
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
package org.myrobotlab.service;
2+
3+
import de.kherud.llama.LlamaModel;
4+
import de.kherud.llama.Parameters;
5+
import org.myrobotlab.framework.Service;
6+
import org.myrobotlab.logging.Level;
7+
import org.myrobotlab.logging.LoggingFactory;
8+
import org.myrobotlab.programab.Response;
9+
import org.myrobotlab.service.config.LlamaConfig;
10+
import org.myrobotlab.service.data.Utterance;
11+
import org.myrobotlab.service.interfaces.ResponsePublisher;
12+
import org.myrobotlab.service.interfaces.UtterancePublisher;
13+
14+
import java.io.File;
15+
import java.io.FileOutputStream;
16+
import java.io.IOException;
17+
import java.net.URL;
18+
import java.nio.channels.Channels;
19+
import java.nio.channels.FileChannel;
20+
import java.nio.channels.ReadableByteChannel;
21+
import java.util.stream.StreamSupport;
22+
23+
public class Llama extends Service<LlamaConfig> implements UtterancePublisher, ResponsePublisher {
24+
private transient LlamaModel model;
25+
26+
/**
27+
* Constructor of service, reservedkey typically is a services name and inId
28+
* will be its process id
29+
*
30+
* @param reservedKey the service name
31+
* @param inId process id
32+
*/
33+
public Llama(String reservedKey, String inId) {
34+
super(reservedKey, inId);
35+
}
36+
37+
public void loadModel(String modelPath) {
38+
Parameters params = new Parameters.Builder()
39+
.setNGpuLayers(0)
40+
.setTemperature(0.7f)
41+
.setPenalizeNl(true)
42+
.setMirostat(Parameters.MiroStat.V2)
43+
.setAntiPrompt(new String[]{config.userPrompt})
44+
.build();
45+
model = new LlamaModel(modelPath, params);
46+
}
47+
48+
public Response getResponse(String text) {
49+
if (model == null) {
50+
error("Model is not loaded.");
51+
return null;
52+
}
53+
54+
String prompt = config.systemPrompt + config.systemMessage + "\n" + text + "\n";
55+
String response = StreamSupport.stream(model.generate(prompt).spliterator(), false)
56+
.map(LlamaModel.Output::toString)
57+
.reduce("", (a, b) -> a + b);
58+
59+
Utterance utterance = new Utterance();
60+
utterance.username = getName();
61+
utterance.text = response;
62+
utterance.isBot = true;
63+
utterance.channel = "";
64+
utterance.channelType = "";
65+
utterance.channelBotName = getName();
66+
utterance.channelName = "";
67+
invoke("publishUtterance", utterance);
68+
Response res = new Response("friend", getName(), response, null);
69+
invoke("publishResponse", res);
70+
return res;
71+
}
72+
73+
public String findModelPath(String model) {
74+
// First, we loop over all user-defined
75+
// model directories
76+
for (String dir : config.modelPaths) {
77+
File path = new File(dir + fs + model);
78+
if (path.exists()) {
79+
return path.getAbsolutePath();
80+
}
81+
}
82+
83+
// Now, we check our data directory for any downloaded models
84+
File path = new File(getDataDir() + fs + model);
85+
if (path.exists()) {
86+
return path.getAbsolutePath();
87+
} else if (config.modelUrls.containsKey(model)){
88+
// Model was not in data but we do have a URL for it
89+
try (FileOutputStream fileOutputStream = new FileOutputStream(path)){
90+
ReadableByteChannel readableByteChannel = Channels.newChannel(new URL(config.modelUrls.get(model)).openStream());
91+
FileChannel fileChannel = fileOutputStream.getChannel();
92+
info("Downloading model %s to path %s from URL %s", model, path, config.modelUrls.get(model));
93+
fileChannel.transferFrom(readableByteChannel, 0, Long.MAX_VALUE);
94+
} catch (IOException e) {
95+
throw new RuntimeException(e);
96+
}
97+
return path.getAbsolutePath();
98+
99+
}
100+
101+
// Cannot find the model anywhere
102+
error("Could not locate model {}, add its URL to download it or add a directory where it is located", model);
103+
return null;
104+
}
105+
106+
@Override
107+
public LlamaConfig apply(LlamaConfig c) {
108+
super.apply(c);
109+
110+
if (config.selectedModel != null && !config.selectedModel.isEmpty()) {
111+
String modelPath = findModelPath(config.selectedModel);
112+
if (modelPath != null) {
113+
loadModel(modelPath);
114+
} else {
115+
error("Could not find selected model {}", config.selectedModel);
116+
}
117+
}
118+
119+
return config;
120+
}
121+
122+
@Override
123+
public Utterance publishUtterance(Utterance utterance) {
124+
return utterance;
125+
}
126+
127+
@Override
128+
public Response publishResponse(Response response) {
129+
return response;
130+
}
131+
132+
public static void main(String[] args) {
133+
try {
134+
135+
LoggingFactory.init(Level.INFO);
136+
137+
// Runtime runtime = Runtime.getInstance();
138+
// Runtime.startConfig("gpt3-01");
139+
140+
WebGui webgui = (WebGui) Runtime.create("webgui", "WebGui");
141+
webgui.autoStartBrowser(false);
142+
webgui.startService();
143+
144+
145+
Llama llama = (Llama) Runtime.start("llama", "Llama");
146+
147+
System.out.println(llama.getResponse("Hello!").msg);
148+
149+
150+
} catch (Exception e) {
151+
log.error("main threw", e);
152+
}
153+
}
154+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package org.myrobotlab.service.config;
2+
3+
import java.util.ArrayList;
4+
import java.util.HashMap;
5+
import java.util.List;
6+
import java.util.Map;
7+
8+
public class LlamaConfig extends ServiceConfig {
9+
10+
public String systemPrompt = "";
11+
12+
public String systemMessage = "";
13+
14+
/**
15+
* The prompt that is prefixed to every user request.
16+
* No whitespace is stripped, so ensure that
17+
* the prompt is formatted so that a whitespace-stripped
18+
* user request does not cause tokenizer errors.
19+
*/
20+
public String userPrompt = "### User:\n";
21+
22+
/**
23+
* The prompt that the AI should use, should not
24+
* have a trailing space. Any trailing space
25+
* (but not newlines) are stripped to prevent
26+
* tokenizer errors.
27+
*/
28+
public String assistantPrompt = "### Assistant:\n";
29+
30+
public String selectedModel = "llama-2-7b-guanaco-qlora.Q4_K_M.gguf";
31+
32+
public List<String> modelPaths = new ArrayList<>(List.of(
33+
34+
));
35+
36+
public Map<String, String> modelUrls = new HashMap<>(Map.of(
37+
"stablebeluga-7b.Q4_K_M.gguf", "https://huggingface.co/TheBloke/StableBeluga-7B-GGUF/resolve/main/stablebeluga-7b.Q4_K_M.gguf",
38+
"llama-2-7b-guanaco-qlora.Q4_K_M.gguf", "https://huggingface.co/TheBloke/llama-2-7B-Guanaco-QLoRA-GGUF/resolve/main/llama-2-7b-guanaco-qlora.Q4_K_M.gguf"
39+
));
40+
41+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package org.myrobotlab.service.meta;
2+
3+
import org.myrobotlab.service.meta.abstracts.MetaData;
4+
5+
public class LlamaMeta extends MetaData {
6+
7+
public LlamaMeta() {
8+
addDescription(
9+
"A large language model inference engine based on the widely used " +
10+
"llama.cpp project. Can run most GGUF models."
11+
);
12+
13+
addDependency("de.kherud", "llama", "1.1.0");
14+
}
15+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
angular.module('mrlapp.service.LlamaGui', []).controller('LlamaGuiCtrl', ['$scope', 'mrl', function($scope, mrl) {
2+
console.info('LlamaGuiCtrl')
3+
var _self = this
4+
var msg = this.msg
5+
$scope.utterances = []
6+
$scope.maxRecords = 500
7+
$scope.text = null
8+
9+
// GOOD TEMPLATE TO FOLLOW
10+
this.updateState = function(service) {
11+
$scope.service = service
12+
}
13+
14+
15+
// init scope variables
16+
$scope.onTime = null
17+
$scope.onEpoch = null
18+
19+
this.onMsg = function(inMsg) {
20+
let data = inMsg.data[0]
21+
switch (inMsg.method) {
22+
case 'onState':
23+
_self.updateState(data)
24+
$scope.$apply()
25+
break
26+
case 'onUtterance':
27+
$scope.utterances.push(data)
28+
// remove the beginning if we are at maxRecords
29+
if ($scope.utterances.length > $scope.maxRecords) {
30+
$scope.utterances.shift()
31+
}
32+
$scope.$apply()
33+
break
34+
case 'onRequest':
35+
request = {"username":"friend", "text":data}
36+
$scope.utterances.push(request)
37+
// remove the beginning if we are at maxRecords
38+
if ($scope.utterances.length > $scope.maxRecords) {
39+
$scope.utterances.shift()
40+
}
41+
$scope.$apply()
42+
break
43+
case 'onEpoch':
44+
$scope.onEpoch = data
45+
$scope.$apply()
46+
break
47+
default:
48+
console.error("ERROR - unhandled method " + $scope.name + " " + inMsg.method)
49+
break
50+
}
51+
}
52+
53+
msg.subscribe('publishRequest')
54+
msg.subscribe('publishUtterance')
55+
msg.subscribe(this)
56+
}
57+
])
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
<div class="row">
2+
<form class="form-inline">
3+
text<br/>
4+
<input class="form-control" type="text" ng-model="text" placeholder="send text" title="send text">
5+
<button class="btn btn-default" ng-click="msg.getResponse(text)">send text</button><br/>
6+
</form>
7+
</div>
8+
9+
<div class="row">
10+
<table class="table table-hover">
11+
<tbody>
12+
<tr ng-repeat="e in utterances" >
13+
<td>
14+
<small>{{e.username}}</small>
15+
</td>
16+
<td>
17+
<small>{{e.channel}}</small>
18+
</td>
19+
<td>
20+
<small style="white-space: pre-wrap">{{e.text}}</small>
21+
</td>
22+
</tr>
23+
</tbody>
24+
</table>
25+
</div>

0 commit comments

Comments
 (0)