package messages import ( "boostai-backend/internal/database" "boostai-backend/internal/handlers/api/shared" "boostai-backend/internal/http/params" "boostai-backend/internal/http/respond" authmw "boostai-backend/internal/middleware" "boostai-backend/internal/sqlc" "errors" "sort" "strings" "time" "github.com/gofiber/fiber/v2" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" ) type Handler struct { db *database.DB queries *sqlc.Queries } type recipientResponse struct { ID int64 `json:"id"` Email string `json:"email"` Role string `json:"role"` FullName string `json:"full_name"` PreferredName *string `json:"preferred_name"` ProfileIconURL *string `json:"profile_icon_url"` Headline *string `json:"headline"` } type threadParticipantResponse struct { ID int64 `json:"id"` Email string `json:"email"` Role string `json:"role"` FullName string `json:"full_name"` PreferredName *string `json:"preferred_name"` ProfileIconURL *string `json:"profile_icon_url"` Headline *string `json:"headline"` JoinedAt *time.Time `json:"joined_at,omitempty"` LastReadAt *time.Time `json:"last_read_at,omitempty"` ArchivedAt *time.Time `json:"archived_at,omitempty"` } type messageSenderResponse struct { ID int64 `json:"id"` Email string `json:"email"` Role string `json:"role"` FullName string `json:"full_name"` PreferredName *string `json:"preferred_name"` ProfileIconURL *string `json:"profile_icon_url"` Headline *string `json:"headline"` } type messageResponse struct { ID int64 `json:"id"` ThreadID int64 `json:"thread_id"` Body string `json:"body"` CreatedAt *time.Time `json:"created_at,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` Mine bool `json:"mine"` Sender messageSenderResponse `json:"sender"` } type messageThreadSummaryResponse struct { ID int64 `json:"id"` Subject string `json:"subject"` CreatedByUserID int64 `json:"created_by_user_id"` CreatedAt *time.Time `json:"created_at,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` UnreadCount int64 `json:"unread_count"` LastMessageID int64 `json:"last_message_id"` LastMessageBody *string `json:"last_message_body"` LastMessageCreatedAt *time.Time `json:"last_message_created_at,omitempty"` LastMessageSender *messageSenderResponse `json:"last_message_sender,omitempty"` Participants []threadParticipantResponse `json:"participants"` } type messageThreadDetailResponse struct { ID int64 `json:"id"` Subject string `json:"subject"` CreatedByUserID int64 `json:"created_by_user_id"` CreatedAt *time.Time `json:"created_at,omitempty"` UpdatedAt *time.Time `json:"updated_at,omitempty"` UnreadCount int64 `json:"unread_count"` LastReadAt *time.Time `json:"last_read_at,omitempty"` Participants []threadParticipantResponse `json:"participants"` Messages []messageResponse `json:"messages"` } type createThreadRequest struct { Subject string `json:"subject"` RecipientIDs []int64 `json:"recipient_ids"` Body string `json:"body"` } type createThreadResponse struct { ThreadID int64 `json:"thread_id"` } type createThreadMessageRequest struct { Body string `json:"body"` } type updateThreadRequest struct { Subject string `json:"subject"` } type updateThreadMessageRequest struct { Body string `json:"body"` } func NewHandler(db *database.DB) *Handler { return &Handler{db: db, queries: sqlc.New(db.Pool)} } func (h *Handler) ListRecipients(c *fiber.Ctx) error { currentUserID := authmw.CurrentUserID(c) ctx, cancel := shared.WithTimeout() defer cancel() recipients, err := h.queries.ListMessageRecipientsForUser(ctx, currentUserID) if err != nil { return respond.DatabaseError(c, err) } items := make([]recipientResponse, 0, len(recipients)) for _, recipient := range recipients { items = append(items, mapRecipient(recipient)) } return c.JSON(shared.ListResponse[recipientResponse]{Data: items}) } func (h *Handler) ListThreads(c *fiber.Ctx) error { currentUserID := authmw.CurrentUserID(c) ctx, cancel := shared.WithTimeout() defer cancel() threads, err := h.queries.ListMessageThreadsForUser(ctx, currentUserID) if err != nil { return respond.DatabaseError(c, err) } participants, err := h.queries.ListMessageThreadParticipantsForUser(ctx, currentUserID) if err != nil { return respond.DatabaseError(c, err) } participantsByThread := make(map[int64][]threadParticipantResponse) for _, participant := range participants { participantsByThread[participant.ThreadID] = append(participantsByThread[participant.ThreadID], mapThreadParticipant(participant)) } items := make([]messageThreadSummaryResponse, 0, len(threads)) for _, thread := range threads { items = append(items, mapThreadSummary(thread, participantsByThread[thread.ThreadID])) } return c.JSON(shared.ListResponse[messageThreadSummaryResponse]{Data: items}) } func (h *Handler) GetThread(c *fiber.Ctx) error { threadID, err := params.Int64PathParam(c, "threadId") if err != nil { return err } thread, err := h.loadThread(threadID, authmw.CurrentUserID(c)) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found") } return respond.DatabaseError(c, err) } return c.JSON(thread) } func (h *Handler) CreateThread(c *fiber.Ctx) error { currentUserID := authmw.CurrentUserID(c) var req createThreadRequest if err := c.BodyParser(&req); err != nil { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body") } subject := strings.TrimSpace(req.Subject) body := strings.TrimSpace(req.Body) if subject == "" { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "subject is required") } recipientIDs := normalizeRecipientIDs(currentUserID, req.RecipientIDs) if len(recipientIDs) == 0 { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "At least one valid recipient is required") } ctx, cancel := shared.WithTimeout() defer cancel() tx, err := h.db.Pool.Begin(ctx) if err != nil { return respond.DatabaseError(c, err) } defer func() { _ = tx.Rollback(ctx) }() queries := h.queries.WithTx(tx) for _, recipientID := range recipientIDs { if _, err := queries.GetMessageRecipientByIDForUser(ctx, sqlc.GetMessageRecipientByIDForUserParams{ID: currentUserID, ID_2: recipientID}); err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "One or more recipients are not available for messaging") } return respond.DatabaseError(c, err) } } thread, err := queries.CreateMessageThread(ctx, sqlc.CreateMessageThreadParams{ CreatedByUserID: currentUserID, Subject: subject, }) if err != nil { return respond.DatabaseError(c, err) } creatorReadAt := pgtype.Timestamptz{} if body != "" { message, err := queries.CreateThreadMessage(ctx, sqlc.CreateThreadMessageParams{ ThreadID: thread.ID, SenderUserID: currentUserID, Body: body, }) if err != nil { return respond.DatabaseError(c, err) } if message.CreatedAt.Valid { creatorReadAt = pgtype.Timestamptz{Time: message.CreatedAt.Time.UTC(), Valid: true} } } if err := queries.AddMessageThreadParticipant(ctx, sqlc.AddMessageThreadParticipantParams{ ThreadID: thread.ID, UserID: currentUserID, LastReadAt: creatorReadAt, }); err != nil { return respond.DatabaseError(c, err) } for _, recipientID := range recipientIDs { if err := queries.AddMessageThreadParticipant(ctx, sqlc.AddMessageThreadParticipantParams{ ThreadID: thread.ID, UserID: recipientID, }); err != nil { return respond.DatabaseError(c, err) } } if err := queries.TouchMessageThread(ctx, thread.ID); err != nil { return respond.DatabaseError(c, err) } if err := tx.Commit(ctx); err != nil { return respond.DatabaseError(c, err) } return c.Status(fiber.StatusCreated).JSON(createThreadResponse{ThreadID: thread.ID}) } func (h *Handler) CreateThreadMessage(c *fiber.Ctx) error { threadID, err := params.Int64PathParam(c, "threadId") if err != nil { return err } currentUserID := authmw.CurrentUserID(c) var req createThreadMessageRequest if err := c.BodyParser(&req); err != nil { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body") } body := strings.TrimSpace(req.Body) if body == "" { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "body is required") } ctx, cancel := shared.WithTimeout() defer cancel() tx, err := h.db.Pool.Begin(ctx) if err != nil { return respond.DatabaseError(c, err) } defer func() { _ = tx.Rollback(ctx) }() queries := h.queries.WithTx(tx) if _, err := queries.GetMessageThreadForUser(ctx, sqlc.GetMessageThreadForUserParams{ID: threadID, SenderUserID: currentUserID}); err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found") } return respond.DatabaseError(c, err) } if _, err := queries.CreateThreadMessage(ctx, sqlc.CreateThreadMessageParams{ ThreadID: threadID, SenderUserID: currentUserID, Body: body, }); err != nil { return respond.DatabaseError(c, err) } if err := queries.TouchMessageThread(ctx, threadID); err != nil { return respond.DatabaseError(c, err) } if _, err := queries.MarkMessageThreadRead(ctx, sqlc.MarkMessageThreadReadParams{ThreadID: threadID, UserID: currentUserID}); err != nil { return respond.DatabaseError(c, err) } if err := tx.Commit(ctx); err != nil { return respond.DatabaseError(c, err) } return c.Status(fiber.StatusCreated).JSON(fiber.Map{"status": "ok"}) } func (h *Handler) UpdateThread(c *fiber.Ctx) error { threadID, err := params.Int64PathParam(c, "threadId") if err != nil { return err } currentUserID := authmw.CurrentUserID(c) var req updateThreadRequest if err := c.BodyParser(&req); err != nil { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body") } subject := strings.TrimSpace(req.Subject) if subject == "" { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "subject is required") } ctx, cancel := shared.WithTimeout() defer cancel() thread, err := h.queries.GetMessageThreadForUser(ctx, sqlc.GetMessageThreadForUserParams{ID: threadID, SenderUserID: currentUserID}) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found") } return respond.DatabaseError(c, err) } if thread.CreatedByUserID != currentUserID { return respond.Error(c, fiber.StatusForbidden, "forbidden", "Only the conversation starter can edit the thread title") } if _, err := h.queries.UpdateMessageThreadSubject(ctx, sqlc.UpdateMessageThreadSubjectParams{ ThreadID: threadID, Subject: subject, }); err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found") } return respond.DatabaseError(c, err) } return c.JSON(fiber.Map{"status": "ok"}) } func (h *Handler) DeleteThread(c *fiber.Ctx) error { threadID, err := params.Int64PathParam(c, "threadId") if err != nil { return err } currentUserID := authmw.CurrentUserID(c) ctx, cancel := shared.WithTimeout() defer cancel() thread, err := h.queries.GetMessageThreadForUser(ctx, sqlc.GetMessageThreadForUserParams{ID: threadID, SenderUserID: currentUserID}) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found") } return respond.DatabaseError(c, err) } if thread.CreatedByUserID != currentUserID { return respond.Error(c, fiber.StatusForbidden, "forbidden", "Only the conversation starter can delete this conversation") } if _, err := h.queries.DeleteMessageThread(ctx, threadID); err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found") } return respond.DatabaseError(c, err) } return c.JSON(fiber.Map{"status": "ok"}) } func (h *Handler) UpdateThreadMessage(c *fiber.Ctx) error { threadID, err := params.Int64PathParam(c, "threadId") if err != nil { return err } messageID, err := params.Int64PathParam(c, "messageId") if err != nil { return err } currentUserID := authmw.CurrentUserID(c) var req updateThreadMessageRequest if err := c.BodyParser(&req); err != nil { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body") } body := strings.TrimSpace(req.Body) if body == "" { return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "body is required") } ctx, cancel := shared.WithTimeout() defer cancel() tx, err := h.db.Pool.Begin(ctx) if err != nil { return respond.DatabaseError(c, err) } defer func() { _ = tx.Rollback(ctx) }() queries := h.queries.WithTx(tx) if _, err := queries.UpdateThreadMessageBody(ctx, sqlc.UpdateThreadMessageBodyParams{ Body: body, MessageID: messageID, ThreadID: threadID, UserID: currentUserID, }); err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusNotFound, "not_found", "Message not found") } return respond.DatabaseError(c, err) } if err := queries.TouchMessageThread(ctx, threadID); err != nil { return respond.DatabaseError(c, err) } if err := tx.Commit(ctx); err != nil { return respond.DatabaseError(c, err) } return c.JSON(fiber.Map{"status": "ok"}) } func (h *Handler) DeleteThreadMessage(c *fiber.Ctx) error { threadID, err := params.Int64PathParam(c, "threadId") if err != nil { return err } messageID, err := params.Int64PathParam(c, "messageId") if err != nil { return err } currentUserID := authmw.CurrentUserID(c) ctx, cancel := shared.WithTimeout() defer cancel() tx, err := h.db.Pool.Begin(ctx) if err != nil { return respond.DatabaseError(c, err) } defer func() { _ = tx.Rollback(ctx) }() queries := h.queries.WithTx(tx) if _, err := queries.DeleteThreadMessage(ctx, sqlc.DeleteThreadMessageParams{ MessageID: messageID, ThreadID: threadID, UserID: currentUserID, }); err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusNotFound, "not_found", "Message not found") } return respond.DatabaseError(c, err) } if err := queries.TouchMessageThread(ctx, threadID); err != nil { return respond.DatabaseError(c, err) } if err := tx.Commit(ctx); err != nil { return respond.DatabaseError(c, err) } return c.JSON(fiber.Map{"status": "ok"}) } func (h *Handler) MarkThreadRead(c *fiber.Ctx) error { threadID, err := params.Int64PathParam(c, "threadId") if err != nil { return err } ctx, cancel := shared.WithTimeout() defer cancel() if _, err := h.queries.MarkMessageThreadRead(ctx, sqlc.MarkMessageThreadReadParams{ThreadID: threadID, UserID: authmw.CurrentUserID(c)}); err != nil { if errors.Is(err, pgx.ErrNoRows) { return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found") } return respond.DatabaseError(c, err) } return c.JSON(fiber.Map{"status": "ok"}) } func (h *Handler) loadThread(threadID, currentUserID int64) (messageThreadDetailResponse, error) { queryCtx, cancel := shared.WithTimeout() defer cancel() thread, err := h.queries.GetMessageThreadForUser(queryCtx, sqlc.GetMessageThreadForUserParams{ID: threadID, SenderUserID: currentUserID}) if err != nil { return messageThreadDetailResponse{}, err } participants, err := h.queries.ListParticipantsForThreadForUser(queryCtx, sqlc.ListParticipantsForThreadForUserParams{ThreadID: threadID, UserID: currentUserID}) if err != nil { return messageThreadDetailResponse{}, err } messages, err := h.queries.ListMessagesForThreadForUser(queryCtx, sqlc.ListMessagesForThreadForUserParams{ThreadID: threadID, UserID: currentUserID}) if err != nil { return messageThreadDetailResponse{}, err } participantItems := make([]threadParticipantResponse, 0, len(participants)) for _, participant := range participants { participantItems = append(participantItems, mapThreadParticipantByThread(participant)) } messageItems := make([]messageResponse, 0, len(messages)) for _, message := range messages { messageItems = append(messageItems, mapThreadMessage(message, currentUserID)) } return messageThreadDetailResponse{ ID: thread.ID, Subject: thread.Subject, CreatedByUserID: thread.CreatedByUserID, CreatedAt: shared.TimePointer(thread.CreatedAt), UpdatedAt: shared.TimePointer(thread.UpdatedAt), UnreadCount: thread.UnreadCount, LastReadAt: shared.TimePointer(thread.LastReadAt), Participants: participantItems, Messages: messageItems, }, nil } func mapRecipient(row sqlc.ListMessageRecipientsForUserRow) recipientResponse { return recipientResponse{ ID: row.UserID, Email: row.UserEmail, Role: string(row.UserRole), FullName: row.UserFullName, PreferredName: shared.TextPointer(row.PreferredName), ProfileIconURL: shared.TextPointer(row.ProfileIconUrl), Headline: shared.TextPointer(row.Headline), } } func mapRecipientByID(row sqlc.GetMessageRecipientByIDForUserRow) recipientResponse { return recipientResponse{ ID: row.UserID, Email: row.UserEmail, Role: string(row.UserRole), FullName: row.UserFullName, PreferredName: shared.TextPointer(row.PreferredName), ProfileIconURL: shared.TextPointer(row.ProfileIconUrl), Headline: shared.TextPointer(row.Headline), } } func mapThreadParticipant(row sqlc.ListMessageThreadParticipantsForUserRow) threadParticipantResponse { return threadParticipantResponse{ ID: row.UserID, Email: row.UserEmail, Role: string(row.UserRole), FullName: row.UserFullName, PreferredName: shared.TextPointer(row.PreferredName), ProfileIconURL: shared.TextPointer(row.ProfileIconUrl), Headline: shared.TextPointer(row.Headline), JoinedAt: shared.TimePointer(row.JoinedAt), LastReadAt: shared.TimePointer(row.LastReadAt), ArchivedAt: shared.TimePointer(row.ArchivedAt), } } func mapThreadParticipantByThread(row sqlc.ListParticipantsForThreadForUserRow) threadParticipantResponse { return threadParticipantResponse{ ID: row.UserID, Email: row.UserEmail, Role: string(row.UserRole), FullName: row.UserFullName, PreferredName: shared.TextPointer(row.PreferredName), ProfileIconURL: shared.TextPointer(row.ProfileIconUrl), Headline: shared.TextPointer(row.Headline), JoinedAt: shared.TimePointer(row.JoinedAt), LastReadAt: shared.TimePointer(row.LastReadAt), ArchivedAt: shared.TimePointer(row.ArchivedAt), } } func mapThreadSummary(row sqlc.ListMessageThreadsForUserRow, participants []threadParticipantResponse) messageThreadSummaryResponse { response := messageThreadSummaryResponse{ ID: row.ThreadID, Subject: row.Subject, CreatedByUserID: row.CreatedByUserID, CreatedAt: shared.TimePointer(row.ThreadCreatedAt), UpdatedAt: shared.TimePointer(row.ThreadUpdatedAt), UnreadCount: row.UnreadCount, LastMessageID: row.LastMessageID, LastMessageBody: stringPointerOrNil(row.LastMessageBody), LastMessageCreatedAt: shared.TimePointer(row.LastMessageCreatedAt), Participants: participants, } if row.LastMessageID > 0 { response.LastMessageSender = &messageSenderResponse{ ID: row.LastMessageSenderUserID, Email: "", Role: "", FullName: valueOrEmpty(row.LastMessageSenderFullName), PreferredName: shared.TextPointer(row.LastMessageSenderPreferredName), ProfileIconURL: shared.TextPointer(row.LastMessageSenderProfileIconUrl), } } return response } func mapThreadMessage(row sqlc.ListMessagesForThreadForUserRow, currentUserID int64) messageResponse { return messageResponse{ ID: row.ID, ThreadID: row.ThreadID, Body: row.Body, CreatedAt: shared.TimePointer(row.CreatedAt), UpdatedAt: shared.TimePointer(row.UpdatedAt), Mine: row.SenderUserID == currentUserID, Sender: messageSenderResponse{ ID: row.SenderUserID, Email: row.SenderEmail, Role: string(row.SenderRole), FullName: row.SenderFullName, PreferredName: shared.TextPointer(row.SenderPreferredName), ProfileIconURL: shared.TextPointer(row.SenderProfileIconUrl), Headline: shared.TextPointer(row.SenderHeadline), }, } } func normalizeRecipientIDs(currentUserID int64, values []int64) []int64 { seen := make(map[int64]struct{}, len(values)) normalized := make([]int64, 0, len(values)) for _, value := range values { if value <= 0 || value == currentUserID { continue } if _, ok := seen[value]; ok { continue } seen[value] = struct{}{} normalized = append(normalized, value) } sort.Slice(normalized, func(i, j int) bool { return normalized[i] < normalized[j] }) return normalized } func valueOrEmpty(value pgtype.Text) string { if !value.Valid { return "" } return value.String } func stringPointerOrNil(value string) *string { if strings.TrimSpace(value) == "" { return nil } copy := value return © }