Skip to content

Commit fe482ea

Browse files
authored
Merge pull request #220 from HQ-Q/newapi
新增向量嵌入接口
2 parents 2e3175f + 912f9bd commit fe482ea

File tree

6 files changed

+307
-4
lines changed

6 files changed

+307
-4
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package com.plexpt.chatgpt;
2+
3+
import cn.hutool.core.util.RandomUtil;
4+
import cn.hutool.http.ContentType;
5+
import cn.hutool.http.Header;
6+
import com.alibaba.fastjson.JSON;
7+
import com.plexpt.chatgpt.api.Api;
8+
import com.plexpt.chatgpt.entity.BaseResponse;
9+
import com.plexpt.chatgpt.entity.embedding.EmbeddingRequest;
10+
import com.plexpt.chatgpt.entity.embedding.EmbeddingResult;
11+
import com.plexpt.chatgpt.exception.ChatException;
12+
import io.reactivex.Single;
13+
import lombok.*;
14+
import lombok.extern.slf4j.Slf4j;
15+
import okhttp3.OkHttpClient;
16+
import okhttp3.Request;
17+
import okhttp3.Response;
18+
import retrofit2.Retrofit;
19+
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
20+
import retrofit2.converter.jackson.JacksonConverterFactory;
21+
22+
import java.net.Proxy;
23+
import java.util.Collections;
24+
import java.util.List;
25+
import java.util.Objects;
26+
import java.util.concurrent.TimeUnit;
27+
28+
/**
29+
* 向量client
30+
*
31+
* @author hq
32+
* @version 1.0
33+
* @date 2023/12/12
34+
*/
35+
@Slf4j
36+
@Getter
37+
@Setter
38+
@Builder
39+
@AllArgsConstructor
40+
@NoArgsConstructor
41+
public class Embedding {
42+
43+
/**
44+
* keys
45+
*/
46+
private String apiKey;
47+
48+
private List<String> apiKeyList;
49+
/**
50+
* 自定义api host使用builder的方式构造client
51+
*/
52+
@Builder.Default
53+
private String apiHost = Api.DEFAULT_API_HOST;
54+
private Api apiClient;
55+
private OkHttpClient okHttpClient;
56+
/**
57+
* 超时 默认300
58+
*/
59+
@Builder.Default
60+
private long timeout = 300;
61+
/**
62+
* okhttp 代理
63+
*/
64+
@Builder.Default
65+
private Proxy proxy = Proxy.NO_PROXY;
66+
67+
68+
public Embedding init() {
69+
OkHttpClient.Builder client = new OkHttpClient.Builder();
70+
client.addInterceptor(chain -> {
71+
Request original = chain.request();
72+
String key = apiKey;
73+
if (apiKeyList != null && !apiKeyList.isEmpty()) {
74+
key = RandomUtil.randomEle(apiKeyList);
75+
}
76+
Request request = original.newBuilder()
77+
.header(Header.AUTHORIZATION.getValue(), "Bearer " + key)
78+
.header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue())
79+
.method(original.method(), original.body())
80+
.build();
81+
return chain.proceed(request);
82+
}).addInterceptor(chain -> {
83+
Request original = chain.request();
84+
Response response = chain.proceed(original);
85+
if (!response.isSuccessful()) {
86+
String errorMsg = response.body().string();
87+
log.error("请求异常:{}", errorMsg);
88+
BaseResponse baseResponse = JSON.parseObject(errorMsg, BaseResponse.class);
89+
if (Objects.nonNull(baseResponse.getError())) {
90+
log.error(baseResponse.getError().getMessage());
91+
throw new ChatException(baseResponse.getError().getMessage());
92+
}
93+
throw new ChatException("error");
94+
}
95+
return response;
96+
});
97+
98+
client.connectTimeout(timeout, TimeUnit.SECONDS);
99+
client.writeTimeout(timeout, TimeUnit.SECONDS);
100+
client.readTimeout(timeout, TimeUnit.SECONDS);
101+
if (Objects.nonNull(proxy)) {
102+
client.proxy(proxy);
103+
}
104+
this.okHttpClient = client.build();
105+
this.apiClient = new Retrofit.Builder()
106+
.baseUrl(this.apiHost)
107+
.client(okHttpClient)
108+
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
109+
.addConverterFactory(JacksonConverterFactory.create())
110+
.build()
111+
.create(Api.class);
112+
return this;
113+
}
114+
115+
116+
/**
117+
* 生成向量
118+
*/
119+
public EmbeddingResult createEmbeddings(EmbeddingRequest request) {
120+
Single<EmbeddingResult> embeddingResultSingle = this.apiClient.createEmbeddings(request);
121+
return embeddingResultSingle.blockingGet();
122+
}
123+
124+
125+
/**
126+
* 生成向量
127+
*/
128+
public EmbeddingResult createEmbeddings(String input, String user) {
129+
EmbeddingRequest request = EmbeddingRequest.builder()
130+
.input(Collections.singletonList(input))
131+
.model(EmbeddingRequest.EmbeddingModelEnum.TEXT_EMBEDDING_ADA_002.getModelName())
132+
.user(user)
133+
.build();
134+
Single<EmbeddingResult> embeddingResultSingle = this.apiClient.createEmbeddings(request);
135+
return embeddingResultSingle.blockingGet();
136+
}
137+
138+
}

src/main/java/com/plexpt/chatgpt/api/Api.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,16 @@
77
import com.plexpt.chatgpt.entity.billing.UseageResponse;
88
import com.plexpt.chatgpt.entity.chat.ChatCompletion;
99
import com.plexpt.chatgpt.entity.chat.ChatCompletionResponse;
10-
10+
import com.plexpt.chatgpt.entity.embedding.EmbeddingRequest;
11+
import com.plexpt.chatgpt.entity.embedding.EmbeddingResult;
1112
import com.plexpt.chatgpt.entity.images.Edits;
1213
import com.plexpt.chatgpt.entity.images.Generations;
1314
import com.plexpt.chatgpt.entity.images.ImagesRensponse;
1415
import com.plexpt.chatgpt.entity.images.Variations;
1516
import io.reactivex.Single;
1617
import okhttp3.MultipartBody;
17-
import okhttp3.RequestBody;
1818
import retrofit2.http.*;
1919

20-
import java.util.Map;
21-
2220

2321
/**
2422
*
@@ -95,4 +93,10 @@ Single<UseageResponse> usage(@Query("start_date") String startDate,
9593
@Query("end_date") String endDate);
9694

9795

96+
/**
97+
* 生成向量
98+
*/
99+
@POST("v1/embeddings")
100+
Single<EmbeddingResult> createEmbeddings(@Body EmbeddingRequest request);
101+
98102
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package com.plexpt.chatgpt.entity.embedding;
2+
3+
import lombok.Data;
4+
5+
import java.util.List;
6+
7+
/**
8+
* 向量
9+
*
10+
* @author hq
11+
* @version 1.0
12+
* @date 2023/12/12
13+
*/
14+
@Data
15+
public class EmbeddingData {
16+
17+
/**
18+
* The type of object returned, should be "embedding"
19+
*/
20+
String object;
21+
22+
/**
23+
* The embedding vector
24+
*/
25+
List<Double> embedding;
26+
27+
/**
28+
* The position of this embedding in the list
29+
*/
30+
Integer index;
31+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package com.plexpt.chatgpt.entity.embedding;
2+
3+
import lombok.*;
4+
5+
import java.util.List;
6+
7+
/**
8+
* 生成向量请求参数
9+
*
10+
* @author hq
11+
* @version 1.0
12+
* @date 2023/12/12
13+
*/
14+
15+
@Builder
16+
@NoArgsConstructor
17+
@AllArgsConstructor
18+
@Data
19+
public class EmbeddingRequest {
20+
21+
/**
22+
* 向量模型
23+
*/
24+
private String model;
25+
26+
/**
27+
* 需要转成向量的文本
28+
*/
29+
private List<String> input;
30+
31+
/**
32+
* 代表最终用户的唯一标识符,这将有助于 OpenAI 监控和检测滥用行为
33+
*/
34+
private String user;
35+
36+
37+
/**
38+
* 向量模型枚举
39+
*
40+
* @author hq
41+
* @version 1.0
42+
* @date 2023/12/12
43+
*/
44+
@Getter
45+
@AllArgsConstructor
46+
public enum EmbeddingModelEnum {
47+
/**
48+
* text-embedding-ada-002
49+
*/
50+
TEXT_EMBEDDING_ADA_002("text-embedding-ada-002"),
51+
;
52+
53+
/**
54+
* modelName
55+
*/
56+
private final String modelName;
57+
}
58+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package com.plexpt.chatgpt.entity.embedding;
2+
3+
import com.plexpt.chatgpt.entity.billing.Usage;
4+
import lombok.Data;
5+
6+
import java.util.List;
7+
8+
/**
9+
* 向量结果
10+
*
11+
* @author hq
12+
* @version 1.0
13+
* @date 2023/12/12
14+
*/
15+
@Data
16+
public class EmbeddingResult {
17+
18+
/**
19+
* The GPTmodel used for generating embeddings
20+
*/
21+
String model;
22+
23+
/**
24+
* The type of object returned, should be "list"
25+
*/
26+
String object;
27+
28+
/**
29+
* A list of the calculated embeddings
30+
*/
31+
List<EmbeddingData> data;
32+
33+
/**
34+
* The API usage for this request
35+
*/
36+
Usage usage;
37+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package com.plexpt.chatgpt;
2+
3+
import com.plexpt.chatgpt.entity.embedding.EmbeddingResult;
4+
import org.junit.Before;
5+
import org.junit.Test;
6+
7+
/**
8+
* @author hq
9+
* @version 1.0
10+
* @date 2023/12/13
11+
*/
12+
public class EmbeddingTest {
13+
14+
private Embedding embedding;
15+
16+
@Before
17+
public void before() {
18+
//Proxy proxy = Proxys.http("127.0.0.1", 1080);
19+
20+
embedding = Embedding.builder()
21+
.apiKey("sk-6MeitSJhboJdWhGJLZTaH1T3BlbkFJdmbrrY7dgAnucJo6Arn7G")
22+
.timeout(900)
23+
//.proxy(proxy)
24+
.apiHost("https://api.openai.com/") //代理地址
25+
.build()
26+
.init();
27+
}
28+
29+
30+
@Test
31+
public void setEmbedding(){
32+
EmbeddingResult embeddingResult = embedding.createEmbeddings("123445", "user1");
33+
System.out.println(embeddingResult);
34+
}
35+
}

0 commit comments

Comments
 (0)