Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions internal/policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ func (p *Policy) Filter(client http.Client, input any, scanner Scanner, cd Custo

contents := []string{}
for _, message := range converted.Messages {
contents = append(contents, message.Content)
contents = append(contents, message.Content.String())
}

result, err := p.scan(contents, scanner, cd, log)
Expand All @@ -460,7 +460,7 @@ func (p *Policy) Filter(client http.Client, input any, scanner Scanner, cd Custo

for index, c := range result.Updated {
newMessages = append(newMessages, anthropic.Message{
Content: c,
Content: anthropic.FlexContent{Text: c},
Role: converted.Messages[index].Role,
})
}
Expand Down
75 changes: 73 additions & 2 deletions internal/provider/anthropic/anthropic.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
package anthropic

import (
"encoding/json"
"strings"
)

type Metadata struct {
UserId string `json:"user_id"`
}
Expand All @@ -16,9 +21,75 @@ type CompletionRequest struct {
Stream bool `json:"stream,omitempty"`
}

type FlexContent struct {
Text string
Raw []interface{}
}

func (f *FlexContent) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err == nil {
f.Text = s
return nil
}

var arr []interface{}
if err := json.Unmarshal(data, &arr); err != nil {
return err
}
f.Raw = arr
return nil
}

func (f *FlexContent) String() string {
if f == nil {
return ""
}
if f.Text != "" {
return f.Text
}
if len(f.Raw) > 0 {
var builder strings.Builder
for _, block := range f.Raw {
builder.WriteString(extractAnthropicContentText(block))
}
return builder.String()
}
Comment thread
sergei-bronnikov marked this conversation as resolved.
return ""
}

func extractAnthropicContentText(value any) string {
switch v := value.(type) {
case string:
return v
case []any:
var builder strings.Builder
for _, item := range v {
builder.WriteString(extractAnthropicContentText(item))
}
return builder.String()
case map[string]any:
if text, ok := v["text"].(string); ok {
return text
}
if content, ok := v["content"].(string); ok {
return content
}
if raw, ok := v["raw"].([]any); ok {
var builder strings.Builder
for _, item := range raw {
builder.WriteString(extractAnthropicContentText(item))
}
return builder.String()
}
}

return ""
}

type Message struct {
Content string `json:"content"`
Role string `json:"role"`
Content FlexContent `json:"content"`
Role string `json:"role"`
}

type MessagesRequest struct {
Expand Down
2 changes: 1 addition & 1 deletion internal/provider/anthropic/cost.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func (ce *CostEstimator) CountMessagesTokens(messages []Message) int {
count := 0

for _, message := range messages {
count += ce.tc.Count(message.Content) + anthropicMessageOverhead
count += ce.tc.Count(message.Content.String()) + anthropicMessageOverhead
}

return count + anthropicMessageOverhead
Expand Down
40 changes: 40 additions & 0 deletions internal/server/web/proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
body, err := io.ReadAll(c.Request.Body)
if err != nil {
logError(logWithCid, "error when reading request body", prod, err)
JSON(c, http.StatusInternalServerError, "[BricksLLM] error reading request body")
c.Abort()
return
}

Expand All @@ -415,6 +417,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, cr)
if err != nil {
logError(logWithCid, "error when unmarshalling anthropic completion request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling anthropic completion request")
c.Abort()
return
}

Expand All @@ -440,6 +444,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, cr)
if err != nil {
logError(logWithCid, "error when unmarshalling bedrock anthropic completion request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling bedrock anthropic completion request")
c.Abort()
return
}

Expand All @@ -465,6 +471,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, mr)
if err != nil {
logError(logWithCid, "error when unmarshalling anthropic messages request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling anthropic messages request")
c.Abort()
return
}

Expand All @@ -488,6 +496,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, mr)
if err != nil {
logError(logWithCid, "error when unmarshalling anthropic messages request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling anthropic messages request")
c.Abort()
return
}

Expand Down Expand Up @@ -624,6 +634,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, ccr)
if err != nil {
logError(logWithCid, "error when unmarshalling vllm chat completions request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling vllm chat completions request")
c.Abort()
return
}

Expand All @@ -645,6 +657,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, cr)
if err != nil {
logError(logWithCid, "error when unmarshalling vllm completions request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling vllm completions request")
c.Abort()
return
}

Expand All @@ -666,6 +680,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, ccr)
if err != nil {
logError(logWithCid, "error when unmarshalling deepinfra chat completions request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling deepinfra chat completions request")
c.Abort()
return
}

Expand All @@ -686,6 +702,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, cr)
if err != nil {
logError(logWithCid, "error when unmarshalling deepinfra completions request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling deepinfra completions request")
c.Abort()
return
}

Expand All @@ -706,6 +724,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, er)
if err != nil {
logError(logWithCid, "error when unmarshalling deepinfra embeddings request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling deepinfra embeddings request")
c.Abort()
return
}

Expand All @@ -723,6 +743,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, ccr)
if err != nil {
logError(logWithCid, "error when unmarshalling azure openai chat completion request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling azure openai chat completion request")
c.Abort()
return
}

Expand All @@ -744,6 +766,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, cr)
if err != nil {
logError(logWithCid, "error when unmarshalling azure openai completions request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling azure openai completions request")
c.Abort()
return
}

Expand All @@ -765,6 +789,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, er)
if err != nil {
logError(logWithCid, "error when unmarshalling azure openai embedding request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling azure openai embedding request")
c.Abort()
return
}

Expand All @@ -789,6 +815,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal([]byte(cleaned), ccr)
if err != nil {
logError(logWithCid, "error when unmarshalling chat completion request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling chat completion request")
c.Abort()
return
}

Expand All @@ -812,6 +840,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, responsesReq)
if err != nil {
logError(logWithCid, "error when unmarshalling openai responses request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling openai responses request")
c.Abort()
return
}

Expand Down Expand Up @@ -876,6 +906,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err = json.Unmarshal(body, er)
if err != nil {
logError(logWithCid, "error when unmarshalling embedding request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling embedding request")
c.Abort()
return
}

Expand All @@ -894,6 +926,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err := json.Unmarshal(body, ir)
if err != nil {
logError(logWithCid, "error when unmarshalling create image request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling create image request")
c.Abort()
return
}
enrichedEvent.Request = ir
Expand All @@ -913,6 +947,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err := json.Unmarshal(body, ier)
if err != nil {
logError(logWithCid, "error when unmarshalling edit image request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling edit image request")
c.Abort()
return
}
enrichedEvent.Request = ier
Expand Down Expand Up @@ -941,6 +977,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err := json.Unmarshal(body, ir)
if err != nil {
logError(logWithCid, "error when unmarshalling image variations request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling image variations request")
c.Abort()
return
}
enrichedEvent.Request = ir
Expand Down Expand Up @@ -968,6 +1006,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag
err := json.Unmarshal(body, sr)
if err != nil {
logError(logWithCid, "error when unmarshalling create speech request", prod, err)
JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling create speech request")
c.Abort()
return
}

Expand Down
Loading