Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,30 @@ python scripts/import_corpus.py /absolute/path/to/file.xlsx
- 엑셀 시트명은 `jd_embed_corpus`, `question_embed_corpus`를 사용합니다.
- `source_analysis_id`, `source_question_id` 기준으로 `INSERT ... ON CONFLICT DO UPDATE` 방식으로 적재합니다.

## Corpus Embedding Sync Script

관리자 API 대신 Python 스크립트로 corpus 임베딩을 일괄 동기화할 수 있습니다.

실행:

```bash
source .venv/bin/activate
pip install -r scripts/requirements-corpus-import.txt
python scripts/sync_corpus_embeddings.py --env-file .env
```

옵션 예시:

```bash
python scripts/sync_corpus_embeddings.py --env-file .env --limit 100
python scripts/sync_corpus_embeddings.py --env-file .env --job-only
python scripts/sync_corpus_embeddings.py --env-file .env --question-only --batch-size 16
```

- `.env`의 `DB_URL`, `DB_USERNAME`, `DB_PASSWORD`, `COHERE_API_KEY`를 사용합니다.
- 기본 모델은 `embed-v4.0`, 기본 배치 크기는 `32`입니다.
- `mock_job_posting_embeddings`, `mock_question_embeddings` 테이블에 `INSERT ... ON CONFLICT DO UPDATE` 방식으로 적재합니다.

## CI/CD

- `CI`: `main`, `develop` 브랜치 push 및 PR에서 테스트와 Docker 이미지 빌드를 실행합니다.
Expand Down
1 change: 1 addition & 0 deletions scripts/requirements-corpus-import.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
openpyxl>=3.1.0
psycopg[binary]>=3.1.0
requests>=2.31.0
276 changes: 276 additions & 0 deletions scripts/sync_corpus_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
#!/usr/bin/env python3
"""Sync corpus embeddings from Postgres to pgvector tables via Cohere."""

from __future__ import annotations

import argparse
import os
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import requests

try:
import psycopg
except ImportError: # pragma: no cover - fallback for older environments
psycopg = None

try:
import psycopg2
from psycopg2.extras import RealDictCursor
except ImportError: # pragma: no cover - optional fallback
psycopg2 = None
RealDictCursor = None


JOB_POSTING_SELECT_SQL = """
select id, embedding_text
from mock_job_posting_corpus
where is_valid_for_embedding = true
and embedding_text is not null
order by id asc
"""

QUESTION_SELECT_SQL = """
select id, embedding_text
from mock_question_corpus
where is_valid_for_embedding = true
and embedding_text is not null
order by id asc
"""

UPSERT_JOB_POSTING_SQL = """
insert into mock_job_posting_embeddings (corpus_id, embedding_model, embedding, created_at, updated_at)
values (%s, %s, %s::vector, now(), now())
on conflict (corpus_id) do update
set embedding_model = excluded.embedding_model,
embedding = excluded.embedding,
updated_at = now()
"""

UPSERT_QUESTION_SQL = """
insert into mock_question_embeddings (corpus_id, embedding_model, embedding, created_at, updated_at)
values (%s, %s, %s::vector, now(), now())
on conflict (corpus_id) do update
set embedding_model = excluded.embedding_model,
embedding = excluded.embedding,
updated_at = now()
"""


@dataclass
class SyncStats:
job_posting_embeddings_upserted: int = 0
question_embeddings_upserted: int = 0


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Sync corpus embeddings into pgvector tables.")
parser.add_argument("--env-file", default=".env", help="Path to env file")
parser.add_argument("--limit", type=int, default=None, help="Limit rows per corpus type")
parser.add_argument("--batch-size", type=int, default=None, help="Batch size for Cohere embed requests")
parser.add_argument("--job-only", action="store_true", help="Sync only job posting corpus embeddings")
parser.add_argument("--question-only", action="store_true", help="Sync only question corpus embeddings")
return parser.parse_args()


def load_env_file(env_path: Path) -> None:
if not env_path.exists():
return

for raw_line in env_path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
os.environ.setdefault(key.strip(), value.strip())


def jdbc_to_postgres_dsn(jdbc_url: str) -> str:
return jdbc_url[len("jdbc:") :] if jdbc_url.startswith("jdbc:") else jdbc_url


def connect():
db_url = os.environ.get("DB_URL")
db_user = os.environ.get("DB_USERNAME")
db_password = os.environ.get("DB_PASSWORD")

if not db_url or not db_user:
raise SystemExit("DB_URL and DB_USERNAME must be set in environment or env file.")

dsn = jdbc_to_postgres_dsn(db_url)
if psycopg is not None:
return psycopg.connect(dsn, user=db_user, password=db_password)
if psycopg2 is not None:
return psycopg2.connect(dsn, user=db_user, password=db_password, cursor_factory=RealDictCursor)
raise SystemExit("Install psycopg or psycopg2-binary before running this script.")


def fetch_all(cur, query: str, limit: int | None) -> list[dict[str, Any]]:
effective_query = query
params: tuple[Any, ...] = ()
if limit is not None:
effective_query += "\nlimit %s"
params = (limit,)

cur.execute(effective_query, params)
rows = cur.fetchall()
if not rows:
return []

normalized_rows = []
for row in rows:
if isinstance(row, dict):
normalized_rows.append(row)
continue
if hasattr(row, "_mapping"):
normalized_rows.append(dict(row._mapping))
continue
desc = cur.description
normalized_rows.append(
{desc[i].name if hasattr(desc[i], "name") else desc[i][0]: row[i] for i in range(len(row))}
)
return normalized_rows


def chunked(items: list[Any], size: int) -> list[list[Any]]:
actual_size = max(1, size)
return [items[i : i + actual_size] for i in range(0, len(items), actual_size)]


def create_requests_session(cohere_api_key: str) -> requests.Session:
session = requests.Session()
session.headers.update(
{
"Authorization": f"Bearer {cohere_api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
)
return session


def embed_documents(session: requests.Session, texts: list[str], model: str, output_dimension: int) -> list[list[float]]:
last_error = None
for attempt in range(5):
response = session.post(
"https://api.cohere.com/v2/embed",
json={
"texts": texts,
"model": model,
"input_type": "search_document",
"output_dimension": output_dimension,
"embedding_types": ["float"],
},
timeout=(5, 60),
)

if response.status_code != 429:
response.raise_for_status()
data = response.json()
embeddings = data.get("embeddings", {}).get("float")
if not isinstance(embeddings, list):
raise RuntimeError(f"Unexpected Cohere response: {data}")
return embeddings

retry_after = response.headers.get("Retry-After")
sleep_seconds = float(retry_after) if retry_after else min(60, 2 ** (attempt + 1))
print(
f"Cohere rate limit hit for batch of {len(texts)} texts. "
f"Retrying in {sleep_seconds:.1f}s... (attempt {attempt + 1}/5)",
flush=True,
)
last_error = response
time.sleep(sleep_seconds)

if last_error is not None:
last_error.raise_for_status()
raise RuntimeError("Failed to get embedding response from Cohere.")


def vector_literal(values: list[float]) -> str:
return "[" + ",".join(f"{value:.8f}" for value in values) + "]"


def upsert_embeddings(cur, sql: str, corpus_ids: list[int], embeddings: list[list[float]], model: str) -> int:
if len(corpus_ids) != len(embeddings):
raise RuntimeError("Embedding count does not match corpus row count.")

for corpus_id, embedding in zip(corpus_ids, embeddings):
cur.execute(sql, (corpus_id, model, vector_literal(embedding)))
return len(corpus_ids)


def sync_dataset(cur, session: requests.Session, select_sql: str, upsert_sql: str, limit: int | None, batch_size: int, model: str, output_dimension: int) -> int:
rows = fetch_all(cur, select_sql, limit)
if not rows:
return 0

processed = 0
for batch in chunked(rows, batch_size):
corpus_ids = [int(row["id"]) for row in batch]
texts = [str(row["embedding_text"]) for row in batch]
embeddings = embed_documents(session, texts, model, output_dimension)
processed += upsert_embeddings(cur, upsert_sql, corpus_ids, embeddings, model)
return processed


def main() -> int:
args = parse_args()
if args.job_only and args.question_only:
raise SystemExit("--job-only and --question-only cannot be used together.")

load_env_file(Path(args.env_file))

cohere_api_key = os.environ.get("COHERE_API_KEY")
if not cohere_api_key:
raise SystemExit("COHERE_API_KEY must be set in environment or env file.")

model = os.environ.get("APP_CORPUS_EMBEDDING_MODEL", "embed-v4.0")
output_dimension = int(os.environ.get("APP_CORPUS_EMBEDDING_OUTPUT_DIMENSION", "1024"))
batch_size = args.batch_size or int(os.environ.get("APP_CORPUS_EMBEDDING_BATCH_SIZE", "32"))

stats = SyncStats()
session = create_requests_session(cohere_api_key)
conn = connect()
try:
with conn:
with conn.cursor() as cur:
if not args.question_only:
stats.job_posting_embeddings_upserted = sync_dataset(
cur,
session,
JOB_POSTING_SELECT_SQL,
UPSERT_JOB_POSTING_SQL,
args.limit,
batch_size,
model,
output_dimension,
)
if not args.job_only:
stats.question_embeddings_upserted = sync_dataset(
cur,
session,
QUESTION_SELECT_SQL,
UPSERT_QUESTION_SQL,
args.limit,
batch_size,
model,
output_dimension,
)
finally:
session.close()
conn.close()

print("Embedding sync completed")
print(f"jobPostingEmbeddingsUpserted={stats.job_posting_embeddings_upserted}")
print(f"questionEmbeddingsUpserted={stats.question_embeddings_upserted}")
print(f"embeddingModel={model}")
print(f"batchSize={batch_size}")
return 0


if __name__ == "__main__":
raise SystemExit(main())
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,14 @@

public interface DetailClassificationRepository extends JpaRepository<DetailClassification, Long> {
List<DetailClassification> findAllByMiddleClassificationId(Long middleClassificationId);
@Query("""
SELECT dc
FROM DetailClassification dc
JOIN FETCH dc.middleClassification mc
JOIN FETCH mc.classification
WHERE dc.id = :id
""")
Optional<DetailClassification> findWithHierarchyById(@Param("id") Long id);
Optional<DetailClassification> findByDetailNameIgnoreCase(String detailName);
long countByDetailNameIgnoreCase(String detailName);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package com.jobdri.jobdri_api.domain.jobposting.controller;

import com.jobdri.jobdri_api.domain.jobposting.dto.request.JobPostingGenerateRequest;
import com.jobdri.jobdri_api.domain.jobposting.dto.response.JobPostingGenerateResponse;
import com.jobdri.jobdri_api.domain.jobposting.service.JobPostingAiService;
import com.jobdri.jobdri_api.domain.user.service.UserService;
import com.jobdri.jobdri_api.global.apiPayload.ApiResponse;
import com.jobdri.jobdri_api.global.security.UserDetailsImpl;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springframework.security.core.annotation.AuthenticationPrincipal;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequiredArgsConstructor
@RequestMapping("/api/job-postings")
@Tag(name = "Actual JobPosting", description = "실제 채용 공고 초안 생성 API")
public class ActualJobPostingController {

private final JobPostingAiService jobPostingAiService;
private final UserService userService;

@Operation(summary = "실제 채용 공고 초안 생성", description = "회사 정보와 직무 정보를 바탕으로 AI가 실제 채용 공고 본문 초안을 생성합니다.")
@PostMapping("/generate")
public ApiResponse<JobPostingGenerateResponse> generateJobPosting(
@AuthenticationPrincipal UserDetailsImpl userDetails,
@Valid @RequestBody JobPostingGenerateRequest request
) {
validateAuthenticatedUser(userDetails);
return ApiResponse.onSuccess(
"채용 공고 초안 생성에 성공했습니다.",
jobPostingAiService.generateJobPosting(request)
);
}

private com.jobdri.jobdri_api.domain.user.entity.User validateAuthenticatedUser(UserDetailsImpl userDetails) {
return userService.validateUser(userDetails == null ? null : userDetails.getUser());
}
}
Loading
Loading