153 lines
4.1 KiB
Go
153 lines
4.1 KiB
Go
package middleware
|
|
|
|
import (
|
|
"crypto/x509"
|
|
"encoding/hex"
|
|
"encoding/pem"
|
|
"errors"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"gorm.io/gorm"
|
|
|
|
"smoop-api/internal/models"
|
|
)
|
|
|
|
// Common helpers ---------------------------------------------------------------
|
|
|
|
func getHeader(c *gin.Context, k string) string {
|
|
return strings.TrimSpace(c.GetHeader(k))
|
|
}
|
|
|
|
func parseClientCertFromHeader(escaped string) (*x509.Certificate, error) {
|
|
if escaped == "" {
|
|
return nil, errors.New("empty client cert header")
|
|
}
|
|
raw, err := url.QueryUnescape(escaped)
|
|
if err != nil {
|
|
// Some Nginx builds already space-escape; still try raw
|
|
raw = escaped
|
|
}
|
|
block, _ := pem.Decode([]byte(raw))
|
|
if block == nil {
|
|
return nil, errors.New("failed to decode PEM in client cert header")
|
|
}
|
|
return x509.ParseCertificate(block.Bytes)
|
|
}
|
|
|
|
func extractGuidFromSAN(cert *x509.Certificate) (string, bool) {
|
|
for _, u := range cert.URIs {
|
|
if strings.EqualFold(u.Scheme, "urn") {
|
|
// urn:device:<GUID> => u.Opaque == "device:<GUID>"
|
|
parts := strings.SplitN(u.Opaque, ":", 2)
|
|
if len(parts) == 2 && strings.EqualFold(parts[0], "device") {
|
|
return parts[1], true
|
|
}
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func normalizeSerialHex(b []byte) string {
|
|
h := strings.ToUpper(hex.EncodeToString(b))
|
|
// remove leading zeros normalization optional, keep as-is
|
|
return h
|
|
}
|
|
|
|
func isSerialRevoked(db *gorm.DB, serialHex string) (bool, error) {
|
|
var cnt int64
|
|
if err := db.Model(&models.RevokedSerial{}).Where("serial_hex = ?", serialHex).Count(&cnt).Error; err != nil {
|
|
return false, err
|
|
}
|
|
return cnt > 0, nil
|
|
}
|
|
|
|
// Guard for /tasks/:guid and /renew/:guid -------------------------------------
|
|
// Requires Nginx to forward:
|
|
//
|
|
// X-SSL-Client-Verify, X-SSL-Client-Serial, X-SSL-Client-Cert
|
|
func MTLSGuard(db *gorm.DB) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
|
|
c.AbortWithStatus(401)
|
|
return
|
|
}
|
|
|
|
// Parse client cert (for GUID extraction) and serial
|
|
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
|
|
return
|
|
}
|
|
pathGuid := strings.TrimSpace(c.Param("guid"))
|
|
certGuid, ok := extractGuidFromSAN(cert)
|
|
if !ok || certGuid == "" {
|
|
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
|
|
return
|
|
}
|
|
if pathGuid == "" || !strings.EqualFold(pathGuid, certGuid) {
|
|
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
|
|
return
|
|
}
|
|
|
|
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
|
|
rev, err := isSerialRevoked(db, serialHex)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
|
|
return
|
|
}
|
|
if rev {
|
|
c.AbortWithStatus(403)
|
|
return
|
|
}
|
|
|
|
// Stash for downstream (device + serial)
|
|
c.Set("mtlsDeviceGUID", certGuid)
|
|
c.Set("mtlsSerialHex", serialHex)
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
// Guard for /records/upload (guid comes as form field instead of path)
|
|
func MTLSGuardUpload(db *gorm.DB) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if getHeader(c, "X-SSL-Client-Verify") != "SUCCESS" {
|
|
c.AbortWithStatus(401)
|
|
return
|
|
}
|
|
cert, err := parseClientCertFromHeader(getHeader(c, "X-SSL-Client-Cert"))
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(401, gin.H{"error": "invalid client cert"})
|
|
return
|
|
}
|
|
certGuid, ok := extractGuidFromSAN(cert)
|
|
if !ok || certGuid == "" {
|
|
c.AbortWithStatusJSON(403, gin.H{"error": "guid not present in SAN"})
|
|
return
|
|
}
|
|
formGuid := strings.TrimSpace(c.PostForm("guid"))
|
|
if formGuid == "" {
|
|
// allow query fallback if you want
|
|
formGuid = strings.TrimSpace(c.Query("guid"))
|
|
}
|
|
if !strings.EqualFold(formGuid, certGuid) {
|
|
c.AbortWithStatusJSON(403, gin.H{"error": "guid mismatch"})
|
|
return
|
|
}
|
|
serialHex := normalizeSerialHex(cert.SerialNumber.Bytes())
|
|
rev, err := isSerialRevoked(db, serialHex)
|
|
if err != nil {
|
|
c.AbortWithStatusJSON(500, gin.H{"error": "revocation check failed"})
|
|
return
|
|
}
|
|
if rev {
|
|
c.AbortWithStatus(403)
|
|
return
|
|
}
|
|
c.Set("mtlsDeviceGUID", certGuid)
|
|
c.Set("mtlsSerialHex", serialHex)
|
|
c.Next()
|
|
}
|
|
}
|