Before Fine Tune
This commit is contained in:
@@ -11,6 +11,13 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type requestStyle string
|
||||
|
||||
const (
|
||||
requestStyleResponses requestStyle = "responses"
|
||||
requestStyleChatCompletions requestStyle = "chat_completions"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
endpoint string
|
||||
apiKey string
|
||||
@@ -109,14 +116,9 @@ func (s *Service) ReviewSubmission(ctx context.Context, input AssignmentReviewIn
|
||||
return nil, fmt.Errorf("marshal AI review input: %w", err)
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"model": s.model,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "system",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": strings.TrimSpace(`You are reviewing student homework submissions for a teacher workflow.
|
||||
outputText, err := s.runStructuredRequest(
|
||||
ctx,
|
||||
strings.TrimSpace(`You are reviewing student homework submissions for a teacher workflow.
|
||||
|
||||
You must assess the student's understanding by looking at the student's final answer and working against the saved correct answer when one is available. Do not re-grade weighting.
|
||||
|
||||
@@ -142,56 +144,11 @@ Interpretation guidance:
|
||||
- support = the student shows meaningful gaps and likely needs targeted help.
|
||||
- redo = the student should redo the assignment because understanding is broadly too weak or incomplete.
|
||||
|
||||
Review the full assignment in one pass and produce a short assignment-level summary.`),
|
||||
}},
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": string(payloadJSON),
|
||||
}},
|
||||
},
|
||||
},
|
||||
"text": map[string]any{
|
||||
"format": map[string]any{
|
||||
"type": "json_schema",
|
||||
"name": "assignment_review",
|
||||
"strict": true,
|
||||
"schema": reviewSchema(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal AI review request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.endpoint, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build AI review request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("api-key", s.apiKey)
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send AI review request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read AI review response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("AI review request failed: status %d body %s", resp.StatusCode, strings.TrimSpace(string(respBytes)))
|
||||
}
|
||||
|
||||
outputText, err := extractOutputText(respBytes)
|
||||
Review the full assignment in one pass and produce a short assignment-level summary.`),
|
||||
string(payloadJSON),
|
||||
"assignment_review",
|
||||
reviewSchema(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -215,14 +172,9 @@ func (s *Service) PlanRedoAssignment(ctx context.Context, input RedoPlanInput) (
|
||||
return nil, fmt.Errorf("marshal redo plan input: %w", err)
|
||||
}
|
||||
|
||||
body := map[string]any{
|
||||
"model": s.model,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "system",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": strings.TrimSpace(`You are planning the next redo assignment for a student.
|
||||
outputText, err := s.runStructuredRequest(
|
||||
ctx,
|
||||
strings.TrimSpace(`You are planning the next redo assignment for a student.
|
||||
|
||||
You are NOT writing final math questions. You are only producing a structured topic+difficulty blueprint for a later generator layer.
|
||||
|
||||
@@ -238,55 +190,10 @@ Rules:
|
||||
- reason on each item should briefly explain why that topic/difficulty belongs in the redo set.
|
||||
- Do not invent topics outside the allowed topic vocabulary.
|
||||
- Do not output prose outside the JSON schema.`),
|
||||
}},
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": string(payloadJSON),
|
||||
}},
|
||||
},
|
||||
},
|
||||
"text": map[string]any{
|
||||
"format": map[string]any{
|
||||
"type": "json_schema",
|
||||
"name": "redo_assignment_plan",
|
||||
"strict": true,
|
||||
"schema": redoPlanSchema(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal redo plan request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.endpoint, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build redo plan request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("api-key", s.apiKey)
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send redo plan request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read redo plan response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("redo plan request failed: status %d body %s", resp.StatusCode, strings.TrimSpace(string(respBytes)))
|
||||
}
|
||||
|
||||
outputText, err := extractOutputText(respBytes)
|
||||
string(payloadJSON),
|
||||
"redo_assignment_plan",
|
||||
redoPlanSchema(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -300,6 +207,146 @@ Rules:
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *Service) runStructuredRequest(ctx context.Context, systemPrompt, userPrompt, schemaName string, schema map[string]any) (string, error) {
|
||||
respBytes, err := s.sendStructuredRequest(ctx, systemPrompt, userPrompt, schemaName, schema)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch s.requestStyle() {
|
||||
case requestStyleChatCompletions:
|
||||
return extractChatCompletionText(respBytes)
|
||||
default:
|
||||
return extractResponsesOutputText(respBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) sendStructuredRequest(ctx context.Context, systemPrompt, userPrompt, schemaName string, schema map[string]any) ([]byte, error) {
|
||||
body, err := s.buildStructuredRequestBody(systemPrompt, userPrompt, schemaName, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal AI review request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, s.endpoint, bytes.NewReader(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("build AI review request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
s.applyAuthHeader(req)
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send AI review request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read AI review response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("AI review request failed: status %d body %s", resp.StatusCode, strings.TrimSpace(string(respBytes)))
|
||||
}
|
||||
|
||||
return respBytes, nil
|
||||
}
|
||||
|
||||
func (s *Service) buildStructuredRequestBody(systemPrompt, userPrompt, schemaName string, schema map[string]any) (map[string]any, error) {
|
||||
switch s.requestStyle() {
|
||||
case requestStyleChatCompletions:
|
||||
body := map[string]any{
|
||||
"model": s.model,
|
||||
"messages": []map[string]any{
|
||||
{
|
||||
"role": "system",
|
||||
"content": systemPrompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": userPrompt,
|
||||
},
|
||||
},
|
||||
"temperature": 0,
|
||||
"response_format": map[string]any{
|
||||
"type": "json_schema",
|
||||
"json_schema": map[string]any{
|
||||
"name": schemaName,
|
||||
"strict": true,
|
||||
"schema": schema,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if s.shouldDisableThinking() {
|
||||
body["chat_template_kwargs"] = map[string]any{
|
||||
"enable_thinking": false,
|
||||
}
|
||||
}
|
||||
|
||||
return body, nil
|
||||
default:
|
||||
return map[string]any{
|
||||
"model": s.model,
|
||||
"input": []map[string]any{
|
||||
{
|
||||
"role": "system",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": systemPrompt,
|
||||
}},
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": []map[string]any{{
|
||||
"type": "input_text",
|
||||
"text": userPrompt,
|
||||
}},
|
||||
},
|
||||
},
|
||||
"text": map[string]any{
|
||||
"format": map[string]any{
|
||||
"type": "json_schema",
|
||||
"name": schemaName,
|
||||
"strict": true,
|
||||
"schema": schema,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) requestStyle() requestStyle {
|
||||
endpoint := strings.ToLower(strings.TrimSpace(s.endpoint))
|
||||
if strings.Contains(endpoint, "/chat/completions") {
|
||||
return requestStyleChatCompletions
|
||||
}
|
||||
return requestStyleResponses
|
||||
}
|
||||
|
||||
func (s *Service) applyAuthHeader(req *http.Request) {
|
||||
if s.isAzureEndpoint() {
|
||||
req.Header.Set("api-key", s.apiKey)
|
||||
return
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+s.apiKey)
|
||||
}
|
||||
|
||||
func (s *Service) shouldDisableThinking() bool {
|
||||
return s.requestStyle() == requestStyleChatCompletions && !s.isAzureEndpoint()
|
||||
}
|
||||
|
||||
func (s *Service) isAzureEndpoint() bool {
|
||||
endpoint := strings.ToLower(strings.TrimSpace(s.endpoint))
|
||||
return strings.Contains(endpoint, "cognitiveservices.azure.com") || strings.Contains(endpoint, ".openai.azure.com")
|
||||
}
|
||||
|
||||
func reviewSchema() map[string]any {
|
||||
return map[string]any{
|
||||
"type": "object",
|
||||
@@ -358,7 +405,7 @@ func redoPlanSchema() map[string]any {
|
||||
}
|
||||
}
|
||||
|
||||
func extractOutputText(respBytes []byte) (string, error) {
|
||||
func extractResponsesOutputText(respBytes []byte) (string, error) {
|
||||
var direct struct {
|
||||
OutputText string `json:"output_text"`
|
||||
}
|
||||
@@ -380,6 +427,47 @@ func extractOutputText(respBytes []byte) (string, error) {
|
||||
return "", fmt.Errorf("AI review response did not contain structured output text")
|
||||
}
|
||||
|
||||
func extractChatCompletionText(respBytes []byte) (string, error) {
|
||||
var payload struct {
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Content any `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(respBytes, &payload); err != nil {
|
||||
return "", fmt.Errorf("decode AI review chat completion response: %w", err)
|
||||
}
|
||||
|
||||
for _, choice := range payload.Choices {
|
||||
if text := strings.TrimSpace(extractMessageContent(choice.Message.Content)); text != "" {
|
||||
return text, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("AI review chat completion response did not contain message content")
|
||||
}
|
||||
|
||||
func extractMessageContent(content any) string {
|
||||
switch typed := content.(type) {
|
||||
case string:
|
||||
return typed
|
||||
case []any:
|
||||
parts := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
text := strings.TrimSpace(findOutputText(item))
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
parts = append(parts, text)
|
||||
}
|
||||
return strings.Join(parts, "\n")
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func findOutputText(value any) string {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
|
||||
Reference in New Issue
Block a user