Update vendoring

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar
2023-11-15 21:38:54 +01:00
parent c63fb35ba8
commit 2ff2d84283
57 changed files with 4299 additions and 1606 deletions

202
vendor/github.com/NVIDIA/go-nvlib/LICENSE generated vendored Normal file
View File

@@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

4
vendor/github.com/NVIDIA/go-nvlib/NOTICE generated vendored Normal file
View File

@@ -0,0 +1,4 @@
The file pkg/pciids/default_pci.ids is distributed under the 3-clause BSD License.
Maintained by Albert Pool, Martin Mares, and other volunteers from
the PCI ID Project at https://pci-ids.ucw.cz/.

View File

@@ -0,0 +1,98 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package device
import (
"github.com/NVIDIA/go-nvlib/pkg/nvml"
)
// Interface provides the API to the 'device' package
type Interface interface {
AssertValidMigProfileFormat(profile string) error
GetDevices() ([]Device, error)
GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
NewDevice(d nvml.Device) (Device, error)
NewDeviceByUUID(uuid string) (Device, error)
NewMigDevice(d nvml.Device) (MigDevice, error)
NewMigDeviceByUUID(uuid string) (MigDevice, error)
NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error)
ParseMigProfile(profile string) (MigProfile, error)
VisitDevices(func(i int, d Device) error) error
VisitMigDevices(func(i int, d Device, j int, m MigDevice) error) error
VisitMigProfiles(func(p MigProfile) error) error
}
type devicelib struct {
nvml nvml.Interface
skippedDevices map[string]struct{}
verifySymbols *bool
migProfiles []MigProfile
}
var _ Interface = &devicelib{}
// New creates a new instance of the 'device' interface
func New(opts ...Option) Interface {
d := &devicelib{}
for _, opt := range opts {
opt(d)
}
if d.nvml == nil {
d.nvml = nvml.New()
}
if d.verifySymbols == nil {
verify := true
d.verifySymbols = &verify
}
if d.skippedDevices == nil {
WithSkippedDevices(
"DGX Display",
"NVIDIA DGX Display",
)(d)
}
return d
}
// WithNvml provides an Option to set the NVML library used by the 'device' interface
func WithNvml(nvml nvml.Interface) Option {
return func(d *devicelib) {
d.nvml = nvml
}
}
// WithVerifySymbols provides an option to toggle whether to verify select symbols exist in dynamic libraries before calling them
func WithVerifySymbols(verify bool) Option {
return func(d *devicelib) {
d.verifySymbols = &verify
}
}
// WithSkippedDevices provides an Option to set devices to be skipped by model name
func WithSkippedDevices(names ...string) Option {
return func(d *devicelib) {
if d.skippedDevices == nil {
d.skippedDevices = make(map[string]struct{})
}
for _, name := range names {
d.skippedDevices[name] = struct{}{}
}
}
}
// Option defines a function for passing options to the New() call
type Option func(*devicelib)

View File

@@ -0,0 +1,473 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package device
import (
"fmt"
"github.com/NVIDIA/go-nvlib/pkg/nvml"
)
// Device defines the set of extended functions associated with a device.Device
type Device interface {
nvml.Device
GetArchitectureAsString() (string, error)
GetBrandAsString() (string, error)
GetCudaComputeCapabilityAsString() (string, error)
GetMigDevices() ([]MigDevice, error)
GetMigProfiles() ([]MigProfile, error)
IsMigCapable() (bool, error)
IsMigEnabled() (bool, error)
VisitMigDevices(func(j int, m MigDevice) error) error
VisitMigProfiles(func(p MigProfile) error) error
}
type device struct {
nvml.Device
lib *devicelib
migProfiles []MigProfile
}
var _ Device = &device{}
// NewDevice builds a new Device from an nvml.Device
func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) {
return d.newDevice(dev)
}
// NewDeviceByUUID builds a new Device from a UUID
func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) {
dev, ret := d.nvml.DeviceGetHandleByUUID(uuid)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret)
}
return d.newDevice(dev)
}
// newDevice creates a device from an nvml.Device
func (d *devicelib) newDevice(dev nvml.Device) (*device, error) {
return &device{dev, d, nil}, nil
}
// GetArchitectureAsString returns the Device architecture as a string
func (d *device) GetArchitectureAsString() (string, error) {
arch, ret := d.GetArchitecture()
if ret != nvml.SUCCESS {
return "", fmt.Errorf("error getting device architecture: %v", ret)
}
switch arch {
case nvml.DEVICE_ARCH_KEPLER:
return "Kepler", nil
case nvml.DEVICE_ARCH_MAXWELL:
return "Maxwell", nil
case nvml.DEVICE_ARCH_PASCAL:
return "Pascal", nil
case nvml.DEVICE_ARCH_VOLTA:
return "Volta", nil
case nvml.DEVICE_ARCH_TURING:
return "Turing", nil
case nvml.DEVICE_ARCH_AMPERE:
return "Ampere", nil
case nvml.DEVICE_ARCH_ADA:
return "Ada", nil
case nvml.DEVICE_ARCH_HOPPER:
return "Hopper", nil
case nvml.DEVICE_ARCH_UNKNOWN:
return "Unknown", nil
}
return "", fmt.Errorf("error interpreting device architecture as string: %v", arch)
}
// GetBrandAsString returns the Device architecture as a string
func (d *device) GetBrandAsString() (string, error) {
brand, ret := d.GetBrand()
if ret != nvml.SUCCESS {
return "", fmt.Errorf("error getting device brand: %v", ret)
}
switch brand {
case nvml.BRAND_UNKNOWN:
return "Unknown", nil
case nvml.BRAND_QUADRO:
return "Quadro", nil
case nvml.BRAND_TESLA:
return "Tesla", nil
case nvml.BRAND_NVS:
return "NVS", nil
case nvml.BRAND_GRID:
return "Grid", nil
case nvml.BRAND_GEFORCE:
return "GeForce", nil
case nvml.BRAND_TITAN:
return "Titan", nil
case nvml.BRAND_NVIDIA_VAPPS:
return "NvidiaVApps", nil
case nvml.BRAND_NVIDIA_VPC:
return "NvidiaVPC", nil
case nvml.BRAND_NVIDIA_VCS:
return "NvidiaVCS", nil
case nvml.BRAND_NVIDIA_VWS:
return "NvidiaVWS", nil
// Deprecated in favor of nvml.BRAND_NVIDIA_CLOUD_GAMING
//case nvml.BRAND_NVIDIA_VGAMING:
// return "VGaming", nil
case nvml.BRAND_NVIDIA_CLOUD_GAMING:
return "NvidiaCloudGaming", nil
case nvml.BRAND_QUADRO_RTX:
return "QuadroRTX", nil
case nvml.BRAND_NVIDIA_RTX:
return "NvidiaRTX", nil
case nvml.BRAND_NVIDIA:
return "Nvidia", nil
case nvml.BRAND_GEFORCE_RTX:
return "GeForceRTX", nil
case nvml.BRAND_TITAN_RTX:
return "TitanRTX", nil
}
return "", fmt.Errorf("error interpreting device brand as string: %v", brand)
}
// GetCudaComputeCapabilityAsString returns the Device's CUDA compute capability as a version string
func (d *device) GetCudaComputeCapabilityAsString() (string, error) {
major, minor, ret := d.GetCudaComputeCapability()
if ret != nvml.SUCCESS {
return "", fmt.Errorf("error getting CUDA compute capability: %v", ret)
}
return fmt.Sprintf("%d.%d", major, minor), nil
}
// IsMigCapable checks if a device is capable of having MIG paprtitions created on it
func (d *device) IsMigCapable() (bool, error) {
if !d.lib.hasSymbol("nvmlDeviceGetMigMode") {
return false, nil
}
_, _, ret := nvml.Device(d).GetMigMode()
if ret == nvml.ERROR_NOT_SUPPORTED {
return false, nil
}
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting MIG mode: %v", ret)
}
return true, nil
}
// IsMigEnabled checks if a device has MIG mode currently enabled on it
func (d *device) IsMigEnabled() (bool, error) {
if !d.lib.hasSymbol("nvmlDeviceGetMigMode") {
return false, nil
}
mode, _, ret := nvml.Device(d).GetMigMode()
if ret == nvml.ERROR_NOT_SUPPORTED {
return false, nil
}
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting MIG mode: %v", ret)
}
return (mode == nvml.DEVICE_MIG_ENABLE), nil
}
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
func (d *device) VisitMigDevices(visit func(int, MigDevice) error) error {
capable, err := d.IsMigCapable()
if err != nil {
return fmt.Errorf("error checking if GPU is MIG capable: %v", err)
}
if !capable {
return nil
}
count, ret := nvml.Device(d).GetMaxMigDeviceCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting max MIG device count: %v", ret)
}
for i := 0; i < count; i++ {
device, ret := nvml.Device(d).GetMigDeviceHandleByIndex(i)
if ret == nvml.ERROR_NOT_FOUND {
continue
}
if ret == nvml.ERROR_INVALID_ARGUMENT {
continue
}
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting MIG device handle at index '%v': %v", i, ret)
}
mig, err := d.lib.NewMigDevice(device)
if err != nil {
return fmt.Errorf("error creating new MIG device wrapper: %v", err)
}
err = visit(i, mig)
if err != nil {
return fmt.Errorf("error visiting MIG device: %v", err)
}
}
return nil
}
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG Profile that can be configured on it
func (d *device) VisitMigProfiles(visit func(MigProfile) error) error {
capable, err := d.IsMigCapable()
if err != nil {
return fmt.Errorf("error checking if GPU is MIG capable: %v", err)
}
if !capable {
return nil
}
memory, ret := d.GetMemoryInfo()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device memory info: %v", ret)
}
for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ {
giProfileInfo, ret := d.GetGpuInstanceProfileInfo(i)
if ret == nvml.ERROR_NOT_SUPPORTED {
continue
}
if ret == nvml.ERROR_INVALID_ARGUMENT {
continue
}
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting GPU Instance profile info: %v", ret)
}
for j := 0; j < nvml.COMPUTE_INSTANCE_PROFILE_COUNT; j++ {
for k := 0; k < nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT; k++ {
p, err := d.lib.NewMigProfile(i, j, k, giProfileInfo.MemorySizeMB, memory.Total)
if err != nil {
return fmt.Errorf("error creating MIG profile: %v", err)
}
// NOTE: The NVML API doesn't currently let us query the set of
// valid Compute Instance profiles without first instantiating
// a GPU Instance to check against. In theory, it should be
// possible to get this information without a reference to a
// GPU instance, but no API is provided for that at the moment.
// We run the checks below to weed out invalid profiles
// heuristically, given what we know about how they are
// physically constructed. In the future we should do this via
// NVML once a proper API for this exists.
pi := p.GetInfo()
if pi.C > pi.G {
continue
}
if (pi.C < pi.G) && ((pi.C * 2) > (pi.G + 1)) {
continue
}
err = visit(p)
if err != nil {
return fmt.Errorf("error visiting MIG profile: %v", err)
}
}
}
}
return nil
}
// GetMigDevices gets the set of MIG devices associated with a top-level device
func (d *device) GetMigDevices() ([]MigDevice, error) {
var migs []MigDevice
err := d.VisitMigDevices(func(j int, m MigDevice) error {
migs = append(migs, m)
return nil
})
if err != nil {
return nil, err
}
return migs, nil
}
// GetMigProfiles gets the set of unique MIG profiles associated with a top-level device
func (d *device) GetMigProfiles() ([]MigProfile, error) {
// Return the cached list if available
if d.migProfiles != nil {
return d.migProfiles, nil
}
// Otherwise generate it...
var profiles []MigProfile
err := d.VisitMigProfiles(func(p MigProfile) error {
profiles = append(profiles, p)
return nil
})
if err != nil {
return nil, err
}
// And cache it before returning
d.migProfiles = profiles
return profiles, nil
}
// isSkipped checks whether the device should be skipped.
func (d *device) isSkipped() (bool, error) {
name, ret := d.GetName()
if ret != nvml.SUCCESS {
return false, fmt.Errorf("error getting device name: %v", ret)
}
if _, exists := d.lib.skippedDevices[name]; exists {
return true, nil
}
return false, nil
}
// VisitDevices visits each top-level device and invokes a callback function for it
func (d *devicelib) VisitDevices(visit func(int, Device) error) error {
count, ret := d.nvml.DeviceGetCount()
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device count: %v", ret)
}
for i := 0; i < count; i++ {
device, ret := d.nvml.DeviceGetHandleByIndex(i)
if ret != nvml.SUCCESS {
return fmt.Errorf("error getting device handle for index '%v': %v", i, ret)
}
dev, err := d.newDevice(device)
if err != nil {
return fmt.Errorf("error creating new device wrapper: %v", err)
}
isSkipped, err := dev.isSkipped()
if err != nil {
return fmt.Errorf("error checking whether device is skipped: %v", err)
}
if isSkipped {
continue
}
err = visit(i, dev)
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
}
return nil
}
// VisitMigDevices walks a top-level device and invokes a callback function for each MIG device configured on it
func (d *devicelib) VisitMigDevices(visit func(int, Device, int, MigDevice) error) error {
err := d.VisitDevices(func(i int, dev Device) error {
err := dev.VisitMigDevices(func(j int, mig MigDevice) error {
err := visit(i, dev, j, mig)
if err != nil {
return fmt.Errorf("error visiting MIG device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting devices: %v", err)
}
return nil
}
// VisitMigProfiles walks a top-level device and invokes a callback function for each unique MIG profile found on them
func (d *devicelib) VisitMigProfiles(visit func(MigProfile) error) error {
visited := make(map[string]bool)
err := d.VisitDevices(func(i int, dev Device) error {
err := dev.VisitMigProfiles(func(p MigProfile) error {
if visited[p.String()] {
return nil
}
err := visit(p)
if err != nil {
return fmt.Errorf("error visiting MIG profile: %v", err)
}
visited[p.String()] = true
return nil
})
if err != nil {
return fmt.Errorf("error visiting device: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("error visiting devices: %v", err)
}
return nil
}
// GetDevices gets the set of all top-level devices
func (d *devicelib) GetDevices() ([]Device, error) {
var devs []Device
err := d.VisitDevices(func(i int, dev Device) error {
devs = append(devs, dev)
return nil
})
if err != nil {
return nil, err
}
return devs, nil
}
// GetMigDevices gets the set of MIG devices across all top-level devices
func (d *devicelib) GetMigDevices() ([]MigDevice, error) {
var migs []MigDevice
err := d.VisitMigDevices(func(i int, dev Device, j int, m MigDevice) error {
migs = append(migs, m)
return nil
})
if err != nil {
return nil, err
}
return migs, nil
}
// GetMigProfiles gets the set of unique MIG profiles across all top-level devices
func (d *devicelib) GetMigProfiles() ([]MigProfile, error) {
// Return the cached list if available
if d.migProfiles != nil {
return d.migProfiles, nil
}
// Otherwise generate it...
var profiles []MigProfile
err := d.VisitMigProfiles(func(p MigProfile) error {
profiles = append(profiles, p)
return nil
})
if err != nil {
return nil, err
}
// And cache it before returning
d.migProfiles = profiles
return profiles, nil
}
// hasSymbol checks to see if the given symbol is present in the NVML library.
// If devicelib is configured to not verify symbols, then all symbols are assumed to exist.
func (d *devicelib) hasSymbol(symbol string) bool {
if !*d.verifySymbols {
return true
}
return d.nvml.Lookup(symbol) == nil
}

View File

@@ -0,0 +1,157 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package device
import (
"fmt"
"github.com/NVIDIA/go-nvlib/pkg/nvml"
)
// MigDevice defines the set of extended functions associated with a MIG device
type MigDevice interface {
nvml.Device
GetProfile() (MigProfile, error)
}
type migdevice struct {
nvml.Device
lib *devicelib
profile MigProfile
}
var _ MigDevice = &migdevice{}
// NewMigDevice builds a new MigDevice from an nvml.Device
func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) {
isMig, ret := handle.IsMigDeviceHandle()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error checking if device is a MIG device: %v", ret)
}
if !isMig {
return nil, fmt.Errorf("not a MIG device")
}
return &migdevice{handle, d, nil}, nil
}
// NewMigDeviceByUUID builds a new MigDevice from a UUID
func (d *devicelib) NewMigDeviceByUUID(uuid string) (MigDevice, error) {
dev, ret := d.nvml.DeviceGetHandleByUUID(uuid)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting device handle for uuid '%v': %v", uuid, ret)
}
return d.NewMigDevice(dev)
}
// GetProfile returns the MIG profile associated with a MIG device
func (m *migdevice) GetProfile() (MigProfile, error) {
if m.profile != nil {
return m.profile, nil
}
parent, ret := m.Device.GetDeviceHandleFromMigDeviceHandle()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting parent device handle: %v", ret)
}
parentMemoryInfo, ret := parent.GetMemoryInfo()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting parent memory info: %v", ret)
}
attributes, ret := m.Device.GetAttributes()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device attributes: %v", ret)
}
giID, ret := m.Device.GetGpuInstanceId()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device GPU Instance ID: %v", ret)
}
ciID, ret := m.Device.GetComputeInstanceId()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting MIG device Compute Instance ID: %v", ret)
}
gi, ret := parent.GetGpuInstanceById(giID)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting GPU Instance: %v", ret)
}
ci, ret := gi.GetComputeInstanceById(ciID)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting Compute Instance: %v", ret)
}
giInfo, ret := gi.GetInfo()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting GPU Instance info: %v", ret)
}
ciInfo, ret := ci.GetInfo()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting Compute Instance info: %v", ret)
}
for i := 0; i < nvml.GPU_INSTANCE_PROFILE_COUNT; i++ {
giProfileInfo, ret := parent.GetGpuInstanceProfileInfo(i)
if ret == nvml.ERROR_NOT_SUPPORTED {
continue
}
if ret == nvml.ERROR_INVALID_ARGUMENT {
continue
}
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting GPU Instance profile info: %v", ret)
}
if giProfileInfo.Id != giInfo.ProfileId {
continue
}
for j := 0; j < nvml.COMPUTE_INSTANCE_PROFILE_COUNT; j++ {
for k := 0; k < nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT; k++ {
ciProfileInfo, ret := gi.GetComputeInstanceProfileInfo(j, k)
if ret == nvml.ERROR_NOT_SUPPORTED {
continue
}
if ret == nvml.ERROR_INVALID_ARGUMENT {
continue
}
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("error getting Compute Instance profile info: %v", ret)
}
if ciProfileInfo.Id != ciInfo.ProfileId {
continue
}
p, err := m.lib.NewMigProfile(i, j, k, attributes.MemorySizeMB, parentMemoryInfo.Total)
if err != nil {
return nil, fmt.Errorf("error creating MIG profile: %v", err)
}
m.profile = p
return p, nil
}
}
}
return nil, fmt.Errorf("no matching profile IDs found")
}

View File

@@ -0,0 +1,331 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package device
import (
"fmt"
"math"
"sort"
"strconv"
"strings"
"github.com/NVIDIA/go-nvlib/pkg/nvml"
)
const (
// AttributeMediaExtensions holds the string representation for the media extension MIG profile attribute.
AttributeMediaExtensions = "me"
)
// MigProfile represents a specific MIG profile.
// Examples include "1g.5gb", "2g.10gb", "1c.2g.10gb", or "1c.1g.5gb+me", etc.
type MigProfile interface {
String() string
GetInfo() MigProfileInfo
Equals(other MigProfile) bool
Matches(profile string) bool
}
// MigProfileInfo holds all info associated with a specific MIG profile
type MigProfileInfo struct {
C int
G int
GB int
Attributes []string
GIProfileID int
CIProfileID int
CIEngProfileID int
}
var _ MigProfile = &MigProfileInfo{}
// NewProfile constructs a new Profile struct using info from the giProfiles and ciProfiles used to create it.
func (d *devicelib) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) {
giSlices := 0
switch giProfileID {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE,
nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1,
nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV2:
giSlices = 1
case nvml.GPU_INSTANCE_PROFILE_2_SLICE,
nvml.GPU_INSTANCE_PROFILE_2_SLICE_REV1:
giSlices = 2
case nvml.GPU_INSTANCE_PROFILE_3_SLICE:
giSlices = 3
case nvml.GPU_INSTANCE_PROFILE_4_SLICE:
giSlices = 4
case nvml.GPU_INSTANCE_PROFILE_6_SLICE:
giSlices = 6
case nvml.GPU_INSTANCE_PROFILE_7_SLICE:
giSlices = 7
case nvml.GPU_INSTANCE_PROFILE_8_SLICE:
giSlices = 8
default:
return nil, fmt.Errorf("invalid GPU Instance Profile ID: %v", giProfileID)
}
ciSlices := 0
switch ciProfileID {
case nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE,
nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE_REV1:
ciSlices = 1
case nvml.COMPUTE_INSTANCE_PROFILE_2_SLICE:
ciSlices = 2
case nvml.COMPUTE_INSTANCE_PROFILE_3_SLICE:
ciSlices = 3
case nvml.COMPUTE_INSTANCE_PROFILE_4_SLICE:
ciSlices = 4
case nvml.COMPUTE_INSTANCE_PROFILE_6_SLICE:
ciSlices = 6
case nvml.COMPUTE_INSTANCE_PROFILE_7_SLICE:
ciSlices = 7
case nvml.COMPUTE_INSTANCE_PROFILE_8_SLICE:
ciSlices = 8
default:
return nil, fmt.Errorf("invalid Compute Instance Profile ID: %v", ciProfileID)
}
var attrs []string
switch giProfileID {
case nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1,
nvml.GPU_INSTANCE_PROFILE_2_SLICE_REV1:
attrs = append(attrs, AttributeMediaExtensions)
}
p := &MigProfileInfo{
C: ciSlices,
G: giSlices,
GB: int(getMigMemorySizeGB(deviceMemorySizeBytes, migMemorySizeMB)),
Attributes: attrs,
GIProfileID: giProfileID,
CIProfileID: ciProfileID,
CIEngProfileID: ciEngProfileID,
}
return p, nil
}
// AssertValidMigProfileFormat checks if the string is in the proper format to represent a MIG profile
func (d *devicelib) AssertValidMigProfileFormat(profile string) error {
_, _, _, _, err := parseMigProfile(profile)
return err
}
// ParseMigProfile converts a string representation of a MigProfile into an object
func (d *devicelib) ParseMigProfile(profile string) (MigProfile, error) {
profiles, err := d.GetMigProfiles()
if err != nil {
return nil, fmt.Errorf("error getting list of possible MIG profiles: %v", err)
}
for _, p := range profiles {
if p.Matches(profile) {
return p, nil
}
}
return nil, fmt.Errorf("unable to parse profile string into a valid profile")
}
// String returns the string representation of a Profile
func (p MigProfileInfo) String() string {
var suffix string
if len(p.Attributes) > 0 {
suffix = "+" + strings.Join(p.Attributes, ",")
}
if p.C == p.G {
return fmt.Sprintf("%dg.%dgb%s", p.G, p.GB, suffix)
}
return fmt.Sprintf("%dc.%dg.%dgb%s", p.C, p.G, p.GB, suffix)
}
// GetInfo returns detailed info about a Profile
func (p MigProfileInfo) GetInfo() MigProfileInfo {
return p
}
// Equals checks if two Profiles are identical or not
func (p MigProfileInfo) Equals(other MigProfile) bool {
o := other.GetInfo()
if p.C != o.C {
return false
}
if p.G != o.G {
return false
}
if p.GB != o.GB {
return false
}
if p.GIProfileID != o.GIProfileID {
return false
}
if p.CIProfileID != o.CIProfileID {
return false
}
if p.CIEngProfileID != o.CIEngProfileID {
return false
}
return true
}
// Matches checks if a MigProfile matches the string passed in
func (p MigProfileInfo) Matches(profile string) bool {
c, g, gb, attrs, err := parseMigProfile(profile)
if err != nil {
return false
}
if c != p.C {
return false
}
if g != p.G {
return false
}
if gb != p.GB {
return false
}
if len(attrs) != len(p.Attributes) {
return false
}
sort.Strings(attrs)
sort.Strings(p.Attributes)
for i, a := range p.Attributes {
if a != attrs[i] {
return false
}
}
return true
}
func parseMigProfile(profile string) (int, int, int, []string, error) {
// If we are handed the empty string, we cannot parse it
if profile == "" {
return -1, -1, -1, nil, fmt.Errorf("profile is the empty string")
}
// Split by + to separate out attributes
split := strings.SplitN(profile, "+", 2)
// Check to make sure the c, g, and gb values match
c, g, gb, err := parseMigProfileFields(split[0])
if err != nil {
return -1, -1, -1, nil, fmt.Errorf("cannot parse fields of '%v': %v", profile, err)
}
// If we have no attributes we are done
if len(split) == 1 {
return c, g, gb, nil, nil
}
// Make sure we have the same set of attributes
attrs, err := parseMigProfileAttributes(split[1])
if err != nil {
return -1, -1, -1, nil, fmt.Errorf("cannot parse attributes of '%v': %v", profile, err)
}
return c, g, gb, attrs, nil
}
func parseMigProfileField(s string, field string) (int, error) {
if strings.TrimSpace(s) != s {
return -1, fmt.Errorf("leading or trailing spaces on '%%d%s'", field)
}
if !strings.HasSuffix(s, field) {
return -1, fmt.Errorf("missing '%s' from '%%d%s'", field, field)
}
v, err := strconv.Atoi(strings.TrimSuffix(s, field))
if err != nil {
return -1, fmt.Errorf("malformed number in '%%d%s'", field)
}
return v, nil
}
func parseMigProfileFields(s string) (int, int, int, error) {
var err error
var c, g, gb int
split := strings.SplitN(s, ".", 3)
if len(split) == 3 {
c, err = parseMigProfileField(split[0], "c")
if err != nil {
return -1, -1, -1, err
}
g, err = parseMigProfileField(split[1], "g")
if err != nil {
return -1, -1, -1, err
}
gb, err = parseMigProfileField(split[2], "gb")
if err != nil {
return -1, -1, -1, err
}
return c, g, gb, err
}
if len(split) == 2 {
g, err = parseMigProfileField(split[0], "g")
if err != nil {
return -1, -1, -1, err
}
gb, err = parseMigProfileField(split[1], "gb")
if err != nil {
return -1, -1, -1, err
}
return g, g, gb, nil
}
return -1, -1, -1, fmt.Errorf("parsed wrong number of fields, expected 2 or 3")
}
func parseMigProfileAttributes(s string) ([]string, error) {
attr := strings.Split(s, ",")
if len(attr) == 0 {
return nil, fmt.Errorf("empty attribute list")
}
unique := make(map[string]int)
for _, a := range attr {
if unique[a] > 0 {
return nil, fmt.Errorf("non unique attribute in list")
}
if a == "" {
return nil, fmt.Errorf("empty attribute in list")
}
if strings.TrimSpace(a) != a {
return nil, fmt.Errorf("leading or trailing spaces in attribute")
}
if a[0] >= '0' && a[0] <= '9' {
return nil, fmt.Errorf("attribute begins with a number")
}
for _, c := range a {
if (c < 'a' || c > 'z') && (c < 'A' || c > 'Z') && (c < '0' || c > '9') {
return nil, fmt.Errorf("non alpha-numeric character or digit in attribute")
}
}
unique[a]++
}
return attr, nil
}
func getMigMemorySizeGB(totalDeviceMemory, migMemorySizeMB uint64) uint64 {
const fracDenominator = 8
const oneMB = 1024 * 1024
const oneGB = 1024 * 1024 * 1024
fractionalGpuMem := (float64(migMemorySizeMB) * oneMB) / float64(totalDeviceMemory)
fractionalGpuMem = math.Ceil(fractionalGpuMem*fracDenominator) / fracDenominator
totalMemGB := float64((totalDeviceMemory + oneGB - 1) / oneGB)
return uint64(math.Round(fractionalGpuMem * totalMemGB))
}

View File

@@ -0,0 +1,102 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package info
import (
"fmt"
"os"
"path/filepath"
"strings"
"github.com/NVIDIA/go-nvml/pkg/dl"
)
// Interface provides the API to the info package
type Interface interface {
HasDXCore() (bool, string)
HasNvml() (bool, string)
IsTegraSystem() (bool, string)
}
type infolib struct {
root string
}
var _ Interface = &infolib{}
// HasDXCore returns true if DXCore is detected on the system.
func (i *infolib) HasDXCore() (bool, string) {
const (
libraryName = "libdxcore.so"
)
if err := assertHasLibrary(libraryName); err != nil {
return false, fmt.Sprintf("could not load DXCore library: %v", err)
}
return true, "found DXCore library"
}
// HasNvml returns true if NVML is detected on the system
func (i *infolib) HasNvml() (bool, string) {
const (
libraryName = "libnvidia-ml.so.1"
)
if err := assertHasLibrary(libraryName); err != nil {
return false, fmt.Sprintf("could not load NVML library: %v", err)
}
return true, "found NVML library"
}
// IsTegraSystem returns true if the system is detected as a Tegra-based system
func (i *infolib) IsTegraSystem() (bool, string) {
tegraReleaseFile := filepath.Join(i.root, "/etc/nv_tegra_release")
tegraFamilyFile := filepath.Join(i.root, "/sys/devices/soc0/family")
if info, err := os.Stat(tegraReleaseFile); err == nil && !info.IsDir() {
return true, fmt.Sprintf("%v found", tegraReleaseFile)
}
if info, err := os.Stat(tegraFamilyFile); err != nil || info.IsDir() {
return false, fmt.Sprintf("%v file not found", tegraFamilyFile)
}
contents, err := os.ReadFile(tegraFamilyFile)
if err != nil {
return false, fmt.Sprintf("could not read %v", tegraFamilyFile)
}
if strings.HasPrefix(strings.ToLower(string(contents)), "tegra") {
return true, fmt.Sprintf("%v has 'tegra' prefix", tegraFamilyFile)
}
return false, fmt.Sprintf("%v has no 'tegra' prefix", tegraFamilyFile)
}
// assertHasLibrary returns an error if the specified library cannot be loaded
func assertHasLibrary(libraryName string) error {
const (
libraryLoadFlags = dl.RTLD_LAZY
)
lib := dl.New(libraryName, libraryLoadFlags)
if err := lib.Open(); err != nil {
return err
}
defer lib.Close()
return nil
}

View File

@@ -0,0 +1,39 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package info
// Option defines a function for passing options to the New() call
type Option func(*infolib)
// New creates a new instance of the 'info' interface
func New(opts ...Option) Interface {
i := &infolib{}
for _, opt := range opts {
opt(i)
}
if i.root == "" {
i.root = "/"
}
return i
}
// WithRoot provides a Option to set the root of the 'info' interface
func WithRoot(root string) Option {
return func(i *infolib) {
i.root = root
}
}

44
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/ci.go generated vendored Normal file
View File

@@ -0,0 +1,44 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type nvmlComputeInstance nvml.ComputeInstance
var _ ComputeInstance = (*nvmlComputeInstance)(nil)
// GetInfo() returns info about a Compute Instance
func (ci nvmlComputeInstance) GetInfo() (ComputeInstanceInfo, Return) {
i, r := nvml.ComputeInstance(ci).GetInfo()
info := ComputeInstanceInfo{
Device: nvmlDevice(i.Device),
GpuInstance: nvmlGpuInstance(i.GpuInstance),
Id: i.Id,
ProfileId: i.ProfileId,
Placement: ComputeInstancePlacement(i.Placement),
}
return info, Return(r)
}
// Destroy() destroys a Compute Instance
func (ci nvmlComputeInstance) Destroy() Return {
r := nvml.ComputeInstance(ci).Destroy()
return Return(r)
}

104
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/ci_mock.go generated vendored Normal file
View File

@@ -0,0 +1,104 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package nvml
import (
"sync"
)
// Ensure, that ComputeInstanceMock does implement ComputeInstance.
// If this is not the case, regenerate this file with moq.
var _ ComputeInstance = &ComputeInstanceMock{}
// ComputeInstanceMock is a mock implementation of ComputeInstance.
//
// func TestSomethingThatUsesComputeInstance(t *testing.T) {
//
// // make and configure a mocked ComputeInstance
// mockedComputeInstance := &ComputeInstanceMock{
// DestroyFunc: func() Return {
// panic("mock out the Destroy method")
// },
// GetInfoFunc: func() (ComputeInstanceInfo, Return) {
// panic("mock out the GetInfo method")
// },
// }
//
// // use mockedComputeInstance in code that requires ComputeInstance
// // and then make assertions.
//
// }
type ComputeInstanceMock struct {
// DestroyFunc mocks the Destroy method.
DestroyFunc func() Return
// GetInfoFunc mocks the GetInfo method.
GetInfoFunc func() (ComputeInstanceInfo, Return)
// calls tracks calls to the methods.
calls struct {
// Destroy holds details about calls to the Destroy method.
Destroy []struct {
}
// GetInfo holds details about calls to the GetInfo method.
GetInfo []struct {
}
}
lockDestroy sync.RWMutex
lockGetInfo sync.RWMutex
}
// Destroy calls DestroyFunc.
func (mock *ComputeInstanceMock) Destroy() Return {
if mock.DestroyFunc == nil {
panic("ComputeInstanceMock.DestroyFunc: method is nil but ComputeInstance.Destroy was just called")
}
callInfo := struct {
}{}
mock.lockDestroy.Lock()
mock.calls.Destroy = append(mock.calls.Destroy, callInfo)
mock.lockDestroy.Unlock()
return mock.DestroyFunc()
}
// DestroyCalls gets all the calls that were made to Destroy.
// Check the length with:
//
// len(mockedComputeInstance.DestroyCalls())
func (mock *ComputeInstanceMock) DestroyCalls() []struct {
} {
var calls []struct {
}
mock.lockDestroy.RLock()
calls = mock.calls.Destroy
mock.lockDestroy.RUnlock()
return calls
}
// GetInfo calls GetInfoFunc.
func (mock *ComputeInstanceMock) GetInfo() (ComputeInstanceInfo, Return) {
if mock.GetInfoFunc == nil {
panic("ComputeInstanceMock.GetInfoFunc: method is nil but ComputeInstance.GetInfo was just called")
}
callInfo := struct {
}{}
mock.lockGetInfo.Lock()
mock.calls.GetInfo = append(mock.calls.GetInfo, callInfo)
mock.lockGetInfo.Unlock()
return mock.GetInfoFunc()
}
// GetInfoCalls gets all the calls that were made to GetInfo.
// Check the length with:
//
// len(mockedComputeInstance.GetInfoCalls())
func (mock *ComputeInstanceMock) GetInfoCalls() []struct {
} {
var calls []struct {
}
mock.lockGetInfo.RLock()
calls = mock.calls.GetInfo
mock.lockGetInfo.RUnlock()
return calls
}

133
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/consts.go generated vendored Normal file
View File

@@ -0,0 +1,133 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
// Return constants
const (
SUCCESS = Return(nvml.SUCCESS)
ERROR_UNINITIALIZED = Return(nvml.ERROR_UNINITIALIZED)
ERROR_INVALID_ARGUMENT = Return(nvml.ERROR_INVALID_ARGUMENT)
ERROR_NOT_SUPPORTED = Return(nvml.ERROR_NOT_SUPPORTED)
ERROR_NO_PERMISSION = Return(nvml.ERROR_NO_PERMISSION)
ERROR_ALREADY_INITIALIZED = Return(nvml.ERROR_ALREADY_INITIALIZED)
ERROR_NOT_FOUND = Return(nvml.ERROR_NOT_FOUND)
ERROR_INSUFFICIENT_SIZE = Return(nvml.ERROR_INSUFFICIENT_SIZE)
ERROR_INSUFFICIENT_POWER = Return(nvml.ERROR_INSUFFICIENT_POWER)
ERROR_DRIVER_NOT_LOADED = Return(nvml.ERROR_DRIVER_NOT_LOADED)
ERROR_TIMEOUT = Return(nvml.ERROR_TIMEOUT)
ERROR_IRQ_ISSUE = Return(nvml.ERROR_IRQ_ISSUE)
ERROR_LIBRARY_NOT_FOUND = Return(nvml.ERROR_LIBRARY_NOT_FOUND)
ERROR_FUNCTION_NOT_FOUND = Return(nvml.ERROR_FUNCTION_NOT_FOUND)
ERROR_CORRUPTED_INFOROM = Return(nvml.ERROR_CORRUPTED_INFOROM)
ERROR_GPU_IS_LOST = Return(nvml.ERROR_GPU_IS_LOST)
ERROR_RESET_REQUIRED = Return(nvml.ERROR_RESET_REQUIRED)
ERROR_OPERATING_SYSTEM = Return(nvml.ERROR_OPERATING_SYSTEM)
ERROR_LIB_RM_VERSION_MISMATCH = Return(nvml.ERROR_LIB_RM_VERSION_MISMATCH)
ERROR_IN_USE = Return(nvml.ERROR_IN_USE)
ERROR_MEMORY = Return(nvml.ERROR_MEMORY)
ERROR_NO_DATA = Return(nvml.ERROR_NO_DATA)
ERROR_VGPU_ECC_NOT_SUPPORTED = Return(nvml.ERROR_VGPU_ECC_NOT_SUPPORTED)
ERROR_INSUFFICIENT_RESOURCES = Return(nvml.ERROR_INSUFFICIENT_RESOURCES)
ERROR_UNKNOWN = Return(nvml.ERROR_UNKNOWN)
)
// Device architecture constants
const (
DEVICE_ARCH_KEPLER = nvml.DEVICE_ARCH_KEPLER
DEVICE_ARCH_MAXWELL = nvml.DEVICE_ARCH_MAXWELL
DEVICE_ARCH_PASCAL = nvml.DEVICE_ARCH_PASCAL
DEVICE_ARCH_VOLTA = nvml.DEVICE_ARCH_VOLTA
DEVICE_ARCH_TURING = nvml.DEVICE_ARCH_TURING
DEVICE_ARCH_AMPERE = nvml.DEVICE_ARCH_AMPERE
DEVICE_ARCH_ADA = nvml.DEVICE_ARCH_ADA
DEVICE_ARCH_HOPPER = nvml.DEVICE_ARCH_HOPPER
DEVICE_ARCH_UNKNOWN = nvml.DEVICE_ARCH_UNKNOWN
)
// Device brand constants
const (
BRAND_UNKNOWN = BrandType(nvml.BRAND_UNKNOWN)
BRAND_QUADRO = BrandType(nvml.BRAND_QUADRO)
BRAND_TESLA = BrandType(nvml.BRAND_TESLA)
BRAND_NVS = BrandType(nvml.BRAND_NVS)
BRAND_GRID = BrandType(nvml.BRAND_GRID)
BRAND_GEFORCE = BrandType(nvml.BRAND_GEFORCE)
BRAND_TITAN = BrandType(nvml.BRAND_TITAN)
BRAND_NVIDIA_VAPPS = BrandType(nvml.BRAND_NVIDIA_VAPPS)
BRAND_NVIDIA_VPC = BrandType(nvml.BRAND_NVIDIA_VPC)
BRAND_NVIDIA_VCS = BrandType(nvml.BRAND_NVIDIA_VCS)
BRAND_NVIDIA_VWS = BrandType(nvml.BRAND_NVIDIA_VWS)
BRAND_NVIDIA_CLOUD_GAMING = BrandType(nvml.BRAND_NVIDIA_CLOUD_GAMING)
BRAND_NVIDIA_VGAMING = BrandType(nvml.BRAND_NVIDIA_VGAMING)
BRAND_QUADRO_RTX = BrandType(nvml.BRAND_QUADRO_RTX)
BRAND_NVIDIA_RTX = BrandType(nvml.BRAND_NVIDIA_RTX)
BRAND_NVIDIA = BrandType(nvml.BRAND_NVIDIA)
BRAND_GEFORCE_RTX = BrandType(nvml.BRAND_GEFORCE_RTX)
BRAND_TITAN_RTX = BrandType(nvml.BRAND_TITAN_RTX)
BRAND_COUNT = BrandType(nvml.BRAND_COUNT)
)
// MIG Mode constants
const (
DEVICE_MIG_ENABLE = nvml.DEVICE_MIG_ENABLE
DEVICE_MIG_DISABLE = nvml.DEVICE_MIG_DISABLE
)
// GPU Instance Profiles
const (
GPU_INSTANCE_PROFILE_1_SLICE = nvml.GPU_INSTANCE_PROFILE_1_SLICE
GPU_INSTANCE_PROFILE_2_SLICE = nvml.GPU_INSTANCE_PROFILE_2_SLICE
GPU_INSTANCE_PROFILE_3_SLICE = nvml.GPU_INSTANCE_PROFILE_3_SLICE
GPU_INSTANCE_PROFILE_4_SLICE = nvml.GPU_INSTANCE_PROFILE_4_SLICE
GPU_INSTANCE_PROFILE_6_SLICE = nvml.GPU_INSTANCE_PROFILE_6_SLICE
GPU_INSTANCE_PROFILE_7_SLICE = nvml.GPU_INSTANCE_PROFILE_7_SLICE
GPU_INSTANCE_PROFILE_8_SLICE = nvml.GPU_INSTANCE_PROFILE_8_SLICE
GPU_INSTANCE_PROFILE_1_SLICE_REV1 = nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV1
GPU_INSTANCE_PROFILE_1_SLICE_REV2 = nvml.GPU_INSTANCE_PROFILE_1_SLICE_REV2
GPU_INSTANCE_PROFILE_2_SLICE_REV1 = nvml.GPU_INSTANCE_PROFILE_2_SLICE_REV1
GPU_INSTANCE_PROFILE_COUNT = nvml.GPU_INSTANCE_PROFILE_COUNT
)
// Compute Instance Profiles
const (
COMPUTE_INSTANCE_PROFILE_1_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE
COMPUTE_INSTANCE_PROFILE_2_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_2_SLICE
COMPUTE_INSTANCE_PROFILE_3_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_3_SLICE
COMPUTE_INSTANCE_PROFILE_4_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_4_SLICE
COMPUTE_INSTANCE_PROFILE_6_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_6_SLICE
COMPUTE_INSTANCE_PROFILE_7_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_7_SLICE
COMPUTE_INSTANCE_PROFILE_8_SLICE = nvml.COMPUTE_INSTANCE_PROFILE_8_SLICE
COMPUTE_INSTANCE_PROFILE_1_SLICE_REV1 = nvml.COMPUTE_INSTANCE_PROFILE_1_SLICE_REV1
COMPUTE_INSTANCE_PROFILE_COUNT = nvml.COMPUTE_INSTANCE_PROFILE_COUNT
)
// Compute Instance Engine Profiles
const (
COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED = nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_SHARED
COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT = nvml.COMPUTE_INSTANCE_ENGINE_PROFILE_COUNT
)
// Event Types
const (
EventTypeXidCriticalError = nvml.EventTypeXidCriticalError
EventTypeSingleBitEccError = nvml.EventTypeSingleBitEccError
EventTypeDoubleBitEccError = nvml.EventTypeDoubleBitEccError
)

180
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/device.go generated vendored Normal file
View File

@@ -0,0 +1,180 @@
/**
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvml
import "github.com/NVIDIA/go-nvml/pkg/nvml"
type nvmlDevice nvml.Device
var _ Device = (*nvmlDevice)(nil)
// GetIndex returns the index of a Device
func (d nvmlDevice) GetIndex() (int, Return) {
i, r := nvml.Device(d).GetIndex()
return i, Return(r)
}
// GetPciInfo returns the PCI info of a Device
func (d nvmlDevice) GetPciInfo() (PciInfo, Return) {
p, r := nvml.Device(d).GetPciInfo()
return PciInfo(p), Return(r)
}
// GetMemoryInfo returns the memory info of a Device
func (d nvmlDevice) GetMemoryInfo() (Memory, Return) {
p, r := nvml.Device(d).GetMemoryInfo()
return Memory(p), Return(r)
}
// GetUUID returns the UUID of a Device
func (d nvmlDevice) GetUUID() (string, Return) {
u, r := nvml.Device(d).GetUUID()
return u, Return(r)
}
// GetMinorNumber returns the minor number of a Device
func (d nvmlDevice) GetMinorNumber() (int, Return) {
m, r := nvml.Device(d).GetMinorNumber()
return m, Return(r)
}
// IsMigDeviceHandle returns whether a Device is a MIG device or not
func (d nvmlDevice) IsMigDeviceHandle() (bool, Return) {
b, r := nvml.Device(d).IsMigDeviceHandle()
return b, Return(r)
}
// GetDeviceHandleFromMigDeviceHandle returns the parent Device of a MIG device
func (d nvmlDevice) GetDeviceHandleFromMigDeviceHandle() (Device, Return) {
p, r := nvml.Device(d).GetDeviceHandleFromMigDeviceHandle()
return nvmlDevice(p), Return(r)
}
// SetMigMode sets the MIG mode of a Device
func (d nvmlDevice) SetMigMode(mode int) (Return, Return) {
r1, r2 := nvml.Device(d).SetMigMode(mode)
return Return(r1), Return(r2)
}
// GetMigMode returns the MIG mode of a Device
func (d nvmlDevice) GetMigMode() (int, int, Return) {
s1, s2, r := nvml.Device(d).GetMigMode()
return s1, s2, Return(r)
}
// GetGpuInstanceById returns the GPU Instance associated with a particular ID
func (d nvmlDevice) GetGpuInstanceById(id int) (GpuInstance, Return) {
gi, r := nvml.Device(d).GetGpuInstanceById(id)
return nvmlGpuInstance(gi), Return(r)
}
// GetGpuInstanceProfileInfo returns the profile info of a GPU Instance
func (d nvmlDevice) GetGpuInstanceProfileInfo(profile int) (GpuInstanceProfileInfo, Return) {
p, r := nvml.Device(d).GetGpuInstanceProfileInfo(profile)
return GpuInstanceProfileInfo(p), Return(r)
}
// GetGpuInstancePossiblePlacements returns the possible placements of a GPU Instance
func (d nvmlDevice) GetGpuInstancePossiblePlacements(info *GpuInstanceProfileInfo) ([]GpuInstancePlacement, Return) {
nvmlPlacements, r := nvml.Device(d).GetGpuInstancePossiblePlacements((*nvml.GpuInstanceProfileInfo)(info))
var placements []GpuInstancePlacement
for _, p := range nvmlPlacements {
placements = append(placements, GpuInstancePlacement(p))
}
return placements, Return(r)
}
// GetGpuInstances returns the set of GPU Instances associated with a Device
func (d nvmlDevice) GetGpuInstances(info *GpuInstanceProfileInfo) ([]GpuInstance, Return) {
nvmlGis, r := nvml.Device(d).GetGpuInstances((*nvml.GpuInstanceProfileInfo)(info))
var gis []GpuInstance
for _, gi := range nvmlGis {
gis = append(gis, nvmlGpuInstance(gi))
}
return gis, Return(r)
}
// CreateGpuInstanceWithPlacement creates a GPU Instance with a specific placement
func (d nvmlDevice) CreateGpuInstanceWithPlacement(info *GpuInstanceProfileInfo, placement *GpuInstancePlacement) (GpuInstance, Return) {
gi, r := nvml.Device(d).CreateGpuInstanceWithPlacement((*nvml.GpuInstanceProfileInfo)(info), (*nvml.GpuInstancePlacement)(placement))
return nvmlGpuInstance(gi), Return(r)
}
// GetMaxMigDeviceCount returns the maximum number of MIG devices that can be created on a Device
func (d nvmlDevice) GetMaxMigDeviceCount() (int, Return) {
m, r := nvml.Device(d).GetMaxMigDeviceCount()
return m, Return(r)
}
// GetMigDeviceHandleByIndex returns the handle to a MIG device given its index
func (d nvmlDevice) GetMigDeviceHandleByIndex(Index int) (Device, Return) {
h, r := nvml.Device(d).GetMigDeviceHandleByIndex(Index)
return nvmlDevice(h), Return(r)
}
// GetGpuInstanceId returns the GPU Instance ID of a MIG device
func (d nvmlDevice) GetGpuInstanceId() (int, Return) {
gi, r := nvml.Device(d).GetGpuInstanceId()
return gi, Return(r)
}
// GetComputeInstanceId returns the Compute Instance ID of a MIG device
func (d nvmlDevice) GetComputeInstanceId() (int, Return) {
ci, r := nvml.Device(d).GetComputeInstanceId()
return ci, Return(r)
}
// GetCudaComputeCapability returns the compute capability major and minor versions for a device
func (d nvmlDevice) GetCudaComputeCapability() (int, int, Return) {
major, minor, r := nvml.Device(d).GetCudaComputeCapability()
return major, minor, Return(r)
}
// GetAttributes returns the device attributes for a MIG device
func (d nvmlDevice) GetAttributes() (DeviceAttributes, Return) {
a, r := nvml.Device(d).GetAttributes()
return DeviceAttributes(a), Return(r)
}
// GetName returns the product name of a Device
func (d nvmlDevice) GetName() (string, Return) {
n, r := nvml.Device(d).GetName()
return n, Return(r)
}
// GetBrand returns the brand of a Device
func (d nvmlDevice) GetBrand() (BrandType, Return) {
b, r := nvml.Device(d).GetBrand()
return BrandType(b), Return(r)
}
// GetArchitecture returns the architecture of a Device
func (d nvmlDevice) GetArchitecture() (DeviceArchitecture, Return) {
a, r := nvml.Device(d).GetArchitecture()
return DeviceArchitecture(a), Return(r)
}
// RegisterEvents registers the specified event set and type with the device
func (d nvmlDevice) RegisterEvents(EventTypes uint64, Set EventSet) Return {
return Return(nvml.Device(d).RegisterEvents(EventTypes, nvml.EventSet(Set)))
}
// GetSupportedEventTypes returns the events supported by the device
func (d nvmlDevice) GetSupportedEventTypes() (uint64, Return) {
e, r := nvml.Device(d).GetSupportedEventTypes()
return e, Return(r)
}

1023
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/device_mock.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,39 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
// Wait watches for an event with the specified timeout
func (e EventSet) Wait(Timeoutms uint32) (EventData, Return) {
d, r := nvml.EventSet(e).Wait(Timeoutms)
eventData := EventData{
Device: nvmlDevice(d.Device),
EventType: d.EventType,
EventData: d.EventData,
GpuInstanceId: d.GpuInstanceId,
ComputeInstanceId: d.ComputeInstanceId,
}
return eventData, Return(r)
}
// Free deletes the event set
func (e EventSet) Free() Return {
return Return(nvml.EventSet(e).Free())
}

71
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/gi.go generated vendored Normal file
View File

@@ -0,0 +1,71 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type nvmlGpuInstance nvml.GpuInstance
var _ GpuInstance = (*nvmlGpuInstance)(nil)
// GetInfo returns info about a GPU Intsance
func (gi nvmlGpuInstance) GetInfo() (GpuInstanceInfo, Return) {
i, r := nvml.GpuInstance(gi).GetInfo()
info := GpuInstanceInfo{
Device: nvmlDevice(i.Device),
Id: i.Id,
ProfileId: i.ProfileId,
Placement: GpuInstancePlacement(i.Placement),
}
return info, Return(r)
}
// GetComputeInstanceById returns the Compute Instance associated with a particular ID.
func (gi nvmlGpuInstance) GetComputeInstanceById(id int) (ComputeInstance, Return) {
ci, r := nvml.GpuInstance(gi).GetComputeInstanceById(id)
return nvmlComputeInstance(ci), Return(r)
}
// GetComputeInstanceProfileInfo returns info about a given Compute Instance profile
func (gi nvmlGpuInstance) GetComputeInstanceProfileInfo(profile int, engProfile int) (ComputeInstanceProfileInfo, Return) {
p, r := nvml.GpuInstance(gi).GetComputeInstanceProfileInfo(profile, engProfile)
return ComputeInstanceProfileInfo(p), Return(r)
}
// CreateComputeInstance creates a Compute Instance within the GPU Instance
func (gi nvmlGpuInstance) CreateComputeInstance(info *ComputeInstanceProfileInfo) (ComputeInstance, Return) {
ci, r := nvml.GpuInstance(gi).CreateComputeInstance((*nvml.ComputeInstanceProfileInfo)(info))
return nvmlComputeInstance(ci), Return(r)
}
// GetComputeInstances returns the set of Compute Instances associated with a GPU Instance
func (gi nvmlGpuInstance) GetComputeInstances(info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) {
nvmlCis, r := nvml.GpuInstance(gi).GetComputeInstances((*nvml.ComputeInstanceProfileInfo)(info))
var cis []ComputeInstance
for _, ci := range nvmlCis {
cis = append(cis, nvmlComputeInstance(ci))
}
return cis, Return(r)
}
// Destroy destroys a GPU Instance
func (gi nvmlGpuInstance) Destroy() Return {
r := nvml.GpuInstance(gi).Destroy()
return Return(r)
}

286
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/gi_mock.go generated vendored Normal file
View File

@@ -0,0 +1,286 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package nvml
import (
"sync"
)
// Ensure, that GpuInstanceMock does implement GpuInstance.
// If this is not the case, regenerate this file with moq.
var _ GpuInstance = &GpuInstanceMock{}
// GpuInstanceMock is a mock implementation of GpuInstance.
//
// func TestSomethingThatUsesGpuInstance(t *testing.T) {
//
// // make and configure a mocked GpuInstance
// mockedGpuInstance := &GpuInstanceMock{
// CreateComputeInstanceFunc: func(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return) {
// panic("mock out the CreateComputeInstance method")
// },
// DestroyFunc: func() Return {
// panic("mock out the Destroy method")
// },
// GetComputeInstanceByIdFunc: func(ID int) (ComputeInstance, Return) {
// panic("mock out the GetComputeInstanceById method")
// },
// GetComputeInstanceProfileInfoFunc: func(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) {
// panic("mock out the GetComputeInstanceProfileInfo method")
// },
// GetComputeInstancesFunc: func(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) {
// panic("mock out the GetComputeInstances method")
// },
// GetInfoFunc: func() (GpuInstanceInfo, Return) {
// panic("mock out the GetInfo method")
// },
// }
//
// // use mockedGpuInstance in code that requires GpuInstance
// // and then make assertions.
//
// }
type GpuInstanceMock struct {
// CreateComputeInstanceFunc mocks the CreateComputeInstance method.
CreateComputeInstanceFunc func(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return)
// DestroyFunc mocks the Destroy method.
DestroyFunc func() Return
// GetComputeInstanceByIdFunc mocks the GetComputeInstanceById method.
GetComputeInstanceByIdFunc func(ID int) (ComputeInstance, Return)
// GetComputeInstanceProfileInfoFunc mocks the GetComputeInstanceProfileInfo method.
GetComputeInstanceProfileInfoFunc func(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return)
// GetComputeInstancesFunc mocks the GetComputeInstances method.
GetComputeInstancesFunc func(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return)
// GetInfoFunc mocks the GetInfo method.
GetInfoFunc func() (GpuInstanceInfo, Return)
// calls tracks calls to the methods.
calls struct {
// CreateComputeInstance holds details about calls to the CreateComputeInstance method.
CreateComputeInstance []struct {
// Info is the Info argument value.
Info *ComputeInstanceProfileInfo
}
// Destroy holds details about calls to the Destroy method.
Destroy []struct {
}
// GetComputeInstanceById holds details about calls to the GetComputeInstanceById method.
GetComputeInstanceById []struct {
// ID is the ID argument value.
ID int
}
// GetComputeInstanceProfileInfo holds details about calls to the GetComputeInstanceProfileInfo method.
GetComputeInstanceProfileInfo []struct {
// Profile is the Profile argument value.
Profile int
// EngProfile is the EngProfile argument value.
EngProfile int
}
// GetComputeInstances holds details about calls to the GetComputeInstances method.
GetComputeInstances []struct {
// Info is the Info argument value.
Info *ComputeInstanceProfileInfo
}
// GetInfo holds details about calls to the GetInfo method.
GetInfo []struct {
}
}
lockCreateComputeInstance sync.RWMutex
lockDestroy sync.RWMutex
lockGetComputeInstanceById sync.RWMutex
lockGetComputeInstanceProfileInfo sync.RWMutex
lockGetComputeInstances sync.RWMutex
lockGetInfo sync.RWMutex
}
// CreateComputeInstance calls CreateComputeInstanceFunc.
func (mock *GpuInstanceMock) CreateComputeInstance(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return) {
if mock.CreateComputeInstanceFunc == nil {
panic("GpuInstanceMock.CreateComputeInstanceFunc: method is nil but GpuInstance.CreateComputeInstance was just called")
}
callInfo := struct {
Info *ComputeInstanceProfileInfo
}{
Info: Info,
}
mock.lockCreateComputeInstance.Lock()
mock.calls.CreateComputeInstance = append(mock.calls.CreateComputeInstance, callInfo)
mock.lockCreateComputeInstance.Unlock()
return mock.CreateComputeInstanceFunc(Info)
}
// CreateComputeInstanceCalls gets all the calls that were made to CreateComputeInstance.
// Check the length with:
//
// len(mockedGpuInstance.CreateComputeInstanceCalls())
func (mock *GpuInstanceMock) CreateComputeInstanceCalls() []struct {
Info *ComputeInstanceProfileInfo
} {
var calls []struct {
Info *ComputeInstanceProfileInfo
}
mock.lockCreateComputeInstance.RLock()
calls = mock.calls.CreateComputeInstance
mock.lockCreateComputeInstance.RUnlock()
return calls
}
// Destroy calls DestroyFunc.
func (mock *GpuInstanceMock) Destroy() Return {
if mock.DestroyFunc == nil {
panic("GpuInstanceMock.DestroyFunc: method is nil but GpuInstance.Destroy was just called")
}
callInfo := struct {
}{}
mock.lockDestroy.Lock()
mock.calls.Destroy = append(mock.calls.Destroy, callInfo)
mock.lockDestroy.Unlock()
return mock.DestroyFunc()
}
// DestroyCalls gets all the calls that were made to Destroy.
// Check the length with:
//
// len(mockedGpuInstance.DestroyCalls())
func (mock *GpuInstanceMock) DestroyCalls() []struct {
} {
var calls []struct {
}
mock.lockDestroy.RLock()
calls = mock.calls.Destroy
mock.lockDestroy.RUnlock()
return calls
}
// GetComputeInstanceById calls GetComputeInstanceByIdFunc.
func (mock *GpuInstanceMock) GetComputeInstanceById(ID int) (ComputeInstance, Return) {
if mock.GetComputeInstanceByIdFunc == nil {
panic("GpuInstanceMock.GetComputeInstanceByIdFunc: method is nil but GpuInstance.GetComputeInstanceById was just called")
}
callInfo := struct {
ID int
}{
ID: ID,
}
mock.lockGetComputeInstanceById.Lock()
mock.calls.GetComputeInstanceById = append(mock.calls.GetComputeInstanceById, callInfo)
mock.lockGetComputeInstanceById.Unlock()
return mock.GetComputeInstanceByIdFunc(ID)
}
// GetComputeInstanceByIdCalls gets all the calls that were made to GetComputeInstanceById.
// Check the length with:
//
// len(mockedGpuInstance.GetComputeInstanceByIdCalls())
func (mock *GpuInstanceMock) GetComputeInstanceByIdCalls() []struct {
ID int
} {
var calls []struct {
ID int
}
mock.lockGetComputeInstanceById.RLock()
calls = mock.calls.GetComputeInstanceById
mock.lockGetComputeInstanceById.RUnlock()
return calls
}
// GetComputeInstanceProfileInfo calls GetComputeInstanceProfileInfoFunc.
func (mock *GpuInstanceMock) GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return) {
if mock.GetComputeInstanceProfileInfoFunc == nil {
panic("GpuInstanceMock.GetComputeInstanceProfileInfoFunc: method is nil but GpuInstance.GetComputeInstanceProfileInfo was just called")
}
callInfo := struct {
Profile int
EngProfile int
}{
Profile: Profile,
EngProfile: EngProfile,
}
mock.lockGetComputeInstanceProfileInfo.Lock()
mock.calls.GetComputeInstanceProfileInfo = append(mock.calls.GetComputeInstanceProfileInfo, callInfo)
mock.lockGetComputeInstanceProfileInfo.Unlock()
return mock.GetComputeInstanceProfileInfoFunc(Profile, EngProfile)
}
// GetComputeInstanceProfileInfoCalls gets all the calls that were made to GetComputeInstanceProfileInfo.
// Check the length with:
//
// len(mockedGpuInstance.GetComputeInstanceProfileInfoCalls())
func (mock *GpuInstanceMock) GetComputeInstanceProfileInfoCalls() []struct {
Profile int
EngProfile int
} {
var calls []struct {
Profile int
EngProfile int
}
mock.lockGetComputeInstanceProfileInfo.RLock()
calls = mock.calls.GetComputeInstanceProfileInfo
mock.lockGetComputeInstanceProfileInfo.RUnlock()
return calls
}
// GetComputeInstances calls GetComputeInstancesFunc.
func (mock *GpuInstanceMock) GetComputeInstances(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return) {
if mock.GetComputeInstancesFunc == nil {
panic("GpuInstanceMock.GetComputeInstancesFunc: method is nil but GpuInstance.GetComputeInstances was just called")
}
callInfo := struct {
Info *ComputeInstanceProfileInfo
}{
Info: Info,
}
mock.lockGetComputeInstances.Lock()
mock.calls.GetComputeInstances = append(mock.calls.GetComputeInstances, callInfo)
mock.lockGetComputeInstances.Unlock()
return mock.GetComputeInstancesFunc(Info)
}
// GetComputeInstancesCalls gets all the calls that were made to GetComputeInstances.
// Check the length with:
//
// len(mockedGpuInstance.GetComputeInstancesCalls())
func (mock *GpuInstanceMock) GetComputeInstancesCalls() []struct {
Info *ComputeInstanceProfileInfo
} {
var calls []struct {
Info *ComputeInstanceProfileInfo
}
mock.lockGetComputeInstances.RLock()
calls = mock.calls.GetComputeInstances
mock.lockGetComputeInstances.RUnlock()
return calls
}
// GetInfo calls GetInfoFunc.
func (mock *GpuInstanceMock) GetInfo() (GpuInstanceInfo, Return) {
if mock.GetInfoFunc == nil {
panic("GpuInstanceMock.GetInfoFunc: method is nil but GpuInstance.GetInfo was just called")
}
callInfo := struct {
}{}
mock.lockGetInfo.Lock()
mock.calls.GetInfo = append(mock.calls.GetInfo, callInfo)
mock.lockGetInfo.Unlock()
return mock.GetInfoFunc()
}
// GetInfoCalls gets all the calls that were made to GetInfo.
// Check the length with:
//
// len(mockedGpuInstance.GetInfoCalls())
func (mock *GpuInstanceMock) GetInfoCalls() []struct {
} {
var calls []struct {
}
mock.lockGetInfo.RLock()
calls = mock.calls.GetInfo
mock.lockGetInfo.RUnlock()
return calls
}

127
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/nvml.go generated vendored Normal file
View File

@@ -0,0 +1,127 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"sync"
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type nvmlLib struct {
sync.Mutex
refcount int
}
var _ Interface = (*nvmlLib)(nil)
// New creates a new instance of the NVML Interface
func New(opts ...Option) Interface {
o := &options{}
for _, opt := range opts {
opt(o)
}
var nvmlOptions []nvml.LibraryOption
if o.libraryPath != "" {
nvmlOptions = append(nvmlOptions, nvml.WithLibraryPath(o.libraryPath))
}
nvml.SetLibraryOptions(nvmlOptions...)
return &nvmlLib{}
}
// Lookup checks whether the specified symbol exists in the configured NVML library.
func (n *nvmlLib) Lookup(name string) error {
// TODO: For now we rely on the default NVML library and perform the lookups against this.
return nvml.GetLibrary().Lookup(name)
}
// Init initializes an NVML Interface
func (n *nvmlLib) Init() Return {
ret := nvml.Init()
if ret != nvml.SUCCESS {
return Return(ret)
}
n.Lock()
defer n.Unlock()
if n.refcount == 0 {
errorStringFunc = nvml.ErrorString
}
n.refcount++
return SUCCESS
}
// Shutdown shuts down an NVML Interface
func (n *nvmlLib) Shutdown() Return {
ret := nvml.Shutdown()
if ret != nvml.SUCCESS {
return Return(ret)
}
n.Lock()
defer n.Unlock()
n.refcount--
if n.refcount == 0 {
errorStringFunc = defaultErrorStringFunc
}
return SUCCESS
}
// DeviceGetCount returns the total number of GPU Devices
func (n *nvmlLib) DeviceGetCount() (int, Return) {
c, r := nvml.DeviceGetCount()
return c, Return(r)
}
// DeviceGetHandleByIndex returns a Device handle given its index
func (n *nvmlLib) DeviceGetHandleByIndex(index int) (Device, Return) {
d, r := nvml.DeviceGetHandleByIndex(index)
return nvmlDevice(d), Return(r)
}
// DeviceGetHandleByUUID returns a Device handle given its UUID
func (n *nvmlLib) DeviceGetHandleByUUID(uuid string) (Device, Return) {
d, r := nvml.DeviceGetHandleByUUID(uuid)
return nvmlDevice(d), Return(r)
}
// SystemGetDriverVersion returns the version of the installed NVIDIA driver
func (n *nvmlLib) SystemGetDriverVersion() (string, Return) {
v, r := nvml.SystemGetDriverVersion()
return v, Return(r)
}
// SystemGetCudaDriverVersion returns the version of CUDA associated with the NVIDIA driver
func (n *nvmlLib) SystemGetCudaDriverVersion() (int, Return) {
v, r := nvml.SystemGetCudaDriverVersion()
return v, Return(r)
}
// ErrorString returns the error string associated with a given return value
func (n *nvmlLib) ErrorString(ret Return) string {
return nvml.ErrorString(nvml.Return(ret))
}
// EventSetCreate creates an event set
func (n *nvmlLib) EventSetCreate() (EventSet, Return) {
e, r := nvml.EventSetCreate()
return EventSet(e), Return(r)
}

428
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/nvml_mock.go generated vendored Normal file
View File

@@ -0,0 +1,428 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package nvml
import (
"sync"
)
// Ensure, that InterfaceMock does implement Interface.
// If this is not the case, regenerate this file with moq.
var _ Interface = &InterfaceMock{}
// InterfaceMock is a mock implementation of Interface.
//
// func TestSomethingThatUsesInterface(t *testing.T) {
//
// // make and configure a mocked Interface
// mockedInterface := &InterfaceMock{
// DeviceGetCountFunc: func() (int, Return) {
// panic("mock out the DeviceGetCount method")
// },
// DeviceGetHandleByIndexFunc: func(Index int) (Device, Return) {
// panic("mock out the DeviceGetHandleByIndex method")
// },
// DeviceGetHandleByUUIDFunc: func(UUID string) (Device, Return) {
// panic("mock out the DeviceGetHandleByUUID method")
// },
// ErrorStringFunc: func(r Return) string {
// panic("mock out the ErrorString method")
// },
// EventSetCreateFunc: func() (EventSet, Return) {
// panic("mock out the EventSetCreate method")
// },
// InitFunc: func() Return {
// panic("mock out the Init method")
// },
// LookupFunc: func(s string) error {
// panic("mock out the Lookup method")
// },
// ShutdownFunc: func() Return {
// panic("mock out the Shutdown method")
// },
// SystemGetCudaDriverVersionFunc: func() (int, Return) {
// panic("mock out the SystemGetCudaDriverVersion method")
// },
// SystemGetDriverVersionFunc: func() (string, Return) {
// panic("mock out the SystemGetDriverVersion method")
// },
// }
//
// // use mockedInterface in code that requires Interface
// // and then make assertions.
//
// }
type InterfaceMock struct {
// DeviceGetCountFunc mocks the DeviceGetCount method.
DeviceGetCountFunc func() (int, Return)
// DeviceGetHandleByIndexFunc mocks the DeviceGetHandleByIndex method.
DeviceGetHandleByIndexFunc func(Index int) (Device, Return)
// DeviceGetHandleByUUIDFunc mocks the DeviceGetHandleByUUID method.
DeviceGetHandleByUUIDFunc func(UUID string) (Device, Return)
// ErrorStringFunc mocks the ErrorString method.
ErrorStringFunc func(r Return) string
// EventSetCreateFunc mocks the EventSetCreate method.
EventSetCreateFunc func() (EventSet, Return)
// InitFunc mocks the Init method.
InitFunc func() Return
// LookupFunc mocks the Lookup method.
LookupFunc func(s string) error
// ShutdownFunc mocks the Shutdown method.
ShutdownFunc func() Return
// SystemGetCudaDriverVersionFunc mocks the SystemGetCudaDriverVersion method.
SystemGetCudaDriverVersionFunc func() (int, Return)
// SystemGetDriverVersionFunc mocks the SystemGetDriverVersion method.
SystemGetDriverVersionFunc func() (string, Return)
// calls tracks calls to the methods.
calls struct {
// DeviceGetCount holds details about calls to the DeviceGetCount method.
DeviceGetCount []struct {
}
// DeviceGetHandleByIndex holds details about calls to the DeviceGetHandleByIndex method.
DeviceGetHandleByIndex []struct {
// Index is the Index argument value.
Index int
}
// DeviceGetHandleByUUID holds details about calls to the DeviceGetHandleByUUID method.
DeviceGetHandleByUUID []struct {
// UUID is the UUID argument value.
UUID string
}
// ErrorString holds details about calls to the ErrorString method.
ErrorString []struct {
// R is the r argument value.
R Return
}
// EventSetCreate holds details about calls to the EventSetCreate method.
EventSetCreate []struct {
}
// Init holds details about calls to the Init method.
Init []struct {
}
// Lookup holds details about calls to the Lookup method.
Lookup []struct {
// S is the s argument value.
S string
}
// Shutdown holds details about calls to the Shutdown method.
Shutdown []struct {
}
// SystemGetCudaDriverVersion holds details about calls to the SystemGetCudaDriverVersion method.
SystemGetCudaDriverVersion []struct {
}
// SystemGetDriverVersion holds details about calls to the SystemGetDriverVersion method.
SystemGetDriverVersion []struct {
}
}
lockDeviceGetCount sync.RWMutex
lockDeviceGetHandleByIndex sync.RWMutex
lockDeviceGetHandleByUUID sync.RWMutex
lockErrorString sync.RWMutex
lockEventSetCreate sync.RWMutex
lockInit sync.RWMutex
lockLookup sync.RWMutex
lockShutdown sync.RWMutex
lockSystemGetCudaDriverVersion sync.RWMutex
lockSystemGetDriverVersion sync.RWMutex
}
// DeviceGetCount calls DeviceGetCountFunc.
func (mock *InterfaceMock) DeviceGetCount() (int, Return) {
if mock.DeviceGetCountFunc == nil {
panic("InterfaceMock.DeviceGetCountFunc: method is nil but Interface.DeviceGetCount was just called")
}
callInfo := struct {
}{}
mock.lockDeviceGetCount.Lock()
mock.calls.DeviceGetCount = append(mock.calls.DeviceGetCount, callInfo)
mock.lockDeviceGetCount.Unlock()
return mock.DeviceGetCountFunc()
}
// DeviceGetCountCalls gets all the calls that were made to DeviceGetCount.
// Check the length with:
//
// len(mockedInterface.DeviceGetCountCalls())
func (mock *InterfaceMock) DeviceGetCountCalls() []struct {
} {
var calls []struct {
}
mock.lockDeviceGetCount.RLock()
calls = mock.calls.DeviceGetCount
mock.lockDeviceGetCount.RUnlock()
return calls
}
// DeviceGetHandleByIndex calls DeviceGetHandleByIndexFunc.
func (mock *InterfaceMock) DeviceGetHandleByIndex(Index int) (Device, Return) {
if mock.DeviceGetHandleByIndexFunc == nil {
panic("InterfaceMock.DeviceGetHandleByIndexFunc: method is nil but Interface.DeviceGetHandleByIndex was just called")
}
callInfo := struct {
Index int
}{
Index: Index,
}
mock.lockDeviceGetHandleByIndex.Lock()
mock.calls.DeviceGetHandleByIndex = append(mock.calls.DeviceGetHandleByIndex, callInfo)
mock.lockDeviceGetHandleByIndex.Unlock()
return mock.DeviceGetHandleByIndexFunc(Index)
}
// DeviceGetHandleByIndexCalls gets all the calls that were made to DeviceGetHandleByIndex.
// Check the length with:
//
// len(mockedInterface.DeviceGetHandleByIndexCalls())
func (mock *InterfaceMock) DeviceGetHandleByIndexCalls() []struct {
Index int
} {
var calls []struct {
Index int
}
mock.lockDeviceGetHandleByIndex.RLock()
calls = mock.calls.DeviceGetHandleByIndex
mock.lockDeviceGetHandleByIndex.RUnlock()
return calls
}
// DeviceGetHandleByUUID calls DeviceGetHandleByUUIDFunc.
func (mock *InterfaceMock) DeviceGetHandleByUUID(UUID string) (Device, Return) {
if mock.DeviceGetHandleByUUIDFunc == nil {
panic("InterfaceMock.DeviceGetHandleByUUIDFunc: method is nil but Interface.DeviceGetHandleByUUID was just called")
}
callInfo := struct {
UUID string
}{
UUID: UUID,
}
mock.lockDeviceGetHandleByUUID.Lock()
mock.calls.DeviceGetHandleByUUID = append(mock.calls.DeviceGetHandleByUUID, callInfo)
mock.lockDeviceGetHandleByUUID.Unlock()
return mock.DeviceGetHandleByUUIDFunc(UUID)
}
// DeviceGetHandleByUUIDCalls gets all the calls that were made to DeviceGetHandleByUUID.
// Check the length with:
//
// len(mockedInterface.DeviceGetHandleByUUIDCalls())
func (mock *InterfaceMock) DeviceGetHandleByUUIDCalls() []struct {
UUID string
} {
var calls []struct {
UUID string
}
mock.lockDeviceGetHandleByUUID.RLock()
calls = mock.calls.DeviceGetHandleByUUID
mock.lockDeviceGetHandleByUUID.RUnlock()
return calls
}
// ErrorString calls ErrorStringFunc.
func (mock *InterfaceMock) ErrorString(r Return) string {
if mock.ErrorStringFunc == nil {
panic("InterfaceMock.ErrorStringFunc: method is nil but Interface.ErrorString was just called")
}
callInfo := struct {
R Return
}{
R: r,
}
mock.lockErrorString.Lock()
mock.calls.ErrorString = append(mock.calls.ErrorString, callInfo)
mock.lockErrorString.Unlock()
return mock.ErrorStringFunc(r)
}
// ErrorStringCalls gets all the calls that were made to ErrorString.
// Check the length with:
//
// len(mockedInterface.ErrorStringCalls())
func (mock *InterfaceMock) ErrorStringCalls() []struct {
R Return
} {
var calls []struct {
R Return
}
mock.lockErrorString.RLock()
calls = mock.calls.ErrorString
mock.lockErrorString.RUnlock()
return calls
}
// EventSetCreate calls EventSetCreateFunc.
func (mock *InterfaceMock) EventSetCreate() (EventSet, Return) {
if mock.EventSetCreateFunc == nil {
panic("InterfaceMock.EventSetCreateFunc: method is nil but Interface.EventSetCreate was just called")
}
callInfo := struct {
}{}
mock.lockEventSetCreate.Lock()
mock.calls.EventSetCreate = append(mock.calls.EventSetCreate, callInfo)
mock.lockEventSetCreate.Unlock()
return mock.EventSetCreateFunc()
}
// EventSetCreateCalls gets all the calls that were made to EventSetCreate.
// Check the length with:
//
// len(mockedInterface.EventSetCreateCalls())
func (mock *InterfaceMock) EventSetCreateCalls() []struct {
} {
var calls []struct {
}
mock.lockEventSetCreate.RLock()
calls = mock.calls.EventSetCreate
mock.lockEventSetCreate.RUnlock()
return calls
}
// Init calls InitFunc.
func (mock *InterfaceMock) Init() Return {
if mock.InitFunc == nil {
panic("InterfaceMock.InitFunc: method is nil but Interface.Init was just called")
}
callInfo := struct {
}{}
mock.lockInit.Lock()
mock.calls.Init = append(mock.calls.Init, callInfo)
mock.lockInit.Unlock()
return mock.InitFunc()
}
// InitCalls gets all the calls that were made to Init.
// Check the length with:
//
// len(mockedInterface.InitCalls())
func (mock *InterfaceMock) InitCalls() []struct {
} {
var calls []struct {
}
mock.lockInit.RLock()
calls = mock.calls.Init
mock.lockInit.RUnlock()
return calls
}
// Lookup calls LookupFunc.
func (mock *InterfaceMock) Lookup(s string) error {
if mock.LookupFunc == nil {
panic("InterfaceMock.LookupFunc: method is nil but Interface.Lookup was just called")
}
callInfo := struct {
S string
}{
S: s,
}
mock.lockLookup.Lock()
mock.calls.Lookup = append(mock.calls.Lookup, callInfo)
mock.lockLookup.Unlock()
return mock.LookupFunc(s)
}
// LookupCalls gets all the calls that were made to Lookup.
// Check the length with:
//
// len(mockedInterface.LookupCalls())
func (mock *InterfaceMock) LookupCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLookup.RLock()
calls = mock.calls.Lookup
mock.lockLookup.RUnlock()
return calls
}
// Shutdown calls ShutdownFunc.
func (mock *InterfaceMock) Shutdown() Return {
if mock.ShutdownFunc == nil {
panic("InterfaceMock.ShutdownFunc: method is nil but Interface.Shutdown was just called")
}
callInfo := struct {
}{}
mock.lockShutdown.Lock()
mock.calls.Shutdown = append(mock.calls.Shutdown, callInfo)
mock.lockShutdown.Unlock()
return mock.ShutdownFunc()
}
// ShutdownCalls gets all the calls that were made to Shutdown.
// Check the length with:
//
// len(mockedInterface.ShutdownCalls())
func (mock *InterfaceMock) ShutdownCalls() []struct {
} {
var calls []struct {
}
mock.lockShutdown.RLock()
calls = mock.calls.Shutdown
mock.lockShutdown.RUnlock()
return calls
}
// SystemGetCudaDriverVersion calls SystemGetCudaDriverVersionFunc.
func (mock *InterfaceMock) SystemGetCudaDriverVersion() (int, Return) {
if mock.SystemGetCudaDriverVersionFunc == nil {
panic("InterfaceMock.SystemGetCudaDriverVersionFunc: method is nil but Interface.SystemGetCudaDriverVersion was just called")
}
callInfo := struct {
}{}
mock.lockSystemGetCudaDriverVersion.Lock()
mock.calls.SystemGetCudaDriverVersion = append(mock.calls.SystemGetCudaDriverVersion, callInfo)
mock.lockSystemGetCudaDriverVersion.Unlock()
return mock.SystemGetCudaDriverVersionFunc()
}
// SystemGetCudaDriverVersionCalls gets all the calls that were made to SystemGetCudaDriverVersion.
// Check the length with:
//
// len(mockedInterface.SystemGetCudaDriverVersionCalls())
func (mock *InterfaceMock) SystemGetCudaDriverVersionCalls() []struct {
} {
var calls []struct {
}
mock.lockSystemGetCudaDriverVersion.RLock()
calls = mock.calls.SystemGetCudaDriverVersion
mock.lockSystemGetCudaDriverVersion.RUnlock()
return calls
}
// SystemGetDriverVersion calls SystemGetDriverVersionFunc.
func (mock *InterfaceMock) SystemGetDriverVersion() (string, Return) {
if mock.SystemGetDriverVersionFunc == nil {
panic("InterfaceMock.SystemGetDriverVersionFunc: method is nil but Interface.SystemGetDriverVersion was just called")
}
callInfo := struct {
}{}
mock.lockSystemGetDriverVersion.Lock()
mock.calls.SystemGetDriverVersion = append(mock.calls.SystemGetDriverVersion, callInfo)
mock.lockSystemGetDriverVersion.Unlock()
return mock.SystemGetDriverVersionFunc()
}
// SystemGetDriverVersionCalls gets all the calls that were made to SystemGetDriverVersion.
// Check the length with:
//
// len(mockedInterface.SystemGetDriverVersionCalls())
func (mock *InterfaceMock) SystemGetDriverVersionCalls() []struct {
} {
var calls []struct {
}
mock.lockSystemGetDriverVersion.RLock()
calls = mock.calls.SystemGetDriverVersion
mock.lockSystemGetDriverVersion.RUnlock()
return calls
}

32
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/options.go generated vendored Normal file
View File

@@ -0,0 +1,32 @@
/**
# Copyright 2023 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvml
// options represents the options that could be passed to the nvml contructor.
type options struct {
libraryPath string
}
// Option represents a functional option to control behaviour.
type Option func(*options)
// WithLibraryPath sets the NVML library name to use.
func WithLibraryPath(libraryPath string) Option {
return func(o *options) {
o.libraryPath = libraryPath
}
}

93
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/return.go generated vendored Normal file
View File

@@ -0,0 +1,93 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"fmt"
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
// String returns the string representation of a Return
func (r Return) String() string {
return errorStringFunc(nvml.Return(r))
}
// Error returns the string representation of a Return
func (r Return) Error() string {
return errorStringFunc(nvml.Return(r))
}
// Assigned to nvml.ErrorString if the system nvml library is in use
var errorStringFunc = defaultErrorStringFunc
var defaultErrorStringFunc = func(r nvml.Return) string {
switch Return(r) {
case SUCCESS:
return "SUCCESS"
case ERROR_UNINITIALIZED:
return "ERROR_UNINITIALIZED"
case ERROR_INVALID_ARGUMENT:
return "ERROR_INVALID_ARGUMENT"
case ERROR_NOT_SUPPORTED:
return "ERROR_NOT_SUPPORTED"
case ERROR_NO_PERMISSION:
return "ERROR_NO_PERMISSION"
case ERROR_ALREADY_INITIALIZED:
return "ERROR_ALREADY_INITIALIZED"
case ERROR_NOT_FOUND:
return "ERROR_NOT_FOUND"
case ERROR_INSUFFICIENT_SIZE:
return "ERROR_INSUFFICIENT_SIZE"
case ERROR_INSUFFICIENT_POWER:
return "ERROR_INSUFFICIENT_POWER"
case ERROR_DRIVER_NOT_LOADED:
return "ERROR_DRIVER_NOT_LOADED"
case ERROR_TIMEOUT:
return "ERROR_TIMEOUT"
case ERROR_IRQ_ISSUE:
return "ERROR_IRQ_ISSUE"
case ERROR_LIBRARY_NOT_FOUND:
return "ERROR_LIBRARY_NOT_FOUND"
case ERROR_FUNCTION_NOT_FOUND:
return "ERROR_FUNCTION_NOT_FOUND"
case ERROR_CORRUPTED_INFOROM:
return "ERROR_CORRUPTED_INFOROM"
case ERROR_GPU_IS_LOST:
return "ERROR_GPU_IS_LOST"
case ERROR_RESET_REQUIRED:
return "ERROR_RESET_REQUIRED"
case ERROR_OPERATING_SYSTEM:
return "ERROR_OPERATING_SYSTEM"
case ERROR_LIB_RM_VERSION_MISMATCH:
return "ERROR_LIB_RM_VERSION_MISMATCH"
case ERROR_IN_USE:
return "ERROR_IN_USE"
case ERROR_MEMORY:
return "ERROR_MEMORY"
case ERROR_NO_DATA:
return "ERROR_NO_DATA"
case ERROR_VGPU_ECC_NOT_SUPPORTED:
return "ERROR_VGPU_ECC_NOT_SUPPORTED"
case ERROR_INSUFFICIENT_RESOURCES:
return "ERROR_INSUFFICIENT_RESOURCES"
case ERROR_UNKNOWN:
return "ERROR_UNKNOWN"
default:
return fmt.Sprintf("Unknown return value: %d", r)
}
}

147
vendor/github.com/NVIDIA/go-nvlib/pkg/nvml/types.go generated vendored Normal file
View File

@@ -0,0 +1,147 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvml
import (
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
// Interface defines the functions implemented by an NVML library
//
//go:generate moq -out nvml_mock.go . Interface
type Interface interface {
DeviceGetCount() (int, Return)
DeviceGetHandleByIndex(Index int) (Device, Return)
DeviceGetHandleByUUID(UUID string) (Device, Return)
ErrorString(r Return) string
EventSetCreate() (EventSet, Return)
Init() Return
Lookup(string) error
Shutdown() Return
SystemGetCudaDriverVersion() (int, Return)
SystemGetDriverVersion() (string, Return)
}
// Device defines the functions implemented by an NVML device
//
//go:generate moq -out device_mock.go . Device
type Device interface {
CreateGpuInstanceWithPlacement(*GpuInstanceProfileInfo, *GpuInstancePlacement) (GpuInstance, Return)
GetArchitecture() (DeviceArchitecture, Return)
GetAttributes() (DeviceAttributes, Return)
GetBrand() (BrandType, Return)
GetComputeInstanceId() (int, Return)
GetCudaComputeCapability() (int, int, Return)
GetDeviceHandleFromMigDeviceHandle() (Device, Return)
GetGpuInstanceById(ID int) (GpuInstance, Return)
GetGpuInstanceId() (int, Return)
GetGpuInstancePossiblePlacements(*GpuInstanceProfileInfo) ([]GpuInstancePlacement, Return)
GetGpuInstanceProfileInfo(Profile int) (GpuInstanceProfileInfo, Return)
GetGpuInstances(Info *GpuInstanceProfileInfo) ([]GpuInstance, Return)
GetIndex() (int, Return)
GetMaxMigDeviceCount() (int, Return)
GetMemoryInfo() (Memory, Return)
GetMigDeviceHandleByIndex(Index int) (Device, Return)
GetMigMode() (int, int, Return)
GetMinorNumber() (int, Return)
GetName() (string, Return)
GetPciInfo() (PciInfo, Return)
GetSupportedEventTypes() (uint64, Return)
GetUUID() (string, Return)
IsMigDeviceHandle() (bool, Return)
RegisterEvents(uint64, EventSet) Return
SetMigMode(Mode int) (Return, Return)
}
// GpuInstance defines the functions implemented by a GpuInstance
//
//go:generate moq -out gi_mock.go . GpuInstance
type GpuInstance interface {
CreateComputeInstance(Info *ComputeInstanceProfileInfo) (ComputeInstance, Return)
Destroy() Return
GetComputeInstanceById(ID int) (ComputeInstance, Return)
GetComputeInstanceProfileInfo(Profile int, EngProfile int) (ComputeInstanceProfileInfo, Return)
GetComputeInstances(Info *ComputeInstanceProfileInfo) ([]ComputeInstance, Return)
GetInfo() (GpuInstanceInfo, Return)
}
// ComputeInstance defines the functions implemented by a ComputeInstance
//
//go:generate moq -out ci_mock.go . ComputeInstance
type ComputeInstance interface {
Destroy() Return
GetInfo() (ComputeInstanceInfo, Return)
}
// GpuInstanceInfo holds info about a GPU Instance
type GpuInstanceInfo struct {
Device Device
Id uint32
ProfileId uint32
Placement GpuInstancePlacement
}
// ComputeInstanceInfo holds info about a Compute Instance
type ComputeInstanceInfo struct {
Device Device
GpuInstance GpuInstance
Id uint32
ProfileId uint32
Placement ComputeInstancePlacement
}
// EventData defines NVML event Data
type EventData struct {
Device Device
EventType uint64
EventData uint64
GpuInstanceId uint32
ComputeInstanceId uint32
}
// EventSet defines NVML event Data
type EventSet nvml.EventSet
// Return defines an NVML return type
type Return nvml.Return
// Memory holds info about GPU device memory
type Memory nvml.Memory
// PciInfo holds info about the PCI connections of a GPU dvice
type PciInfo nvml.PciInfo
// GpuInstanceProfileInfo holds info about a GPU Instance Profile
type GpuInstanceProfileInfo nvml.GpuInstanceProfileInfo
// GpuInstancePlacement holds placement info about a GPU Instance
type GpuInstancePlacement nvml.GpuInstancePlacement
// ComputeInstanceProfileInfo holds info about a Compute Instance Profile
type ComputeInstanceProfileInfo nvml.ComputeInstanceProfileInfo
// ComputeInstancePlacement holds placement info about a Compute Instance
type ComputeInstancePlacement nvml.ComputeInstancePlacement
// DeviceAttributes stores information about MIG devices
type DeviceAttributes nvml.DeviceAttributes
// DeviceArchitecture represents the hardware architecture of a GPU device
type DeviceArchitecture nvml.DeviceArchitecture
// BrandType represents the brand of a GPU device
type BrandType nvml.BrandType

View File

@@ -0,0 +1,94 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package bytes
import (
"encoding/binary"
"unsafe"
)
// Raw returns just the bytes without any assumptions about layout
type Raw interface {
Raw() *[]byte
}
// Reader used to read various data sizes in the byte array
type Reader interface {
Read8(pos int) uint8
Read16(pos int) uint16
Read32(pos int) uint32
Read64(pos int) uint64
Len() int
}
// Writer used to write various sizes of data in the byte array
type Writer interface {
Write8(pos int, value uint8)
Write16(pos int, value uint16)
Write32(pos int, value uint32)
Write64(pos int, value uint64)
Len() int
}
// Bytes object for manipulating arbitrary byte arrays
type Bytes interface {
Raw
Reader
Writer
Slice(offset int, size int) Bytes
LittleEndian() Bytes
BigEndian() Bytes
}
var nativeByteOrder binary.ByteOrder
func init() {
buf := [2]byte{}
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0x00FF)
switch buf {
case [2]byte{0xFF, 0x00}:
nativeByteOrder = binary.LittleEndian
case [2]byte{0x00, 0xFF}:
nativeByteOrder = binary.BigEndian
default:
panic("Unable to infer byte order")
}
}
// New raw bytearray
func New(data *[]byte) Bytes {
return (*native)(data)
}
// NewLittleEndian little endian ordering of bytes
func NewLittleEndian(data *[]byte) Bytes {
if nativeByteOrder == binary.LittleEndian {
return (*native)(data)
}
return (*swapbo)(data)
}
// NewBigEndian big endian ordering of bytes
func NewBigEndian(data *[]byte) Bytes {
if nativeByteOrder == binary.BigEndian {
return (*native)(data)
}
return (*swapbo)(data)
}

View File

@@ -0,0 +1,78 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package bytes
import (
"unsafe"
)
type native []byte
var _ Bytes = (*native)(nil)
func (b *native) Read8(pos int) uint8 {
return (*b)[pos]
}
func (b *native) Read16(pos int) uint16 {
return *(*uint16)(unsafe.Pointer(&((*b)[pos])))
}
func (b *native) Read32(pos int) uint32 {
return *(*uint32)(unsafe.Pointer(&((*b)[pos])))
}
func (b *native) Read64(pos int) uint64 {
return *(*uint64)(unsafe.Pointer(&((*b)[pos])))
}
func (b *native) Write8(pos int, value uint8) {
(*b)[pos] = value
}
func (b *native) Write16(pos int, value uint16) {
*(*uint16)(unsafe.Pointer(&((*b)[pos]))) = value
}
func (b *native) Write32(pos int, value uint32) {
*(*uint32)(unsafe.Pointer(&((*b)[pos]))) = value
}
func (b *native) Write64(pos int, value uint64) {
*(*uint64)(unsafe.Pointer(&((*b)[pos]))) = value
}
func (b *native) Slice(offset int, size int) Bytes {
nb := (*b)[offset : offset+size]
return &nb
}
func (b *native) LittleEndian() Bytes {
return NewLittleEndian((*[]byte)(b))
}
func (b *native) BigEndian() Bytes {
return NewBigEndian((*[]byte)(b))
}
func (b *native) Raw() *[]byte {
return (*[]byte)(b)
}
func (b *native) Len() int {
return len(*b)
}

View File

@@ -0,0 +1,112 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package bytes
import (
"unsafe"
)
type swapbo []byte
var _ Bytes = (*swapbo)(nil)
func (b *swapbo) Read8(pos int) uint8 {
return (*b)[pos]
}
func (b *swapbo) Read16(pos int) uint16 {
buf := [2]byte{}
buf[0] = (*b)[pos+1]
buf[1] = (*b)[pos+0]
return *(*uint16)(unsafe.Pointer(&buf[0]))
}
func (b *swapbo) Read32(pos int) uint32 {
buf := [4]byte{}
buf[0] = (*b)[pos+3]
buf[1] = (*b)[pos+2]
buf[2] = (*b)[pos+1]
buf[3] = (*b)[pos+0]
return *(*uint32)(unsafe.Pointer(&buf[0]))
}
func (b *swapbo) Read64(pos int) uint64 {
buf := [8]byte{}
buf[0] = (*b)[pos+7]
buf[1] = (*b)[pos+6]
buf[2] = (*b)[pos+5]
buf[3] = (*b)[pos+4]
buf[4] = (*b)[pos+3]
buf[5] = (*b)[pos+2]
buf[6] = (*b)[pos+1]
buf[7] = (*b)[pos+0]
return *(*uint64)(unsafe.Pointer(&buf[0]))
}
func (b *swapbo) Write8(pos int, value uint8) {
(*b)[pos] = value
}
func (b *swapbo) Write16(pos int, value uint16) {
buf := [2]byte{}
*(*uint16)(unsafe.Pointer(&buf[0])) = value
(*b)[pos+0] = buf[1]
(*b)[pos+1] = buf[0]
}
func (b *swapbo) Write32(pos int, value uint32) {
buf := [4]byte{}
*(*uint32)(unsafe.Pointer(&buf[0])) = value
(*b)[pos+0] = buf[3]
(*b)[pos+1] = buf[2]
(*b)[pos+2] = buf[1]
(*b)[pos+3] = buf[0]
}
func (b *swapbo) Write64(pos int, value uint64) {
buf := [8]byte{}
*(*uint64)(unsafe.Pointer(&buf[0])) = value
(*b)[pos+0] = buf[7]
(*b)[pos+1] = buf[6]
(*b)[pos+2] = buf[5]
(*b)[pos+3] = buf[4]
(*b)[pos+4] = buf[3]
(*b)[pos+5] = buf[2]
(*b)[pos+6] = buf[1]
(*b)[pos+7] = buf[0]
}
func (b *swapbo) Slice(offset int, size int) Bytes {
nb := (*b)[offset : offset+size]
return &nb
}
func (b *swapbo) LittleEndian() Bytes {
return NewLittleEndian((*[]byte)(b))
}
func (b *swapbo) BigEndian() Bytes {
return NewBigEndian((*[]byte)(b))
}
func (b *swapbo) Raw() *[]byte {
return (*[]byte)(b)
}
func (b *swapbo) Len() int {
return len(*b)
}

147
vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/config.go generated vendored Normal file
View File

@@ -0,0 +1,147 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvpci
import (
"fmt"
"os"
"github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes"
)
const (
// PCICfgSpaceStandardSize represents the size in bytes of the standard config space
PCICfgSpaceStandardSize = 256
// PCICfgSpaceExtendedSize represents the size in bytes of the extended config space
PCICfgSpaceExtendedSize = 4096
// PCICapabilityListPointer represents offset for the capability list pointer
PCICapabilityListPointer = 0x34
// PCIStatusCapabilityList represents the status register bit which indicates capability list support
PCIStatusCapabilityList = 0x10
// PCIStatusBytePosition represents the position of the status register
PCIStatusBytePosition = 0x06
)
// ConfigSpace PCI configuration space (standard extended) file path
type ConfigSpace struct {
Path string
}
// ConfigSpaceIO Interface for reading and writing raw and preconfigured values
type ConfigSpaceIO interface {
bytes.Bytes
GetVendorID() uint16
GetDeviceID() uint16
GetPCICapabilities() (*PCICapabilities, error)
}
type configSpaceIO struct {
bytes.Bytes
}
// PCIStandardCapability standard PCI config space
type PCIStandardCapability struct {
bytes.Bytes
}
// PCIExtendedCapability extended PCI config space
type PCIExtendedCapability struct {
bytes.Bytes
Version uint8
}
// PCICapabilities combines the standard and extended config space
type PCICapabilities struct {
Standard map[uint8]*PCIStandardCapability
Extended map[uint16]*PCIExtendedCapability
}
func (cs *ConfigSpace) Read() (ConfigSpaceIO, error) {
config, err := os.ReadFile(cs.Path)
if err != nil {
return nil, fmt.Errorf("failed to open file: %v", err)
}
return &configSpaceIO{bytes.New(&config)}, nil
}
func (cs *configSpaceIO) GetVendorID() uint16 {
return cs.Read16(0)
}
func (cs *configSpaceIO) GetDeviceID() uint16 {
return cs.Read16(2)
}
func (cs *configSpaceIO) GetPCICapabilities() (*PCICapabilities, error) {
caps := &PCICapabilities{
make(map[uint8]*PCIStandardCapability),
make(map[uint16]*PCIExtendedCapability),
}
support := cs.Read8(PCIStatusBytePosition) & PCIStatusCapabilityList
if support == 0 {
return nil, fmt.Errorf("pci device does not support capability list")
}
soffset := cs.Read8(PCICapabilityListPointer)
if int(soffset) >= cs.Len() {
return nil, fmt.Errorf("capability list pointer out of bounds")
}
for soffset != 0 {
if soffset == 0xff {
return nil, fmt.Errorf("config space broken")
}
if int(soffset) >= PCICfgSpaceStandardSize {
return nil, fmt.Errorf("standard capability list pointer out of bounds")
}
data := cs.Read32(int(soffset))
id := uint8(data & 0xff)
caps.Standard[id] = &PCIStandardCapability{
cs.Slice(int(soffset), cs.Len()-int(soffset)),
}
soffset = uint8((data >> 8) & 0xff)
}
if cs.Len() <= PCICfgSpaceStandardSize {
return caps, nil
}
eoffset := uint16(PCICfgSpaceStandardSize)
for eoffset != 0 {
if eoffset == 0xffff {
return nil, fmt.Errorf("config space broken")
}
if int(eoffset) >= PCICfgSpaceExtendedSize {
return nil, fmt.Errorf("extended capability list pointer out of bounds")
}
// |31 20|19 16|15 0|
// |--------------------|------|-------------------------|
// | Next Cap Offset |Vers. |PCI Express Ext. Cap ID |
data := cs.Read32(int(eoffset))
id := uint16(data & 0xffff)
version := uint8((data >> 16) & 0xf)
caps.Extended[id] = &PCIExtendedCapability{
cs.Slice(int(eoffset), cs.Len()-int(eoffset)),
version,
}
eoffset = uint16((data >> 20) & 0xfff)
}
return caps, nil
}

29
vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/logger.go generated vendored Normal file
View File

@@ -0,0 +1,29 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvpci
import "log"
type logger interface {
Warningf(string, ...interface{})
}
type simpleLogger struct{}
func (l simpleLogger) Warningf(format string, v ...interface{}) {
log.Printf("WARNING: "+format, v)
}

105
vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/mlxpci.go generated vendored Normal file
View File

@@ -0,0 +1,105 @@
/*
* Copyright (c) NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvpci
import (
"fmt"
"strings"
)
const (
// PCIMellanoxVendorID represents PCI vendor id for Mellanox
PCIMellanoxVendorID uint16 = 0x15b3
// PCINetworkControllerClass represents the PCI class for network controllers
PCINetworkControllerClass uint32 = 0x020000
// PCIBridgeClass represents the PCI class for network controllers
PCIBridgeClass uint32 = 0x060400
)
// GetNetworkControllers returns all Mellanox Network Controller PCI devices on the system
func (p *nvpci) GetNetworkControllers() ([]*NvidiaPCIDevice, error) {
devices, err := p.GetAllDevices()
if err != nil {
return nil, fmt.Errorf("error getting all NVIDIA devices: %v", err)
}
var filtered []*NvidiaPCIDevice
for _, d := range devices {
if d.IsNetworkController() {
filtered = append(filtered, d)
}
}
return filtered, nil
}
// GetPciBridges retrieves all Mellanox PCI(e) Bridges
func (p *nvpci) GetPciBridges() ([]*NvidiaPCIDevice, error) {
devices, err := p.GetAllDevices()
if err != nil {
return nil, fmt.Errorf("error getting all NVIDIA devices: %v", err)
}
var filtered []*NvidiaPCIDevice
for _, d := range devices {
if d.IsPciBridge() {
filtered = append(filtered, d)
}
}
return filtered, nil
}
// IsNetworkController if class == 0x300
func (d *NvidiaPCIDevice) IsNetworkController() bool {
return d.Class == PCINetworkControllerClass
}
// IsPciBridge if class == 0x0604
func (d *NvidiaPCIDevice) IsPciBridge() bool {
return d.Class == PCIBridgeClass
}
// IsDPU returns if a device is a DPU
func (d *NvidiaPCIDevice) IsDPU() bool {
if !strings.Contains(d.DeviceName, "BlueField") {
return false
}
// DPU is a multifunction device hence look only for the .0 function
// and ignore subfunctions like .1, .2, etc.
if strings.HasSuffix(d.Address, ".0") {
return true
}
return false
}
// GetDPUs returns all Mellanox DPU devices on the system
func (p *nvpci) GetDPUs() ([]*NvidiaPCIDevice, error) {
devices, err := p.GetNetworkControllers()
if err != nil {
return nil, fmt.Errorf("error getting all network controllers: %v", err)
}
var filtered []*NvidiaPCIDevice
for _, d := range devices {
if d.IsDPU() {
filtered = append(filtered, d)
}
}
return filtered, nil
}

View File

@@ -0,0 +1,127 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mmio
import (
"fmt"
"os"
"syscall"
"unsafe"
"github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes"
)
// Mmio memory map a region
type Mmio interface {
bytes.Raw
bytes.Reader
bytes.Writer
Sync() error
Close() error
Slice(offset int, size int) Mmio
LittleEndian() Mmio
BigEndian() Mmio
}
type mmio struct {
bytes.Bytes
}
func open(path string, offset int, size int, flags int) (Mmio, error) {
var mmapFlags int
switch flags {
case os.O_RDONLY:
mmapFlags = syscall.PROT_READ
case os.O_RDWR:
mmapFlags = syscall.PROT_READ | syscall.PROT_WRITE
default:
return nil, fmt.Errorf("invalid flags: %v", flags)
}
file, err := os.OpenFile(path, flags, 0)
if err != nil {
return nil, fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
fi, err := file.Stat()
if err != nil {
return nil, fmt.Errorf("failed to get file info: %v", err)
}
if size > int(fi.Size()) {
return nil, fmt.Errorf("requested size larger than file size")
}
if size < 0 {
size = int(fi.Size())
}
mmap, err := syscall.Mmap(
int(file.Fd()),
int64(offset),
size,
mmapFlags,
syscall.MAP_SHARED)
if err != nil {
return nil, fmt.Errorf("failed to mmap file: %v", err)
}
return &mmio{bytes.New(&mmap)}, nil
}
// OpenRO open region readonly
func OpenRO(path string, offset int, size int) (Mmio, error) {
return open(path, offset, size, os.O_RDONLY)
}
// OpenRW open region read write
func OpenRW(path string, offset int, size int) (Mmio, error) {
return open(path, offset, size, os.O_RDWR)
}
func (m *mmio) Slice(offset int, size int) Mmio {
return &mmio{m.Bytes.Slice(offset, size)}
}
func (m *mmio) LittleEndian() Mmio {
return &mmio{m.Bytes.LittleEndian()}
}
func (m *mmio) BigEndian() Mmio {
return &mmio{m.Bytes.BigEndian()}
}
func (m *mmio) Close() error {
err := syscall.Munmap(*m.Bytes.Raw())
if err != nil {
return fmt.Errorf("failed to munmap file: %v", err)
}
return nil
}
func (m *mmio) Sync() error {
_, _, errno := syscall.Syscall(
syscall.SYS_MSYNC,
uintptr(unsafe.Pointer(&(*m.Bytes.Raw())[0])),
uintptr(m.Len()),
uintptr(syscall.MS_SYNC|syscall.MS_INVALIDATE))
if errno != 0 {
return fmt.Errorf("failed to msync file: %v", errno)
}
return nil
}

View File

@@ -0,0 +1,74 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package mmio
import (
"fmt"
"github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes"
)
type mockMmio struct {
mmio
source *[]byte
offset int
rw bool
}
func mockOpen(source *[]byte, offset int, size int, rw bool) (Mmio, error) {
if size < 0 {
size = len(*source) - offset
}
if (offset + size) > len(*source) {
return nil, fmt.Errorf("offset+size out of range")
}
data := append([]byte{}, (*source)[offset:offset+size]...)
m := &mockMmio{}
m.Bytes = bytes.New(&data).LittleEndian()
m.source = source
m.offset = offset
m.rw = rw
return m, nil
}
// MockOpenRO open read only
func MockOpenRO(source *[]byte, offset int, size int) (Mmio, error) {
return mockOpen(source, offset, size, false)
}
// MockOpenRW open read write
func MockOpenRW(source *[]byte, offset int, size int) (Mmio, error) {
return mockOpen(source, offset, size, true)
}
func (m *mockMmio) Close() error {
m = &mockMmio{}
return nil
}
func (m *mockMmio) Sync() error {
if !m.rw {
return fmt.Errorf("opened read-only")
}
for i := range *m.Bytes.Raw() {
(*m.source)[m.offset+i] = (*m.Bytes.Raw())[i]
}
return nil
}

158
vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/mock.go generated vendored Normal file
View File

@@ -0,0 +1,158 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvpci
import (
"fmt"
"os"
"path/filepath"
"github.com/NVIDIA/go-nvlib/pkg/nvpci/bytes"
)
// MockNvpci mock pci device
type MockNvpci struct {
*nvpci
}
var _ Interface = (*MockNvpci)(nil)
// NewMockNvpci create new mock PCI and remove old devices
func NewMockNvpci() (mock *MockNvpci, rerr error) {
rootDir, err := os.MkdirTemp(os.TempDir(), "")
if err != nil {
return nil, err
}
defer func() {
if rerr != nil {
os.RemoveAll(rootDir)
}
}()
mock = &MockNvpci{
New(WithPCIDevicesRoot(rootDir)).(*nvpci),
}
return mock, nil
}
// Cleanup remove the mocked PCI devices root folder
func (m *MockNvpci) Cleanup() {
os.RemoveAll(m.pciDevicesRoot)
}
// AddMockA100 Create an A100 like GPU mock device
func (m *MockNvpci) AddMockA100(address string, numaNode int) error {
deviceDir := filepath.Join(m.pciDevicesRoot, address)
err := os.MkdirAll(deviceDir, 0755)
if err != nil {
return err
}
vendor, err := os.Create(filepath.Join(deviceDir, "vendor"))
if err != nil {
return err
}
_, err = vendor.WriteString(fmt.Sprintf("0x%x", PCINvidiaVendorID))
if err != nil {
return err
}
class, err := os.Create(filepath.Join(deviceDir, "class"))
if err != nil {
return err
}
_, err = class.WriteString(fmt.Sprintf("0x%x", PCI3dControllerClass))
if err != nil {
return err
}
device, err := os.Create(filepath.Join(deviceDir, "device"))
if err != nil {
return err
}
_, err = device.WriteString("0x20bf")
if err != nil {
return err
}
_, err = os.Create(filepath.Join(deviceDir, "nvidia"))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(deviceDir, "nvidia"), filepath.Join(deviceDir, "driver"))
if err != nil {
return err
}
_, err = os.Create(filepath.Join(deviceDir, "20"))
if err != nil {
return err
}
err = os.Symlink(filepath.Join(deviceDir, "20"), filepath.Join(deviceDir, "iommu_group"))
if err != nil {
return err
}
numa, err := os.Create(filepath.Join(deviceDir, "numa_node"))
if err != nil {
return err
}
_, err = numa.WriteString(fmt.Sprintf("%v", numaNode))
if err != nil {
return err
}
config, err := os.Create(filepath.Join(deviceDir, "config"))
if err != nil {
return err
}
_data := make([]byte, PCICfgSpaceStandardSize)
data := bytes.New(&_data)
data.Write16(0, PCINvidiaVendorID)
data.Write16(2, uint16(0x20bf))
data.Write8(PCIStatusBytePosition, PCIStatusCapabilityList)
_, err = config.Write(*data.Raw())
if err != nil {
return err
}
bar0 := []uint64{0x00000000c2000000, 0x00000000c2ffffff, 0x0000000000040200}
resource, err := os.Create(filepath.Join(deviceDir, "resource"))
if err != nil {
return err
}
_, err = resource.WriteString(fmt.Sprintf("0x%x 0x%x 0x%x", bar0[0], bar0[1], bar0[2]))
if err != nil {
return err
}
pmcID := uint32(0x170000a1)
resource0, err := os.Create(filepath.Join(deviceDir, "resource0"))
if err != nil {
return err
}
_data = make([]byte, bar0[1]-bar0[0]+1)
data = bytes.New(&_data).LittleEndian()
data.Write32(0, pmcID)
_, err = resource0.Write(*data.Raw())
if err != nil {
return err
}
return nil
}

430
vendor/github.com/NVIDIA/go-nvlib/pkg/nvpci/nvpci.go generated vendored Normal file
View File

@@ -0,0 +1,430 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvpci
import (
"fmt"
"os"
"path"
"path/filepath"
"sort"
"strconv"
"strings"
"github.com/NVIDIA/go-nvlib/pkg/pciids"
)
const (
// PCIDevicesRoot represents base path for all pci devices under sysfs
PCIDevicesRoot = "/sys/bus/pci/devices"
// PCINvidiaVendorID represents PCI vendor id for NVIDIA
PCINvidiaVendorID uint16 = 0x10de
// PCIVgaControllerClass represents the PCI class for VGA Controllers
PCIVgaControllerClass uint32 = 0x030000
// PCI3dControllerClass represents the PCI class for 3D Graphics accellerators
PCI3dControllerClass uint32 = 0x030200
// PCINvSwitchClass represents the PCI class for NVSwitches
PCINvSwitchClass uint32 = 0x068000
// UnknownDeviceString is the device name to set for devices not found in the PCI database
UnknownDeviceString = "UNKNOWN_DEVICE"
// UnknownClassString is the class name to set for devices not found in the PCI database
UnknownClassString = "UNKNOWN_CLASS"
)
// Interface allows us to get a list of all NVIDIA PCI devices
type Interface interface {
GetAllDevices() ([]*NvidiaPCIDevice, error)
Get3DControllers() ([]*NvidiaPCIDevice, error)
GetVGAControllers() ([]*NvidiaPCIDevice, error)
GetNVSwitches() ([]*NvidiaPCIDevice, error)
GetGPUs() ([]*NvidiaPCIDevice, error)
GetGPUByIndex(int) (*NvidiaPCIDevice, error)
GetGPUByPciBusID(string) (*NvidiaPCIDevice, error)
GetNetworkControllers() ([]*NvidiaPCIDevice, error)
GetPciBridges() ([]*NvidiaPCIDevice, error)
GetDPUs() ([]*NvidiaPCIDevice, error)
}
// MemoryResources a more human readable handle
type MemoryResources map[int]*MemoryResource
// ResourceInterface exposes some higher level functions of resources
type ResourceInterface interface {
GetTotalAddressableMemory(bool) (uint64, uint64)
}
type nvpci struct {
logger logger
pciDevicesRoot string
pcidbPath string
}
var _ Interface = (*nvpci)(nil)
var _ ResourceInterface = (*MemoryResources)(nil)
// NvidiaPCIDevice represents a PCI device for an NVIDIA product
type NvidiaPCIDevice struct {
Path string
Address string
Vendor uint16
Class uint32
ClassName string
Device uint16
DeviceName string
Driver string
IommuGroup int
NumaNode int
Config *ConfigSpace
Resources MemoryResources
IsVF bool
}
// IsVGAController if class == 0x300
func (d *NvidiaPCIDevice) IsVGAController() bool {
return d.Class == PCIVgaControllerClass
}
// Is3DController if class == 0x302
func (d *NvidiaPCIDevice) Is3DController() bool {
return d.Class == PCI3dControllerClass
}
// IsNVSwitch if class == 0x068
func (d *NvidiaPCIDevice) IsNVSwitch() bool {
return d.Class == PCINvSwitchClass
}
// IsGPU either VGA for older cards or 3D for newer
func (d *NvidiaPCIDevice) IsGPU() bool {
return d.IsVGAController() || d.Is3DController()
}
// IsResetAvailable some devices can be reset without rebooting,
// check if applicable
func (d *NvidiaPCIDevice) IsResetAvailable() bool {
_, err := os.Stat(path.Join(d.Path, "reset"))
return err == nil
}
// Reset perform a reset to apply a new configuration at HW level
func (d *NvidiaPCIDevice) Reset() error {
err := os.WriteFile(path.Join(d.Path, "reset"), []byte("1"), 0)
if err != nil {
return fmt.Errorf("unable to write to reset file: %v", err)
}
return nil
}
// New interface that allows us to get a list of all NVIDIA PCI devices
func New(opts ...Option) Interface {
n := &nvpci{}
for _, opt := range opts {
opt(n)
}
if n.logger == nil {
n.logger = &simpleLogger{}
}
if n.pciDevicesRoot == "" {
n.pciDevicesRoot = PCIDevicesRoot
}
return n
}
// Option defines a function for passing options to the New() call
type Option func(*nvpci)
// WithLogger provides an Option to set the logger for the library
func WithLogger(logger logger) Option {
return func(n *nvpci) {
n.logger = logger
}
}
// WithPCIDevicesRoot provides an Option to set the root path
// for PCI devices on the system.
func WithPCIDevicesRoot(root string) Option {
return func(n *nvpci) {
n.pciDevicesRoot = root
}
}
// WithPCIDatabasePath provides an Option to set the path
// to the pciids database file.
func WithPCIDatabasePath(path string) Option {
return func(n *nvpci) {
n.pcidbPath = path
}
}
// GetAllDevices returns all Nvidia PCI devices on the system
func (p *nvpci) GetAllDevices() ([]*NvidiaPCIDevice, error) {
deviceDirs, err := os.ReadDir(p.pciDevicesRoot)
if err != nil {
return nil, fmt.Errorf("unable to read PCI bus devices: %v", err)
}
var nvdevices []*NvidiaPCIDevice
for _, deviceDir := range deviceDirs {
deviceAddress := deviceDir.Name()
nvdevice, err := p.GetGPUByPciBusID(deviceAddress)
if err != nil {
return nil, fmt.Errorf("error constructing NVIDIA PCI device %s: %v", deviceAddress, err)
}
if nvdevice == nil {
continue
}
nvdevices = append(nvdevices, nvdevice)
}
addressToID := func(address string) uint64 {
address = strings.ReplaceAll(address, ":", "")
address = strings.ReplaceAll(address, ".", "")
id, _ := strconv.ParseUint(address, 16, 64)
return id
}
sort.Slice(nvdevices, func(i, j int) bool {
return addressToID(nvdevices[i].Address) < addressToID(nvdevices[j].Address)
})
return nvdevices, nil
}
// GetGPUByPciBusID constructs an NvidiaPCIDevice for the specified address (PCI Bus ID)
func (p *nvpci) GetGPUByPciBusID(address string) (*NvidiaPCIDevice, error) {
devicePath := filepath.Join(p.pciDevicesRoot, address)
vendor, err := os.ReadFile(path.Join(devicePath, "vendor"))
if err != nil {
return nil, fmt.Errorf("unable to read PCI device vendor id for %s: %v", address, err)
}
vendorStr := strings.TrimSpace(string(vendor))
vendorID, err := strconv.ParseUint(vendorStr, 0, 16)
if err != nil {
return nil, fmt.Errorf("unable to convert vendor string to uint16: %v", vendorStr)
}
if uint16(vendorID) != PCINvidiaVendorID && uint16(vendorID) != PCIMellanoxVendorID {
return nil, nil
}
class, err := os.ReadFile(path.Join(devicePath, "class"))
if err != nil {
return nil, fmt.Errorf("unable to read PCI device class for %s: %v", address, err)
}
classStr := strings.TrimSpace(string(class))
classID, err := strconv.ParseUint(classStr, 0, 32)
if err != nil {
return nil, fmt.Errorf("unable to convert class string to uint32: %v", classStr)
}
device, err := os.ReadFile(path.Join(devicePath, "device"))
if err != nil {
return nil, fmt.Errorf("unable to read PCI device id for %s: %v", address, err)
}
deviceStr := strings.TrimSpace(string(device))
deviceID, err := strconv.ParseUint(deviceStr, 0, 16)
if err != nil {
return nil, fmt.Errorf("unable to convert device string to uint16: %v", deviceStr)
}
driver, err := filepath.EvalSymlinks(path.Join(devicePath, "driver"))
if err == nil {
driver = filepath.Base(driver)
} else if os.IsNotExist(err) {
driver = ""
} else {
return nil, fmt.Errorf("unable to detect driver for %s: %v", address, err)
}
var iommuGroup int64
iommu, err := filepath.EvalSymlinks(path.Join(devicePath, "iommu_group"))
if err == nil {
iommuGroupStr := strings.TrimSpace(filepath.Base(iommu))
iommuGroup, err = strconv.ParseInt(iommuGroupStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("unable to convert iommu_group string to int64: %v", iommuGroupStr)
}
} else if os.IsNotExist(err) {
iommuGroup = -1
} else {
return nil, fmt.Errorf("unable to detect iommu_group for %s: %v", address, err)
}
// device is a virtual function (VF) if "physfn" symlink exists
var isVF bool
_, err = filepath.EvalSymlinks(path.Join(devicePath, "physfn"))
if err == nil {
isVF = true
}
if err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("unable to resolve %s: %v", path.Join(devicePath, "physfn"), err)
}
numa, err := os.ReadFile(path.Join(devicePath, "numa_node"))
if err != nil {
return nil, fmt.Errorf("unable to read PCI NUMA node for %s: %v", address, err)
}
numaStr := strings.TrimSpace(string(numa))
numaNode, err := strconv.ParseInt(numaStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("unable to convert NUMA node string to int64: %v", numaNode)
}
config := &ConfigSpace{
Path: path.Join(devicePath, "config"),
}
resource, err := os.ReadFile(path.Join(devicePath, "resource"))
if err != nil {
return nil, fmt.Errorf("unable to read PCI resource file for %s: %v", address, err)
}
resources := make(map[int]*MemoryResource)
for i, line := range strings.Split(strings.TrimSpace(string(resource)), "\n") {
values := strings.Split(line, " ")
if len(values) != 3 {
return nil, fmt.Errorf("more than 3 entries in line '%d' of resource file", i)
}
start, _ := strconv.ParseUint(values[0], 0, 64)
end, _ := strconv.ParseUint(values[1], 0, 64)
flags, _ := strconv.ParseUint(values[2], 0, 64)
if (end - start) != 0 {
resources[i] = &MemoryResource{
uintptr(start),
uintptr(end),
flags,
fmt.Sprintf("%s/resource%d", devicePath, i),
}
}
}
pciDB := pciids.NewDB()
deviceName, err := pciDB.GetDeviceName(uint16(vendorID), uint16(deviceID))
if err != nil {
p.logger.Warningf("unable to get device name: %v\n", err)
deviceName = UnknownDeviceString
}
className, err := pciDB.GetClassName(uint32(classID))
if err != nil {
p.logger.Warningf("unable to get class name for device: %v\n", err)
className = UnknownClassString
}
nvdevice := &NvidiaPCIDevice{
Path: devicePath,
Address: address,
Vendor: uint16(vendorID),
Class: uint32(classID),
Device: uint16(deviceID),
Driver: driver,
IommuGroup: int(iommuGroup),
NumaNode: int(numaNode),
Config: config,
Resources: resources,
IsVF: isVF,
DeviceName: deviceName,
ClassName: className,
}
return nvdevice, nil
}
// Get3DControllers returns all NVIDIA 3D Controller PCI devices on the system
func (p *nvpci) Get3DControllers() ([]*NvidiaPCIDevice, error) {
devices, err := p.GetAllDevices()
if err != nil {
return nil, fmt.Errorf("error getting all NVIDIA devices: %v", err)
}
var filtered []*NvidiaPCIDevice
for _, d := range devices {
if d.Is3DController() {
filtered = append(filtered, d)
}
}
return filtered, nil
}
// GetVGAControllers returns all NVIDIA VGA Controller PCI devices on the system
func (p *nvpci) GetVGAControllers() ([]*NvidiaPCIDevice, error) {
devices, err := p.GetAllDevices()
if err != nil {
return nil, fmt.Errorf("error getting all NVIDIA devices: %v", err)
}
var filtered []*NvidiaPCIDevice
for _, d := range devices {
if d.IsVGAController() {
filtered = append(filtered, d)
}
}
return filtered, nil
}
// GetNVSwitches returns all NVIDIA NVSwitch PCI devices on the system
func (p *nvpci) GetNVSwitches() ([]*NvidiaPCIDevice, error) {
devices, err := p.GetAllDevices()
if err != nil {
return nil, fmt.Errorf("error getting all NVIDIA devices: %v", err)
}
var filtered []*NvidiaPCIDevice
for _, d := range devices {
if d.IsNVSwitch() {
filtered = append(filtered, d)
}
}
return filtered, nil
}
// GetGPUs returns all NVIDIA GPU devices on the system
func (p *nvpci) GetGPUs() ([]*NvidiaPCIDevice, error) {
devices, err := p.GetAllDevices()
if err != nil {
return nil, fmt.Errorf("error getting all NVIDIA devices: %v", err)
}
var filtered []*NvidiaPCIDevice
for _, d := range devices {
if d.IsGPU() && !d.IsVF {
filtered = append(filtered, d)
}
}
return filtered, nil
}
// GetGPUByIndex returns an NVIDIA GPU device at a particular index
func (p *nvpci) GetGPUByIndex(i int) (*NvidiaPCIDevice, error) {
gpus, err := p.GetGPUs()
if err != nil {
return nil, fmt.Errorf("error getting all gpus: %v", err)
}
if i < 0 || i >= len(gpus) {
return nil, fmt.Errorf("invalid index '%d'", i)
}
return gpus[i], nil
}

View File

@@ -0,0 +1,140 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nvpci
import (
"fmt"
"sort"
"github.com/NVIDIA/go-nvlib/pkg/nvpci/mmio"
)
const (
pmcEndianRegister = 0x4
pmcLittleEndian = 0x0
pmcBigEndian = 0x01000001
)
// MemoryResource represents a mmio region
type MemoryResource struct {
Start uintptr
End uintptr
Flags uint64
Path string
}
// OpenRW read write mmio region
func (mr *MemoryResource) OpenRW() (mmio.Mmio, error) {
rw, err := mmio.OpenRW(mr.Path, 0, int(mr.End-mr.Start+1))
if err != nil {
return nil, fmt.Errorf("failed to open file for mmio: %v", err)
}
switch rw.Read32(pmcEndianRegister) {
case pmcBigEndian:
return rw.BigEndian(), nil
case pmcLittleEndian:
return rw.LittleEndian(), nil
}
return nil, fmt.Errorf("unknown endianness for mmio: %v", err)
}
// OpenRO read only mmio region
func (mr *MemoryResource) OpenRO() (mmio.Mmio, error) {
ro, err := mmio.OpenRO(mr.Path, 0, int(mr.End-mr.Start+1))
if err != nil {
return nil, fmt.Errorf("failed to open file for mmio: %v", err)
}
switch ro.Read32(pmcEndianRegister) {
case pmcBigEndian:
return ro.BigEndian(), nil
case pmcLittleEndian:
return ro.LittleEndian(), nil
}
return nil, fmt.Errorf("unknown endianness for mmio: %v", err)
}
// From Bit Twiddling Hacks, great resource for all low level bit manipulations
func calcNextPowerOf2(n uint64) uint64 {
n--
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n |= n >> 32
n++
return n
}
// GetTotalAddressableMemory will accumulate the 32bit and 64bit memory windows
// of each BAR and round the value if needed to the next power of 2; first
// return value is the accumulated 32bit addresable memory size the second one
// is the accumulated 64bit addressable memory size in bytes. These values are
// needed to configure virtualized environments.
func (mrs MemoryResources) GetTotalAddressableMemory(roundUp bool) (uint64, uint64) {
const pciIOVNumBAR = 6
const pciBaseAddressMemTypeMask = 0x06
const pciBaseAddressMemType32 = 0x00 /* 32 bit address */
const pciBaseAddressMemType64 = 0x04 /* 64 bit address */
// We need to sort the resources so the first 6 entries are the BARs
// How a map is represented in memory is not guaranteed, it is not an
// array. Keys do not have an order.
keys := make([]int, 0, len(mrs))
for k := range mrs {
keys = append(keys, k)
}
sort.Ints(keys)
numBAR := 0
memSize32bit := uint64(0)
memSize64bit := uint64(0)
for _, key := range keys {
// The PCIe spec only defines 5 BARs per device, we're
// discarding everything after the 5th entry of the resources
// file, see lspci.c
if key >= pciIOVNumBAR || numBAR == pciIOVNumBAR {
break
}
numBAR = numBAR + 1
region := mrs[key]
flags := region.Flags & pciBaseAddressMemTypeMask
memType32bit := flags == pciBaseAddressMemType32
memType64bit := flags == pciBaseAddressMemType64
memSize := (region.End - region.Start) + 1
if memType32bit {
memSize32bit = memSize32bit + uint64(memSize)
}
if memType64bit {
memSize64bit = memSize64bit + uint64(memSize)
}
}
if roundUp {
memSize32bit = calcNextPowerOf2(memSize32bit)
memSize64bit = calcNextPowerOf2(memSize64bit)
}
return memSize32bit, memSize64bit
}

File diff suppressed because it is too large Load Diff

444
vendor/github.com/NVIDIA/go-nvlib/pkg/pciids/pciids.go generated vendored Normal file
View File

@@ -0,0 +1,444 @@
package pciids
import (
"bufio"
"bytes"
_ "embed" // Fallback is the embedded pci.ids db file
"fmt"
"io"
"os"
"strconv"
"strings"
)
// token what the Lexer retruns
type token int
const (
// ILLEGAL a token which the Lexer does not understand
ILLEGAL token = iota
// EOF end of file
EOF
// WS whitespace
WS
// NEWLINE '\n'
NEWLINE
// COMMENT '# something'
COMMENT
// VENDOR PCI vendor
VENDOR
// SUBVENDOR PCI subvendor
SUBVENDOR
// DEVICE PCI device
DEVICE
// CLASS PCI class
CLASS
// SUBCLASS PCI subclass
SUBCLASS
// PROGIF PCI programming interface
PROGIF
)
// literal values from the Lexer
type literal struct {
ID string
name string
SubName string
}
// scanner a lexical scanner
type scanner struct {
r *bufio.Reader
isVendor bool
}
// newScanner well a new scanner ...
func newScanner(r io.Reader) *scanner {
return &scanner{r: bufio.NewReader(r)}
}
// Since the pci.ids is line base we're consuming a whole line rather then only
// a single rune/char
func (s *scanner) readline() []byte {
ln, err := s.r.ReadBytes('\n')
if err == io.EOF {
return []byte{'E', 'O', 'F'}
}
if err != nil {
fmt.Printf("ReadBytes failed with %v", err)
return []byte{}
}
return ln
}
func scanClass(line []byte) (token, literal) {
class := string(line[1:])
return CLASS, scanEntry([]byte(class), 2)
}
func scanSubVendor(line []byte) (token, literal) {
trim0 := strings.TrimSpace(string(line))
subv := string(trim0[:4])
trim1 := strings.TrimSpace(trim0[4:])
subd := string(trim1[:4])
subn := strings.TrimSpace(trim1[4:])
return SUBVENDOR, literal{subv, subd, subn}
}
func scanEntry(line []byte, offset uint) literal {
trim := strings.TrimSpace(string(line))
id := string(trim[:offset])
name := strings.TrimSpace(trim[offset:])
return literal{id, name, ""}
}
func isLeadingOneTab(ln []byte) bool { return (ln[0] == '\t') && (ln[1] != '\t') }
func isLeadingTwoTabs(ln []byte) bool { return (ln[0] == '\t') && (ln[1] == '\t') }
func isHexDigit(ln []byte) bool { return (ln[0] >= '0' && ln[0] <= '9') }
func isHexLetter(ln []byte) bool { return (ln[0] >= 'a' && ln[0] <= 'f') }
func isVendor(ln []byte) bool { return isHexDigit(ln) || isHexLetter(ln) }
func isEOF(ln []byte) bool { return (ln[0] == 'E' && ln[1] == 'O' && ln[2] == 'F') }
func isComment(ln []byte) bool { return (ln[0] == '#') }
func isSubVendor(ln []byte) bool { return isLeadingTwoTabs(ln) }
func isDevice(ln []byte) bool { return isLeadingOneTab(ln) }
func isNewline(ln []byte) bool { return (ln[0] == '\n') }
// List of known device classes, subclasses and programming interfaces
func isClass(ln []byte) bool { return (ln[0] == 'C') }
func isProgIf(ln []byte) bool { return isLeadingTwoTabs(ln) }
func isSubClass(ln []byte) bool { return isLeadingOneTab(ln) }
// unread places the previously read rune back on the reader.
func (s *scanner) unread() { _ = s.r.UnreadRune() }
// scan returns the next token and literal value.
func (s *scanner) scan() (tok token, lit literal) {
line := s.readline()
if isEOF(line) {
return EOF, literal{}
}
if isNewline(line) {
return NEWLINE, literal{ID: string('\n')}
}
if isComment(line) {
return COMMENT, literal{ID: string(line)}
}
// vendors
if isVendor(line) {
s.isVendor = true
return VENDOR, scanEntry(line, 4)
}
if isSubVendor(line) && s.isVendor {
return scanSubVendor(line)
}
if isDevice(line) && s.isVendor {
return DEVICE, scanEntry(line, 4)
}
// classes
if isClass(line) {
s.isVendor = false
return scanClass(line)
}
if isProgIf(line) && !s.isVendor {
return PROGIF, scanEntry(line, 2)
}
if isSubClass(line) && !s.isVendor {
return SUBCLASS, scanEntry(line, 2)
}
return ILLEGAL, literal{ID: string(line)}
}
// parser reads the tokens returned by the Lexer and constructs the AST
type parser struct {
s *scanner
buf struct {
tok token
lit literal
n int
}
}
// Various locations of pci.ids for different distributions. These may be more
// up to date then the embedded pci.ids db
var defaultPCIdbPaths = []string{
"/usr/share/misc/pci.ids", // Ubuntu
"/usr/local/share/pci.ids", // RHEL like with manual update
"/usr/share/hwdata/pci.ids", // RHEL like
"/usr/share/pci.ids", // SUSE
}
// This is a fallback if all of the locations fail
//
//go:embed default_pci.ids
var defaultPCIdb []byte
// NewDB Parse the PCI DB in its default locations or use the default
// builtin pci.ids db.
func NewDB(opts ...Option) Interface {
db := &pcidb{}
for _, opt := range opts {
opt(db)
}
pcidbs := defaultPCIdbPaths
if db.path != "" {
pcidbs = append([]string{db.path}, defaultPCIdbPaths...)
}
return newParser(pcidbs).parse()
}
// Option defines a function for passing options to the NewDB() call
type Option func(*pcidb)
// WithFilePath provides an Option to set the file path
// for the pciids database used by pciids interface.
// The file path provided takes precedence over all other
// paths.
func WithFilePath(path string) Option {
return func(db *pcidb) {
db.path = path
}
}
// newParser will attempt to read the db pci.ids from well known places or fall
// back to an internal db
func newParser(pcidbs []string) *parser {
for _, db := range pcidbs {
file, err := os.ReadFile(db)
if err != nil {
continue
}
return newParserFromReader(bufio.NewReader(bytes.NewReader(file)))
}
// We're using go embed above to have the byte array
// correctly initialized with the internal shipped db
// if we cannot find an up to date in the filesystem
return newParserFromReader(bufio.NewReader(bytes.NewReader(defaultPCIdb)))
}
func newParserFromReader(r *bufio.Reader) *parser {
return &parser{s: newScanner(r)}
}
func (p *parser) scan() (tok token, lit literal) {
if p.buf.n != 0 {
p.buf.n = 0
return p.buf.tok, p.buf.lit
}
tok, lit = p.s.scan()
p.buf.tok, p.buf.lit = tok, lit
return
}
func (p *parser) unscan() { p.buf.n = 1 }
var _ Interface = (*pcidb)(nil)
// Interface returns textual description of specific attributes of PCI devices
type Interface interface {
GetDeviceName(uint16, uint16) (string, error)
GetClassName(uint32) (string, error)
}
// GetDeviceName return the textual description of the PCI device
func (d *pcidb) GetDeviceName(vendorID uint16, deviceID uint16) (string, error) {
vendor, ok := d.vendors[vendorID]
if !ok {
return "", fmt.Errorf("failed to find vendor with id '%x'", vendorID)
}
device, ok := vendor.devices[deviceID]
if !ok {
return "", fmt.Errorf("failed to find device with id '%x'", deviceID)
}
return device.name, nil
}
// GetClassName resturn the textual description of the PCI device class
func (d *pcidb) GetClassName(classID uint32) (string, error) {
class, ok := d.classes[classID]
if !ok {
return "", fmt.Errorf("failed to find class with id '%x'", classID)
}
return class.name, nil
}
// pcidb The complete set of PCI vendors and PCI classes
type pcidb struct {
vendors map[uint16]vendor
classes map[uint32]class
path string
}
// vendor PCI vendors/devices/subVendors/SubDevices
type vendor struct {
name string
devices map[uint16]device
}
// subVendor PCI subVendor
type subVendor struct {
SubDevices map[uint16]SubDevice
}
// SubDevice PCI SubDevice
type SubDevice struct {
name string
}
// device PCI device
type device struct {
name string
subVendors map[uint16]subVendor
}
// class PCI classes/subClasses/Programming Interfaces
type class struct {
name string
subClasses map[uint32]subClass
}
// subClass PCI subClass
type subClass struct {
name string
progIfs map[uint8]progIf
}
// progIf PCI Programming Interface
type progIf struct {
name string
}
// parse parses a PCI IDS entry
func (p *parser) parse() Interface {
db := &pcidb{
vendors: map[uint16]vendor{},
classes: map[uint32]class{},
}
// Used for housekeeping, breadcrumb for aggregated types
var hkVendor vendor
var hkDevice device
var hkClass class
var hkSubClass subClass
var hkFullID uint32 = 0
var hkFullName [2]string
for {
tok, lit := p.scan()
// We're ignoring COMMENT, NEWLINE
// An EOF will break the loop
if tok == EOF {
break
}
// PCI vendors -------------------------------------------------
if tok == VENDOR {
id, _ := strconv.ParseUint(lit.ID, 16, 16)
db.vendors[uint16(id)] = vendor{
name: lit.name,
devices: map[uint16]device{},
}
hkVendor = db.vendors[uint16(id)]
}
if tok == DEVICE {
id, _ := strconv.ParseUint(lit.ID, 16, 16)
hkVendor.devices[uint16(id)] = device{
name: lit.name,
subVendors: map[uint16]subVendor{},
}
hkDevice = hkVendor.devices[uint16(id)]
}
if tok == SUBVENDOR {
id, _ := strconv.ParseUint(lit.ID, 16, 16)
hkDevice.subVendors[uint16(id)] = subVendor{
SubDevices: map[uint16]SubDevice{},
}
subvendor := hkDevice.subVendors[uint16(id)]
subid, _ := strconv.ParseUint(lit.name, 16, 16)
subvendor.SubDevices[uint16(subid)] = SubDevice{
name: lit.SubName,
}
}
// PCI classes -------------------------------------------------
if tok == CLASS {
id, _ := strconv.ParseUint(lit.ID, 16, 32)
db.classes[uint32(id)] = class{
name: lit.name,
subClasses: map[uint32]subClass{},
}
hkClass = db.classes[uint32(id)]
hkFullID = uint32(id) << 16
hkFullID = hkFullID & 0xFFFF0000
hkFullName[0] = fmt.Sprintf("%s (%02x)", lit.name, id)
}
if tok == SUBCLASS {
id, _ := strconv.ParseUint(lit.ID, 16, 8)
hkClass.subClasses[uint32(id)] = subClass{
name: lit.name,
progIfs: map[uint8]progIf{},
}
hkSubClass = hkClass.subClasses[uint32(id)]
// Clear the last detected sub class
hkFullID = hkFullID & 0xFFFF0000
hkFullID = hkFullID | uint32(id)<<8
// Clear the last detected prog iface
hkFullID = hkFullID & 0xFFFFFF00
hkFullName[1] = fmt.Sprintf("%s (%02x)", lit.name, id)
db.classes[uint32(hkFullID)] = class{
name: hkFullName[0] + " | " + hkFullName[1],
}
}
if tok == PROGIF {
id, _ := strconv.ParseUint(lit.ID, 16, 8)
hkSubClass.progIfs[uint8(id)] = progIf{
name: lit.name,
}
finalID := hkFullID | uint32(id)
name := fmt.Sprintf("%s (%02x)", lit.name, id)
finalName := hkFullName[0] + " | " + hkFullName[1] + " | " + name
db.classes[finalID] = class{
name: finalName,
}
}
if tok == ILLEGAL {
fmt.Printf("warning: illegal token %s %s cannot parse PCI IDS, database may be incomplete ", lit.ID, lit.name)
}
}
return db
}

View File

@@ -15,7 +15,9 @@
package dl
import (
"errors"
"fmt"
"runtime"
"unsafe"
)
@@ -25,45 +27,72 @@ import (
import "C"
const (
RTLD_LAZY = C.RTLD_LAZY
RTLD_NOW = C.RTLD_NOW
RTLD_GLOBAL = C.RTLD_GLOBAL
RTLD_LOCAL = C.RTLD_LOCAL
RTLD_LAZY = C.RTLD_LAZY
RTLD_NOW = C.RTLD_NOW
RTLD_GLOBAL = C.RTLD_GLOBAL
RTLD_LOCAL = C.RTLD_LOCAL
RTLD_NODELETE = C.RTLD_NODELETE
RTLD_NOLOAD = C.RTLD_NOLOAD
RTLD_DEEPBIND = C.RTLD_DEEPBIND
RTLD_NOLOAD = C.RTLD_NOLOAD
)
type DynamicLibrary struct{
Name string
Flags int
type DynamicLibrary struct {
Name string
Flags int
handle unsafe.Pointer
}
func New(name string, flags int) *DynamicLibrary {
return &DynamicLibrary{
Name: name,
Flags: flags,
Name: name,
Flags: flags,
handle: nil,
}
}
}
func withOSLock(action func() error) error {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
return action()
}
func dlError() error {
lastErr := C.dlerror()
if lastErr == nil {
return nil
}
return errors.New(C.GoString(lastErr))
}
func (dl *DynamicLibrary) Open() error {
name := C.CString(dl.Name)
defer C.free(unsafe.Pointer(name))
handle := C.dlopen(name, C.int(dl.Flags))
if handle == C.NULL {
return fmt.Errorf("%s", C.GoString(C.dlerror()))
if err := withOSLock(func() error {
handle := C.dlopen(name, C.int(dl.Flags))
if handle == nil {
return dlError()
}
dl.handle = handle
return nil
}); err != nil {
return err
}
dl.handle = handle
return nil
}
func (dl *DynamicLibrary) Close() error {
err := C.dlclose(dl.handle)
if err != 0 {
return fmt.Errorf("%s", C.GoString(C.dlerror()))
if dl.handle == nil {
return nil
}
if err := withOSLock(func() error {
if C.dlclose(dl.handle) != 0 {
return dlError()
}
dl.handle = nil
return nil
}); err != nil {
return err
}
return nil
}
@@ -72,11 +101,17 @@ func (dl *DynamicLibrary) Lookup(symbol string) error {
sym := C.CString(symbol)
defer C.free(unsafe.Pointer(sym))
C.dlerror() // Clear out any previous errors
C.dlsym(dl.handle, sym)
err := C.dlerror()
if unsafe.Pointer(err) == C.NULL {
var pointer unsafe.Pointer
if err := withOSLock(func() error {
// Call dlError() to clear out any previous errors.
dlError()
pointer = C.dlsym(dl.handle, sym)
if pointer == nil {
return fmt.Errorf("symbol %q not found: %w", symbol, dlError())
}
return nil
}); err != nil {
return err
}
return fmt.Errorf("%s", C.GoString(err))
return nil
}

26
vendor/github.com/NVIDIA/go-nvml/pkg/dl/dl_linux.go generated vendored Normal file
View File

@@ -0,0 +1,26 @@
/**
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package dl
// #cgo LDFLAGS: -ldl
// #include <dlfcn.h>
// #include <stdlib.h>
import "C"
const (
RTLD_DEEPBIND = C.RTLD_DEEPBIND
)

37
vendor/github.com/NVIDIA/go-nvml/pkg/nvml/api.go generated vendored Normal file
View File

@@ -0,0 +1,37 @@
/**
# Copyright 2023 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvml
// Library defines a set of functions defined on the underlying dynamic library.
type Library interface {
Lookup(string) error
}
// dynamicLibrary is an interface for abstacting the underlying library.
// This also allows for mocking and testing.
//go:generate moq -stub -out dynamicLibrary_mock.go . dynamicLibrary
type dynamicLibrary interface {
Lookup(string) error
Open() error
Close() error
}
// Interface represents the interace for the NVML library.
type Interface interface {
GetLibrary() Library
}

View File

@@ -18,7 +18,8 @@
package nvml
/*
#cgo LDFLAGS: -Wl,--unresolved-symbols=ignore-in-object-files
#cgo linux LDFLAGS: -Wl,--export-dynamic -Wl,--unresolved-symbols=ignore-in-object-files
#cgo darwin LDFLAGS: -Wl,-undefined,dynamic_lookup
#cgo CFLAGS: -DNVML_NO_UNVERSIONED_FUNC_DEFS=1
#include "nvml.h"
#include <stdlib.h>

View File

@@ -0,0 +1,157 @@
// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package nvml
import (
"sync"
)
// Ensure, that dynamicLibraryMock does implement dynamicLibrary.
// If this is not the case, regenerate this file with moq.
var _ dynamicLibrary = &dynamicLibraryMock{}
// dynamicLibraryMock is a mock implementation of dynamicLibrary.
//
// func TestSomethingThatUsesdynamicLibrary(t *testing.T) {
//
// // make and configure a mocked dynamicLibrary
// mockeddynamicLibrary := &dynamicLibraryMock{
// CloseFunc: func() error {
// panic("mock out the Close method")
// },
// LookupFunc: func(s string) error {
// panic("mock out the Lookup method")
// },
// OpenFunc: func() error {
// panic("mock out the Open method")
// },
// }
//
// // use mockeddynamicLibrary in code that requires dynamicLibrary
// // and then make assertions.
//
// }
type dynamicLibraryMock struct {
// CloseFunc mocks the Close method.
CloseFunc func() error
// LookupFunc mocks the Lookup method.
LookupFunc func(s string) error
// OpenFunc mocks the Open method.
OpenFunc func() error
// calls tracks calls to the methods.
calls struct {
// Close holds details about calls to the Close method.
Close []struct {
}
// Lookup holds details about calls to the Lookup method.
Lookup []struct {
// S is the s argument value.
S string
}
// Open holds details about calls to the Open method.
Open []struct {
}
}
lockClose sync.RWMutex
lockLookup sync.RWMutex
lockOpen sync.RWMutex
}
// Close calls CloseFunc.
func (mock *dynamicLibraryMock) Close() error {
callInfo := struct {
}{}
mock.lockClose.Lock()
mock.calls.Close = append(mock.calls.Close, callInfo)
mock.lockClose.Unlock()
if mock.CloseFunc == nil {
var (
errOut error
)
return errOut
}
return mock.CloseFunc()
}
// CloseCalls gets all the calls that were made to Close.
// Check the length with:
//
// len(mockeddynamicLibrary.CloseCalls())
func (mock *dynamicLibraryMock) CloseCalls() []struct {
} {
var calls []struct {
}
mock.lockClose.RLock()
calls = mock.calls.Close
mock.lockClose.RUnlock()
return calls
}
// Lookup calls LookupFunc.
func (mock *dynamicLibraryMock) Lookup(s string) error {
callInfo := struct {
S string
}{
S: s,
}
mock.lockLookup.Lock()
mock.calls.Lookup = append(mock.calls.Lookup, callInfo)
mock.lockLookup.Unlock()
if mock.LookupFunc == nil {
var (
errOut error
)
return errOut
}
return mock.LookupFunc(s)
}
// LookupCalls gets all the calls that were made to Lookup.
// Check the length with:
//
// len(mockeddynamicLibrary.LookupCalls())
func (mock *dynamicLibraryMock) LookupCalls() []struct {
S string
} {
var calls []struct {
S string
}
mock.lockLookup.RLock()
calls = mock.calls.Lookup
mock.lockLookup.RUnlock()
return calls
}
// Open calls OpenFunc.
func (mock *dynamicLibraryMock) Open() error {
callInfo := struct {
}{}
mock.lockOpen.Lock()
mock.calls.Open = append(mock.calls.Open, callInfo)
mock.lockOpen.Unlock()
if mock.OpenFunc == nil {
var (
errOut error
)
return errOut
}
return mock.OpenFunc()
}
// OpenCalls gets all the calls that were made to Open.
// Check the length with:
//
// len(mockeddynamicLibrary.OpenCalls())
func (mock *dynamicLibraryMock) OpenCalls() []struct {
} {
var calls []struct {
}
mock.lockOpen.RLock()
calls = mock.calls.Open
mock.lockOpen.RUnlock()
return calls
}

93
vendor/github.com/NVIDIA/go-nvml/pkg/nvml/gpm.go generated vendored Normal file
View File

@@ -0,0 +1,93 @@
// Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package nvml
// nvml.GpmMetricsGet()
type GpmMetricsGetVType struct {
metricsGet *GpmMetricsGetType
}
func GpmMetricsGetV(MetricsGet *GpmMetricsGetType) GpmMetricsGetVType {
return GpmMetricsGetVType{MetricsGet}
}
func (MetricsGetV GpmMetricsGetVType) V1() Return {
MetricsGetV.metricsGet.Version = 1
return nvmlGpmMetricsGet(MetricsGetV.metricsGet)
}
func GpmMetricsGet(MetricsGet *GpmMetricsGetType) Return {
MetricsGet.Version = GPM_METRICS_GET_VERSION
return nvmlGpmMetricsGet(MetricsGet)
}
// nvml.GpmSampleFree()
func GpmSampleFree(GpmSample GpmSample) Return {
return nvmlGpmSampleFree(GpmSample)
}
// nvml.GpmSampleAlloc()
func GpmSampleAlloc(GpmSample *GpmSample) Return {
return nvmlGpmSampleAlloc(GpmSample)
}
// nvml.GpmSampleGet()
func GpmSampleGet(Device Device, GpmSample GpmSample) Return {
return nvmlGpmSampleGet(Device, GpmSample)
}
func (Device Device) GpmSampleGet(GpmSample GpmSample) Return {
return GpmSampleGet(Device, GpmSample)
}
// nvml.GpmQueryDeviceSupport()
type GpmSupportV struct {
device Device
}
func GpmQueryDeviceSupportV(Device Device) GpmSupportV {
return GpmSupportV{Device}
}
func (Device Device) GpmQueryDeviceSupportV() GpmSupportV {
return GpmSupportV{Device}
}
func (GpmSupportV GpmSupportV) V1() (GpmSupport, Return) {
var GpmSupport GpmSupport
GpmSupport.Version = 1
ret := nvmlGpmQueryDeviceSupport(GpmSupportV.device, &GpmSupport)
return GpmSupport, ret
}
func GpmQueryDeviceSupport(Device Device) (GpmSupport, Return) {
var GpmSupport GpmSupport
GpmSupport.Version = GPM_SUPPORT_VERSION
ret := nvmlGpmQueryDeviceSupport(Device, &GpmSupport)
return GpmSupport, ret
}
func (Device Device) GpmQueryDeviceSupport() (GpmSupport, Return) {
return GpmQueryDeviceSupport(Device)
}
// nvml.GpmMigSampleGet()
func GpmMigSampleGet(Device Device, GpuInstanceId int, GpmSample GpmSample) Return {
return nvmlGpmMigSampleGet(Device, uint32(GpuInstanceId), GpmSample)
}
func (Device Device) GpmMigSampleGet(GpuInstanceId int, GpmSample GpmSample) Return {
return GpmMigSampleGet(Device, GpuInstanceId, GpmSample)
}

View File

@@ -14,45 +14,21 @@
package nvml
import (
"fmt"
"github.com/NVIDIA/go-nvml/pkg/dl"
)
import "C"
const (
nvmlLibraryName = "libnvidia-ml.so.1"
nvmlLibraryLoadFlags = dl.RTLD_LAZY | dl.RTLD_GLOBAL
)
var nvml *dl.DynamicLibrary
// nvml.Init()
func Init() Return {
lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags)
err := lib.Open()
if err != nil {
if err := libnvml.load(); err != nil {
return ERROR_LIBRARY_NOT_FOUND
}
nvml = lib
updateVersionedSymbols()
return nvmlInit()
}
// nvml.InitWithFlags()
func InitWithFlags(Flags uint32) Return {
lib := dl.New(nvmlLibraryName, nvmlLibraryLoadFlags)
err := lib.Open()
if err != nil {
if err := libnvml.load(); err != nil {
return ERROR_LIBRARY_NOT_FOUND
}
nvml = lib
return nvmlInitWithFlags(Flags)
}
@@ -63,156 +39,10 @@ func Shutdown() Return {
return ret
}
err := nvml.Close()
err := libnvml.close()
if err != nil {
panic(fmt.Sprintf("error closing %s: %v", nvmlLibraryName, err))
panic(err)
}
return ret
}
// Default all versioned APIs to v1 (to infer the types)
var nvmlInit = nvmlInit_v1
var nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v1
var nvmlDeviceGetCount = nvmlDeviceGetCount_v1
var nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v1
var nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v1
var nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v1
var nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v1
var nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v1
var nvmlEventSetWait = nvmlEventSetWait_v1
var nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v1
var nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v1
var DeviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v1
var DeviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v1
var DeviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v1
var GetBlacklistDeviceCount = GetExcludedDeviceCount
var GetBlacklistDeviceInfoByIndex = GetExcludedDeviceInfoByIndex
var nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v1
var nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v1
type BlacklistDeviceInfo = ExcludedDeviceInfo
type ProcessInfo_v1Slice []ProcessInfo_v1
type ProcessInfo_v2Slice []ProcessInfo_v2
func (pis ProcessInfo_v1Slice) ToProcessInfoSlice() []ProcessInfo {
var newInfos []ProcessInfo
for _, pi := range pis {
info := ProcessInfo{
Pid: pi.Pid,
UsedGpuMemory: pi.UsedGpuMemory,
GpuInstanceId: 0xFFFFFFFF, // GPU instance ID is invalid in v1
ComputeInstanceId: 0xFFFFFFFF, // Compute instance ID is invalid in v1
}
newInfos = append(newInfos, info)
}
return newInfos
}
func (pis ProcessInfo_v2Slice) ToProcessInfoSlice() []ProcessInfo {
var newInfos []ProcessInfo
for _, pi := range pis {
info := ProcessInfo{
Pid: pi.Pid,
UsedGpuMemory: pi.UsedGpuMemory,
GpuInstanceId: pi.GpuInstanceId,
ComputeInstanceId: pi.ComputeInstanceId,
}
newInfos = append(newInfos, info)
}
return newInfos
}
// updateVersionedSymbols()
func updateVersionedSymbols() {
err := nvml.Lookup("nvmlInit_v2")
if err == nil {
nvmlInit = nvmlInit_v2
}
err = nvml.Lookup("nvmlDeviceGetPciInfo_v2")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v2
}
err = nvml.Lookup("nvmlDeviceGetPciInfo_v3")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v3
}
err = nvml.Lookup("nvmlDeviceGetCount_v2")
if err == nil {
nvmlDeviceGetCount = nvmlDeviceGetCount_v2
}
err = nvml.Lookup("nvmlDeviceGetHandleByIndex_v2")
if err == nil {
nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v2
}
err = nvml.Lookup("nvmlDeviceGetHandleByPciBusId_v2")
if err == nil {
nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v2
}
err = nvml.Lookup("nvmlDeviceGetNvLinkRemotePciInfo_v2")
if err == nil {
nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v2
}
// Unable to overwrite nvmlDeviceRemoveGpu() because the v2 function takes
// a different set of parameters than the v1 function.
//err = nvml.Lookup("nvmlDeviceRemoveGpu_v2")
//if err == nil {
// nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v2
//}
err = nvml.Lookup("nvmlDeviceGetGridLicensableFeatures_v2")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v2
}
err = nvml.Lookup("nvmlDeviceGetGridLicensableFeatures_v3")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v3
}
err = nvml.Lookup("nvmlDeviceGetGridLicensableFeatures_v4")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v4
}
err = nvml.Lookup("nvmlEventSetWait_v2")
if err == nil {
nvmlEventSetWait = nvmlEventSetWait_v2
}
err = nvml.Lookup("nvmlDeviceGetAttributes_v2")
if err == nil {
nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v2
}
err = nvml.Lookup("nvmlComputeInstanceGetInfo_v2")
if err == nil {
nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v2
}
err = nvml.Lookup("nvmlDeviceGetComputeRunningProcesses_v2")
if err == nil {
DeviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v2
}
err = nvml.Lookup("nvmlDeviceGetComputeRunningProcesses_v3")
if err == nil {
DeviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v3
}
err = nvml.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v2")
if err == nil {
DeviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v2
}
err = nvml.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v3")
if err == nil {
DeviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v3
}
err = nvml.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v2")
if err == nil {
DeviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v2
}
err = nvml.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v3")
if err == nil {
DeviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v3
}
err = nvml.Lookup("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
if err == nil {
nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v2
}
err = nvml.Lookup("nvmlVgpuInstanceGetLicenseInfo_v2")
if err == nil {
nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v2
}
}

302
vendor/github.com/NVIDIA/go-nvml/pkg/nvml/lib.go generated vendored Normal file
View File

@@ -0,0 +1,302 @@
/**
# Copyright 2023 NVIDIA CORPORATION
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
**/
package nvml
import (
"errors"
"fmt"
"sync"
"github.com/NVIDIA/go-nvml/pkg/dl"
)
import "C"
const (
defaultNvmlLibraryName = "libnvidia-ml.so.1"
defaultNvmlLibraryLoadFlags = dl.RTLD_LAZY | dl.RTLD_GLOBAL
)
var errLibraryNotLoaded = errors.New("library not loaded")
var errLibraryAlreadyLoaded = errors.New("library already loaded")
// library represents an nvml library.
// This includes a reference to the underlying DynamicLibrary
type library struct {
sync.Mutex
path string
flags int
dl dynamicLibrary
}
// libnvml is a global instance of the nvml library.
var libnvml = library{
path: defaultNvmlLibraryName,
flags: defaultNvmlLibraryLoadFlags,
}
var _ Interface = (*library)(nil)
// GetLibrary returns a the library as a Library interface.
func (l *library) GetLibrary() Library {
return l
}
// GetLibrary returns a representation of the underlying library that implements the Library interface.
func GetLibrary() Library {
return libnvml.GetLibrary()
}
// Lookup checks whether the specified library symbol exists in the library.
// Note that this requires that the library be loaded.
func (l *library) Lookup(name string) error {
if l == nil || l.dl == nil {
return fmt.Errorf("error looking up %s: %w", name, errLibraryNotLoaded)
}
return l.dl.Lookup(name)
}
// newDynamicLibrary is a function variable that can be overridden for testing.
var newDynamicLibrary = func(path string, flags int) dynamicLibrary {
return dl.New(path, flags)
}
// load initializes the library and updates the versioned symbols.
// Multiple calls to an already loaded library will return without error.
func (l *library) load() error {
l.Lock()
defer l.Unlock()
if l.dl != nil {
return nil
}
dl := newDynamicLibrary(l.path, l.flags)
err := dl.Open()
if err != nil {
return fmt.Errorf("error opening %s: %w", l.path, err)
}
l.dl = dl
l.updateVersionedSymbols()
return nil
}
// close the underlying library and ensure that the global pointer to the
// library is set to nil to ensure that subsequent calls to open will reinitialize it.
// Multiple calls to an already closed nvml library will return without error.
func (l *library) close() error {
l.Lock()
defer l.Unlock()
if l.dl == nil {
return nil
}
err := l.dl.Close()
if err != nil {
return fmt.Errorf("error closing %s: %w", l.path, err)
}
l.dl = nil
return nil
}
// Default all versioned APIs to v1 (to infer the types)
var nvmlInit = nvmlInit_v1
var nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v1
var nvmlDeviceGetCount = nvmlDeviceGetCount_v1
var nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v1
var nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v1
var nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v1
var nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v1
var nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v1
var nvmlEventSetWait = nvmlEventSetWait_v1
var nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v1
var nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v1
var DeviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v1
var DeviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v1
var DeviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v1
var GetBlacklistDeviceCount = GetExcludedDeviceCount
var GetBlacklistDeviceInfoByIndex = GetExcludedDeviceInfoByIndex
var nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v1
var nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v1
// BlacklistDeviceInfo was replaced by ExcludedDeviceInfo
type BlacklistDeviceInfo = ExcludedDeviceInfo
type ProcessInfo_v1Slice []ProcessInfo_v1
type ProcessInfo_v2Slice []ProcessInfo_v2
func (pis ProcessInfo_v1Slice) ToProcessInfoSlice() []ProcessInfo {
var newInfos []ProcessInfo
for _, pi := range pis {
info := ProcessInfo{
Pid: pi.Pid,
UsedGpuMemory: pi.UsedGpuMemory,
GpuInstanceId: 0xFFFFFFFF, // GPU instance ID is invalid in v1
ComputeInstanceId: 0xFFFFFFFF, // Compute instance ID is invalid in v1
}
newInfos = append(newInfos, info)
}
return newInfos
}
func (pis ProcessInfo_v2Slice) ToProcessInfoSlice() []ProcessInfo {
var newInfos []ProcessInfo
for _, pi := range pis {
info := ProcessInfo{
Pid: pi.Pid,
UsedGpuMemory: pi.UsedGpuMemory,
GpuInstanceId: pi.GpuInstanceId,
ComputeInstanceId: pi.ComputeInstanceId,
}
newInfos = append(newInfos, info)
}
return newInfos
}
// updateVersionedSymbols checks for versioned symbols in the loaded dynamic library.
// If newer versioned symbols exist, these replace the default `v1` symbols initialized above.
// When new versioned symbols are added, these would have to be initialized above and have
// corresponding checks and subsequent assignments added below.
func (l *library) updateVersionedSymbols() {
err := l.Lookup("nvmlInit_v2")
if err == nil {
nvmlInit = nvmlInit_v2
}
err = l.Lookup("nvmlDeviceGetPciInfo_v2")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v2
}
err = l.Lookup("nvmlDeviceGetPciInfo_v3")
if err == nil {
nvmlDeviceGetPciInfo = nvmlDeviceGetPciInfo_v3
}
err = l.Lookup("nvmlDeviceGetCount_v2")
if err == nil {
nvmlDeviceGetCount = nvmlDeviceGetCount_v2
}
err = l.Lookup("nvmlDeviceGetHandleByIndex_v2")
if err == nil {
nvmlDeviceGetHandleByIndex = nvmlDeviceGetHandleByIndex_v2
}
err = l.Lookup("nvmlDeviceGetHandleByPciBusId_v2")
if err == nil {
nvmlDeviceGetHandleByPciBusId = nvmlDeviceGetHandleByPciBusId_v2
}
err = l.Lookup("nvmlDeviceGetNvLinkRemotePciInfo_v2")
if err == nil {
nvmlDeviceGetNvLinkRemotePciInfo = nvmlDeviceGetNvLinkRemotePciInfo_v2
}
// Unable to overwrite nvmlDeviceRemoveGpu() because the v2 function takes
// a different set of parameters than the v1 function.
//err = l.Lookup("nvmlDeviceRemoveGpu_v2")
//if err == nil {
// nvmlDeviceRemoveGpu = nvmlDeviceRemoveGpu_v2
//}
err = l.Lookup("nvmlDeviceGetGridLicensableFeatures_v2")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v2
}
err = l.Lookup("nvmlDeviceGetGridLicensableFeatures_v3")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v3
}
err = l.Lookup("nvmlDeviceGetGridLicensableFeatures_v4")
if err == nil {
nvmlDeviceGetGridLicensableFeatures = nvmlDeviceGetGridLicensableFeatures_v4
}
err = l.Lookup("nvmlEventSetWait_v2")
if err == nil {
nvmlEventSetWait = nvmlEventSetWait_v2
}
err = l.Lookup("nvmlDeviceGetAttributes_v2")
if err == nil {
nvmlDeviceGetAttributes = nvmlDeviceGetAttributes_v2
}
err = l.Lookup("nvmlComputeInstanceGetInfo_v2")
if err == nil {
nvmlComputeInstanceGetInfo = nvmlComputeInstanceGetInfo_v2
}
err = l.Lookup("nvmlDeviceGetComputeRunningProcesses_v2")
if err == nil {
DeviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v2
}
err = l.Lookup("nvmlDeviceGetComputeRunningProcesses_v3")
if err == nil {
DeviceGetComputeRunningProcesses = deviceGetComputeRunningProcesses_v3
}
err = l.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v2")
if err == nil {
DeviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v2
}
err = l.Lookup("nvmlDeviceGetGraphicsRunningProcesses_v3")
if err == nil {
DeviceGetGraphicsRunningProcesses = deviceGetGraphicsRunningProcesses_v3
}
err = l.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v2")
if err == nil {
DeviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v2
}
err = l.Lookup("nvmlDeviceGetMPSComputeRunningProcesses_v3")
if err == nil {
DeviceGetMPSComputeRunningProcesses = deviceGetMPSComputeRunningProcesses_v3
}
err = l.Lookup("nvmlDeviceGetGpuInstancePossiblePlacements_v2")
if err == nil {
nvmlDeviceGetGpuInstancePossiblePlacements = nvmlDeviceGetGpuInstancePossiblePlacements_v2
}
err = l.Lookup("nvmlVgpuInstanceGetLicenseInfo_v2")
if err == nil {
nvmlVgpuInstanceGetLicenseInfo = nvmlVgpuInstanceGetLicenseInfo_v2
}
}
// LibraryOption represents a functional option to configure the underlying NVML library
type LibraryOption func(*library)
// WithLibraryPath provides an option to set the library name to be used by the NVML library.
func WithLibraryPath(path string) LibraryOption {
return func(l *library) {
l.path = path
}
}
// SetLibraryOptions applies the specified options to the NVML library.
// If this is called when a library is already loaded, and error is raised.
func SetLibraryOptions(opts ...LibraryOption) error {
libnvml.Lock()
defer libnvml.Unlock()
if libnvml.dl != nil {
return errLibraryAlreadyLoaded
}
for _, opt := range opts {
opt(&libnvml)
}
if libnvml.path == "" {
libnvml.path = defaultNvmlLibraryName
}
if libnvml.flags == 0 {
libnvml.flags = defaultNvmlLibraryLoadFlags
}
return nil
}

View File

@@ -18,7 +18,8 @@
package nvml
/*
#cgo LDFLAGS: -Wl,--unresolved-symbols=ignore-in-object-files
#cgo linux LDFLAGS: -Wl,--export-dynamic -Wl,--unresolved-symbols=ignore-in-object-files
#cgo darwin LDFLAGS: -Wl,-undefined,dynamic_lookup
#cgo CFLAGS: -DNVML_NO_UNVERSIONED_FUNC_DEFS=1
#include "nvml.h"
#include <stdlib.h>

View File

@@ -438,7 +438,7 @@ func GetVgpuVersion() (VgpuVersion, VgpuVersion, Return) {
// nvml.SetVgpuVersion()
func SetVgpuVersion(VgpuVersion *VgpuVersion) Return {
return SetVgpuVersion(VgpuVersion)
return nvmlSetVgpuVersion(VgpuVersion)
}
// nvml.VgpuInstanceClearAccountingPids()

View File

@@ -352,9 +352,9 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
// Greater asserts that the first element is greater than the second
//
// assert.Greater(t, 2, 1)
// assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a")
// assert.Greater(t, 2, 1)
// assert.Greater(t, float64(2), float64(1))
// assert.Greater(t, "b", "a")
func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -364,10 +364,10 @@ func Greater(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface
// GreaterOrEqual asserts that the first element is greater than or equal to the second
//
// assert.GreaterOrEqual(t, 2, 1)
// assert.GreaterOrEqual(t, 2, 2)
// assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b")
// assert.GreaterOrEqual(t, 2, 1)
// assert.GreaterOrEqual(t, 2, 2)
// assert.GreaterOrEqual(t, "b", "a")
// assert.GreaterOrEqual(t, "b", "b")
func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -377,9 +377,9 @@ func GreaterOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...in
// Less asserts that the first element is less than the second
//
// assert.Less(t, 1, 2)
// assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b")
// assert.Less(t, 1, 2)
// assert.Less(t, float64(1), float64(2))
// assert.Less(t, "a", "b")
func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -389,10 +389,10 @@ func Less(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{})
// LessOrEqual asserts that the first element is less than or equal to the second
//
// assert.LessOrEqual(t, 1, 2)
// assert.LessOrEqual(t, 2, 2)
// assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b")
// assert.LessOrEqual(t, 1, 2)
// assert.LessOrEqual(t, 2, 2)
// assert.LessOrEqual(t, "a", "b")
// assert.LessOrEqual(t, "b", "b")
func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -402,8 +402,8 @@ func LessOrEqual(t TestingT, e1 interface{}, e2 interface{}, msgAndArgs ...inter
// Positive asserts that the specified element is positive
//
// assert.Positive(t, 1)
// assert.Positive(t, 1.23)
// assert.Positive(t, 1)
// assert.Positive(t, 1.23)
func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -414,8 +414,8 @@ func Positive(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
// Negative asserts that the specified element is negative
//
// assert.Negative(t, -1)
// assert.Negative(t, -1.23)
// assert.Negative(t, -1)
// assert.Negative(t, -1.23)
func Negative(t TestingT, e interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()

View File

@@ -22,9 +22,9 @@ func Conditionf(t TestingT, comp Comparison, msg string, args ...interface{}) bo
// Containsf asserts that the specified string, list(array, slice...) or map contains the
// specified substring or element.
//
// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted")
// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted")
// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted")
// assert.Containsf(t, "Hello World", "World", "error message %s", "formatted")
// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted")
// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted")
func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -56,7 +56,7 @@ func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string
// Emptyf asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
//
// assert.Emptyf(t, obj, "error message %s", "formatted")
// assert.Emptyf(t, obj, "error message %s", "formatted")
func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -66,7 +66,7 @@ func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) boo
// Equalf asserts that two objects are equal.
//
// assert.Equalf(t, 123, 123, "error message %s", "formatted")
// assert.Equalf(t, 123, 123, "error message %s", "formatted")
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses). Function equality
@@ -81,8 +81,8 @@ func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, ar
// EqualErrorf asserts that a function returned an error (i.e. not `nil`)
// and that it is equal to the provided error.
//
// actualObj, err := SomeFunction()
// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted")
// actualObj, err := SomeFunction()
// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted")
func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -90,10 +90,27 @@ func EqualErrorf(t TestingT, theError error, errString string, msg string, args
return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...)
}
// EqualExportedValuesf asserts that the types of two objects are equal and their public
// fields are also equal. This is useful for comparing structs that have private fields
// that could potentially differ.
//
// type S struct {
// Exported int
// notExported int
// }
// assert.EqualExportedValuesf(t, S{1, 2}, S{1, 3}, "error message %s", "formatted") => true
// assert.EqualExportedValuesf(t, S{1, 2}, S{2, 3}, "error message %s", "formatted") => false
func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...)
}
// EqualValuesf asserts that two objects are equal or convertable to the same types
// and equal.
//
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -103,10 +120,10 @@ func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg stri
// Errorf asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if assert.Errorf(t, err, "error message %s", "formatted") {
// assert.Equal(t, expectedErrorf, err)
// }
// actualObj, err := SomeFunction()
// if assert.Errorf(t, err, "error message %s", "formatted") {
// assert.Equal(t, expectedErrorf, err)
// }
func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -126,8 +143,8 @@ func ErrorAsf(t TestingT, err error, target interface{}, msg string, args ...int
// ErrorContainsf asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
// actualObj, err := SomeFunction()
// assert.ErrorContainsf(t, err, expectedErrorSubString, "error message %s", "formatted")
func ErrorContainsf(t TestingT, theError error, contains string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -147,7 +164,7 @@ func ErrorIsf(t TestingT, err error, target error, msg string, args ...interface
// Eventuallyf asserts that given condition will be met in waitFor time,
// periodically checking target function each tick.
//
// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
// assert.Eventuallyf(t, func() bool { return true; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -155,9 +172,34 @@ func Eventuallyf(t TestingT, condition func() bool, waitFor time.Duration, tick
return Eventually(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...)
}
// EventuallyWithTf asserts that given condition will be met in waitFor time,
// periodically checking target function each tick. In contrast to Eventually,
// it supplies a CollectT to the condition function, so that the condition
// function can use the CollectT to call other assertions.
// The condition is considered "met" if no errors are raised in a tick.
// The supplied CollectT collects all errors from one tick (if there are any).
// If the condition is not met before waitFor, the collected errors of
// the last tick are copied to t.
//
// externalValue := false
// go func() {
// time.Sleep(8*time.Second)
// externalValue = true
// }()
// assert.EventuallyWithTf(t, func(c *assert.CollectT, "error message %s", "formatted") {
// // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false")
func EventuallyWithTf(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return EventuallyWithT(t, condition, waitFor, tick, append([]interface{}{msg}, args...)...)
}
// Exactlyf asserts that two objects are equal in value and type.
//
// assert.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted")
// assert.Exactlyf(t, int32(123), int64(123), "error message %s", "formatted")
func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -183,7 +225,7 @@ func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}
// Falsef asserts that the specified value is false.
//
// assert.Falsef(t, myBool, "error message %s", "formatted")
// assert.Falsef(t, myBool, "error message %s", "formatted")
func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -202,9 +244,9 @@ func FileExistsf(t TestingT, path string, msg string, args ...interface{}) bool
// Greaterf asserts that the first element is greater than the second
//
// assert.Greaterf(t, 2, 1, "error message %s", "formatted")
// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted")
// assert.Greaterf(t, "b", "a", "error message %s", "formatted")
// assert.Greaterf(t, 2, 1, "error message %s", "formatted")
// assert.Greaterf(t, float64(2), float64(1), "error message %s", "formatted")
// assert.Greaterf(t, "b", "a", "error message %s", "formatted")
func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -214,10 +256,10 @@ func Greaterf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...in
// GreaterOrEqualf asserts that the first element is greater than or equal to the second
//
// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted")
// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted")
// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted")
// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted")
// assert.GreaterOrEqualf(t, 2, 1, "error message %s", "formatted")
// assert.GreaterOrEqualf(t, 2, 2, "error message %s", "formatted")
// assert.GreaterOrEqualf(t, "b", "a", "error message %s", "formatted")
// assert.GreaterOrEqualf(t, "b", "b", "error message %s", "formatted")
func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -228,7 +270,7 @@ func GreaterOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, arg
// HTTPBodyContainsf asserts that a specified handler returns a
// body that contains a string.
//
// assert.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
// assert.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool {
@@ -241,7 +283,7 @@ func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url
// HTTPBodyNotContainsf asserts that a specified handler returns a
// body that does not contain a string.
//
// assert.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
// assert.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool {
@@ -253,7 +295,7 @@ func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, u
// HTTPErrorf asserts that a specified handler returns an error status code.
//
// assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
// assert.HTTPErrorf(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
@@ -265,7 +307,7 @@ func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string,
// HTTPRedirectf asserts that a specified handler returns a redirect status code.
//
// assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
// assert.HTTPRedirectf(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
@@ -277,7 +319,7 @@ func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url stri
// HTTPStatusCodef asserts that a specified handler returns a specified status code.
//
// assert.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted")
// assert.HTTPStatusCodef(t, myHandler, "GET", "/notImplemented", nil, 501, "error message %s", "formatted")
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, statuscode int, msg string, args ...interface{}) bool {
@@ -289,7 +331,7 @@ func HTTPStatusCodef(t TestingT, handler http.HandlerFunc, method string, url st
// HTTPSuccessf asserts that a specified handler returns a success status code.
//
// assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted")
// assert.HTTPSuccessf(t, myHandler, "POST", "http://www.google.com", nil, "error message %s", "formatted")
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
@@ -301,7 +343,7 @@ func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url strin
// Implementsf asserts that an object is implemented by the specified interface.
//
// assert.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
// assert.Implementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -311,7 +353,7 @@ func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, ms
// InDeltaf asserts that the two numerals are within delta of each other.
//
// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted")
// assert.InDeltaf(t, math.Pi, 22/7.0, 0.01, "error message %s", "formatted")
func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -353,9 +395,9 @@ func InEpsilonSlicef(t TestingT, expected interface{}, actual interface{}, epsil
// IsDecreasingf asserts that the collection is decreasing
//
// assert.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted")
// assert.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted")
// assert.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted")
// assert.IsDecreasingf(t, []int{2, 1, 0}, "error message %s", "formatted")
// assert.IsDecreasingf(t, []float{2, 1}, "error message %s", "formatted")
// assert.IsDecreasingf(t, []string{"b", "a"}, "error message %s", "formatted")
func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -365,9 +407,9 @@ func IsDecreasingf(t TestingT, object interface{}, msg string, args ...interface
// IsIncreasingf asserts that the collection is increasing
//
// assert.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted")
// assert.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted")
// assert.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted")
// assert.IsIncreasingf(t, []int{1, 2, 3}, "error message %s", "formatted")
// assert.IsIncreasingf(t, []float{1, 2}, "error message %s", "formatted")
// assert.IsIncreasingf(t, []string{"a", "b"}, "error message %s", "formatted")
func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -377,9 +419,9 @@ func IsIncreasingf(t TestingT, object interface{}, msg string, args ...interface
// IsNonDecreasingf asserts that the collection is not decreasing
//
// assert.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted")
// assert.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted")
// assert.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted")
// assert.IsNonDecreasingf(t, []int{1, 1, 2}, "error message %s", "formatted")
// assert.IsNonDecreasingf(t, []float{1, 2}, "error message %s", "formatted")
// assert.IsNonDecreasingf(t, []string{"a", "b"}, "error message %s", "formatted")
func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -389,9 +431,9 @@ func IsNonDecreasingf(t TestingT, object interface{}, msg string, args ...interf
// IsNonIncreasingf asserts that the collection is not increasing
//
// assert.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted")
// assert.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted")
// assert.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted")
// assert.IsNonIncreasingf(t, []int{2, 1, 1}, "error message %s", "formatted")
// assert.IsNonIncreasingf(t, []float{2, 1}, "error message %s", "formatted")
// assert.IsNonIncreasingf(t, []string{"b", "a"}, "error message %s", "formatted")
func IsNonIncreasingf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -409,7 +451,7 @@ func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg strin
// JSONEqf asserts that two JSON strings are equivalent.
//
// assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted")
// assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted")
func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -420,7 +462,7 @@ func JSONEqf(t TestingT, expected string, actual string, msg string, args ...int
// Lenf asserts that the specified object has specific length.
// Lenf also fails if the object has a type that len() not accept.
//
// assert.Lenf(t, mySlice, 3, "error message %s", "formatted")
// assert.Lenf(t, mySlice, 3, "error message %s", "formatted")
func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -430,9 +472,9 @@ func Lenf(t TestingT, object interface{}, length int, msg string, args ...interf
// Lessf asserts that the first element is less than the second
//
// assert.Lessf(t, 1, 2, "error message %s", "formatted")
// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted")
// assert.Lessf(t, "a", "b", "error message %s", "formatted")
// assert.Lessf(t, 1, 2, "error message %s", "formatted")
// assert.Lessf(t, float64(1), float64(2), "error message %s", "formatted")
// assert.Lessf(t, "a", "b", "error message %s", "formatted")
func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -442,10 +484,10 @@ func Lessf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...inter
// LessOrEqualf asserts that the first element is less than or equal to the second
//
// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted")
// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted")
// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted")
// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted")
// assert.LessOrEqualf(t, 1, 2, "error message %s", "formatted")
// assert.LessOrEqualf(t, 2, 2, "error message %s", "formatted")
// assert.LessOrEqualf(t, "a", "b", "error message %s", "formatted")
// assert.LessOrEqualf(t, "b", "b", "error message %s", "formatted")
func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -455,8 +497,8 @@ func LessOrEqualf(t TestingT, e1 interface{}, e2 interface{}, msg string, args .
// Negativef asserts that the specified element is negative
//
// assert.Negativef(t, -1, "error message %s", "formatted")
// assert.Negativef(t, -1.23, "error message %s", "formatted")
// assert.Negativef(t, -1, "error message %s", "formatted")
// assert.Negativef(t, -1.23, "error message %s", "formatted")
func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -467,7 +509,7 @@ func Negativef(t TestingT, e interface{}, msg string, args ...interface{}) bool
// Neverf asserts that the given condition doesn't satisfy in waitFor time,
// periodically checking the target function each tick.
//
// assert.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
// assert.Neverf(t, func() bool { return false; }, time.Second, 10*time.Millisecond, "error message %s", "formatted")
func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -477,7 +519,7 @@ func Neverf(t TestingT, condition func() bool, waitFor time.Duration, tick time.
// Nilf asserts that the specified object is nil.
//
// assert.Nilf(t, err, "error message %s", "formatted")
// assert.Nilf(t, err, "error message %s", "formatted")
func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -496,10 +538,10 @@ func NoDirExistsf(t TestingT, path string, msg string, args ...interface{}) bool
// NoErrorf asserts that a function returned no error (i.e. `nil`).
//
// actualObj, err := SomeFunction()
// if assert.NoErrorf(t, err, "error message %s", "formatted") {
// assert.Equal(t, expectedObj, actualObj)
// }
// actualObj, err := SomeFunction()
// if assert.NoErrorf(t, err, "error message %s", "formatted") {
// assert.Equal(t, expectedObj, actualObj)
// }
func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -519,9 +561,9 @@ func NoFileExistsf(t TestingT, path string, msg string, args ...interface{}) boo
// NotContainsf asserts that the specified string, list(array, slice...) or map does NOT contain the
// specified substring or element.
//
// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted")
// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted")
// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted")
// assert.NotContainsf(t, "Hello World", "Earth", "error message %s", "formatted")
// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted")
// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted")
func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -532,9 +574,9 @@ func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, a
// NotEmptyf asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
//
// if assert.NotEmptyf(t, obj, "error message %s", "formatted") {
// assert.Equal(t, "two", obj[1])
// }
// if assert.NotEmptyf(t, obj, "error message %s", "formatted") {
// assert.Equal(t, "two", obj[1])
// }
func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -544,7 +586,7 @@ func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{})
// NotEqualf asserts that the specified values are NOT equal.
//
// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted")
// assert.NotEqualf(t, obj1, obj2, "error message %s", "formatted")
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses).
@@ -557,7 +599,7 @@ func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string,
// NotEqualValuesf asserts that two objects are not equal even when converted to the same type
//
// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted")
// assert.NotEqualValuesf(t, obj1, obj2, "error message %s", "formatted")
func NotEqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -576,7 +618,7 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf
// NotNilf asserts that the specified object is not nil.
//
// assert.NotNilf(t, err, "error message %s", "formatted")
// assert.NotNilf(t, err, "error message %s", "formatted")
func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -586,7 +628,7 @@ func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bo
// NotPanicsf asserts that the code inside the specified PanicTestFunc does NOT panic.
//
// assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted")
// assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted")
func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -596,8 +638,8 @@ func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bo
// NotRegexpf asserts that a specified regexp does not match a string.
//
// assert.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted")
// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted")
// assert.NotRegexpf(t, regexp.MustCompile("starts"), "it's starting", "error message %s", "formatted")
// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted")
func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -607,7 +649,7 @@ func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ..
// NotSamef asserts that two pointers do not reference the same object.
//
// assert.NotSamef(t, ptr1, ptr2, "error message %s", "formatted")
// assert.NotSamef(t, ptr1, ptr2, "error message %s", "formatted")
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
@@ -621,7 +663,7 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
// NotSubsetf asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
//
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted")
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -639,7 +681,7 @@ func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) bool {
// Panicsf asserts that the code inside the specified PanicTestFunc panics.
//
// assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted")
// assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted")
func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -651,7 +693,7 @@ func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool
// panics, and that the recovered panic value is an error that satisfies the
// EqualError comparison.
//
// assert.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
// assert.PanicsWithErrorf(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
func PanicsWithErrorf(t TestingT, errString string, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -662,7 +704,7 @@ func PanicsWithErrorf(t TestingT, errString string, f PanicTestFunc, msg string,
// PanicsWithValuef asserts that the code inside the specified PanicTestFunc panics, and that
// the recovered panic value equals the expected panic value.
//
// assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
// assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -672,8 +714,8 @@ func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg str
// Positivef asserts that the specified element is positive
//
// assert.Positivef(t, 1, "error message %s", "formatted")
// assert.Positivef(t, 1.23, "error message %s", "formatted")
// assert.Positivef(t, 1, "error message %s", "formatted")
// assert.Positivef(t, 1.23, "error message %s", "formatted")
func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -683,8 +725,8 @@ func Positivef(t TestingT, e interface{}, msg string, args ...interface{}) bool
// Regexpf asserts that a specified regexp matches a string.
//
// assert.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted")
// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted")
// assert.Regexpf(t, regexp.MustCompile("start"), "it's starting", "error message %s", "formatted")
// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted")
func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -694,7 +736,7 @@ func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...in
// Samef asserts that two pointers reference the same object.
//
// assert.Samef(t, ptr1, ptr2, "error message %s", "formatted")
// assert.Samef(t, ptr1, ptr2, "error message %s", "formatted")
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
@@ -708,7 +750,7 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
// Subsetf asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
//
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted")
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -718,7 +760,7 @@ func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args
// Truef asserts that the specified value is true.
//
// assert.Truef(t, myBool, "error message %s", "formatted")
// assert.Truef(t, myBool, "error message %s", "formatted")
func Truef(t TestingT, value bool, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -728,7 +770,7 @@ func Truef(t TestingT, value bool, msg string, args ...interface{}) bool {
// WithinDurationf asserts that the two times are within duration delta of each other.
//
// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted")
// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted")
func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -738,7 +780,7 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim
// WithinRangef asserts that a time is within a time range (inclusive).
//
// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
// assert.WithinRangef(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second), "error message %s", "formatted")
func WithinRangef(t TestingT, actual time.Time, start time.Time, end time.Time, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()

File diff suppressed because it is too large Load Diff

View File

@@ -46,36 +46,36 @@ func isOrdered(t TestingT, object interface{}, allowedComparesResults []CompareT
// IsIncreasing asserts that the collection is increasing
//
// assert.IsIncreasing(t, []int{1, 2, 3})
// assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"})
// assert.IsIncreasing(t, []int{1, 2, 3})
// assert.IsIncreasing(t, []float{1, 2})
// assert.IsIncreasing(t, []string{"a", "b"})
func IsIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess}, "\"%v\" is not less than \"%v\"", msgAndArgs...)
}
// IsNonIncreasing asserts that the collection is not increasing
//
// assert.IsNonIncreasing(t, []int{2, 1, 1})
// assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"})
// assert.IsNonIncreasing(t, []int{2, 1, 1})
// assert.IsNonIncreasing(t, []float{2, 1})
// assert.IsNonIncreasing(t, []string{"b", "a"})
func IsNonIncreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareEqual, compareGreater}, "\"%v\" is not greater than or equal to \"%v\"", msgAndArgs...)
}
// IsDecreasing asserts that the collection is decreasing
//
// assert.IsDecreasing(t, []int{2, 1, 0})
// assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"})
// assert.IsDecreasing(t, []int{2, 1, 0})
// assert.IsDecreasing(t, []float{2, 1})
// assert.IsDecreasing(t, []string{"b", "a"})
func IsDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareGreater}, "\"%v\" is not greater than \"%v\"", msgAndArgs...)
}
// IsNonDecreasing asserts that the collection is not decreasing
//
// assert.IsNonDecreasing(t, []int{1, 1, 2})
// assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"})
// assert.IsNonDecreasing(t, []int{1, 1, 2})
// assert.IsNonDecreasing(t, []float{1, 2})
// assert.IsNonDecreasing(t, []string{"a", "b"})
func IsNonDecreasing(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return isOrdered(t, object, []CompareType{compareLess, compareEqual}, "\"%v\" is not less than or equal to \"%v\"", msgAndArgs...)
}

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"math"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
@@ -76,6 +75,77 @@ func ObjectsAreEqual(expected, actual interface{}) bool {
return bytes.Equal(exp, act)
}
// copyExportedFields iterates downward through nested data structures and creates a copy
// that only contains the exported struct fields.
func copyExportedFields(expected interface{}) interface{} {
if isNil(expected) {
return expected
}
expectedType := reflect.TypeOf(expected)
expectedKind := expectedType.Kind()
expectedValue := reflect.ValueOf(expected)
switch expectedKind {
case reflect.Struct:
result := reflect.New(expectedType).Elem()
for i := 0; i < expectedType.NumField(); i++ {
field := expectedType.Field(i)
isExported := field.IsExported()
if isExported {
fieldValue := expectedValue.Field(i)
if isNil(fieldValue) || isNil(fieldValue.Interface()) {
continue
}
newValue := copyExportedFields(fieldValue.Interface())
result.Field(i).Set(reflect.ValueOf(newValue))
}
}
return result.Interface()
case reflect.Ptr:
result := reflect.New(expectedType.Elem())
unexportedRemoved := copyExportedFields(expectedValue.Elem().Interface())
result.Elem().Set(reflect.ValueOf(unexportedRemoved))
return result.Interface()
case reflect.Array, reflect.Slice:
result := reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len())
for i := 0; i < expectedValue.Len(); i++ {
index := expectedValue.Index(i)
if isNil(index) {
continue
}
unexportedRemoved := copyExportedFields(index.Interface())
result.Index(i).Set(reflect.ValueOf(unexportedRemoved))
}
return result.Interface()
case reflect.Map:
result := reflect.MakeMap(expectedType)
for _, k := range expectedValue.MapKeys() {
index := expectedValue.MapIndex(k)
unexportedRemoved := copyExportedFields(index.Interface())
result.SetMapIndex(k, reflect.ValueOf(unexportedRemoved))
}
return result.Interface()
default:
return expected
}
}
// ObjectsExportedFieldsAreEqual determines if the exported (public) fields of two objects are
// considered equal. This comparison of only exported fields is applied recursively to nested data
// structures.
//
// This function does no assertion of any kind.
func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool {
expectedCleaned := copyExportedFields(expected)
actualCleaned := copyExportedFields(actual)
return ObjectsAreEqualValues(expectedCleaned, actualCleaned)
}
// ObjectsAreEqualValues gets whether two objects are equal, or if their
// values are equal.
func ObjectsAreEqualValues(expected, actual interface{}) bool {
@@ -141,12 +211,11 @@ func CallerInfo() []string {
}
parts := strings.Split(file, "/")
file = parts[len(parts)-1]
if len(parts) > 1 {
filename := parts[len(parts)-1]
dir := parts[len(parts)-2]
if (dir != "assert" && dir != "mock" && dir != "require") || file == "mock_test.go" {
path, _ := filepath.Abs(file)
callers = append(callers, fmt.Sprintf("%s:%d", path, line))
if (dir != "assert" && dir != "mock" && dir != "require") || filename == "mock_test.go" {
callers = append(callers, fmt.Sprintf("%s:%d", file, line))
}
}
@@ -273,7 +342,7 @@ type labeledContent struct {
// labeledOutput returns a string consisting of the provided labeledContent. Each labeled output is appended in the following manner:
//
// \t{{label}}:{{align_spaces}}\t{{content}}\n
// \t{{label}}:{{align_spaces}}\t{{content}}\n
//
// The initial carriage return is required to undo/erase any padding added by testing.T.Errorf. The "\t{{label}}:" is for the label.
// If a label is shorter than the longest label provided, padding spaces are added to make all the labels match in length. Once this
@@ -296,7 +365,7 @@ func labeledOutput(content ...labeledContent) string {
// Implements asserts that an object is implemented by the specified interface.
//
// assert.Implements(t, (*MyInterface)(nil), new(MyObject))
// assert.Implements(t, (*MyInterface)(nil), new(MyObject))
func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -328,7 +397,7 @@ func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs
// Equal asserts that two objects are equal.
//
// assert.Equal(t, 123, 123)
// assert.Equal(t, 123, 123)
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses). Function equality
@@ -369,7 +438,7 @@ func validateEqualArgs(expected, actual interface{}) error {
// Same asserts that two pointers reference the same object.
//
// assert.Same(t, ptr1, ptr2)
// assert.Same(t, ptr1, ptr2)
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
@@ -389,7 +458,7 @@ func Same(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) b
// NotSame asserts that two pointers do not reference the same object.
//
// assert.NotSame(t, ptr1, ptr2)
// assert.NotSame(t, ptr1, ptr2)
//
// Both arguments must be pointer variables. Pointer variable sameness is
// determined based on the equality of both type and value.
@@ -457,7 +526,7 @@ func truncatingFormat(data interface{}) string {
// EqualValues asserts that two objects are equal or convertable to the same types
// and equal.
//
// assert.EqualValues(t, uint32(123), int32(123))
// assert.EqualValues(t, uint32(123), int32(123))
func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -475,9 +544,53 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa
}
// EqualExportedValues asserts that the types of two objects are equal and their public
// fields are also equal. This is useful for comparing structs that have private fields
// that could potentially differ.
//
// type S struct {
// Exported int
// notExported int
// }
// assert.EqualExportedValues(t, S{1, 2}, S{1, 3}) => true
// assert.EqualExportedValues(t, S{1, 2}, S{2, 3}) => false
func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
aType := reflect.TypeOf(expected)
bType := reflect.TypeOf(actual)
if aType != bType {
return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...)
}
if aType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...)
}
if bType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...)
}
expected = copyExportedFields(expected)
actual = copyExportedFields(actual)
if !ObjectsAreEqualValues(expected, actual) {
diff := diff(expected, actual)
expected, actual = formatUnequalValues(expected, actual)
return Fail(t, fmt.Sprintf("Not equal (comparing only exported fields): \n"+
"expected: %s\n"+
"actual : %s%s", expected, actual, diff), msgAndArgs...)
}
return true
}
// Exactly asserts that two objects are equal in value and type.
//
// assert.Exactly(t, int32(123), int64(123))
// assert.Exactly(t, int32(123), int64(123))
func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -496,7 +609,7 @@ func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
// NotNil asserts that the specified object is not nil.
//
// assert.NotNil(t, err)
// assert.NotNil(t, err)
func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if !isNil(object) {
return true
@@ -530,7 +643,7 @@ func isNil(object interface{}) bool {
[]reflect.Kind{
reflect.Chan, reflect.Func,
reflect.Interface, reflect.Map,
reflect.Ptr, reflect.Slice},
reflect.Ptr, reflect.Slice, reflect.UnsafePointer},
kind)
if isNilableKind && value.IsNil() {
@@ -542,7 +655,7 @@ func isNil(object interface{}) bool {
// Nil asserts that the specified object is nil.
//
// assert.Nil(t, err)
// assert.Nil(t, err)
func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if isNil(object) {
return true
@@ -585,7 +698,7 @@ func isEmpty(object interface{}) bool {
// Empty asserts that the specified object is empty. I.e. nil, "", false, 0 or either
// a slice or a channel with len == 0.
//
// assert.Empty(t, obj)
// assert.Empty(t, obj)
func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
pass := isEmpty(object)
if !pass {
@@ -602,9 +715,9 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
// NotEmpty asserts that the specified object is NOT empty. I.e. not nil, "", false, 0 or either
// a slice or a channel with len == 0.
//
// if assert.NotEmpty(t, obj) {
// assert.Equal(t, "two", obj[1])
// }
// if assert.NotEmpty(t, obj) {
// assert.Equal(t, "two", obj[1])
// }
func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
pass := !isEmpty(object)
if !pass {
@@ -633,7 +746,7 @@ func getLen(x interface{}) (ok bool, length int) {
// Len asserts that the specified object has specific length.
// Len also fails if the object has a type that len() not accept.
//
// assert.Len(t, mySlice, 3)
// assert.Len(t, mySlice, 3)
func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -651,7 +764,7 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{})
// True asserts that the specified value is true.
//
// assert.True(t, myBool)
// assert.True(t, myBool)
func True(t TestingT, value bool, msgAndArgs ...interface{}) bool {
if !value {
if h, ok := t.(tHelper); ok {
@@ -666,7 +779,7 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) bool {
// False asserts that the specified value is false.
//
// assert.False(t, myBool)
// assert.False(t, myBool)
func False(t TestingT, value bool, msgAndArgs ...interface{}) bool {
if value {
if h, ok := t.(tHelper); ok {
@@ -681,7 +794,7 @@ func False(t TestingT, value bool, msgAndArgs ...interface{}) bool {
// NotEqual asserts that the specified values are NOT equal.
//
// assert.NotEqual(t, obj1, obj2)
// assert.NotEqual(t, obj1, obj2)
//
// Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses).
@@ -704,7 +817,7 @@ func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{
// NotEqualValues asserts that two objects are not equal even when converted to the same type
//
// assert.NotEqualValues(t, obj1, obj2)
// assert.NotEqualValues(t, obj1, obj2)
func NotEqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -763,9 +876,9 @@ func containsElement(list interface{}, element interface{}) (ok, found bool) {
// Contains asserts that the specified string, list(array, slice...) or map contains the
// specified substring or element.
//
// assert.Contains(t, "Hello World", "World")
// assert.Contains(t, ["Hello", "World"], "World")
// assert.Contains(t, {"Hello": "World"}, "Hello")
// assert.Contains(t, "Hello World", "World")
// assert.Contains(t, ["Hello", "World"], "World")
// assert.Contains(t, {"Hello": "World"}, "Hello")
func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -786,9 +899,9 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
// NotContains asserts that the specified string, list(array, slice...) or map does NOT contain the
// specified substring or element.
//
// assert.NotContains(t, "Hello World", "Earth")
// assert.NotContains(t, ["Hello", "World"], "Earth")
// assert.NotContains(t, {"Hello": "World"}, "Earth")
// assert.NotContains(t, "Hello World", "Earth")
// assert.NotContains(t, ["Hello", "World"], "Earth")
// assert.NotContains(t, {"Hello": "World"}, "Earth")
func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -796,10 +909,10 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
ok, found := containsElement(s, contains)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", s), msgAndArgs...)
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", s), msgAndArgs...)
}
if found {
return Fail(t, fmt.Sprintf("\"%s\" should not contain \"%s\"", s, contains), msgAndArgs...)
return Fail(t, fmt.Sprintf("%#v should not contain %#v", s, contains), msgAndArgs...)
}
return true
@@ -809,7 +922,7 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
// Subset asserts that the specified list(array, slice...) contains all
// elements given in the specified subset(array, slice...).
//
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]")
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]")
func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -818,49 +931,44 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true // we consider nil to be equal to the nil set
}
defer func() {
if e := recover(); e != nil {
ok = false
}
}()
listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind()
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
}
subsetKind := reflect.TypeOf(subset).Kind()
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()
subsetMap := reflect.ValueOf(subset)
actualMap := reflect.ValueOf(list)
for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()
for _, k := range subsetMap.MapKeys() {
ev := subsetMap.MapIndex(k)
av := actualMap.MapIndex(k)
if !ObjectsAreEqual(subsetElement, listElement) {
return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, subsetElement), msgAndArgs...)
if !av.IsValid() {
return Fail(t, fmt.Sprintf("%#v does not contain %#v", list, subset), msgAndArgs...)
}
if !ObjectsAreEqual(ev.Interface(), av.Interface()) {
return Fail(t, fmt.Sprintf("%#v does not contain %#v", list, subset), msgAndArgs...)
}
}
return true
}
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
subsetList := reflect.ValueOf(subset)
for i := 0; i < subsetList.Len(); i++ {
element := subsetList.Index(i).Interface()
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
return Fail(t, fmt.Sprintf("%#v could not be applied builtin len()", list), msgAndArgs...)
}
if !found {
return Fail(t, fmt.Sprintf("\"%s\" does not contain \"%s\"", list, element), msgAndArgs...)
return Fail(t, fmt.Sprintf("%#v does not contain %#v", list, element), msgAndArgs...)
}
}
@@ -870,7 +978,7 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
// NotSubset asserts that the specified list(array, slice...) contains not all
// elements given in the specified subset(array, slice...).
//
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]")
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]")
func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -879,34 +987,28 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
return Fail(t, "nil is the empty set which is a subset of every set", msgAndArgs...)
}
defer func() {
if e := recover(); e != nil {
ok = false
}
}()
listKind := reflect.TypeOf(list).Kind()
subsetKind := reflect.TypeOf(subset).Kind()
if listKind != reflect.Array && listKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", list, listKind), msgAndArgs...)
}
subsetKind := reflect.TypeOf(subset).Kind()
if subsetKind != reflect.Array && subsetKind != reflect.Slice && listKind != reflect.Map {
return Fail(t, fmt.Sprintf("%q has an unsupported type %s", subset, subsetKind), msgAndArgs...)
}
subsetValue := reflect.ValueOf(subset)
if subsetKind == reflect.Map && listKind == reflect.Map {
listValue := reflect.ValueOf(list)
subsetKeys := subsetValue.MapKeys()
subsetMap := reflect.ValueOf(subset)
actualMap := reflect.ValueOf(list)
for i := 0; i < len(subsetKeys); i++ {
subsetKey := subsetKeys[i]
subsetElement := subsetValue.MapIndex(subsetKey).Interface()
listElement := listValue.MapIndex(subsetKey).Interface()
for _, k := range subsetMap.MapKeys() {
ev := subsetMap.MapIndex(k)
av := actualMap.MapIndex(k)
if !ObjectsAreEqual(subsetElement, listElement) {
if !av.IsValid() {
return true
}
if !ObjectsAreEqual(ev.Interface(), av.Interface()) {
return true
}
}
@@ -914,8 +1016,9 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
return Fail(t, fmt.Sprintf("%q is a subset of %q", subset, list), msgAndArgs...)
}
for i := 0; i < subsetValue.Len(); i++ {
element := subsetValue.Index(i).Interface()
subsetList := reflect.ValueOf(subset)
for i := 0; i < subsetList.Len(); i++ {
element := subsetList.Index(i).Interface()
ok, found := containsElement(list, element)
if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", list), msgAndArgs...)
@@ -1060,7 +1163,7 @@ func didPanic(f PanicTestFunc) (didPanic bool, message interface{}, stack string
// Panics asserts that the code inside the specified PanicTestFunc panics.
//
// assert.Panics(t, func(){ GoCrazy() })
// assert.Panics(t, func(){ GoCrazy() })
func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1076,7 +1179,7 @@ func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
// PanicsWithValue asserts that the code inside the specified PanicTestFunc panics, and that
// the recovered panic value equals the expected panic value.
//
// assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() })
// assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() })
func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1097,7 +1200,7 @@ func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndAr
// panics, and that the recovered panic value is an error that satisfies the
// EqualError comparison.
//
// assert.PanicsWithError(t, "crazy error", func(){ GoCrazy() })
// assert.PanicsWithError(t, "crazy error", func(){ GoCrazy() })
func PanicsWithError(t TestingT, errString string, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1117,7 +1220,7 @@ func PanicsWithError(t TestingT, errString string, f PanicTestFunc, msgAndArgs .
// NotPanics asserts that the code inside the specified PanicTestFunc does NOT panic.
//
// assert.NotPanics(t, func(){ RemainCalm() })
// assert.NotPanics(t, func(){ RemainCalm() })
func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1132,7 +1235,7 @@ func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
// WithinDuration asserts that the two times are within duration delta of each other.
//
// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second)
// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second)
func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1148,7 +1251,7 @@ func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration,
// WithinRange asserts that a time is within a time range (inclusive).
//
// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
// assert.WithinRange(t, time.Now(), time.Now().Add(-time.Second), time.Now().Add(time.Second))
func WithinRange(t TestingT, actual, start, end time.Time, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1207,7 +1310,7 @@ func toFloat(x interface{}) (float64, bool) {
// InDelta asserts that the two numerals are within delta of each other.
//
// assert.InDelta(t, math.Pi, 22/7.0, 0.01)
// assert.InDelta(t, math.Pi, 22/7.0, 0.01)
func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1380,10 +1483,10 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
// NoError asserts that a function returned no error (i.e. `nil`).
//
// actualObj, err := SomeFunction()
// if assert.NoError(t, err) {
// assert.Equal(t, expectedObj, actualObj)
// }
// actualObj, err := SomeFunction()
// if assert.NoError(t, err) {
// assert.Equal(t, expectedObj, actualObj)
// }
func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool {
if err != nil {
if h, ok := t.(tHelper); ok {
@@ -1397,10 +1500,10 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool {
// Error asserts that a function returned an error (i.e. not `nil`).
//
// actualObj, err := SomeFunction()
// if assert.Error(t, err) {
// assert.Equal(t, expectedError, err)
// }
// actualObj, err := SomeFunction()
// if assert.Error(t, err) {
// assert.Equal(t, expectedError, err)
// }
func Error(t TestingT, err error, msgAndArgs ...interface{}) bool {
if err == nil {
if h, ok := t.(tHelper); ok {
@@ -1415,8 +1518,8 @@ func Error(t TestingT, err error, msgAndArgs ...interface{}) bool {
// EqualError asserts that a function returned an error (i.e. not `nil`)
// and that it is equal to the provided error.
//
// actualObj, err := SomeFunction()
// assert.EqualError(t, err, expectedErrorString)
// actualObj, err := SomeFunction()
// assert.EqualError(t, err, expectedErrorString)
func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1438,8 +1541,8 @@ func EqualError(t TestingT, theError error, errString string, msgAndArgs ...inte
// ErrorContains asserts that a function returned an error (i.e. not `nil`)
// and that the error contains the specified substring.
//
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
// actualObj, err := SomeFunction()
// assert.ErrorContains(t, err, expectedErrorSubString)
func ErrorContains(t TestingT, theError error, contains string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1472,8 +1575,8 @@ func matchRegexp(rx interface{}, str interface{}) bool {
// Regexp asserts that a specified regexp matches a string.
//
// assert.Regexp(t, regexp.MustCompile("start"), "it's starting")
// assert.Regexp(t, "start...$", "it's not starting")
// assert.Regexp(t, regexp.MustCompile("start"), "it's starting")
// assert.Regexp(t, "start...$", "it's not starting")
func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1490,8 +1593,8 @@ func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface
// NotRegexp asserts that a specified regexp does not match a string.
//
// assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting")
// assert.NotRegexp(t, "^start", "it's not starting")
// assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting")
// assert.NotRegexp(t, "^start", "it's not starting")
func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1603,7 +1706,7 @@ func NoDirExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
// JSONEq asserts that two JSON strings are equivalent.
//
// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`)
// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`)
func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1726,7 +1829,7 @@ type tHelper interface {
// Eventually asserts that given condition will be met in waitFor time,
// periodically checking target function each tick.
//
// assert.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond)
// assert.Eventually(t, func() bool { return true; }, time.Second, 10*time.Millisecond)
func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
@@ -1756,10 +1859,93 @@ func Eventually(t TestingT, condition func() bool, waitFor time.Duration, tick t
}
}
// CollectT implements the TestingT interface and collects all errors.
type CollectT struct {
errors []error
}
// Errorf collects the error.
func (c *CollectT) Errorf(format string, args ...interface{}) {
c.errors = append(c.errors, fmt.Errorf(format, args...))
}
// FailNow panics.
func (c *CollectT) FailNow() {
panic("Assertion failed")
}
// Reset clears the collected errors.
func (c *CollectT) Reset() {
c.errors = nil
}
// Copy copies the collected errors to the supplied t.
func (c *CollectT) Copy(t TestingT) {
if tt, ok := t.(tHelper); ok {
tt.Helper()
}
for _, err := range c.errors {
t.Errorf("%v", err)
}
}
// EventuallyWithT asserts that given condition will be met in waitFor time,
// periodically checking target function each tick. In contrast to Eventually,
// it supplies a CollectT to the condition function, so that the condition
// function can use the CollectT to call other assertions.
// The condition is considered "met" if no errors are raised in a tick.
// The supplied CollectT collects all errors from one tick (if there are any).
// If the condition is not met before waitFor, the collected errors of
// the last tick are copied to t.
//
// externalValue := false
// go func() {
// time.Sleep(8*time.Second)
// externalValue = true
// }()
// assert.EventuallyWithT(t, func(c *assert.CollectT) {
// // add assertions as needed; any assertion failure will fail the current tick
// assert.True(c, externalValue, "expected 'externalValue' to be true")
// }, 1*time.Second, 10*time.Second, "external state has not changed to 'true'; still false")
func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
collect := new(CollectT)
ch := make(chan bool, 1)
timer := time.NewTimer(waitFor)
defer timer.Stop()
ticker := time.NewTicker(tick)
defer ticker.Stop()
for tick := ticker.C; ; {
select {
case <-timer.C:
collect.Copy(t)
return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick:
tick = nil
collect.Reset()
go func() {
condition(collect)
ch <- len(collect.errors) == 0
}()
case v := <-ch:
if v {
return true
}
tick = ticker.C
}
}
}
// Never asserts that the given condition doesn't satisfy in waitFor time,
// periodically checking the target function each tick.
//
// assert.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond)
// assert.Never(t, func() bool { return false; }, time.Second, 10*time.Millisecond)
func Never(t TestingT, condition func() bool, waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()

View File

@@ -1,39 +1,40 @@
// Package assert provides a set of comprehensive testing tools for use with the normal Go testing system.
//
// Example Usage
// # Example Usage
//
// The following is a complete example using assert in a standard test function:
// import (
// "testing"
// "github.com/stretchr/testify/assert"
// )
//
// func TestSomething(t *testing.T) {
// import (
// "testing"
// "github.com/stretchr/testify/assert"
// )
//
// var a string = "Hello"
// var b string = "Hello"
// func TestSomething(t *testing.T) {
//
// assert.Equal(t, a, b, "The two words should be the same.")
// var a string = "Hello"
// var b string = "Hello"
//
// }
// assert.Equal(t, a, b, "The two words should be the same.")
//
// }
//
// if you assert many times, use the format below:
//
// import (
// "testing"
// "github.com/stretchr/testify/assert"
// )
// import (
// "testing"
// "github.com/stretchr/testify/assert"
// )
//
// func TestSomething(t *testing.T) {
// assert := assert.New(t)
// func TestSomething(t *testing.T) {
// assert := assert.New(t)
//
// var a string = "Hello"
// var b string = "Hello"
// var a string = "Hello"
// var b string = "Hello"
//
// assert.Equal(a, b, "The two words should be the same.")
// }
// assert.Equal(a, b, "The two words should be the same.")
// }
//
// Assertions
// # Assertions
//
// Assertions allow you to easily write test code, and are global funcs in the `assert` package.
// All assertion functions take, as the first argument, the `*testing.T` object provided by the

View File

@@ -23,7 +23,7 @@ func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (
// HTTPSuccess asserts that a specified handler returns a success status code.
//
// assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil)
// assert.HTTPSuccess(t, myHandler, "POST", "http://www.google.com", nil)
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool {
@@ -45,7 +45,7 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, value
// HTTPRedirect asserts that a specified handler returns a redirect status code.
//
// assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
// assert.HTTPRedirect(t, myHandler, "GET", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool {
@@ -67,7 +67,7 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, valu
// HTTPError asserts that a specified handler returns an error status code.
//
// assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
// assert.HTTPError(t, myHandler, "POST", "/a/b/c", url.Values{"a": []string{"b", "c"}}
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool {
@@ -89,7 +89,7 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values
// HTTPStatusCode asserts that a specified handler returns a specified status code.
//
// assert.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501)
// assert.HTTPStatusCode(t, myHandler, "GET", "/notImplemented", nil, 501)
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, statuscode int, msgAndArgs ...interface{}) bool {
@@ -124,7 +124,7 @@ func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) s
// HTTPBodyContains asserts that a specified handler returns a
// body that contains a string.
//
// assert.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
// assert.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool {
@@ -144,7 +144,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string,
// HTTPBodyNotContains asserts that a specified handler returns a
// body that does not contain a string.
//
// assert.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
// assert.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
//
// Returns whether the assertion was successful (true) or not (false).
func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool {

View File

@@ -1,24 +1,25 @@
// Package require implements the same assertions as the `assert` package but
// stops test execution when a test fails.
//
// Example Usage
// # Example Usage
//
// The following is a complete example using require in a standard test function:
// import (
// "testing"
// "github.com/stretchr/testify/require"
// )
//
// func TestSomething(t *testing.T) {
// import (
// "testing"
// "github.com/stretchr/testify/require"
// )
//
// var a string = "Hello"
// var b string = "Hello"
// func TestSomething(t *testing.T) {
//
// require.Equal(t, a, b, "The two words should be the same.")
// var a string = "Hello"
// var b string = "Hello"
//
// }
// require.Equal(t, a, b, "The two words should be the same.")
//
// Assertions
// }
//
// # Assertions
//
// The `require` package have same global functions as in the `assert` package,
// but instead of returning a boolean result they call `t.FailNow()`.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff