from loguru import logger import sys sys.path.append("pyrat39/") from analysis_param import AnalysisParam from sound_pyrat.analyzer_split import analyze_splits from pyrat import load_img from sound_pyrat.domains.domain_factory import create_box_maker import torch from sound_pyrat.utils_analyzer import analyze, read_inputs, read_model, analyze_adv import warnings warnings.filterwarnings("ignore", message="invalid value encountered in *") warnings.filterwarnings("ignore", message="divide by zero encountered in reciprocal") def launch_pyrat(model_path, property_path, domains=[], split_timeout=0): """ Args: model_path: path to the model property_path: path to the property file domains: a python list of domains to propagate in the network from "zonotopes" or "poly" nb_of_splits: maximum number of splits to do on the input intervals (-1 for no split) """ logger.remove() params = AnalysisParam(domains=domains, squeeze=True, verbose=False, check="skip", by_layer=True) inputs = read_inputs(model_path=model_path, prop_path=property_path, params=params) pyrat_model, bounds, to_verify, to_counter = inputs if pyrat_model.root[0].input[0].input_shape == (1, 3): pyrat_model.root[0].input[0].input_shape = (3,) bounds = params.box_maker(bounds[0], bounds[1]).to_type(pyrat_model.dtype) if pyrat_model.root[0].input[0].input_shape == (1, 3): bounds = bounds.reshape((3,)) if split_timeout == 0: result = analyze(bounds, pyrat_model, to_verify, to_counter, params) print("Output bounds:\n", result.single_res.output_bounds) print("Result = {}, Time = {:.2f} s".format(result.result, result.time)) else: params.timeout = split_timeout result = analyze_splits(bounds, pyrat_model, to_verify, to_counter, params) result.print_res() return result.result, result.time def local_robustness(model_path, image, label, pert, domains=[]): box_maker = create_box_maker("torch", sound=False, dtype=torch.float32) params = AnalysisParam(total_labels=10, epsilon=pert, box_maker=box_maker, domains=domains, by_layer=True, force_analysis=True, true_label=label) pyrat_model = read_model(model_path, params=params) res = analyze_adv(image, model_path, pyrat_model, params) return res.result, res.time def read_images(): images = [] logger.remove() with open("fmnist/img_labels.csv") as f: lines = f.readlines() for line in lines[:1]: name, label = line.split(",") image = load_img(f"fmnist/images/{name}", grayscale=False) label = int(label.replace("\n", "").strip()) images.append((image, label)) return images