// Path: Backend/internal/middleware/auth.go package middleware import ( "errors" "strconv" "strings" "time" "boostai-backend/internal/config" "boostai-backend/internal/sqlc" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v5" ) const DefaultTokenTTL = 7 * 24 * time.Hour type AuthClaims struct { UserID int64 `json:"user_id"` Role sqlc.UserRole `json:"role"` Email string `json:"email"` jwt.RegisteredClaims } type AuthMiddleware struct { cfg *config.Config } func NewAuthMiddleware(cfg *config.Config) *AuthMiddleware { return &AuthMiddleware{cfg: cfg} } func (m *AuthMiddleware) CreateToken(userID int64, role sqlc.UserRole, email string, ttl time.Duration) (string, error) { if ttl <= 0 { ttl = DefaultTokenTTL } now := time.Now().UTC() claims := AuthClaims{ UserID: userID, Role: role, Email: email, RegisteredClaims: jwt.RegisteredClaims{ Subject: email, IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(ttl)), }, } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) return token.SignedString([]byte(m.cfg.JWTSecret)) } func (m *AuthMiddleware) RequireAuth() fiber.Handler { return func(c *fiber.Ctx) error { claims, err := m.parseClaims(c) if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ "error": "unauthorized", "message": "Authentication required", }) } c.Locals("auth.user_id", claims.UserID) c.Locals("auth.role", claims.Role) c.Locals("auth.email", claims.Email) return c.Next() } } func (m *AuthMiddleware) RequireTeacher() fiber.Handler { return func(c *fiber.Ctx) error { if CurrentUserRole(c) != sqlc.UserRoleTeacher { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "forbidden", "message": "Teacher access required", }) } return c.Next() } } func (m *AuthMiddleware) RequireTeacherSelf(param string) fiber.Handler { return func(c *fiber.Ctx) error { if CurrentUserRole(c) != sqlc.UserRoleTeacher { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "forbidden", "message": "Teacher access required", }) } paramID, err := parsePositiveParam(c, param) if err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ "error": "invalid_request", "message": "Invalid path parameter: " + param, }) } if CurrentUserID(c) != paramID { return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "forbidden", "message": "You can only access your own teacher resources", }) } return c.Next() } } func (m *AuthMiddleware) RequireStudentSelfOrTeacher(param string) fiber.Handler { return func(c *fiber.Ctx) error { paramID, err := parsePositiveParam(c, param) if err != nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ "error": "invalid_request", "message": "Invalid path parameter: " + param, }) } if CurrentUserRole(c) == sqlc.UserRoleTeacher || CurrentUserID(c) == paramID { return c.Next() } return c.Status(fiber.StatusForbidden).JSON(fiber.Map{ "error": "forbidden", "message": "You can only access your own student resources", }) } } func CurrentUserID(c *fiber.Ctx) int64 { value, ok := c.Locals("auth.user_id").(int64) if !ok { return 0 } return value } func CurrentUserRole(c *fiber.Ctx) sqlc.UserRole { value, ok := c.Locals("auth.role").(sqlc.UserRole) if !ok { return "" } return value } func (m *AuthMiddleware) parseClaims(c *fiber.Ctx) (*AuthClaims, error) { tokenValue := strings.TrimSpace(c.Cookies(m.cfg.SessionCookie)) if tokenValue == "" { authorization := strings.TrimSpace(c.Get("Authorization")) if strings.HasPrefix(strings.ToLower(authorization), "bearer ") { tokenValue = strings.TrimSpace(authorization[7:]) } } if tokenValue == "" { return nil, errors.New("missing token") } parsed, err := jwt.ParseWithClaims(tokenValue, &AuthClaims{}, func(token *jwt.Token) (any, error) { if token.Method != jwt.SigningMethodHS256 { return nil, errors.New("unexpected signing method") } return []byte(m.cfg.JWTSecret), nil }) if err != nil { return nil, err } claims, ok := parsed.Claims.(*AuthClaims) if !ok || !parsed.Valid { return nil, errors.New("invalid token") } return claims, nil } func parsePositiveParam(c *fiber.Ctx, param string) (int64, error) { value := strings.TrimSpace(c.Params(param)) parsed, err := strconv.ParseInt(value, 10, 64) if err != nil || parsed <= 0 { return 0, errors.New("invalid param") } return parsed, nil }