diff --git a/internal/policy/policy.go b/internal/policy/policy.go index 3ee0e6d..58b853a 100644 --- a/internal/policy/policy.go +++ b/internal/policy/policy.go @@ -438,7 +438,7 @@ func (p *Policy) Filter(client http.Client, input any, scanner Scanner, cd Custo contents := []string{} for _, message := range converted.Messages { - contents = append(contents, message.Content) + contents = append(contents, message.Content.String()) } result, err := p.scan(contents, scanner, cd, log) @@ -460,7 +460,7 @@ func (p *Policy) Filter(client http.Client, input any, scanner Scanner, cd Custo for index, c := range result.Updated { newMessages = append(newMessages, anthropic.Message{ - Content: c, + Content: anthropic.FlexContent{Text: c}, Role: converted.Messages[index].Role, }) } diff --git a/internal/provider/anthropic/anthropic.go b/internal/provider/anthropic/anthropic.go index 27aa6d2..8603e06 100644 --- a/internal/provider/anthropic/anthropic.go +++ b/internal/provider/anthropic/anthropic.go @@ -1,5 +1,10 @@ package anthropic +import ( + "encoding/json" + "strings" +) + type Metadata struct { UserId string `json:"user_id"` } @@ -16,9 +21,75 @@ type CompletionRequest struct { Stream bool `json:"stream,omitempty"` } +type FlexContent struct { + Text string + Raw []interface{} +} + +func (f *FlexContent) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err == nil { + f.Text = s + return nil + } + + var arr []interface{} + if err := json.Unmarshal(data, &arr); err != nil { + return err + } + f.Raw = arr + return nil +} + +func (f *FlexContent) String() string { + if f == nil { + return "" + } + if f.Text != "" { + return f.Text + } + if len(f.Raw) > 0 { + var builder strings.Builder + for _, block := range f.Raw { + builder.WriteString(extractAnthropicContentText(block)) + } + return builder.String() + } + return "" +} + +func extractAnthropicContentText(value any) string { + switch v := value.(type) { + case string: + return v + case []any: + var builder strings.Builder + for _, item := range v { + builder.WriteString(extractAnthropicContentText(item)) + } + return builder.String() + case map[string]any: + if text, ok := v["text"].(string); ok { + return text + } + if content, ok := v["content"].(string); ok { + return content + } + if raw, ok := v["raw"].([]any); ok { + var builder strings.Builder + for _, item := range raw { + builder.WriteString(extractAnthropicContentText(item)) + } + return builder.String() + } + } + + return "" +} + type Message struct { - Content string `json:"content"` - Role string `json:"role"` + Content FlexContent `json:"content"` + Role string `json:"role"` } type MessagesRequest struct { diff --git a/internal/provider/anthropic/cost.go b/internal/provider/anthropic/cost.go index f5ef452..896b92c 100644 --- a/internal/provider/anthropic/cost.go +++ b/internal/provider/anthropic/cost.go @@ -198,7 +198,7 @@ func (ce *CostEstimator) CountMessagesTokens(messages []Message) int { count := 0 for _, message := range messages { - count += ce.tc.Count(message.Content) + anthropicMessageOverhead + count += ce.tc.Count(message.Content.String()) + anthropicMessageOverhead } return count + anthropicMessageOverhead diff --git a/internal/server/web/proxy/middleware.go b/internal/server/web/proxy/middleware.go index 4c304ef..8a16f34 100644 --- a/internal/server/web/proxy/middleware.go +++ b/internal/server/web/proxy/middleware.go @@ -393,6 +393,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag body, err := io.ReadAll(c.Request.Body) if err != nil { logError(logWithCid, "error when reading request body", prod, err) + JSON(c, http.StatusInternalServerError, "[BricksLLM] error reading request body") + c.Abort() return } @@ -415,6 +417,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, cr) if err != nil { logError(logWithCid, "error when unmarshalling anthropic completion request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling anthropic completion request") + c.Abort() return } @@ -440,6 +444,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, cr) if err != nil { logError(logWithCid, "error when unmarshalling bedrock anthropic completion request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling bedrock anthropic completion request") + c.Abort() return } @@ -465,6 +471,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, mr) if err != nil { logError(logWithCid, "error when unmarshalling anthropic messages request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling anthropic messages request") + c.Abort() return } @@ -488,6 +496,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, mr) if err != nil { logError(logWithCid, "error when unmarshalling anthropic messages request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling anthropic messages request") + c.Abort() return } @@ -624,6 +634,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, ccr) if err != nil { logError(logWithCid, "error when unmarshalling vllm chat completions request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling vllm chat completions request") + c.Abort() return } @@ -645,6 +657,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, cr) if err != nil { logError(logWithCid, "error when unmarshalling vllm completions request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling vllm completions request") + c.Abort() return } @@ -666,6 +680,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, ccr) if err != nil { logError(logWithCid, "error when unmarshalling deepinfra chat completions request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling deepinfra chat completions request") + c.Abort() return } @@ -686,6 +702,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, cr) if err != nil { logError(logWithCid, "error when unmarshalling deepinfra completions request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling deepinfra completions request") + c.Abort() return } @@ -706,6 +724,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, er) if err != nil { logError(logWithCid, "error when unmarshalling deepinfra embeddings request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling deepinfra embeddings request") + c.Abort() return } @@ -723,6 +743,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, ccr) if err != nil { logError(logWithCid, "error when unmarshalling azure openai chat completion request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling azure openai chat completion request") + c.Abort() return } @@ -744,6 +766,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, cr) if err != nil { logError(logWithCid, "error when unmarshalling azure openai completions request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling azure openai completions request") + c.Abort() return } @@ -765,6 +789,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, er) if err != nil { logError(logWithCid, "error when unmarshalling azure openai embedding request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling azure openai embedding request") + c.Abort() return } @@ -789,6 +815,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal([]byte(cleaned), ccr) if err != nil { logError(logWithCid, "error when unmarshalling chat completion request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling chat completion request") + c.Abort() return } @@ -812,6 +840,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, responsesReq) if err != nil { logError(logWithCid, "error when unmarshalling openai responses request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling openai responses request") + c.Abort() return } @@ -876,6 +906,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err = json.Unmarshal(body, er) if err != nil { logError(logWithCid, "error when unmarshalling embedding request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling embedding request") + c.Abort() return } @@ -894,6 +926,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err := json.Unmarshal(body, ir) if err != nil { logError(logWithCid, "error when unmarshalling create image request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling create image request") + c.Abort() return } enrichedEvent.Request = ir @@ -913,6 +947,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err := json.Unmarshal(body, ier) if err != nil { logError(logWithCid, "error when unmarshalling edit image request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling edit image request") + c.Abort() return } enrichedEvent.Request = ier @@ -941,6 +977,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err := json.Unmarshal(body, ir) if err != nil { logError(logWithCid, "error when unmarshalling image variations request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling image variations request") + c.Abort() return } enrichedEvent.Request = ir @@ -968,6 +1006,8 @@ func getMiddleware(cpm CustomProvidersManager, rm routeManager, pm PoliciesManag err := json.Unmarshal(body, sr) if err != nil { logError(logWithCid, "error when unmarshalling create speech request", prod, err) + JSON(c, http.StatusBadRequest, "[BricksLLM] error when unmarshalling create speech request") + c.Abort() return }