Skip to content

Commit 5dcaf5e

Browse files
committed
Use WNetGetUniversalNameW instead of net use
1 parent e932c38 commit 5dcaf5e

File tree

4 files changed

+286
-321
lines changed

4 files changed

+286
-321
lines changed

pkg/windows/windows.go

Lines changed: 22 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import (
44
"errors"
55
"fmt"
66
"os"
7-
"os/exec"
8-
"path/filepath"
97
"regexp"
108
"strings"
119
"unicode"
@@ -62,29 +60,14 @@ func IsWindowsNetworkMount(fp string) bool {
6260
return windowsNetworkMountRegex.MatchString(fp)
6361
}
6462

65-
// commander is an interface for exec.Command function.
66-
type commander interface {
67-
Command(name string, args ...string) *exec.Cmd
68-
}
69-
70-
// realCommander implements commander interface and is used by default.
71-
type realCommander struct{}
72-
73-
// Command calls exec.Command function.
74-
func (realCommander) Command(name string, args ...string) *exec.Cmd {
75-
//nolint:gosec // Wrapper for dependency injection in tests; callers provide validated command values.
76-
return exec.Command(name, args...)
77-
}
63+
const driveRemote = 4
7864

79-
// nolint:gochecknoglobals
80-
// commander replaces exec.Command function. It is initialized in init()
81-
// and can be overwritten in tests.
82-
var cmd commander
83-
84-
// nolint:gochecknoinits
85-
func init() {
86-
cmd = realCommander{}
87-
}
65+
// nolint:gochecknoglobals // Package-level seams for stubbing Win32 API calls in tests.
66+
var (
67+
getDriveType = systemGetDriveType
68+
getUniversalName = systemGetUniversalName
69+
getConnectionName = systemGetConnectionName
70+
)
8871

8972
// FormatLocalFilePath maps entity filepath to unc path, if neither
9073
// localFile, nor entity file are existing.
@@ -108,183 +91,39 @@ func FormatLocalFilePath(localFile, entity string) (string, error) {
10891
}
10992

11093
// toUncPath converts a filepath to a Universal Naming Convention path
111-
// by querying remote drive information via `net use` cmd.
94+
// by querying Windows networking APIs for mapped drive information.
11295
func toUncPath(fp string) (string, error) {
11396
letter, rest := splitDrive(fp)
11497
if letter == "" {
11598
return fp, nil
11699
}
117100

118-
out, err := netUseOutput()
101+
drive := letter + `:\`
102+
driveType, err := getDriveType(drive)
119103
if err != nil {
120-
return "", fmt.Errorf("failed to execute net use command: %s", err)
121-
}
122-
123-
drives, err := parseNetUseOutput(string(out))
124-
if err != nil {
125-
return "", fmt.Errorf("failed to parse net use command output: %s", err)
126-
}
127-
128-
if drive, ok := drives[driveLetter(letter)]; ok {
129-
return string(drive) + rest, nil
130-
}
131-
132-
return fp, nil
133-
}
134-
135-
func netUseOutput() ([]byte, error) {
136-
var (
137-
out []byte
138-
err error
139-
cmdErrs []string
140-
)
141-
142-
cmds := [][]string{
143-
{"net", "use"},
144-
{"net.exe", "use"},
104+
return "", fmt.Errorf("failed to get drive type for %q: %w", drive, err)
145105
}
146106

147-
if winDir := os.Getenv("WINDIR"); winDir != "" {
148-
cmds = append(cmds, []string{filepath.Join(winDir, "System32", "net.exe"), "use"})
107+
if driveType != driveRemote {
108+
return fp, nil
149109
}
150110

151-
for _, args := range cmds {
152-
out, err = cmd.Command(args[0], args[1:]...).Output()
111+
if rest != "" {
112+
uncPath, err := getUniversalName(fp)
153113
if err == nil {
154-
return out, nil
114+
return uncPath, nil
155115
}
156-
157-
cmdErrs = append(cmdErrs, fmt.Sprintf("%q: %s", strings.Join(args, " "), err))
158-
}
159-
160-
return nil, errors.New(strings.Join(cmdErrs, "; "))
161-
}
162-
163-
// driveLetter represents the letter of a drive.
164-
type driveLetter string
165-
166-
// remoteDrive represents the path to a remote drive.
167-
type remoteDrive string
168-
169-
// remoteDrives maps drive letters to remote drives.
170-
type remoteDrives map[driveLetter]remoteDrive
171-
172-
// parseNetUseOutput parses the drives from net use output.
173-
func parseNetUseOutput(text string) (remoteDrives, error) {
174-
var (
175-
cols netUseColumns
176-
err error
177-
)
178-
179-
lines := strings.Split(strings.ReplaceAll(text, "\r\n", "\n"), "\n")
180-
181-
drives := make(remoteDrives)
182-
183-
for _, line := range lines[1 : len(lines)-1] {
184-
if len(strings.TrimSpace(line)) == 0 || strings.ContainsAny(line, "---") {
185-
continue
116+
if !errors.Is(err, errNotSupported) {
117+
return "", fmt.Errorf("failed to get universal path for %q: %w", fp, err)
186118
}
187-
188-
if cols.Empty() {
189-
cols, err = parseNetUseColumns(line)
190-
if err != nil {
191-
return nil, fmt.Errorf("%s from 'net use' output: %s", err, strings.Join(lines, "\n"))
192-
}
193-
194-
continue
195-
}
196-
197-
local := line[cols.Local.Start : cols.Local.Start+cols.Remote.Width]
198-
local = strings.ToUpper(strings.TrimSpace(local))
199-
200-
if len(strings.Split(local, ":")) == 0 || strings.Split(local, ":")[0] == "" {
201-
continue
202-
}
203-
204-
letter := strings.Split(local, ":")[0][0]
205-
if !unicode.IsLetter(rune(letter)) {
206-
continue
207-
}
208-
209-
remote := strings.TrimSpace(line[cols.Remote.Start : cols.Remote.Start+cols.Remote.Width])
210-
if remote == "" {
211-
continue
212-
}
213-
214-
drives[driveLetter(letter)] = remoteDrive(strings.TrimSpace(remote))
215-
}
216-
217-
return drives, nil
218-
}
219-
220-
// netUseColumn represents a column of the 'net use' windows command output.
221-
// It has a start and end position and is used to parse the listed mapped
222-
// network drives.
223-
type netUseColumn struct {
224-
Start int
225-
Width int
226-
}
227-
228-
// Empty returns true, if netUseColumn is unset.
229-
func (c netUseColumn) Empty() bool {
230-
if c.Start == 0 && c.Width == 0 {
231-
return true
232119
}
233120

234-
return false
235-
}
236-
237-
// netUseColumn represents the column of the 'net use' windows command output.
238-
// Only the local and remote column are of importance here.
239-
type netUseColumns struct {
240-
Local netUseColumn
241-
Remote netUseColumn
242-
}
243-
244-
// Empty returns true, if all netUseColumns are unset.
245-
func (c netUseColumns) Empty() bool {
246-
return c.Local.Empty() && c.Remote.Empty()
247-
}
248-
249-
// parseNetUseColumns parses the column line of the 'net use' windows command
250-
// to determine their start position and width.
251-
func parseNetUseColumns(line string) (netUseColumns, error) {
252-
re := regexp.MustCompile(`[a-zA-Z]+[^a-zA-Z]*`)
253-
matches := re.FindAllString(line, -1)
254-
255-
var (
256-
cols netUseColumns
257-
start int
258-
)
259-
260-
for _, match := range matches {
261-
key := strings.ToLower(strings.TrimSpace(match))
262-
263-
switch key {
264-
case "local":
265-
cols.Local = netUseColumn{
266-
Start: start,
267-
Width: len(match),
268-
}
269-
case "remote":
270-
cols.Remote = netUseColumn{
271-
Start: start,
272-
Width: len(match),
273-
}
274-
}
275-
276-
start += len(match)
277-
}
278-
279-
if cols.Local.Empty() {
280-
return netUseColumns{}, errors.New("failed to parse local column")
281-
}
282-
283-
if cols.Remote.Empty() {
284-
return netUseColumns{}, errors.New("failed to parse remote column")
121+
connection, err := getConnectionName(letter + ":")
122+
if err != nil {
123+
return "", fmt.Errorf("failed to get connection name for %q: %w", letter+":", err)
285124
}
286125

287-
return cols, nil
126+
return connection + rest, nil
288127
}
289128

290129
// splitDrive splits a filepath into the drive letter and the path.

pkg/windows/windows_api_stub.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//go:build !windows
2+
3+
package windows
4+
5+
import "errors"
6+
7+
var errNotSupported = errors.New("operation not supported")
8+
9+
func systemGetDriveType(string) (uint32, error) {
10+
return 0, errNotSupported
11+
}
12+
13+
func systemGetConnectionName(string) (string, error) {
14+
return "", errNotSupported
15+
}
16+
17+
func systemGetUniversalName(string) (string, error) {
18+
return "", errNotSupported
19+
}

pkg/windows/windows_api_windows.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
//go:build windows
2+
3+
package windows
4+
5+
import (
6+
"errors"
7+
"fmt"
8+
"unsafe"
9+
10+
xsyswindows "golang.org/x/sys/windows"
11+
)
12+
13+
var errNotSupported = errors.New("operation not supported")
14+
15+
const (
16+
universalNameInfoLevel = 0x00000001
17+
errorMoreData = 234
18+
)
19+
20+
type universalNameInfo struct {
21+
UniversalName *uint16
22+
}
23+
24+
var (
25+
modKernel32 = xsyswindows.NewLazySystemDLL("kernel32.dll")
26+
procGetDriveTypeW = modKernel32.NewProc("GetDriveTypeW")
27+
modMpr = xsyswindows.NewLazySystemDLL("mpr.dll")
28+
procWNetGetConnectionW = modMpr.NewProc("WNetGetConnectionW")
29+
procWNetGetUniversalW = modMpr.NewProc("WNetGetUniversalNameW")
30+
)
31+
32+
func systemGetDriveType(rootPath string) (uint32, error) {
33+
path, err := xsyswindows.UTF16PtrFromString(rootPath)
34+
if err != nil {
35+
return 0, err
36+
}
37+
38+
driveType, _, callErr := procGetDriveTypeW.Call(uintptr(unsafe.Pointer(path)))
39+
if driveType == 0 {
40+
return 0, fmt.Errorf("GetDriveTypeW: %w", callErr)
41+
}
42+
43+
return uint32(driveType), nil
44+
}
45+
46+
func systemGetConnectionName(localName string) (string, error) {
47+
name, err := xsyswindows.UTF16PtrFromString(localName)
48+
if err != nil {
49+
return "", err
50+
}
51+
52+
size := uint32(260)
53+
for {
54+
buffer := make([]uint16, size)
55+
ret, _, _ := procWNetGetConnectionW.Call(
56+
uintptr(unsafe.Pointer(name)),
57+
uintptr(unsafe.Pointer(&buffer[0])),
58+
uintptr(unsafe.Pointer(&size)),
59+
)
60+
61+
switch ret {
62+
case 0:
63+
return xsyswindows.UTF16ToString(buffer), nil
64+
case errorMoreData:
65+
continue
66+
case 50:
67+
return "", errNotSupported
68+
default:
69+
return "", fmt.Errorf("WNetGetConnectionW returned %d", ret)
70+
}
71+
}
72+
}
73+
74+
func systemGetUniversalName(path string) (string, error) {
75+
name, err := xsyswindows.UTF16PtrFromString(path)
76+
if err != nil {
77+
return "", err
78+
}
79+
80+
size := uint32(1024)
81+
for {
82+
buffer := make([]byte, size)
83+
ret, _, _ := procWNetGetUniversalW.Call(
84+
uintptr(unsafe.Pointer(name)),
85+
uintptr(universalNameInfoLevel),
86+
uintptr(unsafe.Pointer(&buffer[0])),
87+
uintptr(unsafe.Pointer(&size)),
88+
)
89+
90+
switch ret {
91+
case 0:
92+
info := (*universalNameInfo)(unsafe.Pointer(&buffer[0]))
93+
return xsyswindows.UTF16PtrToString(info.UniversalName), nil
94+
case errorMoreData:
95+
continue
96+
case 50, 1200:
97+
return "", errNotSupported
98+
default:
99+
return "", fmt.Errorf("WNetGetUniversalNameW returned %d", ret)
100+
}
101+
}
102+
}

0 commit comments

Comments
 (0)