package middleware import ( "crypto/x509" "encoding/hex" "encoding/pem" "errors" "net/url" "strings" "github.com/gin-gonic/gin" "gorm.io/gorm" "smoop-api/internal/models" ) // Common helpers --------------------------------------------------------------- func getHeader(c *gin.Context, k string) string { return strings.TrimSpace(c.GetHeader(k)) } func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) { if escaped == "" { return nil, errors.New("empty client cert header") } raw, err := url.QueryUnescape(escaped) if err != nil { // Some Nginx builds already space-escape; still try raw raw = escaped } block, _ := pem.Decode([]byte(raw)) if block == nil { return nil, errors.New("failed to decode PEM in client cert header") } return x509.ParseCertificate(block.Bytes) } func extractGuidFromSAN(cert *x509.Certificate) (string, bool) { for _, u := range cert.URIs { if strings.EqualFold(u.Scheme, "urn") { // urn:device: => u.Opaque == "device:" parts := strings.SplitN(u.Opaque, ":", 2) if len(parts) == 2 && strings.EqualFold(parts[0], "device") { return parts[1], true } } } return "", false } func normalizeSerialHex(b []byte) string { h := strings.ToUpper(hex.EncodeToString(b)) // remove leading zeros normalization optional, keep as-is return h } func isSerialRevoked(db *gorm.DB, serialHex string) (bool, error) { var cnt int64 if err := db.Model(&models.RevokedSerial{}).Where("serial_hex = ?", serialHex).Count(&cnt).Error; err != nil { return false, err } return cnt > 0, nil } // Guard for /tasks/:guid and /renew/:guid ------------------------------------- // Requires Nginx to forward: // // X-SSL-Client-Verify, X-SSL-Client-Serial, X-SSL-Client-Cert func MTLSGuard(db *gorm.DB) gin.HandlerFunc { return func(c *gin.Context) { if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" { c.AbortWithStatus(401) return } // Parse client cert (for GUID extraction) and serial cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert")) if err != nil { c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"}) return } pathGuid := strings.TrimSpace(c.Param("guid")) certGuid, ok := extractGuidFromSAN(cert) if !ok || certGuid == "" { c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"}) return } if pathGuid == "" || !strings.EqualFold(pathGuid, certGuid) { c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"}) return } serialHex := normalizeSerialHex(cert.SerialNumber.Bytes()) rev, err := isSerialRevoked(db, serialHex) if err != nil { c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"}) return } if rev { c.AbortWithStatus(403) return } // Stash for downstream (device + serial) c.Set("mtlsDeviceGUID", certGuid) c.Set("mtlsSerialHex", serialHex) c.Next() } } // Guard for /records/upload (guid comes as form field instead of path) func MTLSGuardUpload(db *gorm.DB) gin.HandlerFunc { return func(c *gin.Context) { if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" { c.AbortWithStatus(401) return } cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert")) if err != nil { c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"}) return } certGuid, ok := extractGuidFromSAN(cert) if !ok || certGuid == "" { c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"}) return } formGuid := strings.TrimSpace(c.PostForm("guid")) if formGuid == "" { // allow query fallback if you want formGuid = strings.TrimSpace(c.Query("guid")) } if !strings.EqualFold(formGuid, certGuid) { c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"}) return } serialHex := normalizeSerialHex(cert.SerialNumber.Bytes()) rev, err := isSerialRevoked(db, serialHex) if err != nil { c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"}) return } if rev { c.AbortWithStatus(403) return } c.Set("mtlsDeviceGUID", certGuid) c.Set("mtlsSerialHex", serialHex) c.Next() } }