Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/import_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def resolve_detail_classification_id(cur, cache: dict[tuple[str | None, str | No
row = fetch_one(
cur,
"""
select dcm.detail_classification_id as id
select ccm.detail_classification_id as id
from corpus_classification_mappings ccm
join detail_classifications dcm on dcm.id = ccm.detail_classification_id
where ccm.source_job_group_l1 = %s
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.jobdri.jobdri_api.domain.corpus.service;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpHeaders;
Expand All @@ -17,6 +19,7 @@
public class CohereCorpusEmbeddingClient implements CorpusEmbeddingClient {

private final RestClient.Builder restClientBuilder;
private final ObjectMapper objectMapper;

@Value("${cohere.api.key:}")
private String cohereApiKey;
Expand Down Expand Up @@ -47,7 +50,7 @@ public List<float[]> embed(List<String> texts, InputType inputType) {
.defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.build();

EmbedResponse response = client.post()
String responseBody = client.post()
.uri("/v2/embed")
.body(new EmbedRequest(
texts,
Expand All @@ -57,15 +60,9 @@ public List<float[]> embed(List<String> texts, InputType inputType) {
List.of("float")
))
.retrieve()
.body(EmbedResponse.class);
.body(String.class);

if (response == null || response.embeddings() == null || response.embeddings().floatEmbeddings() == null) {
throw new IllegalStateException("Cohere 임베딩 응답이 비어 있습니다.");
}

return response.embeddings().floatEmbeddings().stream()
.map(this::toFloatArray)
.toList();
return parseEmbeddings(responseBody);
}

private float[] toFloatArray(List<Double> values) {
Expand All @@ -76,6 +73,36 @@ private float[] toFloatArray(List<Double> values) {
return array;
}

private List<float[]> parseEmbeddings(String responseBody) {
if (!StringUtils.hasText(responseBody)) {
throw new IllegalStateException("Cohere 임베딩 응답이 비어 있습니다.");
}

try {
JsonNode root = objectMapper.readTree(responseBody);
JsonNode floatEmbeddings = root.path("embeddings").path("float");
if (!floatEmbeddings.isArray()) {
throw new IllegalStateException("Cohere 임베딩 응답 형식이 예상과 다릅니다.");
}

List<float[]> result = new java.util.ArrayList<>();
for (JsonNode embeddingNode : floatEmbeddings) {
if (!embeddingNode.isArray()) {
throw new IllegalStateException("Cohere 임베딩 벡터 형식이 예상과 다릅니다.");
}

float[] vector = new float[embeddingNode.size()];
for (int i = 0; i < embeddingNode.size(); i++) {
vector[i] = embeddingNode.get(i).floatValue();
}
result.add(vector);
}
return result;
} catch (Exception e) {
throw new IllegalStateException("Cohere 임베딩 응답 파싱에 실패했습니다.", e);
}
}

private record EmbedRequest(
List<String> texts,
String model,
Expand All @@ -85,12 +112,4 @@ private record EmbedRequest(
) {
}

private record EmbedResponse(Embeddings embeddings) {
}

private record Embeddings(
@com.fasterxml.jackson.annotation.JsonProperty("float")
List<List<Double>> floatEmbeddings
) {
}
}
Loading