package proxy import ( "context" "encoding/base64" "errors" "io" "net/http" "strings" "time" "synctv/server/internal/room" ) type Rooms interface { Get(ctx context.Context, code string) (room.Room, error) } type Proxy struct { rooms Rooms enabled bool client *http.Client } func New(rooms Rooms, enabled bool, timeout time.Duration, maxIdleConns int, responseHeaderBytes int64) *Proxy { return &Proxy{ rooms: rooms, enabled: enabled, client: &http.Client{ Timeout: timeout, Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, MaxIdleConns: maxIdleConns, ResponseHeaderTimeout: timeout, MaxResponseHeaderBytes: responseHeaderBytes, }, }, } } func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !p.enabled { http.Error(w, "proxy is disabled", http.StatusForbidden) return } code := r.PathValue("code") rm, err := p.rooms.Get(r.Context(), code) if err != nil { http.Error(w, err.Error(), http.StatusNotFound) return } if rm.Source == nil || rm.Source.Mode != room.SourceModeProxy { http.Error(w, "room source is not in proxy mode", http.StatusBadRequest) return } if err := p.stream(w, r, *rm.Source); err != nil { if !errors.Is(err, context.Canceled) { http.Error(w, err.Error(), http.StatusBadGateway) } } } func (p *Proxy) stream(w http.ResponseWriter, r *http.Request, src room.Source) error { req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, src.URL, nil) if err != nil { return err } copyRequestHeader(req.Header, r.Header, "Range", "User-Agent", "Accept", "Accept-Encoding") for k, v := range src.Headers { req.Header.Set(k, v) } if src.Username != "" || src.Password != "" { req.Header.Set("Authorization", "Basic "+base64.StdEncoding.EncodeToString([]byte(src.Username+":"+src.Password))) } resp, err := p.client.Do(req) if err != nil { return err } defer resp.Body.Close() copyResponseHeaders(w.Header(), resp.Header) w.WriteHeader(resp.StatusCode) _, err = io.Copy(w, resp.Body) return err } func copyRequestHeader(dst, src http.Header, keys ...string) { for _, key := range keys { value := src.Get(key) if strings.TrimSpace(value) != "" { dst.Set(key, value) } } } func copyResponseHeaders(dst, src http.Header) { for _, key := range []string{ "Content-Type", "Content-Length", "Content-Range", "Accept-Ranges", "Cache-Control", "Last-Modified", "ETag", } { for _, value := range src.Values(key) { dst.Add(key, value) } } }