diff --git a/github/enterprise_licenses.go b/github/enterprise_licenses.go index 2fbb0cea771..73da6c5e66f 100644 --- a/github/enterprise_licenses.go +++ b/github/enterprise_licenses.go @@ -89,12 +89,12 @@ type LastLicenseSyncProperties struct { Error string `json:"error"` } -// GetConsumedLicenses collect information about the number of consumed licenses and a collection with all the users with consumed enterprise licenses. +// ListConsumedLicenses collect information about the number of consumed licenses and a collection with all the users with consumed enterprise licenses. // // GitHub API docs: https://docs.github.com/enterprise-cloud@latest/rest/enterprise-admin/licensing?apiVersion=2022-11-28#list-enterprise-consumed-licenses // //meta:operation GET /enterprises/{enterprise}/consumed-licenses -func (s *EnterpriseService) GetConsumedLicenses(ctx context.Context, enterprise string, opts *ListOptions) (*EnterpriseConsumedLicenses, *Response, error) { +func (s *EnterpriseService) ListConsumedLicenses(ctx context.Context, enterprise string, opts *ListOptions) (*EnterpriseConsumedLicenses, *Response, error) { u := fmt.Sprintf("enterprises/%v/consumed-licenses", enterprise) u, err := addOptions(u, opts) if err != nil { diff --git a/github/enterprise_licenses_test.go b/github/enterprise_licenses_test.go index dd7e817abef..18b8aea89cb 100644 --- a/github/enterprise_licenses_test.go +++ b/github/enterprise_licenses_test.go @@ -14,7 +14,7 @@ import ( "github.com/google/go-cmp/cmp" ) -func TestEnterpriseService_GetConsumedLicenses(t *testing.T) { +func TestEnterpriseService_ListConsumedLicenses(t *testing.T) { t.Parallel() client, mux, _ := setup(t) @@ -49,9 +49,9 @@ func TestEnterpriseService_GetConsumedLicenses(t *testing.T) { opt := &ListOptions{Page: 2, PerPage: 10} ctx := t.Context() - licenses, _, err := client.Enterprise.GetConsumedLicenses(ctx, "e", opt) + licenses, _, err := client.Enterprise.ListConsumedLicenses(ctx, "e", opt) if err != nil { - t.Errorf("Enterprise.GetConsumedLicenses returned error: %v", err) + t.Errorf("Enterprise.ListConsumedLicenses returned error: %v", err) } userName := "User One" @@ -90,17 +90,17 @@ func TestEnterpriseService_GetConsumedLicenses(t *testing.T) { } if !cmp.Equal(licenses, want) { - t.Errorf("Enterprise.GetConsumedLicenses returned %+v, want %+v", licenses, want) + t.Errorf("Enterprise.ListConsumedLicenses returned %+v, want %+v", licenses, want) } - const methodName = "GetConsumedLicenses" + const methodName = "ListConsumedLicenses" testBadOptions(t, methodName, func() (err error) { - _, _, err = client.Enterprise.GetConsumedLicenses(ctx, "\n", opt) + _, _, err = client.Enterprise.ListConsumedLicenses(ctx, "\n", opt) return err }) testNewRequestAndDoFailure(t, methodName, client, func() (*Response, error) { - got, resp, err := client.Enterprise.GetConsumedLicenses(ctx, "e", opt) + got, resp, err := client.Enterprise.ListConsumedLicenses(ctx, "e", opt) if got != nil { t.Errorf("testNewRequestAndDoFailure %v = %#v, want nil", methodName, got) } diff --git a/github/github-iterators.go b/github/github-iterators.go index 384d4718cd8..bbcfb2df27e 100644 --- a/github/github-iterators.go +++ b/github/github-iterators.go @@ -2857,6 +2857,41 @@ func (s *EnterpriseService) ListCodeSecurityConfigurationsIter(ctx context.Conte } } +// ListConsumedLicensesIter returns an iterator that paginates through all results of ListConsumedLicenses. +func (s *EnterpriseService) ListConsumedLicensesIter(ctx context.Context, enterprise string, opts *ListOptions) iter.Seq2[*EnterpriseLicensedUsers, error] { + return func(yield func(*EnterpriseLicensedUsers, error) bool) { + // Create a copy of opts to avoid mutating the caller's struct + if opts == nil { + opts = &ListOptions{} + } else { + opts = Ptr(*opts) + } + + for { + results, resp, err := s.ListConsumedLicenses(ctx, enterprise, opts) + if err != nil { + yield(nil, err) + return + } + + var iterItems []*EnterpriseLicensedUsers + if results != nil { + iterItems = results.Users + } + for _, item := range iterItems { + if !yield(item, nil) { + return + } + } + + if resp.NextPage == 0 { + break + } + opts.Page = resp.NextPage + } + } +} + // ListEnterpriseNetworkConfigurationsIter returns an iterator that paginates through all results of ListEnterpriseNetworkConfigurations. func (s *EnterpriseService) ListEnterpriseNetworkConfigurationsIter(ctx context.Context, enterprise string, opts *ListOptions) iter.Seq2[*NetworkConfiguration, error] { return func(yield func(*NetworkConfiguration, error) bool) { diff --git a/github/github-iterators_test.go b/github/github-iterators_test.go index ca1fb8d76c8..d9ce0e5973f 100644 --- a/github/github-iterators_test.go +++ b/github/github-iterators_test.go @@ -6135,6 +6135,78 @@ func TestEnterpriseService_ListCodeSecurityConfigurationsIter(t *testing.T) { } } +func TestEnterpriseService_ListConsumedLicensesIter(t *testing.T) { + t.Parallel() + client, mux, _ := setup(t) + var callNum int + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + callNum++ + switch callNum { + case 1: + w.Header().Set("Link", `; rel="next"`) + fmt.Fprint(w, `{"users": [{},{},{}]}`) + case 2: + fmt.Fprint(w, `{"users": [{},{},{},{}]}`) + case 3: + fmt.Fprint(w, `{"users": [{},{}]}`) + case 4: + w.WriteHeader(http.StatusNotFound) + case 5: + fmt.Fprint(w, `{"users": [{},{}]}`) + } + }) + + iter := client.Enterprise.ListConsumedLicensesIter(t.Context(), "", nil) + var gotItems int + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 7; gotItems != want { + t.Errorf("client.Enterprise.ListConsumedLicensesIter call 1 got %v items; want %v", gotItems, want) + } + + opts := &ListOptions{} + iter = client.Enterprise.ListConsumedLicensesIter(t.Context(), "", opts) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + } + if want := 2; gotItems != want { + t.Errorf("client.Enterprise.ListConsumedLicensesIter call 2 got %v items; want %v", gotItems, want) + } + + iter = client.Enterprise.ListConsumedLicensesIter(t.Context(), "", nil) + gotItems = 0 + for _, err := range iter { + gotItems++ + if err == nil { + t.Error("expected error; got nil") + } + } + if gotItems != 1 { + t.Errorf("client.Enterprise.ListConsumedLicensesIter call 3 got %v items; want 1 (an error)", gotItems) + } + + iter = client.Enterprise.ListConsumedLicensesIter(t.Context(), "", nil) + gotItems = 0 + iter(func(item *EnterpriseLicensedUsers, err error) bool { + gotItems++ + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + return false + }) + if gotItems != 1 { + t.Errorf("client.Enterprise.ListConsumedLicensesIter call 4 got %v items; want 1 (an error)", gotItems) + } +} + func TestEnterpriseService_ListEnterpriseNetworkConfigurationsIter(t *testing.T) { t.Parallel() client, mux, _ := setup(t)