diff --git a/api/relay.go b/api/relay.go index b07204b9..d6af74f4 100644 --- a/api/relay.go +++ b/api/relay.go @@ -2,6 +2,7 @@ package api import ( "context" + "encoding/json" "fmt" "math/big" "strings" @@ -204,9 +205,7 @@ func (app *ApiServer) relay(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusInternalServerError, "failed to handle relay: "+err.Error()) } receipt := transactionToReceipt(msg, wallet) - return c.JSON(map[string]interface{}{ - "receipt": receipt, - }) + return sendRelayResponse(c, receipt) } isUser := false @@ -254,9 +253,21 @@ func (app *ApiServer) relay(c *fiber.Ctx) error { return fiber.NewError(fiber.StatusInternalServerError, "failed to handle relay: "+err.Error()) } receipt := transactionToReceipt(msg, wallet) - return c.JSON(map[string]interface{}{ + return sendRelayResponse(c, receipt) +} + +func sendRelayResponse(c *fiber.Ctx, receipt map[string]interface{}) error { + body, err := json.Marshal(map[string]interface{}{ "receipt": receipt, }) + if err != nil { + return fiber.NewError(fiber.StatusInternalServerError, "failed to encode relay response: "+err.Error()) + } + + c.Set(fiber.HeaderCacheControl, "no-store, no-transform") + c.Set(fiber.HeaderContentLength, fmt.Sprint(len(body))) + c.Set(fiber.HeaderContentType, fiber.MIMEApplicationJSONCharsetUTF8) + return c.Send(body) } func (app *ApiServer) handleRelay(ctx context.Context, logger *zap.Logger, decodedTx *v1.ManageEntityLegacy) (*v1.Transaction, error) { diff --git a/api/relay_test.go b/api/relay_test.go new file mode 100644 index 00000000..1e6f72ff --- /dev/null +++ b/api/relay_test.go @@ -0,0 +1,41 @@ +package api + +import ( + "encoding/json" + "io" + "net/http/httptest" + "testing" + + "github.com/gofiber/fiber/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSendRelayResponsePreventsResponseTransformation(t *testing.T) { + app := fiber.New() + app.Post("/relay-response", func(c *fiber.Ctx) error { + return sendRelayResponse(c, map[string]interface{}{ + "transactionHash": "0xabc", + "status": true, + }) + }) + + req := httptest.NewRequest("POST", "/relay-response", nil) + req.Header.Set("Accept-Encoding", "gzip") + res, err := app.Test(req, -1) + require.NoError(t, err) + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + require.NoError(t, err) + + assert.Equal(t, fiber.StatusOK, res.StatusCode) + assert.Equal(t, "no-store, no-transform", res.Header.Get(fiber.HeaderCacheControl)) + assert.Equal(t, fiber.MIMEApplicationJSONCharsetUTF8, res.Header.Get(fiber.HeaderContentType)) + assert.Equal(t, len(body), int(res.ContentLength)) + + var decoded map[string]map[string]interface{} + require.NoError(t, json.Unmarshal(body, &decoded)) + assert.Equal(t, "0xabc", decoded["receipt"]["transactionHash"]) + assert.Equal(t, true, decoded["receipt"]["status"]) +}