112 lines
2.2 KiB
Go
112 lines
2.2 KiB
Go
|
// Copyright 2017 The Go Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package poll_test
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
"internal/poll"
|
||
|
"os"
|
||
|
"sync"
|
||
|
"syscall"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
type loggedFD struct {
|
||
|
Net string
|
||
|
FD *poll.FD
|
||
|
Err error
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
logMu sync.Mutex
|
||
|
loggedFDs map[syscall.Handle]*loggedFD
|
||
|
)
|
||
|
|
||
|
func logFD(net string, fd *poll.FD, err error) {
|
||
|
logMu.Lock()
|
||
|
defer logMu.Unlock()
|
||
|
|
||
|
loggedFDs[fd.Sysfd] = &loggedFD{
|
||
|
Net: net,
|
||
|
FD: fd,
|
||
|
Err: err,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func init() {
|
||
|
loggedFDs = make(map[syscall.Handle]*loggedFD)
|
||
|
*poll.LogInitFD = logFD
|
||
|
}
|
||
|
|
||
|
func findLoggedFD(h syscall.Handle) (lfd *loggedFD, found bool) {
|
||
|
logMu.Lock()
|
||
|
defer logMu.Unlock()
|
||
|
|
||
|
lfd, found = loggedFDs[h]
|
||
|
return lfd, found
|
||
|
}
|
||
|
|
||
|
// checkFileIsNotPartOfNetpoll verifies that f is not managed by netpoll.
|
||
|
// It returns error, if check fails.
|
||
|
func checkFileIsNotPartOfNetpoll(f *os.File) error {
|
||
|
lfd, found := findLoggedFD(syscall.Handle(f.Fd()))
|
||
|
if !found {
|
||
|
return fmt.Errorf("%v fd=%v: is not found in the log", f.Name(), f.Fd())
|
||
|
}
|
||
|
if lfd.FD.IsPartOfNetpoll() {
|
||
|
return fmt.Errorf("%v fd=%v: is part of netpoll, but should not be (logged: net=%v err=%v)", f.Name(), f.Fd(), lfd.Net, lfd.Err)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func TestFileFdsAreInitialised(t *testing.T) {
|
||
|
exe, err := os.Executable()
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
f, err := os.Open(exe)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
defer f.Close()
|
||
|
|
||
|
err = checkFileIsNotPartOfNetpoll(f)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestSerialFdsAreInitialised(t *testing.T) {
|
||
|
for _, name := range []string{"COM1", "COM2", "COM3", "COM4"} {
|
||
|
t.Run(name, func(t *testing.T) {
|
||
|
h, err := syscall.CreateFile(syscall.StringToUTF16Ptr(name),
|
||
|
syscall.GENERIC_READ|syscall.GENERIC_WRITE,
|
||
|
0,
|
||
|
nil,
|
||
|
syscall.OPEN_EXISTING,
|
||
|
syscall.FILE_ATTRIBUTE_NORMAL|syscall.FILE_FLAG_OVERLAPPED,
|
||
|
0)
|
||
|
if err != nil {
|
||
|
if errno, ok := err.(syscall.Errno); ok {
|
||
|
switch errno {
|
||
|
case syscall.ERROR_FILE_NOT_FOUND,
|
||
|
syscall.ERROR_ACCESS_DENIED:
|
||
|
t.Log("Skipping: ", err)
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
f := os.NewFile(uintptr(h), name)
|
||
|
defer f.Close()
|
||
|
|
||
|
err = checkFileIsNotPartOfNetpoll(f)
|
||
|
if err != nil {
|
||
|
t.Fatal(err)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|