Files
2026-06-15 22:46:12 +08:00

113 lines
2.5 KiB
Go

package proxy
import (
"context"
"encoding/base64"
"errors"
"io"
"net/http"
"strings"
"time"
"synctv/server/internal/room"
)
type Rooms interface {
Get(ctx context.Context, code string) (room.Room, error)
}
type Proxy struct {
rooms Rooms
enabled bool
client *http.Client
}
func New(rooms Rooms, enabled bool, timeout time.Duration, maxIdleConns int, responseHeaderBytes int64) *Proxy {
return &Proxy{
rooms: rooms,
enabled: enabled,
client: &http.Client{
Timeout: timeout,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConns: maxIdleConns,
ResponseHeaderTimeout: timeout,
MaxResponseHeaderBytes: responseHeaderBytes,
},
},
}
}
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !p.enabled {
http.Error(w, "proxy is disabled", http.StatusForbidden)
return
}
code := r.PathValue("code")
rm, err := p.rooms.Get(r.Context(), code)
if err != nil {
http.Error(w, err.Error(), http.StatusNotFound)
return
}
if rm.Source == nil || rm.Source.Mode != room.SourceModeProxy {
http.Error(w, "room source is not in proxy mode", http.StatusBadRequest)
return
}
if err := p.stream(w, r, *rm.Source); err != nil {
if !errors.Is(err, context.Canceled) {
http.Error(w, err.Error(), http.StatusBadGateway)
}
}
}
func (p *Proxy) stream(w http.ResponseWriter, r *http.Request, src room.Source) error {
req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, src.URL, nil)
if err != nil {
return err
}
copyRequestHeader(req.Header, r.Header, "Range", "User-Agent", "Accept", "Accept-Encoding")
for k, v := range src.Headers {
req.Header.Set(k, v)
}
if src.Username != "" || src.Password != "" {
req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(src.Username+":"+src.Password)))
}
resp, err := p.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
copyResponseHeaders(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
_, err = io.Copy(w, resp.Body)
return err
}
func copyRequestHeader(dst, src http.Header, keys ...string) {
for _, key := range keys {
value := src.Get(key)
if strings.TrimSpace(value) != "" {
dst.Set(key, value)
}
}
}
func copyResponseHeaders(dst, src http.Header) {
for _, key := range []string{
"Content-Type",
"Content-Length",
"Content-Range",
"Accept-Ranges",
"Cache-Control",
"Last-Modified",
"ETag",
} {
for _, value := range src.Values(key) {
dst.Add(key, value)
}
}
}