Files
BoostAI/Backend/internal/handlers/api/messages/handler.go
2026-05-25 17:05:06 +01:00

709 lines
22 KiB
Go

package messages
import (
"boostai-backend/internal/database"
"boostai-backend/internal/handlers/api/shared"
"boostai-backend/internal/http/params"
"boostai-backend/internal/http/respond"
authmw "boostai-backend/internal/middleware"
"boostai-backend/internal/sqlc"
"errors"
"sort"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
)
type Handler struct {
db *database.DB
queries *sqlc.Queries
}
type recipientResponse struct {
ID int64 `json:"id"`
Email string `json:"email"`
Role string `json:"role"`
FullName string `json:"full_name"`
PreferredName *string `json:"preferred_name"`
ProfileIconURL *string `json:"profile_icon_url"`
Headline *string `json:"headline"`
}
type threadParticipantResponse struct {
ID int64 `json:"id"`
Email string `json:"email"`
Role string `json:"role"`
FullName string `json:"full_name"`
PreferredName *string `json:"preferred_name"`
ProfileIconURL *string `json:"profile_icon_url"`
Headline *string `json:"headline"`
JoinedAt *time.Time `json:"joined_at,omitempty"`
LastReadAt *time.Time `json:"last_read_at,omitempty"`
ArchivedAt *time.Time `json:"archived_at,omitempty"`
}
type messageSenderResponse struct {
ID int64 `json:"id"`
Email string `json:"email"`
Role string `json:"role"`
FullName string `json:"full_name"`
PreferredName *string `json:"preferred_name"`
ProfileIconURL *string `json:"profile_icon_url"`
Headline *string `json:"headline"`
}
type messageResponse struct {
ID int64 `json:"id"`
ThreadID int64 `json:"thread_id"`
Body string `json:"body"`
CreatedAt *time.Time `json:"created_at,omitempty"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
Mine bool `json:"mine"`
Sender messageSenderResponse `json:"sender"`
}
type messageThreadSummaryResponse struct {
ID int64 `json:"id"`
Subject string `json:"subject"`
CreatedByUserID int64 `json:"created_by_user_id"`
CreatedAt *time.Time `json:"created_at,omitempty"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
UnreadCount int64 `json:"unread_count"`
LastMessageID int64 `json:"last_message_id"`
LastMessageBody *string `json:"last_message_body"`
LastMessageCreatedAt *time.Time `json:"last_message_created_at,omitempty"`
LastMessageSender *messageSenderResponse `json:"last_message_sender,omitempty"`
Participants []threadParticipantResponse `json:"participants"`
}
type messageThreadDetailResponse struct {
ID int64 `json:"id"`
Subject string `json:"subject"`
CreatedByUserID int64 `json:"created_by_user_id"`
CreatedAt *time.Time `json:"created_at,omitempty"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
UnreadCount int64 `json:"unread_count"`
LastReadAt *time.Time `json:"last_read_at,omitempty"`
Participants []threadParticipantResponse `json:"participants"`
Messages []messageResponse `json:"messages"`
}
type createThreadRequest struct {
Subject string `json:"subject"`
RecipientIDs []int64 `json:"recipient_ids"`
Body string `json:"body"`
}
type createThreadResponse struct {
ThreadID int64 `json:"thread_id"`
}
type createThreadMessageRequest struct {
Body string `json:"body"`
}
type updateThreadRequest struct {
Subject string `json:"subject"`
}
type updateThreadMessageRequest struct {
Body string `json:"body"`
}
func NewHandler(db *database.DB) *Handler {
return &Handler{db: db, queries: sqlc.New(db.Pool)}
}
func (h *Handler) ListRecipients(c *fiber.Ctx) error {
currentUserID := authmw.CurrentUserID(c)
ctx, cancel := shared.WithTimeout()
defer cancel()
recipients, err := h.queries.ListMessageRecipientsForUser(ctx, currentUserID)
if err != nil {
return respond.DatabaseError(c, err)
}
items := make([]recipientResponse, 0, len(recipients))
for _, recipient := range recipients {
items = append(items, mapRecipient(recipient))
}
return c.JSON(shared.ListResponse[recipientResponse]{Data: items})
}
func (h *Handler) ListThreads(c *fiber.Ctx) error {
currentUserID := authmw.CurrentUserID(c)
ctx, cancel := shared.WithTimeout()
defer cancel()
threads, err := h.queries.ListMessageThreadsForUser(ctx, currentUserID)
if err != nil {
return respond.DatabaseError(c, err)
}
participants, err := h.queries.ListMessageThreadParticipantsForUser(ctx, currentUserID)
if err != nil {
return respond.DatabaseError(c, err)
}
participantsByThread := make(map[int64][]threadParticipantResponse)
for _, participant := range participants {
participantsByThread[participant.ThreadID] = append(participantsByThread[participant.ThreadID], mapThreadParticipant(participant))
}
items := make([]messageThreadSummaryResponse, 0, len(threads))
for _, thread := range threads {
items = append(items, mapThreadSummary(thread, participantsByThread[thread.ThreadID]))
}
return c.JSON(shared.ListResponse[messageThreadSummaryResponse]{Data: items})
}
func (h *Handler) GetThread(c *fiber.Ctx) error {
threadID, err := params.Int64PathParam(c, "threadId")
if err != nil {
return err
}
thread, err := h.loadThread(threadID, authmw.CurrentUserID(c))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found")
}
return respond.DatabaseError(c, err)
}
return c.JSON(thread)
}
func (h *Handler) CreateThread(c *fiber.Ctx) error {
currentUserID := authmw.CurrentUserID(c)
var req createThreadRequest
if err := c.BodyParser(&req); err != nil {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
}
subject := strings.TrimSpace(req.Subject)
body := strings.TrimSpace(req.Body)
if subject == "" {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "subject is required")
}
recipientIDs := normalizeRecipientIDs(currentUserID, req.RecipientIDs)
if len(recipientIDs) == 0 {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "At least one valid recipient is required")
}
ctx, cancel := shared.WithTimeout()
defer cancel()
tx, err := h.db.Pool.Begin(ctx)
if err != nil {
return respond.DatabaseError(c, err)
}
defer func() {
_ = tx.Rollback(ctx)
}()
queries := h.queries.WithTx(tx)
for _, recipientID := range recipientIDs {
if _, err := queries.GetMessageRecipientByIDForUser(ctx, sqlc.GetMessageRecipientByIDForUserParams{ID: currentUserID, ID_2: recipientID}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "One or more recipients are not available for messaging")
}
return respond.DatabaseError(c, err)
}
}
thread, err := queries.CreateMessageThread(ctx, sqlc.CreateMessageThreadParams{
CreatedByUserID: currentUserID,
Subject: subject,
})
if err != nil {
return respond.DatabaseError(c, err)
}
creatorReadAt := pgtype.Timestamptz{}
if body != "" {
message, err := queries.CreateThreadMessage(ctx, sqlc.CreateThreadMessageParams{
ThreadID: thread.ID,
SenderUserID: currentUserID,
Body: body,
})
if err != nil {
return respond.DatabaseError(c, err)
}
if message.CreatedAt.Valid {
creatorReadAt = pgtype.Timestamptz{Time: message.CreatedAt.Time.UTC(), Valid: true}
}
}
if err := queries.AddMessageThreadParticipant(ctx, sqlc.AddMessageThreadParticipantParams{
ThreadID: thread.ID,
UserID: currentUserID,
LastReadAt: creatorReadAt,
}); err != nil {
return respond.DatabaseError(c, err)
}
for _, recipientID := range recipientIDs {
if err := queries.AddMessageThreadParticipant(ctx, sqlc.AddMessageThreadParticipantParams{
ThreadID: thread.ID,
UserID: recipientID,
}); err != nil {
return respond.DatabaseError(c, err)
}
}
if err := queries.TouchMessageThread(ctx, thread.ID); err != nil {
return respond.DatabaseError(c, err)
}
if err := tx.Commit(ctx); err != nil {
return respond.DatabaseError(c, err)
}
return c.Status(fiber.StatusCreated).JSON(createThreadResponse{ThreadID: thread.ID})
}
func (h *Handler) CreateThreadMessage(c *fiber.Ctx) error {
threadID, err := params.Int64PathParam(c, "threadId")
if err != nil {
return err
}
currentUserID := authmw.CurrentUserID(c)
var req createThreadMessageRequest
if err := c.BodyParser(&req); err != nil {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
}
body := strings.TrimSpace(req.Body)
if body == "" {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "body is required")
}
ctx, cancel := shared.WithTimeout()
defer cancel()
tx, err := h.db.Pool.Begin(ctx)
if err != nil {
return respond.DatabaseError(c, err)
}
defer func() {
_ = tx.Rollback(ctx)
}()
queries := h.queries.WithTx(tx)
if _, err := queries.GetMessageThreadForUser(ctx, sqlc.GetMessageThreadForUserParams{ID: threadID, SenderUserID: currentUserID}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found")
}
return respond.DatabaseError(c, err)
}
if _, err := queries.CreateThreadMessage(ctx, sqlc.CreateThreadMessageParams{
ThreadID: threadID,
SenderUserID: currentUserID,
Body: body,
}); err != nil {
return respond.DatabaseError(c, err)
}
if err := queries.TouchMessageThread(ctx, threadID); err != nil {
return respond.DatabaseError(c, err)
}
if _, err := queries.MarkMessageThreadRead(ctx, sqlc.MarkMessageThreadReadParams{ThreadID: threadID, UserID: currentUserID}); err != nil {
return respond.DatabaseError(c, err)
}
if err := tx.Commit(ctx); err != nil {
return respond.DatabaseError(c, err)
}
return c.Status(fiber.StatusCreated).JSON(fiber.Map{"status": "ok"})
}
func (h *Handler) UpdateThread(c *fiber.Ctx) error {
threadID, err := params.Int64PathParam(c, "threadId")
if err != nil {
return err
}
currentUserID := authmw.CurrentUserID(c)
var req updateThreadRequest
if err := c.BodyParser(&req); err != nil {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
}
subject := strings.TrimSpace(req.Subject)
if subject == "" {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "subject is required")
}
ctx, cancel := shared.WithTimeout()
defer cancel()
thread, err := h.queries.GetMessageThreadForUser(ctx, sqlc.GetMessageThreadForUserParams{ID: threadID, SenderUserID: currentUserID})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found")
}
return respond.DatabaseError(c, err)
}
if thread.CreatedByUserID != currentUserID {
return respond.Error(c, fiber.StatusForbidden, "forbidden", "Only the conversation starter can edit the thread title")
}
if _, err := h.queries.UpdateMessageThreadSubject(ctx, sqlc.UpdateMessageThreadSubjectParams{
ThreadID: threadID,
Subject: subject,
}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found")
}
return respond.DatabaseError(c, err)
}
return c.JSON(fiber.Map{"status": "ok"})
}
func (h *Handler) DeleteThread(c *fiber.Ctx) error {
threadID, err := params.Int64PathParam(c, "threadId")
if err != nil {
return err
}
currentUserID := authmw.CurrentUserID(c)
ctx, cancel := shared.WithTimeout()
defer cancel()
thread, err := h.queries.GetMessageThreadForUser(ctx, sqlc.GetMessageThreadForUserParams{ID: threadID, SenderUserID: currentUserID})
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found")
}
return respond.DatabaseError(c, err)
}
if thread.CreatedByUserID != currentUserID {
return respond.Error(c, fiber.StatusForbidden, "forbidden", "Only the conversation starter can delete this conversation")
}
if _, err := h.queries.DeleteMessageThread(ctx, threadID); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found")
}
return respond.DatabaseError(c, err)
}
return c.JSON(fiber.Map{"status": "ok"})
}
func (h *Handler) UpdateThreadMessage(c *fiber.Ctx) error {
threadID, err := params.Int64PathParam(c, "threadId")
if err != nil {
return err
}
messageID, err := params.Int64PathParam(c, "messageId")
if err != nil {
return err
}
currentUserID := authmw.CurrentUserID(c)
var req updateThreadMessageRequest
if err := c.BodyParser(&req); err != nil {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "Unable to parse request body")
}
body := strings.TrimSpace(req.Body)
if body == "" {
return respond.Error(c, fiber.StatusBadRequest, "invalid_request", "body is required")
}
ctx, cancel := shared.WithTimeout()
defer cancel()
tx, err := h.db.Pool.Begin(ctx)
if err != nil {
return respond.DatabaseError(c, err)
}
defer func() {
_ = tx.Rollback(ctx)
}()
queries := h.queries.WithTx(tx)
if _, err := queries.UpdateThreadMessageBody(ctx, sqlc.UpdateThreadMessageBodyParams{
Body: body,
MessageID: messageID,
ThreadID: threadID,
UserID: currentUserID,
}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusNotFound, "not_found", "Message not found")
}
return respond.DatabaseError(c, err)
}
if err := queries.TouchMessageThread(ctx, threadID); err != nil {
return respond.DatabaseError(c, err)
}
if err := tx.Commit(ctx); err != nil {
return respond.DatabaseError(c, err)
}
return c.JSON(fiber.Map{"status": "ok"})
}
func (h *Handler) DeleteThreadMessage(c *fiber.Ctx) error {
threadID, err := params.Int64PathParam(c, "threadId")
if err != nil {
return err
}
messageID, err := params.Int64PathParam(c, "messageId")
if err != nil {
return err
}
currentUserID := authmw.CurrentUserID(c)
ctx, cancel := shared.WithTimeout()
defer cancel()
tx, err := h.db.Pool.Begin(ctx)
if err != nil {
return respond.DatabaseError(c, err)
}
defer func() {
_ = tx.Rollback(ctx)
}()
queries := h.queries.WithTx(tx)
if _, err := queries.DeleteThreadMessage(ctx, sqlc.DeleteThreadMessageParams{
MessageID: messageID,
ThreadID: threadID,
UserID: currentUserID,
}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusNotFound, "not_found", "Message not found")
}
return respond.DatabaseError(c, err)
}
if err := queries.TouchMessageThread(ctx, threadID); err != nil {
return respond.DatabaseError(c, err)
}
if err := tx.Commit(ctx); err != nil {
return respond.DatabaseError(c, err)
}
return c.JSON(fiber.Map{"status": "ok"})
}
func (h *Handler) MarkThreadRead(c *fiber.Ctx) error {
threadID, err := params.Int64PathParam(c, "threadId")
if err != nil {
return err
}
ctx, cancel := shared.WithTimeout()
defer cancel()
if _, err := h.queries.MarkMessageThreadRead(ctx, sqlc.MarkMessageThreadReadParams{ThreadID: threadID, UserID: authmw.CurrentUserID(c)}); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return respond.Error(c, fiber.StatusNotFound, "not_found", "Message thread not found")
}
return respond.DatabaseError(c, err)
}
return c.JSON(fiber.Map{"status": "ok"})
}
func (h *Handler) loadThread(threadID, currentUserID int64) (messageThreadDetailResponse, error) {
queryCtx, cancel := shared.WithTimeout()
defer cancel()
thread, err := h.queries.GetMessageThreadForUser(queryCtx, sqlc.GetMessageThreadForUserParams{ID: threadID, SenderUserID: currentUserID})
if err != nil {
return messageThreadDetailResponse{}, err
}
participants, err := h.queries.ListParticipantsForThreadForUser(queryCtx, sqlc.ListParticipantsForThreadForUserParams{ThreadID: threadID, UserID: currentUserID})
if err != nil {
return messageThreadDetailResponse{}, err
}
messages, err := h.queries.ListMessagesForThreadForUser(queryCtx, sqlc.ListMessagesForThreadForUserParams{ThreadID: threadID, UserID: currentUserID})
if err != nil {
return messageThreadDetailResponse{}, err
}
participantItems := make([]threadParticipantResponse, 0, len(participants))
for _, participant := range participants {
participantItems = append(participantItems, mapThreadParticipantByThread(participant))
}
messageItems := make([]messageResponse, 0, len(messages))
for _, message := range messages {
messageItems = append(messageItems, mapThreadMessage(message, currentUserID))
}
return messageThreadDetailResponse{
ID: thread.ID,
Subject: thread.Subject,
CreatedByUserID: thread.CreatedByUserID,
CreatedAt: shared.TimePointer(thread.CreatedAt),
UpdatedAt: shared.TimePointer(thread.UpdatedAt),
UnreadCount: thread.UnreadCount,
LastReadAt: shared.TimePointer(thread.LastReadAt),
Participants: participantItems,
Messages: messageItems,
}, nil
}
func mapRecipient(row sqlc.ListMessageRecipientsForUserRow) recipientResponse {
return recipientResponse{
ID: row.UserID,
Email: row.UserEmail,
Role: string(row.UserRole),
FullName: row.UserFullName,
PreferredName: shared.TextPointer(row.PreferredName),
ProfileIconURL: shared.TextPointer(row.ProfileIconUrl),
Headline: shared.TextPointer(row.Headline),
}
}
func mapRecipientByID(row sqlc.GetMessageRecipientByIDForUserRow) recipientResponse {
return recipientResponse{
ID: row.UserID,
Email: row.UserEmail,
Role: string(row.UserRole),
FullName: row.UserFullName,
PreferredName: shared.TextPointer(row.PreferredName),
ProfileIconURL: shared.TextPointer(row.ProfileIconUrl),
Headline: shared.TextPointer(row.Headline),
}
}
func mapThreadParticipant(row sqlc.ListMessageThreadParticipantsForUserRow) threadParticipantResponse {
return threadParticipantResponse{
ID: row.UserID,
Email: row.UserEmail,
Role: string(row.UserRole),
FullName: row.UserFullName,
PreferredName: shared.TextPointer(row.PreferredName),
ProfileIconURL: shared.TextPointer(row.ProfileIconUrl),
Headline: shared.TextPointer(row.Headline),
JoinedAt: shared.TimePointer(row.JoinedAt),
LastReadAt: shared.TimePointer(row.LastReadAt),
ArchivedAt: shared.TimePointer(row.ArchivedAt),
}
}
func mapThreadParticipantByThread(row sqlc.ListParticipantsForThreadForUserRow) threadParticipantResponse {
return threadParticipantResponse{
ID: row.UserID,
Email: row.UserEmail,
Role: string(row.UserRole),
FullName: row.UserFullName,
PreferredName: shared.TextPointer(row.PreferredName),
ProfileIconURL: shared.TextPointer(row.ProfileIconUrl),
Headline: shared.TextPointer(row.Headline),
JoinedAt: shared.TimePointer(row.JoinedAt),
LastReadAt: shared.TimePointer(row.LastReadAt),
ArchivedAt: shared.TimePointer(row.ArchivedAt),
}
}
func mapThreadSummary(row sqlc.ListMessageThreadsForUserRow, participants []threadParticipantResponse) messageThreadSummaryResponse {
response := messageThreadSummaryResponse{
ID: row.ThreadID,
Subject: row.Subject,
CreatedByUserID: row.CreatedByUserID,
CreatedAt: shared.TimePointer(row.ThreadCreatedAt),
UpdatedAt: shared.TimePointer(row.ThreadUpdatedAt),
UnreadCount: row.UnreadCount,
LastMessageID: row.LastMessageID,
LastMessageBody: stringPointerOrNil(row.LastMessageBody),
LastMessageCreatedAt: shared.TimePointer(row.LastMessageCreatedAt),
Participants: participants,
}
if row.LastMessageID > 0 {
response.LastMessageSender = &messageSenderResponse{
ID: row.LastMessageSenderUserID,
Email: "",
Role: "",
FullName: valueOrEmpty(row.LastMessageSenderFullName),
PreferredName: shared.TextPointer(row.LastMessageSenderPreferredName),
ProfileIconURL: shared.TextPointer(row.LastMessageSenderProfileIconUrl),
}
}
return response
}
func mapThreadMessage(row sqlc.ListMessagesForThreadForUserRow, currentUserID int64) messageResponse {
return messageResponse{
ID: row.ID,
ThreadID: row.ThreadID,
Body: row.Body,
CreatedAt: shared.TimePointer(row.CreatedAt),
UpdatedAt: shared.TimePointer(row.UpdatedAt),
Mine: row.SenderUserID == currentUserID,
Sender: messageSenderResponse{
ID: row.SenderUserID,
Email: row.SenderEmail,
Role: string(row.SenderRole),
FullName: row.SenderFullName,
PreferredName: shared.TextPointer(row.SenderPreferredName),
ProfileIconURL: shared.TextPointer(row.SenderProfileIconUrl),
Headline: shared.TextPointer(row.SenderHeadline),
},
}
}
func normalizeRecipientIDs(currentUserID int64, values []int64) []int64 {
seen := make(map[int64]struct{}, len(values))
normalized := make([]int64, 0, len(values))
for _, value := range values {
if value <= 0 || value == currentUserID {
continue
}
if _, ok := seen[value]; ok {
continue
}
seen[value] = struct{}{}
normalized = append(normalized, value)
}
sort.Slice(normalized, func(i, j int) bool { return normalized[i] < normalized[j] })
return normalized
}
func valueOrEmpty(value pgtype.Text) string {
if !value.Valid {
return ""
}
return value.String
}
func stringPointerOrNil(value string) *string {
if strings.TrimSpace(value) == "" {
return nil
}
copy := value
return &copy
}