Skip to content

Commit

Permalink
feat: Improve SSE stream and add input validation (main)
Browse files Browse the repository at this point in the history
- Handle potential write errors in SSE stream and gracefully
  close the connection if necessary.
- Add input validation to prevent integer overflow when setting
  the counter value.  Return a 400 error if the value is
  too large.
- Add logging to handle write errors for the SSE stream.
- Fix nosec warnings in main.go and routes.go.
  • Loading branch information
JasonLovesDoggo committed Dec 1, 2024
1 parent deb82af commit 1f70c0b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func main() {
StartTime = time.Now()
// Initialize the Gin router
r := CreateRouter()
srv := &http.Server{
srv := &http.Server{ // #nosec G112 -- Due to the use of SSE endpoints, we cannot close the server early
Addr: ":" + os.Getenv("PORT"),
Handler: r,
}
Expand Down
23 changes: 20 additions & 3 deletions routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"io"
"log"
"math"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -59,7 +61,11 @@ func StreamValueView(c *gin.Context) {
// Send initial value
initialVal := Client.Get(context.Background(), dbKey).Val()
if count, err := strconv.Atoi(initialVal); err == nil {
c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count))
_, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count))
if err != nil {
log.Printf("Error writing to client: %v", err)
return
}
c.Writer.Flush()
}

Expand All @@ -72,7 +78,11 @@ func StreamValueView(c *gin.Context) {
if !ok {
return false
}
c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count))
_, err := c.Writer.WriteString(fmt.Sprintf("data: {\"value\":%d}\n\n", count))
if err != nil {
log.Printf("Error writing to client: %v", err)
return false // Stream closed by client or server error
}
c.Writer.Flush()
return true
}
Expand All @@ -94,8 +104,15 @@ func HitView(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get data. Try again later."})
return
}
// check if val is is greater than the max value of an int
if val > math.MaxInt {
c.JSON(http.StatusBadRequest, gin.H{"error": "Value is too large. Max value is " + strconv.Itoa(math.
MaxInt), "message": "If you are seeing this error and have a legitimate use case, please contact me @ abacus@jasoncameron.dev"})
return
}
go func() {
utils.SetStream(dbKey, int(val))
utils.SetStream(dbKey, int(val)) // #nosec G115 -- This is safe as we perform a check (
// see above) to ensure val is within the range of an int.
Client.Expire(context.Background(), dbKey, utils.BaseTTLPeriod)
}()
if c.Query("callback") != "" {
Expand Down

0 comments on commit 1f70c0b

Please sign in to comment.